dataset.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import nibabel as nib
  2. import torch
  3. import torch.utils.data as data
  4. import pathlib as pl
  5. import pandas as pd
  6. from torch.utils.data import Subset, DataLoader
  7. import re
  8. from jaxtyping import Float
  9. from typing import Tuple, Iterator, Callable, List
  10. class ADNIDataset(data.Dataset): # type: ignore
  11. """
  12. A PyTorch Dataset class for loading
  13. and processing MRI and Excel data from the ADNI dataset.
  14. """
  15. def __init__(
  16. self,
  17. mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
  18. xls_data: Float[torch.Tensor, "n_samples features"],
  19. expected_classes: Float[torch.Tensor, "classes"],
  20. filename_ids: List[int],
  21. device: str = "cuda",
  22. ):
  23. """
  24. Args:
  25. mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
  26. xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
  27. """
  28. self.mri_data = mri_data.float().to(device)
  29. self.xls_data = xls_data.float().to(device)
  30. self.expected_classes = expected_classes.float().to(device)
  31. self.filename_ids = filename_ids
  32. def __len__(self) -> int:
  33. """
  34. Returns the number of samples in the dataset.
  35. """
  36. return self.mri_data.shape[0] # 0th dimension is the number of samples
  37. def __getitem__(self, idx: int) -> Tuple[
  38. Float[torch.Tensor, "channels width height depth"],
  39. Float[torch.Tensor, "features"],
  40. Float[torch.Tensor, "classes"],
  41. int,
  42. ]:
  43. """
  44. Returns a sample from the dataset at the given index.
  45. Args:
  46. idx (int): Index of the sample to retrieve.
  47. Returns:
  48. tuple: A tuple containing the MRI data and Excel data for the sample.
  49. """
  50. # Slices the data on the 0th dimension, corresponding to the sample index
  51. mri_sample = self.mri_data[idx]
  52. xls_sample = self.xls_data[idx]
  53. # Assuming expected_classes is a tensor of classes, we return it as well
  54. expected_classes = self.expected_classes[idx]
  55. filename_id = self.filename_ids[idx]
  56. return mri_sample, xls_sample, expected_classes, filename_id
  57. def load_adni_data_from_file(
  58. mri_files: Iterator[pl.Path], # List of nibablel files
  59. xls_file: pl.Path, # Path to the Excel file
  60. device: str = "cuda",
  61. xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
  62. ) -> ADNIDataset:
  63. """
  64. Loads MRI and Excel data from the ADNI dataset.
  65. Args:
  66. mri_files (List[pl.Path]): List of paths to the MRI files.
  67. xls_file (pl.Path): Path to the Excel file.
  68. Returns:
  69. ADNIDataset: The loaded dataset.
  70. """
  71. # Load the Excel data
  72. xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
  73. # Load the MRI data
  74. mri_data_unstacked: List[torch.Tensor] = []
  75. expected_classes_unstacked: List[torch.Tensor] = []
  76. xls_data_unstacked: List[torch.Tensor] = []
  77. img_ids: List[int] = []
  78. for file in mri_files:
  79. filename = file.stem
  80. match re.search(r".+?(?=_I)_I(\d+).+", filename):
  81. case None:
  82. raise ValueError(
  83. f"Filename {filename} does not match expected pattern."
  84. )
  85. case m:
  86. img_id = int(m.group(1))
  87. file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
  88. # Read the filename to determine the expected class
  89. file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
  90. if "AD" in filename:
  91. file_expected_class = torch.tensor([1.0, 0.0])
  92. elif "NL" in filename:
  93. file_expected_class = torch.tensor([0.0, 1.0])
  94. else:
  95. raise ValueError(
  96. f"Filename {filename} does not contain a valid class identifier (AD or CN)."
  97. )
  98. mri_data_unstacked.append(file_mri_data)
  99. expected_classes_unstacked.append(file_expected_class)
  100. # Extract the corresponding row from the Excel data using the img_id
  101. xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
  102. if xls_row.empty:
  103. raise ValueError(
  104. f"No matching row found in Excel data for Image Data ID {img_id}."
  105. )
  106. elif len(xls_row) > 1:
  107. raise ValueError(
  108. f"Multiple rows found in Excel data for Image Data ID {img_id}."
  109. )
  110. file_xls_data = torch.tensor(
  111. xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
  112. )
  113. xls_data_unstacked.append(file_xls_data)
  114. img_ids.append(img_id)
  115. mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
  116. # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
  117. xls_data = torch.stack(
  118. xls_data_unstacked
  119. ) # Stack the list of tensors into a single tensor
  120. expected_classes = torch.stack(
  121. expected_classes_unstacked
  122. ) # Stack the list of expected classes into a single tensor
  123. return ADNIDataset(mri_data, xls_data, expected_classes, img_ids, device=device)
  124. def divide_dataset(
  125. dataset: ADNIDataset,
  126. ratios: Tuple[float, float, float],
  127. seed: int,
  128. ) -> List[data.Subset[ADNIDataset]]:
  129. """
  130. Divides the dataset into training, validation, and test sets.
  131. Args:
  132. dataset (ADNIDataset): The dataset to divide.
  133. train_ratio (float): The ratio of the training set.
  134. val_ratio (float): The ratio of the validation set.
  135. test_ratio (float): The ratio of the test set.
  136. Returns:
  137. Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
  138. """
  139. if sum(ratios) != 1.0:
  140. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  141. # Set the random seed for reproducibility
  142. gen = torch.Generator().manual_seed(seed)
  143. return data.random_split(dataset, ratios, generator=gen)
  144. def initalize_dataloaders(
  145. datasets: List[Subset[ADNIDataset]],
  146. batch_size: int = 64,
  147. ) -> List[DataLoader[ADNIDataset]]:
  148. """
  149. Initializes the DataLoader for the given datasets.
  150. Args:
  151. datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
  152. batch_size (int): The batch size for the DataLoader.
  153. Returns:
  154. List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
  155. """
  156. return [
  157. DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
  158. ]