import torch from torch import nn, optim, cat, stack from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score import seaborn as sns # GENERAL PURPOSE import os import pandas as pd import numpy as np import matplotlib.pyplot as plt import time # TRAIN def train(model, train_data, test_data, CNN_filepath, params, graphs=True): model.train() criterion = nn.CrossEntropyLoss(reduction='mean') optimizer = optim.Adam(model.parameters(), lr=1e-5) #, weight_decay=params['weight_decay'], betas=params['momentum']) losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time']) start_time = time.time() # seconds # model.init_history() epochs = params['epochs'] for epoch in range(epochs): # loop over the dataset multiple times epoch += 1 # Estimate & count training time t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time)) t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs)) print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}") running_loss = 0.0 predictions = [] # Batches & training for i, data in enumerate(train_data, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model.forward(inputs) loss = criterion(outputs, labels) # This loss is the mean of losses for the batch loss.backward() optimizer.step() # adds average batch loss to running loss running_loss += loss.item() # mini-batches for progress if(i%10==0 and i!=0): print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}") # Gets predictions for f1 metric # predictions = predictions.append(torch.max(outputs.data, 1)[1]) # average loss avg_loss = running_loss / len(train_data) # Running_loss / number of batches print(f"Avg. loss: {avg_loss}") # loss on validation val_loss = evaluate(model, test_data, graphs=False) # , f1_validation losses = pd.concat([losses, pd.DataFrame([{'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}])]) # model.append_loss(running_loss) # model.append_val_loss(val_loss) # f1_training = f1_score(test_data.data, predictions.data) # model.append_metric(f1_training) # model.append_val_metric(f1_validation) print('Finished Training') start_time = time.localtime() time_string = time.strftime("%Y-%m-%d_%H:%M", start_time) losses.to_csv(f'./cnn_net_data_{time_string}.csv') if(graphs): # MAKES EPOCH VS AVG LOSS GRAPH plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training") plt.xlabel('Epoch') plt.ylabel('Average Loss') plt.title('Loss vs Epoch On Training & Validation data') # PLOTS EPOCH VS VALIDATION LOSS ON GRAPH plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation") plt.legend(loc="lower right") plt.savefig(f"./avgloss_epoch_curve_{time_string}.png") print("AVG LOSS EPOCH CURVE IN TRAINING DONE") # plt.show() torch.save(model.state_dict(), CNN_filepath) print("Model saved") return model # , model.parameters() def load(model, filepath): model.load_state_dict(torch.load(filepath)) def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None): start_time = time.localtime() # seconds correct, total = 0, 0 predictionsLabels, predictionsProbabilities, true_labels = [], [], [] # predictions = [] criterion = nn.CrossEntropyLoss(reduction='mean') model.eval() # since we're not training, we don't need to calculate the gradients for our outputs with torch.no_grad(): for data in val_data: images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device # calculate outputs by running images through the model outputs = model.forward(images) # the class with the highest energy is what we choose as prediction loss = criterion(outputs, labels) # mean loss from batch # Gets accuracy predicted = torch.max(outputs.data, 1)[1] # predictions = predictions.append(predicted) # for F1 score total += labels.size(0) correct += (predicted == labels).sum().item() # Saves predictionsProbabilities and labels for ROC if(graphs): predictionsLabels.extend(predicted.cpu().numpy()) predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive true_labels.extend(labels.cpu().numpy()) # K-FOLD MODE if(fold!=None): # Print accuracy print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total)) print('--------------------------------') results[fold] = 100.0 * (correct / total) true_labels = np.array(true_labels) # ROC # Calculate TPR and FPR fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities) time_string = time.strftime("%Y-%m-%d_%H:%M", start_time) # Calculate AUC roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold} (AUC: {roc_auc})') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.005]) plt.ylim([0.0, 1.005]) plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.title('Receiver Operating Characteristic (ROC) Curve') plt.legend(loc="lower right") plt.savefig(f'./ROC_{k_folds}_Folds_{time_string}.png') print("SAVED ROC FOR K-FOLD") return results # NORMAL EVALUATION print(f'Accuracy of the network on {total} scans: {100 * correct // total}%') if(not graphs): print(f'Validation loss: {loss.item()}') else: time_string = time.strftime("%Y-%m-%d_%H:%M", start_time) # ROC # Calculate TPR and FPR fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities) # Calculate AUC roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.005]) plt.ylim([0.0, 1.005]) plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.title('Receiver Operating Characteristic (ROC) Curve') plt.legend(loc="lower right") plt.savefig(f'./ROC_{time_string}.png') print("SAVED ROC FOR NORMAL") # plt.show() # Calculate confusion matrix cm = confusion_matrix(true_labels, predictionsLabels) # Plot confusion matrix plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) plt.xlabel('Predicted labels') plt.ylabel('True labels') plt.title('Confusion Matrix') plt.savefig(f'./confusion_matrix_{time_string}.png') # plt.show() # Classification Report report = classification_report(true_labels, predictionsLabels) print(report) # f1_validation = f1_score(val_data, predictions.data) model.train() return loss.item() # , f1_validation) # PREDICT def predict(model, data): model.eval() with torch.no_grad(): for data in data: images, labels = data[0].to(model.device), data[1].to(model.device) outputs = model.forward(images) # the class with the highest energy is what we choose as prediction _, predicted = torch.max(outputs.data, 1) model.train() return (labels, predicted) # RETURNS (true, predicted)