dataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import nibabel as nib
  2. import torch
  3. import torch.utils.data as data
  4. import pathlib as pl
  5. import pandas as pd
  6. from torch.utils.data import Subset, DataLoader
  7. import re
  8. import random
  9. from jaxtyping import Float
  10. from typing import Tuple, Iterator, Callable, List, Dict
  11. def _row_to_float_tensor(row: pd.DataFrame, *, image_id: int) -> torch.Tensor:
  12. values = row.drop(columns=["Image Data ID"]).iloc[0]
  13. numeric = pd.to_numeric(values, errors="coerce")
  14. if numeric.isna().any():
  15. bad_columns = [
  16. column for column, value in numeric.items() if pd.isna(value)
  17. ]
  18. raise ValueError(
  19. f"Non-numeric Excel values for Image Data ID {image_id}: {bad_columns}"
  20. )
  21. return torch.tensor(numeric.to_numpy(dtype="float32"))
  22. class ADNIDataset(data.Dataset): # type: ignore
  23. """
  24. A PyTorch Dataset class for loading
  25. and processing MRI and Excel data from the ADNI dataset.
  26. """
  27. def __init__(
  28. self,
  29. mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
  30. xls_data: Float[torch.Tensor, "n_samples features"],
  31. expected_classes: Float[torch.Tensor, "classes"],
  32. filename_ids: List[int],
  33. device: str = "cuda",
  34. ):
  35. """
  36. Args:
  37. mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
  38. xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
  39. """
  40. self.mri_data = mri_data.float().to(device)
  41. self.xls_data = xls_data.float().to(device)
  42. self.expected_classes = expected_classes.float().to(device)
  43. self.filename_ids = filename_ids
  44. def __len__(self) -> int:
  45. """
  46. Returns the number of samples in the dataset.
  47. """
  48. return self.mri_data.shape[0] # 0th dimension is the number of samples
  49. def __getitem__(self, idx: int) -> Tuple[
  50. Float[torch.Tensor, "channels width height depth"],
  51. Float[torch.Tensor, "features"],
  52. Float[torch.Tensor, "classes"],
  53. int,
  54. ]:
  55. """
  56. Returns a sample from the dataset at the given index.
  57. Args:
  58. idx (int): Index of the sample to retrieve.
  59. Returns:
  60. tuple: A tuple containing the MRI data and Excel data for the sample.
  61. """
  62. # Slices the data on the 0th dimension, corresponding to the sample index
  63. mri_sample = self.mri_data[idx]
  64. xls_sample = self.xls_data[idx]
  65. # Assuming expected_classes is a tensor of classes, we return it as well
  66. expected_classes = self.expected_classes[idx]
  67. filename_id = self.filename_ids[idx]
  68. return mri_sample, xls_sample, expected_classes, filename_id
  69. def load_adni_data_from_file(
  70. mri_files: Iterator[pl.Path], # List of nibablel files
  71. xls_file: pl.Path, # Path to the Excel file
  72. device: str = "cuda",
  73. xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
  74. ) -> ADNIDataset:
  75. """
  76. Loads MRI and Excel data from the ADNI dataset.
  77. Args:
  78. mri_files (List[pl.Path]): List of paths to the MRI files.
  79. xls_file (pl.Path): Path to the Excel file.
  80. Returns:
  81. ADNIDataset: The loaded dataset.
  82. """
  83. # Load the Excel data
  84. xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
  85. # Load the MRI data
  86. mri_data_unstacked: List[torch.Tensor] = []
  87. expected_classes_unstacked: List[torch.Tensor] = []
  88. xls_data_unstacked: List[torch.Tensor] = []
  89. img_ids: List[int] = []
  90. for file in mri_files:
  91. filename = file.stem
  92. match re.search(r".+?(?=_I)_I(\d+).+", filename):
  93. case None:
  94. raise ValueError(
  95. f"Filename {filename} does not match expected pattern."
  96. )
  97. case m:
  98. img_id = int(m.group(1))
  99. file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
  100. # Read the filename to determine the expected class
  101. file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
  102. if "AD" in filename:
  103. file_expected_class = torch.tensor([1.0, 0.0])
  104. elif "NL" in filename:
  105. file_expected_class = torch.tensor([0.0, 1.0])
  106. else:
  107. raise ValueError(
  108. f"Filename {filename} does not contain a valid class identifier (AD or CN)."
  109. )
  110. mri_data_unstacked.append(file_mri_data)
  111. expected_classes_unstacked.append(file_expected_class)
  112. # Extract the corresponding row from the Excel data using the img_id
  113. xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
  114. if xls_row.empty:
  115. raise ValueError(
  116. f"No matching row found in Excel data for Image Data ID {img_id}."
  117. )
  118. elif len(xls_row) > 1:
  119. raise ValueError(
  120. f"Multiple rows found in Excel data for Image Data ID {img_id}."
  121. )
  122. file_xls_data = _row_to_float_tensor(xls_row, image_id=img_id)
  123. xls_data_unstacked.append(file_xls_data)
  124. img_ids.append(img_id)
  125. mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
  126. # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
  127. xls_data = torch.stack(
  128. xls_data_unstacked
  129. ) # Stack the list of tensors into a single tensor
  130. expected_classes = torch.stack(
  131. expected_classes_unstacked
  132. ) # Stack the list of expected classes into a single tensor
  133. return ADNIDataset(mri_data, xls_data, expected_classes, img_ids, device=device)
  134. def divide_dataset(
  135. dataset: ADNIDataset,
  136. ratios: Tuple[float, float, float],
  137. seed: int,
  138. ) -> List[data.Subset[ADNIDataset]]:
  139. """
  140. Divides the dataset into training, validation, and test sets.
  141. Args:
  142. dataset (ADNIDataset): The dataset to divide.
  143. train_ratio (float): The ratio of the training set.
  144. val_ratio (float): The ratio of the validation set.
  145. test_ratio (float): The ratio of the test set.
  146. Returns:
  147. Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
  148. """
  149. if sum(ratios) != 1.0:
  150. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  151. # Set the random seed for reproducibility
  152. gen = torch.Generator().manual_seed(seed)
  153. return data.random_split(dataset, ratios, generator=gen)
  154. def initalize_dataloaders(
  155. datasets: List[Subset[ADNIDataset]],
  156. batch_size: int = 64,
  157. ) -> List[DataLoader[ADNIDataset]]:
  158. """
  159. Initializes the DataLoader for the given datasets.
  160. Args:
  161. datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
  162. batch_size (int): The batch size for the DataLoader.
  163. Returns:
  164. List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
  165. """
  166. return [
  167. DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
  168. ]
  169. def divide_dataset_by_patient_id(
  170. dataset: ADNIDataset,
  171. ptids: List[Tuple[int, str]],
  172. ratios: Tuple[float, float, float],
  173. seed: int,
  174. ) -> List[data.Subset[ADNIDataset]]:
  175. """
  176. Divides the dataset into training, validation, and test sets based on patient IDs.
  177. Ensures that all samples from the same patient are in the same set.
  178. Args:
  179. dataset (ADNIDataset): The dataset to divide.
  180. ptids (List[Tuple[int, str]]): A list of tuples containing image file IDs and their corresponding patient IDs.
  181. ratios (Tuple[float, float, float]): The ratios for training, validation, and test sets.
  182. seed (int): The random seed for reproducibility.
  183. Returns:
  184. List[data.Subset[ADNIDataset]]: A list of subsets for training, validation, and test sets.
  185. Notes:
  186. This split is grouped by PTID, so all images from the same patient are assigned
  187. to exactly one partition to avoid patient-level leakage across train/val/test.
  188. """
  189. if sum(ratios) != 1.0:
  190. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  191. if not ptids:
  192. raise ValueError("ptids list cannot be empty.")
  193. image_to_patient: Dict[int, str] = {}
  194. for image_id, patient_id in ptids:
  195. image_id_int = int(image_id)
  196. patient_id_str = str(patient_id).strip()
  197. if not patient_id_str or patient_id_str.lower() == "nan":
  198. raise ValueError(f"Invalid PTID for Image Data ID {image_id_int}.")
  199. if (
  200. image_id_int in image_to_patient
  201. and image_to_patient[image_id_int] != patient_id_str
  202. ):
  203. raise ValueError(
  204. f"Conflicting PTIDs for Image Data ID {image_id_int}: "
  205. f"{image_to_patient[image_id_int]} vs {patient_id_str}."
  206. )
  207. image_to_patient[image_id_int] = patient_id_str
  208. patient_to_indices: Dict[str, List[int]] = {}
  209. for idx, image_id in enumerate(dataset.filename_ids):
  210. if image_id not in image_to_patient:
  211. raise ValueError(
  212. f"Missing PTID mapping for dataset Image Data ID {image_id}."
  213. )
  214. patient_id = image_to_patient[image_id]
  215. if patient_id not in patient_to_indices:
  216. patient_to_indices[patient_id] = []
  217. patient_to_indices[patient_id].append(idx)
  218. shuffled_patients = list(patient_to_indices.keys())
  219. random.Random(seed).shuffle(shuffled_patients)
  220. train_cutoff = int(len(shuffled_patients) * ratios[0])
  221. val_cutoff = train_cutoff + int(len(shuffled_patients) * ratios[1])
  222. train_patients = shuffled_patients[:train_cutoff]
  223. val_patients = shuffled_patients[train_cutoff:val_cutoff]
  224. test_patients = shuffled_patients[val_cutoff:]
  225. train_patient_set = set(train_patients)
  226. val_patient_set = set(val_patients)
  227. test_patient_set = set(test_patients)
  228. if (
  229. train_patient_set & val_patient_set
  230. or train_patient_set & test_patient_set
  231. or val_patient_set & test_patient_set
  232. ):
  233. raise ValueError("Patient separation violated across train/val/test splits.")
  234. all_patients = set(patient_to_indices.keys())
  235. if train_patient_set | val_patient_set | test_patient_set != all_patients:
  236. raise ValueError("Not all patients were assigned to a split.")
  237. train_indices = [
  238. idx for patient_id in train_patients for idx in patient_to_indices[patient_id]
  239. ]
  240. val_indices = [
  241. idx for patient_id in val_patients for idx in patient_to_indices[patient_id]
  242. ]
  243. test_indices = [
  244. idx for patient_id in test_patients for idx in patient_to_indices[patient_id]
  245. ]
  246. train_index_set = set(train_indices)
  247. val_index_set = set(val_indices)
  248. test_index_set = set(test_indices)
  249. if (
  250. train_index_set & val_index_set
  251. or train_index_set & test_index_set
  252. or val_index_set & test_index_set
  253. ):
  254. raise ValueError("Sample index overlap detected across train/val/test splits.")
  255. all_split_indices = train_index_set | val_index_set | test_index_set
  256. expected_indices = set(range(len(dataset)))
  257. if all_split_indices != expected_indices:
  258. raise ValueError(
  259. "Split coverage check failed: not all dataset samples are assigned exactly once."
  260. )
  261. return [
  262. Subset(dataset, train_indices),
  263. Subset(dataset, val_indices),
  264. Subset(dataset, test_indices),
  265. ]