123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- import os
- import torch
- from utils.CNN import CNN_Net
- from torch import nn
- from torch.utils.data import DataLoader, ConcatDataset
- from torchvision import transforms
- from sklearn.model_selection import KFold, StratifiedKFold
- from utils.preprocess import prepare_datasets, prepare_predict
- import numpy as np
- import matplotlib.pyplot as plt
- def reset_weights(m):
- '''
- Try resetting model weights to avoid
- weight leakage.
- '''
- for layer in m.children():
- if hasattr(layer, 'reset_parameters'):
- print(f'Reset trainable parameters of layer = {layer}')
- layer.reset_parameters()
- if __name__ == '__main__':
- # Might have to replace datapaths or separate between training and testing
- model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
- CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
- # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
- mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
- annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
- properties = {
- "batch_size": 6,
- "padding": 0,
- "dilation": 1,
- "groups": 1,
- "bias": True,
- "padding_mode": "zeros",
- "drop_rate": 0
- }
- # Configuration options
- k_folds = 5 # TODO
- num_epochs = 1
- loss_function = nn.CrossEntropyLoss()
- # For fold results
- results = {}
- # Set fixed random number seed
- torch.manual_seed(42)
- training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
- dataset = ConcatDataset([training_data, test_data])
- # Define the K-fold Cross Validator
- kfold = KFold(n_splits=k_folds, shuffle=True)
- # Start print
- print('--------------------------------')
- # K-fold Cross Validation model evaluation
- for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
- # Print
- print(f'FOLD {fold}')
- print('--------------------------------')
- # Sample elements randomly from a given list of ids, no replacement.
- train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
- test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
- # Define data loaders for training and testing data in this fold
- trainloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=10, sampler=train_subsampler)
- testloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=10, sampler=test_subsampler)
- # Init the neural network
- network = CNN_Net(prps=properties, final_layer_size=2)
- network.apply(reset_weights)
- # Initialize optimizer
- optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)
- # Run the training loop for defined number of epochs
- for epoch in range(0, num_epochs):
- # Print epoch
- print(f'Starting epoch {epoch + 1}')
- # Set current loss value
- current_loss = 0.0
- # Iterate over the DataLoader for training data
- for i, data in enumerate(trainloader, 0):
- # Get inputs
- inputs, targets = data
- # Zero the gradients
- optimizer.zero_grad()
- # Perform forward pass
- outputs = network(inputs)
- # Compute loss
- loss = loss_function(outputs, targets)
- # Perform backward pass
- loss.backward()
- # Perform optimization
- optimizer.step()
- # Print statistics
- current_loss += loss.item()
- if i % 500 == 499:
- print('Loss after mini-batch %5d: %.3f' %
- (i + 1, current_loss / 500))
- current_loss = 0.0
- # Process is complete.
- print('Training process has finished. Saving trained model.')
- # Print about testing
- print('Starting testing')
- # Saving the model
- save_path = f'./model-fold-{fold}.pth'
- torch.save(network.state_dict(), save_path)
- # Evaluation for this fold
- correct, total = 0, 0
- with torch.no_grad():
- predictions = []
- true_labels = []
- # Iterate over the test data and generate predictions
- for i, data in enumerate(testloader, 0):
- # Get inputs
- inputs, targets = data
- # Generate outputs
- outputs = network(inputs)
- # Set total and correct
- _, predicted = torch.max(outputs.data, 1)
- total += targets.size(0)
- correct += (predicted == targets).sum().item()
- predictions.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
- true_labels.extend(targets.cpu().numpy())
- # Print accuracy
- print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
- print('--------------------------------')
- results[fold] = 100.0 * (correct / total)
- # MAKES ROC CURVE
- thresholds = np.linspace(0, 1, num=50)
- tpr = []
- fpr = []
- acc = []
- true_labels = np.array(true_labels)
- for threshold in thresholds:
- # Thresholding the predictions (meaning all predictions above threshold are considered positive)
- thresholded_predictions = (predictions >= threshold).astype(int)
- # Calculating true positives, false positives, true negatives, false negatives
- true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1))
- false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0))
- true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0))
- false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1))
- accuracy = (true_positives + true_negatives) / (
- true_positives + false_positives + true_negatives + false_negatives)
- # Calculate TPR and FPR
- tpr.append(true_positives / (true_positives + false_negatives))
- fpr.append(false_positives / (false_positives + true_negatives))
- acc.append(accuracy)
- plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold}')
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
- plt.xlim([0.0, 1.0])
- plt.ylim([0.0, 1.0])
- plt.xlabel('False Positive Rate (1 - Specificity)')
- plt.ylabel('True Positive Rate (Sensitivity)')
- plt.title('Receiver Operating Characteristic (ROC) Curve')
- plt.legend(loc="lower right")
- plt.savefig(f'./ROC_{k_folds}_Folds.png')
- plt.show()
- # Print fold results
- print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
- print('--------------------------------')
- sum = 0.0
- for key, value in results.items():
- print(f'Fold {key}: {value} %')
- sum += value
- print(f'Average: {sum / len(results.items())} %')
|