preprocess.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import glob
  2. import nibabel as nib
  3. import numpy as np
  4. import pandas as pd
  5. import random
  6. # import torch
  7. from torch.utils.data import Dataset
  8. from torchvision.transforms import v2
  9. import re
  10. '''
  11. Prepares CustomDatasets for training, validating, and testing CNN
  12. '''
  13. def prepare_datasets(mri_dir, val_split=0.2, seed=50):
  14. rndm = random.Random(seed)
  15. csv = pd.read_csv("LP_ADNIMERGE.csv")
  16. raw_data = glob.glob(mri_dir + "*")
  17. AD_list = []
  18. NL_list = []
  19. print("--- DATA INFO ---")
  20. print("Amount of images: " + str(len(raw_data)))
  21. for image in raw_data:
  22. # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA
  23. image_ID = re.search(r"_I(\d+)_", image).group(1)
  24. if(image_ID==None): raise RuntimeError("Image not found in CSV!")
  25. image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']] # M: 0, F: 1
  26. image_data = image_data.iloc[0].tolist()
  27. if (image_data[0] == 'M'):
  28. image_data[0] = 0
  29. elif(image_data[0] == 'F'):
  30. image_data[0] = 1
  31. else:
  32. raise RuntimeError("Incorrect sex in an image")
  33. image_data = tuple(image_data)
  34. if "NL" in image:
  35. NL_list.append((image, 0, image_data))
  36. elif "AD" in image:
  37. AD_list.append((image, 1, image_data))
  38. print("Total AD: " + str(len(AD_list)))
  39. print("Total NL: " + str(len(NL_list)))
  40. rndm.shuffle(AD_list)
  41. rndm.shuffle(NL_list)
  42. train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
  43. rndm.shuffle(train_list)
  44. rndm.shuffle(val_list)
  45. rndm.shuffle(test_list)
  46. print(f"DATA INITIALIZATION")
  47. print(f"Training size: {len(train_list)}")
  48. print(f"Validation size: {len(val_list)}")
  49. print(f"Test size: {len(test_list)}")
  50. transformation = v2.Compose([
  51. v2.Normalize([0.5],[0.5]), # TODO Get Vals from dataset
  52. # TODO CHOOSE WHAT TRANSFORMATIONS TO DO
  53. ])
  54. train_dataset = CustomDataset(train_list, transformation)
  55. val_dataset = CustomDataset(val_list, transformation)
  56. test_dataset = CustomDataset(test_list, transformation)
  57. return train_dataset, val_dataset, test_dataset
  58. def prepare_predict(mri_dir, IDs):
  59. raw_data = glob.glob(mri_dir + "*")
  60. image_list = []
  61. # Gets all images and prepares them for Dataset
  62. for ID in IDs:
  63. pattern = re.compile(ID)
  64. matches = [item for item in raw_data if pattern.search(item)]
  65. if (len(matches) != 1): print("No image found, or more than one")
  66. for match in matches:
  67. if "NL" in match: image_list.append((match, 0))
  68. if "AD" in match: image_list.append((match, 1))
  69. return CustomDataset(image_list)
  70. '''
  71. Returns train_list, val_list and test_list in format [(image, id), ...] each
  72. '''
  73. def get_train_val_test(AD_list, NL_list, val_split):
  74. train_list, val_list, test_list = [], [], []
  75. num_val_ad = int(len(AD_list) * val_split/2)
  76. num_val_nl = int(len(NL_list) * val_split/2)
  77. num_test_ad = int((len(AD_list) - num_val_ad) * val_split)
  78. num_test_nl = int((len(NL_list) - num_val_nl) * val_split)
  79. # Sets up ADs
  80. for image in AD_list[0:num_test_ad]:
  81. test_list.append(image)
  82. for image in AD_list[num_test_ad:num_test_ad+num_val_ad]:
  83. val_list.append(image)
  84. for image in AD_list[num_test_ad+num_val_ad:]:
  85. train_list.append(image)
  86. # Sets up NLs
  87. for image in NL_list[0:num_test_nl]:
  88. test_list.append(image)
  89. for image in NL_list[num_test_nl:num_test_nl+num_val_nl]:
  90. val_list.append(image)
  91. for image in NL_list[num_test_nl+num_val_nl:]:
  92. train_list.append(image)
  93. return train_list, val_list, test_list
  94. class CustomDataset(Dataset):
  95. def __init__(self, list, transform):
  96. self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
  97. self.transform = transform
  98. def __len__(self):
  99. return len(self.data)
  100. def __getitem__(self, idx): # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id)
  101. mri_path, class_id, clinical_data = self.data[idx]
  102. mri = nib.load(mri_path)
  103. image = np.asarray(mri.dataobj)
  104. mri_data = np.asarray(np.expand_dims(image, axis=0))
  105. mri_data = self.transform(mri_data)
  106. # mri_data = mri.get_fdata()
  107. # mri_array = np.array(mri)
  108. # mri_tensor = torch.from_numpy(mri_array)
  109. # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
  110. return (mri_data, clinical_data), class_id