train_methods.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import torch
  2. from torch import nn, optim
  3. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
  4. import seaborn as sns
  5. # GENERAL PURPOSE
  6. import os
  7. import pandas as pd
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. import time
  11. # TRAIN
  12. def train(model, train_data, test_data, CNN_filepath, epochs=20, graphs=True):
  13. model.train()
  14. criterion = nn.CrossEntropyLoss(reduction='mean')
  15. optimizer = optim.Adam(model.parameters(), lr=1e-5)
  16. losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
  17. start_time = time.time() # seconds
  18. for epoch in range(epochs): # loop over the dataset multiple times
  19. epoch += 1
  20. # Estimate & count training time
  21. t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
  22. t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
  23. print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
  24. running_loss = 0.0
  25. # Batches & training
  26. for i, data in enumerate(train_data, 0):
  27. # get the inputs; data is a list of [inputs, labels]
  28. inputs, labels = data[0].to(model.device), data[1].to(model.device)
  29. # zero the parameter gradients
  30. optimizer.zero_grad()
  31. # forward + backward + optimize
  32. outputs = model.forward(inputs)
  33. loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
  34. loss.backward()
  35. optimizer.step()
  36. # adds average batch loss to running loss
  37. running_loss += loss.item()
  38. # mini-batches for progress
  39. if(i%10==0 and i!=0):
  40. print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}")
  41. # average loss
  42. avg_loss = running_loss / len(train_data) # Running_loss / number of batches
  43. print(f"Avg. loss: {avg_loss}")
  44. # loss on validation
  45. val_loss = evaluate(test_data, graphs)
  46. losses = losses.append({'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}, ignore_index=True)
  47. print('Finished Training')
  48. losses.to_csv('./cnn_net_data.csv')
  49. if(graphs):
  50. # MAKES EPOCH VS AVG LOSS GRAPH
  51. plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
  52. plt.xlabel('Epoch')
  53. plt.ylabel('Average Loss')
  54. plt.title('Loss vs Epoch On Training & Validation data')
  55. # MAKES EPOCH VS VALIDATION LOSS GRAPH
  56. plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
  57. plt.savefig('./avgloss_epoch_curve.png')
  58. plt.show()
  59. torch.save(model.state_dict(), CNN_filepath)
  60. print("Model saved")
  61. def load(model, filepath):
  62. model.load_state_dict(torch.load(filepath))
  63. def evaluate(model, val_data, graphs=True):
  64. # EVALUATE MODEL
  65. correct = 0
  66. total = 0
  67. predictionsLabels = []
  68. predictionsProbabilities = []
  69. true_labels = []
  70. criterion = nn.CrossEntropyLoss(reduction='mean')
  71. model.eval()
  72. # since we're not training, we don't need to calculate the gradients for our outputs
  73. with torch.no_grad():
  74. for data in val_data:
  75. images, labels = data[0].to(model.device), data[1].to(model.device)
  76. # calculate outputs by running images through the network
  77. outputs = model.forward(images)
  78. # the class with the highest energy is what we choose as prediction
  79. loss = criterion(outputs, labels) # mean loss from batch
  80. # Gets accuracy
  81. _, predicted = torch.max(outputs.data, 1)
  82. total += labels.size(0)
  83. correct += (predicted == labels).sum().item()
  84. # Saves predictionsProbabilities and labels for ROC
  85. if(graphs):
  86. predictionsLabels.extend(predicted.cpu().numpy())
  87. predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
  88. true_labels.extend(labels.cpu().numpy())
  89. print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
  90. if(not graphs): print(f'Validation loss: {loss.item()}')
  91. else:
  92. # ROC
  93. # Calculate TPR and FPR
  94. fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
  95. # Calculate AUC
  96. roc_auc = auc(fpr, tpr)
  97. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
  98. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  99. plt.xlim([0.0, 1.005])
  100. plt.ylim([0.0, 1.005])
  101. plt.xlabel('False Positive Rate (1 - Specificity)')
  102. plt.ylabel('True Positive Rate (Sensitivity)')
  103. plt.title('Receiver Operating Characteristic (ROC) Curve')
  104. plt.legend(loc="lower right")
  105. plt.savefig('./ROC.png')
  106. plt.show()
  107. # Calculate confusion matrix
  108. cm = confusion_matrix(true_labels, predictionsLabels)
  109. # Plot confusion matrix
  110. plt.figure(figsize=(8, 6))
  111. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
  112. plt.xlabel('Predicted labels')
  113. plt.ylabel('True labels')
  114. plt.title('Confusion Matrix')
  115. plt.savefig('./confusion_matrix.png')
  116. plt.show()
  117. # Classification Report
  118. report = classification_report(true_labels, predictionsLabels)
  119. print(report)
  120. model.train()
  121. return(loss.item())
  122. # PREDICT
  123. def predict(model, data):
  124. model.eval()
  125. with torch.no_grad():
  126. for data in data:
  127. images, labels = data[0].to(model.device), data[1].to(model.device)
  128. outputs = model.forward(images)
  129. # the class with the highest energy is what we choose as prediction
  130. _, predicted = torch.max(outputs.data, 1)
  131. model.train()
  132. return (labels, predicted) # RETURNS (true, predicted)