evaluate_models.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # This program evaluates every model on the combined validation and test set, then saves the results to a netcdf file.
  2. import json
  3. import pathlib as pl
  4. import numpy as np
  5. import pandas as pd
  6. import torch
  7. import xarray as xr
  8. from torch.utils.data import DataLoader
  9. # Custom modules
  10. from data.dataset import (
  11. ADNIDataset,
  12. divide_dataset,
  13. initalize_dataloaders,
  14. load_adni_data_from_file,
  15. )
  16. # Config
  17. from model.cnn import CNN3D
  18. from utils.config import config
  19. mri_files = pl.Path(config["data"]["mri_files_path"]).glob("*.nii")
  20. xls_file = pl.Path(config["data"]["xls_file_path"])
  21. def xls_pre(df: pd.DataFrame) -> pd.DataFrame:
  22. """
  23. Preprocess the Excel DataFrame.
  24. This function can be customized to filter or modify the DataFrame as needed.
  25. """
  26. data = df[["Image Data ID", "Sex", "Age (current)"]]
  27. data["Sex"] = data["Sex"].str.strip() # type: ignore
  28. data = data.replace({"M": 0, "F": 1}) # type: ignore
  29. data.set_index("Image Data ID") # type: ignore
  30. return data
  31. dataset = load_adni_data_from_file(
  32. mri_files, xls_file, device=config["training"]["device"], xls_preprocessor=xls_pre
  33. )
  34. # Divide the dataset into training and validation sets, using the same seed as training
  35. with open(pl.Path(config["output"]["path"]) / "config.json") as f:
  36. training_config = json.load(f)
  37. try:
  38. loaded_seed = int(training_config["data"]["seed"])
  39. except (ValueError, KeyError) as e:
  40. print(
  41. f"Warning: No previous seed found for dataset division, using seed from config. Error: {e}"
  42. )
  43. loaded_seed = config["data"]["seed"]
  44. datasets = divide_dataset(dataset, config["data"]["data_splits"], seed=loaded_seed)
  45. # Initialize the dataloadersx
  46. train_loader, val_loader, test_loader = initalize_dataloaders(
  47. datasets, batch_size=config["training"]["batch_size"]
  48. )
  49. # Combine validation and test sets for final evaluation
  50. combined_loader: DataLoader[ADNIDataset] = torch.utils.data.DataLoader(
  51. torch.utils.data.ConcatDataset([val_loader.dataset, test_loader.dataset]),
  52. batch_size=1,
  53. shuffle=False,
  54. )
  55. # 50 models are too large to load into memory at once, so we will load and evaluate them one at a time
  56. model_dir = pl.Path(config["output"]["path"])
  57. model_files = sorted(model_dir.glob("model_run_*.pt"))
  58. placeholder = np.zeros(
  59. (len(model_files), len(combined_loader), config["data"]["num_classes"]),
  60. dtype=np.float32,
  61. ) # Placeholder for results
  62. # Get the total list of image_ids
  63. img_ids = [img_id for _, _, _, img_id in combined_loader.dataset]
  64. placeholder[:] = np.nan # Fill with NaNs for easier identification of missing data
  65. dimensions = ["model", "img_id", "img_class"]
  66. coords = {
  67. "model": [int(mf.stem.split("_")[2]) for mf in model_files],
  68. "img_id": img_ids,
  69. "img_class": list(range(config["data"]["num_classes"])),
  70. }
  71. results = xr.DataArray(placeholder, coords=coords, dims=dimensions)
  72. # Now initialize an additional dataarray to hold the labels per image
  73. labels_placeholder = np.zeros(
  74. (len(combined_loader), config["data"]["num_classes"]), dtype=np.float32
  75. )
  76. labels_placeholder[:] = np.nan
  77. labels_coords = {
  78. "img_id": img_ids,
  79. "label": list(range(config["data"]["num_classes"])),
  80. } # type: ignore
  81. labels = xr.DataArray(
  82. labels_placeholder, coords=labels_coords, dims=["img_id", "label"]
  83. )
  84. for model_file in model_files:
  85. model_num = int(model_file.stem.split("_")[2])
  86. print(f"Evaluating model {model_num}...")
  87. # Load the model state
  88. model = (
  89. CNN3D(
  90. image_channels=config["data"]["image_channels"],
  91. clin_data_channels=config["data"]["clin_data_channels"],
  92. num_classes=config["data"]["num_classes"],
  93. droprate=config["training"]["droprate"],
  94. )
  95. .float()
  96. .to(config["training"]["device"])
  97. )
  98. model.load_state_dict(
  99. torch.load(model_file, map_location=config["training"]["device"]), strict=False
  100. )
  101. model.eval()
  102. with torch.no_grad():
  103. for batch_idx, (mri, xls, label, img_id) in enumerate(combined_loader):
  104. outputs = model((mri.float(), xls.float()))
  105. probabilities = outputs.cpu().numpy()[0, :] # type: ignore
  106. results.loc[model_num, img_id, :] = probabilities # type: ignore
  107. labels.loc[int(img_id.cpu()), :] = label.cpu().numpy()[0, :] # type: ignore
  108. # Combine results and labels into a single Dataset
  109. output_set = xr.Dataset({"predictions": results, "labels": labels})
  110. # Save results to netcdf file
  111. output_path = pl.Path(config["output"]["path"]) / "model_evaluation_results.nc"
  112. output_set.to_netcdf(output_path, mode="w") # type: ignore
  113. print(f"Results saved to {output_path}")