dataset.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import nibabel as nib
  2. import torch
  3. import torch.utils.data as data
  4. import pathlib as pl
  5. import pandas as pd
  6. from torch.utils.data import Subset, DataLoader
  7. import re
  8. from jaxtyping import Float
  9. from typing import Tuple, Iterator, Callable, List
  10. class ADNIDataset(data.Dataset): # type: ignore
  11. """
  12. A PyTorch Dataset class for loading
  13. and processing MRI and Excel data from the ADNI dataset.
  14. """
  15. def __init__(
  16. self,
  17. mri_data: Float[torch.Tensor, "n_samples channels width height depth"],
  18. xls_data: Float[torch.Tensor, "n_samples features"],
  19. expected_classes: Float[torch.Tensor, "classes"],
  20. device: str = "cuda",
  21. ):
  22. """
  23. Args:
  24. mri_data (torch.Tensor): 5D tensor of MRI data with shape (n_samples, channels, width, height, depth).
  25. xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
  26. """
  27. self.mri_data = mri_data.float().to(device)
  28. self.xls_data = xls_data.float().to(device)
  29. self.expected_classes = expected_classes.float().to(device)
  30. def __len__(self) -> int:
  31. """
  32. Returns the number of samples in the dataset.
  33. """
  34. return self.mri_data.shape[0] # 0th dimension is the number of samples
  35. def __getitem__(self, idx: int) -> Tuple[
  36. Float[torch.Tensor, "channels width height depth"],
  37. Float[torch.Tensor, "features"],
  38. Float[torch.Tensor, "classes"],
  39. ]:
  40. """
  41. Returns a sample from the dataset at the given index.
  42. Args:
  43. idx (int): Index of the sample to retrieve.
  44. Returns:
  45. tuple: A tuple containing the MRI data and Excel data for the sample.
  46. """
  47. # Slices the data on the 0th dimension, corresponding to the sample index
  48. mri_sample = self.mri_data[idx]
  49. xls_sample = self.xls_data[idx]
  50. # Assuming expected_classes is a tensor of classes, we return it as well
  51. expected_classes = self.expected_classes[idx]
  52. return mri_sample, xls_sample, expected_classes
  53. def load_adni_data_from_file(
  54. mri_files: Iterator[pl.Path], # List of nibablel files
  55. xls_file: pl.Path, # Path to the Excel file
  56. device: str = "cuda",
  57. xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
  58. ) -> ADNIDataset:
  59. """
  60. Loads MRI and Excel data from the ADNI dataset.
  61. Args:
  62. mri_files (List[pl.Path]): List of paths to the MRI files.
  63. xls_file (pl.Path): Path to the Excel file.
  64. Returns:
  65. Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
  66. """
  67. # Load the Excel data
  68. xls_values = xls_preprocessor(pd.read_csv(xls_file)) # type: ignore
  69. # Load the MRI data
  70. mri_data_unstacked: List[torch.Tensor] = []
  71. expected_classes_unstacked: List[torch.Tensor] = []
  72. xls_data_unstacked: List[torch.Tensor] = []
  73. img_ids: List[int] = []
  74. for file in mri_files:
  75. filename = file.stem
  76. match re.search(r".+?(?=_I)_I(\d+).+", filename):
  77. case None:
  78. raise ValueError(
  79. f"Filename {filename} does not match expected pattern."
  80. )
  81. case m:
  82. img_id = int(m.group(1))
  83. file_mri_data = torch.from_numpy(nib.load(file).get_fdata()) # type: ignore # type checking does not work well with nibabel
  84. # Read the filename to determine the expected class
  85. file_expected_class = torch.tensor([0.0, 0.0]) # Default to a tensor of zeros
  86. if "AD" in filename:
  87. file_expected_class = torch.tensor([1.0, 0.0])
  88. elif "CN" in filename:
  89. file_expected_class = torch.tensor([0.0, 1.0])
  90. mri_data_unstacked.append(file_mri_data)
  91. expected_classes_unstacked.append(file_expected_class)
  92. # Extract the corresponding row from the Excel data using the img_id
  93. xls_row = xls_values.loc[xls_values["Image Data ID"] == img_id]
  94. if xls_row.empty:
  95. raise ValueError(
  96. f"No matching row found in Excel data for Image Data ID {img_id}."
  97. )
  98. elif len(xls_row) > 1:
  99. raise ValueError(
  100. f"Multiple rows found in Excel data for Image Data ID {img_id}."
  101. )
  102. file_xls_data = torch.tensor(
  103. xls_row.drop(columns=["Image Data ID"]).values.flatten() # type: ignore
  104. )
  105. xls_data_unstacked.append(file_xls_data)
  106. img_ids.append(img_id)
  107. mri_data = torch.stack(mri_data_unstacked).unsqueeze(1)
  108. # Stack the list of tensors into a single tensor and unsqueeze along the channel dimension
  109. xls_data = torch.stack(
  110. xls_data_unstacked
  111. ) # Stack the list of tensors into a single tensor
  112. expected_classes = torch.stack(
  113. expected_classes_unstacked
  114. ) # Stack the list of expected classes into a single tensor
  115. return ADNIDataset(mri_data, xls_data, expected_classes, device=device)
  116. def divide_dataset(
  117. dataset: ADNIDataset,
  118. ratios: Tuple[float, float, float],
  119. seed: int,
  120. ) -> List[data.Subset[ADNIDataset]]:
  121. """
  122. Divides the dataset into training, validation, and test sets.
  123. Args:
  124. dataset (ADNIDataset): The dataset to divide.
  125. train_ratio (float): The ratio of the training set.
  126. val_ratio (float): The ratio of the validation set.
  127. test_ratio (float): The ratio of the test set.
  128. Returns:
  129. Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
  130. """
  131. if sum(ratios) != 1.0:
  132. raise ValueError(f"Ratios must sum to 1.0, got {ratios}.")
  133. # Set the random seed for reproducibility
  134. gen = torch.Generator().manual_seed(seed)
  135. return data.random_split(dataset, ratios, generator=gen)
  136. def initalize_dataloaders(
  137. datasets: List[Subset[ADNIDataset]],
  138. batch_size: int = 64,
  139. ) -> List[DataLoader[ADNIDataset]]:
  140. """
  141. Initializes the DataLoader for the given datasets.
  142. Args:
  143. datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
  144. batch_size (int): The batch size for the DataLoader.
  145. Returns:
  146. List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
  147. """
  148. return [
  149. DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
  150. ]