|
@@ -5,23 +5,17 @@ from utils.preprocess import prepare_datasets
|
|
|
from torch.utils.data import DataLoader
|
|
|
import pandas as pd
|
|
|
import matplotlib.pyplot as plt
|
|
|
+from sklearn.metrics import ConfusionMatrixDisplay, roc_curve, roc_auc_score, RocCurveDisplay
|
|
|
+import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_model_path, model_name, optimizer, criterion, cuda_device=torch.device('cuda:0')):
|
|
|
- #Print Shape of Image Data
|
|
|
-
|
|
|
- #Print Training Data Length
|
|
|
- print("Length of Training Data: ", len(train_loader))
|
|
|
-
|
|
|
-
|
|
|
- print("--- INITIALIZING MODEL ---")
|
|
|
- print("Seed: ", seed)
|
|
|
+
|
|
|
epoch_number = 0
|
|
|
|
|
|
- print("--- TRAINING MODEL ---")
|
|
|
train_losses = []
|
|
|
train_accs = []
|
|
|
val_losses = []
|
|
@@ -34,7 +28,7 @@ def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_
|
|
|
|
|
|
#Training
|
|
|
train_length = len(train_loader)
|
|
|
- for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch), unit="batch"):
|
|
|
+ for _, data in tqdm(enumerate(train_loader, 0), total=train_length, desc="Epoch " + str(epoch) + "/" + str(epochs), unit="batch"):
|
|
|
mri, xls, label = data
|
|
|
|
|
|
optimizer.zero_grad()
|
|
@@ -96,7 +90,7 @@ def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_
|
|
|
if not os.path.exists(saved_model_path):
|
|
|
os.makedirs(saved_model_path)
|
|
|
|
|
|
- torch.save(model.state_dict(), saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pt")
|
|
|
+ torch.save(model, saved_model_path + model_name + "_t-" + timestamp + "_s-" + str(seed) + "_e-" + str(epochs) + ".pkl")
|
|
|
|
|
|
#Create dataframe with training and validation losses and accuracies, set index to epoch
|
|
|
df = pd.DataFrame()
|
|
@@ -109,10 +103,12 @@ def train_model(model, seed, timestamp, epochs, train_loader, val_loader, saved_
|
|
|
return df
|
|
|
|
|
|
def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
|
|
|
- print("--- TESTING MODEL ---")
|
|
|
#Test model
|
|
|
correct = 0
|
|
|
incorrect = 0
|
|
|
+
|
|
|
+ predictions = []
|
|
|
+ actual = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
length = len(test_loader)
|
|
@@ -130,13 +126,16 @@ def test_model(model, test_loader, cuda_device=torch.device('cuda:0')):
|
|
|
|
|
|
incorrect += (predicted != labels).sum().item()
|
|
|
correct += (predicted == labels).sum().item()
|
|
|
+
|
|
|
+
|
|
|
+ predictions.extend(predicted.tolist())
|
|
|
+ actual.extend(labels.tolist())
|
|
|
+
|
|
|
+ return predictions, actual, correct, incorrect
|
|
|
|
|
|
- print("Model Accuracy: ", 100 * correct / (correct + incorrect))
|
|
|
-
|
|
|
-def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0')):
|
|
|
+def initalize_dataloaders(mri_path, xls_path, val_split, seed, cuda_device=torch.device('cuda:0'), batch_size=64):
|
|
|
training_data, val_data, test_data = prepare_datasets(mri_path, xls_path, val_split, seed)
|
|
|
|
|
|
- batch_size = 64
|
|
|
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
|
test_dataloader = DataLoader(test_data, batch_size=(batch_size // 4), shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
|
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=cuda_device))
|
|
@@ -158,6 +157,7 @@ def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp
|
|
|
plt.title("Accuracy of " + model_name + " Model: " + timestamp)
|
|
|
plt.legend()
|
|
|
plt.savefig(plot_path + model_name + "_t-" + timestamp + "_acc.png")
|
|
|
+ plt.close()
|
|
|
|
|
|
#Loss Plot
|
|
|
plt.figure()
|
|
@@ -168,7 +168,30 @@ def plot_results(train_acc, train_loss, val_acc, val_loss, model_name, timestamp
|
|
|
plt.title("Loss of " + model_name + " Model: " + timestamp)
|
|
|
plt.legend()
|
|
|
plt.savefig(plot_path + model_name + "_t-" + timestamp + "_loss.png")
|
|
|
+ plt.close()
|
|
|
|
|
|
+def plot_confusion_matrix(predicted, actual, model_name, timestamp, plot_path):
|
|
|
+ #Create confusion matrix
|
|
|
+ if not os.path.exists(plot_path):
|
|
|
+ os.makedirs(plot_path)
|
|
|
|
|
|
+ ConfusionMatrixDisplay.from_predictions(predicted, actual).plot()
|
|
|
+ plt.savefig(plot_path + model_name + "_t-" + timestamp + "_confusion_matrix.png")
|
|
|
+ plt.close()
|
|
|
|
|
|
+def plot_roc_curve(predicted, actual, model_name, timestamp, plot_path):
|
|
|
+ #Create ROC Curve
|
|
|
+ if not os.path.exists(plot_path):
|
|
|
+ os.makedirs(plot_path)
|
|
|
+
|
|
|
+ np.array(predicted, dtype=np.float64)
|
|
|
+ np.array(actual, dtype=np.float64)
|
|
|
+
|
|
|
+ fpr, tpr, _ = roc_curve(actual, predicted)
|
|
|
+ print(fpr, tpr)
|
|
|
+ auc = roc_auc_score(actual, predicted)
|
|
|
+ plt.figure()
|
|
|
+ RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=auc).plot()
|
|
|
+ plt.savefig(plot_path + model_name + "_t-" + timestamp + "_roc_curve.png")
|
|
|
+ plt.close()
|
|
|
|