123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- import torch
- from tqdm import tqdm
- import os
- from utils.preprocess import prepare_datasets
- from torch.utils.data import DataLoader
- import pandas as pd
- import matplotlib.pyplot as plt
- from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, roc_auc_score, RocCurveDisplay
- import numpy as np
- def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')):
-
- epoch_number = 0
- train_losses = []
- train_accs = []
- val_losses = []
- val_accs = []
-
- for epoch in range(epochs):
- train_loss = 0
- train_incc = 0
- train_corr = 0
-
- #Training
- train_length = len(train_loader)
- for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch) + "/" + str(epochs), unit="batch"):
- mri, xls, label = data
- optimizer.zero_grad()
- mri = mri.to(cuda_device).float()
- xls = xls.to(cuda_device).float()
- label = label.to(cuda_device).float()
- outputs = model((mri, xls))
- loss = criterion(outputs, label)
- loss.backward()
- optimizer.step()
-
- train_loss += loss.item()
-
- #Calculate Correct and Incorrect Predictions
- _, predicted = torch.max(outputs.data, 1)
- _, labels = torch.max(label.data, 1)
-
- train_corr += (predicted == labels).sum().item()
- train_incc += (predicted != labels).sum().item()
-
- train_losses.append(train_loss / train_length)
- train_accs.append(train_corr / (train_corr + train_incc))
-
-
- #Validation
- with torch.no_grad():
- val_loss = 0
- val_incc = 0
- val_corr = 0
-
- val_length = len(val_loader)
- for _, data in enumerate(val_loader, 0):
- mri, xls, label = data
- mri = mri.to(cuda_device).float()
- xls = xls.to(cuda_device).float()
- label = label.to(cuda_device).float()
- outputs = model((mri, xls))
- loss = criterion(outputs, label)
- val_loss += loss.item()
- _, predicted = torch.max(outputs.data, 1)
- _, labels = torch.max(label.data, 1)
- val_corr += (predicted == labels).sum().item()
- val_incc += (predicted != labels).sum().item()
-
- val_losses.append(val_loss / val_length)
- val_accs.append(val_corr / (val_corr + val_incc))
-
- epoch_number += 1
- print("--- SAVING MODEL ---")
- if not os.path.exists(saved_model_path):
- os.makedirs(saved_model_path)
-
- torch.save(model, saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pkl")
-
- #Create dataframe with training and validation losses and accuracies, set index to epoch
- df = pd.DataFrame()
- df["train_loss"] = train_losses
- df["train_acc"] = train_accs
- df["val_loss"] = val_losses
- df["val_acc"] = val_accs
- df.index.name = "epoch"
-
- return df
-
- def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
- #Test model
- correct = 0
- incorrect = 0
-
- predictions = []
- actual = []
- with torch.no_grad():
- length = len(test_loader)
- for i, data in tqdm(enumerate(test_loader, 0), total=length, desc="Testing", unit="batch"):
- mri, xls, label = data
- mri = mri.to(cuda_device).float()
- xls = xls.to(cuda_device).float()
- label = label.to(cuda_device).float()
- outputs = model((mri, xls))
- _, predicted = torch.max(outputs.data, 1)
- _, labels = torch.max(label.data, 1)
- incorrect += (predicted != labels).sum().item()
- correct += (predicted == labels).sum().item()
-
-
- predictions.extend(predicted.tolist())
- actual.extend(labels.tolist())
-
- return predictions, actual, correct, incorrect
- def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64):
- training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed)
- train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
- test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
- val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
- return train_dataloader, val_dataloader, test_dataloader
- def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp, plot_path):
- #Create 2 plots, one for accuracy and one for loss
- if not os.path.exists(plot_path):
- os.makedirs(plot_path)
-
- #Accuracy Plot
- plt.figure()
- plt.plot(train_acc, label="Training Accuracy")
- plt.plot(val_acc, label="Validation Accuracy")
- plt.xlabel("Epoch")
- plt.ylabel("Accuracy")
- plt.title("Accuracy of " + model_name + " Model: " + timestamp)
- plt.legend()
- plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png")
- plt.close()
-
- #Loss Plot
- plt.figure()
- plt.plot(train_loss, label="Training Loss")
- plt.plot(val_loss, label="Validation Loss")
- plt.xlabel("Epoch")
- plt.ylabel("Loss")
- plt.title("Loss of " + model_name + " Model: " + timestamp)
- plt.legend()
- plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png")
- plt.close()
-
- def plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path):
- #Create confusion matrix
- if not os.path.exists(plot_path):
- os.makedirs(plot_path)
-
- ConfusionMatrixDisplay.from_predictions(predicted, actual).plot()
- plt.savefig(plot_path + model_name + "_t-" + timestamp + "_confusion_matrix.png")
- plt.close()
-
- def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path):
- #Create ROC Curve
- if not os.path.exists(plot_path):
- os.makedirs(plot_path)
-
- np.array(predicted, dtype=np.float64)
- np.array(actual, dtype=np.float64)
-
- fpr, tpr, _ = roc_curve(actual, predicted)
- print(fpr, tpr)
- auc = roc_auc_score(actual, predicted)
- plt.figure()
- RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot()
- plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png")
- plt.close()
-
|