preprocess.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # NEEDS TO BE FINISHED
  2. # TODO CHECK ABOUT IMAGE DIMENSIONS
  3. # TODO ENSURE ITERATION WORKS
  4. import glob
  5. import nibabel as nib
  6. import numpy as np
  7. import random
  8. import torch
  9. from torch.utils.data import Dataset
  10. '''
  11. Prepares CustomDatasets for training, validating, and testing CNN
  12. '''
  13. def prepare_datasets(mri_dir, val_split=0.2, seed=50):
  14. rndm = random.Random(seed)
  15. raw_data = glob.glob(mri_dir + "*")
  16. AD_list = []
  17. NL_list = []
  18. print("--- DATA INFO ---")
  19. print("Amount of images: " + str(len(raw_data)))
  20. # TODO Check that image is in CSV?
  21. for image in raw_data:
  22. if "NL" in image:
  23. NL_list.append(image)
  24. elif "AD" in image:
  25. AD_list.append(image)
  26. print("Total AD: " + str(len(AD_list)))
  27. print("Total NL: " + str(len(NL_list)))
  28. rndm.shuffle(AD_list)
  29. rndm.shuffle(NL_list)
  30. train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
  31. rndm.shuffle(train_list)
  32. rndm.shuffle(val_list)
  33. rndm.shuffle(test_list)
  34. train_dataset = CustomDataset(train_list)
  35. val_dataset = CustomDataset(val_list)
  36. test_dataset = CustomDataset(test_list)
  37. return train_dataset, val_dataset, test_dataset
  38. # TODO Normalize data? Later add / Exctract clinical data? Which data?
  39. '''
  40. Returns train_list, val_list and test_list in format [(image, id), ...] each
  41. '''
  42. def get_train_val_test(AD_list, NL_list, val_split):
  43. train_list, val_list, test_list = [], [], []
  44. num_test_ad = int(len(AD_list) * val_split)
  45. num_test_nl = int(len(NL_list) * val_split)
  46. num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
  47. num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
  48. # Sets up ADs
  49. for image in AD_list[0:num_val_ad]:
  50. val_list.append((image, 1))
  51. for image in AD_list[num_val_ad:num_test_ad]:
  52. test_list.append((image, 1))
  53. for image in AD_list[num_test_ad:]:
  54. train_list.append((image, 1))
  55. # Sets up NLs
  56. for image in NL_list[0:num_val_nl]:
  57. val_list.append((image, 0))
  58. for image in NL_list[num_val_nl:num_test_nl]:
  59. test_list.append((image, 0))
  60. for image in NL_list[num_test_nl:]:
  61. train_list.append((image, 0))
  62. return train_list, val_list, test_list
  63. class CustomDataset(Dataset):
  64. def __init__(self, list):
  65. self.data = list # DATA IS A LIST WITH TUPLES (image_dir, class_id)
  66. def __len__(self):
  67. return len(self.data)
  68. def __getitem__(self, idx): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
  69. mri_path, class_id = self.data[idx]
  70. mri = nib.load(mri_path)
  71. mri_data = mri.get_fdata()
  72. # mri_array = np.array(mri)
  73. # mri_tensor = torch.from_numpy(mri_array)
  74. # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
  75. return mri_data, class_id