# This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file. import torch import xarray as xr from torch.utils.data import DataLoader import numpy as np # Config from model.cnn import CNN3D from utils.config import config import pathlib as pl import pandas as pd import json # Custom modules from data.dataset import ( load_adni_data_from_file, divide_dataset, initalize_dataloaders, ADNIDataset, ) mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii") xls_file = pl.Path(config["data"]["xls_file_path"]) def xls_pre(df: pd.DataFrame) -> pd.DataFrame: """ Preprocess the Excel DataFrame. This function can be customized to filter or modify the DataFrame as needed. """ data = df[["Image Data ID", "Sex", "Age (current)"]] data["Sex"] = data["Sex"].str.strip() # type: ignore data = data.replace({"M": 0, "F": 1}) # type: ignore data.set_index("Image Data ID") # type: ignore return data dataset = load_adni_data_from_file( mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre ) # Divide the dataset into training and validation sets, using the same seed as training with open(pl.Path(config["output"]["path"]) / "config.json") as f: training_config = json.load(f) try: loaded_seed = int(training_config["data"]["seed"]) except (ValueError, KeyError) as e: print( f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}" ) loaded_seed = config["data"]["seed"] datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed) # Initialize the dataloaders train_loader, val_loader, test_loader = initalize_dataloaders( datasets, batch_size=config["training"]["batch_size"] ) # Combine validation and test sets for final evaluation combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader( torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]), batch_size=1, shuffle=False, ) # 50 models are too large to load into memory at once, so we will load and evaluate them one at a time model_dir = pl.Path(config["output"]["path"]) model_files = sorted(model_dir.glob("model_run_*.pt")) placeholder = np.zeros( (len(model_files), len(combined_loader), config["data"]["num_classes"]), dtype=np.float32, ) # Placeholder for results placeholder[:] = np.nan # Fill with NaNs for easier identification of missing data dimensions = ["model", "batch", "img_class"] coords = { "model": [int(mf.stem.split("_")[2]) for mf in model_files], "batch": list(range(len(combined_loader))), "img_class": list(range(config["data"]["num_classes"])), } results = xr.DataArray(placeholder, coords=coords, dims=dimensions) for model_file in model_files: model_num = int(model_file.stem.split("_")[2]) print(f"Evaluating model {model_num}...") # Load the model state model = ( CNN3D( image_channels=config["data"]["image_channels"], clin_data_channels=config["data"]["clin_data_channels"], num_classes=config["data"]["num_classes"], droprate=config["training"]["droprate"], ) .float() .to(config["training"]["device"]) ) model.load_state_dict( torch.load(model_file, map_location=config["training"]["device"]), strict=False ) model.eval() with torch.no_grad(): for batch_idx, (mri_batch, xls_batch, labels_batch) in enumerate( combined_loader ): outputs = model((mri_batch.float(), xls_batch.float())) probabilities = outputs.cpu().numpy()[0, :] # type: ignore results.loc[model_num, batch_idx, :] = probabilities # type: ignore # Save results to netcdf file output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc" results.to_netcdf(output_path, mode="w") # type: ignore print(f"Results saved to {output_path}")