CNN.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from torch import device, cuda
  2. import torch
  3. from torch import add
  4. import torch.nn as nn
  5. import utils.CNN_Layers as CustomLayers
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import pandas as pd
  9. import matplotlib.pyplot as plt
  10. import time
  11. import numpy as np
  12. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
  13. import seaborn as sns
  14. class CNN_Net(nn.Module):
  15. def __init__(self, prps, final_layer_size=5):
  16. super(CNN_Net, self).__init__()
  17. self.final_layer_size = final_layer_size
  18. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  19. print("CNN Initialized. Using: " + str(self.device))
  20. # LAYERS
  21. print(f"CNN Model Initialization")
  22. self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
  23. self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
  24. self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
  25. self.conv4_sepConv = CustomLayers.Conv_elu_maxpool_drop(384, 96,(3, 4, 3), stride=(1,1,1), pool=True, prps=prps,
  26. sep_conv=True)
  27. self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
  28. prps=prps, sep_conv=True)
  29. self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps, softmax=False) # TODO, concatenate clinical data after this
  30. self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True) # For now this works as output layer, though may be incorrect
  31. # FORWARDS
  32. def forward(self, x):
  33. x = self.conv1(x)
  34. x = self.conv2(x)
  35. x = self.conv3_mid_flow(x)
  36. x = self.conv4_sepConv(x)
  37. x = self.conv5_sepConv(x)
  38. # FLATTEN x
  39. flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
  40. x = x.view(-1, flatten_size)
  41. x = self.fc1(x)
  42. x = self.fc2(x)
  43. return x
  44. # TRAIN
  45. def train_model(self, trainloader, testloader, PATH, epochs):
  46. self.train()
  47. criterion = nn.CrossEntropyLoss(reduction='mean')
  48. optimizer = optim.Adam(self.parameters(), lr=1e-5)
  49. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  50. start_time = time.time() # seconds
  51. for epoch in range(epochs): # loop over the dataset multiple times
  52. epoch += 1
  53. # Estimate & count training time
  54. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  55. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  56. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  57. running_loss = 0.0
  58. # Batches & training
  59. for i, data in enumerate(trainloader, 0):
  60. # get the inputs; data is a list of [inputs, labels]
  61. inputs, labels = data[0].to(self.device), data[1].to(self.device)
  62. # zero the parameter gradients
  63. optimizer.zero_grad()
  64. # forward + backward + optimize
  65. outputs = self.forward(inputs)
  66. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  67. loss.backward()
  68. optimizer.step()
  69. # adds average batch loss to running loss
  70. running_loss += loss.item()
  71. # mini-batches for progress
  72. if(i%10==0 and i!=0):
  73. print(f"{i}/{len(trainloader)}, temp. loss:{running_loss / len(trainloader)}")
  74. # average loss
  75. avg_loss = running_loss / len(trainloader) # Running_loss / number of batches
  76. print(f"Avg. loss: {avg_loss}")
  77. # loss on validation
  78. val_loss = self.evaluate_model(testloader, roc=False)
  79. losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
  80. print('Finished Training')
  81. losses.to_csv('./cnn_net_data.csv')
  82. # MAKES EPOCH VS AVG LOSS GRAPH
  83. plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
  84. plt.xlabel('Epoch')
  85. plt.ylabel('Average Loss')
  86. plt.title('Loss vs Epoch On Training & Validation data')
  87. # MAKES EPOCH VS VALIDATION LOSS GRAPH
  88. plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
  89. plt.savefig('./avgloss_epoch_curve.png')
  90. plt.show()
  91. torch.save(self.state_dict(), PATH)
  92. print("Model saved")
  93. # TEST
  94. def evaluate_model(self, testloader, roc):
  95. correct = 0
  96. total = 0
  97. predictionsLabels = []
  98. predictionsProbabilities = []
  99. true_labels = []
  100. criterion = nn.CrossEntropyLoss(reduction='mean')
  101. self.eval()
  102. # since we're not training, we don't need to calculate the gradients for our outputs
  103. with torch.no_grad():
  104. for data in testloader:
  105. images, labels = data[0].to(self.device), data[1].to(self.device)
  106. # calculate outputs by running images through the network
  107. outputs = self.forward(images)
  108. # the class with the highest energy is what we choose as prediction
  109. loss = criterion(outputs, labels) # mean loss from batch
  110. # Gets accuracy
  111. _, predicted = torch.max(outputs.data, 1)
  112. total += labels.size(0)
  113. correct += (predicted == labels).sum().item()
  114. # Saves predictionsProbabilities and labels for ROC
  115. if(roc):
  116. predictionsLabels.extend(predicted.cpu().numpy())
  117. predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
  118. true_labels.extend(labels.cpu().numpy())
  119. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  120. if(not roc): print(f'Validation loss: {loss.item()}')
  121. else:
  122. # ROC
  123. # Calculate TPR and FPR
  124. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  125. # Calculate AUC
  126. roc_auc = auc(fpr, tpr)
  127. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
  128. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  129. plt.xlim([0.0, 1.005])
  130. plt.ylim([0.0, 1.005])
  131. plt.xlabel('False Positive Rate (1 - Specificity)')
  132. plt.ylabel('True Positive Rate (Sensitivity)')
  133. plt.title('Receiver Operating Characteristic (ROC) Curve')
  134. plt.legend(loc="lower right")
  135. plt.savefig('./ROC.png')
  136. plt.show()
  137. # Calculate confusion matrix
  138. cm = confusion_matrix(true_labels, predictionsLabels)
  139. # Plot confusion matrix
  140. plt.figure(figsize=(8, 6))
  141. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
  142. plt.xlabel('Predicted labels')
  143. plt.ylabel('True labels')
  144. plt.title('Confusion Matrix')
  145. plt.savefig('./confusion_matrix.png')
  146. plt.show()
  147. # Classification Report
  148. report = classification_report(true_labels, predictionsLabels)
  149. print(report)
  150. self.train()
  151. return(loss.item())
  152. # PREDICT
  153. def predict(self, loader):
  154. self.eval()
  155. with torch.no_grad():
  156. for data in loader:
  157. images, labels = data[0].to(self.device), data[1].to(self.device)
  158. outputs = self.forward(images)
  159. # the class with the highest energy is what we choose as prediction
  160. _, predicted = torch.max(outputs.data, 1)
  161. self.train()
  162. return predicted