train.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import numpy as np
  2. import torch
  3. from torch.utils.data import DataLoader
  4. from datareader import DataReader
  5. from sklearn.metrics import roc_auc_score
  6. import time
  7. import os
  8. class Training():
  9. def __init__(self, main_path_to_data):
  10. self.main_path_to_data = main_path_to_data
  11. def train(self, train_info, valid_info, model, hyperparameters, path_to_model):
  12. """Function for training the model on train_info.
  13. Args:
  14. train_info (list): list of paths to one central slice per patient (shuffled)
  15. valid_info (list): list of paths to 10 central slices per patient (ordered)
  16. model (nn.Module): architecture of the model
  17. hyperparameters (dictionary): dictionary of hyperparameters (learning rate, weight decay, multiplicator)
  18. path_to_model (string): absolute path to the folder where outputs will be saved
  19. Returns:
  20. aucs (ndarray): array of AUCs during validation (one AUC per epoch)
  21. losses (ndarray): array of LOSSes during training (one running loss per epoch)
  22. """
  23. # 0. Check which device is available
  24. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  25. # 1. Create folder to save the model weights, aucs and losses
  26. try:
  27. os.mkdir(path_to_model)
  28. except: # folder already exists, add current time to avoid duplication
  29. os.mkdir(path_to_model + str(int(time.time())))
  30. # 2. Load hyperparameters
  31. learning_rate = hyperparameters['learning_rate']
  32. weight_decay = hyperparameters['weight_decay']
  33. total_epoch = hyperparameters['total_epoch']
  34. multiplicator = hyperparameters['multiplicator']
  35. # 4. Create train and validation generators, batch_size = 10 for validation generator (10 central slices)
  36. train_datareader = DataReader(self.main_path_to_data, train_info)
  37. train_generator = DataLoader(train_datareader, batch_size=16, shuffle=True, pin_memory=True, num_workers=2)
  38. valid_datareader = DataReader(self.main_path_to_data, valid_info)
  39. valid_generator = DataLoader(valid_datareader, batch_size=10, shuffle=False, pin_memory=True, num_workers=2)
  40. # 5. Move model to the available device
  41. model.to(device)
  42. # 6. Define criterion function, optimizer and scheduler
  43. criterion = torch.nn.BCEWithLogitsLoss()
  44. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
  45. scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1)
  46. # 7. Creat lists for tracking AUC and Losses during training
  47. aucs = []
  48. losses = []
  49. best_auc = -np.inf
  50. # 8. Run training
  51. for epoch in range(total_epoch):
  52. start = time.time()
  53. print('Epoch: %d/%d' % (epoch + 1, total_epoch))
  54. running_loss = 0
  55. # A) Train model
  56. model.train() # put model in training mode
  57. for item_train in train_generator:
  58. # Load images (x) and labels (y)
  59. x, y = item_train
  60. x = x.to(device)
  61. y = y.to(device)
  62. # Forward pass
  63. optimizer.zero_grad()
  64. y_hat = model.forward(x)
  65. loss = criterion(y_hat, y)
  66. # Backward pass
  67. loss.backward()
  68. optimizer.step()
  69. # Track loss change
  70. running_loss += loss.item()
  71. # B) Validate model
  72. predictions = []
  73. trues = []
  74. model.eval() # put model in eval mode
  75. for item_valid in valid_generator:
  76. # Load images (x) and labels (y)
  77. x, y = item_valid
  78. x = x.to(device)
  79. y = y.to(device)
  80. # Forward pass
  81. with torch.no_grad():
  82. y_hat = model.forward(x)
  83. y_hat = torch.sigmoid(y_hat) # In training we are using BCEWithLogitsLoss for improved performance (sigmoid is already embedded), here we have to add it
  84. predictions.append(np.mean(y_hat.cpu().numpy())) # Calculate mean of 10 predictions
  85. trues.append(int(y.cpu().numpy()[0]))
  86. auc = roc_auc_score(trues, predictions)
  87. # C) Track changes, update LR, save best model
  88. print("AUC: ", auc, ", Running loss: ", running_loss/len(train_generator), ", Time: ", time.time()-start)
  89. if auc > best_auc:
  90. torch.save(model.state_dict(), os.path.join(path_to_model, 'trained_model_weights.pth'))
  91. best_auc = auc
  92. aucs.append(auc)
  93. losses.append(running_loss/len(train_generator))
  94. scheduler.step()
  95. np.save(os.path.join(path_to_model, 'AUCS.npy'), np.array(aucs))
  96. np.save(os.path.join(path_to_model, 'LOSSES.npy'), np.array(losses))
  97. return aucs, losses