CNN.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from torch import device, cuda
  2. import torch
  3. from torch import add
  4. import torch.nn as nn
  5. import utils.CNN_Layers as CustomLayers
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. import pandas as pd
  9. import matplotlib.pyplot as plt
  10. import time
  11. import numpy as np
  12. from sklearn.metrics import roc_curve, auc
  13. class CNN_Net(nn.Module):
  14. def __init__(self, prps, final_layer_size=5):
  15. super(CNN_Net, self).__init__()
  16. self.final_layer_size = final_layer_size
  17. self.device = device('cuda:0' if cuda.is_available() else 'cpu')
  18. print("CNN Initialized. Using: " + str(self.device))
  19. # LAYERS
  20. print(f"CNN Model Initialization")
  21. self.conv1 = CustomLayers.Conv_elu_maxpool_drop(1, 192, (11, 13, 11), stride=(4,4,4), pool=True, prps=prps)
  22. self.conv2 = CustomLayers.Conv_elu_maxpool_drop(192, 384, (5, 6, 5), stride=(1,1,1), pool=True, prps=prps)
  23. self.conv3_mid_flow = CustomLayers.Mid_flow(384, 384, prps=prps)
  24. self.conv4_sepConv = CustomLayers.Conv_elu_maxpool_drop(384, 96,(3, 4, 3), stride=(1,1,1), pool=True, prps=prps,
  25. sep_conv=True)
  26. self.conv5_sepConv = CustomLayers.Conv_elu_maxpool_drop(96, 48, (3, 4, 3), stride=(1, 1, 1), pool=True,
  27. prps=prps, sep_conv=True)
  28. self.fc1 = CustomLayers.Fc_elu_drop(113568, 20, prps=prps, softmax=False) # TODO, concatenate clinical data after this
  29. self.fc2 = CustomLayers.Fc_elu_drop(20, final_layer_size, prps=prps, softmax=True) # For now this works as output layer, though may be incorrect
  30. # FORWARDS
  31. def forward(self, x):
  32. x = self.conv1(x)
  33. x = self.conv2(x)
  34. x = self.conv3_mid_flow(x)
  35. x = self.conv4_sepConv(x)
  36. x = self.conv5_sepConv(x)
  37. # FLATTEN x
  38. flatten_size = x.size(1) * x.size(2) * x.size(3) * x.size(4)
  39. x = x.view(-1, flatten_size)
  40. x = self.fc1(x)
  41. x = self.fc2(x)
  42. return x
  43. # TRAIN
  44. def train_model(self, trainloader, testloader, PATH, epochs):
  45. self.train()
  46. criterion = nn.CrossEntropyLoss(reduction='mean')
  47. optimizer = optim.Adam(self.parameters(), lr=1e-5)
  48. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  49. start_time = time.time() # seconds
  50. for epoch in range(epochs): # loop over the dataset multiple times
  51. epoch += 1
  52. # Estimate & count training time
  53. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  54. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  55. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  56. running_loss = 0.0
  57. # Batches & training
  58. for i, data in enumerate(trainloader, 0):
  59. # get the inputs; data is a list of [inputs, labels]
  60. inputs, labels = data[0].to(self.device), data[1].to(self.device)
  61. # zero the parameter gradients
  62. optimizer.zero_grad()
  63. # forward + backward + optimize
  64. outputs = self.forward(inputs)
  65. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  66. loss.backward()
  67. optimizer.step()
  68. # adds average batch loss to running loss
  69. running_loss += loss.item()
  70. # mini-batches for progress
  71. if(i%10==0 and i!=0):
  72. print(f"{i}/{len(trainloader)}, temp. loss:{running_loss / len(trainloader)}")
  73. # average loss
  74. avg_loss = running_loss / len(trainloader) # Running_loss / number of batches
  75. print(f"Avg. loss: {avg_loss}")
  76. # loss on validation
  77. val_loss = self.evaluate_model(testloader, roc=False)
  78. losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
  79. print('Finished Training')
  80. losses.to_csv('./cnn_net_data.csv')
  81. # MAKES EPOCH VS AVG LOSS GRAPH
  82. plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
  83. plt.xlabel('Epoch')
  84. plt.ylabel('Average Loss')
  85. plt.title('Loss vs Epoch On Training & Validation data')
  86. # MAKES EPOCH VS VALIDATION LOSS GRAPH
  87. plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
  88. plt.savefig('./avgloss_epoch_curve.png')
  89. plt.show()
  90. torch.save(self.state_dict(), PATH)
  91. print("Model saved")
  92. # TEST
  93. def evaluate_model(self, testloader, roc):
  94. correct = 0
  95. total = 0
  96. predictions = []
  97. true_labels = []
  98. criterion = nn.CrossEntropyLoss(reduction='mean')
  99. self.eval()
  100. # since we're not training, we don't need to calculate the gradients for our outputs
  101. with torch.no_grad():
  102. for data in testloader:
  103. images, labels = data[0].to(self.device), data[1].to(self.device)
  104. # calculate outputs by running images through the network
  105. outputs = self.forward(images)
  106. # the class with the highest energy is what we choose as prediction
  107. loss = criterion(outputs, labels) # mean loss from batch
  108. # Gets accuracy
  109. _, predicted = torch.max(outputs.data, 1)
  110. total += labels.size(0)
  111. correct += (predicted == labels).sum().item()
  112. # Saves predictions and labels for ROC
  113. if(roc):
  114. predictions.extend(outputs.data[:,1].cpu().numpy()) # Grabs probability of positive
  115. true_labels.extend(labels.cpu().numpy())
  116. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  117. if(not roc): print(f'Validation loss: {loss.item()}')
  118. else:
  119. # ROC
  120. thresholds = np.linspace(0, 1, num=50)
  121. # Calculate TPR and FPR
  122. fpr, tpr, thresholds = roc_curve(true_labels, predictions)
  123. # Calculate AUC
  124. roc_auc = auc(fpr, tpr)
  125. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
  126. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  127. plt.xlim([0.0, 1.005])
  128. plt.ylim([0.0, 1.005])
  129. plt.xlabel('False Positive Rate (1 - Specificity)')
  130. plt.ylabel('True Positive Rate (Sensitivity)')
  131. plt.title('Receiver Operating Characteristic (ROC) Curve')
  132. plt.legend(loc="lower right")
  133. plt.savefig('./ROC.png')
  134. plt.show()
  135. self.train()
  136. return(loss.item())
  137. # PREDICT
  138. def predict(self, loader):
  139. self.eval()
  140. with torch.no_grad():
  141. for data in loader:
  142. images, labels = data[0].to(self.device), data[1].to(self.device)
  143. outputs = self.forward(images)
  144. # the class with the highest energy is what we choose as prediction
  145. _, predicted = torch.max(outputs.data, 1)
  146. self.train()
  147. return predicted