Explorar o código

untested, first rearrangement of code for readibility

Ruben hai 5 meses
pai
achega
224385dc6d
Modificáronse 4 ficheiros con 230 adicións e 234 borrados
  1. 49 74
      main.py
  2. 2 157
      utils/CNN.py
  3. 2 3
      utils/CNN_Layers.py
  4. 177 0
      utils/train_methods.py

+ 49 - 74
main.py

@@ -1,56 +1,26 @@
 import torch
-import torchvision
 
 # FOR DATA
-from utils.preprocess import prepare_datasets, prepare_predict
-from utils.show_image import show_image
+from utils.preprocess import prepare_datasets
+from utils.train_methods import train, load, evaluate, predict
 from utils.CNN import CNN_Net
 from torch.utils.data import DataLoader
 from torchvision import datasets
 
-from torch import nn
-from torchvision.transforms import ToTensor
-
-# import nonechucks as nc     # Used to load data in pytorch even when images are corrupted / unavailable (skips them)
-
-# FOR IMAGE VISUALIZATION
-import nibabel as nib
-
 # GENERAL PURPOSE
-import os
 import pandas as pd
 import numpy as np
 import matplotlib.pyplot as plt
-import glob
 import platform
 
-
-
 print("--- RUNNING ---")
 print("Pytorch Version: " + torch. __version__)
 print("Python Version: " + platform.python_version())
 
 # LOADING DATA
-# data & training properties:
 val_split = 0.2     # % of val and test, rest will be train
 seed = 12       # TODO Randomize seed
 
