train_methods.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import torch
  2. from torch import nn, optim, cat, stack
  3. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score
  4. import seaborn as sns
  5. # GENERAL PURPOSE
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import time
  11. # TRAIN
  12. def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
  13. model.train()
  14. criterion = nn.CrossEntropyLoss(reduction='mean')
  15. optimizer = optim.Adam(model.parameters(), lr=1e-5) #, weight_decay=params['weight_decay'], betas=params['momentum'])
  16. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  17. start_time = time.time() # seconds
  18. # model.init_history()
  19. epochs = params["epochs"]
  20. for epoch in range(epochs): # loop over the dataset multiple times
  21. epoch += 1
  22. # Estimate & count training time
  23. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  24. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  25. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  26. running_loss = 0.0
  27. predictions = []
  28. # Batches & training
  29. for i, data in enumerate(train_data, 0):
  30. # get the inputs; data is a list of [inputs, labels]
  31. inputs, labels = data[0].to(model.device), data[1].to(model.device)
  32. # inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
  33. # zero the parameter gradients
  34. optimizer.zero_grad()
  35. # forward + backward + optimize
  36. outputs = model.forward(inputs)
  37. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  38. loss.backward()
  39. optimizer.step()
  40. # adds average batch loss to running loss
  41. running_loss += loss.item()
  42. # mini-batches for progress
  43. if(i%10==0 and i!=0):
  44. print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}")
  45. # Gets predictions for f1 metric
  46. # predictions = predictions.append(torch.max(outputs.data, 1)[1])
  47. # average loss
  48. avg_loss = running_loss / len(train_data) # Running_loss / number of batches
  49. print(f"Avg. loss: {avg_loss}")
  50. # loss on validation
  51. val_loss = evaluate(model, test_data, graphs=False) # , f1_validation
  52. losses = pd.concat([losses, pd.DataFrame([{'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}])])
  53. # model.append_loss(running_loss)
  54. # model.append_val_loss(val_loss)
  55. # f1_training = f1_score(test_data.data, predictions.data)
  56. # model.append_metric(f1_training)
  57. # model.append_val_metric(f1_validation)
  58. print('Finished Training')
  59. start_time = time.localtime()
  60. time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
  61. losses.to_csv(f'./cnn_net_data_{time_string}.csv')
  62. if(graphs):
  63. # MAKES EPOCH VS AVG LOSS GRAPH
  64. plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
  65. plt.xlabel('Epoch')
  66. plt.ylabel('Average Loss')
  67. plt.title('Loss vs Epoch On Training & Validation data')
  68. # PLOTS EPOCH VS VALIDATION LOSS ON GRAPH
  69. plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
  70. plt.legend(loc="lower right")
  71. plt.savefig(f"./avgloss_epoch_curve_{time_string}.png")
  72. print("AVG LOSS EPOCH CURVE IN TRAINING DONE")
  73. # plt.show()
  74. torch.save(model.state_dict(), CNN_filepath)
  75. print("Model saved")
  76. return model # , model.parameters()
  77. def load(model, filepath):
  78. model.load_state_dict(torch.load(filepath))
  79. def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None):
  80. start_time = time.localtime() # seconds
  81. correct, total = 0, 0
  82. predictionsLabels, predictionsProbabilities, true_labels = [], [], []
  83. # predictions = []
  84. criterion = nn.CrossEntropyLoss(reduction='mean')
  85. model.eval()
  86. # since we're not training, we don't need to calculate the gradients for our outputs
  87. with torch.no_grad():
  88. for data in val_data:
  89. images, labels = data[0].to(model.device), data[1].to(model.device)
  90. # images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
  91. # calculate outputs by running images through the model
  92. outputs = model.forward(images)
  93. # the class with the highest energy is what we choose as prediction
  94. loss = criterion(outputs, labels) # mean loss from batch
  95. # Gets accuracy
  96. predicted = torch.max(outputs.data, 1)[1]
  97. # predictions = predictions.append(predicted) # for F1 score
  98. total += labels.size(0)
  99. correct += (predicted == labels).sum().item()
  100. # Saves predictionsProbabilities and labels for ROC
  101. if(graphs):
  102. predictionsLabels.extend(predicted.cpu().numpy())
  103. predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
  104. true_labels.extend(labels.cpu().numpy())
  105. # K-FOLD MODE
  106. if(fold!=None):
  107. # Print accuracy
  108. print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
  109. print('--------------------------------')
  110. results[fold] = 100.0 * (correct / total)
  111. true_labels = np.array(true_labels)
  112. # ROC
  113. # Calculate TPR and FPR
  114. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  115. time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
  116. # Calculate AUC
  117. roc_auc = auc(fpr, tpr)
  118. plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold} (AUC: {roc_auc})')
  119. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  120. plt.xlim([0.0, 1.005])
  121. plt.ylim([0.0, 1.005])
  122. plt.xlabel('False Positive Rate (1 - Specificity)')
  123. plt.ylabel('True Positive Rate (Sensitivity)')
  124. plt.title('Receiver Operating Characteristic (ROC) Curve')
  125. plt.legend(loc="lower right")
  126. plt.savefig(f'./ROC_{k_folds}_Folds_{time_string}.png')
  127. print("SAVED ROC FOR K-FOLD")
  128. return results
  129. # NORMAL EVALUATION
  130. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  131. if(not graphs): print(f'Validation loss: {loss.item()}')
  132. else:
  133. time_string = time.strftime("%Y-%m-%d_%H.%M", start_time)
  134. # ROC
  135. # Calculate TPR and FPR
  136. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  137. # Calculate AUC
  138. roc_auc = auc(fpr, tpr)
  139. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
  140. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  141. plt.xlim([0.0, 1.005])
  142. plt.ylim([0.0, 1.005])
  143. plt.xlabel('False Positive Rate (1 - Specificity)')
  144. plt.ylabel('True Positive Rate (Sensitivity)')
  145. plt.title('Receiver Operating Characteristic (ROC) Curve')
  146. plt.legend(loc="lower right")
  147. plt.savefig(f'./ROC_{time_string}.png')
  148. print("SAVED ROC FOR NORMAL")
  149. # plt.show()
  150. # Calculate confusion matrix
  151. cm = confusion_matrix(true_labels, predictionsLabels)
  152. # Plot confusion matrix
  153. plt.figure(figsize=(8, 6))
  154. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
  155. plt.xlabel('Predicted labels')
  156. plt.ylabel('True labels')
  157. plt.title('Confusion Matrix')
  158. plt.savefig(f'./confusion_matrix_{time_string}.png')
  159. # plt.show()
  160. # Classification Report
  161. report = classification_report(true_labels, predictionsLabels)
  162. print(report)
  163. # f1_validation = f1_score(val_data, predictions.data)
  164. model.train()
  165. return loss.item() # , f1_validation)
  166. # PREDICT
  167. def predict(model, data):
  168. model.eval()
  169. with torch.no_grad():
  170. for data in data:
  171. images, labels = data[0].to(model.device), data[1].to(model.device)
  172. outputs = model.forward(images)
  173. # the class with the highest energy is what we choose as prediction
  174. _, predicted = torch.max(outputs.data, 1)
  175. model.train()
  176. return (labels, predicted) # RETURNS (true, predicted)