|
@@ -3,22 +3,23 @@
|
|
|
# TODO ENSURE ITERATION WORKS
|
|
|
import glob
|
|
|
import nibabel as nib
|
|
|
-import numpy as np
|
|
|
import random
|
|
|
import torch
|
|
|
from torch.utils.data import Dataset
|
|
|
-import pandas as pd
|
|
|
+import pandas as pd
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
|
-
|
|
|
-'''
|
|
|
+"""
|
|
|
Prepares CustomDatasets for training, validating, and testing CNN
|
|
|
-'''
|
|
|
-def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50):
|
|
|
+"""
|
|
|
+
|
|
|
|
|
|
+def prepare_datasets(
|
|
|
+ mri_dir, xls_file, val_split=0.2, seed=50, device=torch.device("cpu")
|
|
|
+):
|
|
|
rndm = random.Random(seed)
|
|
|
- xls_data = pd.read_csv(xls_file).set_index('Image Data ID')
|
|
|
+ xls_data = pd.read_csv(xls_file).set_index("Image Data ID")
|
|
|
raw_data = glob.glob(mri_dir + "*")
|
|
|
AD_list = []
|
|
|
NL_list = []
|
|
@@ -39,19 +40,21 @@ def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50):
|
|
|
rndm.shuffle(val_list)
|
|
|
rndm.shuffle(test_list)
|
|
|
|
|
|
- train_dataset = ADNIDataset(train_list, xls_data)
|
|
|
- val_dataset = ADNIDataset(val_list, xls_data)
|
|
|
- test_dataset = ADNIDataset(test_list, xls_data)
|
|
|
+ train_dataset = ADNIDataset(train_list, xls_data, device=device)
|
|
|
+ val_dataset = ADNIDataset(val_list, xls_data, device=device)
|
|
|
+ test_dataset = ADNIDataset(test_list, xls_data, device=device)
|
|
|
|
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
|
|
# TODO Normalize data? Later add / Exctract clinical data? Which data?
|
|
|
|
|
|
-'''
|
|
|
+
|
|
|
+"""
|
|
|
Returns train_list, val_list and test_list in format [(image, id), ...] each
|
|
|
-'''
|
|
|
-def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
+"""
|
|
|
|
|
|
+
|
|
|
+def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
train_list, val_list, test_list = [], [], []
|
|
|
|
|
|
num_test_ad = int(len(AD_list) * val_split)
|
|
@@ -84,53 +87,59 @@ def get_train_val_test(AD_list, NL_list, val_split):
|
|
|
|
|
|
|
|
|
class ADNIDataset(Dataset):
|
|
|
- def __init__(self, mri, xls: pd.DataFrame):
|
|
|
- self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
|
|
|
+ def __init__(self, mri, xls: pd.DataFrame, device=torch.device("cpu")):
|
|
|
+ self.mri_data = mri # DATA IS A LIST WITH TUPLES (image_dir, class_id)
|
|
|
self.xls_data = xls
|
|
|
-
|
|
|
+ self.device = device
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.mri_data)
|
|
|
-
|
|
|
+
|
|
|
def _xls_to_tensor(self, xls_data: pd.Series):
|
|
|
- #Get used data
|
|
|
+ # Get used data
|
|
|
+
|
|
|
+ # data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
|
|
|
+ data = xls_data.loc[["Sex", "Age (current)"]]
|
|
|
|
|
|
- #data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']]
|
|
|
- data = xls_data.loc[['Sex', 'Age (current)']]
|
|
|
-
|
|
|
- data.replace({'M': 0, 'F': 1}, inplace=True)
|
|
|
-
|
|
|
+ data.replace({"M": 0, "F": 1}, inplace=True)
|
|
|
|
|
|
- #Convert to tensor
|
|
|
+ # Convert to tensor
|
|
|
xls_tensor = torch.tensor(data.values.astype(float))
|
|
|
-
|
|
|
+
|
|
|
return xls_tensor
|
|
|
|
|
|
- def __getitem__(self, idx): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
|
|
|
+ def __getitem__(
|
|
|
+ self, idx
|
|
|
+ ): # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
|
|
|
mri_path, class_id = self.mri_data[idx]
|
|
|
mri = nib.load(mri_path)
|
|
|
mri_data = mri.get_fdata()
|
|
|
|
|
|
xls = self.xls_data.iloc[idx]
|
|
|
|
|
|
- #Convert xls data to tensor
|
|
|
+ # Convert xls data to tensor
|
|
|
xls_tensor = self._xls_to_tensor(xls)
|
|
|
mri_tensor = torch.from_numpy(mri_data).unsqueeze(0)
|
|
|
-
|
|
|
+
|
|
|
class_id = torch.tensor([class_id])
|
|
|
- #Convert to one-hot and squeeze
|
|
|
+ # Convert to one-hot and squeeze
|
|
|
class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
|
|
|
-
|
|
|
- #Convert to float
|
|
|
- mri_tensor = mri_tensor.float()
|
|
|
- xls_tensor = xls_tensor.float()
|
|
|
- class_id = class_id.float()
|
|
|
+
|
|
|
+ # Convert to float
|
|
|
+ mri_tensor = mri_tensor.float().to(self.device)
|
|
|
+ xls_tensor = xls_tensor.float().to(self.device)
|
|
|
+ class_id = class_id.float().to(self.device)
|
|
|
|
|
|
return (mri_tensor, xls_tensor), class_id
|
|
|
-
|
|
|
-
|
|
|
-def initalize_dataloaders(training_data, val_data, test_data, cuda_device=torch.device('cuda:0'), batch_size=64):
|
|
|
+
|
|
|
+
|
|
|
+def initalize_dataloaders(
|
|
|
+ training_data,
|
|
|
+ val_data,
|
|
|
+ test_data,
|
|
|
+ batch_size=64,
|
|
|
+):
|
|
|
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
|
|
|
test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True)
|
|
|
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)
|
|
|
- return train_dataloader, val_dataloader, test_dataloader
|
|
|
+ return train_dataloader, val_dataloader, test_dataloader
|