Browse Source

First, unsuccessful attempt at GuidedGradCam. Trained new model & removed clinical data temporarily

Ruben 2 months ago
parent
commit
f140373ea3
31 changed files with 92 additions and 29 deletions
  1. BIN
      ROC_2024-08-13_14.36.png
  2. BIN
      avgloss_epoch_curve_2024-08-13_14.36.png
  3. BIN
      cnn_net.pth
  4. 21 0
      cnn_net_data_2024-08-13_14.36.csv
  5. BIN
      confusion_matrix_2024-08-13_14.36.png
  6. 0 0
      figures/5-fold_7_26/5_folds_2024-07-26_14.26.txt
  7. 0 0
      figures/5-fold_7_26/ROC_5_Folds_2024-07-26_15.22.png
  8. 0 0
      figures/5-fold_7_26/cnn_net_data_2024-07-26_14.37.csv
  9. 0 0
      figures/5-fold_7_26/cnn_net_data_2024-07-26_14.48.csv
  10. 0 0
      figures/5-fold_7_26/cnn_net_data_2024-07-26_15.00.csv
  11. 0 0
      figures/5-fold_7_26/cnn_net_data_2024-07-26_15.11.csv
  12. 0 0
      figures/5-fold_7_26/cnn_net_data_2024-07-26_15.22.csv
  13. 0 0
      figures/Age_and_Sex_Data_Run/ROC_2024-05-30_15.48.png
  14. 0 0
      figures/Age_and_Sex_Data_Run/avgloss_epoch_curve_2024-05-30_15.48.png
  15. 0 0
      figures/Age_and_Sex_Data_Run/cnn_net_data_2024-05-30_15.48.csv
  16. 0 0
      figures/Age_and_Sex_Data_Run/confusion_matrix_2024-05-30_15.48.png
  17. 0 0
      figures/Pre-Clinical_Data_Run/ROC_2024-05-30_11.39.png
  18. 0 0
      figures/Pre-Clinical_Data_Run/avgloss_epoch_curve_2024-05-30_11.39.png
  19. 0 0
      figures/Pre-Clinical_Data_Run/cnn_net.pth
  20. 0 0
      figures/Pre-Clinical_Data_Run/cnn_net_data_2024-05-30_11.39.csv
  21. 0 0
      figures/Pre-Clinical_Data_Run/confusion_matrix_2024-05-30_11.39.png
  22. 0 0
      figures/ROC_2024-07-19_12.31.png
  23. 0 0
      figures/avgloss_epoch_curve_2024-07-19_12.30.png
  24. 0 0
      figures/cnn_net_data_2024-07-19_12.30.csv
  25. 0 0
      figures/confusion_matrix_2024-07-19_12.31.png
  26. 36 9
      main.py
  27. 1 1
      original_model/innvestigate/analyzer/__init__.py
  28. 15 7
      utils/CNN.py
  29. 5 3
      utils/dataset_sd_mean_finder.py
  30. 6 3
      utils/preprocess.py
  31. 8 6
      utils/train_methods.py

BIN
ROC_2024-08-13_14.36.png


BIN
avgloss_epoch_curve_2024-08-13_14.36.png


BIN
cnn_net.pth


+ 21 - 0
cnn_net_data_2024-08-13_14.36.csv

@@ -0,0 +1,21 @@
+,Epoch,Avg_loss,Time,Val_loss
+0,1,0.6887565185041988,82.2500913143158,0.6797546744346619
+0,2,0.682743440656101,138.5955469608307,0.678901731967926
+0,3,0.6790511222446666,192.37863397598267,0.6798441410064697
+0,4,0.6764011488241308,246.2428903579712,0.6738172173500061
+0,5,0.6745905350236332,300.25479340553284,0.6706101298332214
+0,6,0.6742749599849477,354.43093943595886,0.6625040769577026
+0,7,0.6726816296577454,408.60620403289795,0.6585175395011902
+0,8,0.6735070768524619,462.8501789569855,0.6575979590415955
+0,9,0.6730456282110775,517.0672903060913,0.6523853540420532
+0,10,0.6724514961242676,571.3123607635498,0.6538687944412231
+0,11,0.6720262976253734,625.621054649353,0.6569682359695435
+0,12,0.6721456226180581,679.8675563335419,0.6546973586082458
+0,13,0.6719080104547388,734.0622756481171,0.6512206196784973
+0,14,0.6719393344486461,788.1773836612701,0.6480571031570435
+0,15,0.671908504822675,842.2912216186523,0.653897225856781
+0,16,0.6717916411512038,896.4600219726562,0.650320827960968
+0,17,0.6717809789321002,950.6497626304626,0.6436631083488464
+0,18,0.6719728462836322,1004.8499879837036,0.6612203121185303
+0,19,0.6715883086709415,1059.0907924175262,0.6507071256637573
+0,20,0.6717346520984874,1113.300742149353,0.669158399105072

