noise_analysis.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import numpy as np
  6. import pandas as pd
  7. import torch
  8. from bayesian_torch.utils.util import predictive_entropy
  9. from tqdm.auto import tqdm
  10. from model.cnn import CNN3D
  11. from .data_pipeline import build_holdout_loader
  12. from .metrics import calibration_stats, performance_at_threshold
  13. from .model_utils import configure_bayesian_sampling_mode
  14. from .plotting import (
  15. plots_dir,
  16. save_clean_scan_image,
  17. save_noise_example_grid,
  18. save_metric_pair_plot,
  19. save_noise_metrics_plot,
  20. )
  21. from .runtime import write_json
  22. def _apply_scaled_noise(
  23. volume: torch.Tensor, sigma: float, intensity_range: float
  24. ) -> torch.Tensor:
  25. # Scale by global MRI intensity range measured from holdout set.
  26. return volume + (torch.randn_like(volume) * sigma * intensity_range)
  27. def _uniform_sigma_schedule(noise_sigmas: list[float]) -> list[float]:
  28. if not noise_sigmas:
  29. raise ValueError("noise_sigmas must contain at least one value")
  30. ordered = np.array(sorted(float(s) for s in noise_sigmas), dtype=float)
  31. if len(ordered) == 1:
  32. return [float(ordered[0])]
  33. uniform = np.linspace(
  34. float(ordered[0]), float(ordered[-1]), num=len(ordered), dtype=float
  35. )
  36. return [float(s) for s in uniform]
  37. def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
  38. model_dir = Path(config["output"]["ensemble_path"])
  39. model_files = sorted(model_dir.glob("model_run_*.pt"))
  40. if not model_files:
  41. raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
  42. models: list[torch.nn.Module] = []
  43. for model_file in model_files:
  44. model = (
  45. CNN3D(
  46. image_channels=int(config["data"]["image_channels"]),
  47. clin_data_channels=int(config["data"]["clin_data_channels"]),
  48. num_classes=int(config["data"]["num_classes"]),
  49. droprate=float(config["training"]["droprate"]),
  50. )
  51. .float()
  52. .to(config["training"]["device"])
  53. )
  54. model.load_state_dict(
  55. torch.load(model_file, map_location=config["training"]["device"]),
  56. strict=False,
  57. )
  58. model.eval()
  59. models.append(model)
  60. return models
  61. def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
  62. device = str(config["training"]["device"])
  63. try:
  64. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  65. except ImportError as e:
  66. raise ImportError(
  67. "bayesian_torch is required for bayesian noise analysis"
  68. ) from e
  69. model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
  70. if not model_path.exists():
  71. raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
  72. model = (
  73. CNN3D(
  74. image_channels=int(config["data"]["image_channels"]),
  75. clin_data_channels=int(config["data"]["clin_data_channels"]),
  76. num_classes=int(config["data"]["num_classes"]),
  77. droprate=float(config["training"]["droprate"]),
  78. )
  79. .float()
  80. .to(config["training"]["device"])
  81. )
  82. prior_params: dict[str, float | bool | str] = {
  83. "prior_mu": 0.0,
  84. "prior_sigma": 1.0,
  85. "posterior_mu_init": 0.0,
  86. "posterior_rho_init": -3.0,
  87. "type": "Reparameterization",
  88. "moped_enable": False,
  89. "moped_delta": 0.5,
  90. }
  91. dnn_to_bnn(model, prior_params)
  92. model.to(device)
  93. model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
  94. model.to(device)
  95. configure_bayesian_sampling_mode(model, stochastic=False)
  96. return model
  97. def _infer_with_noise_ensemble(
  98. test_loader: torch.utils.data.DataLoader,
  99. models: list[torch.nn.Module],
  100. sigma: float,
  101. intensity_range: float,
  102. class_index: int,
  103. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  104. if not models:
  105. raise ValueError("No ensemble models were provided for noise inference")
  106. device = next(models[0].parameters()).device
  107. all_probs: list[float] = []
  108. all_confidence: list[float] = []
  109. all_stds: list[float] = []
  110. all_true: list[int] = []
  111. with torch.no_grad():
  112. batch_iter = tqdm(
  113. test_loader,
  114. total=len(test_loader),
  115. desc=f"ensemble sigma={sigma:g}",
  116. unit="batch",
  117. leave=False,
  118. )
  119. for mri, xls, labels, _ in batch_iter:
  120. mri_device = mri.float().to(device)
  121. xls_device = xls.float().to(device)
  122. labels_device = labels.to(device)
  123. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  124. preds = []
  125. for model in models:
  126. out = model((noisy, xls_device))
  127. preds.append(out[:, class_index].detach().cpu().numpy())
  128. pred_mat = np.stack(preds, axis=0)
  129. mean = pred_mat.mean(axis=0)
  130. confidence = mean
  131. std = pred_mat.std(axis=0)
  132. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  133. all_probs.extend(mean.tolist())
  134. all_confidence.extend(confidence.tolist())
  135. all_stds.extend(std.tolist())
  136. all_true.extend(true.tolist())
  137. return (
  138. np.asarray(all_true),
  139. np.asarray(all_probs),
  140. np.asarray(all_confidence),
  141. np.asarray(all_stds),
  142. )
  143. def _infer_with_noise_bayesian(
  144. test_loader: torch.utils.data.DataLoader,
  145. model: torch.nn.Module,
  146. sigma: float,
  147. intensity_range: float,
  148. class_index: int,
  149. mc_passes: int,
  150. ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
  151. device = next(model.parameters()).device
  152. all_probs: list[float] = []
  153. all_confidence: list[float] = []
  154. all_stds: list[float] = []
  155. all_true: list[int] = []
  156. with torch.no_grad():
  157. batch_iter = tqdm(
  158. test_loader,
  159. total=len(test_loader),
  160. desc=f"bayesian sigma={sigma:g}",
  161. unit="batch",
  162. leave=False,
  163. )
  164. for mri, xls, labels, _ in batch_iter:
  165. mri_device = mri.float().to(device)
  166. xls_device = xls.float().to(device)
  167. labels_device = labels.to(device)
  168. noisy = _apply_scaled_noise(mri_device, sigma, intensity_range)
  169. draws = []
  170. for _ in range(mc_passes):
  171. out = model((noisy, xls_device))
  172. draws.append(out.detach().cpu().numpy())
  173. draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
  174. mean = draw_mat.mean(axis=0)[:, class_index]
  175. confidence = mean
  176. entropy_uncertainty = predictive_entropy(draw_mat)
  177. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  178. all_probs.extend(mean.tolist())
  179. all_confidence.extend(np.asarray(confidence, dtype=float).tolist())
  180. all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
  181. all_true.extend(true.tolist())
  182. return (
  183. np.asarray(all_true),
  184. np.asarray(all_probs),
  185. np.asarray(all_confidence),
  186. np.asarray(all_stds),
  187. )
  188. def run_noise_analysis(
  189. config: dict[str, Any],
  190. root_dir: Path,
  191. backend: str,
  192. output_dir: Path,
  193. class_index: int,
  194. noise_sigmas: list[float],
  195. threshold: float,
  196. calibration_bins: int,
  197. bayesian_mc_passes: int,
  198. ) -> dict[str, Any]:
  199. noise_sigmas = _uniform_sigma_schedule(noise_sigmas)
  200. test_loader = build_holdout_loader(
  201. config=config,
  202. root_dir=root_dir,
  203. seed=int(config["data"]["seed"]),
  204. )
  205. # intensity_min, intensity_max, intensity_range = _compute_mri_intensity_range(
  206. # dataset
  207. # )
  208. # Just use a fixed intensity range for noise scaling since all that matters is that it's consistent
  209. intensity_range = 10_000.0
  210. out_plots_dir = plots_dir(output_dir)
  211. examples_dir = out_plots_dir / "noise_examples"
  212. examples_dir.mkdir(parents=True, exist_ok=True)
  213. rows: list[dict[str, Any]] = []
  214. if backend == "ensemble":
  215. models = _load_ensemble_models(config)
  216. example_rows: list[tuple[float, torch.Tensor]] = []
  217. sigma_iter = tqdm(noise_sigmas, desc="Noise sweep (ensemble)", unit="sigma")
  218. for sigma in sigma_iter:
  219. sigma_iter.set_postfix_str(f"sigma={sigma:g}")
  220. y_true, y_prob, y_confidence, y_std = _infer_with_noise_ensemble(
  221. test_loader,
  222. models,
  223. sigma,
  224. intensity_range,
  225. class_index=class_index,
  226. )
  227. perf = performance_at_threshold(y_true, y_prob, threshold)
  228. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  229. rows.append(
  230. {
  231. "uncertainty_metric": "std",
  232. "noise_factor": float(sigma),
  233. "accuracy": float(perf["accuracy"]),
  234. "f1": float(perf["f1"]),
  235. "mce": float(cal["mce"]),
  236. "mean_confidence": float(np.nanmean(y_confidence)),
  237. "mean_model_output_probability": float(np.nanmean(y_prob)),
  238. "mean_std": float(np.nanmean(y_std)),
  239. "mean_predictive_entropy": float("nan"),
  240. "mri_intensity_range": float(intensity_range),
  241. }
  242. )
  243. with torch.no_grad():
  244. sample = next(iter(test_loader))
  245. original_mri = sample[0]
  246. device = next(models[0].parameters()).device
  247. original_device = original_mri.float().to(device)
  248. for sigma in noise_sigmas:
  249. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  250. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  251. save_noise_example_grid(
  252. original_mri=original_mri,
  253. noisy_by_sigma=example_rows,
  254. output_path=examples_dir / f"{backend}_noise_examples.png",
  255. title=f"{backend.title()} Noise Examples",
  256. max_images=9,
  257. n_rows=2,
  258. )
  259. save_clean_scan_image(
  260. original_mri=original_mri,
  261. output_path=examples_dir / f"{backend}_clean_scan_example.png",
  262. )
  263. elif backend == "bayesian":
  264. model = _load_bayesian_model(config)
  265. example_rows = []
  266. sigma_iter = tqdm(noise_sigmas, desc="Noise sweep (bayesian)", unit="sigma")
  267. for sigma in sigma_iter:
  268. sigma_iter.set_postfix_str(f"sigma={sigma:g}")
  269. y_true, y_prob, y_confidence, y_std = _infer_with_noise_bayesian(
  270. test_loader,
  271. model,
  272. sigma,
  273. intensity_range,
  274. class_index=class_index,
  275. mc_passes=bayesian_mc_passes,
  276. )
  277. perf = performance_at_threshold(y_true, y_prob, threshold)
  278. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  279. rows.append(
  280. {
  281. "uncertainty_metric": "predictive_entropy",
  282. "noise_factor": float(sigma),
  283. "accuracy": float(perf["accuracy"]),
  284. "f1": float(perf["f1"]),
  285. "mce": float(cal["mce"]),
  286. "mean_confidence": float(np.nanmean(y_confidence)),
  287. "mean_model_output_probability": float(np.nanmean(y_prob)),
  288. # Compatibility field name retained for downstream code.
  289. "mean_std": float(np.nanmean(y_std)),
  290. "mean_predictive_entropy": float(np.nanmean(y_std)),
  291. "mri_intensity_range": float(intensity_range),
  292. }
  293. )
  294. with torch.no_grad():
  295. sample = next(iter(test_loader))
  296. original_mri = sample[0]
  297. device = next(model.parameters()).device
  298. original_device = original_mri.float().to(device)
  299. for sigma in noise_sigmas:
  300. noisy_mri = _apply_scaled_noise(original_device, sigma, intensity_range)
  301. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  302. save_noise_example_grid(
  303. original_mri=original_mri,
  304. noisy_by_sigma=example_rows,
  305. output_path=examples_dir / f"{backend}_noise_examples.png",
  306. title=f"{backend.title()} Noise Examples",
  307. max_images=9,
  308. n_rows=2,
  309. )
  310. save_clean_scan_image(
  311. original_mri=original_mri,
  312. output_path=examples_dir / f"{backend}_clean_scan_example.png",
  313. )
  314. else:
  315. raise ValueError(f"Unsupported backend for noise analysis: {backend}")
  316. df = pd.DataFrame(rows).sort_values("noise_factor")
  317. csv_path = output_dir / "noise_sensitivity.csv"
  318. df.to_csv(csv_path, index=False)
  319. accuracy_plot_path = out_plots_dir / "noise_sensitivity_accuracy.png"
  320. f1_plot_path = out_plots_dir / "noise_sensitivity_f1.png"
  321. pair_plot_path = out_plots_dir / "noise_sensitivity_accuracy_f1.png"
  322. confidence_plot_path = out_plots_dir / "noise_confidence.png"
  323. confidence_uncertainty_pair_path = (
  324. out_plots_dir / "noise_confidence_predictive_uncertainty.png"
  325. if backend == "bayesian"
  326. else out_plots_dir / "noise_confidence_standard_deviation.png"
  327. )
  328. secondary_plot_name = (
  329. "noise_predictive_uncertainty.png"
  330. if backend == "bayesian"
  331. else "noise_standard_deviation.png"
  332. )
  333. secondary_plot_path = out_plots_dir / secondary_plot_name
  334. save_noise_metrics_plot(
  335. x=df["noise_factor"],
  336. y=df["accuracy"],
  337. legend_label="Accuracy",
  338. marker="o",
  339. x_label="Gaussian Noise Factor",
  340. y_label="Accuracy",
  341. title=f"Accuracy vs Noise ({backend})",
  342. output_path=accuracy_plot_path,
  343. plot_key="noise_sensitivity_accuracy",
  344. )
  345. save_noise_metrics_plot(
  346. x=df["noise_factor"],
  347. y=df["f1"],
  348. legend_label="F1",
  349. marker="s",
  350. x_label="Gaussian Noise Factor",
  351. y_label="F1",
  352. title=f"F1 vs Noise ({backend})",
  353. output_path=f1_plot_path,
  354. plot_key="noise_sensitivity_f1",
  355. )
  356. save_metric_pair_plot(
  357. x=df["noise_factor"],
  358. left_y=df["accuracy"],
  359. right_y=df["f1"],
  360. left_label="Accuracy",
  361. right_label="F1",
  362. x_label="Gaussian Noise Factor",
  363. y_label="Accuracy/F1",
  364. title=f"Accuracy and F1 vs Noise ({backend})",
  365. output_path=pair_plot_path,
  366. plot_key="noise_sensitivity_accuracy_f1",
  367. )
  368. secondary_label = (
  369. "Predictive Uncertainty" if backend == "bayesian" else "Standard Deviation"
  370. )
  371. save_noise_metrics_plot(
  372. x=df["noise_factor"],
  373. y=df["mean_confidence"],
  374. legend_label="Confidence",
  375. marker="o",
  376. x_label="Gaussian Noise Factor",
  377. y_label="Confidence",
  378. title=f"Confidence vs Noise ({backend})",
  379. output_path=confidence_plot_path,
  380. plot_key="noise_confidence",
  381. )
  382. save_noise_metrics_plot(
  383. x=df["noise_factor"],
  384. y=df["mean_std"],
  385. legend_label=secondary_label,
  386. marker="^",
  387. x_label="Gaussian Noise Factor",
  388. y_label=secondary_label,
  389. title=f"{secondary_label} vs Noise ({backend})",
  390. output_path=secondary_plot_path,
  391. plot_key=(
  392. "noise_predictive_uncertainty"
  393. if backend == "bayesian"
  394. else "noise_standard_deviation"
  395. ),
  396. )
  397. save_metric_pair_plot(
  398. x=df["noise_factor"],
  399. left_y=df["mean_confidence"],
  400. right_y=df["mean_std"],
  401. left_label="Confidence",
  402. right_label=secondary_label,
  403. x_label="Gaussian Noise Factor",
  404. y_label="Confidence/Uncertainty",
  405. title=f"Confidence and {secondary_label} vs Noise ({backend})",
  406. output_path=confidence_uncertainty_pair_path,
  407. plot_key=(
  408. "noise_confidence_predictive_uncertainty"
  409. if backend == "bayesian"
  410. else "noise_confidence_standard_deviation"
  411. ),
  412. )
  413. out = {
  414. "table": str(csv_path),
  415. "plots": {
  416. "accuracy": str(accuracy_plot_path),
  417. "f1": str(f1_plot_path),
  418. "accuracy_f1": str(pair_plot_path),
  419. "confidence": str(confidence_plot_path),
  420. (
  421. "confidence_predictive_uncertainty"
  422. if backend == "bayesian"
  423. else "confidence_standard_deviation"
  424. ): str(confidence_uncertainty_pair_path),
  425. (
  426. "predictive_uncertainty"
  427. if backend == "bayesian"
  428. else "standard_deviation"
  429. ): str(secondary_plot_path),
  430. },
  431. "noise_factors": noise_sigmas,
  432. "noise_sigmas": noise_sigmas,
  433. # "mri_intensity_min": float(intensity_min),
  434. # "mri_intensity_max": float(intensity_max),
  435. "mri_intensity_range": float(intensity_range),
  436. }
  437. write_json(output_dir / "noise_summary.json", out)
  438. return out