noise_analysis.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. from bayesian_torch.utils.util import predictive_entropy
  9. from model.cnn import CNN3D
  10. from .data_pipeline import build_holdout_loader
  11. from .metrics import calibration_stats, performance_at_threshold
  12. from .model_utils import configure_bayesian_sampling_mode
  13. from .plotting import (
  14. plots_dir,
  15. save_clean_scan_image,
  16. save_noise_example_grid,
  17. save_metric_pair_plot,
  18. save_noise_metrics_plot,
  19. )
  20. from .runtime import write_json
  21. def _apply_scaled_noise(
  22. volume: torch.Tensor, sigma: float, intensity_range: float
  23. ) -> torch.Tensor:
  24. # Scale by global MRI intensity range measured from holdout set.
  25. return volume + (torch.randn_like(volume) * sigma * intensity_range)
  26. def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
  27. if not noise_sigmas:
  28. raise ValueError("noise_sigmas must contain at least one value")
  29. ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
  30. if len(ordered) == 1:
  31. return [float(ordered[0])]
  32. uniform = np.linspace(
  33. float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
  34. )
  35. return [float(s) for s in uniform]
  36. def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
  37. model_dir = Path(config["output"]["ensemble_path"])
  38. model_files = sorted(model_dir.glob("model_run_*.pt"))
  39. if not model_files:
  40. raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
  41. models: list[torch.nn.Module] = []
  42. for model_file in model_files:
  43. model = (
  44. CNN3D(
  45. image_channels=int(config["data"]["image_channels"]),
  46. clin_data_channels=int(config["data"]["clin_data_channels"]),
  47. num_classes=int(config["data"]["num_classes"]),
  48. droprate=float(config["training"]["droprate"]),
  49. )
  50. .float()
  51. .to(config["training"]["device"])
  52. )
  53. model.load_state_dict(
  54. torch.load(model_file, map_location=config["training"]["device"]),
  55. strict=False,
  56. )
  57. model.eval()
  58. models.append(model)
  59. return models
  60. def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
  61. device = str(config["training"]["device"])
  62. try:
  63. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  64. except ImportError as e:
  65. raise ImportError(
  66. "bayesian_torch is required for bayesian noise analysis"
  67. ) from e
  68. model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
  69. if not model_path.exists():
  70. raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
  71. model = (
  72. CNN3D(
  73. image_channels=int(config["data"]["image_channels"]),
  74. clin_data_channels=int(config["data"]["clin_data_channels"]),
  75. num_classes=int(config["data"]["num_classes"]),
  76. droprate=float(config["training"]["droprate"]),
  77. )
  78. .float()
  79. .to(config["training"]["device"])
  80. )
  81. prior_params: dict[str, float | bool | str] = {
  82. "prior_mu": 0.0,
  83. "prior_sigma": 1.0,
  84. "posterior_mu_init": 0.0,
  85. "posterior_rho_init": -3.0,
  86. "type": "Reparameterization",
  87. "moped_enable": False,
  88. "moped_delta": 0.5,
  89. }
  90. dnn_to_bnn(model, prior_params)
  91. model.to(device)
  92. model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
  93. model.to(device)
  94. configure_bayesian_sampling_mode(model, stochastic=False)
  95. return model
  96. def _infer_with_noise_ensemble(
  97. test_loader: torch.utils.data.DataLoader,
  98. models: list[torch.nn.Module],
  99. sigma: float,
  100. intensity_range: float,
  101. class_index: int,
  102. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  103. if not models:
  104. raise ValueError("No ensemble models were provided for noise inference")
  105. device = next(models[0].parameters()).device
  106. all_probs: list[float] = []
  107. all_confidence: list[float] = []
  108. all_stds: list[float] = []
  109. all_true: list[int] = []
  110. with torch.no_grad():
  111. for mri, xls, labels, _ in test_loader:
  112. mri_device = mri.float().to(device)
  113. xls_device = xls.float().to(device)
  114. labels_device = labels.to(device)
  115. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  116. preds = []
  117. for model in models:
  118. out = model((noisy, xls_device))
  119. preds.append(out[:, class_index].detach().cpu().numpy())
  120. pred_mat = np.stack(preds, axis=0)
  121. mean = pred_mat.mean(axis=0)
  122. confidence = np.abs(pred_mat - 0.5).mean(axis=0)
  123. std = pred_mat.std(axis=0)
  124. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  125. all_probs.extend(mean.tolist())
  126. all_confidence.extend(confidence.tolist())
  127. all_stds.extend(std.tolist())
  128. all_true.extend(true.tolist())
  129. return (
  130. np.asarray(all_true),
  131. np.asarray(all_probs),
  132. np.asarray(all_confidence),
  133. np.asarray(all_stds),
  134. )
  135. def _infer_with_noise_bayesian(
  136. test_loader: torch.utils.data.DataLoader,
  137. model: torch.nn.Module,
  138. sigma: float,
  139. intensity_range: float,
  140. class_index: int,
  141. mc_passes: int,
  142. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  143. device = next(model.parameters()).device
  144. all_probs: list[float] = []
  145. all_confidence: list[float] = []
  146. all_stds: list[float] = []
  147. all_true: list[int] = []
  148. with torch.no_grad():
  149. for mri, xls, labels, _ in test_loader:
  150. mri_device = mri.float().to(device)
  151. xls_device = xls.float().to(device)
  152. labels_device = labels.to(device)
  153. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  154. draws = []
  155. for _ in range(mc_passes):
  156. out = model((noisy, xls_device))
  157. draws.append(out.detach().cpu().numpy())
  158. draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
  159. mean = draw_mat.mean(axis=0)[:, class_index]
  160. confidence = np.abs(draw_mat[:, :, class_index] - 0.5).mean(axis=0)
  161. entropy_uncertainty = predictive_entropy(draw_mat)
  162. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  163. all_probs.extend(mean.tolist())
  164. all_confidence.extend(np.asarray(confidence, dtype=float).tolist())
  165. all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
  166. all_true.extend(true.tolist())
  167. return (
  168. np.asarray(all_true),
  169. np.asarray(all_probs),
  170. np.asarray(all_confidence),
  171. np.asarray(all_stds),
  172. )
  173. def run_noise_analysis(
  174. config: dict[str, Any],
  175. root_dir: Path,
  176. backend: str,
  177. output_dir: Path,
  178. class_index: int,
  179. noise_sigmas: list[float],
  180. threshold: float,
  181. calibration_bins: int,
  182. bayesian_mc_passes: int,
  183. ) -> dict[str, Any]:
  184. noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
  185. test_loader = build_holdout_loader(
  186. config=config,
  187. root_dir=root_dir,
  188. seed=int(config["data"]["seed"]),
  189. )
  190. # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
  191. # dataset
  192. # )
  193. # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
  194. intensity_range = 10_000.0
  195. out_plots_dir = plots_dir(output_dir)
  196. examples_dir = out_plots_dir / "noise_examples"
  197. examples_dir.mkdir(parents=True, exist_ok=True)
  198. rows: list[dict[str, Any]] = []
  199. if backend == "ensemble":
  200. models = _load_ensemble_models(config)
  201. example_rows: list[tuple[float, torch.Tensor]] = []
  202. for sigma in noise_sigmas:
  203. y_true, y_prob, y_confidence, y_std = _infer_with_noise_ensemble(
  204. test_loader,
  205. models,
  206. sigma,
  207. intensity_range,
  208. class_index=class_index,
  209. )
  210. perf = performance_at_threshold(y_true, y_prob, threshold)
  211. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  212. rows.append(
  213. {
  214. "uncertainty_metric": "std",
  215. "noise_factor": float(sigma),
  216. "accuracy": float(perf["accuracy"]),
  217. "f1": float(perf["f1"]),
  218. "mce": float(cal["mce"]),
  219. "mean_confidence": float(np.nanmean(y_confidence)),
  220. "mean_model_output_probability": float(np.nanmean(y_prob)),
  221. "mean_std": float(np.nanmean(y_std)),
  222. "mean_predictive_entropy": float("nan"),
  223. "mri_intensity_range": float(intensity_range),
  224. }
  225. )
  226. with torch.no_grad():
  227. sample = next(iter(test_loader))
  228. original_mri = sample[0]
  229. device = next(models[0].parameters()).device
  230. original_device = original_mri.float().to(device)
  231. for sigma in noise_sigmas:
  232. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  233. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  234. save_noise_example_grid(
  235. original_mri=original_mri,
  236. noisy_by_sigma=example_rows,
  237. output_path=examples_dir / f"{backend}_noise_examples.png",
  238. title=f"{backend.title()} Noise Examples",
  239. max_images=9,
  240. n_rows=2,
  241. )
  242. save_clean_scan_image(
  243. original_mri=original_mri,
  244. output_path=examples_dir / f"{backend}_clean_scan_example.png",
  245. )
  246. elif backend == "bayesian":
  247. model = _load_bayesian_model(config)
  248. example_rows = []
  249. for sigma in noise_sigmas:
  250. y_true, y_prob, y_confidence, y_std = _infer_with_noise_bayesian(
  251. test_loader,
  252. model,
  253. sigma,
  254. intensity_range,
  255. class_index=class_index,
  256. mc_passes=bayesian_mc_passes,
  257. )
  258. perf = performance_at_threshold(y_true, y_prob, threshold)
  259. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  260. rows.append(
  261. {
  262. "uncertainty_metric": "predictive_entropy",
  263. "noise_factor": float(sigma),
  264. "accuracy": float(perf["accuracy"]),
  265. "f1": float(perf["f1"]),
  266. "mce": float(cal["mce"]),
  267. "mean_confidence": float(np.nanmean(y_confidence)),
  268. "mean_model_output_probability": float(np.nanmean(y_prob)),
  269. # Compatibility field name retained for downstream code.
  270. "mean_std": float(np.nanmean(y_std)),
  271. "mean_predictive_entropy": float(np.nanmean(y_std)),
  272. "mri_intensity_range": float(intensity_range),
  273. }
  274. )
  275. with torch.no_grad():
  276. sample = next(iter(test_loader))
  277. original_mri = sample[0]
  278. device = next(model.parameters()).device
  279. original_device = original_mri.float().to(device)
  280. for sigma in noise_sigmas:
  281. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  282. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  283. save_noise_example_grid(
  284. original_mri=original_mri,
  285. noisy_by_sigma=example_rows,
  286. output_path=examples_dir / f"{backend}_noise_examples.png",
  287. title=f"{backend.title()} Noise Examples",
  288. max_images=9,
  289. n_rows=2,
  290. )
  291. save_clean_scan_image(
  292. original_mri=original_mri,
  293. output_path=examples_dir / f"{backend}_clean_scan_example.png",
  294. )
  295. else:
  296. raise ValueError(f"Unsupported backend for noise analysis: {backend}")
  297. df = pd.DataFrame(rows).sort_values("noise_factor")
  298. csv_path = output_dir / "noise_sensitivity.csv"
  299. df.to_csv(csv_path, index=False)
  300. accuracy_plot_path = out_plots_dir / "noise_sensitivity_accuracy.png"
  301. f1_plot_path = out_plots_dir / "noise_sensitivity_f1.png"
  302. pair_plot_path = out_plots_dir / "noise_sensitivity_accuracy_f1.png"
  303. confidence_plot_path = out_plots_dir / "noise_confidence.png"
  304. confidence_uncertainty_pair_path = (
  305. out_plots_dir / "noise_confidence_predictive_uncertainty.png"
  306. if backend == "bayesian"
  307. else out_plots_dir / "noise_confidence_standard_deviation.png"
  308. )
  309. secondary_plot_name = (
  310. "noise_predictive_uncertainty.png"
  311. if backend == "bayesian"
  312. else "noise_standard_deviation.png"
  313. )
  314. secondary_plot_path = out_plots_dir / secondary_plot_name
  315. save_noise_metrics_plot(
  316. x=df["noise_factor"],
  317. y=df["accuracy"],
  318. legend_label="Accuracy",
  319. marker="o",
  320. x_label="Gaussian Noise Factor",
  321. y_label="Accuracy",
  322. title=f"Accuracy vs Noise ({backend})",
  323. output_path=accuracy_plot_path,
  324. plot_key="noise_sensitivity_accuracy",
  325. )
  326. save_noise_metrics_plot(
  327. x=df["noise_factor"],
  328. y=df["f1"],
  329. legend_label="F1",
  330. marker="s",
  331. x_label="Gaussian Noise Factor",
  332. y_label="F1",
  333. title=f"F1 vs Noise ({backend})",
  334. output_path=f1_plot_path,
  335. plot_key="noise_sensitivity_f1",
  336. )
  337. save_metric_pair_plot(
  338. x=df["noise_factor"],
  339. left_y=df["accuracy"],
  340. right_y=df["f1"],
  341. left_label="Accuracy",
  342. right_label="F1",
  343. x_label="Gaussian Noise Factor",
  344. y_label="Accuracy/F1",
  345. title=f"Accuracy and F1 vs Noise ({backend})",
  346. output_path=pair_plot_path,
  347. plot_key="noise_sensitivity_accuracy_f1",
  348. )
  349. secondary_label = (
  350. "Predictive Uncertainty" if backend == "bayesian" else "Standard Deviation"
  351. )
  352. save_noise_metrics_plot(
  353. x=df["noise_factor"],
  354. y=df["mean_confidence"],
  355. legend_label="Confidence",
  356. marker="o",
  357. x_label="Gaussian Noise Factor",
  358. y_label="Confidence",
  359. title=f"Confidence vs Noise ({backend})",
  360. output_path=confidence_plot_path,
  361. plot_key="noise_confidence",
  362. )
  363. save_noise_metrics_plot(
  364. x=df["noise_factor"],
  365. y=df["mean_std"],
  366. legend_label=secondary_label,
  367. marker="^",
  368. x_label="Gaussian Noise Factor",
  369. y_label=secondary_label,
  370. title=f"{secondary_label} vs Noise ({backend})",
  371. output_path=secondary_plot_path,
  372. plot_key=(
  373. "noise_predictive_uncertainty"
  374. if backend == "bayesian"
  375. else "noise_standard_deviation"
  376. ),
  377. )
  378. save_metric_pair_plot(
  379. x=df["noise_factor"],
  380. left_y=df["mean_confidence"],
  381. right_y=df["mean_std"],
  382. left_label="Confidence",
  383. right_label=secondary_label,
  384. x_label="Gaussian Noise Factor",
  385. y_label="Confidence/Uncertainty",
  386. title=f"Confidence and {secondary_label} vs Noise ({backend})",
  387. output_path=confidence_uncertainty_pair_path,
  388. plot_key=(
  389. "noise_confidence_predictive_uncertainty"
  390. if backend == "bayesian"
  391. else "noise_confidence_standard_deviation"
  392. ),
  393. )
  394. out = {
  395. "table": str(csv_path),
  396. "plots": {
  397. "accuracy": str(accuracy_plot_path),
  398. "f1": str(f1_plot_path),
  399. "accuracy_f1": str(pair_plot_path),
  400. "confidence": str(confidence_plot_path),
  401. (
  402. "confidence_predictive_uncertainty"
  403. if backend == "bayesian"
  404. else "confidence_standard_deviation"
  405. ): str(confidence_uncertainty_pair_path),
  406. (
  407. "predictive_uncertainty"
  408. if backend == "bayesian"
  409. else "standard_deviation"
  410. ): str(secondary_plot_path),
  411. },
  412. "noise_factors": noise_sigmas,
  413. "noise_sigmas": noise_sigmas,
  414. # "mri_intensity_min": float(intensity_min),
  415. # "mri_intensity_max": float(intensity_max),
  416. "mri_intensity_range": float(intensity_range),
  417. }
  418. write_json(output_dir / "noise_summary.json", out)
  419. return out