|
@@ -0,0 +1,128 @@
|
|
|
|
+# This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file.
|
|
|
|
+
|
|
|
|
+import torch
|
|
|
|
+import xarray as xr
|
|
|
|
+from torch.utils.data import DataLoader
|
|
|
|
+import numpy as np
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Config
|
|
|
|
+from model.cnn import CNN3D
|
|
|
|
+from utils.config import config
|
|
|
|
+import pathlib as pl
|
|
|
|
+import pandas as pd
|
|
|
|
+import json
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Custom modules
|
|
|
|
+from data.dataset import (
|
|
|
|
+ load_adni_data_from_file,
|
|
|
|
+ divide_dataset,
|
|
|
|
+ initalize_dataloaders,
|
|
|
|
+ ADNIDataset,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
|
|
|
|
+xls_file = pl.Path(config["data"]["xls_file_path"])
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
|
|
|
|
+ """
|
|
|
|
+ Preprocess the Excel DataFrame.
|
|
|
|
+ This function can be customized to filter or modify the DataFrame as needed.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ data = df[["Image Data ID", "Sex", "Age (current)"]]
|
|
|
|
+ data["Sex"] = data["Sex"].str.strip() # type: ignore
|
|
|
|
+ data = data.replace({"M": 0, "F": 1}) # type: ignore
|
|
|
|
+ data.set_index("Image Data ID") # type: ignore
|
|
|
|
+
|
|
|
|
+ return data
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+dataset = load_adni_data_from_file(
|
|
|
|
+ mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+# Divide the dataset into training and validation sets, using the same seed as training
|
|
|
|
+with open(pl.Path(config["output"]["path"]) / "config.json") as f:
|
|
|
|
+ training_config = json.load(f)
|
|
|
|
+ try:
|
|
|
|
+ loaded_seed = int(training_config["data"]["seed"])
|
|
|
|
+ except (ValueError, KeyError) as e:
|
|
|
|
+ print(
|
|
|
|
+ f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}"
|
|
|
|
+ )
|
|
|
|
+ loaded_seed = config["data"]["seed"]
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Initialize the dataloaders
|
|
|
|
+train_loader, val_loader, test_loader = initalize_dataloaders(
|
|
|
|
+ datasets, batch_size=config["training"]["batch_size"]
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Combine validation and test sets for final evaluation
|
|
|
|
+combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader(
|
|
|
|
+ torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]),
|
|
|
|
+ batch_size=1,
|
|
|
|
+ shuffle=False,
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# 50 models are too large to load into memory at once, so we will load and evaluate them one at a time
|
|
|
|
+model_dir = pl.Path(config["output"]["path"])
|
|
|
|
+model_files = sorted(model_dir.glob("model_run_*.pt"))
|
|
|
|
+
|
|
|
|
+placeholder = np.zeros(
|
|
|
|
+ (len(model_files), len(combined_loader), config["data"]["num_classes"]),
|
|
|
|
+ dtype=np.float32,
|
|
|
|
+) # Placeholder for results
|
|
|
|
+
|
|
|
|
+placeholder[:] = np.nan # Fill with NaNs for easier identification of missing data
|
|
|
|
+dimensions = ["model", "batch", "img_class"]
|
|
|
|
+coords = {
|
|
|
|
+ "model": [int(mf.stem.split("_")[2]) for mf in model_files],
|
|
|
|
+ "batch": list(range(len(combined_loader))),
|
|
|
|
+ "img_class": list(range(config["data"]["num_classes"])),
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+results = xr.DataArray(placeholder, coords=coords, dims=dimensions)
|
|
|
|
+
|
|
|
|
+for model_file in model_files:
|
|
|
|
+ model_num = int(model_file.stem.split("_")[2])
|
|
|
|
+ print(f"Evaluating model {model_num}...")
|
|
|
|
+
|
|
|
|
+ # Load the model state
|
|
|
|
+ model = (
|
|
|
|
+ CNN3D(
|
|
|
|
+ image_channels=config["data"]["image_channels"],
|
|
|
|
+ clin_data_channels=config["data"]["clin_data_channels"],
|
|
|
|
+ num_classes=config["data"]["num_classes"],
|
|
|
|
+ droprate=config["training"]["droprate"],
|
|
|
|
+ )
|
|
|
|
+ .float()
|
|
|
|
+ .to(config["training"]["device"])
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ model.load_state_dict(
|
|
|
|
+ torch.load(model_file, map_location=config["training"]["device"]), strict=False
|
|
|
|
+ )
|
|
|
|
+ model.eval()
|
|
|
|
+
|
|
|
|
+ with torch.no_grad():
|
|
|
|
+ for batch_idx, (mri_batch, xls_batch, labels_batch) in enumerate(
|
|
|
|
+ combined_loader
|
|
|
|
+ ):
|
|
|
|
+ outputs = model((mri_batch.float(), xls_batch.float()))
|
|
|
|
+ probabilities = outputs.cpu().numpy()[0, :] # type: ignore
|
|
|
|
+
|
|
|
|
+ results.loc[model_num, batch_idx, :] = probabilities # type: ignore
|
|
|
|
+
|
|
|
|
+# Save results to netcdf file
|
|
|
|
+output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
|
|
|
|
+results.to_netcdf(output_path, mode="w") # type: ignore
|
|
|
|
+print(f"Results saved to {output_path}")
|