|
@@ -0,0 +1,131 @@
|
|
|
|
+# NEEDS TO BE FINISHED
|
|
|
|
+# TODO CHECK ABOUT IMAGE DIMENSIONS
|
|
|
|
+# TODO ENSURE ITERATION WORKS
|
|
|
|
+import glob
|
|
|
|
+import nibabel as nib
|
|
|
|
+import numpy as np
|
|
|
|
+import random
|
|
|
|
+import torch
|
|
|
|
+from torch.utils.data import Dataset
|
|
|
|
+import pandas as pd
|
|
|
|
+from torch.utils.data import DataLoader
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+'''
|
|
|
|
+Prepares CustomDatasets for training, validating, and testing CNN
|
|
|
|
+'''
|
|
|
|
+def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50):
|
|
|
|
+
|
|
|
|
+ rndm = random.Random(seed)
|
|
|
|
+ xls_data = pd.read_csv(xls_file).set_index('Image Data ID')
|
|
|
|
+ raw_data = glob.glob(mri_dir + "*")
|
|
|
|
+ AD_list = []
|
|
|
|
+ NL_list = []
|
|
|
|
+
|
|
|
|
+ # TODO Check that image is in CSV?
|
|
|
|
+ for image in raw_data:
|
|
|
|
+ if "NL" in image:
|
|
|
|
+ NL_list.append(image)
|
|
|
|
+ elif "AD" in image:
|
|
|
|
+ AD_list.append(image)
|
|
|
|
+
|
|
|
|
+ rndm.shuffle(AD_list)
|
|
|
|
+ rndm.shuffle(NL_list)
|
|
|
|
+
|
|
|
|
+ train_list, val_list, test_list = get_train_val_test(AD_list, NL_list, val_split)
|
|
|
|
+
|
|
|
|
+ rndm.shuffle(train_list)
|
|
|
|
+ rndm.shuffle(val_list)
|
|
|
|
+ rndm.shuffle(test_list)
|
|
|
|
+
|
|
|
|
+ train_dataset = ADNIDataset(train_list, xls_data)
|
|
|
|
+ val_dataset = ADNIDataset(val_list, xls_data)
|
|
|
|
+ test_dataset = ADNIDataset(test_list, xls_data)
|
|
|
|
+
|
|
|
|
+ return train_dataset, val_dataset, test_dataset
|
|
|
|
+
|
|
|
|
+ # TODO Normalize data? Later add / Exctract clinical data? Which data?
|
|
|
|
+
|
|
|
|
+'''
|
|
|
|
+Returns train_list, val_list and test_list in format [(image, id), ...] each
|
|
|
|
+'''
|
|
|
|
+def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
|
+
|
|
|
|
+ train_list, val_list, test_list = [], [], []
|
|
|
|
+
|
|
|
|
+ num_test_ad = int(len(AD_list) * val_split)
|
|
|
|
+ num_test_nl = int(len(NL_list) * val_split)
|
|
|
|
+
|
|
|
|
+ num_val_ad = int((len(AD_list) - num_test_ad) * val_split)
|
|
|
|
+ num_val_nl = int((len(NL_list) - num_test_nl) * val_split)
|
|
|
|
+
|
|
|
|
+ # Sets up ADs
|
|
|
|
+ for image in AD_list[0:num_val_ad]:
|
|
|
|
+ val_list.append((image, 1))
|
|
|
|
+
|
|
|
|
+ for image in AD_list[num_val_ad:num_test_ad]:
|
|
|
|
+ test_list.append((image, 1))
|
|
|
|
+
|
|
|
|
+ for image in AD_list[num_test_ad:]:
|
|
|
|
+ train_list.append((image, 1))
|
|
|
|
+
|
|
|
|
+ # Sets up NLs
|
|
|
|
+ for image in NL_list[0:num_val_nl]:
|
|
|
|
+ val_list.append((image, 0))
|
|
|
|
+
|
|
|
|
+ for image in NL_list[num_val_nl:num_test_nl]:
|
|
|
|
+ test_list.append((image, 0))
|
|
|
|
+
|
|
|
|
+ for image in NL_list[num_test_nl:]:
|
|
|
|
+ train_list.append((image, 0))
|
|
|
|
+
|
|
|
|
+ return train_list, val_list, test_list
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class ADNIDataset(Dataset):
|
|
|
|
+ def __init__(self, mri, xls: pd.DataFrame):
|
|
|
|
+ self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
|
|
|
|
+ self.xls_data = xls
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return len(self.mri_data)
|
|
|
|
+
|
|
|
|
+ def _xls_to_tensor(self, xls_data: pd.Series):
|
|
|
|
+ #Get used data
|
|
|
|
+
|
|
|
|
+ #data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
|
|
|
|
+ data = xls_data.loc[['Sex', 'Age (current)']]
|
|
|
|
+
|
|
|
|
+ data.replace({'M': 0, 'F': 1}, inplace=True)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ #Convert to tensor
|
|
|
|
+ xls_tensor = torch.tensor(data.values.astype(float))
|
|
|
|
+
|
|
|
|
+ return xls_tensor
|
|
|
|
+
|
|
|
|
+ def __getitem__(self, idx): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
|
|
|
|
+ mri_path, class_id = self.mri_data[idx]
|
|
|
|
+ mri = nib.load(mri_path)
|
|
|
|
+ mri_data = mri.get_fdata()
|
|
|
|
+
|
|
|
|
+ xls = self.xls_data.iloc[idx]
|
|
|
|
+
|
|
|
|
+ #Convert xls data to tensor
|
|
|
|
+ xls_tensor = self._xls_to_tensor(xls)
|
|
|
|
+ mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
|
|
|
|
+
|
|
|
|
+ class_id = torch.tensor([class_id])
|
|
|
|
+ #Convert to one-hot and squeeze
|
|
|
|
+ class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
|
|
|
|
+
|
|
|
|
+ return (mri_tensor, xls_tensor), class_id
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def initalize_dataloaders(training_data, val_data, test_data, cuda_device=torch.device('cuda:0'), batch_size=64):
|
|
|
|
+ train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
|
|
+ test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
|
|
+ val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
|
|
+ return train_dataloader, val_dataloader, test_dataloader
|