preprocess.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import glob
  2. import nibabel as nib
  3. import numpy as np
  4. import pandas as pd
  5. import random
  6. import torch
  7. from torch.utils.data import Dataset
  8. import torchvision.transforms as transforms
  9. import re
  10. '''
  11. Prepares CustomDatasets for training, validating, and testing CNN
  12. '''
  13. def prepare_datasets(mri_dir, val_split=0.2, seed=50):
  14. rndm = random.Random(seed)
  15. csv = pd.read_csv("LP_ADNIMERGE.csv")
  16. raw_data = glob.glob(mri_dir + "*")
  17. AD_list = []
  18. NL_list = []
  19. print("--- DATA INFO ---")
  20. print("Amount of images: " + str(len(raw_data)))
  21. # TODO Check that image is in CSV?
  22. for image in raw_data:
  23. # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA
  24. image_ID = re.search(r"_I(\d+)_", image).group(1)
  25. if(image_ID==None): raise RuntimeError("Image not found in CSV!")
  26. image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']] # M: 0, F: 1
  27. image_data = image_data.iloc[0].tolist()
  28. if (image_data[0] == 'M'):
  29. image_data[0] = 0
  30. elif(image_data[0] == 'F'):
  31. image_data[0] = 1
  32. else:
  33. raise RuntimeError("Incorrect sex in an image")
  34. image_data = tuple(image_data)
  35. if "NL" in image:
  36. NL_list.append((image, 0, image_data))
  37. elif "AD" in image:
  38. AD_list.append((image, 1, image_data))
  39. print("Total AD: " + str(len(AD_list)))
  40. print("Total NL: " + str(len(NL_list)))
  41. rndm.shuffle(AD_list)
  42. rndm.shuffle(NL_list)
  43. train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
  44. rndm.shuffle(train_list)
  45. rndm.shuffle(val_list)
  46. rndm.shuffle(test_list)
  47. print(f"DATA INITIALIZATION")
  48. print(f"Training size: {len(train_list)}")
  49. print(f"Validation size: {len(val_list)}")
  50. print(f"Test size: {len(test_list)}")
  51. # # TRANSFORM
  52. # transform = transforms.Compose([
  53. # transforms.Grayscale(num_output_channels=1)
  54. # ])
  55. train_dataset = CustomDataset(train_list)
  56. val_dataset = CustomDataset(val_list)
  57. test_dataset = CustomDataset(test_list)
  58. return train_dataset, val_dataset, test_dataset
  59. # TODO Normalize data? Later add / Extract clinical data? Which data?
  60. def prepare_predict(mri_dir, IDs):
  61. raw_data = glob.glob(mri_dir + "*")
  62. image_list = []
  63. # Gets all images and prepares them for Dataset
  64. for ID in IDs:
  65. pattern = re.compile(ID)
  66. matches = [item for item in raw_data if pattern.search(item)]
  67. if (len(matches) != 1): print("No image found, or more than one")
  68. for match in matches:
  69. if "NL" in match: image_list.append((match, 0))
  70. if "AD" in match: image_list.append((match, 1))
  71. return CustomDataset(image_list)
  72. '''
  73. Returns train_list, val_list and test_list in format [(image, id), ...] each
  74. '''
  75. def get_train_val_test(AD_list, NL_list, val_split):
  76. train_list, val_list, test_list = [], [], []
  77. num_test_ad = int(len(AD_list) * val_split)
  78. num_test_nl = int(len(NL_list) * val_split)
  79. num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
  80. num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
  81. # Sets up ADs
  82. for image in AD_list[0:num_val_ad]:
  83. val_list.append(image)
  84. for image in AD_list[num_val_ad:num_test_ad]:
  85. test_list.append(image)
  86. for image in AD_list[num_test_ad:]:
  87. train_list.append(image)
  88. # Sets up NLs
  89. for image in NL_list[0:num_val_nl]:
  90. val_list.append(image)
  91. for image in NL_list[num_val_nl:num_test_nl]:
  92. test_list.append(image)
  93. for image in NL_list[num_test_nl:]:
  94. train_list.append(image)
  95. return train_list, val_list, test_list
  96. class CustomDataset(Dataset):
  97. def __init__(self, list):
  98. self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
  99. def __len__(self):
  100. return len(self.data)
  101. def __getitem__(self, idx): # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id)
  102. mri_path, class_id, clinical_data = self.data[idx]
  103. mri = nib.load(mri_path)
  104. image = np.asarray(mri.dataobj)
  105. mri_data = np.asarray(np.expand_dims(image, axis=0))
  106. # mri_data = mri.get_fdata()
  107. # mri_array = np.array(mri)
  108. # mri_tensor = torch.from_numpy(mri_array)
  109. # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
  110. return (mri_data, clinical_data), class_id