noise_analysis.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # pyright: basic
  2. from __future__ import annotations
  3. from pathlib import Path
  4. from typing import Any
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from bayesian_torch.utils.util import predictive_entropy
  10. from data.dataset import (
  11. divide_dataset_by_patient_id,
  12. initalize_dataloaders,
  13. load_adni_data_from_file,
  14. )
  15. from model.cnn import CNN3D
  16. from .metrics import calibration_stats, performance_at_threshold
  17. from .runtime import write_json
  18. def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None:
  19. model.eval()
  20. def _central_slice(volume: torch.Tensor) -> np.ndarray:
  21. tensor = volume.detach().cpu()
  22. if tensor.ndim == 5:
  23. tensor = tensor[0]
  24. if tensor.ndim == 4:
  25. tensor = tensor[0]
  26. if tensor.ndim != 3:
  27. raise ValueError(
  28. f"Expected a 3D volume after squeezing batch/channel, got shape {tuple(tensor.shape)}"
  29. )
  30. center_index = tensor.shape[0] // 2
  31. return tensor[center_index].numpy().astype(float)
  32. def _normalize_for_display(image: np.ndarray) -> np.ndarray:
  33. low = float(np.percentile(image, 1))
  34. high = float(np.percentile(image, 99))
  35. if high <= low:
  36. return np.zeros_like(image, dtype=float)
  37. clipped = np.clip(image, low, high)
  38. return (clipped - low) / (high - low)
  39. def _save_noise_example_grid(
  40. original_mri: torch.Tensor,
  41. noisy_by_sigma: list[tuple[float, torch.Tensor]],
  42. output_path: Path,
  43. title: str,
  44. ) -> None:
  45. if not noisy_by_sigma:
  46. return
  47. original_slice = _normalize_for_display(_central_slice(original_mri))
  48. n_rows = len(noisy_by_sigma)
  49. fig, axes = plt.subplots(n_rows, 2, figsize=(8, 3.2 * n_rows))
  50. if n_rows == 1:
  51. axes = np.array([axes])
  52. for row_idx, (sigma, noisy_tensor) in enumerate(noisy_by_sigma):
  53. noisy_slice = _normalize_for_display(_central_slice(noisy_tensor))
  54. ax_orig, ax_noisy = axes[row_idx]
  55. ax_orig.imshow(original_slice, cmap="gray")
  56. ax_orig.set_title(f"Original")
  57. ax_orig.axis("off")
  58. ax_noisy.imshow(noisy_slice, cmap="gray")
  59. ax_noisy.set_title(f"Noisy sigma={sigma:g}")
  60. ax_noisy.axis("off")
  61. fig.suptitle(title)
  62. fig.tight_layout()
  63. output_path.parent.mkdir(parents=True, exist_ok=True)
  64. fig.savefig(output_path)
  65. plt.close(fig)
  66. def _confidence_certainty(y_prob: np.ndarray) -> np.ndarray:
  67. # 2 * |p - 0.5| maps 0 -> very uncertain and 1 -> very certain.
  68. return 2.0 * np.abs(y_prob - 0.5)
  69. def _confidence_uncertainty(y_prob: np.ndarray) -> np.ndarray:
  70. # Invert certainty so that larger values mean more uncertainty, matching std.
  71. return 1.0 - _confidence_certainty(y_prob)
  72. def _apply_scaled_noise(volume: torch.Tensor, sigma: float) -> torch.Tensor:
  73. # Make sigma dimensionless with respect to MRI intensity scale.
  74. # sigma=1.0 means noise std equals total MRI intensity range, which is 65535 .
  75. return volume + (torch.randn_like(volume) * sigma * 65535)
  76. def _xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  77. data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
  78. data["Sex"] = data["Sex"].astype(str).str.strip()
  79. data = data.replace({"M": 0, "F": 1})
  80. return data
  81. def _build_test_loader(
  82. config: dict[str, Any], root_dir: Path
  83. ) -> torch.utils.data.DataLoader:
  84. mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
  85. xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
  86. dataset = load_adni_data_from_file(
  87. mri_files,
  88. xls_file,
  89. device=config["training"]["device"],
  90. xls_preprocessor=_xls_pre,
  91. )
  92. ptid_df = pd.read_csv(xls_file)
  93. ptid_df.columns = ptid_df.columns.str.strip()
  94. ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
  95. subset=["Image Data ID", "PTID"]
  96. )
  97. ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
  98. ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
  99. ptid_df = ptid_df[ptid_df["PTID"] != ""]
  100. ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
  101. splits = divide_dataset_by_patient_id(
  102. dataset,
  103. ptids,
  104. tuple(config["data"]["data_splits"]),
  105. seed=int(config["data"]["seed"]),
  106. )
  107. _, _, test_loader = initalize_dataloaders(
  108. splits,
  109. batch_size=int(config["training"]["batch_size"]),
  110. )
  111. return test_loader
  112. def _load_ensemble_models(config: dict[str, Any]) -> list[torch.nn.Module]:
  113. model_dir = Path(config["output"]["ensemble_path"])
  114. model_files = sorted(model_dir.glob("model_run_*.pt"))
  115. if not model_files:
  116. raise FileNotFoundError(f"No ensemble model files found in {model_dir}")
  117. models: list[torch.nn.Module] = []
  118. for model_file in model_files:
  119. model = (
  120. CNN3D(
  121. image_channels=int(config["data"]["image_channels"]),
  122. clin_data_channels=int(config["data"]["clin_data_channels"]),
  123. num_classes=int(config["data"]["num_classes"]),
  124. droprate=float(config["training"]["droprate"]),
  125. )
  126. .float()
  127. .to(config["training"]["device"])
  128. )
  129. model.load_state_dict(
  130. torch.load(model_file, map_location=config["training"]["device"]),
  131. strict=False,
  132. )
  133. model.eval()
  134. models.append(model)
  135. return models
  136. def _load_bayesian_model(config: dict[str, Any]) -> torch.nn.Module:
  137. device = str(config["training"]["device"])
  138. try:
  139. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  140. except ImportError as e:
  141. raise ImportError(
  142. "bayesian_torch is required for bayesian noise analysis"
  143. ) from e
  144. model_path = Path(config["output"]["bayesian_path"]) / "model_bayesian.pt"
  145. if not model_path.exists():
  146. raise FileNotFoundError(f"Bayesian model checkpoint not found: {model_path}")
  147. model = (
  148. CNN3D(
  149. image_channels=int(config["data"]["image_channels"]),
  150. clin_data_channels=int(config["data"]["clin_data_channels"]),
  151. num_classes=int(config["data"]["num_classes"]),
  152. droprate=float(config["training"]["droprate"]),
  153. )
  154. .float()
  155. .to(config["training"]["device"])
  156. )
  157. prior_params: dict[str, float | bool | str] = {
  158. "prior_mu": 0.0,
  159. "prior_sigma": 1.0,
  160. "posterior_mu_init": 0.0,
  161. "posterior_rho_init": -3.0,
  162. "type": "Reparameterization",
  163. "moped_enable": False,
  164. "moped_delta": 0.5,
  165. }
  166. dnn_to_bnn(model, prior_params)
  167. model.to(device)
  168. model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
  169. model.to(device)
  170. _enable_bayesian_sampling_mode(model)
  171. return model
  172. def _infer_with_noise_ensemble(
  173. test_loader: torch.utils.data.DataLoader,
  174. models: list[torch.nn.Module],
  175. sigma: float,
  176. class_index: int,
  177. ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  178. if not models:
  179. raise ValueError("No ensemble models were provided for noise inference")
  180. device = next(models[0].parameters()).device
  181. all_probs: list[float] = []
  182. all_stds: list[float] = []
  183. all_true: list[int] = []
  184. with torch.no_grad():
  185. for mri, xls, labels, _ in test_loader:
  186. mri_device = mri.float().to(device)
  187. xls_device = xls.float().to(device)
  188. labels_device = labels.to(device)
  189. noisy = _apply_scaled_noise(mri_device, sigma)
  190. preds = []
  191. for model in models:
  192. out = model((noisy, xls_device))
  193. preds.append(out[:, class_index].detach().cpu().numpy())
  194. pred_mat = np.stack(preds, axis=0)
  195. mean = pred_mat.mean(axis=0)
  196. std = pred_mat.std(axis=0)
  197. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  198. all_probs.extend(mean.tolist())
  199. all_stds.extend(std.tolist())
  200. all_true.extend(true.tolist())
  201. return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
  202. def _infer_with_noise_bayesian(
  203. test_loader: torch.utils.data.DataLoader,
  204. model: torch.nn.Module,
  205. sigma: float,
  206. class_index: int,
  207. mc_passes: int,
  208. ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
  209. device = next(model.parameters()).device
  210. all_probs: list[float] = []
  211. all_stds: list[float] = []
  212. all_true: list[int] = []
  213. with torch.no_grad():
  214. for mri, xls, labels, _ in test_loader:
  215. mri_device = mri.float().to(device)
  216. xls_device = xls.float().to(device)
  217. labels_device = labels.to(device)
  218. noisy = _apply_scaled_noise(mri_device, sigma)
  219. draws = []
  220. for _ in range(mc_passes):
  221. out = model((noisy, xls_device))
  222. draws.append(out.detach().cpu().numpy())
  223. draw_mat = np.stack(draws, axis=0) # (mc_passes, batch, classes)
  224. mean = draw_mat.mean(axis=0)[:, class_index]
  225. entropy_uncertainty = predictive_entropy(draw_mat)
  226. true = labels_device[:, class_index].detach().cpu().numpy().astype(int)
  227. all_probs.extend(mean.tolist())
  228. all_stds.extend(np.asarray(entropy_uncertainty, dtype=float).tolist())
  229. all_true.extend(true.tolist())
  230. return np.asarray(all_true), np.asarray(all_probs), np.asarray(all_stds)
  231. def run_noise_analysis(
  232. config: dict[str, Any],
  233. root_dir: Path,
  234. backend: str,
  235. output_dir: Path,
  236. class_index: int,
  237. noise_sigmas: list[float],
  238. threshold: float,
  239. calibration_bins: int,
  240. bayesian_mc_passes: int,
  241. ) -> dict[str, Any]:
  242. test_loader = _build_test_loader(config, root_dir)
  243. examples_dir = output_dir / "noise_examples"
  244. examples_dir.mkdir(parents=True, exist_ok=True)
  245. rows: list[dict[str, Any]] = []
  246. if backend == "ensemble":
  247. models = _load_ensemble_models(config)
  248. example_rows: list[tuple[float, torch.Tensor]] = []
  249. for sigma in noise_sigmas:
  250. y_true, y_prob, y_std = _infer_with_noise_ensemble(
  251. test_loader,
  252. models,
  253. sigma,
  254. class_index=class_index,
  255. )
  256. perf = performance_at_threshold(y_true, y_prob, threshold)
  257. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  258. rows.append(
  259. {
  260. "uncertainty_metric": "std",
  261. "sigma": float(sigma),
  262. "accuracy": float(perf["accuracy"]),
  263. "f1": float(perf["f1"]),
  264. "ece": float(cal["ece"]),
  265. "mce": float(cal["mce"]),
  266. "mean_confidence_certainty": float(
  267. np.nanmean(_confidence_certainty(y_prob))
  268. ),
  269. "mean_confidence_uncertainty": float(
  270. np.nanmean(_confidence_uncertainty(y_prob))
  271. ),
  272. "mean_std": float(np.nanmean(y_std)),
  273. "mean_predictive_entropy": float("nan"),
  274. }
  275. )
  276. with torch.no_grad():
  277. sample = next(iter(test_loader))
  278. original_mri = sample[0]
  279. device = next(models[0].parameters()).device
  280. original_device = original_mri.float().to(device)
  281. for sigma in noise_sigmas:
  282. noisy_mri = _apply_scaled_noise(original_device, sigma)
  283. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  284. _save_noise_example_grid(
  285. original_mri=original_mri,
  286. noisy_by_sigma=example_rows,
  287. output_path=examples_dir / f"{backend}_noise_examples.png",
  288. title=f"{backend.title()} Noise Examples",
  289. )
  290. elif backend == "bayesian":
  291. model = _load_bayesian_model(config)
  292. example_rows = []
  293. for sigma in noise_sigmas:
  294. y_true, y_prob, y_std = _infer_with_noise_bayesian(
  295. test_loader,
  296. model,
  297. sigma,
  298. class_index=class_index,
  299. mc_passes=bayesian_mc_passes,
  300. )
  301. perf = performance_at_threshold(y_true, y_prob, threshold)
  302. cal, _ = calibration_stats(y_true, y_prob, bins=calibration_bins)
  303. rows.append(
  304. {
  305. "uncertainty_metric": "predictive_entropy",
  306. "sigma": float(sigma),
  307. "accuracy": float(perf["accuracy"]),
  308. "f1": float(perf["f1"]),
  309. "ece": float(cal["ece"]),
  310. "mce": float(cal["mce"]),
  311. "mean_confidence_certainty": float(
  312. np.nanmean(_confidence_certainty(y_prob))
  313. ),
  314. "mean_confidence_uncertainty": float(
  315. np.nanmean(_confidence_uncertainty(y_prob))
  316. ),
  317. # Compatibility field name retained for downstream code.
  318. "mean_std": float(np.nanmean(y_std)),
  319. "mean_predictive_entropy": float(np.nanmean(y_std)),
  320. }
  321. )
  322. with torch.no_grad():
  323. sample = next(iter(test_loader))
  324. original_mri = sample[0]
  325. device = next(model.parameters()).device
  326. original_device = original_mri.float().to(device)
  327. for sigma in noise_sigmas:
  328. noisy_mri = _apply_scaled_noise(original_device, sigma)
  329. example_rows.append((float(sigma), noisy_mri.detach().cpu()))
  330. _save_noise_example_grid(
  331. original_mri=original_mri,
  332. noisy_by_sigma=example_rows,
  333. output_path=examples_dir / f"{backend}_noise_examples.png",
  334. title=f"{backend.title()} Noise Examples",
  335. )
  336. else:
  337. raise ValueError(f"Unsupported backend for noise analysis: {backend}")
  338. df = pd.DataFrame(rows).sort_values("sigma")
  339. csv_path = output_dir / "noise_sensitivity.csv"
  340. df.to_csv(csv_path, index=False)
  341. fig, ax = plt.subplots(figsize=(10, 5))
  342. ax.plot(df["sigma"], df["accuracy"], marker="o", label="accuracy")
  343. ax.plot(df["sigma"], df["f1"], marker="s", label="f1")
  344. ax.plot(df["sigma"], df["ece"], marker="^", label="ece")
  345. ax.set_xlabel("Gaussian Noise Sigma")
  346. ax.set_ylabel("Score")
  347. ax.set_title(f"Noise Sensitivity ({backend})")
  348. ax.grid(True, alpha=0.3)
  349. ax.legend()
  350. fig.tight_layout()
  351. plot_path = output_dir / "noise_sensitivity.png"
  352. fig.savefig(plot_path)
  353. plt.close(fig)
  354. fig_u, ax_u = plt.subplots(figsize=(10, 5))
  355. ax_u.plot(
  356. df["sigma"],
  357. df["mean_confidence_uncertainty"],
  358. marker="o",
  359. label="confidence_uncertainty",
  360. )
  361. ax_u.plot(df["sigma"], df["mean_std"], marker="s", label="std_uncertainty")
  362. ax_u.set_xlabel("Gaussian Noise Sigma")
  363. ax_u.set_ylabel("Uncertainty")
  364. ax_u.set_title(f"Uncertainty vs Noise ({backend})")
  365. ax_u.grid(True, alpha=0.3)
  366. ax_u.legend()
  367. fig_u.tight_layout()
  368. uncertainty_plot_path = output_dir / "noise_uncertainty.png"
  369. fig_u.savefig(uncertainty_plot_path)
  370. plt.close(fig_u)
  371. fig_c, ax_c = plt.subplots(figsize=(10, 5))
  372. ax_c.plot(
  373. df["sigma"],
  374. df["mean_confidence_certainty"],
  375. marker="o",
  376. label="confidence_certainty",
  377. )
  378. ax_c.set_xlabel("Gaussian Noise Sigma")
  379. ax_c.set_ylabel("Certainty")
  380. ax_c.set_title(f"Confidence Certainty vs Noise ({backend})")
  381. ax_c.grid(True, alpha=0.3)
  382. ax_c.legend()
  383. fig_c.tight_layout()
  384. certainty_plot_path = output_dir / "noise_confidence_certainty.png"
  385. fig_c.savefig(certainty_plot_path)
  386. plt.close(fig_c)
  387. out = {
  388. "table": str(csv_path),
  389. "plot": str(plot_path),
  390. "uncertainty_plot": str(uncertainty_plot_path),
  391. "certainty_plot": str(certainty_plot_path),
  392. "noise_sigmas": noise_sigmas,
  393. }
  394. write_json(output_dir / "noise_summary.json", out)
  395. return out