|
@@ -5,28 +5,21 @@ import torch.nn as nn
|
|
import utils.CNN_Layers as CustomLayers
|
|
import utils.CNN_Layers as CustomLayers
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
import torch.optim as optim
|
|
-import utils.CNN_methods as CNN
|
|
|
|
import pandas as pd
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.pyplot as plt
|
|
import time
|
|
import time
|
|
import numpy as np
|
|
import numpy as np
|
|
-# from sklearn.metrics import roc_curve, auc
|
|
|
|
|
|
+from sklearn.metrics import roc_curve, auc
|
|
|
|
|
|
class CNN_Net(nn.Module):
|
|
class CNN_Net(nn.Module):
|
|
- def __init__(self, input, prps, final_layer_size=5):
|
|
|
|
|
|
+ def __init__(self, prps, final_layer_size=5):
|
|
super(CNN_Net, self).__init__()
|
|
super(CNN_Net, self).__init__()
|
|
self.final_layer_size = final_layer_size
|
|
self.final_layer_size = final_layer_size
|
|
self.device = device('cuda:0' if cuda.is_available() else 'cpu')
|
|
self.device = device('cuda:0' if cuda.is_available() else 'cpu')
|
|
print("CNN Initialized. Using: " + str(self.device))
|
|
print("CNN Initialized. Using: " + str(self.device))
|
|
|
|
|
|
- # GETS FIRST IMAGE FOR SIZE
|
|
|
|
- data_iter = iter(input)
|
|
|
|
- first_batch = next(data_iter)
|
|
|
|
- first_features = first_batch[0]
|
|
|
|
- image = first_features[0]
|
|
|
|
-
|
|
|
|
# LAYERS
|
|
# LAYERS
|
|
- print(f"CNN Model Initialization. Input size: {image.size()}")
|
|
|
|
|
|
+ print(f"CNN Model Initialization")
|
|
self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
|
|
self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
|
|
self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
|
|
self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
|
|
self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
|
|
self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
|
|
@@ -107,19 +100,14 @@ class CNN_Net(nn.Module):
|
|
losses.to_csv('./cnn_net_data.csv')
|
|
losses.to_csv('./cnn_net_data.csv')
|
|
|
|
|
|
# MAKES EPOCH VS AVG LOSS GRAPH
|
|
# MAKES EPOCH VS AVG LOSS GRAPH
|
|
- plt.plot(losses['Epoch'], losses['Avg_loss'])
|
|
|
|
|
|
+ plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
|
|
plt.xlabel('Epoch')
|
|
plt.xlabel('Epoch')
|
|
plt.ylabel('Average Loss')
|
|
plt.ylabel('Average Loss')
|
|
- plt.title('Average Loss vs Epoch On Training')
|
|
|
|
- plt.savefig('./avgloss_epoch_curve.png')
|
|
|
|
- plt.show()
|
|
|
|
|
|
+ plt.title('Loss vs Epoch On Training & Validation data')
|
|
|
|
|
|
# MAKES EPOCH VS VALIDATION LOSS GRAPH
|
|
# MAKES EPOCH VS VALIDATION LOSS GRAPH
|
|
- plt.plot(losses['Epoch'], losses['Val_loss'])
|
|
|
|
- plt.xlabel('Epoch')
|
|
|
|
- plt.ylabel('Validation Loss')
|
|
|
|
- plt.title('Validation Loss vs Epoch On Training')
|
|
|
|
- plt.savefig('./valloss_epoch_curve.png')
|
|
|
|
|
|
+ plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
|
|
|
|
+ plt.savefig('./avgloss_epoch_curve.png')
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
torch.save(self.state_dict(), PATH)
|
|
torch.save(self.state_dict(), PATH)
|
|
@@ -161,35 +149,17 @@ class CNN_Net(nn.Module):
|
|
else:
|
|
else:
|
|
# ROC
|
|
# ROC
|
|
thresholds = np.linspace(0, 1, num=50)
|
|
thresholds = np.linspace(0, 1, num=50)
|
|
- tpr = []
|
|
|
|
- fpr = []
|
|
|
|
- acc = []
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- true_labels = np.array(true_labels)
|
|
|
|
|
|
|
|
- for threshold in thresholds:
|
|
|
|
- # Thresholding the predictions (meaning all predictions above threshold are considered positive)
|
|
|
|
- thresholded_predictions = (predictions >= threshold).astype(int)
|
|
|
|
|
|
+ # Calculate TPR and FPR
|
|
|
|
+ fpr, tpr, thresholds = roc_curve(true_labels, predictions)
|
|
|
|
|
|
- # Calculating true positives, false positives, true negatives, false negatives
|
|
|
|
- true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1))
|
|
|
|
- false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0))
|
|
|
|
- true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0))
|
|
|
|
- false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1))
|
|
|
|
|
|
+ # Calculate AUC
|
|
|
|
+ roc_auc = auc(fpr, tpr)
|
|
|
|
|
|
- accuracy = (true_positives + true_negatives) / (true_positives + false_positives + true_negatives + false_negatives)
|
|
|
|
-
|
|
|
|
- # Calculate TPR and FPR
|
|
|
|
- tpr.append(true_positives / (true_positives + false_negatives))
|
|
|
|
- fpr.append(false_positives / (false_positives + true_negatives))
|
|
|
|
- acc.append(accuracy)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve')
|
|
|
|
|
|
+ plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
|
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
|
- plt.xlim([0.0, 1.0])
|
|
|
|
- plt.ylim([0.0, 1.0])
|
|
|
|
|
|
+ plt.xlim([0.0, 1.005])
|
|
|
|
+ plt.ylim([0.0, 1.005])
|
|
|
|
|
|
plt.xlabel('False Positive Rate (1 - Specificity)')
|
|
plt.xlabel('False Positive Rate (1 - Specificity)')
|
|
plt.ylabel('True Positive Rate (Sensitivity)')
|
|
plt.ylabel('True Positive Rate (Sensitivity)')
|
|
@@ -198,18 +168,6 @@ class CNN_Net(nn.Module):
|
|
plt.savefig('./ROC.png')
|
|
plt.savefig('./ROC.png')
|
|
plt.show()
|
|
plt.show()
|
|
|
|
|
|
- plt.plot(thresholds, acc)
|
|
|
|
- plt.xlabel('Thresholds')
|
|
|
|
- plt.ylabel('Accuracy')
|
|
|
|
- plt.title('Accuracy vs thresholds')
|
|
|
|
- plt.savefig('./acc.png')
|
|
|
|
- plt.show()
|
|
|
|
-
|
|
|
|
-
|
|
|
|
- # ROC ATTEMPT 2
|
|
|
|
- # fprRoc, tprRoc = roc_curve(true_labels, predictions)
|
|
|
|
- # plt.plot(fprRoc, tprRoc)
|
|
|
|
-
|
|
|
|
self.train()
|
|
self.train()
|
|
|
|
|
|
return(loss.item())
|
|
return(loss.item())
|