K-fold.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import os
  2. import torch
  3. from utils.train_methods import train, evaluate
  4. from utils.CNN import CNN_Net
  5. from torch import nn
  6. from torch.utils.data import DataLoader, ConcatDataset
  7. from torchvision import transforms
  8. from sklearn.model_selection import KFold, StratifiedKFold
  9. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
  10. from utils.preprocess import prepare_datasets, prepare_predict
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. import time
  14. def reset_weights(m):
  15. '''
  16. Try resetting model weights to avoid
  17. weight leakage.
  18. '''
  19. for layer in m.children():
  20. if hasattr(layer, 'reset_parameters'):
  21. print(f'Reset trainable parameters of layer = {layer}')
  22. layer.reset_parameters()
  23. if __name__ == '__main__':
  24. print("--- RUNNING K-FOLD ---")
  25. print("Pytorch Version: " + torch.__version__)
  26. current_time = time.localtime()
  27. print(time.strftime("%Y-%m-%d_%H:%M", current_time))
  28. # might have to replace datapaths or separate between training and testing
  29. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  30. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
  31. # small dataset
  32. # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
  33. # big dataset
  34. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
  35. annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
  36. properties = {
  37. "batch_size": 6,
  38. "padding": 0,
  39. "dilation": 1,
  40. "groups": 1,
  41. "bias": True,
  42. "padding_mode": "zeros",
  43. "drop_rate": 0
  44. }
  45. # Configuration options
  46. k_folds = 5
  47. num_epochs = 10
  48. loss_function = nn.CrossEntropyLoss()
  49. # For fold results
  50. results = {}
  51. # Set fixed random number seed
  52. torch.manual_seed(42) # todo
  53. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
  54. dataset = ConcatDataset([training_data, test_data])
  55. # Define the K-fold Cross Validator
  56. kfold = KFold(n_splits=k_folds, shuffle=True)
  57. # Start print
  58. print('--------------------------------')
  59. # K-fold Cross Validation model evaluation
  60. for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
  61. # Print
  62. print(f'FOLD {fold}')
  63. print('--------------------------------')
  64. # Sample elements randomly from a given list of ids, no replacement.
  65. train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
  66. test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
  67. # Define data loaders for training and testing data in this fold
  68. trainloader = torch.utils.data.DataLoader(
  69. dataset,
  70. batch_size=10, sampler=train_subsampler)
  71. testloader = torch.utils.data.DataLoader(
  72. dataset,
  73. batch_size=10, sampler=test_subsampler)
  74. # Init the neural model
  75. model = CNN_Net(prps=properties, final_layer_size=2)
  76. model.apply(reset_weights)
  77. model.cuda()
  78. # Run the training loop for defined number of epochs
  79. train(model, trainloader, testloader, CNN_filepath, epochs=num_epochs, graphs=True)
  80. # Process is complete.
  81. print('Training process has finished. Saving trained model.')
  82. # Print about testing
  83. print('Starting testing')
  84. # Saving the model
  85. save_path = f'./model-fold-{fold}.pth'
  86. torch.save(model.state_dict(), save_path)
  87. # Evaluation for this fold
  88. results = evaluate(model, testloader, graphs=True, k_folds=k_folds, fold=fold, results=results)
  89. # Print fold results
  90. print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
  91. print('--------------------------------')
  92. sum = 0.0
  93. for key, value in results.items():
  94. print(f'Fold {key}: {value} %')
  95. sum += value
  96. print(f'Average: {sum / len(results.items())} %')
  97. # Saves to .txt if last one
  98. if(fold==k_folds-1):
  99. time_string = time.strftime("%Y-%m-%d_%H:%M", current_time)
  100. txt = open(f"{k_folds}_folds_{time_string}.txt", "w")
  101. txt.write('--------------------------------')
  102. txt.write(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
  103. txt.write('--------------------------------')
  104. sum = 0.0
  105. for key, value in results.items():
  106. txt.write(f'Fold {key}: {value} %')
  107. sum += value
  108. txt.write(f'Average: {sum / len(results.items())} %')
  109. txt.close()