dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import nibabel as nib
  2. import torch
  3. import torch.utils.data as data
  4. import pathlib as pl
  5. import pandas as pd
  6. from torch.utils.data import Subset, DataLoader
  7. import re
  8. import random
  9. from jaxtyping import Float
  10. from typing import Tuple, Iterator, Callable, List, Dict
  11. class ADNIDataset(data.Dataset): # type: ignore
  12. """
  13. A PyTorch Dataset class for loading
  14. and processing MRI and Excel data from the ADNI dataset.
  15. """
  16. def __init__(
  17. self,
  18. mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
  19. xls_data: Float[torch.Tensor, "n_samples features"],
  20. expected_classes: Float[torch.Tensor, "classes"],
  21. filename_ids: List[int],
  22. device: str = "cuda",
  23. ):
  24. """
  25. Args:
  26. mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
  27. xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
  28. """
  29. self.mri_data = mri_data.float().to(device)
  30. self.xls_data = xls_data.float().to(device)
  31. self.expected_classes = expected_classes.float().to(device)
  32. self.filename_ids = filename_ids
  33. def __len__(self) -> int:
  34. """
  35. Returns the number of samples in the dataset.
  36. """
  37. return self.mri_data.shape[0] # 0th dimension is the number of samples
  38. def __getitem__(self, idx: int) -> Tuple[
  39. Float[torch.Tensor, "channels width height depth"],
  40. Float[torch.Tensor, "features"],
  41. Float[torch.Tensor, "classes"],
  42. int,
  43. ]:
  44. """
  45. Returns a sample from the dataset at the given index.
  46. Args:
  47. idx (int): Index of the sample to retrieve.
  48. Returns:
  49. tuple: A tuple containing the MRI data and Excel data for the sample.
  50. """
  51. # Slices the data on the 0th dimension, corresponding to the sample index
  52. mri_sample = self.mri_data[idx]
  53. xls_sample = self.xls_data[idx]
  54. # Assuming expected_classes is a tensor of classes, we return it as well
  55. expected_classes = self.expected_classes[idx]
  56. filename_id = self.filename_ids[idx]
  57. return mri_sample, xls_sample, expected_classes, filename_id
  58. def load_adni_data_from_file(
  59. mri_files: Iterator[pl.Path], # List of nibablel files
  60. xls_file: pl.Path, # Path to the Excel file
  61. device: str = "cuda",
  62. xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
  63. ) -> ADNIDataset:
  64. """
  65. Loads MRI and Excel data from the ADNI dataset.
  66. Args:
  67. mri_files (List[pl.Path]): List of paths to the MRI files.
  68. xls_file (pl.Path): Path to the Excel file.
  69. Returns:
  70. ADNIDataset: The loaded dataset.
  71. """
  72. # Load the Excel data
  73. xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
  74. # Load the MRI data
  75. mri_data_unstacked: List[torch.Tensor] = []
  76. expected_classes_unstacked: List[torch.Tensor] = []
  77. xls_data_unstacked: List[torch.Tensor] = []
  78. img_ids: List[int] = []
  79. for file in mri_files:
  80. filename = file.stem
  81. match re.search(r".+?(?=_I)_I(\d+).+", filename):
  82. case None:
  83. raise ValueError(
  84. f"Filename {filename} does not match expected pattern."
  85. )
  86. case m:
  87. img_id = int(m.group(1))
  88. file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
  89. # Read the filename to determine the expected class
  90. file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
  91. if "AD" in filename:
  92. file_expected_class = torch.tensor([1.0, 0.0])
  93. elif "NL" in filename:
  94. file_expected_class = torch.tensor([0.0, 1.0])
  95. else:
  96. raise ValueError(
  97. f"Filename {filename} does not contain a valid class identifier (AD or CN)."
  98. )
  99. mri_data_unstacked.append(file_mri_data)
  100. expected_classes_unstacked.append(file_expected_class)
  101. # Extract the corresponding row from the Excel data using the img_id
  102. xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
  103. if xls_row.empty:
  104. raise ValueError(
  105. f"No matching row found in Excel data for Image Data ID {img_id}."
  106. )
  107. elif len(xls_row) > 1:
  108. raise ValueError(
  109. f"Multiple rows found in Excel data for Image Data ID {img_id}."
  110. )
  111. file_xls_data = torch.tensor(
  112. xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
  113. )
  114. xls_data_unstacked.append(file_xls_data)
  115. img_ids.append(img_id)
  116. mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
  117. # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
  118. xls_data = torch.stack(
  119. xls_data_unstacked
  120. ) # Stack the list of tensors into a single tensor
  121. expected_classes = torch.stack(
  122. expected_classes_unstacked
  123. ) # Stack the list of expected classes into a single tensor
  124. return ADNIDataset(mri_data, xls_data, expected_classes, img_ids, device=device)
  125. def divide_dataset(
  126. dataset: ADNIDataset,
  127. ratios: Tuple[float, float, float],
  128. seed: int,
  129. ) -> List[data.Subset[ADNIDataset]]:
  130. """
  131. Divides the dataset into training, validation, and test sets.
  132. Args:
  133. dataset (ADNIDataset): The dataset to divide.
  134. train_ratio (float): The ratio of the training set.
  135. val_ratio (float): The ratio of the validation set.
  136. test_ratio (float): The ratio of the test set.
  137. Returns:
  138. Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
  139. """
  140. if sum(ratios) != 1.0:
  141. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  142. # Set the random seed for reproducibility
  143. gen = torch.Generator().manual_seed(seed)
  144. return data.random_split(dataset, ratios, generator=gen)
  145. def initalize_dataloaders(
  146. datasets: List[Subset[ADNIDataset]],
  147. batch_size: int = 64,
  148. ) -> List[DataLoader[ADNIDataset]]:
  149. """
  150. Initializes the DataLoader for the given datasets.
  151. Args:
  152. datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
  153. batch_size (int): The batch size for the DataLoader.
  154. Returns:
  155. List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
  156. """
  157. return [
  158. DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
  159. ]
  160. def divide_dataset_by_patient_id(
  161. dataset: ADNIDataset,
  162. ptids: List[Tuple[int, str]],
  163. ratios: Tuple[float, float, float],
  164. seed: int,
  165. ) -> List[data.Subset[ADNIDataset]]:
  166. """
  167. Divides the dataset into training, validation, and test sets based on patient IDs.
  168. Ensures that all samples from the same patient are in the same set.
  169. Args:
  170. dataset (ADNIDataset): The dataset to divide.
  171. ptids (List[Tuple[int, str]]): A list of tuples containing image file IDs and their corresponding patient IDs.
  172. ratios (Tuple[float, float, float]): The ratios for training, validation, and test sets.
  173. seed (int): The random seed for reproducibility.
  174. Returns:
  175. List[data.Subset[ADNIDataset]]: A list of subsets for training, validation, and test sets.
  176. Notes:
  177. This split is grouped by PTID, so all images from the same patient are assigned
  178. to exactly one partition to avoid patient-level leakage across train/val/test.
  179. """
  180. if sum(ratios) != 1.0:
  181. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  182. if not ptids:
  183. raise ValueError("ptids list cannot be empty.")
  184. image_to_patient: Dict[int, str] = {}
  185. for image_id, patient_id in ptids:
  186. image_id_int = int(image_id)
  187. patient_id_str = str(patient_id).strip()
  188. if not patient_id_str or patient_id_str.lower() == "nan":
  189. raise ValueError(f"Invalid PTID for Image Data ID {image_id_int}.")
  190. if (
  191. image_id_int in image_to_patient
  192. and image_to_patient[image_id_int] != patient_id_str
  193. ):
  194. raise ValueError(
  195. f"Conflicting PTIDs for Image Data ID {image_id_int}: "
  196. f"{image_to_patient[image_id_int]} vs {patient_id_str}."
  197. )
  198. image_to_patient[image_id_int] = patient_id_str
  199. patient_to_indices: Dict[str, List[int]] = {}
  200. for idx, image_id in enumerate(dataset.filename_ids):
  201. if image_id not in image_to_patient:
  202. raise ValueError(
  203. f"Missing PTID mapping for dataset Image Data ID {image_id}."
  204. )
  205. patient_id = image_to_patient[image_id]
  206. if patient_id not in patient_to_indices:
  207. patient_to_indices[patient_id] = []
  208. patient_to_indices[patient_id].append(idx)
  209. shuffled_patients = list(patient_to_indices.keys())
  210. random.Random(seed).shuffle(shuffled_patients)
  211. train_cutoff = int(len(shuffled_patients) * ratios[0])
  212. val_cutoff = train_cutoff + int(len(shuffled_patients) * ratios[1])
  213. train_patients = shuffled_patients[:train_cutoff]
  214. val_patients = shuffled_patients[train_cutoff:val_cutoff]
  215. test_patients = shuffled_patients[val_cutoff:]
  216. train_patient_set = set(train_patients)
  217. val_patient_set = set(val_patients)
  218. test_patient_set = set(test_patients)
  219. if (
  220. train_patient_set & val_patient_set
  221. or train_patient_set & test_patient_set
  222. or val_patient_set & test_patient_set
  223. ):
  224. raise ValueError("Patient separation violated across train/val/test splits.")
  225. all_patients = set(patient_to_indices.keys())
  226. if train_patient_set | val_patient_set | test_patient_set != all_patients:
  227. raise ValueError("Not all patients were assigned to a split.")
  228. train_indices = [
  229. idx for patient_id in train_patients for idx in patient_to_indices[patient_id]
  230. ]
  231. val_indices = [
  232. idx for patient_id in val_patients for idx in patient_to_indices[patient_id]
  233. ]
  234. test_indices = [
  235. idx for patient_id in test_patients for idx in patient_to_indices[patient_id]
  236. ]
  237. train_index_set = set(train_indices)
  238. val_index_set = set(val_indices)
  239. test_index_set = set(test_indices)
  240. if (
  241. train_index_set & val_index_set
  242. or train_index_set & test_index_set
  243. or val_index_set & test_index_set
  244. ):
  245. raise ValueError("Sample index overlap detected across train/val/test splits.")
  246. all_split_indices = train_index_set | val_index_set | test_index_set
  247. expected_indices = set(range(len(dataset)))
  248. if all_split_indices != expected_indices:
  249. raise ValueError(
  250. "Split coverage check failed: not all dataset samples are assigned exactly once."
  251. )
  252. return [
  253. Subset(dataset, train_indices),
  254. Subset(dataset, val_indices),
  255. Subset(dataset, test_indices),
  256. ]