noise_analysis.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. from bayesian_torch.utils.util import predictive_entropy
  9. from model.cnn import CNN3D
  10. from .data_pipeline import build_dataset_and_test_loader
  11. from .metrics import calibration_stats, performance_at_threshold
  12. from .model_utils import configure_bayesian_sampling_mode
  13. from .plotting import (
  14. plots_dir,
  15. save_noise_example_grid,
  16. save_noise_metrics_plot,
  17. )
  18. from .runtime import write_json
  19. from .uncertainty import confidence_certainty, confidence_uncertainty
  20. def _apply_scaled_noise(
  21. volume: torch.Tensor, sigma: float, intensity_range: float
  22. ) -> torch.Tensor:
  23. # Scale by global MRI intensity range measured from holdout set.
  24. return volume + (torch.randn_like(volume) * sigma * intensity_range)
  25. def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
  26. if not noise_sigmas:
  27. raise ValueError("noise_sigmas must contain at least one value")
  28. ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
  29. if len(ordered) == 1:
  30. return [float(ordered[0])]
  31. uniform = np.linspace(
  32. float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
  33. )
  34. return [float(s) for s in uniform]
  35. def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
  36. model_dir = Path(config["output"]["ensemble_path"])
  37. model_files = sorted(model_dir.glob("model_run_*.pt"))
  38. if not model_files:
  39. raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
  40. models: list[torch.nn.Module] = []
  41. for model_file in model_files:
  42. model = (
  43. CNN3D(
  44. image_channels=int(config["data"]["image_channels"]),
  45. clin_data_channels=int(config["data"]["clin_data_channels"]),
  46. num_classes=int(config["data"]["num_classes"]),
  47. droprate=float(config["training"]["droprate"]),
  48. )
  49. .float()
  50. .to(config["training"]["device"])
  51. )
  52. model.load_state_dict(
  53. torch.load(model_file, map_location=config["training"]["device"]),
  54. strict=False,
  55. )
  56. model.eval()
  57. models.append(model)
  58. return models
  59. def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
  60. device = str(config["training"]["device"])
  61. try:
  62. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  63. except ImportError as e:
  64. raise ImportError(
  65. "bayesian_torch is required for bayesian noise analysis"
  66. ) from e
  67. model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
  68. if not model_path.exists():
  69. raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
  70. model = (
  71. CNN3D(
  72. image_channels=int(config["data"]["image_channels"]),
  73. clin_data_channels=int(config["data"]["clin_data_channels"]),
  74. num_classes=int(config["data"]["num_classes"]),
  75. droprate=float(config["training"]["droprate"]),
  76. )
  77. .float()
  78. .to(config["training"]["device"])
  79. )
  80. prior_params: dict[str, float | bool | str] = {
  81. "prior_mu": 0.0,
  82. "prior_sigma": 1.0,
  83. "posterior_mu_init": 0.0,
  84. "posterior_rho_init": -3.0,
  85. "type": "Reparameterization",
  86. "moped_enable": False,
  87. "moped_delta": 0.5,
  88. }
  89. dnn_to_bnn(model, prior_params)
  90. model.to(device)
  91. model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
  92. model.to(device)
  93. configure_bayesian_sampling_mode(model, stochastic=False)
  94. return model
  95. def _infer_with_noise_ensemble(
  96. test_loader: torch.utils.data.DataLoader,
  97. models: list[torch.nn.Module],
  98. sigma: float,
  99. intensity_range: float,
  100. class_index: int,
  101. ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  102. if not models:
  103. raise ValueError("No ensemble models were provided for noise inference")
  104. device = next(models[0].parameters()).device
  105. all_probs: list[float] = []
  106. all_stds: list[float] = []
  107. all_true: list[int] = []
  108. with torch.no_grad():
  109. for mri, xls, labels, _ in test_loader:
  110. mri_device = mri.float().to(device)
  111. xls_device = xls.float().to(device)
  112. labels_device = labels.to(device)
  113. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  114. preds = []
  115. for model in models:
  116. out = model((noisy, xls_device))
  117. preds.append(out[:, class_index].detach().cpu().numpy())
  118. pred_mat = np.stack(preds, axis=0)
  119. mean = pred_mat.mean(axis=0)
  120. std = pred_mat.std(axis=0)
  121. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  122. all_probs.extend(mean.tolist())
  123. all_stds.extend(std.tolist())
  124. all_true.extend(true.tolist())
  125. return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
  126. def _infer_with_noise_bayesian(
  127. test_loader: torch.utils.data.DataLoader,
  128. model: torch.nn.Module,
  129. sigma: float,
  130. intensity_range: float,
  131. class_index: int,
  132. mc_passes: int,
  133. ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  134. device = next(model.parameters()).device
  135. all_probs: list[float] = []
  136. all_stds: list[float] = []
  137. all_true: list[int] = []
  138. with torch.no_grad():
  139. for mri, xls, labels, _ in test_loader:
  140. mri_device = mri.float().to(device)
  141. xls_device = xls.float().to(device)
  142. labels_device = labels.to(device)
  143. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  144. draws = []
  145. for _ in range(mc_passes):
  146. out = model((noisy, xls_device))
  147. draws.append(out.detach().cpu().numpy())
  148. draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
  149. mean = draw_mat.mean(axis=0)[:, class_index]
  150. entropy_uncertainty = predictive_entropy(draw_mat)
  151. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  152. all_probs.extend(mean.tolist())
  153. all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
  154. all_true.extend(true.tolist())
  155. return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
  156. def run_noise_analysis(
  157. config: dict[str, Any],
  158. root_dir: Path,
  159. backend: str,
  160. output_dir: Path,
  161. class_index: int,
  162. noise_sigmas: list[float],
  163. threshold: float,
  164. calibration_bins: int,
  165. bayesian_mc_passes: int,
  166. ) -> dict[str, Any]:
  167. noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
  168. dataset, test_loader = build_dataset_and_test_loader(
  169. config=config,
  170. root_dir=root_dir,
  171. seed=int(config["data"]["seed"]),
  172. )
  173. # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
  174. # dataset
  175. # )
  176. # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
  177. intensity_range = 10_000.0
  178. out_plots_dir = plots_dir(output_dir)
  179. examples_dir = out_plots_dir / "noise_examples"
  180. examples_dir.mkdir(parents=True, exist_ok=True)
  181. rows: list[dict[str, Any]] = []
  182. if backend == "ensemble":
  183. models = _load_ensemble_models(config)
  184. example_rows: list[tuple[float, torch.Tensor]] = []
  185. for sigma in noise_sigmas:
  186. y_true, y_prob, y_std = _infer_with_noise_ensemble(
  187. test_loader,
  188. models,
  189. sigma,
  190. intensity_range,
  191. class_index=class_index,
  192. )
  193. perf = performance_at_threshold(y_true, y_prob, threshold)
  194. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  195. rows.append(
  196. {
  197. "uncertainty_metric": "std",
  198. "noise_factor": float(sigma),
  199. "accuracy": float(perf["accuracy"]),
  200. "f1": float(perf["f1"]),
  201. "ece": float(cal["ece"]),
  202. "mce": float(cal["mce"]),
  203. "mean_confidence_certainty": float(
  204. np.nanmean(confidence_certainty(y_prob))
  205. ),
  206. "mean_confidence_uncertainty": float(
  207. np.nanmean(confidence_uncertainty(y_prob))
  208. ),
  209. "mean_std": float(np.nanmean(y_std)),
  210. "mean_predictive_entropy": float("nan"),
  211. "mri_intensity_range": float(intensity_range),
  212. }
  213. )
  214. with torch.no_grad():
  215. sample = next(iter(test_loader))
  216. original_mri = sample[0]
  217. device = next(models[0].parameters()).device
  218. original_device = original_mri.float().to(device)
  219. for sigma in noise_sigmas:
  220. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  221. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  222. save_noise_example_grid(
  223. original_mri=original_mri,
  224. noisy_by_sigma=example_rows,
  225. output_path=examples_dir / f"{backend}_noise_examples.png",
  226. title=f"{backend.title()} Noise Examples",
  227. )
  228. elif backend == "bayesian":
  229. model = _load_bayesian_model(config)
  230. example_rows = []
  231. for sigma in noise_sigmas:
  232. y_true, y_prob, y_std = _infer_with_noise_bayesian(
  233. test_loader,
  234. model,
  235. sigma,
  236. intensity_range,
  237. class_index=class_index,
  238. mc_passes=bayesian_mc_passes,
  239. )
  240. perf = performance_at_threshold(y_true, y_prob, threshold)
  241. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  242. rows.append(
  243. {
  244. "uncertainty_metric": "predictive_entropy",
  245. "noise_factor": float(sigma),
  246. "accuracy": float(perf["accuracy"]),
  247. "f1": float(perf["f1"]),
  248. "ece": float(cal["ece"]),
  249. "mce": float(cal["mce"]),
  250. "mean_confidence_certainty": float(
  251. np.nanmean(confidence_certainty(y_prob))
  252. ),
  253. "mean_confidence_uncertainty": float(
  254. np.nanmean(confidence_uncertainty(y_prob))
  255. ),
  256. # Compatibility field name retained for downstream code.
  257. "mean_std": float(np.nanmean(y_std)),
  258. "mean_predictive_entropy": float(np.nanmean(y_std)),
  259. "mri_intensity_range": float(intensity_range),
  260. }
  261. )
  262. with torch.no_grad():
  263. sample = next(iter(test_loader))
  264. original_mri = sample[0]
  265. device = next(model.parameters()).device
  266. original_device = original_mri.float().to(device)
  267. for sigma in noise_sigmas:
  268. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  269. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  270. save_noise_example_grid(
  271. original_mri=original_mri,
  272. noisy_by_sigma=example_rows,
  273. output_path=examples_dir / f"{backend}_noise_examples.png",
  274. title=f"{backend.title()} Noise Examples",
  275. )
  276. else:
  277. raise ValueError(f"Unsupported backend for noise analysis: {backend}")
  278. df = pd.DataFrame(rows).sort_values("noise_factor")
  279. csv_path = output_dir / "noise_sensitivity.csv"
  280. df.to_csv(csv_path, index=False)
  281. plot_path = out_plots_dir / "noise_sensitivity.png"
  282. uncertainty_plot_path = out_plots_dir / "noise_uncertainty.png"
  283. certainty_plot_path = out_plots_dir / "noise_confidence_certainty.png"
  284. save_noise_metrics_plot(
  285. x=df["noise_factor"],
  286. y_by_label=[
  287. (df["accuracy"], "o", "accuracy"),
  288. (df["f1"], "s", "f1"),
  289. (df["ece"], "^", "ece"),
  290. ],
  291. x_label="Gaussian Noise Factor",
  292. y_label="Score",
  293. title=f"Noise Sensitivity ({backend})",
  294. output_path=plot_path,
  295. )
  296. save_noise_metrics_plot(
  297. x=df["noise_factor"],
  298. y_by_label=[
  299. (df["mean_confidence_uncertainty"], "o", "confidence_uncertainty"),
  300. (df["mean_std"], "s", "std_uncertainty"),
  301. ],
  302. x_label="Gaussian Noise Factor",
  303. y_label="Uncertainty",
  304. title=f"Uncertainty vs Noise ({backend})",
  305. output_path=uncertainty_plot_path,
  306. )
  307. save_noise_metrics_plot(
  308. x=df["noise_factor"],
  309. y_by_label=[
  310. (df["mean_confidence_certainty"], "o", "confidence_certainty"),
  311. ],
  312. x_label="Gaussian Noise Factor",
  313. y_label="Certainty",
  314. title=f"Confidence Certainty vs Noise ({backend})",
  315. output_path=certainty_plot_path,
  316. )
  317. out = {
  318. "table": str(csv_path),
  319. "plot": str(plot_path),
  320. "uncertainty_plot": str(uncertainty_plot_path),
  321. "certainty_plot": str(certainty_plot_path),
  322. "noise_factors": noise_sigmas,
  323. "noise_sigmas": noise_sigmas,
  324. # "mri_intensity_min": float(intensity_min),
  325. # "mri_intensity_max": float(intensity_max),
  326. "mri_intensity_range": float(intensity_range),
  327. }
  328. write_json(output_dir / "noise_summary.json", out)
  329. return out