# pyright: basic from __future__ import annotations import json from pathlib import Path from typing import Any import numpy as np import torch from torch.utils.data import DataLoader from tqdm.auto import tqdm import xarray as xr from model.cnn import CNN3D from .data_pipeline import build_holdout_loader from .model_utils import configure_bayesian_sampling_mode def _training_seed(config: dict[str, Any], backend_dir: Path) -> int: config_json = backend_dir / "config.json" if config_json.exists(): try: with config_json.open("r", encoding="utf-8") as f: trained = json.load(f) return int(trained["data"]["seed"]) except Exception: pass return int(config["data"]["seed"]) def _init_cnn(config: dict[str, Any]) -> torch.nn.Module: return ( CNN3D( image_channels=int(config["data"]["image_channels"]), clin_data_channels=int(config["data"]["clin_data_channels"]), num_classes=int(config["data"]["num_classes"]), droprate=float(config["training"]["droprate"]), ) .float() .to(config["training"]["device"]) ) def _extract_img_id(img_id_tensor: Any) -> int: if hasattr(img_id_tensor, "detach"): return int(img_id_tensor.detach().cpu().item()) return int(img_id_tensor) def _evaluate_ensemble( config: dict[str, Any], backend_dir: Path, holdout_loader: DataLoader, ) -> xr.Dataset: device = str(config["training"]["device"]) model_files = sorted(backend_dir.glob("model_run_*.pt")) if not model_files: raise FileNotFoundError(f"No ensemble checkpoints found in {backend_dir}") n_models = len(model_files) n_samples = len(holdout_loader) n_classes = int(config["data"]["num_classes"]) predictions = np.zeros((n_models, n_samples, n_classes), dtype=np.float32) labels = np.zeros((n_samples, n_classes), dtype=np.float32) image_ids = np.zeros((n_samples,), dtype=int) model_iter = tqdm(model_files, desc="Ensemble checkpoints", unit="model") for model_i, model_file in enumerate(model_iter): model_iter.set_postfix_str(model_file.name) model = _init_cnn(config) model.load_state_dict( torch.load(model_file, map_location=device), strict=False, ) model.to(device) model.eval() with torch.no_grad(): sample_iter = tqdm( holdout_loader, total=n_samples, desc=f"{model_file.name}", unit="batch", leave=False, ) for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter): mri_device = mri.float().to(device) xls_device = xls.float().to(device) output = model((mri_device, xls_device)) predictions[model_i, sample_i, :] = output.detach().cpu().numpy()[0, :] if model_i == 0: labels[sample_i, :] = label.detach().cpu().numpy()[0, :] image_ids[sample_i] = _extract_img_id(img_id) model_coord = [int(f.stem.split("_")[2]) for f in model_files] return xr.Dataset( { "predictions": xr.DataArray( predictions, dims=["model", "img_id", "img_class"], coords={ "model": model_coord, "img_id": image_ids, "img_class": list(range(n_classes)), }, ), "labels": xr.DataArray( labels, dims=["img_id", "label"], coords={ "img_id": image_ids, "label": list(range(n_classes)), }, ), } ) def _evaluate_bayesian( config: dict[str, Any], backend_dir: Path, holdout_loader: DataLoader, mc_passes: int, ) -> xr.Dataset: device = str(config["training"]["device"]) try: from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn # type: ignore[import-untyped] except ImportError as e: raise ImportError( "bayesian_torch is required to evaluate bayesian checkpoints" ) from e ckpt = backend_dir / "model_bayesian.pt" if not ckpt.exists(): raise FileNotFoundError(f"Bayesian checkpoint not found: {ckpt}") model = _init_cnn(config) prior_params: dict[str, float | bool | str] = { "prior_mu": 0.0, "prior_sigma": 1.0, "posterior_mu_init": 0.0, "posterior_rho_init": -3.0, "type": "Reparameterization", "moped_enable": False, "moped_delta": 0.5, } dnn_to_bnn(model, prior_params) model.to(device) model.load_state_dict(torch.load(ckpt, map_location=device), strict=False) model.to(device) configure_bayesian_sampling_mode( model, stochastic=False, freeze_batchnorm=True, ) n_samples = len(holdout_loader) n_classes = int(config["data"]["num_classes"]) predictions = np.zeros((mc_passes, n_samples, n_classes), dtype=np.float32) labels = np.zeros((n_samples, n_classes), dtype=np.float32) image_ids = np.zeros((n_samples,), dtype=int) with torch.no_grad(): pass_iter = tqdm(range(mc_passes), desc="Bayesian MC passes", unit="pass") for pass_i in pass_iter: pass_iter.set_postfix_str(f"pass={pass_i + 1}/{mc_passes}") sample_iter = tqdm( holdout_loader, total=n_samples, desc=f"MC pass {pass_i + 1}", unit="batch", leave=False, ) for sample_i, (mri, xls, label, img_id) in enumerate(sample_iter): mri_device = mri.float().to(device) xls_device = xls.float().to(device) output = model((mri_device, xls_device)) predictions[pass_i, sample_i, :] = output.detach().cpu().numpy()[0, :] if pass_i == 0: labels[sample_i, :] = label.detach().cpu().numpy()[0, :] image_ids[sample_i] = _extract_img_id(img_id) return xr.Dataset( { "predictions": xr.DataArray( predictions, dims=["sample", "img_id", "img_class"], coords={ "sample": list(range(mc_passes)), "img_id": image_ids, "img_class": list(range(n_classes)), }, ), "labels": xr.DataArray( labels, dims=["img_id", "label"], coords={ "img_id": image_ids, "label": list(range(n_classes)), }, ), } ) def ensure_backend_netcdf( config: dict[str, Any], root_dir: Path, backend: str, bayesian_mc_passes: int, ) -> Path: backend_dir = Path(config["output"][f"{backend}_path"]).expanduser().resolve() backend_dir.mkdir(parents=True, exist_ok=True) output_path = backend_dir / "model_evaluation_results.nc" if output_path.exists(): return output_path holdout_loader = build_holdout_loader( config=config, root_dir=root_dir, seed=_training_seed(config, backend_dir), ) if backend == "ensemble": ds = _evaluate_ensemble(config, backend_dir, holdout_loader) elif backend == "bayesian": ds = _evaluate_bayesian(config, backend_dir, holdout_loader, bayesian_mc_passes) else: raise ValueError(f"Unsupported backend: {backend}") ds.to_netcdf(output_path, mode="w") return output_path