dataset.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import nibabel as nib
  2. import torch
  3. import torch.utils.data as data
  4. import pathlib as pl
  5. import pandas as pd
  6. from torch.utils.data import Subset, DataLoader
  7. from jaxtyping import Float
  8. from typing import Tuple, Iterator, Callable, List
  9. from result import Ok, Err, Result
  10. class ADNIDataset(data.Dataset): # type: ignore
  11. """
  12. A PyTorch Dataset class for loading
  13. and processing MRI and Excel data from the ADNI dataset.
  14. """
  15. def __init__(
  16. self,
  17. mri_data: Float[torch.Tensor, "n_samples width height depth"],
  18. xls_data: Float[torch.Tensor, "n_samples features"],
  19. device: str = "cuda",
  20. ):
  21. """
  22. Args:
  23. mri_data (torch.Tensor): 4D tensor of MRI data with shape (n_samples, width, height, depth).
  24. xls_data (torch.Tensor): 2D tensor of Excel data with shape (n_samples, features).
  25. """
  26. self.mri_data = mri_data.float().to(device)
  27. self.xls_data = xls_data.float().to(device)
  28. def __len__(self) -> int:
  29. """
  30. Returns the number of samples in the dataset.
  31. """
  32. return self.mri_data.shape[0] # 0th dimension is the number of samples
  33. def __getitem__(self, idx: int) -> Tuple[
  34. Float[torch.Tensor, "width height depth"],
  35. Float[torch.Tensor, "features"],
  36. ]:
  37. """
  38. Returns a sample from the dataset at the given index.
  39. Args:
  40. idx (int): Index of the sample to retrieve.
  41. Returns:
  42. tuple: A tuple containing the MRI data and Excel data for the sample.
  43. """
  44. # Slices the data on the 0th dimension, corresponding to the sample index
  45. mri_sample = self.mri_data[idx]
  46. xls_sample = self.xls_data[idx]
  47. return mri_sample, xls_sample
  48. def load_adni_data_from_file(
  49. mri_files: Iterator[pl.Path], # List of nibablel files
  50. xls_file: pl.Path, # Path to the Excel file
  51. device: str = "cuda",
  52. xls_preprocessor: Callable[[pd.DataFrame], pd.DataFrame] = lambda x: x,
  53. ) -> Result[ADNIDataset, str]:
  54. """
  55. Loads MRI and Excel data from the ADNI dataset.
  56. Args:
  57. mri_files (List[pl.Path]): List of paths to the MRI files.
  58. xls_file (pl.Path): Path to the Excel file.
  59. Returns:
  60. Result[ADNIDataset, str]: A Result object containing the ADNIDataset or an error message.
  61. """
  62. # Load the MRI data
  63. mri_data_unstacked = [
  64. torch.from_numpy(nib.load(file).get_fdata()) for file in mri_files # type: ignore # type checking does not work well with nibabel
  65. ]
  66. mri_data = torch.stack(
  67. mri_data_unstacked
  68. ) # Stack the list of tensors into a single tensor\
  69. # Load the Excel data
  70. xls_data = torch.from_numpy( # type: ignore
  71. xls_preprocessor(pd.read_excel(xls_file)).to_numpy() # type: ignore
  72. ).float()
  73. # Check if the number of samples in MRI and Excel data match
  74. if mri_data.shape[0] == xls_data.shape[0]:
  75. return Ok(ADNIDataset(mri_data, xls_data, device=device))
  76. else:
  77. return Err("Loading MRI data failed")
  78. def divide_dataset(
  79. dataset: ADNIDataset,
  80. ratios: Tuple[float, float, float],
  81. seed: int,
  82. ) -> Result[List[data.Subset[ADNIDataset]], str]:
  83. """
  84. Divides the dataset into training, validation, and test sets.
  85. Args:
  86. dataset (ADNIDataset): The dataset to divide.
  87. train_ratio (float): The ratio of the training set.
  88. val_ratio (float): The ratio of the validation set.
  89. test_ratio (float): The ratio of the test set.
  90. Returns:
  91. Result[List[data.Subset[ADNIDataset]], str]: A Result object containing the subsets or an error message.
  92. """
  93. if sum(ratios) != 1.0:
  94. return Err("Ratios must sum to 1.0")
  95. # Set the random seed for reproducibility
  96. gen = torch.Generator().manual_seed(seed)
  97. return Ok(data.random_split(dataset, ratios, generator=gen))
  98. def initalize_dataloaders(
  99. datasets: List[Subset[ADNIDataset]],
  100. batch_size: int = 64,
  101. ) -> List[DataLoader[ADNIDataset]]:
  102. """
  103. Initializes the DataLoader for the given datasets.
  104. Args:
  105. datasets (List[Subset[ADNIDataset]]): List of datasets to create DataLoaders for.
  106. batch_size (int): The batch size for the DataLoader.
  107. Returns:
  108. List[DataLoader[ADNIDataset]]: A list of DataLoaders for the datasets.
  109. """
  110. return [
  111. DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in datasets
  112. ]