holdout_evaluation.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # pyright: basic
  2. from __future__ import annotations
  3. import json
  4. from pathlib import Path
  5. from typing import Any
  6. import numpy as np
  7. import torch
  8. from torch.utils.data import DataLoader
  9. from tqdm.auto import tqdm
  10. import xarray as xr
  11. from model.cnn import CNN3D
  12. from .data_pipeline import build_holdout_loader
  13. from .model_utils import configure_bayesian_sampling_mode
  14. def _training_seed(config: dict[str, Any], backend_dir: Path) -> int:
  15. config_json = backend_dir / "config.json"
  16. if config_json.exists():
  17. try:
  18. with config_json.open("r", encoding="utf-8") as f:
  19. trained = json.load(f)
  20. return int(trained["data"]["seed"])
  21. except Exception:
  22. pass
  23. return int(config["data"]["seed"])
  24. def _init_cnn(config: dict[str, Any]) -> torch.nn.Module:
  25. return (
  26. CNN3D(
  27. image_channels=int(config["data"]["image_channels"]),
  28. clin_data_channels=int(config["data"]["clin_data_channels"]),
  29. num_classes=int(config["data"]["num_classes"]),
  30. droprate=float(config["training"]["droprate"]),
  31. )
  32. .float()
  33. .to(config["training"]["device"])
  34. )
  35. def _extract_img_id(img_id_tensor: Any) -> int:
  36. if hasattr(img_id_tensor, "detach"):
  37. return int(img_id_tensor.detach().cpu().item())
  38. return int(img_id_tensor)
  39. def _evaluate_ensemble(
  40. config: dict[str, Any],
  41. backend_dir: Path,
  42. holdout_loader: DataLoader,
  43. ) -> xr.Dataset:
  44. device = str(config["training"]["device"])
  45. model_files = sorted(backend_dir.glob("model_run_*.pt"))
  46. if not model_files:
  47. raise FileNotFoundError(f"No ensemble checkpoints found in {backend_dir}")
  48. n_models = len(model_files)
  49. n_samples = len(holdout_loader)
  50. n_classes = int(config["data"]["num_classes"])
  51. predictions = np.zeros((n_models, n_samples, n_classes), dtype=np.float32)
  52. labels = np.zeros((n_samples, n_classes), dtype=np.float32)
  53. image_ids = np.zeros((n_samples,), dtype=int)
  54. model_iter = tqdm(model_files, desc="Ensemble checkpoints", unit="model")
  55. for model_i, model_file in enumerate(model_iter):
  56. model_iter.set_postfix_str(model_file.name)
  57. model = _init_cnn(config)
  58. model.load_state_dict(
  59. torch.load(model_file, map_location=device),
  60. strict=False,
  61. )
  62. model.to(device)
  63. model.eval()
  64. with torch.no_grad():
  65. sample_iter = tqdm(
  66. holdout_loader,
  67. total=n_samples,
  68. desc=f"{model_file.name}",
  69. unit="batch",
  70. leave=False,
  71. )
  72. for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter):
  73. mri_device = mri.float().to(device)
  74. xls_device = xls.float().to(device)
  75. output = model((mri_device, xls_device))
  76. predictions[model_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
  77. if model_i == 0:
  78. labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
  79. image_ids[sample_i] = _extract_img_id(img_id)
  80. model_coord = [int(f.stem.split("_")[2]) for f in model_files]
  81. return xr.Dataset(
  82. {
  83. "predictions": xr.DataArray(
  84. predictions,
  85. dims=["model", "img_id", "img_class"],
  86. coords={
  87. "model": model_coord,
  88. "img_id": image_ids,
  89. "img_class": list(range(n_classes)),
  90. },
  91. ),
  92. "labels": xr.DataArray(
  93. labels,
  94. dims=["img_id", "label"],
  95. coords={
  96. "img_id": image_ids,
  97. "label": list(range(n_classes)),
  98. },
  99. ),
  100. }
  101. )
  102. def _evaluate_bayesian(
  103. config: dict[str, Any],
  104. backend_dir: Path,
  105. holdout_loader: DataLoader,
  106. mc_passes: int,
  107. ) -> xr.Dataset:
  108. device = str(config["training"]["device"])
  109. try:
  110. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
  111. except ImportError as e:
  112. raise ImportError(
  113. "bayesian_torch is required to evaluate bayesian checkpoints"
  114. ) from e
  115. ckpt = backend_dir / "model_bayesian.pt"
  116. if not ckpt.exists():
  117. raise FileNotFoundError(f"Bayesian checkpoint not found: {ckpt}")
  118. model = _init_cnn(config)
  119. prior_params: dict[str, float | bool | str] = {
  120. "prior_mu": 0.0,
  121. "prior_sigma": 1.0,
  122. "posterior_mu_init": 0.0,
  123. "posterior_rho_init": -3.0,
  124. "type": "Reparameterization",
  125. "moped_enable": False,
  126. "moped_delta": 0.5,
  127. }
  128. dnn_to_bnn(model, prior_params)
  129. model.to(device)
  130. model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
  131. model.to(device)
  132. configure_bayesian_sampling_mode(
  133. model,
  134. stochastic=False,
  135. freeze_batchnorm=True,
  136. )
  137. n_samples = len(holdout_loader)
  138. n_classes = int(config["data"]["num_classes"])
  139. predictions = np.zeros((mc_passes, n_samples, n_classes), dtype=np.float32)
  140. labels = np.zeros((n_samples, n_classes), dtype=np.float32)
  141. image_ids = np.zeros((n_samples,), dtype=int)
  142. with torch.no_grad():
  143. pass_iter = tqdm(range(mc_passes), desc="Bayesian MC passes", unit="pass")
  144. for pass_i in pass_iter:
  145. pass_iter.set_postfix_str(f"pass={pass_i + 1}/{mc_passes}")
  146. sample_iter = tqdm(
  147. holdout_loader,
  148. total=n_samples,
  149. desc=f"MC pass {pass_i + 1}",
  150. unit="batch",
  151. leave=False,
  152. )
  153. for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter):
  154. mri_device = mri.float().to(device)
  155. xls_device = xls.float().to(device)
  156. output = model((mri_device, xls_device))
  157. predictions[pass_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
  158. if pass_i == 0:
  159. labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
  160. image_ids[sample_i] = _extract_img_id(img_id)
  161. return xr.Dataset(
  162. {
  163. "predictions": xr.DataArray(
  164. predictions,
  165. dims=["sample", "img_id", "img_class"],
  166. coords={
  167. "sample": list(range(mc_passes)),
  168. "img_id": image_ids,
  169. "img_class": list(range(n_classes)),
  170. },
  171. ),
  172. "labels": xr.DataArray(
  173. labels,
  174. dims=["img_id", "label"],
  175. coords={
  176. "img_id": image_ids,
  177. "label": list(range(n_classes)),
  178. },
  179. ),
  180. }
  181. )
  182. def ensure_backend_netcdf(
  183. config: dict[str, Any],
  184. root_dir: Path,
  185. backend: str,
  186. bayesian_mc_passes: int,
  187. ) -> Path:
  188. backend_dir = Path(config["output"][f"{backend}_path"]).expanduser().resolve()
  189. backend_dir.mkdir(parents=True, exist_ok=True)
  190. output_path = backend_dir / "model_evaluation_results.nc"
  191. if output_path.exists():
  192. return output_path
  193. holdout_loader = build_holdout_loader(
  194. config=config,
  195. root_dir=root_dir,
  196. seed=_training_seed(config, backend_dir),
  197. )
  198. if backend == "ensemble":
  199. ds = _evaluate_ensemble(config, backend_dir, holdout_loader)
  200. elif backend == "bayesian":
  201. ds = _evaluate_bayesian(config, backend_dir, holdout_loader, bayesian_mc_passes)
  202. else:
  203. raise ValueError(f"Unsupported backend: {backend}")
  204. ds.to_netcdf(output_path, mode="w")
  205. return output_path