datasets.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # NEEDS TO BE FINISHED
  2. # TODO CHECK ABOUT IMAGE DIMENSIONS
  3. # TODO ENSURE ITERATION WORKS
  4. import glob
  5. import nibabel as nib
  6. import random
  7. import torch
  8. from torch.utils.data import Dataset
  9. import pandas as pd
  10. from torch.utils.data import DataLoader
  11. """
  12. Prepares CustomDatasets for training, validating, and testing CNN
  13. """
  14. def prepare_datasets(
  15. mri_dir, xls_file, val_split=0.2, seed=50, device=torch.device("cpu")
  16. ):
  17. rndm = random.Random(seed)
  18. xls_data = pd.read_csv(xls_file).set_index("Image Data ID")
  19. raw_data = glob.glob(mri_dir + "*")
  20. AD_list = []
  21. NL_list = []
  22. # TODO Check that image is in CSV?
  23. for image in raw_data:
  24. if "NL" in image:
  25. NL_list.append(image)
  26. elif "AD" in image:
  27. AD_list.append(image)
  28. rndm.shuffle(AD_list)
  29. rndm.shuffle(NL_list)
  30. train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
  31. rndm.shuffle(train_list)
  32. rndm.shuffle(val_list)
  33. rndm.shuffle(test_list)
  34. train_dataset = ADNIDataset(train_list, xls_data, device=device)
  35. val_dataset = ADNIDataset(val_list, xls_data, device=device)
  36. test_dataset = ADNIDataset(test_list, xls_data, device=device)
  37. return train_dataset, val_dataset, test_dataset
  38. # TODO Normalize data? Later add / Exctract clinical data? Which data?
  39. """
  40. Returns train_list, val_list and test_list in format [(image, id), ...] each
  41. """
  42. def get_train_val_test(AD_list, NL_list, val_split):
  43. train_list, val_list, test_list = [], [], []
  44. num_test_ad = int(len(AD_list) * val_split)
  45. num_test_nl = int(len(NL_list) * val_split)
  46. num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
  47. num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
  48. # Sets up ADs
  49. for image in AD_list[0:num_val_ad]:
  50. val_list.append((image, 1))
  51. for image in AD_list[num_val_ad:num_test_ad]:
  52. test_list.append((image, 1))
  53. for image in AD_list[num_test_ad:]:
  54. train_list.append((image, 1))
  55. # Sets up NLs
  56. for image in NL_list[0:num_val_nl]:
  57. val_list.append((image, 0))
  58. for image in NL_list[num_val_nl:num_test_nl]:
  59. test_list.append((image, 0))
  60. for image in NL_list[num_test_nl:]:
  61. train_list.append((image, 0))
  62. return train_list, val_list, test_list
  63. class ADNIDataset(Dataset):
  64. def __init__(self, mri, xls: pd.DataFrame, device=torch.device("cpu")):
  65. self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
  66. self.xls_data = xls
  67. self.device = device
  68. def __len__(self):
  69. return len(self.mri_data)
  70. def _xls_to_tensor(self, xls_data: pd.Series):
  71. # Get used data
  72. # data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
  73. data = xls_data.loc[["Sex", "Age (current)"]]
  74. data.replace({"M": 0, "F": 1}, inplace=True)
  75. # Convert to tensor
  76. xls_tensor = torch.tensor(data.values.astype(float))
  77. return xls_tensor
  78. def __getitem__(
  79. self, idx
  80. ): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
  81. mri_path, class_id = self.mri_data[idx]
  82. mri = nib.load(mri_path)
  83. mri_data = mri.get_fdata()
  84. xls = self.xls_data.iloc[idx]
  85. # Convert xls data to tensor
  86. xls_tensor = self._xls_to_tensor(xls)
  87. mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
  88. class_id = torch.tensor([class_id])
  89. # Convert to one-hot and squeeze
  90. class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
  91. # Convert to float
  92. mri_tensor = mri_tensor.float().to(self.device)
  93. xls_tensor = xls_tensor.float().to(self.device)
  94. class_id = class_id.float().to(self.device)
  95. return (mri_tensor, xls_tensor), class_id
  96. def initalize_dataloaders(
  97. training_data,
  98. val_data,
  99. test_data,
  100. batch_size=64,
  101. ):
  102. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
  103. test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True)
  104. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
  105. return train_dataloader, val_dataloader, test_dataloader