# pyright: basic from __future__ import annotations from pathlib import Path import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch def plots_dir(output_dir: Path) -> Path: plots = output_dir / "plots" plots.mkdir(parents=True, exist_ok=True) return plots def save_performance_threshold_plot( df: pd.DataFrame, backend: str, output_path: Path ) -> None: fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(df["threshold"], df["accuracy"], label="accuracy", marker="o") ax.plot(df["threshold"], df["f1"], label="f1", marker="s") ax.set_xlabel("Threshold") ax.set_ylabel("Score") ax.set_title(f"Performance vs Threshold ({backend})") ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_uncertainty_cutoff_plot( cutoff_df: pd.DataFrame, title_prefix: str, x_label: str, output_path: Path, ) -> None: fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharex=True) for uncertainty_name, group in cutoff_df.groupby("uncertainty_type"): g = group.sort_values("restriction_level") axes[0].plot( g["restriction_level"], g["accuracy"], marker="o", label=uncertainty_name ) axes[1].plot( g["restriction_level"], g["f1"], marker="s", label=uncertainty_name ) axes[0].set_title(f"Accuracy vs {title_prefix}") axes[1].set_title(f"F1 vs {title_prefix}") for ax in axes: ax.set_xlabel(x_label) ax.grid(True, alpha=0.3) ax.legend() axes[0].set_ylabel("Accuracy") axes[1].set_ylabel("F1") fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_calibration_plot(per_bin: np.ndarray, backend: str, output_path: Path) -> None: fig, ax = plt.subplots(figsize=(6, 6)) valid = ~np.isnan(per_bin[:, 1]) ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="ideal") ax.plot(per_bin[valid, 0], per_bin[valid, 1], marker="o", label=backend) ax.set_xlabel("Mean Predicted Probability") ax.set_ylabel("Empirical Fraction Positive") ax.set_title(f"Reliability Diagram ({backend})") ax.legend() ax.grid(True, alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_boxplot( data: list[np.ndarray], tick_labels: list[str], x_label: str, y_label: str, title: str, output_path: Path, ) -> None: fig, ax = plt.subplots(figsize=(9, 5)) ax.boxplot(data, tick_labels=tick_labels) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.grid(True, axis="y", alpha=0.3) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def _central_slice(volume: torch.Tensor) -> np.ndarray: tensor = volume.detach().cpu() if tensor.ndim == 5: tensor = tensor[0] if tensor.ndim == 4: tensor = tensor[0] if tensor.ndim != 3: raise ValueError( f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}" ) center_index = tensor.shape[0] // 2 return tensor[center_index].numpy().astype(float) def _normalize_for_display(image: np.ndarray) -> np.ndarray: low = float(np.percentile(image, 1)) high = float(np.percentile(image, 99)) if high <= low: return np.zeros_like(image, dtype=float) clipped = np.clip(image, low, high) return (clipped - low) / (high - low) def save_noise_example_grid( original_mri: torch.Tensor, noisy_by_sigma: list[tuple[float, torch.Tensor]], output_path: Path, title: str, ) -> None: if not noisy_by_sigma: return original_slice = _normalize_for_display(_central_slice(original_mri)) n_rows = len(noisy_by_sigma) fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows)) if n_rows == 1: axes = np.array([axes]) for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma): noisy_slice = _normalize_for_display(_central_slice(noisy_tensor)) ax_orig, ax_noisy = axes[row_idx] ax_orig.imshow(original_slice, cmap="gray") ax_orig.set_title("Original") ax_orig.axis("off") ax_noisy.imshow(noisy_slice, cmap="gray") ax_noisy.set_title(f"Noisy factor={sigma:g}") ax_noisy.axis("off") fig.suptitle(title) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def save_noise_metrics_plot( x: pd.Series, y_by_label: list[tuple[pd.Series, str, str]], x_label: str, y_label: str, title: str, output_path: Path, ) -> None: fig, ax = plt.subplots(figsize=(10, 5)) for series, marker, label in y_by_label: ax.plot(x, series, marker=marker, label=label) ax.set_xlabel(x_label) ax.set_ylabel(y_label) ax.set_title(title) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig)