| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- # pyright: basic
- from __future__ import annotations
- import json
- from pathlib import Path
- from typing import Any
- import numpy as np
- import torch
- from torch.utils.data import DataLoader
- from tqdm.auto import tqdm
- import xarray as xr
- from model.cnn import CNN3D
- from .data_pipeline import build_holdout_loader
- from .model_utils import configure_bayesian_sampling_mode
- def _training_seed(config: dict[str, Any], backend_dir: Path) -> int:
- config_json = backend_dir / "config.json"
- if config_json.exists():
- try:
- with config_json.open("r", encoding="utf-8") as f:
- trained = json.load(f)
- return int(trained["data"]["seed"])
- except Exception:
- pass
- return int(config["data"]["seed"])
- def _init_cnn(config: dict[str, Any]) -> torch.nn.Module:
- return (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- def _extract_img_id(img_id_tensor: Any) -> int:
- if hasattr(img_id_tensor, "detach"):
- return int(img_id_tensor.detach().cpu().item())
- return int(img_id_tensor)
- def _evaluate_ensemble(
- config: dict[str, Any],
- backend_dir: Path,
- holdout_loader: DataLoader,
- ) -> xr.Dataset:
- device = str(config["training"]["device"])
- model_files = sorted(backend_dir.glob("model_run_*.pt"))
- if not model_files:
- raise FileNotFoundError(f"No ensemble checkpoints found in {backend_dir}")
- n_models = len(model_files)
- n_samples = len(holdout_loader)
- n_classes = int(config["data"]["num_classes"])
- predictions = np.zeros((n_models, n_samples, n_classes), dtype=np.float32)
- labels = np.zeros((n_samples, n_classes), dtype=np.float32)
- image_ids = np.zeros((n_samples,), dtype=int)
- model_iter = tqdm(model_files, desc="Ensemble checkpoints", unit="model")
- for model_i, model_file in enumerate(model_iter):
- model_iter.set_postfix_str(model_file.name)
- model = _init_cnn(config)
- model.load_state_dict(
- torch.load(model_file, map_location=device),
- strict=False,
- )
- model.to(device)
- model.eval()
- with torch.no_grad():
- sample_iter = tqdm(
- holdout_loader,
- total=n_samples,
- desc=f"{model_file.name}",
- unit="batch",
- leave=False,
- )
- for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter):
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- output = model((mri_device, xls_device))
- predictions[model_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
- if model_i == 0:
- labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
- image_ids[sample_i] = _extract_img_id(img_id)
- model_coord = [int(f.stem.split("_")[2]) for f in model_files]
- return xr.Dataset(
- {
- "predictions": xr.DataArray(
- predictions,
- dims=["model", "img_id", "img_class"],
- coords={
- "model": model_coord,
- "img_id": image_ids,
- "img_class": list(range(n_classes)),
- },
- ),
- "labels": xr.DataArray(
- labels,
- dims=["img_id", "label"],
- coords={
- "img_id": image_ids,
- "label": list(range(n_classes)),
- },
- ),
- }
- )
- def _evaluate_bayesian(
- config: dict[str, Any],
- backend_dir: Path,
- holdout_loader: DataLoader,
- mc_passes: int,
- ) -> xr.Dataset:
- device = str(config["training"]["device"])
- try:
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
- except ImportError as e:
- raise ImportError(
- "bayesian_torch is required to evaluate bayesian checkpoints"
- ) from e
- ckpt = backend_dir / "model_bayesian.pt"
- if not ckpt.exists():
- raise FileNotFoundError(f"Bayesian checkpoint not found: {ckpt}")
- model = _init_cnn(config)
- prior_params: dict[str, float | bool | str] = {
- "prior_mu": 0.0,
- "prior_sigma": 1.0,
- "posterior_mu_init": 0.0,
- "posterior_rho_init": -3.0,
- "type": "Reparameterization",
- "moped_enable": False,
- "moped_delta": 0.5,
- }
- dnn_to_bnn(model, prior_params)
- model.to(device)
- model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
- model.to(device)
- configure_bayesian_sampling_mode(
- model,
- stochastic=False,
- freeze_batchnorm=True,
- )
- n_samples = len(holdout_loader)
- n_classes = int(config["data"]["num_classes"])
- predictions = np.zeros((mc_passes, n_samples, n_classes), dtype=np.float32)
- labels = np.zeros((n_samples, n_classes), dtype=np.float32)
- image_ids = np.zeros((n_samples,), dtype=int)
- with torch.no_grad():
- pass_iter = tqdm(range(mc_passes), desc="Bayesian MC passes", unit="pass")
- for pass_i in pass_iter:
- pass_iter.set_postfix_str(f"pass={pass_i + 1}/{mc_passes}")
- sample_iter = tqdm(
- holdout_loader,
- total=n_samples,
- desc=f"MC pass {pass_i + 1}",
- unit="batch",
- leave=False,
- )
- for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter):
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- output = model((mri_device, xls_device))
- predictions[pass_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
- if pass_i == 0:
- labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
- image_ids[sample_i] = _extract_img_id(img_id)
- return xr.Dataset(
- {
- "predictions": xr.DataArray(
- predictions,
- dims=["sample", "img_id", "img_class"],
- coords={
- "sample": list(range(mc_passes)),
- "img_id": image_ids,
- "img_class": list(range(n_classes)),
- },
- ),
- "labels": xr.DataArray(
- labels,
- dims=["img_id", "label"],
- coords={
- "img_id": image_ids,
- "label": list(range(n_classes)),
- },
- ),
- }
- )
- def ensure_backend_netcdf(
- config: dict[str, Any],
- root_dir: Path,
- backend: str,
- bayesian_mc_passes: int,
- ) -> Path:
- backend_dir = Path(config["output"][f"{backend}_path"]).expanduser().resolve()
- backend_dir.mkdir(parents=True, exist_ok=True)
- output_path = backend_dir / "model_evaluation_results.nc"
- if output_path.exists():
- return output_path
- holdout_loader = build_holdout_loader(
- config=config,
- root_dir=root_dir,
- seed=_training_seed(config, backend_dir),
- )
- if backend == "ensemble":
- ds = _evaluate_ensemble(config, backend_dir, holdout_loader)
- elif backend == "bayesian":
- ds = _evaluate_bayesian(config, backend_dir, holdout_loader, bayesian_mc_passes)
- else:
- raise ValueError(f"Unsupported backend: {backend}")
- ds.to_netcdf(output_path, mode="w")
- return output_path
|