BIN
confusion_matrix_2024-08-13_14.36.png


+ 0 - 0
figures/5-fold_7_26/5_folds_2024-07-26_14:26.txt → figures/5-fold_7_26/5_folds_2024-07-26_14.26.txt


+ 0 - 0
figures/5-fold_7_26/ROC_5_Folds_2024-07-26_15:22.png → figures/5-fold_7_26/ROC_5_Folds_2024-07-26_15.22.png


+ 0 - 0
figures/5-fold_7_26/cnn_net_data_2024-07-26_14:37.csv → figures/5-fold_7_26/cnn_net_data_2024-07-26_14.37.csv


+ 0 - 0
figures/5-fold_7_26/cnn_net_data_2024-07-26_14:48.csv → figures/5-fold_7_26/cnn_net_data_2024-07-26_14.48.csv


+ 0 - 0
figures/5-fold_7_26/cnn_net_data_2024-07-26_15:00.csv → figures/5-fold_7_26/cnn_net_data_2024-07-26_15.00.csv


+ 0 - 0
figures/5-fold_7_26/cnn_net_data_2024-07-26_15:11.csv → figures/5-fold_7_26/cnn_net_data_2024-07-26_15.11.csv


+ 0 - 0
figures/5-fold_7_26/cnn_net_data_2024-07-26_15:22.csv → figures/5-fold_7_26/cnn_net_data_2024-07-26_15.22.csv


+ 0 - 0
figures/Age_and_Sex_Data_Run/ROC_2024-05-30_15:48.png → figures/Age_and_Sex_Data_Run/ROC_2024-05-30_15.48.png


+ 0 - 0
figures/Age_and_Sex_Data_Run/avgloss_epoch_curve_2024-05-30_15:48.png → figures/Age_and_Sex_Data_Run/avgloss_epoch_curve_2024-05-30_15.48.png


+ 0 - 0
figures/Age_and_Sex_Data_Run/cnn_net_data_2024-05-30_15:48.csv → figures/Age_and_Sex_Data_Run/cnn_net_data_2024-05-30_15.48.csv


+ 0 - 0
figures/Age_and_Sex_Data_Run/confusion_matrix_2024-05-30_15:48.png → figures/Age_and_Sex_Data_Run/confusion_matrix_2024-05-30_15.48.png


+ 0 - 0
figures/Pre-Clinical_Data_Run/ROC_2024-05-30_11:39.png → figures/Pre-Clinical_Data_Run/ROC_2024-05-30_11.39.png


+ 0 - 0
figures/Pre-Clinical_Data_Run/avgloss_epoch_curve_2024-05-30_11:39.png → figures/Pre-Clinical_Data_Run/avgloss_epoch_curve_2024-05-30_11.39.png


+ 0 - 0
figures/Pre-Clinical Data Run/cnn_net.pth → figures/Pre-Clinical_Data_Run/cnn_net.pth


+ 0 - 0
figures/Pre-Clinical_Data_Run/cnn_net_data_2024-05-30_11:39.csv → figures/Pre-Clinical_Data_Run/cnn_net_data_2024-05-30_11.39.csv


+ 0 - 0
figures/Pre-Clinical_Data_Run/confusion_matrix_2024-05-30_11:39.png → figures/Pre-Clinical_Data_Run/confusion_matrix_2024-05-30_11.39.png


+ 0 - 0
figures/ROC_2024-07-19_12:31.png → figures/ROC_2024-07-19_12.31.png


+ 0 - 0
figures/avgloss_epoch_curve_2024-07-19_12:30.png → figures/avgloss_epoch_curve_2024-07-19_12.30.png


+ 0 - 0
figures/cnn_net_data_2024-07-19_12:30.csv → figures/cnn_net_data_2024-07-19_12.30.csv


+ 0 - 0
figures/confusion_matrix_2024-07-19_12:31.png → figures/confusion_matrix_2024-07-19_12.31.png


+ 36 - 9
main.py

@@ -1,3 +1,5 @@
+from symbol import parameters
+
 import torch
 
 # FOR DATA
@@ -16,6 +18,11 @@ import platform
 import time
 current_time = time.localtime()
 
+
+# INTERPRETABILITY
+from captum.attr import GuidedGradCam
+
+
 print(time.strftime("%Y-%m-%d_%H:%M", current_time))
 print("--- RUNNING ---")
 print("Pytorch Version: " + torch. __version__)
@@ -46,13 +53,13 @@ CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'       # cn
 # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
 # big dataset
 mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'   # Real data
-annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
+csv_datapath = 'LP_ADNIMERGE.csv'
 
 # annotations_file = pd.read_csv(annotations_datapath)    # DataFrame
 # show_image(17508)
 
 # TODO: Datasets include multiple labels, such as medical info
-training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
+training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
 
 # Create data loaders
 train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
@@ -63,7 +70,7 @@ test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shu
 # loads a few images to test
 x = 0
 while x < 0:
-    train_features, train_labels = next(iter(train_dataloader))[0]
+    train_features, train_labels = next(iter(train_dataloader))
     # print(f"Feature batch shape: {train_features.size()}")
     img = train_features[0].squeeze()
     print(f"Feature batch shape: {img.size()}")
@@ -72,23 +79,43 @@ while x < 0:
     label = train_labels[0]
     print(f"Label: {label}")
     plt.imshow(image, cmap="gray")
-    plt.savefig(f"./Image{x}_IS:{label}.png")
+    # plt.savefig(f"./Image{x}_IS:{label}.png")
     plt.show()
 
     x = x+1
 
 
-epochs = 20
 roc = True
 CNN = CNN_Net(prps=properties, final_layer_size=2)
 CNN.cuda()
 
-train(CNN, train_dataloader, val_dataloader, CNN_filepath, epochs, graphs=True)
-# load(CNN, CNN_filepath)
-evaluate(CNN, test_dataloader)
-predict(CNN, test_dataloader)
+# train(CNN, train_dataloader, val_dataloader, CNN_filepath, properties, graphs=True)
+load(CNN, CNN_filepath)
+# evaluate(CNN, test_dataloader)
+# predict(CNN, test_dataloader)
+
+print(CNN)
+CNN.eval()
+
+guided_gc = GuidedGradCam(CNN, CNN.conv5_sepConv)   # Performed on LAST convolution layer
+# input = torch.randn(1, 1, 91, 109, 91, requires_grad=True).cuda()
+
+# TODO MAKE BATCH SIZE 1 FOR THIS TO WORK??
+train_features, train_labels = next(iter(train_dataloader))
+while(train_labels[0] == 0):
+    train_features, train_labels = next(iter(train_dataloader))
+
+attr = guided_gc.attribute(train_features.cuda(), 0) #, interpolate_mode="area")
 
+# draw the attributes
+attr = attr.unsqueeze(0)
+attr = attr.cpu().detach().numpy()
+attr = np.clip(attr, 0, 1)
+plt.imshow(attr)
+plt.show()
 
+print("Done w/ attributions")
+print(attr)
 
 # EXTRA
 

+ 1 - 1
original_model/innvestigate/analyzer/__init__.py

@@ -115,7 +115,7 @@ analyzers = {
 
     # Pattern based
     "pattern.net": PatternNet,
-    "pattern.attribution": PatternAttribution,
+    "pattern.attr": PatternAttribution,
 }
 
 

+ 15 - 7
utils/CNN.py

@@ -22,21 +22,29 @@ class CNN_Net(nn.Module):   # , talos.utils.TorchHistory):
         self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
                                                                 prps=prps, sep_conv=True)
         self.fc1 = CustomLayers.Fc_elu_drop(113568, 10, prps=prps, softmax=False)      # TODO, concatenate clinical data after this
-        self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True)  # For now this works as output layer, though may be incorrect
+        self.fc2 = CustomLayers.Fc_elu_drop(10, final_layer_size, prps=prps, softmax=True)  # For now this works as output layer, though may be incorrect
         self.fc_clinical1 = CustomLayers.Fc_elu_drop(2, 30, prps=prps, softmax=False)
         self.fc_clinical2 = CustomLayers.Fc_elu_drop(30,10, prps=prps, softmax=False)
 
+        # TESTING
+        self.gradients = None
+
+    def activations_hook(self, grad):
+        self.gradients = grad
 
     # FORWARDS
     def forward(self, x):
-        clinical_data = x[1].to(torch.float32)
-        x = x[0]
+        # clinical_data = x[1].to(torch.float32)
+        # x = x[0]
         x = self.conv1(x)
         x = self.conv2(x)
         x = self.conv3_mid_flow(x)
         x = self.conv4_sepConv(x)
         x = self.conv5_sepConv(x)
 
+        # REGISTER HOOK
+        # h = x.register_hook(self.activations_hook)
+
         # FLATTEN x
         flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
         x = x.view(-1, flatten_size)
@@ -44,11 +52,11 @@ class CNN_Net(nn.Module):   # , talos.utils.TorchHistory):
         x = self.fc1(x)
 
         # Clinical
