holdout_evaluation.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # pyright: basic
  2. from __future__ import annotations
  3. import json
  4. from pathlib import Path
  5. from typing import Any
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from torch.utils.data import ConcatDataset, DataLoader
  10. import xarray as xr
  11. from data.dataset import (
  12. divide_dataset_by_patient_id,
  13. initalize_dataloaders,
  14. load_adni_data_from_file,
  15. )
  16. from model.cnn import CNN3D
  17. def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None:
  18. # Keep stochastic Bayesian layers in training mode, but freeze BatchNorm
  19. # statistics to support batch_size=1 inference on holdout samples.
  20. model.train()
  21. for module in model.modules():
  22. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  23. module.eval()
  24. def _xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  25. data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
  26. data["Sex"] = data["Sex"].astype(str).str.strip()
  27. data = data.replace({"M": 0, "F": 1})
  28. return data
  29. def _training_seed(config: dict[str, Any], backend_dir: Path) -> int:
  30. config_json = backend_dir / "config.json"
  31. if config_json.exists():
  32. try:
  33. with config_json.open("r", encoding="utf-8") as f:
  34. trained = json.load(f)
  35. return int(trained["data"]["seed"])
  36. except Exception:
  37. pass
  38. return int(config["data"]["seed"])
  39. def _build_holdout_loader(
  40. config: dict[str, Any], root_dir: Path, backend_dir: Path
  41. ) -> DataLoader:
  42. mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
  43. xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
  44. dataset = load_adni_data_from_file(
  45. mri_files,
  46. xls_file,
  47. device=config["training"]["device"],
  48. xls_preprocessor=_xls_pre,
  49. )
  50. ptid_df = pd.read_csv(xls_file)
  51. ptid_df.columns = ptid_df.columns.str.strip()
  52. ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
  53. subset=["Image Data ID", "PTID"]
  54. )
  55. ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
  56. ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
  57. ptid_df = ptid_df[ptid_df["PTID"] != ""]
  58. ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
  59. datasets = divide_dataset_by_patient_id(
  60. dataset,
  61. ptids,
  62. tuple(config["data"]["data_splits"]),
  63. seed=_training_seed(config, backend_dir),
  64. )
  65. _, val_loader, test_loader = initalize_dataloaders(datasets, batch_size=1)
  66. combined = ConcatDataset([val_loader.dataset, test_loader.dataset])
  67. return DataLoader(combined, batch_size=1, shuffle=False)
  68. def _init_cnn(config: dict[str, Any]) -> torch.nn.Module:
  69. return (
  70. CNN3D(
  71. image_channels=int(config["data"]["image_channels"]),
  72. clin_data_channels=int(config["data"]["clin_data_channels"]),
  73. num_classes=int(config["data"]["num_classes"]),
  74. droprate=float(config["training"]["droprate"]),
  75. )
  76. .float()
  77. .to(config["training"]["device"])
  78. )
  79. def _extract_img_id(img_id_tensor: Any) -> int:
  80. if hasattr(img_id_tensor, "detach"):
  81. return int(img_id_tensor.detach().cpu().item())
  82. return int(img_id_tensor)
  83. def _evaluate_ensemble(
  84. config: dict[str, Any],
  85. backend_dir: Path,
  86. holdout_loader: DataLoader,
  87. ) -> xr.Dataset:
  88. device = str(config["training"]["device"])
  89. model_files = sorted(backend_dir.glob("model_run_*.pt"))
  90. if not model_files:
  91. raise FileNotFoundError(f"No ensemble checkpoints found in {backend_dir}")
  92. n_models = len(model_files)
  93. n_samples = len(holdout_loader)
  94. n_classes = int(config["data"]["num_classes"])
  95. predictions = np.zeros((n_models, n_samples, n_classes), dtype=np.float32)
  96. labels = np.zeros((n_samples, n_classes), dtype=np.float32)
  97. image_ids = np.zeros((n_samples,), dtype=int)
  98. for model_i, model_file in enumerate(model_files):
  99. model = _init_cnn(config)
  100. model.load_state_dict(
  101. torch.load(model_file, map_location=device),
  102. strict=False,
  103. )
  104. model.to(device)
  105. model.eval()
  106. with torch.no_grad():
  107. for sample_i, (mri, xls, label, img_id) in enumerate(holdout_loader):
  108. mri_device = mri.float().to(device)
  109. xls_device = xls.float().to(device)
  110. output = model((mri_device, xls_device))
  111. predictions[model_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
  112. if model_i == 0:
  113. labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
  114. image_ids[sample_i] = _extract_img_id(img_id)
  115. model_coord = [int(f.stem.split("_")[2]) for f in model_files]
  116. return xr.Dataset(
  117. {
  118. "predictions": xr.DataArray(
  119. predictions,
  120. dims=["model", "img_id", "img_class"],
  121. coords={
  122. "model": model_coord,
  123. "img_id": image_ids,
  124. "img_class": list(range(n_classes)),
  125. },
  126. ),
  127. "labels": xr.DataArray(
  128. labels,
  129. dims=["img_id", "label"],
  130. coords={
  131. "img_id": image_ids,
  132. "label": list(range(n_classes)),
  133. },
  134. ),
  135. }
  136. )
  137. def _evaluate_bayesian(
  138. config: dict[str, Any],
  139. backend_dir: Path,
  140. holdout_loader: DataLoader,
  141. mc_passes: int,
  142. ) -> xr.Dataset:
  143. device = str(config["training"]["device"])
  144. try:
  145. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  146. except ImportError as e:
  147. raise ImportError(
  148. "bayesian_torch is required to evaluate bayesian checkpoints"
  149. ) from e
  150. ckpt = backend_dir / "model_bayesian.pt"
  151. if not ckpt.exists():
  152. raise FileNotFoundError(f"Bayesian checkpoint not found: {ckpt}")
  153. model = _init_cnn(config)
  154. prior_params: dict[str, float | bool | str] = {
  155. "prior_mu": 0.0,
  156. "prior_sigma": 1.0,
  157. "posterior_mu_init": 0.0,
  158. "posterior_rho_init": -3.0,
  159. "type": "Reparameterization",
  160. "moped_enable": False,
  161. "moped_delta": 0.5,
  162. }
  163. dnn_to_bnn(model, prior_params)
  164. model.to(device)
  165. model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
  166. model.to(device)
  167. _enable_bayesian_sampling_mode(model)
  168. n_samples = len(holdout_loader)
  169. n_classes = int(config["data"]["num_classes"])
  170. predictions = np.zeros((mc_passes, n_samples, n_classes), dtype=np.float32)
  171. labels = np.zeros((n_samples, n_classes), dtype=np.float32)
  172. image_ids = np.zeros((n_samples,), dtype=int)
  173. with torch.no_grad():
  174. for pass_i in range(mc_passes):
  175. for sample_i, (mri, xls, label, img_id) in enumerate(holdout_loader):
  176. mri_device = mri.float().to(device)
  177. xls_device = xls.float().to(device)
  178. output = model((mri_device, xls_device))
  179. predictions[pass_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
  180. if pass_i == 0:
  181. labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
  182. image_ids[sample_i] = _extract_img_id(img_id)
  183. return xr.Dataset(
  184. {
  185. "predictions": xr.DataArray(
  186. predictions,
  187. dims=["sample", "img_id", "img_class"],
  188. coords={
  189. "sample": list(range(mc_passes)),
  190. "img_id": image_ids,
  191. "img_class": list(range(n_classes)),
  192. },
  193. ),
  194. "labels": xr.DataArray(
  195. labels,
  196. dims=["img_id", "label"],
  197. coords={
  198. "img_id": image_ids,
  199. "label": list(range(n_classes)),
  200. },
  201. ),
  202. }
  203. )
  204. def ensure_backend_netcdf(
  205. config: dict[str, Any],
  206. root_dir: Path,
  207. backend: str,
  208. bayesian_mc_passes: int,
  209. ) -> Path:
  210. backend_dir = Path(config["output"][f"{backend}_path"]).expanduser().resolve()
  211. backend_dir.mkdir(parents=True, exist_ok=True)
  212. output_path = backend_dir / "model_evaluation_results.nc"
  213. if output_path.exists():
  214. return output_path
  215. holdout_loader = _build_holdout_loader(config, root_dir, backend_dir)
  216. if backend == "ensemble":
  217. ds = _evaluate_ensemble(config, backend_dir, holdout_loader)
  218. elif backend == "bayesian":
  219. ds = _evaluate_bayesian(config, backend_dir, holdout_loader, bayesian_mc_passes)
  220. else:
  221. raise ValueError(f"Unsupported backend: {backend}")
  222. ds.to_netcdf(output_path, mode="w")
  223. return output_path