|
@@ -9,7 +9,8 @@ import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
import time
|
|
import numpy as np
|
|
import numpy as np
|
|
-from sklearn.metrics import roc_curve, auc
|
|
|
|
|
|
+from sklearn.metrics import roc_curve, auc, confusion_matrix
|
|
|
|
+import seaborn as sns
|
|
|
|
|
|
class CNN_Net(nn.Module):
|
|
class CNN_Net(nn.Module):
|
|
def __init__(self, prps, final_layer_size=5):
|
|
def __init__(self, prps, final_layer_size=5):
|
|
@@ -118,7 +119,8 @@ class CNN_Net(nn.Module):
|
|
correct = 0
|
|
correct = 0
|
|
total = 0
|
|
total = 0
|
|
|
|
|
|
- predictions = []
|
|
|
|
|
|
+ predictionsLabels = []
|
|
|
|
+ predictionsProbabilities = []
|
|
true_labels = []
|
|
true_labels = []
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(reduction='mean')
|
|
criterion = nn.CrossEntropyLoss(reduction='mean')
|
|
@@ -138,9 +140,10 @@ class CNN_Net(nn.Module):
|
|
total += labels.size(0)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
|
|
- # Saves predictions and labels for ROC
|
|
|
|
|
|
+ # Saves predictionsProbabilities and labels for ROC
|
|
if(roc):
|
|
if(roc):
|
|
- predictions.extend(outputs.data[:,1].cpu().numpy()) # Grabs probability of positive
|
|
|
|
|
|
+ predictionsLabels.extend(predicted.cpu().numpy())
|
|
|
|
+ predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
|
|
true_labels.extend(labels.cpu().numpy())
|
|
true_labels.extend(labels.cpu().numpy())
|
|
|
|
|
|
print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
|
|
print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
|
|
@@ -148,10 +151,8 @@ class CNN_Net(nn.Module):
|
|
if(not roc): print(f'Validation loss: {loss.item()}')
|
|
if(not roc): print(f'Validation loss: {loss.item()}')
|
|
else:
|
|
else:
|
|
# ROC
|
|
# ROC
|
|
- thresholds = np.linspace(0, 1, num=50)
|
|
|
|
-
|
|
|
|
# Calculate TPR and FPR
|
|
# Calculate TPR and FPR
|
|
- fpr, tpr, thresholds = roc_curve(true_labels, predictions)
|
|
|
|
|
|
+ fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
|
|
|
|
|
|
# Calculate AUC
|
|
# Calculate AUC
|
|
roc_auc = auc(fpr, tpr)
|
|
roc_auc = auc(fpr, tpr)
|
|
@@ -168,6 +169,19 @@ class CNN_Net(nn.Module):
|
|
plt.savefig('./ROC.png')
|
|
plt.savefig('./ROC.png')
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
|
|
+
|
|
|
|
+ # Calculate confusion matrix
|
|
|
|
+ cm = confusion_matrix(true_labels, predictionsLabels)
|
|
|
|
+
|
|
|
|
+ # Plot confusion matrix
|
|
|
|
+ plt.figure(figsize=(8, 6))
|
|
|
|
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
|
|
|
|
+ plt.xlabel('Predicted labels')
|
|
|
|
+ plt.ylabel('True labels')
|
|
|
|
+ plt.title('Confusion Matrix')
|
|
|
|
+ plt.savefig('./confusion_matrix.png')
|
|
|
|
+ plt.show()
|
|
|
|
+
|
|
self.train()
|
|
self.train()
|
|
|
|
|
|
return(loss.item())
|
|
return(loss.item())
|