datasets.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # NEEDS TO BE FINISHED
  2. # TODO CHECK ABOUT IMAGE DIMENSIONS
  3. # TODO ENSURE ITERATION WORKS
  4. import glob
  5. import nibabel as nib
  6. import random
  7. import torch
  8. from torch.utils.data import Dataset
  9. import pandas as pd
  10. from torch.utils.data import DataLoader
  11. import math
  12. from typing import Tuple
  13. import pathlib as pl
  14. """
  15. Prepares CustomDatasets for training, validating, and testing CNN
  16. """
  17. def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50, device=None):
  18. if device is None:
  19. device = torch.device("cpu")
  20. rndm = random.Random(seed)
  21. xls_data = pd.read_csv(xls_file)
  22. # Strip all trailing whitespace from dataframe
  23. xls_data = xls_data.replace(r"^ +| +$", r"", regex=True)
  24. # Strip all trailing whitespace from column names
  25. xls_data.columns = xls_data.columns.str.strip()
  26. xls_data = xls_data.set_index("Image Data ID")
  27. raw_data = glob.glob(mri_dir + "*")
  28. print(f"Found {len(raw_data)} images in {mri_dir}")
  29. AD_list = []
  30. NL_list = []
  31. # TODO Check that image is in CSV?
  32. for image in raw_data:
  33. if "NL" in image:
  34. NL_list.append(image)
  35. elif "AD" in image:
  36. AD_list.append(image)
  37. rndm.shuffle(AD_list)
  38. rndm.shuffle(NL_list)
  39. train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
  40. train_dataset = ADNIDataset(train_list, xls_data, device=device)
  41. val_dataset = ADNIDataset(val_list, xls_data, device=device)
  42. test_dataset = ADNIDataset(test_list, xls_data, device=device)
  43. return train_dataset, val_dataset, test_dataset
  44. # TODO Normalize data? Later add / Exctract clinical data? Which data?
  45. """
  46. Returns train_list, val_list and test_list in format [(image, id), ...] each
  47. """
  48. def get_train_val_test(AD_list, NL_list, val_split):
  49. train_list, val_list, test_list = [], [], []
  50. # For the purposes of this split, the val_split constitutes the validation and testing split, as they are divided evenly
  51. # get the overall length of the data
  52. AD_len = len(AD_list)
  53. NL_len = len(NL_list)
  54. # First, determine the length of each of the sets
  55. AD_val_len = int(math.ceil(AD_len * val_split * 0.5))
  56. NL_val_len = int(math.ceil(NL_len * val_split * 0.5))
  57. AD_test_len = int(math.floor(AD_len * val_split * 0.5))
  58. NL_test_len = int(math.floor(NL_len * val_split * 0.5))
  59. AD_train_len = AD_len - AD_val_len - AD_test_len
  60. NL_train_len = NL_len - NL_val_len - NL_test_len
  61. # Add the data to the sets
  62. for i in range(AD_train_len):
  63. train_list.append((AD_list[i], 1))
  64. for i in range(NL_train_len):
  65. train_list.append((NL_list[i], 0))
  66. for i in range(AD_train_len, AD_train_len + AD_val_len):
  67. val_list.append((AD_list[i], 1))
  68. for i in range(NL_train_len, NL_train_len + NL_val_len):
  69. val_list.append((NL_list[i], 0))
  70. for i in range(AD_train_len + AD_val_len, AD_len):
  71. test_list.append((AD_list[i], 1))
  72. for i in range(NL_train_len + NL_val_len, NL_len):
  73. test_list.append((NL_list[i], 0))
  74. return train_list, val_list, test_list
  75. class ADNIDataset(Dataset):
  76. def __init__(self, mri, xls: pd.DataFrame, data_dir: pl.Path, device:torch.device =torch.device("cpu"), ):
  77. self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
  78. self.xls_data = xls
  79. self.device = device
  80. def __len__(self):
  81. return len(self.mri_data)
  82. def _xls_to_tensor(self, xls_data: pd.Series):
  83. # Get used data
  84. # data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
  85. data = xls_data.loc[["Sex", "Age (current)"]]
  86. data.replace({"M": 0, "F": 1}, inplace=True)
  87. # Convert to tensor
  88. xls_tensor = torch.tensor(data.values.astype(float))
  89. return xls_tensor
  90. def __getitem__(
  91. self, idx: int
  92. ) -> Tuple[
  93. Tuple[torch.Tensor, torch.Tensor], torch.Tensor
  94. ]: # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
  95. mri_path, class_id = self.mri_data[idx]
  96. mri_path = pl.Path(mri_path).name
  97. adj_path = self.
  98. mri = nib.load(mri_path)
  99. mri_data = mri.get_fdata()
  100. xls = self.xls_data.iloc[idx]
  101. # Convert xls data to tensor
  102. xls_tensor = self._xls_to_tensor(xls)
  103. mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
  104. class_id = torch.tensor([class_id])
  105. # Convert to one-hot and squeeze
  106. class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
  107. # Convert to float
  108. mri_tensor = mri_tensor.float().to(self.device)
  109. xls_tensor = xls_tensor.float().to(self.device)
  110. class_id = class_id.float().to(self.device)
  111. return (mri_tensor, xls_tensor), class_id
  112. def __iter__(self):
  113. for i in range(len(self)):
  114. yield self.__getitem__(i)
  115. def initalize_dataloaders(
  116. training_data,
  117. val_data,
  118. test_data,
  119. batch_size=64,
  120. ):
  121. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
  122. test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True)
  123. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
  124. return train_dataloader, val_dataloader, test_dataloader