import torch from tqdm import tqdm import os from utils.preprocess import prepare_datasets from torch.utils.data import DataLoader import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, roc_auc_score, RocCurveDisplay import numpy as np def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')): epoch_number = 0 train_losses = [] train_accs = [] val_losses = [] val_accs = [] for epoch in range(epochs): train_loss = 0 train_incc = 0 train_corr = 0 #Training train_length = len(train_loader) for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch) + "/" + str(epochs), unit="batch"): mri, xls, label = data optimizer.zero_grad() mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) loss = criterion(outputs, label) loss.backward() optimizer.step() train_loss += loss.item() #Calculate Correct and Incorrect Predictions _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) train_corr += (predicted == labels).sum().item() train_incc += (predicted != labels).sum().item() train_losses.append(train_loss / train_length) train_accs.append(train_corr / (train_corr + train_incc)) #Validation with torch.no_grad(): val_loss = 0 val_incc = 0 val_corr = 0 val_length = len(val_loader) for _, data in enumerate(val_loader, 0): mri, xls, label = data mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) loss = criterion(outputs, label) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) val_corr += (predicted == labels).sum().item() val_incc += (predicted != labels).sum().item() val_losses.append(val_loss / val_length) val_accs.append(val_corr / (val_corr + val_incc)) epoch_number += 1 print("--- SAVING MODEL ---") if not os.path.exists(saved_model_path): os.makedirs(saved_model_path) torch.save(model, saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pkl") #Create dataframe with training and validation losses and accuracies, set index to epoch df = pd.DataFrame() df["train_loss"] = train_losses df["train_acc"] = train_accs df["val_loss"] = val_losses df["val_acc"] = val_accs df.index.name = "epoch" return df def test_model(model, test_loader, cuda_device=torch.device('cuda:0')): #Test model correct = 0 incorrect = 0 predictions = [] actual = [] with torch.no_grad(): length = len(test_loader) for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"): mri, xls, label = data mri = mri.to(cuda_device).float() xls = xls.to(cuda_device).float() label = label.to(cuda_device).float() outputs = model((mri, xls)) _, predicted = torch.max(outputs.data, 1) _, labels = torch.max(label.data, 1) incorrect += (predicted != labels).sum().item() correct += (predicted == labels).sum().item() predictions.extend(predicted.tolist()) actual.extend(labels.tolist()) return predictions, actual, correct, incorrect def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64): training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed) train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device)) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) return train_dataloader, val_dataloader, test_dataloader def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path): #Create 2 plots, one for accuracy and one for loss if not os.path.exists(plot_path): os.makedirs(plot_path) #Accuracy Plot plt.figure() plt.plot(train_acc, label="Training Accuracy") plt.plot(val_acc, label="Validation Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.title("Accuracy of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png") plt.close() #Loss Plot plt.figure() plt.plot(train_loss, label="Training Loss") plt.plot(val_loss, label="Validation Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Loss of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png") plt.close() def plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path): #Create confusion matrix if not os.path.exists(plot_path): os.makedirs(plot_path) ConfusionMatrixDisplay.from_predictions(predicted, actual).plot() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_confusion_matrix.png") plt.close() def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path): #Create ROC Curve if not os.path.exists(plot_path): os.makedirs(plot_path) np.array(predicted, dtype=np.float64) np.array(actual, dtype=np.float64) fpr, tpr, _ = roc_curve(actual, predicted) print(fpr, tpr) auc = roc_auc_score(actual, predicted) plt.figure() RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png") plt.close()