| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467 |
- # pyright: basic
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- import numpy as np
- import pandas as pd
- import torch
- from bayesian_torch.utils.util import predictive_entropy
- from model.cnn import CNN3D
- from .data_pipeline import build_holdout_loader
- from .metrics import calibration_stats, performance_at_threshold
- from .model_utils import configure_bayesian_sampling_mode
- from .plotting import (
- plots_dir,
- save_clean_scan_image,
- save_noise_example_grid,
- save_metric_pair_plot,
- save_noise_metrics_plot,
- )
- from .runtime import write_json
- def _apply_scaled_noise(
- volume: torch.Tensor, sigma: float, intensity_range: float
- ) -> torch.Tensor:
- # Scale by global MRI intensity range measured from holdout set.
- return volume + (torch.randn_like(volume) * sigma * intensity_range)
- def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
- if not noise_sigmas:
- raise ValueError("noise_sigmas must contain at least one value")
- ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
- if len(ordered) == 1:
- return [float(ordered[0])]
- uniform = np.linspace(
- float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
- )
- return [float(s) for s in uniform]
- def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
- model_dir = Path(config["output"]["ensemble_path"])
- model_files = sorted(model_dir.glob("model_run_*.pt"))
- if not model_files:
- raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
- models: list[torch.nn.Module] = []
- for model_file in model_files:
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- model.load_state_dict(
- torch.load(model_file, map_location=config["training"]["device"]),
- strict=False,
- )
- model.eval()
- models.append(model)
- return models
- def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
- device = str(config["training"]["device"])
- try:
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
- except ImportError as e:
- raise ImportError(
- "bayesian_torch is required for bayesian noise analysis"
- ) from e
- model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
- if not model_path.exists():
- raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- prior_params: dict[str, float | bool | str] = {
- "prior_mu": 0.0,
- "prior_sigma": 1.0,
- "posterior_mu_init": 0.0,
- "posterior_rho_init": -3.0,
- "type": "Reparameterization",
- "moped_enable": False,
- "moped_delta": 0.5,
- }
- dnn_to_bnn(model, prior_params)
- model.to(device)
- model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
- model.to(device)
- configure_bayesian_sampling_mode(model, stochastic=False)
- return model
- def _infer_with_noise_ensemble(
- test_loader: torch.utils.data.DataLoader,
- models: list[torch.nn.Module],
- sigma: float,
- intensity_range: float,
- class_index: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- if not models:
- raise ValueError("No ensemble models were provided for noise inference")
- device = next(models[0].parameters()).device
- all_probs: list[float] = []
- all_confidence: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
- preds = []
- for model in models:
- out = model((noisy, xls_device))
- preds.append(out[:, class_index].detach().cpu().numpy())
- pred_mat = np.stack(preds, axis=0)
- mean = pred_mat.mean(axis=0)
- confidence = np.abs(pred_mat - 0.5).mean(axis=0)
- std = pred_mat.std(axis=0)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_confidence.extend(confidence.tolist())
- all_stds.extend(std.tolist())
- all_true.extend(true.tolist())
- return (
- np.asarray(all_true),
- np.asarray(all_probs),
- np.asarray(all_confidence),
- np.asarray(all_stds),
- )
- def _infer_with_noise_bayesian(
- test_loader: torch.utils.data.DataLoader,
- model: torch.nn.Module,
- sigma: float,
- intensity_range: float,
- class_index: int,
- mc_passes: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
- device = next(model.parameters()).device
- all_probs: list[float] = []
- all_confidence: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
- draws = []
- for _ in range(mc_passes):
- out = model((noisy, xls_device))
- draws.append(out.detach().cpu().numpy())
- draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
- mean = draw_mat.mean(axis=0)[:, class_index]
- confidence = np.abs(draw_mat[:, :, class_index] - 0.5).mean(axis=0)
- entropy_uncertainty = predictive_entropy(draw_mat)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_confidence.extend(np.asarray(confidence, dtype=float).tolist())
- all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
- all_true.extend(true.tolist())
- return (
- np.asarray(all_true),
- np.asarray(all_probs),
- np.asarray(all_confidence),
- np.asarray(all_stds),
- )
- def run_noise_analysis(
- config: dict[str, Any],
- root_dir: Path,
- backend: str,
- output_dir: Path,
- class_index: int,
- noise_sigmas: list[float],
- threshold: float,
- calibration_bins: int,
- bayesian_mc_passes: int,
- ) -> dict[str, Any]:
- noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
- test_loader = build_holdout_loader(
- config=config,
- root_dir=root_dir,
- seed=int(config["data"]["seed"]),
- )
- # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
- # dataset
- # )
- # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
- intensity_range = 10_000.0
- out_plots_dir = plots_dir(output_dir)
- examples_dir = out_plots_dir / "noise_examples"
- examples_dir.mkdir(parents=True, exist_ok=True)
- rows: list[dict[str, Any]] = []
- if backend == "ensemble":
- models = _load_ensemble_models(config)
- example_rows: list[tuple[float, torch.Tensor]] = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_confidence, y_std = _infer_with_noise_ensemble(
- test_loader,
- models,
- sigma,
- intensity_range,
- class_index=class_index,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "std",
- "noise_factor": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "mce": float(cal["mce"]),
- "mean_confidence": float(np.nanmean(y_confidence)),
- "mean_model_output_probability": float(np.nanmean(y_prob)),
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float("nan"),
- "mri_intensity_range": float(intensity_range),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(models[0].parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- max_images=9,
- n_rows=2,
- )
- save_clean_scan_image(
- original_mri=original_mri,
- output_path=examples_dir / f"{backend}_clean_scan_example.png",
- )
- elif backend == "bayesian":
- model = _load_bayesian_model(config)
- example_rows = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_confidence, y_std = _infer_with_noise_bayesian(
- test_loader,
- model,
- sigma,
- intensity_range,
- class_index=class_index,
- mc_passes=bayesian_mc_passes,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "predictive_entropy",
- "noise_factor": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "mce": float(cal["mce"]),
- "mean_confidence": float(np.nanmean(y_confidence)),
- "mean_model_output_probability": float(np.nanmean(y_prob)),
- # Compatibility field name retained for downstream code.
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float(np.nanmean(y_std)),
- "mri_intensity_range": float(intensity_range),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(model.parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- max_images=9,
- n_rows=2,
- )
- save_clean_scan_image(
- original_mri=original_mri,
- output_path=examples_dir / f"{backend}_clean_scan_example.png",
- )
- else:
- raise ValueError(f"Unsupported backend for noise analysis: {backend}")
- df = pd.DataFrame(rows).sort_values("noise_factor")
- csv_path = output_dir / "noise_sensitivity.csv"
- df.to_csv(csv_path, index=False)
- accuracy_plot_path = out_plots_dir / "noise_sensitivity_accuracy.png"
- f1_plot_path = out_plots_dir / "noise_sensitivity_f1.png"
- pair_plot_path = out_plots_dir / "noise_sensitivity_accuracy_f1.png"
- confidence_plot_path = out_plots_dir / "noise_confidence.png"
- confidence_uncertainty_pair_path = (
- out_plots_dir / "noise_confidence_predictive_uncertainty.png"
- if backend == "bayesian"
- else out_plots_dir / "noise_confidence_standard_deviation.png"
- )
- secondary_plot_name = (
- "noise_predictive_uncertainty.png"
- if backend == "bayesian"
- else "noise_standard_deviation.png"
- )
- secondary_plot_path = out_plots_dir / secondary_plot_name
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y=df["accuracy"],
- legend_label="Accuracy",
- marker="o",
- x_label="Gaussian Noise Factor",
- y_label="Accuracy",
- title=f"Accuracy vs Noise ({backend})",
- output_path=accuracy_plot_path,
- plot_key="noise_sensitivity_accuracy",
- )
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y=df["f1"],
- legend_label="F1",
- marker="s",
- x_label="Gaussian Noise Factor",
- y_label="F1",
- title=f"F1 vs Noise ({backend})",
- output_path=f1_plot_path,
- plot_key="noise_sensitivity_f1",
- )
- save_metric_pair_plot(
- x=df["noise_factor"],
- left_y=df["accuracy"],
- right_y=df["f1"],
- left_label="Accuracy",
- right_label="F1",
- x_label="Gaussian Noise Factor",
- y_label="Accuracy/F1",
- title=f"Accuracy and F1 vs Noise ({backend})",
- output_path=pair_plot_path,
- plot_key="noise_sensitivity_accuracy_f1",
- )
- secondary_label = (
- "Predictive Uncertainty" if backend == "bayesian" else "Standard Deviation"
- )
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y=df["mean_confidence"],
- legend_label="Confidence",
- marker="o",
- x_label="Gaussian Noise Factor",
- y_label="Confidence",
- title=f"Confidence vs Noise ({backend})",
- output_path=confidence_plot_path,
- plot_key="noise_confidence",
- )
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y=df["mean_std"],
- legend_label=secondary_label,
- marker="^",
- x_label="Gaussian Noise Factor",
- y_label=secondary_label,
- title=f"{secondary_label} vs Noise ({backend})",
- output_path=secondary_plot_path,
- plot_key=(
- "noise_predictive_uncertainty"
- if backend == "bayesian"
- else "noise_standard_deviation"
- ),
- )
- save_metric_pair_plot(
- x=df["noise_factor"],
- left_y=df["mean_confidence"],
- right_y=df["mean_std"],
- left_label="Confidence",
- right_label=secondary_label,
- x_label="Gaussian Noise Factor",
- y_label="Confidence/Uncertainty",
- title=f"Confidence and {secondary_label} vs Noise ({backend})",
- output_path=confidence_uncertainty_pair_path,
- plot_key=(
- "noise_confidence_predictive_uncertainty"
- if backend == "bayesian"
- else "noise_confidence_standard_deviation"
- ),
- )
- out = {
- "table": str(csv_path),
- "plots": {
- "accuracy": str(accuracy_plot_path),
- "f1": str(f1_plot_path),
- "accuracy_f1": str(pair_plot_path),
- "confidence": str(confidence_plot_path),
- (
- "confidence_predictive_uncertainty"
- if backend == "bayesian"
- else "confidence_standard_deviation"
- ): str(confidence_uncertainty_pair_path),
- (
- "predictive_uncertainty"
- if backend == "bayesian"
- else "standard_deviation"
- ): str(secondary_plot_path),
- },
- "noise_factors": noise_sigmas,
- "noise_sigmas": noise_sigmas,
- # "mri_intensity_min": float(intensity_min),
- # "mri_intensity_max": float(intensity_max),
- "mri_intensity_range": float(intensity_range),
- }
- write_json(output_dir / "noise_summary.json", out)
- return out
|