training.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import torch
  2. from tqdm import tqdm
  3. import os
  4. from utils.preprocess import prepare_datasets
  5. from torch.utils.data import DataLoader
  6. import pandas as pd
  7. import matplotlib.pyplot as plt
  8. def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')):
  9. #Print Shape of Image Data
  10. #Print Training Data Length
  11. print("Length of Training Data: ", len(train_loader))
  12. print("--- INITIALIZING MODEL ---")
  13. print("Seed: ", seed)
  14. epoch_number = 0
  15. print("--- TRAINING MODEL ---")
  16. train_losses = []
  17. train_accs = []
  18. val_losses = []
  19. val_accs = []
  20. for epoch in range(epochs):
  21. train_loss = 0
  22. train_incc = 0
  23. train_corr = 0
  24. #Training
  25. train_length = len(train_loader)
  26. for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch), unit="batch"):
  27. mri, xls, label = data
  28. optimizer.zero_grad()
  29. mri = mri.to(cuda_device).float()
  30. xls = xls.to(cuda_device).float()
  31. label = label.to(cuda_device).float()
  32. outputs = model((mri, xls))
  33. loss = criterion(outputs, label)
  34. loss.backward()
  35. optimizer.step()
  36. train_loss += loss.item()
  37. #Calculate Correct and Incorrect Predictions
  38. _, predicted = torch.max(outputs.data, 1)
  39. _, labels = torch.max(label.data, 1)
  40. train_corr += (predicted == labels).sum().item()
  41. train_incc += (predicted != labels).sum().item()
  42. train_losses.append(train_loss / train_length)
  43. train_accs.append(train_corr / (train_corr + train_incc))
  44. #Validation
  45. with torch.no_grad():
  46. val_loss = 0
  47. val_incc = 0
  48. val_corr = 0
  49. val_length = len(val_loader)
  50. for _, data in enumerate(val_loader, 0):
  51. mri, xls, label = data
  52. mri = mri.to(cuda_device).float()
  53. xls = xls.to(cuda_device).float()
  54. label = label.to(cuda_device).float()
  55. outputs = model((mri, xls))
  56. loss = criterion(outputs, label)
  57. val_loss += loss.item()
  58. _, predicted = torch.max(outputs.data, 1)
  59. _, labels = torch.max(label.data, 1)
  60. val_corr += (predicted == labels).sum().item()
  61. val_incc += (predicted != labels).sum().item()
  62. val_losses.append(val_loss / val_length)
  63. val_accs.append(val_corr / (val_corr + val_incc))
  64. epoch_number += 1
  65. print("--- SAVING MODEL ---")
  66. if not os.path.exists(saved_model_path):
  67. os.makedirs(saved_model_path)
  68. torch.save(model.state_dict(), saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pt")
  69. #Create dataframe with training and validation losses and accuracies, set index to epoch
  70. df = pd.DataFrame()
  71. df["train_loss"] = train_losses
  72. df["train_acc"] = train_accs
  73. df["val_loss"] = val_losses
  74. df["val_acc"] = val_accs
  75. df.index.name = "epoch"
  76. return df
  77. def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
  78. print("--- TESTING MODEL ---")
  79. #Test model
  80. correct = 0
  81. incorrect = 0
  82. with torch.no_grad():
  83. length = len(test_loader)
  84. for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"):
  85. mri, xls, label = data
  86. mri = mri.to(cuda_device).float()
  87. xls = xls.to(cuda_device).float()
  88. label = label.to(cuda_device).float()
  89. outputs = model((mri, xls))
  90. _, predicted = torch.max(outputs.data, 1)
  91. _, labels = torch.max(label.data, 1)
  92. incorrect += (predicted != labels).sum().item()
  93. correct += (predicted == labels).sum().item()
  94. print("Model Accuracy: ", 100 * correct / (correct + incorrect))
  95. def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0')):
  96. training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed)
  97. batch_size = 64
  98. train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  99. test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
  100. val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
  101. return train_dataloader, val_dataloader, test_dataloader
  102. def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path):
  103. #Create 2 plots, one for accuracy and one for loss
  104. if not os.path.exists(plot_path):
  105. os.makedirs(plot_path)
  106. #Accuracy Plot
  107. plt.figure()
  108. plt.plot(train_acc, label="Training Accuracy")
  109. plt.plot(val_acc, label="Validation Accuracy")
  110. plt.xlabel("Epoch")
  111. plt.ylabel("Accuracy")
  112. plt.title("Accuracy of " + model_name + " Model: " + timestamp)
  113. plt.legend()
  114. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png")
  115. #Loss Plot
  116. plt.figure()
  117. plt.plot(train_loss, label="Training Loss")
  118. plt.plot(val_loss, label="Validation Loss")
  119. plt.xlabel("Epoch")
  120. plt.ylabel("Loss")
  121. plt.title("Loss of " + model_name + " Model: " + timestamp)
  122. plt.legend()
  123. plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png")