Ruben 6 місяців тому
батько
коміт
539c232cc5
2 змінених файлів з 21 додано та 7 видалено
  1. BIN
      confusion_matrix.png
  2. 21 7
      utils/CNN.py

BIN
confusion_matrix.png


+ 21 - 7
utils/CNN.py

@@ -9,7 +9,8 @@ import pandas as pd
 import matplotlib.pyplot as plt
 import time
 import numpy as np
-from sklearn.metrics import roc_curve, auc
+from sklearn.metrics import roc_curve, auc, confusion_matrix
+import seaborn as sns
 
 class CNN_Net(nn.Module):
     def __init__(self, prps, final_layer_size=5):
@@ -118,7 +119,8 @@ class CNN_Net(nn.Module):
         correct = 0
         total = 0
 
-        predictions = []
+        predictionsLabels = []
+        predictionsProbabilities = []
         true_labels = []
 
         criterion = nn.CrossEntropyLoss(reduction='mean')
@@ -138,9 +140,10 @@ class CNN_Net(nn.Module):
                 total += labels.size(0)
                 correct += (predicted == labels).sum().item()
 
-                # Saves predictions and labels for ROC
+                # Saves predictionsProbabilities and labels for ROC
                 if(roc):
-                    predictions.extend(outputs.data[:,1].cpu().numpy())     # Grabs probability of positive
+                    predictionsLabels.extend(predicted.cpu().numpy())
+                    predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy())     # Grabs probability of positive
                     true_labels.extend(labels.cpu().numpy())
 
         print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
@@ -148,10 +151,8 @@ class CNN_Net(nn.Module):
         if(not roc): print(f'Validation loss: {loss.item()}')
         else:
             # ROC
-            thresholds = np.linspace(0, 1, num=50)
-
             # Calculate TPR and FPR
-            fpr, tpr, thresholds = roc_curve(true_labels, predictions)
+            fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
 
             # Calculate AUC
             roc_auc = auc(fpr, tpr)
@@ -168,6 +169,19 @@ class CNN_Net(nn.Module):
             plt.savefig('./ROC.png')
             plt.show()
 
+
+            # Calculate confusion matrix
+            cm = confusion_matrix(true_labels, predictionsLabels)
+
+            # Plot confusion matrix
+            plt.figure(figsize=(8, 6))
+            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
+            plt.xlabel('Predicted labels')
+            plt.ylabel('True labels')
+            plt.title('Confusion Matrix')
+            plt.savefig('./confusion_matrix.png')
+            plt.show()
+
         self.train()
 
         return(loss.item())