| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- # pyright: basic
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- import numpy as np
- import pandas as pd
- import torch
- from bayesian_torch.utils.util import predictive_entropy
- from model.cnn import CNN3D
- from .data_pipeline import build_dataset_and_test_loader
- from .metrics import calibration_stats, performance_at_threshold
- from .model_utils import configure_bayesian_sampling_mode
- from .plotting import (
- plots_dir,
- save_noise_example_grid,
- save_noise_metrics_plot,
- )
- from .runtime import write_json
- from .uncertainty import confidence_certainty, confidence_uncertainty
- def _apply_scaled_noise(
- volume: torch.Tensor, sigma: float, intensity_range: float
- ) -> torch.Tensor:
- # Scale by global MRI intensity range measured from holdout set.
- return volume + (torch.randn_like(volume) * sigma * intensity_range)
- def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
- if not noise_sigmas:
- raise ValueError("noise_sigmas must contain at least one value")
- ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
- if len(ordered) == 1:
- return [float(ordered[0])]
- uniform = np.linspace(
- float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
- )
- return [float(s) for s in uniform]
- def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
- model_dir = Path(config["output"]["ensemble_path"])
- model_files = sorted(model_dir.glob("model_run_*.pt"))
- if not model_files:
- raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
- models: list[torch.nn.Module] = []
- for model_file in model_files:
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- model.load_state_dict(
- torch.load(model_file, map_location=config["training"]["device"]),
- strict=False,
- )
- model.eval()
- models.append(model)
- return models
- def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
- device = str(config["training"]["device"])
- try:
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
- except ImportError as e:
- raise ImportError(
- "bayesian_torch is required for bayesian noise analysis"
- ) from e
- model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
- if not model_path.exists():
- raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- prior_params: dict[str, float | bool | str] = {
- "prior_mu": 0.0,
- "prior_sigma": 1.0,
- "posterior_mu_init": 0.0,
- "posterior_rho_init": -3.0,
- "type": "Reparameterization",
- "moped_enable": False,
- "moped_delta": 0.5,
- }
- dnn_to_bnn(model, prior_params)
- model.to(device)
- model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
- model.to(device)
- configure_bayesian_sampling_mode(model, stochastic=False)
- return model
- def _infer_with_noise_ensemble(
- test_loader: torch.utils.data.DataLoader,
- models: list[torch.nn.Module],
- sigma: float,
- intensity_range: float,
- class_index: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- if not models:
- raise ValueError("No ensemble models were provided for noise inference")
- device = next(models[0].parameters()).device
- all_probs: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
- preds = []
- for model in models:
- out = model((noisy, xls_device))
- preds.append(out[:, class_index].detach().cpu().numpy())
- pred_mat = np.stack(preds, axis=0)
- mean = pred_mat.mean(axis=0)
- std = pred_mat.std(axis=0)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_stds.extend(std.tolist())
- all_true.extend(true.tolist())
- return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
- def _infer_with_noise_bayesian(
- test_loader: torch.utils.data.DataLoader,
- model: torch.nn.Module,
- sigma: float,
- intensity_range: float,
- class_index: int,
- mc_passes: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- device = next(model.parameters()).device
- all_probs: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
- draws = []
- for _ in range(mc_passes):
- out = model((noisy, xls_device))
- draws.append(out.detach().cpu().numpy())
- draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
- mean = draw_mat.mean(axis=0)[:, class_index]
- entropy_uncertainty = predictive_entropy(draw_mat)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
- all_true.extend(true.tolist())
- return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
- def run_noise_analysis(
- config: dict[str, Any],
- root_dir: Path,
- backend: str,
- output_dir: Path,
- class_index: int,
- noise_sigmas: list[float],
- threshold: float,
- calibration_bins: int,
- bayesian_mc_passes: int,
- ) -> dict[str, Any]:
- noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
- dataset, test_loader = build_dataset_and_test_loader(
- config=config,
- root_dir=root_dir,
- seed=int(config["data"]["seed"]),
- )
- # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
- # dataset
- # )
- # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
- intensity_range = 10_000.0
- out_plots_dir = plots_dir(output_dir)
- examples_dir = out_plots_dir / "noise_examples"
- examples_dir.mkdir(parents=True, exist_ok=True)
- rows: list[dict[str, Any]] = []
- if backend == "ensemble":
- models = _load_ensemble_models(config)
- example_rows: list[tuple[float, torch.Tensor]] = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_std = _infer_with_noise_ensemble(
- test_loader,
- models,
- sigma,
- intensity_range,
- class_index=class_index,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "std",
- "noise_factor": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "ece": float(cal["ece"]),
- "mce": float(cal["mce"]),
- "mean_confidence_certainty": float(
- np.nanmean(confidence_certainty(y_prob))
- ),
- "mean_confidence_uncertainty": float(
- np.nanmean(confidence_uncertainty(y_prob))
- ),
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float("nan"),
- "mri_intensity_range": float(intensity_range),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(models[0].parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- )
- elif backend == "bayesian":
- model = _load_bayesian_model(config)
- example_rows = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_std = _infer_with_noise_bayesian(
- test_loader,
- model,
- sigma,
- intensity_range,
- class_index=class_index,
- mc_passes=bayesian_mc_passes,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "predictive_entropy",
- "noise_factor": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "ece": float(cal["ece"]),
- "mce": float(cal["mce"]),
- "mean_confidence_certainty": float(
- np.nanmean(confidence_certainty(y_prob))
- ),
- "mean_confidence_uncertainty": float(
- np.nanmean(confidence_uncertainty(y_prob))
- ),
- # Compatibility field name retained for downstream code.
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float(np.nanmean(y_std)),
- "mri_intensity_range": float(intensity_range),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(model.parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- )
- else:
- raise ValueError(f"Unsupported backend for noise analysis: {backend}")
- df = pd.DataFrame(rows).sort_values("noise_factor")
- csv_path = output_dir / "noise_sensitivity.csv"
- df.to_csv(csv_path, index=False)
- plot_path = out_plots_dir / "noise_sensitivity.png"
- uncertainty_plot_path = out_plots_dir / "noise_uncertainty.png"
- certainty_plot_path = out_plots_dir / "noise_confidence_certainty.png"
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y_by_label=[
- (df["accuracy"], "o", "accuracy"),
- (df["f1"], "s", "f1"),
- (df["ece"], "^", "ece"),
- ],
- x_label="Gaussian Noise Factor",
- y_label="Score",
- title=f"Noise Sensitivity ({backend})",
- output_path=plot_path,
- )
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y_by_label=[
- (df["mean_confidence_uncertainty"], "o", "confidence_uncertainty"),
- (df["mean_std"], "s", "std_uncertainty"),
- ],
- x_label="Gaussian Noise Factor",
- y_label="Uncertainty",
- title=f"Uncertainty vs Noise ({backend})",
- output_path=uncertainty_plot_path,
- )
- save_noise_metrics_plot(
- x=df["noise_factor"],
- y_by_label=[
- (df["mean_confidence_certainty"], "o", "confidence_certainty"),
- ],
- x_label="Gaussian Noise Factor",
- y_label="Certainty",
- title=f"Confidence Certainty vs Noise ({backend})",
- output_path=certainty_plot_path,
- )
- out = {
- "table": str(csv_path),
- "plot": str(plot_path),
- "uncertainty_plot": str(uncertainty_plot_path),
- "certainty_plot": str(certainty_plot_path),
- "noise_factors": noise_sigmas,
- "noise_sigmas": noise_sigmas,
- # "mri_intensity_min": float(intensity_min),
- # "mri_intensity_max": float(intensity_max),
- "mri_intensity_range": float(intensity_range),
- }
- write_json(output_dir / "noise_summary.json", out)
- return out
|