CNN.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from torch import device, cuda
  2. import torch
  3. from torch import add
  4. import torch.nn as nn
  5. import utils.CNN_Layers as CustomLayers
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import utils.CNN_methods as CNN
  9. import pandas as pd
  10. import matplotlib.pyplot as plt
  11. import time
  12. import numpy as np
  13. # from sklearn.metrics import roc_curve, auc
  14. class CNN_Net(nn.Module):
  15. def __init__(self, input, prps, final_layer_size=5):
  16. super(CNN_Net, self).__init__()
  17. self.final_layer_size = final_layer_size
  18. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  19. print("CNN Initialized. Using: " + str(self.device))
  20. # GETS FIRST IMAGE FOR SIZE
  21. data_iter = iter(input)
  22. first_batch = next(data_iter)
  23. first_features = first_batch[0]
  24. image = first_features[0]
  25. # LAYERS
  26. print(f"CNN Model Initialization. Input size: {image.size()}")
  27. self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
  28. self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
  29. self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
  30. self.conv4_sepConv = CustomLayers.Conv_elu_maxpool_drop(384, 96,(3, 4, 3), stride=(1,1,1), pool=True, prps=prps,
  31. sep_conv=True)
  32. self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
  33. prps=prps, sep_conv=True)
  34. self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps, softmax=False) # TODO, concatenate clinical data after this
  35. self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True) # For now this works as output layer, though may be incorrect
  36. # FORWARDS
  37. def forward(self, x):
  38. x = self.conv1(x)
  39. x = self.conv2(x)
  40. x = self.conv3_mid_flow(x)
  41. x = self.conv4_sepConv(x)
  42. x = self.conv5_sepConv(x)
  43. # FLATTEN x
  44. flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
  45. x = x.view(-1, flatten_size)
  46. x = self.fc1(x)
  47. x = self.fc2(x)
  48. return x
  49. # TRAIN
  50. def train_model(self, trainloader, testloader, PATH, epochs):
  51. self.train()
  52. criterion = nn.CrossEntropyLoss(reduction='mean')
  53. optimizer = optim.Adam(self.parameters(), lr=1e-5)
  54. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  55. start_time = time.time() # seconds
  56. for epoch in range(epochs): # loop over the dataset multiple times
  57. epoch += 1
  58. # Estimate & count training time
  59. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  60. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  61. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  62. running_loss = 0.0
  63. # Batches & training
  64. for i, data in enumerate(trainloader, 0):
  65. # get the inputs; data is a list of [inputs, labels]
  66. inputs, labels = data[0].to(self.device), data[1].to(self.device)
  67. # zero the parameter gradients
  68. optimizer.zero_grad()
  69. # forward + backward + optimize
  70. outputs = self.forward(inputs)
  71. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  72. loss.backward()
  73. optimizer.step()
  74. # adds average batch loss to running loss
  75. running_loss += loss.item()
  76. # mini-batches for progress
  77. if(i%10==0 and i!=0):
  78. print(f"{i}/{len(trainloader)}, temp. loss:{running_loss / len(trainloader)}")
  79. # average loss
  80. avg_loss = running_loss / len(trainloader) # Running_loss / number of batches
  81. print(f"Avg. loss: {avg_loss}")
  82. # loss on validation
  83. val_loss = self.evaluate_model(testloader, roc=False)
  84. losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
  85. print('Finished Training')
  86. losses.to_csv('./cnn_net_data.csv')
  87. # MAKES EPOCH VS AVG LOSS GRAPH
  88. plt.plot(losses['Epoch'], losses['Avg_loss'])
  89. plt.xlabel('Epoch')
  90. plt.ylabel('Average Loss')
  91. plt.title('Average Loss vs Epoch On Training')
  92. plt.savefig('./avgloss_epoch_curve.png')
  93. plt.show()
  94. # MAKES EPOCH VS VALIDATION LOSS GRAPH
  95. plt.plot(losses['Epoch'], losses['Val_loss'])
  96. plt.xlabel('Epoch')
  97. plt.ylabel('Validation Loss')
  98. plt.title('Validation Loss vs Epoch On Training')
  99. plt.savefig('./valloss_epoch_curve.png')
  100. plt.show()
  101. torch.save(self.state_dict(), PATH)
  102. print("Model saved")
  103. # TEST
  104. def evaluate_model(self, testloader, roc):
  105. correct = 0
  106. total = 0
  107. predictions = []
  108. true_labels = []
  109. criterion = nn.CrossEntropyLoss(reduction='mean')
  110. self.eval()
  111. # since we're not training, we don't need to calculate the gradients for our outputs
  112. with torch.no_grad():
  113. for data in testloader:
  114. images, labels = data[0].to(self.device), data[1].to(self.device)
  115. # calculate outputs by running images through the network
  116. outputs = self.forward(images)
  117. # the class with the highest energy is what we choose as prediction
  118. loss = criterion(outputs, labels) # mean loss from batch
  119. # Gets accuracy
  120. _, predicted = torch.max(outputs.data, 1)
  121. total += labels.size(0)
  122. correct += (predicted == labels).sum().item()
  123. # Saves predictions and labels for ROC
  124. if(roc):
  125. predictions.extend(outputs.data[:,1].cpu().numpy()) # Grabs probability of positive
  126. true_labels.extend(labels.cpu().numpy())
  127. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  128. if(not roc): print(f'Validation loss: {loss.item()}')
  129. else:
  130. # ROC
  131. thresholds = np.linspace(0, 1, num=50)
  132. tpr = []
  133. fpr = []
  134. acc = []
  135. true_labels = np.array(true_labels)
  136. for threshold in thresholds:
  137. # Thresholding the predictions (meaning all predictions above threshold are considered positive)
  138. thresholded_predictions = (predictions >= threshold).astype(int)
  139. # Calculating true positives, false positives, true negatives, false negatives
  140. true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1))
  141. false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0))
  142. true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0))
  143. false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1))
  144. accuracy = (true_positives + true_negatives) / (true_positives + false_positives + true_negatives + false_negatives)
  145. # Calculate TPR and FPR
  146. tpr.append(true_positives / (true_positives + false_negatives))
  147. fpr.append(false_positives / (false_positives + true_negatives))
  148. acc.append(accuracy)
  149. plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve')
  150. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  151. plt.xlim([0.0, 1.0])
  152. plt.ylim([0.0, 1.0])
  153. plt.xlabel('False Positive Rate (1 - Specificity)')
  154. plt.ylabel('True Positive Rate (Sensitivity)')
  155. plt.title('Receiver Operating Characteristic (ROC) Curve')
  156. plt.legend(loc="lower right")
  157. plt.savefig('./ROC.png')
  158. plt.show()
  159. plt.plot(thresholds, acc)
  160. plt.xlabel('Thresholds')
  161. plt.ylabel('Accuracy')
  162. plt.title('Accuracy vs thresholds')
  163. plt.savefig('./acc.png')
  164. plt.show()
  165. # ROC ATTEMPT 2
  166. # fprRoc, tprRoc = roc_curve(true_labels, predictions)
  167. # plt.plot(fprRoc, tprRoc)
  168. self.train()
  169. return(loss.item())
  170. # PREDICT
  171. def predict(self, loader):
  172. self.eval()
  173. with torch.no_grad():
  174. for data in loader:
  175. images, labels = data[0].to(self.device), data[1].to(self.device)
  176. outputs = self.forward(images)
  177. # the class with the highest energy is what we choose as prediction
  178. _, predicted = torch.max(outputs.data, 1)
  179. self.train()
  180. return predicted