-# params = {
-#     "target_rows": 91,
-#     "target_cols": 109,
-#     "depth": 91,
-#     "axis": 1,
-#     "num_clinical": 2,
-#     "CNN_drop_rate": 0.3,
-#     "RNN_drop_rate": 0.1,
-#     # "CNN_w_regularizer": regularizers.l2(2e-2),
-#     # "RNN_w_regularizer": regularizers.l2(1e-6),
-#     "CNN_batch_size": 10,
-#     "RNN_batch_size": 5,
-#     "val_split": 0.2,
-#     "final_layer_size": 5
-# }
-
 properties = {
     "batch_size":6,
     "padding":0,
@@ -62,7 +32,6 @@ properties = {
 }
 
 
-# Might have to replace datapaths or separate between training and testing
 model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
 CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth'       # cnn_net.pth
 # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/'   # Small Test
@@ -80,13 +49,7 @@ train_dataloader = DataLoader(training_data, batch_size=properties['batch_size']
 test_dataloader = DataLoader(test_data, batch_size=properties['batch_size'], shuffle=True)
 val_dataloader = DataLoader(val_data, batch_size=properties['batch_size'], shuffle=True)
 
-# for X, y in train_dataloader:
-#     print(f"Shape of X [Channels (colors), Y, X, Z]: {X.shape}")   # X & Y are from TOP LOOKING DOWN
-#     print(f"Shape of Y (Dataset?): {y.shape} {y.dtype}")
-#     break
-
-
-# Display 4 images and labels.
+# loads a few images to test
 x = 0
 while x < 0:
     train_features, train_labels = next(iter(train_dataloader))
@@ -104,51 +67,63 @@ while x < 0:
     x = x+1
 
 
-train = False
-predict = False
+
+epochs = 20
+roc = True
 CNN = CNN_Net(prps=properties, final_layer_size=2)
 CNN.cuda()
 
-# RUN CNN
-if(train):
-    CNN.train_model(train_dataloader, test_dataloader, CNN_filepath, epochs=20)
-    CNN.evaluate_model(val_dataloader, roc=True)
-
-else:
-    CNN.load_state_dict(torch.load(CNN_filepath))
-    CNN.evaluate_model(val_dataloader, roc=True)
+train(CNN, train_dataloader, test_dataloader, CNN_filepath, epochs, graphs=True)
+load(CNN, CNN_filepath)
+evaluate(CNN, val_dataloader)
+predict(CNN, val_dataloader)
 
 
-# PREDICT MODE TO TEST INDIVIDUAL IMAGES
-if(predict):
-    on = True
-    print("---- Predict mode ----")
-    print("Integer for image")
-    print("x or X for exit")
 
-    while(on):
-        inp = input("Next image: ")
-        if(inp == None or inp.lower() == 'x' or not inp.isdigit()): on = False
-        else:
-            dataloader = DataLoader(prepare_predict(mri_datapath, [inp]), batch_size=properties['batch_size'], shuffle=True)
-            prediction = CNN.predict(dataloader)
-
-            features, labels = next(iter(dataloader), )
-            img = features[0].squeeze()
-            image = img[:, :, 40]
-            print(f"Expected class: {labels}")
-            print(f"Prediction: {prediction}")
-            plt.imshow(image, cmap="gray")
-            plt.show()
-
-print("--- END ---")
+# EXTRA
 
 
+# # PREDICT MODE TO TEST INDIVIDUAL IMAGES
+# if(predict):
+#     on = True
+#     print("---- Predict mode ----")
+#     print("Integer for image")
+#     print("x or X for exit")
+#
+#     while(on):
+#         inp = input("Next image: ")
+#         if(inp == None or inp.lower() == 'x' or not inp.isdigit()): on = False
+#         else:
+#             dataloader = DataLoader(prepare_predict(mri_datapath, [inp]), batch_size=properties['batch_size'], shuffle=True)
+#             prediction = CNN.predict(dataloader)
+#
+#             features, labels = next(iter(dataloader), )
+#             img = features[0].squeeze()
+#             image = img[:, :, 40]
+#             print(f"Expected class: {labels}")
+#             print(f"Prediction: {prediction}")
+#             plt.imshow(image, cmap="gray")
+#             plt.show()
+#
+# print("--- END ---")
 
 
-# EXTRA
+# params = {
+#     "target_rows": 91,
+#     "target_cols": 109,
+#     "depth": 91,
+#     "axis": 1,
+#     "num_clinical": 2,
+#     "CNN_drop_rate": 0.3,
+#     "RNN_drop_rate": 0.1,
+#     # "CNN_w_regularizer": regularizers.l2(2e-2),
+#     # "RNN_w_regularizer": regularizers.l2(1e-6),
+#     "CNN_batch_size": 10,
+#     "RNN_batch_size": 5,
+#     "val_split": 0.2,
+#     "final_layer_size": 5
+# }
 
-# will I need these params?
 '''
 params_dict = { 'CNN_w_regularizer': CNN_w_regularizer, 'RNN_w_regularizer': RNN_w_regularizer,
                'CNN_batch_size': CNN_batch_size, 'RNN_batch_size': RNN_batch_size,

+ 2 - 157
utils/CNN.py

@@ -31,6 +31,7 @@ class CNN_Net(nn.Module):
         self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps, softmax=False)      # TODO, concatenate clinical data after this
         self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True)  # For now this works as output layer, though may be incorrect
 
+
     # FORWARDS
     def forward(self, x):
         x = self.conv1(x)
@@ -45,160 +46,4 @@ class CNN_Net(nn.Module):
 
         x = self.fc1(x)
         x = self.fc2(x)
-        return x
-
-    # TRAIN
-    def train_model(self, trainloader, testloader, PATH, epochs):
-        self.train()
-        criterion = nn.CrossEntropyLoss(reduction='mean')
-        optimizer = optim.Adam(self.parameters(), lr=1e-5)
-
-        losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
-        start_time = time.time()  # seconds
-
-        for epoch in range(epochs):  # loop over the dataset multiple times
-            epoch += 1
-
-            # Estimate & count training time
-            t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
-            t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
-            print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
-
-            running_loss = 0.0
-
-            # Batches & training
-            for i, data in enumerate(trainloader, 0):
-                # get the inputs; data is a list of [inputs, labels]
-                inputs, labels = data[0].to(self.device), data[1].to(self.device)
-
-                # zero the parameter gradients
-                optimizer.zero_grad()
-
-                # forward + backward + optimize
-                outputs = self.forward(inputs)
-                loss = criterion(outputs, labels)   # This loss is the mean of losses for the batch
-                loss.backward()
-                optimizer.step()
-
-                # adds average batch loss to running loss
-                running_loss += loss.item()
-
-                # mini-batches for progress
-                if(i%10==0 and i!=0):
-                    print(f"{i}/{len(trainloader)}, temp. loss:{running_loss / len(trainloader)}")
-
-            # average loss
-            avg_loss = running_loss / len(trainloader)      # Running_loss / number of batches
-            print(f"Avg. loss: {avg_loss}")
-
-            # loss on validation
-            val_loss = self.evaluate_model(testloader, roc=False)
-
-            losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
-
-
-        print('Finished Training')
-        losses.to_csv('./cnn_net_data.csv')
-
-        # MAKES EPOCH VS AVG LOSS GRAPH
-        plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
-        plt.xlabel('Epoch')
-        plt.ylabel('Average Loss')
-        plt.title('Loss vs Epoch On Training & Validation data')
-
-        # MAKES EPOCH VS VALIDATION LOSS GRAPH
-        plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
-        plt.savefig('./avgloss_epoch_curve.png')
-        plt.show()
-
-        torch.save(self.state_dict(), PATH)
-        print("Model saved")
-
-    # TEST
-    def evaluate_model(self, testloader, roc):
-        correct = 0
-        total = 0
-
-        predictionsLabels = []
-        predictionsProbabilities = []
-        true_labels = []
-
-        criterion = nn.CrossEntropyLoss(reduction='mean')
-        self.eval()
-        # since we're not training, we don't need to calculate the gradients for our outputs
-        with torch.no_grad():
-            for data in testloader:
-                images, labels = data[0].to(self.device), data[1].to(self.device)
-                # calculate outputs by running images through the network
-                outputs = self.forward(images)
-                # the class with the highest energy is what we choose as prediction
-
-                loss = criterion(outputs, labels)  # mean loss from batch
-
-                # Gets accuracy
-                _, predicted = torch.max(outputs.data, 1)
-                total += labels.size(0)
-                correct += (predicted == labels).sum().item()
-
-                # Saves predictionsProbabilities and labels for ROC
-                if(roc):
-                    predictionsLabels.extend(predicted.cpu().numpy())
-                    predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy())     # Grabs probability of positive
-                    true_labels.extend(labels.cpu().numpy())
-
-        print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
-
-        if(not roc): print(f'Validation loss: {loss.item()}')
-        else:
-            # ROC
-            # Calculate TPR and FPR
-            fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
-
-            # Calculate AUC
-            roc_auc = auc(fpr, tpr)
-
-            plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
-            plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
-            plt.xlim([0.0, 1.005])
-            plt.ylim([0.0, 1.005])
-
-            plt.xlabel('False Positive Rate (1 - Specificity)')
-            plt.ylabel('True Positive Rate (Sensitivity)')
-            plt.title('Receiver Operating Characteristic (ROC) Curve')
-            plt.legend(loc="lower right")
-            plt.savefig('./ROC.png')
-            plt.show()
-
-
-            # Calculate confusion matrix
-            cm = confusion_matrix(true_labels, predictionsLabels)
-
-            # Plot confusion matrix
-            plt.figure(figsize=(8, 6))
-            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
-            plt.xlabel('Predicted labels')
-            plt.ylabel('True labels')
-            plt.title('Confusion Matrix')
-            plt.savefig('./confusion_matrix.png')
-            plt.show()
-
-            # Classification Report
-            report = classification_report(true_labels, predictionsLabels)
-            print(report)
-
-        self.train()
-
-        return(loss.item())
-
-
-    # PREDICT
-    def predict(self, loader):
-        self.eval()
-        with torch.no_grad():
-            for data in loader:
-                images, labels = data[0].to(self.device), data[1].to(self.device)
-                outputs = self.forward(images)
-                # the class with the highest energy is what we choose as prediction
-                _, predicted = torch.max(outputs.data, 1)
-        self.train()
-        return predicted
+        return x

+ 2 - 3
utils/CNN_Layers.py

@@ -1,6 +1,5 @@
-from torch import device, cuda
 import torch
-from torch import add
+# from torch import add
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.optim as optim
@@ -69,7 +68,7 @@ class Mid_flow(nn.Module):
         x = self.conv(x)
         # print(f"Output: {x.size()}")
 
-        x = add(x, residual)
+        x = torch.add(x, residual)
         x = self.elu(x)
 
         # return torch.matmul(x, self.weight) + self.bias       # TODO WHAT??? WEIGHT & BIAS YES OR NO?

+ 177 - 0
utils/train_methods.py

@@ -0,0 +1,177 @@
+import torch
+
+from torch import nn, optim
+from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
+import seaborn as sns
+
+
+# GENERAL PURPOSE
+import os
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import time
+
+
+# TRAIN
+def train(model, train_data, test_data, CNN_filepath, epochs=20, graphs=True):
+    model.train()
+    criterion = nn.CrossEntropyLoss(reduction='mean')
+    optimizer = optim.Adam(model.parameters(), lr=1e-5)
+
+    losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
+    start_time = time.time()  # seconds
+
+    for epoch in range(epochs):  # loop over the dataset multiple times
+        epoch += 1
+
+        # Estimate & count training time
+        t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
+        t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
+        print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
+
+        running_loss = 0.0
+
+        # Batches & training
+        for i, data in enumerate(train_data, 0):
+            # get the inputs; data is a list of [inputs, labels]
+            inputs, labels = data[0].to(model.device), data[1].to(model.device)
+
+            # zero the parameter gradients
+            optimizer.zero_grad()
+
+            # forward + backward + optimize
+            outputs = model.forward(inputs)
+            loss = criterion(outputs, labels)   # This loss is the mean of losses for the batch
+            loss.backward()
+            optimizer.step()
+
+            # adds average batch loss to running loss
+            running_loss += loss.item()
+
+            # mini-batches for progress
+            if(i%10==0 and i!=0):
+                print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}")
+
+        # average loss
+        avg_loss = running_loss / len(train_data)      # Running_loss / number of batches
+        print(f"Avg. loss: {avg_loss}")
+
+        # loss on validation
+        val_loss = evaluate(test_data, graphs)
+
+        losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
+
+
+    print('Finished Training')
+    losses.to_csv('./cnn_net_data.csv')
+
+    if(graphs):
+        # MAKES EPOCH VS AVG LOSS GRAPH
+        plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
+        plt.xlabel('Epoch')
+        plt.ylabel('Average Loss')
+        plt.title('Loss vs Epoch On Training & Validation data')
+
+        # MAKES EPOCH VS VALIDATION LOSS GRAPH
+        plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
+        plt.savefig('./avgloss_epoch_curve.png')
+        plt.show()
+
+        torch.save(model.state_dict(), CNN_filepath)
+        print("Model saved")
+
+
+
+def load(model, filepath):
+    model.load_state_dict(torch.load(filepath))
+
+
+def evaluate(model, val_data, graphs=True):
+# EVALUATE MODEL
+    correct = 0
+    total = 0
+
+    predictionsLabels = []
+    predictionsProbabilities = []
+    true_labels = []
+
+    criterion = nn.CrossEntropyLoss(reduction='mean')
+    model.eval()
+    # since we're not training, we don't need to calculate the gradients for our outputs
+    with torch.no_grad():
+        for data in val_data:
+            images, labels = data[0].to(model.device), data[1].to(model.device)
+            # calculate outputs by running images through the network
+            outputs = model.forward(images)
+            # the class with the highest energy is what we choose as prediction
+
+            loss = criterion(outputs, labels)  # mean loss from batch
+
+            # Gets accuracy
+            _, predicted = torch.max(outputs.data, 1)
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+            # Saves predictionsProbabilities and labels for ROC
+            if(graphs):
+                predictionsLabels.extend(predicted.cpu().numpy())
+                predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy())     # Grabs probability of positive
+                true_labels.extend(labels.cpu().numpy())
+
+    print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
+
+    if(not graphs): print(f'Validation loss: {loss.item()}')
+    else:
+        # ROC
+        # Calculate TPR and FPR
+        fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
+
+        # Calculate AUC
+        roc_auc = auc(fpr, tpr)
+
+        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
+        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
+        plt.xlim([0.0, 1.005])
+        plt.ylim([0.0, 1.005])
+
+        plt.xlabel('False Positive Rate (1 - Specificity)')
+        plt.ylabel('True Positive Rate (Sensitivity)')
+        plt.title('Receiver Operating Characteristic (ROC) Curve')
+        plt.legend(loc="lower right")
+        plt.savefig('./ROC.png')
+        plt.show()
+
+
+        # Calculate confusion matrix
+        cm = confusion_matrix(true_labels, predictionsLabels)
+
+        # Plot confusion matrix
+        plt.figure(figsize=(8, 6))
+        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
+        plt.xlabel('Predicted labels')
+        plt.ylabel('True labels')
+        plt.title('Confusion Matrix')
+        plt.savefig('./confusion_matrix.png')
+        plt.show()
+
+        # Classification Report
+        report = classification_report(true_labels, predictionsLabels)
+        print(report)
+
+    model.train()
+
+    return(loss.item())
+
+
+# PREDICT
+def predict(model, data):
+    model.eval()
+    with torch.no_grad():
+        for data in data:
+            images, labels = data[0].to(model.device), data[1].to(model.device)
+            outputs = model.forward(images)
+            # the class with the highest energy is what we choose as prediction
+            _, predicted = torch.max(outputs.data, 1)
+    model.train()
+    return (labels, predicted)  # RETURNS (true, predicted)