import torch from tqdm import tqdm import os from utils.preprocess import prepare_datasets from import DataLoader import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, roc_auc_score, RocCurveDisplay import numpy as np def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')): epoch_number = 0 train_losses = [] train_accs = [] val_losses = [] val_accs = [] for epoch in range(epochs): train_loss = 0 train_incc = 0 train_corr = 0 #Training train_length = len(train_loader) for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch) + "/" + str(epochs), unit="batch"): mri, xls, label = data optimizer.zero_grad() mri = xls = label = outputs = model((mri, xls)) loss = criterion(outputs, label) loss.backward() optimizer.step() train_loss += loss.item() #Calculate Correct and Incorrect Predictions _, predicted = torch.max(, 1) _, labels = torch.max(, 1) train_corr += (predicted == labels).sum().item() train_incc += (predicted != labels).sum().item() train_losses.append(train_loss / train_length) train_accs.append(train_corr / (train_corr + train_incc)) #Validation with torch.no_grad(): val_loss = 0 val_incc = 0 val_corr = 0 val_length = len(val_loader) for _, data in enumerate(val_loader, 0): mri, xls, label = data mri = xls = label = outputs = model((mri, xls)) loss = criterion(outputs, label) val_loss += loss.item() _, predicted = torch.max(, 1) _, labels = torch.max(, 1) val_corr += (predicted == labels).sum().item() val_incc += (predicted != labels).sum().item() val_losses.append(val_loss / val_length) val_accs.append(val_corr / (val_corr + val_incc)) epoch_number += 1 print("--- SAVING MODEL ---") if not os.path.exists(saved_model_path): os.makedirs(saved_model_path), saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pkl") #Create dataframe with training and validation losses and accuracies, set index to epoch df = pd.DataFrame() df["train_loss"] = train_losses df["train_acc"] = train_accs df["val_loss"] = val_losses df["val_acc"] = val_accs = "epoch" return df def test_model(model, test_loader, cuda_device=torch.device('cuda:0')): #Test model correct = 0 incorrect = 0 predictions = [] actual = [] max_preds = [] max_actuals = [] with torch.no_grad(): length = len(test_loader) for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"): mri, xls, labels = data mri = xls = labels = outputs = model((mri, xls)) _, m_predicted = torch.max(, 1) _, m_labels = torch.max(, 1) incorrect += (m_predicted != m_labels).sum().item() correct += (m_predicted == m_labels).sum().item() #We just want the positive class, since there are only 2 classes and we use softmax pos_outputs = outputs[:, 1] pos_labels = labels[:, 1] predictions.extend(pos_outputs.tolist()) actual.extend(pos_labels.tolist()) _, max_pred = torch.max(, 1) _, max_actual = torch.max(, 1) max_preds.extend(max_pred.tolist()) max_actuals.extend(max_actual.tolist()) return predictions, actual, correct, incorrect, max_preds, max_actuals def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64): training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed) train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device)) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device)) return train_dataloader, val_dataloader, test_dataloader, test_data def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path): #Create 2 plots, one for accuracy and one for loss if not os.path.exists(plot_path): os.makedirs(plot_path) #Accuracy Plot plt.figure() plt.plot(train_acc, label="Training Accuracy") plt.plot(val_acc, label="Validation Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.title("Accuracy of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png") plt.close() #Loss Plot plt.figure() plt.plot(train_loss, label="Training Loss") plt.plot(val_loss, label="Validation Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Loss of " + model_name + " Model: " + timestamp) plt.legend() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png") plt.close() def plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path): #Create confusion matrix if not os.path.exists(plot_path): os.makedirs(plot_path) ConfusionMatrixDisplay.from_predictions(predicted, actual).plot() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_confusion_matrix.png") plt.close() def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path): #Create ROC Curve if not os.path.exists(plot_path): os.makedirs(plot_path) np.array(predicted, dtype=np.float64) np.array(actual, dtype=np.float64) fpr, tpr, _ = roc_curve(actual, predicted) auc = roc_auc_score(actual, predicted) plt.figure() RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot() plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png") plt.close() def plot_image_selection(model, test_set, model_name, timestamp, plot_path, cuda_device=torch.device('cuda:0')): #Plot a bevy of random images from the test set and their predictions for the positive class if not os.path.exists(plot_path): os.makedirs(plot_path) #Get random images images = [] for i in range(8): images.append(test_set[np.random.randint(0, len(test_set))]) #Now that we have our images, create a subplot for each image plt.figure() fig, axs = plt.subplots(2, 4) for i, image in enumerate(images): mri, xls, label = image mri = xls = label = label[1] mri = mri.unsqueeze(0) xls = xls.unsqueeze(0) output = model((mri, xls)) prediction = output[:, 1] sliced_image = torch.permute(, 0), 3, 80), (1, 2, 0)).cpu().numpy() axs[i // 4, i % 4].imshow(sliced_image, cmap="gray") axs[i // 4, i % 4].set_title("Pr: " + str(round(prediction.item(), 3)) + ", \nAc: " + str(label.item())) plt.savefig(plot_path + model_name + "_t-" + timestamp + "_image_selection.png") plt.close()