Explorar el Código

unfinished clinical_data implementation

Ruben hace 5 meses
padre
commit
c3c69c6c7b
Se han modificado 4 ficheros con 49 adiciones y 19 borrados
  1. 2 2
      main.py
  2. 16 3
      utils/CNN.py
  3. 29 12
      utils/preprocess.py
  4. 2 2
      utils/train_methods.py

+ 2 - 2
main.py

@@ -39,9 +39,9 @@ properties = {
 model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
 CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'       # cnn_net.pth
 # small dataset
-# mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
+mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
 # big dataset
-mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'   # Real data
+# mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'   # Real data
 annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
 
 # annotations_file = pd.read_csv(annotations_datapath)    # DataFrame

+ 16 - 3
utils/CNN.py

@@ -1,4 +1,5 @@
-from torch import device, cuda
+import torch
+from torch import device, cuda, cat, stack
 import torch.nn as nn
 import utils.CNN_Layers as CustomLayers
 
@@ -19,12 +20,16 @@ class CNN_Net(nn.Module):
                                                                 sep_conv=True)
         self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
                                                                 prps=prps, sep_conv=True)
-        self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps, softmax=False)      # TODO, concatenate clinical data after this
-        self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True)  # For now this works as output layer, though may be incorrect
+        self.fc1 = CustomLayers.Fc_elu_drop(113568, 10, prps=prps, softmax=False)      # TODO, concatenate clinical data after this
+        self.fc2 = CustomLayers.Fc_elu_drop(10, final_layer_size, prps=prps, softmax=True)  # For now this works as output layer, though may be incorrect
+        self.fc_clinical1 = CustomLayers.Fc_elu_drop(6, 30, prps=prps, softmax=False)
+        self.fc_clinical2 = CustomLayers.Fc_elu_drop(30,10, prps=prps, softmax=False)
 
 
     # FORWARDS
     def forward(self, x):
+        clinical_data = x[1].to(torch.float32)
+        x = x[0]
         x = self.conv1(x)
         x = self.conv2(x)
         x = self.conv3_mid_flow(x)
@@ -36,5 +41,13 @@ class CNN_Net(nn.Module):
         x = x.view(-1, flatten_size)
 
         x = self.fc1(x)
+
+        # Clinical
+        clinical_data = self.fc_clinical1(clinical_data)
+        clinical_data = self.fc_clinical2(clinical_data)
+
+        x = cat((x, clinical_data), dim=1)
+        print(x.shape)
+
         x = self.fc2(x)
         return x

+ 29 - 12
utils/preprocess.py

@@ -1,6 +1,7 @@
 import glob
 import nibabel as nib
 import numpy as np
+import pandas as pd
 import random
 import torch
 from torch.utils.data import Dataset
@@ -14,6 +15,7 @@ Prepares CustomDatasets for training, validating, and testing CNN
 def prepare_datasets(mri_dir, val_split=0.2, seed=50):
 
     rndm = random.Random(seed)
+    csv = pd.read_csv("LP_ADNIMERGE.csv")
 
     raw_data = glob.glob(mri_dir + "*")
 
@@ -25,10 +27,25 @@ def prepare_datasets(mri_dir, val_split=0.2, seed=50):
 
     # TODO Check that image is in CSV?
     for image in raw_data:
+        # FIND IMAGE IN CSV, GET IT'S CLINICAL DATA
+        image_ID = re.search(r"_I(\d+)_", image).group(1)
+        if(image_ID==None): raise RuntimeError("Image not found in CSV!")
+        image_data = csv[csv["Image Data ID"] == int(image_ID)].loc[:,['Sex', 'Age (current)']]     # M: 0, F: 1
+        image_data = image_data.iloc[0].tolist()
+        if (image_data[0] == 'M'):
+            image_data[0] = 0
+        elif(image_data[0] == 'F'):
+            image_data[0] = 1
+        else:
+            raise RuntimeError("Incorrect sex in an image")
+
+        image_data = tuple(image_data)
+
         if "NL" in image:
-            NL_list.append(image)
+            NL_list.append((image, 0, image_data))
         elif "AD" in image:
-            AD_list.append(image)
+            AD_list.append((image, 1, image_data))
+
 
     print("Total AD: " + str(len(AD_list)))
     print("Total NL: " + str(len(NL_list)))
@@ -95,36 +112,36 @@ def get_train_val_test(AD_list, NL_list, val_split):
 
     # Sets up ADs
     for image in AD_list[0:num_val_ad]:
-        val_list.append((image, 1))
+        val_list.append(image)
 
     for image in AD_list[num_val_ad:num_test_ad]:
-        test_list.append((image, 1))
+        test_list.append(image)
 
     for image in AD_list[num_test_ad:]:
-        train_list.append((image, 1))
+        train_list.append(image)
 
     # Sets up NLs
     for image in NL_list[0:num_val_nl]:
-        val_list.append((image, 0))
+        val_list.append(image)
 
     for image in NL_list[num_val_nl:num_test_nl]:
-        test_list.append((image, 0))
+        test_list.append(image)
 
     for image in NL_list[num_test_nl:]:
-        train_list.append((image, 0))
+        train_list.append(image)
 
     return train_list, val_list, test_list
 
 
 class CustomDataset(Dataset):
     def __init__(self, list):
-        self.data = list        # DATA IS A LIST WITH TUPLES (image_dir, class_id)
+        self.data = list        # INPUT DATA: (image_dir, class_id, (clinical_data))
 
     def __len__(self):
         return len(self.data)
 
-    def __getitem__(self, idx):     # RETURNS TUPLE WITH IMAGE AND CLASS_ID, BASED ON INDEX IDX
-        mri_path, class_id = self.data[idx]
+    def __getitem__(self, idx):     # RETURNS TUPLE: ((mri_data, [clinical_data]), class_id)
+        mri_path, class_id, clinical_data = self.data[idx]
         mri = nib.load(mri_path)
         image = np.asarray(mri.dataobj)
         mri_data = np.asarray(np.expand_dims(image, axis=0))
@@ -134,4 +151,4 @@ class CustomDataset(Dataset):
         # mri_tensor = torch.from_numpy(mri_array)
         # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
 
-        return mri_data, class_id
+        return (mri_data, clinical_data), class_id

+ 2 - 2
utils/train_methods.py

@@ -1,6 +1,6 @@
 import torch
 
-from torch import nn, optim
+from torch import nn, optim, cat, stack
 from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
 import seaborn as sns
 
@@ -35,7 +35,7 @@ def train(model, train_data, test_data, CNN_filepath, epochs=20, graphs=True):
         # Batches & training
         for i, data in enumerate(train_data, 0):
             # get the inputs; data is a list of [inputs, labels]
-            inputs, labels = data[0].to(model.device), data[1].to(model.device)
+            inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
 
             # zero the parameter gradients
             optimizer.zero_grad()