K-fold.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import os
  2. import torch
  3. from utils.CNN import CNN_Net
  4. from torch import nn
  5. from torch.utils.data import DataLoader, ConcatDataset
  6. from torchvision import transforms
  7. from sklearn.model_selection import KFold, StratifiedKFold
  8. from utils.preprocess import prepare_datasets, prepare_predict
  9. import numpy as np
  10. import matplotlib.pyplot as plt
  11. def reset_weights(m):
  12. '''
  13. Try resetting model weights to avoid
  14. weight leakage.
  15. '''
  16. for layer in m.children():
  17. if hasattr(layer, 'reset_parameters'):
  18. print(f'Reset trainable parameters of layer = {layer}')
  19. layer.reset_parameters()
  20. if __name__ == '__main__':
  21. # Might have to replace datapaths or separate between training and testing
  22. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  23. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
  24. # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
  25. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
  26. annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
  27. properties = {
  28. "batch_size": 6,
  29. "padding": 0,
  30. "dilation": 1,
  31. "groups": 1,
  32. "bias": True,
  33. "padding_mode": "zeros",
  34. "drop_rate": 0
  35. }
  36. # Configuration options
  37. k_folds = 5 # TODO
  38. num_epochs = 1
  39. loss_function = nn.CrossEntropyLoss()
  40. # For fold results
  41. results = {}
  42. # Set fixed random number seed
  43. torch.manual_seed(42)
  44. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
  45. dataset = ConcatDataset([training_data, test_data])
  46. # Define the K-fold Cross Validator
  47. kfold = KFold(n_splits=k_folds, shuffle=True)
  48. # Start print
  49. print('--------------------------------')
  50. # K-fold Cross Validation model evaluation
  51. for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
  52. # Print
  53. print(f'FOLD {fold}')
  54. print('--------------------------------')
  55. # Sample elements randomly from a given list of ids, no replacement.
  56. train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
  57. test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
  58. # Define data loaders for training and testing data in this fold
  59. trainloader = torch.utils.data.DataLoader(
  60. dataset,
  61. batch_size=10, sampler=train_subsampler)
  62. testloader = torch.utils.data.DataLoader(
  63. dataset,
  64. batch_size=10, sampler=test_subsampler)
  65. # Init the neural network
  66. network = CNN_Net(prps=properties, final_layer_size=2)
  67. network.apply(reset_weights)
  68. # Initialize optimizer
  69. optimizer = torch.optim.Adam(network.parameters(), lr=1e-5)
  70. # Run the training loop for defined number of epochs
  71. for epoch in range(0, num_epochs):
  72. # Print epoch
  73. print(f'Starting epoch {epoch + 1}')
  74. # Set current loss value
  75. current_loss = 0.0
  76. # Iterate over the DataLoader for training data
  77. for i, data in enumerate(trainloader, 0):
  78. # Get inputs
  79. inputs, targets = data
  80. # Zero the gradients
  81. optimizer.zero_grad()
  82. # Perform forward pass
  83. outputs = network(inputs)
  84. # Compute loss
  85. loss = loss_function(outputs, targets)
  86. # Perform backward pass
  87. loss.backward()
  88. # Perform optimization
  89. optimizer.step()
  90. # Print statistics
  91. current_loss += loss.item()
  92. if i % 500 == 499:
  93. print('Loss after mini-batch %5d: %.3f' %
  94. (i + 1, current_loss / 500))
  95. current_loss = 0.0
  96. # Process is complete.
  97. print('Training process has finished. Saving trained model.')
  98. # Print about testing
  99. print('Starting testing')
  100. # Saving the model
  101. save_path = f'./model-fold-{fold}.pth'
  102. torch.save(network.state_dict(), save_path)
  103. # Evaluation for this fold
  104. correct, total = 0, 0
  105. with torch.no_grad():
  106. predictions = []
  107. true_labels = []
  108. # Iterate over the test data and generate predictions
  109. for i, data in enumerate(testloader, 0):
  110. # Get inputs
  111. inputs, targets = data
  112. # Generate outputs
  113. outputs = network(inputs)
  114. # Set total and correct
  115. _, predicted = torch.max(outputs.data, 1)
  116. total += targets.size(0)
  117. correct += (predicted == targets).sum().item()
  118. predictions.extend(outputs.data[:, 1].cpu().numpy()) # Grabs probability of positive
  119. true_labels.extend(targets.cpu().numpy())
  120. # Print accuracy
  121. print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
  122. print('--------------------------------')
  123. results[fold] = 100.0 * (correct / total)
  124. # MAKES ROC CURVE
  125. thresholds = np.linspace(0, 1, num=50)
  126. tpr = []
  127. fpr = []
  128. acc = []
  129. true_labels = np.array(true_labels)
  130. for threshold in thresholds:
  131. # Thresholding the predictions (meaning all predictions above threshold are considered positive)
  132. thresholded_predictions = (predictions >= threshold).astype(int)
  133. # Calculating true positives, false positives, true negatives, false negatives
  134. true_positives = np.sum((thresholded_predictions == 1) & (true_labels == 1))
  135. false_positives = np.sum((thresholded_predictions == 1) & (true_labels == 0))
  136. true_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 0))
  137. false_negatives = np.sum((thresholded_predictions == 0) & (true_labels == 1))
  138. accuracy = (true_positives + true_negatives) / (
  139. true_positives + false_positives + true_negatives + false_negatives)
  140. # Calculate TPR and FPR
  141. tpr.append(true_positives / (true_positives + false_negatives))
  142. fpr.append(false_positives / (false_positives + true_negatives))
  143. acc.append(accuracy)
  144. plt.plot(fpr, tpr, lw=2, label=f'ROC Fold {fold}')
  145. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  146. plt.xlim([0.0, 1.0])
  147. plt.ylim([0.0, 1.0])
  148. plt.xlabel('False Positive Rate (1 - Specificity)')
  149. plt.ylabel('True Positive Rate (Sensitivity)')
  150. plt.title('Receiver Operating Characteristic (ROC) Curve')
  151. plt.legend(loc="lower right")
  152. plt.savefig(f'./ROC_{k_folds}_Folds.png')
  153. plt.show()
  154. # Print fold results
  155. print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
  156. print('--------------------------------')
  157. sum = 0.0
  158. for key, value in results.items():
  159. print(f'Fold {key}: {value} %')
  160. sum += value
  161. print(f'Average: {sum / len(results.items())} %')