|
@@ -0,0 +1,124 @@
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from datareader import DataReader
|
|
|
+from sklearn.metrics import roc_auc_score
|
|
|
+import time
|
|
|
+import os
|
|
|
+
|
|
|
+class Training():
|
|
|
+ def __init__(self, main_path_to_data):
|
|
|
+ self.main_path_to_data = main_path_to_data
|
|
|
+
|
|
|
+ def train(self, train_info, valid_info, model, hyperparameters, path_to_model):
|
|
|
+ """Function for training the model on train_info.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ train_info (list): list of paths to one central slice per patient (shuffled)
|
|
|
+ valid_info (list): list of paths to 10 central slices per patient (ordered)
|
|
|
+ model (nn.Module): architecture of the model
|
|
|
+ hyperparameters (dictionary): dictionary of hyperparameters (learning rate, weight decay, multiplicator)
|
|
|
+ path_to_model (string): absolute path to the folder where outputs will be saved
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ aucs (ndarray): array of AUCs during validation (one AUC per epoch)
|
|
|
+ losses (ndarray): array of LOSSes during training (one running loss per epoch)
|
|
|
+ """
|
|
|
+ # 0. Check which device is available
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+
|
|
|
+ # 1. Create folder to save the model weights, aucs and losses
|
|
|
+ try:
|
|
|
+ os.mkdir(path_to_model)
|
|
|
+ except: # folder already exists, add current time to avoid duplication
|
|
|
+ os.mkdir(path_to_model + str(int(time.time())))
|
|
|
+
|
|
|
+ # 2. Load hyperparameters
|
|
|
+ learning_rate = hyperparameters['learning_rate']
|
|
|
+ weight_decay = hyperparameters['weight_decay']
|
|
|
+ total_epoch = hyperparameters['total_epoch']
|
|
|
+ multiplicator = hyperparameters['multiplicator']
|
|
|
+
|
|
|
+ # 4. Create train and validation generators, batch_size = 10 for validation generator (10 central slices)
|
|
|
+ train_datareader = DataReader(self.main_path_to_data, train_info)
|
|
|
+ train_generator = DataLoader(train_datareader, batch_size=16, shuffle=True, pin_memory=True, num_workers=2)
|
|
|
+
|
|
|
+ valid_datareader = DataReader(self.main_path_to_data, valid_info)
|
|
|
+ valid_generator = DataLoader(valid_datareader, batch_size=10, shuffle=False, pin_memory=True, num_workers=2)
|
|
|
+
|
|
|
+ # 5. Move model to the available device
|
|
|
+ model.to(device)
|
|
|
+
|
|
|
+ # 6. Define criterion function, optimizer and scheduler
|
|
|
+ criterion = torch.nn.BCEWithLogitsLoss()
|
|
|
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
|
|
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, multiplicator, last_epoch=-1)
|
|
|
+
|
|
|
+ # 7. Creat lists for tracking AUC and Losses during training
|
|
|
+ aucs = []
|
|
|
+ losses = []
|
|
|
+ best_auc = -np.inf
|
|
|
+
|
|
|
+ # 8. Run training
|
|
|
+ for epoch in range(total_epoch):
|
|
|
+ start = time.time()
|
|
|
+ print('Epoch: %d/%d' % (epoch + 1, total_epoch))
|
|
|
+
|
|
|
+ running_loss = 0
|
|
|
+ # A) Train model
|
|
|
+ model.train() # put model in training mode
|
|
|
+ for item_train in train_generator:
|
|
|
+ # Load images (x) and labels (y)
|
|
|
+ x, y = item_train
|
|
|
+ x = x.to(device)
|
|
|
+ y = y.to(device)
|
|
|
+
|
|
|
+ # Forward pass
|
|
|
+ optimizer.zero_grad()
|
|
|
+ y_hat = model.forward(x)
|
|
|
+ loss = criterion(y_hat, y)
|
|
|
+
|
|
|
+ # Backward pass
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ # Track loss change
|
|
|
+ running_loss += loss.item()
|
|
|
+
|
|
|
+ # B) Validate model
|
|
|
+ predictions = []
|
|
|
+ trues = []
|
|
|
+
|
|
|
+ model.eval() # put model in eval mode
|
|
|
+ for item_valid in valid_generator:
|
|
|
+ # Load images (x) and labels (y)
|
|
|
+ x, y = item_valid
|
|
|
+ x = x.to(device)
|
|
|
+ y = y.to(device)
|
|
|
+
|
|
|
+ # Forward pass
|
|
|
+ with torch.no_grad():
|
|
|
+ y_hat = model.forward(x)
|
|
|
+ y_hat = torch.sigmoid(y_hat) # In training we are using BCEWithLogitsLoss for improved performance (sigmoid is already embedded), here we have to add it
|
|
|
+
|
|
|
+ predictions.append(np.mean(y_hat.cpu().numpy())) # Calculate mean of 10 predictions
|
|
|
+ trues.append(int(y.cpu().numpy()[0]))
|
|
|
+
|
|
|
+ auc = roc_auc_score(trues, predictions)
|
|
|
+
|
|
|
+ # C) Track changes, update LR, save best model
|
|
|
+ print("AUC: ", auc, ", Running loss: ", running_loss/len(train_generator), ", Time: ", time.time()-start)
|
|
|
+
|
|
|
+ if auc > best_auc:
|
|
|
+ torch.save(model.state_dict(), os.path.join(path_to_model, 'trained_model_weights.pth'))
|
|
|
+ best_auc = auc
|
|
|
+
|
|
|
+
|
|
|
+ aucs.append(auc)
|
|
|
+ losses.append(running_loss/len(train_generator))
|
|
|
+ scheduler.step()
|
|
|
+
|
|
|
+ np.save(os.path.join(path_to_model, 'AUCS.npy'), np.array(aucs))
|
|
|
+ np.save(os.path.join(path_to_model, 'LOSSES.npy'), np.array(losses))
|
|
|
+
|
|
|
+ return aucs, losses
|