123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- # This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file.
- import torch
- import xarray as xr
- from torch.utils.data import DataLoader
- import numpy as np
- # Config
- from model.cnn import CNN3D
- from utils.config import config
- import pathlib as pl
- import pandas as pd
- import json
- # Custom modules
- from data.dataset import (
- load_adni_data_from_file,
- divide_dataset,
- initalize_dataloaders,
- ADNIDataset,
- )
- mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
- xls_file = pl.Path(config["data"]["xls_file_path"])
- def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
- """
- Preprocess the Excel DataFrame.
- This function can be customized to filter or modify the DataFrame as needed.
- """
- data = df[["Image Data ID", "Sex", "Age (current)"]]
- data["Sex"] = data["Sex"].str.strip() # type: ignore
- data = data.replace({"M": 0, "F": 1}) # type: ignore
- data.set_index("Image Data ID") # type: ignore
- return data
- dataset = load_adni_data_from_file(
- mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
- )
- # Divide the dataset into training and validation sets, using the same seed as training
- with open(pl.Path(config["output"]["path"]) / "config.json") as f:
- training_config = json.load(f)
- try:
- loaded_seed = int(training_config["data"]["seed"])
- except (ValueError, KeyError) as e:
- print(
- f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}"
- )
- loaded_seed = config["data"]["seed"]
- datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed)
- # Initialize the dataloaders
- train_loader, val_loader, test_loader = initalize_dataloaders(
- datasets, batch_size=config["training"]["batch_size"]
- )
- # Combine validation and test sets for final evaluation
- combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader(
- torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]),
- batch_size=1,
- shuffle=False,
- )
- # 50 models are too large to load into memory at once, so we will load and evaluate them one at a time
- model_dir = pl.Path(config["output"]["path"])
- model_files = sorted(model_dir.glob("model_run_*.pt"))
- placeholder = np.zeros(
- (len(model_files), len(combined_loader), config["data"]["num_classes"]),
- dtype=np.float32,
- ) # Placeholder for results
- placeholder[:] = np.nan # Fill with NaNs for easier identification of missing data
- dimensions = ["model", "batch", "img_class"]
- coords = {
- "model": [int(mf.stem.split("_")[2]) for mf in model_files],
- "batch": list(range(len(combined_loader))),
- "img_class": list(range(config["data"]["num_classes"])),
- }
- results = xr.DataArray(placeholder, coords=coords, dims=dimensions)
- for model_file in model_files:
- model_num = int(model_file.stem.split("_")[2])
- print(f"Evaluating model {model_num}...")
- # Load the model state
- model = (
- CNN3D(
- image_channels=config["data"]["image_channels"],
- clin_data_channels=config["data"]["clin_data_channels"],
- num_classes=config["data"]["num_classes"],
- droprate=config["training"]["droprate"],
- )
- .float()
- .to(config["training"]["device"])
- )
- model.load_state_dict(
- torch.load(model_file, map_location=config["training"]["device"]), strict=False
- )
- model.eval()
- with torch.no_grad():
- for batch_idx, (mri_batch, xls_batch, labels_batch) in enumerate(
- combined_loader
- ):
- outputs = model((mri_batch.float(), xls_batch.float()))
- probabilities = outputs.cpu().numpy()[0, :] # type: ignore
- results.loc[model_num, batch_idx, :] = probabilities # type: ignore
- # Save results to netcdf file
- output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
- results.to_netcdf(output_path, mode="w") # type: ignore
- print(f"Results saved to {output_path}")
|