# pyright: basic from __future__ import annotations from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch from bayesian_torch.utils.util import predictive_entropy from model.cnn import CNN3D from .data_pipeline import build_holdout_loader from .metrics import calibration_stats, performance_at_threshold from .model_utils import configure_bayesian_sampling_mode from .plotting import ( plots_dir, save_clean_scan_image, save_noise_example_grid, save_metric_pair_plot, save_noise_metrics_plot, ) from .runtime import write_json def _apply_scaled_noise( volume: torch.Tensor, sigma: float, intensity_range: float ) -> torch.Tensor: # Scale by global MRI intensity range measured from holdout set. return volume + (torch.randn_like(volume) * sigma * intensity_range) def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]: if not noise_sigmas: raise ValueError("noise_sigmas must contain at least one value") ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float) if len(ordered) == 1: return [float(ordered[0])] uniform = np.linspace( float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float ) return [float(s) for s in uniform] def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]: model_dir = Path(config["output"]["ensemble_path"]) model_files = sorted(model_dir.glob("model_run_*.pt")) if not model_files: raise FileNotFoundError(f"No ensemble model files found in {model_dir}") models: list[torch.nn.Module] = [] for model_file in model_files: model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) model.load_state_dict( torch.load(model_file, map_location=config["training"]["device"]), strict=False, ) model.eval() models.append(model) return models def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module: device = str(config["training"]["device"]) try: from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped] except ImportError as e: raise ImportError( "bayesian_torch is required for bayesian noise analysis" ) from e model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt" if not model_path.exists(): raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}") model = ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) prior_params: dict[str, float | bool | str] = { "prior_mu": 0.0, "prior_sigma": 1.0, "posterior_mu_init": 0.0, "posterior_rho_init": -3.0, "type": "Reparameterization", "moped_enable": False, "moped_delta": 0.5, } dnn_to_bnn(model, prior_params) model.to(device) model.load_state_dict(torch.load(model_path, map_location=device), strict=False) model.to(device) configure_bayesian_sampling_mode(model, stochastic=False) return model def _infer_with_noise_ensemble( test_loader: torch.utils.data.DataLoader, models: list[torch.nn.Module], sigma: float, intensity_range: float, class_index: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if not models: raise ValueError("No ensemble models were provided for noise inference") device = next(models[0].parameters()).device all_probs: list[float] = [] all_confidence: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma, intensity_range) preds = [] for model in models: out = model((noisy, xls_device)) preds.append(out[:, class_index].detach().cpu().numpy()) pred_mat = np.stack(preds, axis=0) mean = pred_mat.mean(axis=0) confidence = np.abs(pred_mat - 0.5).mean(axis=0) std = pred_mat.std(axis=0) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_confidence.extend(confidence.tolist()) all_stds.extend(std.tolist()) all_true.extend(true.tolist()) return ( np.asarray(all_true), np.asarray(all_probs), np.asarray(all_confidence), np.asarray(all_stds), ) def _infer_with_noise_bayesian( test_loader: torch.utils.data.DataLoader, model: torch.nn.Module, sigma: float, intensity_range: float, class_index: int, mc_passes: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: device = next(model.parameters()).device all_probs: list[float] = [] all_confidence: list[float] = [] all_stds: list[float] = [] all_true: list[int] = [] with torch.no_grad(): for mri, xls, labels, _ in test_loader: mri_device = mri.float().to(device) xls_device = xls.float().to(device) labels_device = labels.to(device) noisy = _apply_scaled_noise(mri_device, sigma, intensity_range) draws = [] for _ in range(mc_passes): out = model((noisy, xls_device)) draws.append(out.detach().cpu().numpy()) draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes) mean = draw_mat.mean(axis=0)[:, class_index] confidence = np.abs(draw_mat[:, :, class_index] - 0.5).mean(axis=0) entropy_uncertainty = predictive_entropy(draw_mat) true = labels_device[:, class_index].detach().cpu().numpy().astype(int) all_probs.extend(mean.tolist()) all_confidence.extend(np.asarray(confidence, dtype=float).tolist()) all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist()) all_true.extend(true.tolist()) return ( np.asarray(all_true), np.asarray(all_probs), np.asarray(all_confidence), np.asarray(all_stds), ) def run_noise_analysis( config: dict[str, Any], root_dir: Path, backend: str, output_dir: Path, class_index: int, noise_sigmas: list[float], threshold: float, calibration_bins: int, bayesian_mc_passes: int, ) -> dict[str, Any]: noise_sigmas = _uniform_sigma_schedule(noise_sigmas) test_loader = build_holdout_loader( config=config, root_dir=root_dir, seed=int(config["data"]["seed"]), ) # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range( # dataset # ) # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent intensity_range = 10_000.0 out_plots_dir = plots_dir(output_dir) examples_dir = out_plots_dir / "noise_examples" examples_dir.mkdir(parents=True, exist_ok=True) rows: list[dict[str, Any]] = [] if backend == "ensemble": models = _load_ensemble_models(config) example_rows: list[tuple[float, torch.Tensor]] = [] for sigma in noise_sigmas: y_true, y_prob, y_confidence, y_std = _infer_with_noise_ensemble( test_loader, models, sigma, intensity_range, class_index=class_index, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "std", "noise_factor": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "mce": float(cal["mce"]), "mean_confidence": float(np.nanmean(y_confidence)), "mean_model_output_probability": float(np.nanmean(y_prob)), "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float("nan"), "mri_intensity_range": float(intensity_range), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(models[0].parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range) example_rows.append((float(sigma), noisy_mri.detach().cpu())) save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", max_images=9, n_rows=2, ) save_clean_scan_image( original_mri=original_mri, output_path=examples_dir / f"{backend}_clean_scan_example.png", ) elif backend == "bayesian": model = _load_bayesian_model(config) example_rows = [] for sigma in noise_sigmas: y_true, y_prob, y_confidence, y_std = _infer_with_noise_bayesian( test_loader, model, sigma, intensity_range, class_index=class_index, mc_passes=bayesian_mc_passes, ) perf = performance_at_threshold(y_true, y_prob, threshold) cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins) rows.append( { "uncertainty_metric": "predictive_entropy", "noise_factor": float(sigma), "accuracy": float(perf["accuracy"]), "f1": float(perf["f1"]), "mce": float(cal["mce"]), "mean_confidence": float(np.nanmean(y_confidence)), "mean_model_output_probability": float(np.nanmean(y_prob)), # Compatibility field name retained for downstream code. "mean_std": float(np.nanmean(y_std)), "mean_predictive_entropy": float(np.nanmean(y_std)), "mri_intensity_range": float(intensity_range), } ) with torch.no_grad(): sample = next(iter(test_loader)) original_mri = sample[0] device = next(model.parameters()).device original_device = original_mri.float().to(device) for sigma in noise_sigmas: noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range) example_rows.append((float(sigma), noisy_mri.detach().cpu())) save_noise_example_grid( original_mri=original_mri, noisy_by_sigma=example_rows, output_path=examples_dir / f"{backend}_noise_examples.png", title=f"{backend.title()} Noise Examples", max_images=9, n_rows=2, ) save_clean_scan_image( original_mri=original_mri, output_path=examples_dir / f"{backend}_clean_scan_example.png", ) else: raise ValueError(f"Unsupported backend for noise analysis: {backend}") df = pd.DataFrame(rows).sort_values("noise_factor") csv_path = output_dir / "noise_sensitivity.csv" df.to_csv(csv_path, index=False) accuracy_plot_path = out_plots_dir / "noise_sensitivity_accuracy.png" f1_plot_path = out_plots_dir / "noise_sensitivity_f1.png" pair_plot_path = out_plots_dir / "noise_sensitivity_accuracy_f1.png" confidence_plot_path = out_plots_dir / "noise_confidence.png" confidence_uncertainty_pair_path = ( out_plots_dir / "noise_confidence_predictive_uncertainty.png" if backend == "bayesian" else out_plots_dir / "noise_confidence_standard_deviation.png" ) secondary_plot_name = ( "noise_predictive_uncertainty.png" if backend == "bayesian" else "noise_standard_deviation.png" ) secondary_plot_path = out_plots_dir / secondary_plot_name save_noise_metrics_plot( x=df["noise_factor"], y=df["accuracy"], legend_label="Accuracy", marker="o", x_label="Gaussian Noise Factor", y_label="Accuracy", title=f"Accuracy vs Noise ({backend})", output_path=accuracy_plot_path, plot_key="noise_sensitivity_accuracy", ) save_noise_metrics_plot( x=df["noise_factor"], y=df["f1"], legend_label="F1", marker="s", x_label="Gaussian Noise Factor", y_label="F1", title=f"F1 vs Noise ({backend})", output_path=f1_plot_path, plot_key="noise_sensitivity_f1", ) save_metric_pair_plot( x=df["noise_factor"], left_y=df["accuracy"], right_y=df["f1"], left_label="Accuracy", right_label="F1", x_label="Gaussian Noise Factor", y_label="Accuracy/F1", title=f"Accuracy and F1 vs Noise ({backend})", output_path=pair_plot_path, plot_key="noise_sensitivity_accuracy_f1", ) secondary_label = ( "Predictive Uncertainty" if backend == "bayesian" else "Standard Deviation" ) save_noise_metrics_plot( x=df["noise_factor"], y=df["mean_confidence"], legend_label="Confidence", marker="o", x_label="Gaussian Noise Factor", y_label="Confidence", title=f"Confidence vs Noise ({backend})", output_path=confidence_plot_path, plot_key="noise_confidence", ) save_noise_metrics_plot( x=df["noise_factor"], y=df["mean_std"], legend_label=secondary_label, marker="^", x_label="Gaussian Noise Factor", y_label=secondary_label, title=f"{secondary_label} vs Noise ({backend})", output_path=secondary_plot_path, plot_key=( "noise_predictive_uncertainty" if backend == "bayesian" else "noise_standard_deviation" ), ) save_metric_pair_plot( x=df["noise_factor"], left_y=df["mean_confidence"], right_y=df["mean_std"], left_label="Confidence", right_label=secondary_label, x_label="Gaussian Noise Factor", y_label="Confidence/Uncertainty", title=f"Confidence and {secondary_label} vs Noise ({backend})", output_path=confidence_uncertainty_pair_path, plot_key=( "noise_confidence_predictive_uncertainty" if backend == "bayesian" else "noise_confidence_standard_deviation" ), ) out = { "table": str(csv_path), "plots": { "accuracy": str(accuracy_plot_path), "f1": str(f1_plot_path), "accuracy_f1": str(pair_plot_path), "confidence": str(confidence_plot_path), ( "confidence_predictive_uncertainty" if backend == "bayesian" else "confidence_standard_deviation" ): str(confidence_uncertainty_pair_path), ( "predictive_uncertainty" if backend == "bayesian" else "standard_deviation" ): str(secondary_plot_path), }, "noise_factors": noise_sigmas, "noise_sigmas": noise_sigmas, # "mri_intensity_min": float(intensity_min), # "mri_intensity_max": float(intensity_max), "mri_intensity_range": float(intensity_range), } write_json(output_dir / "noise_summary.json", out) return out