train_methods.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import torch
  2. from torch import nn, optim, cat, stack
  3. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score
  4. import seaborn as sns
  5. # GENERAL PURPOSE
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import time
  11. # TRAIN
  12. def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
  13. model.train()
  14. criterion = nn.CrossEntropyLoss(reduction='mean')
  15. optimizer = optim.Adam(model.parameters(), lr=1e-5) #, weight_decay=params['weight_decay'], betas=params['momentum'])
  16. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  17. start_time = time.time() # seconds
  18. # model.init_history()
  19. epochs = params['epochs']
  20. for epoch in range(epochs): # loop over the dataset multiple times
  21. epoch += 1
  22. # Estimate & count training time
  23. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  24. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  25. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  26. running_loss = 0.0
  27. predictions = []
  28. # Batches & training
  29. for i, data in enumerate(train_data, 0):
  30. # get the inputs; data is a list of [inputs, labels]
  31. inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
  32. # zero the parameter gradients
  33. optimizer.zero_grad()
  34. # forward + backward + optimize
  35. outputs = model.forward(inputs)
  36. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  37. loss.backward()
  38. optimizer.step()
  39. # adds average batch loss to running loss
  40. running_loss += loss.item()
  41. # mini-batches for progress
  42. if(i%10==0 and i!=0):
  43. print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}")
  44. # Gets predictions for f1 metric
  45. # predictions = predictions.append(torch.max(outputs.data, 1)[1])
  46. # average loss
  47. avg_loss = running_loss / len(train_data) # Running_loss / number of batches
  48. print(f"Avg. loss: {avg_loss}")
  49. # loss on validation
  50. val_loss = evaluate(model, test_data, graphs=False) # , f1_validation
  51. losses = pd.concat([losses, pd.DataFrame([{'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}])])
  52. # model.append_loss(running_loss)
  53. # model.append_val_loss(val_loss)
  54. # f1_training = f1_score(test_data.data, predictions.data)
  55. # model.append_metric(f1_training)
  56. # model.append_val_metric(f1_validation)
  57. print('Finished Training')
  58. start_time = time.localtime()
  59. time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
  60. losses.to_csv(f'./cnn_net_data_{time_string}.csv')
  61. if(graphs):
  62. # MAKES EPOCH VS AVG LOSS GRAPH
  63. plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
  64. plt.xlabel('Epoch')
  65. plt.ylabel('Average Loss')
  66. plt.title('Loss vs Epoch On Training & Validation data')
  67. # PLOTS EPOCH VS VALIDATION LOSS ON GRAPH
  68. plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
  69. plt.legend(loc="lower right")
  70. plt.savefig(f"./avgloss_epoch_curve_{time_string}.png")
  71. print("AVG LOSS EPOCH CURVE IN TRAINING DONE")
  72. # plt.show()
  73. torch.save(model.state_dict(), CNN_filepath)
  74. print("Model saved")
  75. return model # , model.parameters()
  76. def load(model, filepath):
  77. model.load_state_dict(torch.load(filepath))
  78. def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None):
  79. start_time = time.localtime() # seconds
  80. correct, total = 0, 0
  81. predictionsLabels, predictionsProbabilities, true_labels = [], [], []
  82. # predictions = []
  83. criterion = nn.CrossEntropyLoss(reduction='mean')
  84. model.eval()
  85. # since we're not training, we don't need to calculate the gradients for our outputs
  86. with torch.no_grad():
  87. for data in val_data:
  88. images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
  89. # calculate outputs by running images through the model
  90. outputs = model.forward(images)
  91. # the class with the highest energy is what we choose as prediction
  92. loss = criterion(outputs, labels) # mean loss from batch
  93. # Gets accuracy
  94. predicted = torch.max(outputs.data, 1)[1]
  95. # predictions = predictions.append(predicted) # for F1 score
  96. total += labels.size(0)
  97. correct += (predicted == labels).sum().item()
  98. # Saves predictionsProbabilities and labels for ROC
  99. if(graphs):
  100. predictionsLabels.extend(predicted.cpu().numpy())
  101. predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
  102. true_labels.extend(labels.cpu().numpy())
  103. # K-FOLD MODE
  104. if(fold!=None):
  105. # Print accuracy
  106. print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
  107. print('--------------------------------')
  108. results[fold] = 100.0 * (correct / total)
  109. true_labels = np.array(true_labels)
  110. # ROC
  111. # Calculate TPR and FPR
  112. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  113. time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
  114. # Calculate AUC
  115. roc_auc = auc(fpr, tpr)
  116. plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold} (AUC: {roc_auc})')
  117. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  118. plt.xlim([0.0, 1.005])
  119. plt.ylim([0.0, 1.005])
  120. plt.xlabel('False Positive Rate (1 - Specificity)')
  121. plt.ylabel('True Positive Rate (Sensitivity)')
  122. plt.title('Receiver Operating Characteristic (ROC) Curve')
  123. plt.legend(loc="lower right")
  124. plt.savefig(f'./ROC_{k_folds}_Folds_{time_string}.png')
  125. print("SAVED ROC FOR K-FOLD")
  126. return results
  127. # NORMAL EVALUATION
  128. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  129. if(not graphs): print(f'Validation loss: {loss.item()}')
  130. else:
  131. time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
  132. # ROC
  133. # Calculate TPR and FPR
  134. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  135. # Calculate AUC
  136. roc_auc = auc(fpr, tpr)
  137. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
  138. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  139. plt.xlim([0.0, 1.005])
  140. plt.ylim([0.0, 1.005])
  141. plt.xlabel('False Positive Rate (1 - Specificity)')
  142. plt.ylabel('True Positive Rate (Sensitivity)')
  143. plt.title('Receiver Operating Characteristic (ROC) Curve')
  144. plt.legend(loc="lower right")
  145. plt.savefig(f'./ROC_{time_string}.png')
  146. print("SAVED ROC FOR NORMAL")
  147. # plt.show()
  148. # Calculate confusion matrix
  149. cm = confusion_matrix(true_labels, predictionsLabels)
  150. # Plot confusion matrix
  151. plt.figure(figsize=(8, 6))
  152. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
  153. plt.xlabel('Predicted labels')
  154. plt.ylabel('True labels')
  155. plt.title('Confusion Matrix')
  156. plt.savefig(f'./confusion_matrix_{time_string}.png')
  157. # plt.show()
  158. # Classification Report
  159. report = classification_report(true_labels, predictionsLabels)
  160. print(report)
  161. # f1_validation = f1_score(val_data, predictions.data)
  162. model.train()
  163. return loss.item() # , f1_validation)
  164. # PREDICT
  165. def predict(model, data):
  166. model.eval()
  167. with torch.no_grad():
  168. for data in data:
  169. images, labels = data[0].to(model.device), data[1].to(model.device)
  170. outputs = model.forward(images)
  171. # the class with the highest energy is what we choose as prediction
  172. _, predicted = torch.max(outputs.data, 1)
  173. model.train()
  174. return (labels, predicted) # RETURNS (true, predicted)