# pyright: basic from __future__ import annotations from pathlib import Path from typing import Any import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from bayesian_torch.utils.util import predictive_entropy from data.dataset import ( divide_dataset_by_patient_id, initalize_dataloaders, load_adni_data_from_file, ) from model.cnn import CNN3D from .metrics import calibration_stats, performance_at_threshold from .runtime import write_json def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None: model.eval() def _central_slice(volume: torch.Tensor) -> np.ndarray: tensor = volume.detach().cpu() if tensor.ndim == 5: tensor = tensor[0] if tensor.ndim == 4: tensor = tensor[0] if tensor.ndim != 3: raise ValueError( f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}" ) center_index = tensor.shape[0] // 2 return tensor[center_index].numpy().astype(float) def _normalize_for_display(image: np.ndarray) -> np.ndarray: low = float(np.percentile(image, 1)) high = float(np.percentile(image, 99)) if high <= low: return np.zeros_like(image, dtype=float) clipped = np.clip(image, low, high) return (clipped - low) / (high - low) def _save_noise_example_grid( original_mri: torch.Tensor, noisy_by_sigma: list[tuple[float, torch.Tensor]], output_path: Path, title: str, ) -> None: if not noisy_by_sigma: return original_slice = _normalize_for_display(_central_slice(original_mri)) n_rows = len(noisy_by_sigma) fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows)) if n_rows == 1: axes = np.array([axes]) for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma): noisy_slice = _normalize_for_display(_central_slice(noisy_tensor)) ax_orig, ax_noisy = axes[row_idx] ax_orig.imshow(original_slice, cmap="gray") ax_orig.set_title(f"Original") ax_orig.axis("off") ax_noisy.imshow(noisy_slice, cmap="gray") ax_noisy.set_title(f"Noisy sigma={sigma:g}") ax_noisy.axis("off") fig.suptitle(title) fig.tight_layout() output_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(output_path) plt.close(fig) def _confidence_certainty(y_prob: np.ndarray) -> np.ndarray: # 2 * |p - 0.5| maps 0 -> very uncertain and 1 -> very certain. return 2.0 * np.abs(y_prob - 0.5) def _confidence_uncertainty(y_prob: np.ndarray) -> np.ndarray: # Invert certainty so that larger values mean more uncertainty, matching std. return 1.0 - _confidence_certainty(y_prob) def _apply_scaled_noise(volume: torch.Tensor, sigma: float) -> torch.Tensor: # Make sigma dimensionless with respect to MRI intensity scale. # sigma=1.0 means noise std equals total MRI intensity range, which is 65535 . return volume + (torch.randn_like(volume) * sigma * 65535) def _xls_pre(df: pd.DataFrame) -> pd.DataFrame: data = df[["Image Data ID", "Sex", "Age (current)"]].copy() data["Sex"] = data["Sex"].astype(str).str.strip() data = data.replace({"M": 0, "F": 1}) return data def _build_test_loader( config: dict[str, Any], root_dir: Path ) -> torch.utils.data.DataLoader: mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii") xls_file = (root_dir / config["data"]["xls_file_path"]).resolve() dataset = load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=_xls_pre, ) ptid_df = pd.read_csv(xls_file) ptid_df.columns = ptid_df.columns.str.strip() ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna( subset=["Image Data ID", "PTID"] ) ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int) ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip() ptid_df = ptid_df[ptid_df["PTID"] != ""] ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist())) splits = divide_dataset_by_patient_id( dataset, ptids, tuple(config["data"]["data_splits"]), seed=int(config["data"]["seed"]), ) _, _, test_loader = initalize_dataloaders( splits, batch_size=int(config["training"]["batch_size"]), ) return test_loader def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]: model_dir = Path(config["output"]["ensemble_path"]) model_files = sorted(model_dir.glob("model_run_*.pt")) if not model_files: raise FileNotFoundError(f"No ensemble model files found in {model_dir}") models: list[torch.nn.Module] = [] for model_file in model_files: model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) model.load_state_dict( torch.load(model_file, map_location=config["training"]["device"]), strict=False, ) model.eval() models.append(model) return models def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module: device = str(config["training"]["device"]) try: from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped] except ImportError as e: raise ImportError( "bayesian_torch is required for bayesian noise analysis" ) from e model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt" if not model_path.exists(): raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}") model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) prior_params: dict[str, float | bool | str] = { "prior_mu": 0.0, "prior_sigma": 1.0, "posterior_mu_init": 0.0, "posterior_rho_init": -3.0, "type": "Reparameterization", "moped_enable": False, "moped_delta": 0.5, } dnn_to_bnn(model, prior_params) model.to(device) model.load_state_dict(torch.load(model_path, map_location=device), strict=False) model.to(device) _enable_bayesian_sampling_mode(model) return model def _infer_with_noise_ensemble( test_loader: torch.utils.data.DataLoader, models: list[torch.nn.Module], sigma: float, class_index: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if not models: raise ValueError("No ensemble models were provided for noise inference") device = next(models[0].parameters()).device all_probs: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma) preds = [] for model in models: out = model((noisy, xls_device)) preds.append(out[:, class_index].detach().cpu().numpy()) pred_mat = np.stack(preds, axis=0) mean = pred_mat.mean(axis=0) std = pred_mat.std(axis=0) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_stds.extend(std.tolist()) all_true.extend(true.tolist()) return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds) def _infer_with_noise_bayesian( test_loader: torch.utils.data.DataLoader, model: torch.nn.Module, sigma: float, class_index: int, mc_passes: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: device = next(model.parameters()).device all_probs: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma) draws = [] for _ in range(mc_passes): out = model((noisy, xls_device)) draws.append(out.detach().cpu().numpy()) draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes) mean = draw_mat.mean(axis=0)[:, class_index] entropy_uncertainty = predictive_entropy(draw_mat) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist()) all_true.extend(true.tolist()) return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds) def run_noise_analysis( config: dict[str, Any], root_dir: Path, backend: str, output_dir: Path, class_index: int, noise_sigmas: list[float], threshold: float, calibration_bins: int, bayesian_mc_passes: int, ) -> dict[str, Any]: test_loader = _build_test_loader(config, root_dir) examples_dir = output_dir / "noise_examples" examples_dir.mkdir(parents=True, exist_ok=True) rows: list[dict[str, Any]] = [] if backend == "ensemble": models = _load_ensemble_models(config) example_rows: list[tuple[float, torch.Tensor]] = [] for sigma in noise_sigmas: y_true, y_prob, y_std = _infer_with_noise_ensemble( test_loader, models, sigma, class_index=class_index, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "std", "sigma": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "ece": float(cal["ece"]), "mce": float(cal["mce"]), "mean_confidence_certainty": float( np.nanmean(_confidence_certainty(y_prob)) ), "mean_confidence_uncertainty": float( np.nanmean(_confidence_uncertainty(y_prob)) ), "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float("nan"), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(models[0].parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma) example_rows.append((float(sigma), noisy_mri.detach().cpu())) _save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", ) elif backend == "bayesian": model = _load_bayesian_model(config) example_rows = [] for sigma in noise_sigmas: y_true, y_prob, y_std = _infer_with_noise_bayesian( test_loader, model, sigma, class_index=class_index, mc_passes=bayesian_mc_passes, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "predictive_entropy", "sigma": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "ece": float(cal["ece"]), "mce": float(cal["mce"]), "mean_confidence_certainty": float( np.nanmean(_confidence_certainty(y_prob)) ), "mean_confidence_uncertainty": float( np.nanmean(_confidence_uncertainty(y_prob)) ), # Compatibility field name retained for downstream code. "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float(np.nanmean(y_std)), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(model.parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma) example_rows.append((float(sigma), noisy_mri.detach().cpu())) _save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", ) else: raise ValueError(f"Unsupported backend for noise analysis: {backend}") df = pd.DataFrame(rows).sort_values("sigma") csv_path = output_dir / "noise_sensitivity.csv" df.to_csv(csv_path, index=False) fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(df["sigma"], df["accuracy"], marker="o", label="accuracy") ax.plot(df["sigma"], df["f1"], marker="s", label="f1") ax.plot(df["sigma"], df["ece"], marker="^", label="ece") ax.set_xlabel("Gaussian Noise Sigma") ax.set_ylabel("Score") ax.set_title(f"Noise Sensitivity ({backend})") ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() plot_path = output_dir / "noise_sensitivity.png" fig.savefig(plot_path) plt.close(fig) fig_u, ax_u = plt.subplots(figsize=(10, 5)) ax_u.plot( df["sigma"], df["mean_confidence_uncertainty"], marker="o", label="confidence_uncertainty", ) ax_u.plot(df["sigma"], df["mean_std"], marker="s", label="std_uncertainty") ax_u.set_xlabel("Gaussian Noise Sigma") ax_u.set_ylabel("Uncertainty") ax_u.set_title(f"Uncertainty vs Noise ({backend})") ax_u.grid(True, alpha=0.3) ax_u.legend() fig_u.tight_layout() uncertainty_plot_path = output_dir / "noise_uncertainty.png" fig_u.savefig(uncertainty_plot_path) plt.close(fig_u) fig_c, ax_c = plt.subplots(figsize=(10, 5)) ax_c.plot( df["sigma"], df["mean_confidence_certainty"], marker="o", label="confidence_certainty", ) ax_c.set_xlabel("Gaussian Noise Sigma") ax_c.set_ylabel("Certainty") ax_c.set_title(f"Confidence Certainty vs Noise ({backend})") ax_c.grid(True, alpha=0.3) ax_c.legend() fig_c.tight_layout() certainty_plot_path = output_dir / "noise_confidence_certainty.png" fig_c.savefig(certainty_plot_path) plt.close(fig_c) out = { "table": str(csv_path), "plot": str(plot_path), "uncertainty_plot": str(uncertainty_plot_path), "certainty_plot": str(certainty_plot_path), "noise_sigmas": noise_sigmas, } write_json(output_dir / "noise_summary.json", out) return out