K-fold.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import os
  2. import torch
  3. from utils.train_methods import train, evaluate
  4. from utils.CNN import CNN_Net
  5. from torch import nn
  6. from torch.utils.data import DataLoader, ConcatDataset
  7. from torchvision import transforms
  8. from sklearn.model_selection import KFold, StratifiedKFold
  9. from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
  10. from utils.preprocess import prepare_datasets, prepare_predict
  11. import numpy as np
  12. import matplotlib.pyplot as plt
  13. import time
  14. def reset_weights(m):
  15. '''
  16. Try resetting model weights to avoid
  17. weight leakage.
  18. '''
  19. for layer in m.children():
  20. if hasattr(layer, 'reset_parameters'):
  21. print(f'Reset trainable parameters of layer = {layer}')
  22. layer.reset_parameters()
  23. if __name__ == '__main__':
  24. print("--- RUNNING K-FOLD ---")
  25. print("Pytorch Version: " + torch.__version__)
  26. current_time = time.localtime()
  27. print(time.strftime("%Y-%m-%d_%H:%M", current_time))
  28. # might have to replace datapaths or separate between training and testing
  29. model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
  30. CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
  31. # small dataset
  32. # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
  33. # big dataset
  34. mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
  35. annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
  36. params = {
  37. "batch_size": 6,
  38. "padding": 0,
  39. "dilation": 1,
  40. "groups": 1,
  41. "bias": True,
  42. "padding_mode": "zeros",
  43. "drop_rate": 0,
  44. "epochs": 15,
  45. }
  46. # Configuration options
  47. k_folds = 5
  48. # num_epochs = 10
  49. loss_function = nn.CrossEntropyLoss()
  50. # For fold results
  51. results = {}
  52. # Set fixed random number seed
  53. torch.manual_seed(42) # todo
  54. training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
  55. dataset = ConcatDataset([training_data, test_data])
  56. # Define the K-fold Cross Validator
  57. kfold = KFold(n_splits=k_folds, shuffle=True)
  58. # Start print
  59. print('--------------------------------')
  60. # K-fold Cross Validation model evaluation
  61. for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
  62. # Print
  63. print(f'FOLD {fold}')
  64. print('--------------------------------')
  65. # Sample elements randomly from a given list of ids, no replacement.
  66. train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
  67. test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
  68. # Define data loaders for training and testing data in this fold
  69. trainloader = torch.utils.data.DataLoader(
  70. dataset,
  71. batch_size=10, sampler=train_subsampler)
  72. testloader = torch.utils.data.DataLoader(
  73. dataset,
  74. batch_size=10, sampler=test_subsampler)
  75. # Init the neural model
  76. model = CNN_Net(prps=params, final_layer_size=2)
  77. model.apply(reset_weights)
  78. model.cuda()
  79. # Run the training loop for defined number of epochs
  80. train(model, trainloader, testloader, CNN_filepath, params=params, graphs=False)
  81. # Process is complete.
  82. print('Training process has finished. Saving trained model.')
  83. # Print about testing
  84. print('Starting testing')
  85. # Saving the model
  86. save_path = f'./model-fold-{fold}.pth'
  87. torch.save(model.state_dict(), save_path)
  88. # Evaluation for this fold
  89. results = evaluate(model, testloader, graphs=True, k_folds=k_folds, fold=fold, results=results)
  90. # Print fold results
  91. print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
  92. print('--------------------------------')
  93. sum = 0.0
  94. for key, value in results.items():
  95. print(f'Fold {key}: {value} %')
  96. sum += value
  97. print(f'Average: {sum / len(results.items())} %')
  98. # Saves to .txt if last one
  99. if(fold==k_folds-1):
  100. time_string = time.strftime("%Y-%m-%d_%H:%M", current_time)
  101. txt = open(f"{k_folds}_folds_{time_string}.txt", "w")
  102. txt.write('--------------------------------\n')
  103. txt.write(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS\n')
  104. txt.write('--------------------------------\n')
  105. sum = 0.0
  106. for key, value in results.items():
  107. txt.write(f'Fold {key}: {value}%\n')
  108. sum += value
  109. txt.write(f'Average: {sum / len(results.items())}%')
  110. txt.close()