| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- # pyright: basic
- from __future__ import annotations
- from pathlib import Path
- from typing import Any
- import matplotlib.pyplot as plt
- import numpy as np
- import pandas as pd
- import torch
- from bayesian_torch.utils.util import predictive_entropy
- from data.dataset import (
- divide_dataset_by_patient_id,
- initalize_dataloaders,
- load_adni_data_from_file,
- )
- from model.cnn import CNN3D
- from .metrics import calibration_stats, performance_at_threshold
- from .runtime import write_json
- def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None:
- model.eval()
- def _central_slice(volume: torch.Tensor) -> np.ndarray:
- tensor = volume.detach().cpu()
- if tensor.ndim == 5:
- tensor = tensor[0]
- if tensor.ndim == 4:
- tensor = tensor[0]
- if tensor.ndim != 3:
- raise ValueError(
- f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
- )
- center_index = tensor.shape[0] // 2
- return tensor[center_index].numpy().astype(float)
- def _normalize_for_display(image: np.ndarray) -> np.ndarray:
- low = float(np.percentile(image, 1))
- high = float(np.percentile(image, 99))
- if high <= low:
- return np.zeros_like(image, dtype=float)
- clipped = np.clip(image, low, high)
- return (clipped - low) / (high - low)
- def _save_noise_example_grid(
- original_mri: torch.Tensor,
- noisy_by_sigma: list[tuple[float, torch.Tensor]],
- output_path: Path,
- title: str,
- ) -> None:
- if not noisy_by_sigma:
- return
- original_slice = _normalize_for_display(_central_slice(original_mri))
- n_rows = len(noisy_by_sigma)
- fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows))
- if n_rows == 1:
- axes = np.array([axes])
- for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma):
- noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
- ax_orig, ax_noisy = axes[row_idx]
- ax_orig.imshow(original_slice, cmap="gray")
- ax_orig.set_title(f"Original")
- ax_orig.axis("off")
- ax_noisy.imshow(noisy_slice, cmap="gray")
- ax_noisy.set_title(f"Noisy sigma={sigma:g}")
- ax_noisy.axis("off")
- fig.suptitle(title)
- fig.tight_layout()
- output_path.parent.mkdir(parents=True, exist_ok=True)
- fig.savefig(output_path)
- plt.close(fig)
- def _confidence_certainty(y_prob: np.ndarray) -> np.ndarray:
- # 2 * |p - 0.5| maps 0 -> very uncertain and 1 -> very certain.
- return 2.0 * np.abs(y_prob - 0.5)
- def _confidence_uncertainty(y_prob: np.ndarray) -> np.ndarray:
- # Invert certainty so that larger values mean more uncertainty, matching std.
- return 1.0 - _confidence_certainty(y_prob)
- def _apply_scaled_noise(volume: torch.Tensor, sigma: float) -> torch.Tensor:
- # Make sigma dimensionless with respect to MRI intensity scale.
- # sigma=1.0 means noise std equals total MRI intensity range, which is 65535 .
- return volume + (torch.randn_like(volume) * sigma * 65535)
- def _xls_pre(df: pd.DataFrame) -> pd.DataFrame:
- data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
- data["Sex"] = data["Sex"].astype(str).str.strip()
- data = data.replace({"M": 0, "F": 1})
- return data
- def _build_test_loader(
- config: dict[str, Any], root_dir: Path
- ) -> torch.utils.data.DataLoader:
- mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
- xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
- dataset = load_adni_data_from_file(
- mri_files,
- xls_file,
- device=config["training"]["device"],
- xls_preprocessor=_xls_pre,
- )
- ptid_df = pd.read_csv(xls_file)
- ptid_df.columns = ptid_df.columns.str.strip()
- ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
- subset=["Image Data ID", "PTID"]
- )
- ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
- ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
- ptid_df = ptid_df[ptid_df["PTID"] != ""]
- ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
- splits = divide_dataset_by_patient_id(
- dataset,
- ptids,
- tuple(config["data"]["data_splits"]),
- seed=int(config["data"]["seed"]),
- )
- _, _, test_loader = initalize_dataloaders(
- splits,
- batch_size=int(config["training"]["batch_size"]),
- )
- return test_loader
- def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
- model_dir = Path(config["output"]["ensemble_path"])
- model_files = sorted(model_dir.glob("model_run_*.pt"))
- if not model_files:
- raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
- models: list[torch.nn.Module] = []
- for model_file in model_files:
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- model.load_state_dict(
- torch.load(model_file, map_location=config["training"]["device"]),
- strict=False,
- )
- model.eval()
- models.append(model)
- return models
- def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
- device = str(config["training"]["device"])
- try:
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
- except ImportError as e:
- raise ImportError(
- "bayesian_torch is required for bayesian noise analysis"
- ) from e
- model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
- if not model_path.exists():
- raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
- model = (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- prior_params: dict[str, float | bool | str] = {
- "prior_mu": 0.0,
- "prior_sigma": 1.0,
- "posterior_mu_init": 0.0,
- "posterior_rho_init": -3.0,
- "type": "Reparameterization",
- "moped_enable": False,
- "moped_delta": 0.5,
- }
- dnn_to_bnn(model, prior_params)
- model.to(device)
- model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
- model.to(device)
- _enable_bayesian_sampling_mode(model)
- return model
- def _infer_with_noise_ensemble(
- test_loader: torch.utils.data.DataLoader,
- models: list[torch.nn.Module],
- sigma: float,
- class_index: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- if not models:
- raise ValueError("No ensemble models were provided for noise inference")
- device = next(models[0].parameters()).device
- all_probs: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma)
- preds = []
- for model in models:
- out = model((noisy, xls_device))
- preds.append(out[:, class_index].detach().cpu().numpy())
- pred_mat = np.stack(preds, axis=0)
- mean = pred_mat.mean(axis=0)
- std = pred_mat.std(axis=0)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_stds.extend(std.tolist())
- all_true.extend(true.tolist())
- return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
- def _infer_with_noise_bayesian(
- test_loader: torch.utils.data.DataLoader,
- model: torch.nn.Module,
- sigma: float,
- class_index: int,
- mc_passes: int,
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
- device = next(model.parameters()).device
- all_probs: list[float] = []
- all_stds: list[float] = []
- all_true: list[int] = []
- with torch.no_grad():
- for mri, xls, labels, _ in test_loader:
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- labels_device = labels.to(device)
- noisy = _apply_scaled_noise(mri_device, sigma)
- draws = []
- for _ in range(mc_passes):
- out = model((noisy, xls_device))
- draws.append(out.detach().cpu().numpy())
- draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
- mean = draw_mat.mean(axis=0)[:, class_index]
- entropy_uncertainty = predictive_entropy(draw_mat)
- true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
- all_probs.extend(mean.tolist())
- all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
- all_true.extend(true.tolist())
- return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
- def run_noise_analysis(
- config: dict[str, Any],
- root_dir: Path,
- backend: str,
- output_dir: Path,
- class_index: int,
- noise_sigmas: list[float],
- threshold: float,
- calibration_bins: int,
- bayesian_mc_passes: int,
- ) -> dict[str, Any]:
- test_loader = _build_test_loader(config, root_dir)
- examples_dir = output_dir / "noise_examples"
- examples_dir.mkdir(parents=True, exist_ok=True)
- rows: list[dict[str, Any]] = []
- if backend == "ensemble":
- models = _load_ensemble_models(config)
- example_rows: list[tuple[float, torch.Tensor]] = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_std = _infer_with_noise_ensemble(
- test_loader,
- models,
- sigma,
- class_index=class_index,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "std",
- "sigma": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "ece": float(cal["ece"]),
- "mce": float(cal["mce"]),
- "mean_confidence_certainty": float(
- np.nanmean(_confidence_certainty(y_prob))
- ),
- "mean_confidence_uncertainty": float(
- np.nanmean(_confidence_uncertainty(y_prob))
- ),
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float("nan"),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(models[0].parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- _save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- )
- elif backend == "bayesian":
- model = _load_bayesian_model(config)
- example_rows = []
- for sigma in noise_sigmas:
- y_true, y_prob, y_std = _infer_with_noise_bayesian(
- test_loader,
- model,
- sigma,
- class_index=class_index,
- mc_passes=bayesian_mc_passes,
- )
- perf = performance_at_threshold(y_true, y_prob, threshold)
- cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
- rows.append(
- {
- "uncertainty_metric": "predictive_entropy",
- "sigma": float(sigma),
- "accuracy": float(perf["accuracy"]),
- "f1": float(perf["f1"]),
- "ece": float(cal["ece"]),
- "mce": float(cal["mce"]),
- "mean_confidence_certainty": float(
- np.nanmean(_confidence_certainty(y_prob))
- ),
- "mean_confidence_uncertainty": float(
- np.nanmean(_confidence_uncertainty(y_prob))
- ),
- # Compatibility field name retained for downstream code.
- "mean_std": float(np.nanmean(y_std)),
- "mean_predictive_entropy": float(np.nanmean(y_std)),
- }
- )
- with torch.no_grad():
- sample = next(iter(test_loader))
- original_mri = sample[0]
- device = next(model.parameters()).device
- original_device = original_mri.float().to(device)
- for sigma in noise_sigmas:
- noisy_mri = _apply_scaled_noise(original_device, sigma)
- example_rows.append((float(sigma), noisy_mri.detach().cpu()))
- _save_noise_example_grid(
- original_mri=original_mri,
- noisy_by_sigma=example_rows,
- output_path=examples_dir / f"{backend}_noise_examples.png",
- title=f"{backend.title()} Noise Examples",
- )
- else:
- raise ValueError(f"Unsupported backend for noise analysis: {backend}")
- df = pd.DataFrame(rows).sort_values("sigma")
- csv_path = output_dir / "noise_sensitivity.csv"
- df.to_csv(csv_path, index=False)
- fig, ax = plt.subplots(figsize=(10, 5))
- ax.plot(df["sigma"], df["accuracy"], marker="o", label="accuracy")
- ax.plot(df["sigma"], df["f1"], marker="s", label="f1")
- ax.plot(df["sigma"], df["ece"], marker="^", label="ece")
- ax.set_xlabel("Gaussian Noise Sigma")
- ax.set_ylabel("Score")
- ax.set_title(f"Noise Sensitivity ({backend})")
- ax.grid(True, alpha=0.3)
- ax.legend()
- fig.tight_layout()
- plot_path = output_dir / "noise_sensitivity.png"
- fig.savefig(plot_path)
- plt.close(fig)
- fig_u, ax_u = plt.subplots(figsize=(10, 5))
- ax_u.plot(
- df["sigma"],
- df["mean_confidence_uncertainty"],
- marker="o",
- label="confidence_uncertainty",
- )
- ax_u.plot(df["sigma"], df["mean_std"], marker="s", label="std_uncertainty")
- ax_u.set_xlabel("Gaussian Noise Sigma")
- ax_u.set_ylabel("Uncertainty")
- ax_u.set_title(f"Uncertainty vs Noise ({backend})")
- ax_u.grid(True, alpha=0.3)
- ax_u.legend()
- fig_u.tight_layout()
- uncertainty_plot_path = output_dir / "noise_uncertainty.png"
- fig_u.savefig(uncertainty_plot_path)
- plt.close(fig_u)
- fig_c, ax_c = plt.subplots(figsize=(10, 5))
- ax_c.plot(
- df["sigma"],
- df["mean_confidence_certainty"],
- marker="o",
- label="confidence_certainty",
- )
- ax_c.set_xlabel("Gaussian Noise Sigma")
- ax_c.set_ylabel("Certainty")
- ax_c.set_title(f"Confidence Certainty vs Noise ({backend})")
- ax_c.grid(True, alpha=0.3)
- ax_c.legend()
- fig_c.tight_layout()
- certainty_plot_path = output_dir / "noise_confidence_certainty.png"
- fig_c.savefig(certainty_plot_path)
- plt.close(fig_c)
- out = {
- "table": str(csv_path),
- "plot": str(plot_path),
- "uncertainty_plot": str(uncertainty_plot_path),
- "certainty_plot": str(certainty_plot_path),
- "noise_sigmas": noise_sigmas,
- }
- write_json(output_dir / "noise_summary.json", out)
- return out
|