| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # pyright: basic
- from __future__ import annotations
- import json
- from pathlib import Path
- from typing import Any
- import numpy as np
- import pandas as pd
- import torch
- from torch.utils.data import ConcatDataset, DataLoader
- import xarray as xr
- from data.dataset import (
- divide_dataset_by_patient_id,
- initalize_dataloaders,
- load_adni_data_from_file,
- )
- from model.cnn import CNN3D
- def _enable_bayesian_sampling_mode(model: torch.nn.Module) -> None:
- # Keep stochastic Bayesian layers in training mode, but freeze BatchNorm
- # statistics to support batch_size=1 inference on holdout samples.
- model.train()
- for module in model.modules():
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- module.eval()
- def _xls_pre(df: pd.DataFrame) -> pd.DataFrame:
- data = df[["Image Data ID", "Sex", "Age (current)"]].copy()
- data["Sex"] = data["Sex"].astype(str).str.strip()
- data = data.replace({"M": 0, "F": 1})
- return data
- def _training_seed(config: dict[str, Any], backend_dir: Path) -> int:
- config_json = backend_dir / "config.json"
- if config_json.exists():
- try:
- with config_json.open("r", encoding="utf-8") as f:
- trained = json.load(f)
- return int(trained["data"]["seed"])
- except Exception:
- pass
- return int(config["data"]["seed"])
- def _build_holdout_loader(
- config: dict[str, Any], root_dir: Path, backend_dir: Path
- ) -> DataLoader:
- mri_files = (root_dir / config["data"]["mri_files_path"]).resolve().glob("*.nii")
- xls_file = (root_dir / config["data"]["xls_file_path"]).resolve()
- dataset = load_adni_data_from_file(
- mri_files,
- xls_file,
- device=config["training"]["device"],
- xls_preprocessor=_xls_pre,
- )
- ptid_df = pd.read_csv(xls_file)
- ptid_df.columns = ptid_df.columns.str.strip()
- ptid_df = ptid_df[["Image Data ID", "PTID"]].dropna(
- subset=["Image Data ID", "PTID"]
- )
- ptid_df["Image Data ID"] = ptid_df["Image Data ID"].astype(int)
- ptid_df["PTID"] = ptid_df["PTID"].astype(str).str.strip()
- ptid_df = ptid_df[ptid_df["PTID"] != ""]
- ptids = list(zip(ptid_df["Image Data ID"].tolist(), ptid_df["PTID"].tolist()))
- datasets = divide_dataset_by_patient_id(
- dataset,
- ptids,
- tuple(config["data"]["data_splits"]),
- seed=_training_seed(config, backend_dir),
- )
- _, val_loader, test_loader = initalize_dataloaders(datasets, batch_size=1)
- combined = ConcatDataset([val_loader.dataset, test_loader.dataset])
- return DataLoader(combined, batch_size=1, shuffle=False)
- def _init_cnn(config: dict[str, Any]) -> torch.nn.Module:
- return (
- CNN3D(
- image_channels=int(config["data"]["image_channels"]),
- clin_data_channels=int(config["data"]["clin_data_channels"]),
- num_classes=int(config["data"]["num_classes"]),
- droprate=float(config["training"]["droprate"]),
- )
- .float()
- .to(config["training"]["device"])
- )
- def _extract_img_id(img_id_tensor: Any) -> int:
- if hasattr(img_id_tensor, "detach"):
- return int(img_id_tensor.detach().cpu().item())
- return int(img_id_tensor)
- def _evaluate_ensemble(
- config: dict[str, Any],
- backend_dir: Path,
- holdout_loader: DataLoader,
- ) -> xr.Dataset:
- device = str(config["training"]["device"])
- model_files = sorted(backend_dir.glob("model_run_*.pt"))
- if not model_files:
- raise FileNotFoundError(f"No ensemble checkpoints found in {backend_dir}")
- n_models = len(model_files)
- n_samples = len(holdout_loader)
- n_classes = int(config["data"]["num_classes"])
- predictions = np.zeros((n_models, n_samples, n_classes), dtype=np.float32)
- labels = np.zeros((n_samples, n_classes), dtype=np.float32)
- image_ids = np.zeros((n_samples,), dtype=int)
- for model_i, model_file in enumerate(model_files):
- model = _init_cnn(config)
- model.load_state_dict(
- torch.load(model_file, map_location=device),
- strict=False,
- )
- model.to(device)
- model.eval()
- with torch.no_grad():
- for sample_i, (mri, xls, label, img_id) in enumerate(holdout_loader):
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- output = model((mri_device, xls_device))
- predictions[model_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
- if model_i == 0:
- labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
- image_ids[sample_i] = _extract_img_id(img_id)
- model_coord = [int(f.stem.split("_")[2]) for f in model_files]
- return xr.Dataset(
- {
- "predictions": xr.DataArray(
- predictions,
- dims=["model", "img_id", "img_class"],
- coords={
- "model": model_coord,
- "img_id": image_ids,
- "img_class": list(range(n_classes)),
- },
- ),
- "labels": xr.DataArray(
- labels,
- dims=["img_id", "label"],
- coords={
- "img_id": image_ids,
- "label": list(range(n_classes)),
- },
- ),
- }
- )
- def _evaluate_bayesian(
- config: dict[str, Any],
- backend_dir: Path,
- holdout_loader: DataLoader,
- mc_passes: int,
- ) -> xr.Dataset:
- device = str(config["training"]["device"])
- try:
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped]
- except ImportError as e:
- raise ImportError(
- "bayesian_torch is required to evaluate bayesian checkpoints"
- ) from e
- ckpt = backend_dir / "model_bayesian.pt"
- if not ckpt.exists():
- raise FileNotFoundError(f"Bayesian checkpoint not found: {ckpt}")
- model = _init_cnn(config)
- prior_params: dict[str, float | bool | str] = {
- "prior_mu": 0.0,
- "prior_sigma": 1.0,
- "posterior_mu_init": 0.0,
- "posterior_rho_init": -3.0,
- "type": "Reparameterization",
- "moped_enable": False,
- "moped_delta": 0.5,
- }
- dnn_to_bnn(model, prior_params)
- model.to(device)
- model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
- model.to(device)
- _enable_bayesian_sampling_mode(model)
- n_samples = len(holdout_loader)
- n_classes = int(config["data"]["num_classes"])
- predictions = np.zeros((mc_passes, n_samples, n_classes), dtype=np.float32)
- labels = np.zeros((n_samples, n_classes), dtype=np.float32)
- image_ids = np.zeros((n_samples,), dtype=int)
- with torch.no_grad():
- for pass_i in range(mc_passes):
- for sample_i, (mri, xls, label, img_id) in enumerate(holdout_loader):
- mri_device = mri.float().to(device)
- xls_device = xls.float().to(device)
- output = model((mri_device, xls_device))
- predictions[pass_i, sample_i, :] = output.detach().cpu().numpy()[0, :]
- if pass_i == 0:
- labels[sample_i, :] = label.detach().cpu().numpy()[0, :]
- image_ids[sample_i] = _extract_img_id(img_id)
- return xr.Dataset(
- {
- "predictions": xr.DataArray(
- predictions,
- dims=["sample", "img_id", "img_class"],
- coords={
- "sample": list(range(mc_passes)),
- "img_id": image_ids,
- "img_class": list(range(n_classes)),
- },
- ),
- "labels": xr.DataArray(
- labels,
- dims=["img_id", "label"],
- coords={
- "img_id": image_ids,
- "label": list(range(n_classes)),
- },
- ),
- }
- )
- def ensure_backend_netcdf(
- config: dict[str, Any],
- root_dir: Path,
- backend: str,
- bayesian_mc_passes: int,
- ) -> Path:
- backend_dir = Path(config["output"][f"{backend}_path"]).expanduser().resolve()
- backend_dir.mkdir(parents=True, exist_ok=True)
- output_path = backend_dir / "model_evaluation_results.nc"
- if output_path.exists():
- return output_path
- holdout_loader = _build_holdout_loader(config, root_dir, backend_dir)
- if backend == "ensemble":
- ds = _evaluate_ensemble(config, backend_dir, holdout_loader)
- elif backend == "bayesian":
- ds = _evaluate_bayesian(config, backend_dir, holdout_loader, bayesian_mc_passes)
- else:
- raise ValueError(f"Unsupported backend: {backend}")
- ds.to_netcdf(output_path, mode="w")
- return output_path
|