# pyright: basic from __future__ import annotations from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from matplotlib.axes import Axes # Easily editable plot text overrides by plot key. # Example: # "performance_threshold": { # "title": "Custom Title", # "x_label": "Custom X", # "y_label": "Custom Y", # } # Common keys: # - performance_threshold_accuracy # - performance_threshold_f1 # - performance_uncertainty_cutoff_accuracy # - performance_uncertainty_cutoff_f1 # - performance_uncertainty_percentile_cutoff_accuracy # - performance_uncertainty_percentile_cutoff_f1 # - calibration_reliability # - noise_sensitivity_accuracy # - noise_sensitivity_f1 # - noise_confidence # - noise_standard_deviation # - noise_predictive_uncertainty # - boxplot PLOT_TEXT_OVERRIDES: dict[str, dict[str, str]] = {} def _resolve_plot_text( plot_key: str, default_title: str, default_x_label: str, default_y_label: str, ) -> tuple[str, str, str]: override = PLOT_TEXT_OVERRIDES.get(plot_key, {}) return ( override.get("title", default_title), override.get("x_label", default_x_label), override.get("y_label", default_y_label), ) def annotate_stats_box( ax: Axes, lines: list[str], location: str = "upper left", ) -> None: if not lines: return locations: dict[str, tuple[float, float, str, str]] = { "upper left": (0.02, 0.98, "left", "top"), "upper right": (0.98, 0.98, "right", "top"), "lower left": (0.02, 0.02, "left", "bottom"), "lower right": (0.98, 0.02, "right", "bottom"), } x, y, ha, va = locations.get(location, locations["upper left"]) ax.text( x, y, "\n".join(lines), transform=ax.transAxes, ha=ha, va=va, fontsize=10, bbox={ "boxstyle": "round,pad=0.35", "facecolor": "white", "edgecolor": "#555555", "alpha": 0.9, }, ) def _plot_correct_incorrect_bars( ax: Axes, x_values: pd.Series, n_correct: pd.Series, n_incorrect: pd.Series, ) -> None: x = np.asarray(x_values, dtype=float) correct = np.asarray(n_correct, dtype=float) incorrect = np.asarray(n_incorrect, dtype=float) if x.size == 0 or correct.size == 0 or incorrect.size == 0: return width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04 max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0)) bars_ax = ax.twinx() bars_ax.patch.set_alpha(0.0) bars_ax.bar( x, correct, width=width, color="#2ca02c", alpha=0.2, label="correct", zorder=0, align="center", ) bars_ax.bar( x, -incorrect, width=width, color="#d62728", alpha=0.2, label="incorrect", zorder=0, align="center", ) bars_ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4) bars_ax.set_ylim(-1.15 * max_count, 1.15 * max_count) bars_ax.set_yticks([]) bars_ax.grid(False) def save_coverage_bar_plot( x_values: pd.Series | np.ndarray, n_correct: pd.Series | np.ndarray, n_incorrect: pd.Series | np.ndarray, x_label: str, title: str, output_path: Path, ) -> None: """Save a standalone bar chart showing sample counts (correct vs incorrect).""" x = np.asarray(x_values, dtype=float) correct = np.asarray(n_correct, dtype=float) incorrect = np.asarray(n_incorrect, dtype=float) if x.size == 0 or correct.size == 0 or incorrect.size == 0: return width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04 max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0)) fig, ax = plt.subplots(figsize=(10, 5)) ax.bar( x, correct, width=width, color="#2ca02c", alpha=0.6, label="correct", align="center", ) ax.bar( x, -incorrect, width=width, color="#d62728", alpha=0.6, label="incorrect", align="center", ) ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4) ax.set_ylim(-1.15 * max_count, 1.15 * max_count) ax.set_xlabel(x_label) ax.set_ylabel("Sample Count") ax.set_title(title) ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def plots_dir(output_dir: Path) -> Path: plots = output_dir / "plots" plots.mkdir(parents=True, exist_ok=True) return plots def save_performance_threshold_plot( df: pd.DataFrame, backend: str, output_path: Path, metric_column: str, metric_label: str, plot_key: str, ) -> None: title, x_label, y_label = _resolve_plot_text( plot_key=plot_key, default_title=f"{metric_label} vs Decision Threshold ({backend})", default_x_label="Decision Threshold", default_y_label=metric_label, ) n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric( df["tn"], errors="coerce" ) n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric( df["fn"], errors="coerce" ) fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(df["threshold"], df[metric_column], label=metric_label, marker="o") ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) # Generate separate coverage bar plot coverage_path = output_path.parent / f"{output_path.stem}_coverage.png" save_coverage_bar_plot( x_values=df["threshold"], n_correct=n_correct, n_incorrect=n_incorrect, x_label=x_label, title=f"Sample Distribution vs Decision Threshold ({backend})", output_path=coverage_path, ) def save_performance_threshold_pair_plot( df: pd.DataFrame, backend: str, output_path: Path, plot_key: str, ) -> None: title, x_label, _ = _resolve_plot_text( plot_key=plot_key, default_title=f"Accuracy and F1 vs Decision Threshold ({backend})", default_x_label="Decision Threshold", default_y_label="Accuracy/F1", ) n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric( df["tn"], errors="coerce" ) n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric( df["fn"], errors="coerce" ) fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True) for ax, metric_col, metric_label, marker in [ (axes[0], "accuracy", "Accuracy", "o"), (axes[1], "f1", "F1", "s"), ]: ax.plot(df["threshold"], df[metric_col], label=metric_label, marker=marker) ax.set_xlabel(x_label) ax.set_ylabel(metric_label) ax.set_title(f"{metric_label}") ax.grid(True, alpha=0.3) ax.legend() fig.suptitle(title) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) # Generate separate coverage bar plot coverage_path = output_path.parent / f"{output_path.stem}_coverage.png" save_coverage_bar_plot( x_values=df["threshold"], n_correct=n_correct, n_incorrect=n_incorrect, x_label=x_label, title=f"Sample Distribution vs Decision Threshold ({backend})", output_path=coverage_path, ) def save_uncertainty_cutoff_plot( cutoff_df: pd.DataFrame, title_prefix: str, x_label: str, output_path: Path, metric_column: str, metric_label: str, plot_key: str, ) -> None: title, x_label_final, y_label = _resolve_plot_text( plot_key=plot_key, default_title=f"{metric_label} vs {title_prefix}", default_x_label=x_label, default_y_label=metric_label, ) fig, ax = plt.subplots(figsize=(10, 5)) for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"): g = group.sort_values("restriction_level") ax.plot( g["restriction_level"], g[metric_column], marker="o", label=uncertainty_name, ) ax.set_title(title) ax.set_xlabel(x_label_final) ax.set_ylabel(y_label) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) # Generate separate coverage bar plot first_group = ( cutoff_df.sort_values(["uncertainty_type", "restriction_level"]) .groupby("uncertainty_type", as_index=False) .head(1) ) if not first_group.empty: rep_name = str(first_group.iloc[0]["uncertainty_type"]) rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values( "restriction_level" ) coverage_path = output_path.parent / f"{output_path.stem}_coverage.png" save_coverage_bar_plot( x_values=rep["restriction_level"], n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"), n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"), x_label=x_label_final, title=f"Sample Coverage vs {title_prefix}", output_path=coverage_path, ) def save_uncertainty_cutoff_pair_plot( cutoff_df: pd.DataFrame, title_prefix: str, x_label: str, output_path: Path, plot_key: str, ) -> None: title, x_label_final, _ = _resolve_plot_text( plot_key=plot_key, default_title=f"Accuracy and F1 vs {title_prefix}", default_x_label=x_label, default_y_label="Accuracy/F1", ) fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True) for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"): g = group.sort_values("restriction_level") axes[0].plot( g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name ) axes[1].plot( g["restriction_level"], g["f1"], marker="s", label=uncertainty_name ) axes[0].set_title("Accuracy") axes[1].set_title("F1") for ax, metric_label in [(axes[0], "Accuracy"), (axes[1], "F1")]: ax.set_xlabel(x_label_final) ax.set_ylabel(metric_label) ax.grid(True, alpha=0.3) ax.legend() fig.suptitle(title) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) # Generate separate coverage bar plot first_group = ( cutoff_df.sort_values(["uncertainty_type", "restriction_level"]) .groupby("uncertainty_type", as_index=False) .head(1) ) if not first_group.empty: rep_name = str(first_group.iloc[0]["uncertainty_type"]) rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values( "restriction_level" ) coverage_path = output_path.parent / f"{output_path.stem}_coverage.png" save_coverage_bar_plot( x_values=rep["restriction_level"], n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"), n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"), x_label=x_label_final, title=f"Sample Coverage vs {title_prefix}", output_path=coverage_path, ) def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None: title, x_label, y_label = _resolve_plot_text( plot_key="calibration_reliability", default_title=f"Reliability Diagram ({backend})", default_x_label="Mean Predicted Probability", default_y_label="Empirical Fraction Positive", ) fig, ax = plt.subplots(figsize=(6, 6)) valid = ~np.isnan(per_bin[:, 1]) ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal") ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_boxplot( data: list[np.ndarray], tick_labels: list[str], x_label: str, y_label: str, title: str, output_path: Path, ) -> None: title_final, x_label_final, y_label_final = _resolve_plot_text( plot_key="boxplot", default_title=title, default_x_label=x_label, default_y_label=y_label, ) labels_with_n = [ f"{label}\n(n={len(np.asarray(values, dtype=float))})" for label, values in zip(tick_labels, data) ] fig, ax = plt.subplots(figsize=(9, 5)) ax.boxplot(data, tick_labels=labels_with_n) ax.set_xlabel(x_label_final) ax.set_ylabel(y_label_final) ax.set_title(title_final) ax.grid(True, axis="y", alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def _central_slice(volume: torch.Tensor) -> np.ndarray: tensor = volume.detach().cpu() if tensor.ndim == 5: tensor = tensor[0] if tensor.ndim == 4: tensor = tensor[0] if tensor.ndim != 3: raise ValueError( f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}" ) center_index = tensor.shape[0] // 2 return tensor[center_index].numpy().astype(float) def _normalize_for_display(image: np.ndarray) -> np.ndarray: low = float(np.percentile(image, 1)) high = float(np.percentile(image, 99)) if high <= low: return np.zeros_like(image, dtype=float) clipped = np.clip(image, low, high) return (clipped - low) / (high - low) def save_noise_example_grid( original_mri: torch.Tensor, noisy_by_sigma: list[tuple[float, torch.Tensor]], output_path: Path, title: str, max_images: int = 9, n_rows: int = 2, ) -> None: if not noisy_by_sigma: return n_total = len(noisy_by_sigma) target = max(1, min(int(max_images), n_total)) if target >= n_total: selected = noisy_by_sigma else: # Sample indices across the full tested range, including first and last. raw_idx = np.linspace(0, n_total - 1, num=target) idx = np.round(raw_idx).astype(int) selected_indices = sorted(set(idx.tolist())) if len(selected_indices) < target: existing = set(selected_indices) for i in range(n_total): if i in existing: continue selected_indices.append(i) if len(selected_indices) >= target: break selected_indices = sorted(selected_indices) selected = [noisy_by_sigma[i] for i in selected_indices] n_images = len(selected) n_rows = max(1, int(n_rows)) n_cols = int(np.ceil(n_images / n_rows)) fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.8 * n_cols, 3.2 * n_rows)) axes_flat = np.atleast_1d(axes).reshape(-1) for idx, (sigma, noisy_tensor) in enumerate(selected): ax = axes_flat[idx] noisy_slice = _normalize_for_display(_central_slice(noisy_tensor)) ax.imshow(noisy_slice, cmap="gray") ax.set_title(f"Noise factor={sigma:g}") ax.axis("off") for idx in range(n_images, len(axes_flat)): axes_flat[idx].axis("off") fig.suptitle(title) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_clean_scan_image( original_mri: torch.Tensor, output_path: Path, ) -> None: image = _normalize_for_display(_central_slice(original_mri)) fig, ax = plt.subplots(figsize=(4, 4)) ax.imshow(image, cmap="gray") ax.axis("off") fig.subplots_adjust(left=0, right=1, top=1, bottom=0) output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path, bbox_inches="tight", pad_inches=0) plt.close(fig) def save_noise_metrics_plot( x: pd.Series, y: pd.Series, legend_label: str, marker: str, x_label: str, y_label: str, title: str, output_path: Path, plot_key: str, ) -> None: title_final, x_label_final, y_label_final = _resolve_plot_text( plot_key=plot_key, default_title=title, default_x_label=x_label, default_y_label=y_label, ) fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(x, y, marker=marker, label=legend_label) ax.set_xlabel(x_label_final) ax.set_ylabel(y_label_final) ax.set_title(title_final) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_metric_pair_plot( x: pd.Series, left_y: pd.Series, right_y: pd.Series, left_label: str, right_label: str, x_label: str, y_label: str, title: str, output_path: Path, plot_key: str, ) -> None: title_final, x_label_final, y_label_final = _resolve_plot_text( plot_key=plot_key, default_title=title, default_x_label=x_label, default_y_label=y_label, ) fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True) axes[0].plot(x, left_y, marker="o", label=left_label) axes[1].plot(x, right_y, marker="s", label=right_label) for ax, name in [(axes[0], left_label), (axes[1], right_label)]: ax.set_xlabel(x_label_final) ax.set_ylabel(name) ax.set_title(name) ax.grid(True, alpha=0.3) ax.legend() fig.suptitle(title_final) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig)