# pyright: basic from __future__ import annotations from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch from bayesian_torch.utils.util import predictive_entropy from model.cnn import CNN3D from .data_pipeline import build_dataset_and_test_loader from .metrics import calibration_stats, performance_at_threshold from .model_utils import configure_bayesian_sampling_mode from .plotting import ( plots_dir, save_noise_example_grid, save_noise_metrics_plot, ) from .runtime import write_json from .uncertainty import confidence_certainty, confidence_uncertainty def _apply_scaled_noise( volume: torch.Tensor, sigma: float, intensity_range: float ) -> torch.Tensor: # Scale by global MRI intensity range measured from holdout set. return volume + (torch.randn_like(volume) * sigma * intensity_range) def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]: if not noise_sigmas: raise ValueError("noise_sigmas must contain at least one value") ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float) if len(ordered) == 1: return [float(ordered[0])] uniform = np.linspace( float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float ) return [float(s) for s in uniform] def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]: model_dir = Path(config["output"]["ensemble_path"]) model_files = sorted(model_dir.glob("model_run_*.pt")) if not model_files: raise FileNotFoundError(f"No ensemble model files found in {model_dir}") models: list[torch.nn.Module] = [] for model_file in model_files: model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) model.load_state_dict( torch.load(model_file, map_location=config["training"]["device"]), strict=False, ) model.eval() models.append(model) return models def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module: device = str(config["training"]["device"]) try: from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped] except ImportError as e: raise ImportError( "bayesian_torch is required for bayesian noise analysis" ) from e model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt" if not model_path.exists(): raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}") model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) prior_params: dict[str, float | bool | str] = { "prior_mu": 0.0, "prior_sigma": 1.0, "posterior_mu_init": 0.0, "posterior_rho_init": -3.0, "type": "Reparameterization", "moped_enable": False, "moped_delta": 0.5, } dnn_to_bnn(model, prior_params) model.to(device) model.load_state_dict(torch.load(model_path, map_location=device), strict=False) model.to(device) configure_bayesian_sampling_mode(model, stochastic=False) return model def _infer_with_noise_ensemble( test_loader: torch.utils.data.DataLoader, models: list[torch.nn.Module], sigma: float, intensity_range: float, class_index: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: if not models: raise ValueError("No ensemble models were provided for noise inference") device = next(models[0].parameters()).device all_probs: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma, intensity_range) preds = [] for model in models: out = model((noisy, xls_device)) preds.append(out[:, class_index].detach().cpu().numpy()) pred_mat = np.stack(preds, axis=0) mean = pred_mat.mean(axis=0) std = pred_mat.std(axis=0) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_stds.extend(std.tolist()) all_true.extend(true.tolist()) return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds) def _infer_with_noise_bayesian( test_loader: torch.utils.data.DataLoader, model: torch.nn.Module, sigma: float, intensity_range: float, class_index: int, mc_passes: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: device = next(model.parameters()).device all_probs: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma, intensity_range) draws = [] for _ in range(mc_passes): out = model((noisy, xls_device)) draws.append(out.detach().cpu().numpy()) draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes) mean = draw_mat.mean(axis=0)[:, class_index] entropy_uncertainty = predictive_entropy(draw_mat) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist()) all_true.extend(true.tolist()) return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds) def run_noise_analysis( config: dict[str, Any], root_dir: Path, backend: str, output_dir: Path, class_index: int, noise_sigmas: list[float], threshold: float, calibration_bins: int, bayesian_mc_passes: int, ) -> dict[str, Any]: noise_sigmas = _uniform_sigma_schedule(noise_sigmas) dataset, test_loader = build_dataset_and_test_loader( config=config, root_dir=root_dir, seed=int(config["data"]["seed"]), ) # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range( # dataset # ) # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent intensity_range = 10_000.0 out_plots_dir = plots_dir(output_dir) examples_dir = out_plots_dir / "noise_examples" examples_dir.mkdir(parents=True, exist_ok=True) rows: list[dict[str, Any]] = [] if backend == "ensemble": models = _load_ensemble_models(config) example_rows: list[tuple[float, torch.Tensor]] = [] for sigma in noise_sigmas: y_true, y_prob, y_std = _infer_with_noise_ensemble( test_loader, models, sigma, intensity_range, class_index=class_index, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "std", "noise_factor": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "ece": float(cal["ece"]), "mce": float(cal["mce"]), "mean_confidence_certainty": float( np.nanmean(confidence_certainty(y_prob)) ), "mean_confidence_uncertainty": float( np.nanmean(confidence_uncertainty(y_prob)) ), "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float("nan"), "mri_intensity_range": float(intensity_range), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(models[0].parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range) example_rows.append((float(sigma), noisy_mri.detach().cpu())) save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", ) elif backend == "bayesian": model = _load_bayesian_model(config) example_rows = [] for sigma in noise_sigmas: y_true, y_prob, y_std = _infer_with_noise_bayesian( test_loader, model, sigma, intensity_range, class_index=class_index, mc_passes=bayesian_mc_passes, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "predictive_entropy", "noise_factor": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "ece": float(cal["ece"]), "mce": float(cal["mce"]), "mean_confidence_certainty": float( np.nanmean(confidence_certainty(y_prob)) ), "mean_confidence_uncertainty": float( np.nanmean(confidence_uncertainty(y_prob)) ), # Compatibility field name retained for downstream code. "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float(np.nanmean(y_std)), "mri_intensity_range": float(intensity_range), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(model.parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range) example_rows.append((float(sigma), noisy_mri.detach().cpu())) save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", ) else: raise ValueError(f"Unsupported backend for noise analysis: {backend}") df = pd.DataFrame(rows).sort_values("noise_factor") csv_path = output_dir / "noise_sensitivity.csv" df.to_csv(csv_path, index=False) plot_path = out_plots_dir / "noise_sensitivity.png" uncertainty_plot_path = out_plots_dir / "noise_uncertainty.png" certainty_plot_path = out_plots_dir / "noise_confidence_certainty.png" save_noise_metrics_plot( x=df["noise_factor"], y_by_label=[ (df["accuracy"], "o", "accuracy"), (df["f1"], "s", "f1"), (df["ece"], "^", "ece"), ], x_label="Gaussian Noise Factor", y_label="Score", title=f"Noise Sensitivity ({backend})", output_path=plot_path, ) save_noise_metrics_plot( x=df["noise_factor"], y_by_label=[ (df["mean_confidence_uncertainty"], "o", "confidence_uncertainty"), (df["mean_std"], "s", "std_uncertainty"), ], x_label="Gaussian Noise Factor", y_label="Uncertainty", title=f"Uncertainty vs Noise ({backend})", output_path=uncertainty_plot_path, ) save_noise_metrics_plot( x=df["noise_factor"], y_by_label=[ (df["mean_confidence_certainty"], "o", "confidence_certainty"), ], x_label="Gaussian Noise Factor", y_label="Certainty", title=f"Confidence Certainty vs Noise ({backend})", output_path=certainty_plot_path, ) out = { "table": str(csv_path), "plot": str(plot_path), "uncertainty_plot": str(uncertainty_plot_path), "certainty_plot": str(certainty_plot_path), "noise_factors": noise_sigmas, "noise_sigmas": noise_sigmas, # "mri_intensity_min": float(intensity_min), # "mri_intensity_max": float(intensity_max), "mri_intensity_range": float(intensity_range), } write_json(output_dir / "noise_summary.json", out) return out