123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238 |
- import torch
- from torch import nn, optim, cat, stack
- from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, f1_score
- import seaborn as sns
- # GENERAL PURPOSE
- import os
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- # TRAIN
- def train(model, train_data, test_data, CNN_filepath, params, graphs=True):
- model.train()
- criterion = nn.CrossEntropyLoss(reduction='mean')
- optimizer = optim.Adam(model.parameters(), lr=1e-5) #, weight_decay=params['weight_decay'], betas=params['momentum'])
- losses = pd.DataFrame(columns=['Epoch', 'Avg_loss', 'Time'])
- start_time = time.time() # seconds
- # model.init_history()
- epochs = params['epochs']
- for epoch in range(epochs): # loop over the dataset multiple times
- epoch += 1
- # Estimate & count training time
- t = time.strftime("%H:%M:%S", time.gmtime(time.time() - start_time))
- t_remain = time.strftime("%H:%M:%S", time.gmtime((time.time() - start_time)/epoch * epochs))
- print(f"{epoch/epochs * 100} || {epoch}/{epochs} || Time: {t}/{t_remain}")
- running_loss = 0.0
- predictions = []
- # Batches & training
- for i, data in enumerate(train_data, 0):
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward + backward + optimize
- outputs = model.forward(inputs)
- loss = criterion(outputs, labels) # This loss is the mean of losses for the batch
- loss.backward()
- optimizer.step()
- # adds average batch loss to running loss
- running_loss += loss.item()
- # mini-batches for progress
- if(i%10==0 and i!=0):
- print(f"{i}/{len(train_data)}, temp. loss:{running_loss / len(train_data)}")
- # Gets predictions for f1 metric
- # predictions = predictions.append(torch.max(outputs.data, 1)[1])
- # average loss
- avg_loss = running_loss / len(train_data) # Running_loss / number of batches
- print(f"Avg. loss: {avg_loss}")
- # loss on validation
- val_loss = evaluate(model, test_data, graphs=False) # , f1_validation
- losses = pd.concat([losses, pd.DataFrame([{'Epoch':int(epoch), 'Avg_loss':avg_loss, 'Val_loss':val_loss, 'Time':time.time() - start_time}])])
- # model.append_loss(running_loss)
- # model.append_val_loss(val_loss)
- # f1_training = f1_score(test_data.data, predictions.data)
- # model.append_metric(f1_training)
- # model.append_val_metric(f1_validation)
- print('Finished Training')
- start_time = time.localtime()
- time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
- losses.to_csv(f'./cnn_net_data_{time_string}.csv')
- if(graphs):
- # MAKES EPOCH VS AVG LOSS GRAPH
- plt.plot(losses['Epoch'], losses['Avg_loss'], label="Loss on Training")
- plt.xlabel('Epoch')
- plt.ylabel('Average Loss')
- plt.title('Loss vs Epoch On Training & Validation data')
- # PLOTS EPOCH VS VALIDATION LOSS ON GRAPH
- plt.plot(losses['Epoch'], losses['Val_loss'], label="Loss on Validation")
- plt.legend(loc="lower right")
- plt.savefig(f"./avgloss_epoch_curve_{time_string}.png")
- print("AVG LOSS EPOCH CURVE IN TRAINING DONE")
- # plt.show()
- torch.save(model.state_dict(), CNN_filepath)
- print("Model saved")
- return model # , model.parameters()
- def load(model, filepath):
- model.load_state_dict(torch.load(filepath))
- def evaluate(model, val_data, graphs=True, k_folds=None, fold=None, results=None):
- start_time = time.localtime() # seconds
- correct, total = 0, 0
- predictionsLabels, predictionsProbabilities, true_labels = [], [], []
- # predictions = []
- criterion = nn.CrossEntropyLoss(reduction='mean')
- model.eval()
- # since we're not training, we don't need to calculate the gradients for our outputs
- with torch.no_grad():
- for data in val_data:
- images, labels = [data[0][0].to(model.device), stack(data[0][1], dim=0).to(model.device)], data[1].to(model.device) # TODO Clinical data not sent to model.device
- # calculate outputs by running images through the model
- outputs = model.forward(images)
- # the class with the highest energy is what we choose as prediction
- loss = criterion(outputs, labels) # mean loss from batch
- # Gets accuracy
- predicted = torch.max(outputs.data, 1)[1]
- # predictions = predictions.append(predicted) # for F1 score
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- # Saves predictionsProbabilities and labels for ROC
- if(graphs):
- predictionsLabels.extend(predicted.cpu().numpy())
- predictionsProbabilities.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
- true_labels.extend(labels.cpu().numpy())
- # K-FOLD MODE
- if(fold!=None):
- # Print accuracy
- print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
- print('--------------------------------')
- results[fold] = 100.0 * (correct / total)
- true_labels = np.array(true_labels)
- # ROC
- # Calculate TPR and FPR
- fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
- time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
- # Calculate AUC
- roc_auc = auc(fpr, tpr)
- plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold} (AUC: {roc_auc})')
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
- plt.xlim([0.0, 1.005])
- plt.ylim([0.0, 1.005])
- plt.xlabel('False Positive Rate (1 - Specificity)')
- plt.ylabel('True Positive Rate (Sensitivity)')
- plt.title('Receiver Operating Characteristic (ROC) Curve')
- plt.legend(loc="lower right")
- plt.savefig(f'./ROC_{k_folds}_Folds_{time_string}.png')
- print("SAVED ROC FOR K-FOLD")
- return results
- # NORMAL EVALUATION
- print(f'Accuracy of the network on {total} scans: {100 * correct // total}%')
- if(not graphs): print(f'Validation loss: {loss.item()}')
- else:
- time_string = time.strftime("%Y-%m-%d_%H:%M", start_time)
- # ROC
- # Calculate TPR and FPR
- fpr, tpr, thresholds = roc_curve(true_labels, predictionsProbabilities)
- # Calculate AUC
- roc_auc = auc(fpr, tpr)
- plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC: {roc_auc})')
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
- plt.xlim([0.0, 1.005])
- plt.ylim([0.0, 1.005])
- plt.xlabel('False Positive Rate (1 - Specificity)')
- plt.ylabel('True Positive Rate (Sensitivity)')
- plt.title('Receiver Operating Characteristic (ROC) Curve')
- plt.legend(loc="lower right")
- plt.savefig(f'./ROC_{time_string}.png')
- print("SAVED ROC FOR NORMAL")
- # plt.show()
- # Calculate confusion matrix
- cm = confusion_matrix(true_labels, predictionsLabels)
- # Plot confusion matrix
- plt.figure(figsize=(8, 6))
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
- plt.xlabel('Predicted labels')
- plt.ylabel('True labels')
- plt.title('Confusion Matrix')
- plt.savefig(f'./confusion_matrix_{time_string}.png')
- # plt.show()
- # Classification Report
- report = classification_report(true_labels, predictionsLabels)
- print(report)
- # f1_validation = f1_score(val_data, predictions.data)
- model.train()
- return loss.item() # , f1_validation)
- # PREDICT
- def predict(model, data):
- model.eval()
- with torch.no_grad():
- for data in data:
- images, labels = data[0].to(model.device), data[1].to(model.device)
- outputs = model.forward(images)
- # the class with the highest energy is what we choose as prediction
- _, predicted = torch.max(outputs.data, 1)
- model.train()
- return (labels, predicted) # RETURNS (true, predicted)
|