import os import torch from utils.CNN import CNN_Net from torch import nn from torch.utils.data import DataLoader, ConcatDataset from torchvision import transforms from sklearn.model_selection import KFold, StratifiedKFold from utils.preprocess import prepare_datasets, prepare_predict import numpy as np import matplotlib.pyplot as plt def reset_weights(m): ''' Try resetting model weights to avoid weight leakage. ''' for layer in m.children(): if hasattr(layer, 'reset_parameters'): print(f'Reset trainable parameters of layer = {layer}') layer.reset_parameters() if __name__ == '__main__': # Might have to replace datapaths or separate between training and testing model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN' CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv' properties = { "batch_size": 6, "padding": 0, "dilation": 1, "groups": 1, "bias": True, "padding_mode": "zeros", "drop_rate": 0 } # Configuration options k_folds = 5 # TODO num_epochs = 1 loss_function = nn.CrossEntropyLoss() # For fold results results = {} # Set fixed random number seed torch.manual_seed(42) training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12) dataset = ConcatDataset([training_data, test_data]) # Define the K-fold Cross Validator kfold = KFold(n_splits=k_folds, shuffle=True) # Start print print('--------------------------------') # K-fold Cross Validation model evaluation for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)): # Print print(f'FOLD {fold}') print('--------------------------------') # Sample elements randomly from a given list of ids, no replacement. train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids) test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids) # Define data loaders for training and testing data in this fold trainloader = torch.utils.data.DataLoader( dataset, batch_size=10, sampler=train_subsampler) testloader = torch.utils.data.DataLoader( dataset, batch_size=10, sampler=test_subsampler) # Init the neural network network = CNN_Net(prps=properties, final_layer_size=2) network.apply(reset_weights) # Initialize optimizer optimizer = torch.optim.Adam(network.parameters(), lr=1e-5) # Run the training loop for defined number of epochs for epoch in range(0, num_epochs): # Print epoch print(f'Starting epoch {epoch + 1}') # Set current loss value current_loss = 0.0 # Iterate over the DataLoader for training data for i, data in enumerate(trainloader, 0): # Get inputs inputs, targets = data # Zero the gradients optimizer.zero_grad() # Perform forward pass outputs = network(inputs) # Compute loss loss = loss_function(outputs, targets) # Perform backward pass loss.backward() # Perform optimization optimizer.step() # Print statistics current_loss += loss.item() if i % 500 == 499: print('Loss after mini-batch %5d: %.3f' % (i + 1, current_loss / 500)) current_loss = 0.0 # Process is complete. print('Training process has finished. Saving trained model.') # Print about testing print('Starting testing') # Saving the model save_path = f'./model-fold-{fold}.pth' torch.save(network.state_dict(), save_path) # Evaluation for this fold correct, total = 0, 0 with torch.no_grad(): predictions = [] true_labels = [] # Iterate over the test data and generate predictions for i, data in enumerate(testloader, 0): # Get inputs inputs, targets = data # Generate outputs outputs = network(inputs) # Set total and correct _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += (predicted == targets).sum().item() predictions.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive true_labels.extend(targets.cpu().numpy()) # Print accuracy print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total)) print('--------------------------------') results[fold] = 100.0 * (correct / total) # MAKES ROC CURVE thresholds = np.linspace(0, 1, num=50) tpr = [] fpr = [] acc = [] true_labels = np.array(true_labels) for threshold in thresholds: # Thresholding the predictions (meaning all predictions above threshold are considered positive) thresholded_predictions = (predictions >= threshold).astype(int) # Calculating true positives, false positives, true negatives, false negatives true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1)) false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0)) true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0)) false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1)) accuracy = (true_positives + true_negatives) / ( true_positives + false_positives + true_negatives + false_negatives) # Calculate TPR and FPR tpr.append(true_positives / (true_positives + false_negatives)) fpr.append(false_positives / (false_positives + true_negatives)) acc.append(accuracy) plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold}') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.0]) plt.xlabel('False Positive Rate (1 - Specificity)') plt.ylabel('True Positive Rate (Sensitivity)') plt.title('Receiver Operating Characteristic (ROC) Curve') plt.legend(loc="lower right") plt.savefig(f'./ROC_{k_folds}_Folds.png') plt.show() # Print fold results print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS') print('--------------------------------') sum = 0.0 for key, value in results.items(): print(f'Fold {key}: {value} %') sum += value print(f'Average: {sum / len(results.items())} %')