|
@@ -0,0 +1,211 @@
|
|
|
+import os
|
|
|
+import torch
|
|
|
+from utils.CNN import CNN_Net
|
|
|
+from torch import nn
|
|
|
+from torch.utils.data import DataLoader, ConcatDataset
|
|
|
+from torchvision import transforms
|
|
|
+from sklearn.model_selection import KFold, StratifiedKFold
|
|
|
+from utils.preprocess import prepare_datasets, prepare_predict
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
+
|
|
|
+def reset_weights(m):
|
|
|
+ '''
|
|
|
+ Try resetting model weights to avoid
|
|
|
+ weight leakage.
|
|
|
+ '''
|
|
|
+ for layer in m.children():
|
|
|
+ if hasattr(layer, 'reset_parameters'):
|
|
|
+ print(f'Reset trainable parameters of layer = {layer}')
|
|
|
+ layer.reset_parameters()
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+
|
|
|
+ # Might have to replace datapaths or separate between training and testing
|
|
|
+ model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
|
|
|
+ CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
|
|
|
+ # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
|
|
|
+ mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
|
|
|
+ annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
|
|
|
+
|
|
|
+ properties = {
|
|
|
+ "batch_size": 6,
|
|
|
+ "padding": 0,
|
|
|
+ "dilation": 1,
|
|
|
+ "groups": 1,
|
|
|
+ "bias": True,
|
|
|
+ "padding_mode": "zeros",
|
|
|
+ "drop_rate": 0
|
|
|
+ }
|
|
|
+
|
|
|
+ # Configuration options
|
|
|
+ k_folds = 5 # TODO
|
|
|
+ num_epochs = 1
|
|
|
+ loss_function = nn.CrossEntropyLoss()
|
|
|
+
|
|
|
+ # For fold results
|
|
|
+ results = {}
|
|
|
+
|
|
|
+ # Set fixed random number seed
|
|
|
+ torch.manual_seed(42)
|
|
|
+
|
|
|
+ training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
|
|
|
+
|
|
|
+ dataset = ConcatDataset([training_data, test_data])
|
|
|
+
|
|
|
+ # Define the K-fold Cross Validator
|
|
|
+ kfold = KFold(n_splits=k_folds, shuffle=True)
|
|
|
+
|
|
|
+ # Start print
|
|
|
+ print('--------------------------------')
|
|
|
+
|
|
|
+ # K-fold Cross Validation model evaluation
|
|
|
+ for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
|
|
|
+
|
|
|
+ # Print
|
|
|
+ print(f'FOLD {fold}')
|
|
|
+ print('--------------------------------')
|
|
|
+
|
|
|
+ # Sample elements randomly from a given list of ids, no replacement.
|
|
|
+ train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
|
|
|
+ test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
|
|
|
+
|
|
|
+ # Define data loaders for training and testing data in this fold
|
|
|
+ trainloader = torch.utils.data.DataLoader(
|
|
|
+ dataset,
|
|
|
+ batch_size=10, sampler=train_subsampler)
|
|
|
+ testloader = torch.utils.data.DataLoader(
|
|
|
+ dataset,
|
|
|
+ batch_size=10, sampler=test_subsampler)
|
|
|
+
|
|
|
+ # Init the neural network
|
|
|
+ network = CNN_Net(prps=properties, final_layer_size=2)
|
|
|
+ network.apply(reset_weights)
|
|
|
+
|
|
|
+ # Initialize optimizer
|
|
|
+ optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)
|
|
|
+
|
|
|
+ # Run the training loop for defined number of epochs
|
|
|
+ for epoch in range(0, num_epochs):
|
|
|
+
|
|
|
+ # Print epoch
|
|
|
+ print(f'Starting epoch {epoch + 1}')
|
|
|
+
|
|
|
+ # Set current loss value
|
|
|
+ current_loss = 0.0
|
|
|
+
|
|
|
+ # Iterate over the DataLoader for training data
|
|
|
+ for i, data in enumerate(trainloader, 0):
|
|
|
+
|
|
|
+ # Get inputs
|
|
|
+ inputs, targets = data
|
|
|
+
|
|
|
+ # Zero the gradients
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ # Perform forward pass
|
|
|
+ outputs = network(inputs)
|
|
|
+
|
|
|
+ # Compute loss
|
|
|
+ loss = loss_function(outputs, targets)
|
|
|
+
|
|
|
+ # Perform backward pass
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ # Perform optimization
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ # Print statistics
|
|
|
+ current_loss += loss.item()
|
|
|
+ if i % 500 == 499:
|
|
|
+ print('Loss after mini-batch %5d: %.3f' %
|
|
|
+ (i + 1, current_loss / 500))
|
|
|
+ current_loss = 0.0
|
|
|
+
|
|
|
+ # Process is complete.
|
|
|
+ print('Training process has finished. Saving trained model.')
|
|
|
+
|
|
|
+ # Print about testing
|
|
|
+ print('Starting testing')
|
|
|
+
|
|
|
+ # Saving the model
|
|
|
+ save_path = f'./model-fold-{fold}.pth'
|
|
|
+ torch.save(network.state_dict(), save_path)
|
|
|
+
|
|
|
+ # Evaluation for this fold
|
|
|
+ correct, total = 0, 0
|
|
|
+ with torch.no_grad():
|
|
|
+
|
|
|
+ predictions = []
|
|
|
+ true_labels = []
|
|
|
+
|
|
|
+ # Iterate over the test data and generate predictions
|
|
|
+ for i, data in enumerate(testloader, 0):
|
|
|
+ # Get inputs
|
|
|
+ inputs, targets = data
|
|
|
+
|
|
|
+ # Generate outputs
|
|
|
+ outputs = network(inputs)
|
|
|
+
|
|
|
+ # Set total and correct
|
|
|
+ _, predicted = torch.max(outputs.data, 1)
|
|
|
+ total += targets.size(0)
|
|
|
+ correct += (predicted == targets).sum().item()
|
|
|
+
|
|
|
+ predictions.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
|
|
|
+ true_labels.extend(targets.cpu().numpy())
|
|
|
+
|
|
|
+ # Print accuracy
|
|
|
+ print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
|
|
|
+ print('--------------------------------')
|
|
|
+ results[fold] = 100.0 * (correct / total)
|
|
|
+
|
|
|
+
|
|
|
+ # MAKES ROC CURVE
|
|
|
+ thresholds = np.linspace(0, 1, num=50)
|
|
|
+ tpr = []
|
|
|
+ fpr = []
|
|
|
+ acc = []
|
|
|
+
|
|
|
+ true_labels = np.array(true_labels)
|
|
|
+
|
|
|
+ for threshold in thresholds:
|
|
|
+ # Thresholding the predictions (meaning all predictions above threshold are considered positive)
|
|
|
+ thresholded_predictions = (predictions >= threshold).astype(int)
|
|
|
+
|
|
|
+ # Calculating true positives, false positives, true negatives, false negatives
|
|
|
+ true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1))
|
|
|
+ false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0))
|
|
|
+ true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0))
|
|
|
+ false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1))
|
|
|
+
|
|
|
+ accuracy = (true_positives + true_negatives) / (
|
|
|
+ true_positives + false_positives + true_negatives + false_negatives)
|
|
|
+
|
|
|
+ # Calculate TPR and FPR
|
|
|
+ tpr.append(true_positives / (true_positives + false_negatives))
|
|
|
+ fpr.append(false_positives / (false_positives + true_negatives))
|
|
|
+ acc.append(accuracy)
|
|
|
+
|
|
|
+ plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold}')
|
|
|
+ plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
|
|
+ plt.xlim([0.0, 1.0])
|
|
|
+ plt.ylim([0.0, 1.0])
|
|
|
+
|
|
|
+ plt.xlabel('False Positive Rate (1 - Specificity)')
|
|
|
+ plt.ylabel('True Positive Rate (Sensitivity)')
|
|
|
+ plt.title('Receiver Operating Characteristic (ROC) Curve')
|
|
|
+ plt.legend(loc="lower right")
|
|
|
+
|
|
|
+ plt.savefig(f'./ROC_{k_folds}_Folds.png')
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ # Print fold results
|
|
|
+ print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
|
|
|
+ print('--------------------------------')
|
|
|
+ sum = 0.0
|
|
|
+ for key, value in results.items():
|
|
|
+ print(f'Fold {key}: {value} %')
|
|
|
+ sum += value
|
|
|
+ print(f'Average: {sum / len(results.items())} %')
|