-        clinical_data = torch.transpose(clinical_data, 0, 1)
-        clinical_data = self.fc_clinical1(clinical_data)
-        clinical_data = self.fc_clinical2(clinical_data)
+        # clinical_data = torch.transpose(clinical_data, 0, 1)
+        # clinical_data = self.fc_clinical1(clinical_data)
+        # clinical_data = self.fc_clinical2(clinical_data)
 
-        x = cat((x, clinical_data), dim=1)
+        # x = cat((x, clinical_data), dim=1)    # TODO IGNORES CLINICAL DATA
 
         x = self.fc2(x)
         return x

+ 5 - 3
utils/dataset_sd_mean_finder.py

@@ -9,7 +9,8 @@ CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'       # cn
 # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
 # big dataset
 mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/'   # Real data
-annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
+csv_datapath = '../LP_ADNIMERGE.csv'
+                        # '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv')
 
 
 # LOADING DATA
@@ -33,14 +34,14 @@ properties = {
 
 
 # TODO: Datasets include multiple labels, such as medical info
-training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split, seed)
+training_data, val_data, test_data = prepare_datasets(mri_datapath, csv_datapath, val_split, seed)
 
 # Create data loaders
 train_dataloader = DataLoader(training_data, batch_size=properties['batch_size'], shuffle=True, drop_last=True)
 val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True)        # Used during training
 test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True)      # Used at end for graphs
 
-
+print("STARTING")
 
 
 # HERE'S ACTUAL CODE
@@ -48,6 +49,7 @@ mean = 0.
 std = 0.
 nb_samples = 0.
 for data in train_dataloader:
+    print(data)
     batch_samples = data.size(0)
     data = data.view(batch_samples, data.size(1), -1)
     mean += data.mean(2).sum(0)

+ 6 - 3
utils/preprocess.py

@@ -12,10 +12,10 @@ import re
 '''
 Prepares CustomDatasets for training, validating, and testing CNN
 '''
-def prepare_datasets(mri_dir, val_split=0.2, seed=50):
+def prepare_datasets(mri_dir, csv_dir, val_split=0.2, seed=50):
 
     rndm = random.Random(seed)
-    csv = pd.read_csv("LP_ADNIMERGE.csv")
+    csv = pd.read_csv(csv_dir)
 
     raw_data = glob.glob(mri_dir + "*")
 
@@ -144,9 +144,12 @@ class CustomDataset(Dataset):
         mri_data = np.asarray(np.expand_dims(image, axis=0))
         mri_data = self.transform(mri_data)
 
+        # print(mri_data.dtype)
+
         # mri_data = mri.get_fdata()
         # mri_array = np.array(mri)
         # mri_tensor = torch.from_numpy(mri_array)
         # class_id = torch.tensor([class_id]) TODO return tensor or just id (0, 1)??
 
-        return (mri_data, clinical_data), class_id
+        # return (mri_data, clinical_data), class_id
+        return mri_data, class_id

+ 8 - 6
utils/train_methods.py

@@ -24,7 +24,7 @@ def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
 
     # model.init_history()
 
-    epochs = params['epochs']
+    epochs = params["epochs"]
     for epoch in range(epochs):  # loop over the dataset multiple times
         epoch += 1
 
@@ -40,7 +40,8 @@ def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
         # Batches & training
         for i, data in enumerate(train_data, 0):
             # get the inputs; data is a list of [inputs, labels]
-            inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
+            inputs, labels = data[0].to(model.device), data[1].to(model.device)
+            # inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device)   # TODO Clinical data not sent to model.device
 
             # zero the parameter gradients
             optimizer.zero_grad()
@@ -81,7 +82,7 @@ def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
     print('Finished Training')
 
     start_time = time.localtime()
-    time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
+    time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
     losses.to_csv(f'./cnn_net_data_{time_string}.csv')
 
     if(graphs):
@@ -120,7 +121,8 @@ def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None
     # since we're not training, we don't need to calculate the gradients for our outputs
     with torch.no_grad():
         for data in val_data:
-            images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device)  # TODO Clinical data not sent to model.device
+            images, labels = data[0].to(model.device), data[1].to(model.device)
+            # images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device)  # TODO Clinical data not sent to model.device
 
             # calculate outputs by running images through the model
             outputs = model.forward(images)
@@ -153,7 +155,7 @@ def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None
         # ROC
         # Calculate TPR and FPR
         fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
-        time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
+        time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
 
         # Calculate AUC
         roc_auc = auc(fpr, tpr)
@@ -179,7 +181,7 @@ def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None
 
     if(not graphs): print(f'Validation loss: {loss.item()}')
     else:
-        time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
+        time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
         # ROC
         # Calculate TPR and FPR
         fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)