| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612 |
- # pyright: basic
- from __future__ import annotations
- from pathlib import Path
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- import torch
- from matplotlib.axes import Axes
- # Easily editable plot text overrides by plot key.
- # Example:
- # "performance_threshold": {
- # "title": "Custom Title",
- # "x_label": "Custom X",
- # "y_label": "Custom Y",
- # }
- # Common keys:
- # - performance_threshold_accuracy
- # - performance_threshold_f1
- # - performance_uncertainty_cutoff_accuracy
- # - performance_uncertainty_cutoff_f1
- # - performance_uncertainty_percentile_cutoff_accuracy
- # - performance_uncertainty_percentile_cutoff_f1
- # - calibration_reliability
- # - noise_sensitivity_accuracy
- # - noise_sensitivity_f1
- # - noise_confidence
- # - noise_standard_deviation
- # - noise_predictive_uncertainty
- # - boxplot
- PLOT_TEXT_OVERRIDES: dict[str, dict[str, str]] = {}
- def _resolve_plot_text(
- plot_key: str,
- default_title: str,
- default_x_label: str,
- default_y_label: str,
- ) -> tuple[str, str, str]:
- override = PLOT_TEXT_OVERRIDES.get(plot_key, {})
- return (
- override.get("title", default_title),
- override.get("x_label", default_x_label),
- override.get("y_label", default_y_label),
- )
- def annotate_stats_box(
- ax: Axes,
- lines: list[str],
- location: str = "upper left",
- ) -> None:
- if not lines:
- return
- locations: dict[str, tuple[float, float, str, str]] = {
- "upper left": (0.02, 0.98, "left", "top"),
- "upper right": (0.98, 0.98, "right", "top"),
- "lower left": (0.02, 0.02, "left", "bottom"),
- "lower right": (0.98, 0.02, "right", "bottom"),
- }
- x, y, ha, va = locations.get(location, locations["upper left"])
- ax.text(
- x,
- y,
- "\n".join(lines),
- transform=ax.transAxes,
- ha=ha,
- va=va,
- fontsize=10,
- bbox={
- "boxstyle": "round,pad=0.35",
- "facecolor": "white",
- "edgecolor": "#555555",
- "alpha": 0.9,
- },
- )
- def _plot_correct_incorrect_bars(
- ax: Axes,
- x_values: pd.Series,
- n_correct: pd.Series,
- n_incorrect: pd.Series,
- ) -> None:
- x = np.asarray(x_values, dtype=float)
- correct = np.asarray(n_correct, dtype=float)
- incorrect = np.asarray(n_incorrect, dtype=float)
- if x.size == 0 or correct.size == 0 or incorrect.size == 0:
- return
- width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04
- max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0))
- bars_ax = ax.twinx()
- bars_ax.patch.set_alpha(0.0)
- bars_ax.bar(
- x,
- correct,
- width=width,
- color="#2ca02c",
- alpha=0.2,
- label="correct",
- zorder=0,
- align="center",
- )
- bars_ax.bar(
- x,
- -incorrect,
- width=width,
- color="#d62728",
- alpha=0.2,
- label="incorrect",
- zorder=0,
- align="center",
- )
- bars_ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4)
- bars_ax.set_ylim(-1.15 * max_count, 1.15 * max_count)
- bars_ax.set_yticks([])
- bars_ax.grid(False)
- def save_coverage_bar_plot(
- x_values: pd.Series | np.ndarray,
- n_correct: pd.Series | np.ndarray,
- n_incorrect: pd.Series | np.ndarray,
- x_label: str,
- title: str,
- output_path: Path,
- ) -> None:
- """Save a standalone bar chart showing sample counts (correct vs incorrect)."""
- x = np.asarray(x_values, dtype=float)
- correct = np.asarray(n_correct, dtype=float)
- incorrect = np.asarray(n_incorrect, dtype=float)
- if x.size == 0 or correct.size == 0 or incorrect.size == 0:
- return
- width = float(np.diff(np.sort(x)).min()) * 0.8 if x.size > 1 else 0.04
- max_count = float(max(np.nanmax(correct), np.nanmax(incorrect), 1.0))
- fig, ax = plt.subplots(figsize=(10, 5))
- ax.bar(
- x,
- correct,
- width=width,
- color="#2ca02c",
- alpha=0.6,
- label="correct",
- align="center",
- )
- ax.bar(
- x,
- -incorrect,
- width=width,
- color="#d62728",
- alpha=0.6,
- label="incorrect",
- align="center",
- )
- ax.axhline(0.0, color="gray", linewidth=0.8, alpha=0.4)
- ax.set_ylim(-1.15 * max_count, 1.15 * max_count)
- ax.set_xlabel(x_label)
- ax.set_ylabel("Sample Count")
- ax.set_title(title)
- ax.legend()
- ax.grid(True, alpha=0.3)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def plots_dir(output_dir: Path) -> Path:
- plots = output_dir / "plots"
- plots.mkdir(parents=True, exist_ok=True)
- return plots
- def save_performance_threshold_plot(
- df: pd.DataFrame,
- backend: str,
- output_path: Path,
- metric_column: str,
- metric_label: str,
- plot_key: str,
- ) -> None:
- title, x_label, y_label = _resolve_plot_text(
- plot_key=plot_key,
- default_title=f"{metric_label} vs Decision Threshold ({backend})",
- default_x_label="Decision Threshold",
- default_y_label=metric_label,
- )
- n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
- df["tn"], errors="coerce"
- )
- n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
- df["fn"], errors="coerce"
- )
- fig, ax = plt.subplots(figsize=(10, 5))
- ax.plot(df["threshold"], df[metric_column], label=metric_label, marker="o")
- ax.set_xlabel(x_label)
- ax.set_ylabel(y_label)
- ax.set_title(title)
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- # Generate separate coverage bar plot
- coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
- save_coverage_bar_plot(
- x_values=df["threshold"],
- n_correct=n_correct,
- n_incorrect=n_incorrect,
- x_label=x_label,
- title=f"Sample Distribution vs Decision Threshold ({backend})",
- output_path=coverage_path,
- )
- def save_performance_threshold_pair_plot(
- df: pd.DataFrame,
- backend: str,
- output_path: Path,
- plot_key: str,
- ) -> None:
- title, x_label, _ = _resolve_plot_text(
- plot_key=plot_key,
- default_title=f"Accuracy and F1 vs Decision Threshold ({backend})",
- default_x_label="Decision Threshold",
- default_y_label="Accuracy/F1",
- )
- n_correct = pd.to_numeric(df["tp"], errors="coerce") + pd.to_numeric(
- df["tn"], errors="coerce"
- )
- n_incorrect = pd.to_numeric(df["fp"], errors="coerce") + pd.to_numeric(
- df["fn"], errors="coerce"
- )
- fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
- for ax, metric_col, metric_label, marker in [
- (axes[0], "accuracy", "Accuracy", "o"),
- (axes[1], "f1", "F1", "s"),
- ]:
- ax.plot(df["threshold"], df[metric_col], label=metric_label, marker=marker)
- ax.set_xlabel(x_label)
- ax.set_ylabel(metric_label)
- ax.set_title(f"{metric_label}")
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.suptitle(title)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- # Generate separate coverage bar plot
- coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
- save_coverage_bar_plot(
- x_values=df["threshold"],
- n_correct=n_correct,
- n_incorrect=n_incorrect,
- x_label=x_label,
- title=f"Sample Distribution vs Decision Threshold ({backend})",
- output_path=coverage_path,
- )
- def save_uncertainty_cutoff_plot(
- cutoff_df: pd.DataFrame,
- title_prefix: str,
- x_label: str,
- output_path: Path,
- metric_column: str,
- metric_label: str,
- plot_key: str,
- ) -> None:
- title, x_label_final, y_label = _resolve_plot_text(
- plot_key=plot_key,
- default_title=f"{metric_label} vs {title_prefix}",
- default_x_label=x_label,
- default_y_label=metric_label,
- )
- fig, ax = plt.subplots(figsize=(10, 5))
- for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
- g = group.sort_values("restriction_level")
- ax.plot(
- g["restriction_level"],
- g[metric_column],
- marker="o",
- label=uncertainty_name,
- )
- ax.set_title(title)
- ax.set_xlabel(x_label_final)
- ax.set_ylabel(y_label)
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- # Generate separate coverage bar plot
- first_group = (
- cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
- .groupby("uncertainty_type", as_index=False)
- .head(1)
- )
- if not first_group.empty:
- rep_name = str(first_group.iloc[0]["uncertainty_type"])
- rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
- "restriction_level"
- )
- coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
- save_coverage_bar_plot(
- x_values=rep["restriction_level"],
- n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"),
- n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"),
- x_label=x_label_final,
- title=f"Sample Coverage vs {title_prefix}",
- output_path=coverage_path,
- )
- def save_uncertainty_cutoff_pair_plot(
- cutoff_df: pd.DataFrame,
- title_prefix: str,
- x_label: str,
- output_path: Path,
- plot_key: str,
- ) -> None:
- title, x_label_final, _ = _resolve_plot_text(
- plot_key=plot_key,
- default_title=f"Accuracy and F1 vs {title_prefix}",
- default_x_label=x_label,
- default_y_label="Accuracy/F1",
- )
- fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
- for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"):
- g = group.sort_values("restriction_level")
- axes[0].plot(
- g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name
- )
- axes[1].plot(
- g["restriction_level"], g["f1"], marker="s", label=uncertainty_name
- )
- axes[0].set_title("Accuracy")
- axes[1].set_title("F1")
- for ax, metric_label in [(axes[0], "Accuracy"), (axes[1], "F1")]:
- ax.set_xlabel(x_label_final)
- ax.set_ylabel(metric_label)
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.suptitle(title)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- # Generate separate coverage bar plot
- first_group = (
- cutoff_df.sort_values(["uncertainty_type", "restriction_level"])
- .groupby("uncertainty_type", as_index=False)
- .head(1)
- )
- if not first_group.empty:
- rep_name = str(first_group.iloc[0]["uncertainty_type"])
- rep = cutoff_df[cutoff_df["uncertainty_type"] == rep_name].sort_values(
- "restriction_level"
- )
- coverage_path = output_path.parent / f"{output_path.stem}_coverage.png"
- save_coverage_bar_plot(
- x_values=rep["restriction_level"],
- n_correct=pd.to_numeric(rep["n_correct"], errors="coerce"),
- n_incorrect=pd.to_numeric(rep["n_incorrect"], errors="coerce"),
- x_label=x_label_final,
- title=f"Sample Coverage vs {title_prefix}",
- output_path=coverage_path,
- )
- def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None:
- title, x_label, y_label = _resolve_plot_text(
- plot_key="calibration_reliability",
- default_title=f"Reliability Diagram ({backend})",
- default_x_label="Mean Predicted Probability",
- default_y_label="Empirical Fraction Positive",
- )
- fig, ax = plt.subplots(figsize=(6, 6))
- valid = ~np.isnan(per_bin[:, 1])
- ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal")
- ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend)
- ax.set_xlabel(x_label)
- ax.set_ylabel(y_label)
- ax.set_title(title)
- ax.legend()
- ax.grid(True, alpha=0.3)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def save_boxplot(
- data: list[np.ndarray],
- tick_labels: list[str],
- x_label: str,
- y_label: str,
- title: str,
- output_path: Path,
- ) -> None:
- title_final, x_label_final, y_label_final = _resolve_plot_text(
- plot_key="boxplot",
- default_title=title,
- default_x_label=x_label,
- default_y_label=y_label,
- )
- labels_with_n = [
- f"{label}\n(n={len(np.asarray(values, dtype=float))})"
- for label, values in zip(tick_labels, data)
- ]
- fig, ax = plt.subplots(figsize=(9, 5))
- ax.boxplot(data, tick_labels=labels_with_n)
- ax.set_xlabel(x_label_final)
- ax.set_ylabel(y_label_final)
- ax.set_title(title_final)
- ax.grid(True, axis="y", alpha=0.3)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def _central_slice(volume: torch.Tensor) -> np.ndarray:
- tensor = volume.detach().cpu()
- if tensor.ndim == 5:
- tensor = tensor[0]
- if tensor.ndim == 4:
- tensor = tensor[0]
- if tensor.ndim != 3:
- raise ValueError(
- f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
- )
- center_index = tensor.shape[0] // 2
- return tensor[center_index].numpy().astype(float)
- def _normalize_for_display(image: np.ndarray) -> np.ndarray:
- low = float(np.percentile(image, 1))
- high = float(np.percentile(image, 99))
- if high <= low:
- return np.zeros_like(image, dtype=float)
- clipped = np.clip(image, low, high)
- return (clipped - low) / (high - low)
- def save_noise_example_grid(
- original_mri: torch.Tensor,
- noisy_by_sigma: list[tuple[float, torch.Tensor]],
- output_path: Path,
- title: str,
- max_images: int = 9,
- n_rows: int = 2,
- ) -> None:
- if not noisy_by_sigma:
- return
- n_total = len(noisy_by_sigma)
- target = max(1, min(int(max_images), n_total))
- if target >= n_total:
- selected = noisy_by_sigma
- else:
- # Sample indices across the full tested range, including first and last.
- raw_idx = np.linspace(0, n_total - 1, num=target)
- idx = np.round(raw_idx).astype(int)
- selected_indices = sorted(set(idx.tolist()))
- if len(selected_indices) < target:
- existing = set(selected_indices)
- for i in range(n_total):
- if i in existing:
- continue
- selected_indices.append(i)
- if len(selected_indices) >= target:
- break
- selected_indices = sorted(selected_indices)
- selected = [noisy_by_sigma[i] for i in selected_indices]
- n_images = len(selected)
- n_rows = max(1, int(n_rows))
- n_cols = int(np.ceil(n_images / n_rows))
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.8 * n_cols, 3.2 * n_rows))
- axes_flat = np.atleast_1d(axes).reshape(-1)
- for idx, (sigma, noisy_tensor) in enumerate(selected):
- ax = axes_flat[idx]
- noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
- ax.imshow(noisy_slice, cmap="gray")
- ax.set_title(f"Noise factor={sigma:g}")
- ax.axis("off")
- for idx in range(n_images, len(axes_flat)):
- axes_flat[idx].axis("off")
- fig.suptitle(title)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def save_clean_scan_image(
- original_mri: torch.Tensor,
- output_path: Path,
- ) -> None:
- image = _normalize_for_display(_central_slice(original_mri))
- fig, ax = plt.subplots(figsize=(4, 4))
- ax.imshow(image, cmap="gray")
- ax.axis("off")
- fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path, bbox_inches="tight", pad_inches=0)
- plt.close(fig)
- def save_noise_metrics_plot(
- x: pd.Series,
- y: pd.Series,
- legend_label: str,
- marker: str,
- x_label: str,
- y_label: str,
- title: str,
- output_path: Path,
- plot_key: str,
- ) -> None:
- title_final, x_label_final, y_label_final = _resolve_plot_text(
- plot_key=plot_key,
- default_title=title,
- default_x_label=x_label,
- default_y_label=y_label,
- )
- fig, ax = plt.subplots(figsize=(10, 5))
- ax.plot(x, y, marker=marker, label=legend_label)
- ax.set_xlabel(x_label_final)
- ax.set_ylabel(y_label_final)
- ax.set_title(title_final)
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def save_metric_pair_plot(
- x: pd.Series,
- left_y: pd.Series,
- right_y: pd.Series,
- left_label: str,
- right_label: str,
- x_label: str,
- y_label: str,
- title: str,
- output_path: Path,
- plot_key: str,
- ) -> None:
- title_final, x_label_final, y_label_final = _resolve_plot_text(
- plot_key=plot_key,
- default_title=title,
- default_x_label=x_label,
- default_y_label=y_label,
- )
- fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True)
- axes[0].plot(x, left_y, marker="o", label=left_label)
- axes[1].plot(x, right_y, marker="s", label=right_label)
- for ax, name in [(axes[0], left_label), (axes[1], right_label)]:
- ax.set_xlabel(x_label_final)
- ax.set_ylabel(name)
- ax.set_title(name)
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.suptitle(title_final)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
|