evaluate_models.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file.
  2. import torch
  3. import xarray as xr
  4. from torch.utils.data import DataLoader
  5. import numpy as np
  6. # Config
  7. from model.cnn import CNN3D
  8. from utils.config import config
  9. import pathlib as pl
  10. import pandas as pd
  11. import json
  12. # Custom modules
  13. from data.dataset import (
  14. load_adni_data_from_file,
  15. divide_dataset,
  16. initalize_dataloaders,
  17. ADNIDataset,
  18. )
  19. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  20. xls_file = pl.Path(config["data"]["xls_file_path"])
  21. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  22. """
  23. Preprocess the Excel DataFrame.
  24. This function can be customized to filter or modify the DataFrame as needed.
  25. """
  26. data = df[["Image Data ID", "Sex", "Age (current)"]]
  27. data["Sex"] = data["Sex"].str.strip() # type: ignore
  28. data = data.replace({"M": 0, "F": 1}) # type: ignore
  29. data.set_index("Image Data ID") # type: ignore
  30. return data
  31. dataset = load_adni_data_from_file(
  32. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  33. )
  34. # Divide the dataset into training and validation sets, using the same seed as training
  35. with open(pl.Path(config["output"]["path"]) / "config.json") as f:
  36. training_config = json.load(f)
  37. try:
  38. loaded_seed = int(training_config["data"]["seed"])
  39. except (ValueError, KeyError) as e:
  40. print(
  41. f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}"
  42. )
  43. loaded_seed = config["data"]["seed"]
  44. datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed)
  45. # Initialize the dataloaders
  46. train_loader, val_loader, test_loader = initalize_dataloaders(
  47. datasets, batch_size=config["training"]["batch_size"]
  48. )
  49. # Combine validation and test sets for final evaluation
  50. combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader(
  51. torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]),
  52. batch_size=1,
  53. shuffle=False,
  54. )
  55. # 50 models are too large to load into memory at once, so we will load and evaluate them one at a time
  56. model_dir = pl.Path(config["output"]["path"])
  57. model_files = sorted(model_dir.glob("model_run_*.pt"))
  58. placeholder = np.zeros(
  59. (len(model_files), len(combined_loader), config["data"]["num_classes"]),
  60. dtype=np.float32,
  61. ) # Placeholder for results
  62. placeholder[:] = np.nan # Fill with NaNs for easier identification of missing data
  63. dimensions = ["model", "batch", "img_class"]
  64. coords = {
  65. "model": [int(mf.stem.split("_")[2]) for mf in model_files],
  66. "batch": list(range(len(combined_loader))),
  67. "img_class": list(range(config["data"]["num_classes"])),
  68. }
  69. results = xr.DataArray(placeholder, coords=coords, dims=dimensions)
  70. for model_file in model_files:
  71. model_num = int(model_file.stem.split("_")[2])
  72. print(f"Evaluating model {model_num}...")
  73. # Load the model state
  74. model = (
  75. CNN3D(
  76. image_channels=config["data"]["image_channels"],
  77. clin_data_channels=config["data"]["clin_data_channels"],
  78. num_classes=config["data"]["num_classes"],
  79. droprate=config["training"]["droprate"],
  80. )
  81. .float()
  82. .to(config["training"]["device"])
  83. )
  84. model.load_state_dict(
  85. torch.load(model_file, map_location=config["training"]["device"]), strict=False
  86. )
  87. model.eval()
  88. with torch.no_grad():
  89. for batch_idx, (mri_batch, xls_batch, labels_batch) in enumerate(
  90. combined_loader
  91. ):
  92. outputs = model((mri_batch.float(), xls_batch.float()))
  93. probabilities = outputs.cpu().numpy()[0, :] # type: ignore
  94. results.loc[model_num, batch_idx, :] = probabilities # type: ignore
  95. # Save results to netcdf file
  96. output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
  97. results.to_netcdf(output_path, mode="w") # type: ignore
  98. print(f"Results saved to {output_path}")