training.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import torch
  2. from tqdm import tqdm
  3. import os
  4. from utils.preprocess import prepare_datasets
  5. from torch.utils.data import DataLoader
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, roc_auc_score, RocCurveDisplay
  9. import numpy as np
  10. def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')):
  11. epoch_number = 0
  12. train_losses = []
  13. train_accs = []
  14. val_losses = []
  15. val_accs = []
  16. for epoch in range(epochs):
  17. train_loss = 0
  18. train_incc = 0
  19. train_corr = 0
  20. #Training
  21. train_length = len(train_loader)
  22. for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch) + "/" + str(epochs), unit="batch"):
  23. mri, xls, label = data
  24. optimizer.zero_grad()
  25. mri = mri.to(cuda_device).float()
  26. xls = xls.to(cuda_device).float()
  27. label = label.to(cuda_device).float()
  28. outputs = model((mri, xls))
  29. loss = criterion(outputs, label)
  30. loss.backward()
  31. optimizer.step()
  32. train_loss += loss.item()
  33. #Calculate Correct and Incorrect Predictions
  34. _, predicted = torch.max(outputs.data, 1)
  35. _, labels = torch.max(label.data, 1)
  36. train_corr += (predicted == labels).sum().item()
  37. train_incc += (predicted != labels).sum().item()
  38. train_losses.append(train_loss / train_length)
  39. train_accs.append(train_corr / (train_corr + train_incc))
  40. #Validation
  41. with torch.no_grad():
  42. val_loss = 0
  43. val_incc = 0
  44. val_corr = 0
  45. val_length = len(val_loader)
  46. for _, data in enumerate(val_loader, 0):
  47. mri, xls, label = data
  48. mri = mri.to(cuda_device).float()
  49. xls = xls.to(cuda_device).float()
  50. label = label.to(cuda_device).float()
  51. outputs = model((mri, xls))
  52. loss = criterion(outputs, label)
  53. val_loss += loss.item()
  54. _, predicted = torch.max(outputs.data, 1)
  55. _, labels = torch.max(label.data, 1)
  56. val_corr += (predicted == labels).sum().item()
  57. val_incc += (predicted != labels).sum().item()
  58. val_losses.append(val_loss / val_length)
  59. val_accs.append(val_corr / (val_corr + val_incc))
  60. epoch_number += 1
  61. print("--- SAVING MODEL ---")
  62. if not os.path.exists(saved_model_path):
  63. os.makedirs(saved_model_path)
  64. torch.save(model, saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pkl")
  65. #Create dataframe with training and validation losses and accuracies, set index to epoch
  66. df = pd.DataFrame()
  67. df["train_loss"] = train_losses
  68. df["train_acc"] = train_accs
  69. df["val_loss"] = val_losses
  70. df["val_acc"] = val_accs
  71. df.index.name = "epoch"
  72. return df
  73. def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
  74. #Test model
  75. correct = 0
  76. incorrect = 0
  77. predictions = []
  78. actual = []
  79. max_preds = []
  80. max_actuals = []
  81. with torch.no_grad():
  82. length = len(test_loader)
  83. for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"):
  84. mri, xls, labels = data
  85. mri = mri.to(cuda_device).float()
  86. xls = xls.to(cuda_device).float()
  87. labels = labels.to(cuda_device).float()
  88. outputs = model((mri, xls))
  89. _, m_predicted = torch.max(outputs.data, 1)
  90. _, m_labels = torch.max(labels.data, 1)
  91. incorrect += (m_predicted != m_labels).sum().item()
  92. correct += (m_predicted == m_labels).sum().item()
  93. #We just want the positive class, since there are only 2 classes and we use softmax
  94. pos_outputs = outputs[:, 1]
  95. pos_labels = labels[:, 1]
  96. predictions.extend(pos_outputs.tolist())
  97. actual.extend(pos_labels.tolist())
  98. _, max_pred = torch.max(outputs.data, 1)
  99. _, max_actual = torch.max(labels.data, 1)
  100. max_preds.extend(max_pred.tolist())
  101. max_actuals.extend(max_actual.tolist())
  102. return predictions, actual, correct, incorrect, max_preds, max_actuals
  103. def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64):
  104. training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed)
  105. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  106. test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
  107. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  108. return train_dataloader, val_dataloader, test_dataloader, test_data
  109. def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path):
  110. #Create 2 plots, one for accuracy and one for loss
  111. if not os.path.exists(plot_path):
  112. os.makedirs(plot_path)
  113. #Accuracy Plot
  114. plt.figure()
  115. plt.plot(train_acc, label="Training Accuracy")
  116. plt.plot(val_acc, label="Validation Accuracy")
  117. plt.xlabel("Epoch")
  118. plt.ylabel("Accuracy")
  119. plt.title("Accuracy of " + model_name + " Model: " + timestamp)
  120. plt.legend()
  121. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png")
  122. plt.close()
  123. #Loss Plot
  124. plt.figure()
  125. plt.plot(train_loss, label="Training Loss")
  126. plt.plot(val_loss, label="Validation Loss")
  127. plt.xlabel("Epoch")
  128. plt.ylabel("Loss")
  129. plt.title("Loss of " + model_name + " Model: " + timestamp)
  130. plt.legend()
  131. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png")
  132. plt.close()
  133. def plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path):
  134. #Create confusion matrix
  135. if not os.path.exists(plot_path):
  136. os.makedirs(plot_path)
  137. ConfusionMatrixDisplay.from_predictions(predicted, actual).plot()
  138. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_confusion_matrix.png")
  139. plt.close()
  140. def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path):
  141. #Create ROC Curve
  142. if not os.path.exists(plot_path):
  143. os.makedirs(plot_path)
  144. np.array(predicted, dtype=np.float64)
  145. np.array(actual, dtype=np.float64)
  146. fpr, tpr, _ = roc_curve(actual, predicted)
  147. auc = roc_auc_score(actual, predicted)
  148. plt.figure()
  149. RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot()
  150. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png")
  151. plt.close()
  152. def plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=torch.device('cuda:0')):
  153. #Plot a bevy of random images from the test set and their predictions for the positive class
  154. if not os.path.exists(plot_path):
  155. os.makedirs(plot_path)
  156. #Get random images
  157. images = []
  158. for i in range(8):
  159. images.append(test_set[np.random.randint(0, len(test_set))])
  160. #Now that we have our images, create a subplot for each image
  161. plt.figure()
  162. fig, axs = plt.subplots(2, 4)
  163. for i, image in enumerate(images):
  164. mri, xls, label = image
  165. mri = mri.to(cuda_device).float()
  166. xls = xls.to(cuda_device).float()
  167. label = label[1]
  168. mri = mri.unsqueeze(0)
  169. xls = xls.unsqueeze(0)
  170. output = model((mri, xls))
  171. prediction = output[:, 1]
  172. sliced_image = torch.permute(torch.select(torch.squeeze(mri, 0), 3, 80), (1, 2, 0)).cpu().numpy()
  173. axs[i // 4, i % 4].imshow(sliced_image, cmap="gray")
  174. axs[i // 4, i % 4].set_title("Pr: " + str(round(prediction.item(), 3)) + ", \nAc: " + str(label.item()))
  175. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_image_selection.png")
  176. plt.close()