| 
					
				 | 
			
			
				@@ -3,22 +3,23 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # TODO ENSURE ITERATION WORKS 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import glob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import nibabel as nib 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import random 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torch.utils.data import Dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import pandas as pd  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import pandas as pd 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from torch.utils.data import DataLoader 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 Prepares CustomDatasets for training, validating, and testing CNN 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def prepare_datasets( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    mri_dir, xls_file, val_split=0.2, seed=50, device=torch.device("cpu") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     rndm = random.Random(seed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    xls_data = pd.read_csv(xls_file).set_index('Image Data ID') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    xls_data = pd.read_csv(xls_file).set_index("Image Data ID") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     raw_data = glob.glob(mri_dir + "*") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     AD_list = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     NL_list = [] 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -39,19 +40,21 @@ def prepare_datasets(mri_dir, xls_file, val_split=0.2, seed=50): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     rndm.shuffle(val_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     rndm.shuffle(test_list) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    train_dataset = ADNIDataset(train_list, xls_data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    val_dataset = ADNIDataset(val_list, xls_data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    test_dataset = ADNIDataset(test_list, xls_data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    train_dataset = ADNIDataset(train_list, xls_data, device=device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    val_dataset = ADNIDataset(val_list, xls_data, device=device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    test_dataset = ADNIDataset(test_list, xls_data, device=device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return train_dataset, val_dataset, test_dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # TODO  Normalize data? Later add / Exctract clinical data? Which data? 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 Returns train_list, val_list and test_list in format [(image, id), ...] each 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-''' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def get_train_val_test(AD_list, NL_list, val_split): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def get_train_val_test(AD_list, NL_list, val_split): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     train_list, val_list, test_list = [], [], [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     num_test_ad = int(len(AD_list) * val_split) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -84,53 +87,59 @@ def get_train_val_test(AD_list, NL_list, val_split): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class ADNIDataset(Dataset): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, mri, xls: pd.DataFrame): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.mri_data = mri        # DATA IS A LIST WITH TUPLES (image_dir, class_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, mri, xls: pd.DataFrame, device=torch.device("cpu")): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.mri_data = mri  # DATA IS A LIST WITH TUPLES (image_dir, class_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.xls_data = xls 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.device = device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def __len__(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return len(self.mri_data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _xls_to_tensor(self, xls_data: pd.Series): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Get used data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Get used data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        data = xls_data.loc[["Sex", "Age (current)"]] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #data = xls_data.loc[['Sex', 'Age (current)', 'PTID', 'DXCONFID (1=uncertain, 2= mild, 3= moderate, 4=high confidence)', 'Alz_csf']] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        data = xls_data.loc[['Sex', 'Age (current)']] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        data.replace({'M': 0, 'F': 1}, inplace=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        data.replace({"M": 0, "F": 1}, inplace=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Convert to tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Convert to tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         xls_tensor = torch.tensor(data.values.astype(float)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return xls_tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __getitem__(self, idx):     # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __getitem__( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self, idx 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ):  # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mri_path, class_id = self.mri_data[idx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mri = nib.load(mri_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mri_data = mri.get_fdata() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         xls = self.xls_data.iloc[idx] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Convert xls data to tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Convert xls data to tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         xls_tensor = self._xls_to_tensor(xls) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         mri_tensor = torch.from_numpy(mri_data).unsqueeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         class_id = torch.tensor([class_id]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Convert to one-hot and squeeze 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Convert to one-hot and squeeze 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Convert to float 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        mri_tensor = mri_tensor.float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        xls_tensor = xls_tensor.float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        class_id = class_id.float() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Convert to float 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        mri_tensor = mri_tensor.float().to(self.device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        xls_tensor = xls_tensor.float().to(self.device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        class_id = class_id.float().to(self.device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return (mri_tensor, xls_tensor), class_id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def initalize_dataloaders(training_data, val_data, test_data, cuda_device=torch.device('cuda:0'), batch_size=64): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def initalize_dataloaders( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    training_data, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    val_data, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    test_data, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    batch_size=64, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    return train_dataloader, val_dataloader, test_dataloader 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return train_dataloader, val_dataloader, test_dataloader 
			 |