123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import os
- import torch
- from utils.train_methods import train, evaluate
- from utils.CNN import CNN_Net
- from torch import nn
- from torch.utils.data import DataLoader, ConcatDataset
- from torchvision import transforms
- from sklearn.model_selection import KFold, StratifiedKFold
- from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
- from utils.preprocess import prepare_datasets, prepare_predict
- import numpy as np
- import matplotlib.pyplot as plt
- import time
- def reset_weights(m):
- '''
- Try resetting model weights to avoid
- weight leakage.
- '''
- for layer in m.children():
- if hasattr(layer, 'reset_parameters'):
- print(f'Reset trainable parameters of layer = {layer}')
- layer.reset_parameters()
- if __name__ == '__main__':
- print("--- RUNNING K-FOLD ---")
- print("Pytorch Version: " + torch.__version__)
- current_time = time.localtime()
- print(time.strftime("%Y-%m-%d_%H:%M", current_time))
- # might have to replace datapaths or separate between training and testing
- model_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN'
- CNN_filepath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/cnn_net.pth' # cnn_net.pth
- # small dataset
- # mri_datapath = '/data/data_wnx1/rschuurs/Pytorch_CNN-RNN/PET_volumes_customtemplate_float32/' # Small Test
- # big dataset
- mri_datapath = '/data/data_wnx1/_Data/AlzheimersDL/CNN+RNN-2class-1cnn+data/PET_volumes_customtemplate_float32/' # Real data
- annotations_datapath = './data/data_wnx1/rschuurs/Pytorch_CNN-RNN/LP_ADNIMERGE.csv'
- params = {
- "batch_size": 6,
- "padding": 0,
- "dilation": 1,
- "groups": 1,
- "bias": True,
- "padding_mode": "zeros",
- "drop_rate": 0,
- "epochs": 15,
- }
- # Configuration options
- k_folds = 5
- # num_epochs = 10
- loss_function = nn.CrossEntropyLoss()
- # For fold results
- results = {}
- # Set fixed random number seed
- torch.manual_seed(42) # todo
- training_data, val_data, test_data = prepare_datasets(mri_datapath, val_split=0.2, seed=12)
- dataset = ConcatDataset([training_data, test_data])
- # Define the K-fold Cross Validator
- kfold = KFold(n_splits=k_folds, shuffle=True)
- # Start print
- print('--------------------------------')
- # K-fold Cross Validation model evaluation
- for fold, (train_ids, test_ids) in enumerate(kfold.split(training_data)):
- # Print
- print(f'FOLD {fold}')
- print('--------------------------------')
- # Sample elements randomly from a given list of ids, no replacement.
- train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
- test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
- # Define data loaders for training and testing data in this fold
- trainloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=10, sampler=train_subsampler)
- testloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=10, sampler=test_subsampler)
- # Init the neural model
- model = CNN_Net(prps=params, final_layer_size=2)
- model.apply(reset_weights)
- model.cuda()
- # Run the training loop for defined number of epochs
- train(model, trainloader, testloader, CNN_filepath, params=params, graphs=False)
- # Process is complete.
- print('Training process has finished. Saving trained model.')
- # Print about testing
- print('Starting testing')
- # Saving the model
- save_path = f'./model-fold-{fold}.pth'
- torch.save(model.state_dict(), save_path)
- # Evaluation for this fold
- results = evaluate(model, testloader, graphs=True, k_folds=k_folds, fold=fold, results=results)
- # Print fold results
- print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
- print('--------------------------------')
- sum = 0.0
- for key, value in results.items():
- print(f'Fold {key}: {value} %')
- sum += value
- print(f'Average: {sum / len(results.items())} %')
- # Saves to .txt if last one
- if(fold==k_folds-1):
- time_string = time.strftime("%Y-%m-%d_%H:%M", current_time)
- txt = open(f"{k_folds}_folds_{time_string}.txt", "w")
- txt.write('--------------------------------\n')
- txt.write(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS\n')
- txt.write('--------------------------------\n')
- sum = 0.0
- for key, value in results.items():
- txt.write(f'Fold {key}: {value}%\n')
- sum += value
- txt.write(f'Average: {sum / len(results.items())}%')
- txt.close()
|