|
@@ -1,6 +1,7 @@
|
|
|
import glob
|
|
|
import nibabel as nib
|
|
|
import numpy as np
|
|
|
+import pandas as pd
|
|
|
import random
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
@@ -14,6 +15,7 @@ Prepares CustomDatasets for training, validating, and testing CNN
|
|
|
def prepare_datasets(mri_dir, val_split=0.2, seed=50):
|
|
|
|
|
|
rndm = random.Random(seed)
|
|
|
+ csv = pd.read_csv("LP_ADNIMERGE.csv")
|
|
|
|
|
|
raw_data = glob.glob(mri_dir + "*")
|
|
|
|
|
@@ -25,10 +27,25 @@ def prepare_datasets(mri_dir, val_split=0.2, seed=50):
|
|
|
|
|
|
# TODO Check that image is in CSV?
|
|
|
for image in raw_data:
|
|
|
+ # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA
|
|
|
+ image_ID = re.search(r"_I(\d+)_", image).group(1)
|
|
|
+ if(image_ID==None): raise RuntimeError("Image not found in CSV!")
|
|
|
+ image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']] # M: 0, F: 1
|
|
|
+ image_data = image_data.iloc[0].tolist()
|
|
|
+ if (image_data[0] == 'M'):
|
|
|
+ image_data[0] = 0
|
|
|
+ elif(image_data[0] == 'F'):
|
|
|
+ image_data[0] = 1
|
|
|
+ else:
|
|
|
+ raise RuntimeError("Incorrect sex in an image")
|
|
|
+
|
|
|
+ image_data = tuple(image_data)
|
|
|
+
|
|
|
if "NL" in image:
|
|
|
- NL_list.append(image)
|
|
|
+ NL_list.append((image, 0, image_data))
|
|
|
elif "AD" in image:
|
|
|
- AD_list.append(image)
|
|
|
+ AD_list.append((image, 1, image_data))
|
|
|
+
|
|
|
|
|
|
print("Total AD: " + str(len(AD_list)))
|
|
|
print("Total NL: " + str(len(NL_list)))
|
|
@@ -95,36 +112,36 @@ def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
|
|
|
# Sets up ADs
|
|
|
for image in AD_list[0:num_val_ad]:
|
|
|
- val_list.append((image, 1))
|
|
|
+ val_list.append(image)
|
|
|
|
|
|
for image in AD_list[num_val_ad:num_test_ad]:
|
|
|
- test_list.append((image, 1))
|
|
|
+ test_list.append(image)
|
|
|
|
|
|
for image in AD_list[num_test_ad:]:
|
|
|
- train_list.append((image, 1))
|
|
|
+ train_list.append(image)
|
|
|
|
|
|
# Sets up NLs
|
|
|
for image in NL_list[0:num_val_nl]:
|
|
|
- val_list.append((image, 0))
|
|
|
+ val_list.append(image)
|
|
|
|
|
|
for image in NL_list[num_val_nl:num_test_nl]:
|
|
|
- test_list.append((image, 0))
|
|
|
+ test_list.append(image)
|
|
|
|
|
|
for image in NL_list[num_test_nl:]:
|
|
|
- train_list.append((image, 0))
|
|
|
+ train_list.append(image)
|
|
|
|
|
|
return train_list, val_list, test_list
|
|
|
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
|
def __init__(self, list):
|
|
|
- self.data = list # DATA IS A LIST WITH TUPLES (image_dir, class_id)
|
|
|
+ self.data = list # INPUT DATA: (image_dir, class_id, (clinical_data))
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.data)
|
|
|
|
|
|
- def __getitem__(self, idx): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
|
|
|
- mri_path, class_id = self.data[idx]
|
|
|
+ def __getitem__(self, idx): # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id)
|
|
|
+ mri_path, class_id, clinical_data = self.data[idx]
|
|
|
mri = nib.load(mri_path)
|
|
|
image = np.asarray(mri.dataobj)
|
|
|
mri_data = np.asarray(np.expand_dims(image, axis=0))
|
|
@@ -134,4 +151,4 @@ class CustomDataset(Dataset):
|
|
|
# mri_tensor = torch.from_numpy(mri_array)
|
|
|
# class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
|
|
|
|
|
|
- return mri_data, class_id
|
|
|
+ return (mri_data, clinical_data), class_id
|