import torch from tqdm import tqdm import os from utils.preprocess import prepare_datasets from torch.utils.data import DataLoader import pandas as pd import matplotlib.pyplot as plt def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')): #Print Shape of Image Data #Print Training Data Length print("Length of Training Data: ", len(train_loader)) print("--- INITIALIZING MODEL ---") print("Seed: ", seed) epoch_number = 0 print("--- TRAINING MODEL ---") train_losses = [] train_accs = [] val_losses = [] val_accs = [] for epoch in range(epochs): train_loss = 0 train_incc = 0 train_corr = 0 #Training train_length = len(train_loader) for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch), unit="batch"): mri, xls, label = data optimizer.zero_grad() mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) loss = criterion(outputs, label) loss.backward() optimizer.step() train_loss += loss.item() #Calculate Correct and Incorrect Predictions _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) train_corr += (predicted == labels).sum().item() train_incc += (predicted != labels).sum().item() train_losses.append(train_loss / train_length) train_accs.append(train_corr / (train_corr + train_incc)) #Validation with torch.no_grad(): val_loss = 0 val_incc = 0 val_corr = 0 val_length = len(val_loader) for _, data in enumerate(val_loader, 0): mri, xls, label = data mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) loss = criterion(outputs, label) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) val_corr += (predicted == labels).sum().item() val_incc += (predicted != labels).sum().item() val_losses.append(val_loss / val_length) val_accs.append(val_corr / (val_corr + val_incc)) epoch_number += 1 print("--- SAVING MODEL ---") if not os.path.exists(saved_model_path): os.makedirs(saved_model_path) torch.save(model.state_dict(), saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pt") #Create dataframe with training and validation losses and accuracies, set index to epoch df = pd.DataFrame() df["train_loss"] = train_losses df["train_acc"] = train_accs df["val_loss"] = val_losses df["val_acc"] = val_accs df.index.name = "epoch" return df def test_model(model, test_loader, cuda_device=torch.device('cuda:0')): print("--- TESTING MODEL ---") #Test model correct = 0 incorrect = 0 with torch.no_grad(): length = len(test_loader) for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"): mri, xls, label = data mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) incorrect += (predicted != labels).sum().item() correct += (predicted == labels).sum().item() print("Model Accuracy: ", 100 * correct / (correct + incorrect)) def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0')): training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed) batch_size = 64 train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device)) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) return train_dataloader, val_dataloader, test_dataloader def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path): #Create 2 plots, one for accuracy and one for loss if not os.path.exists(plot_path): os.makedirs(plot_path) #Accuracy Plot plt.figure() plt.plot(train_acc, label="Training Accuracy") plt.plot(val_acc, label="Validation Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.title("Accuracy of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png") #Loss Plot plt.figure() plt.plot(train_loss, label="Training Loss") plt.plot(val_loss, label="Validation Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Loss of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png")