train_cnn.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import shutil
  6. # GENERAL USE
  7. import random as rand
  8. # SYSTEM
  9. import tomli as toml
  10. import os
  11. import warnings
  12. # DATA PROCESSING
  13. # CUSTOM MODULES
  14. import utils.models.cnn as cnn
  15. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  16. import utils.training as train
  17. import utils.testing as testn
  18. from utils.system import force_init_cudnn
  19. # CONFIGURATION
  20. if os.getenv('ADL_CONFIG_PATH') is None:
  21. with open('config.toml', 'rb') as f:
  22. config = toml.load(f)
  23. else:
  24. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  25. config = toml.load(f)
  26. # Force cuDNN initialization
  27. force_init_cudnn(config['training']['device'])
  28. # Generate seed for each set of runs
  29. seed = rand.randint(0, 1000)
  30. # Prepare data
  31. train_dataset, val_dataset, test_dataset = prepare_datasets(
  32. config['paths']['mri_data'],
  33. config['paths']['xls_data'],
  34. config['dataset']['validation_split'],
  35. seed,
  36. config['training']['device'],
  37. )
  38. train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
  39. train_dataset,
  40. val_dataset,
  41. test_dataset,
  42. config['hyperparameters']['batch_size'],
  43. )
  44. # Save datasets
  45. model_folder_path = (
  46. config['paths']['model_output'] + '/' + str(config['model']['name']) + '/'
  47. )
  48. if not os.path.exists(model_folder_path):
  49. os.makedirs(model_folder_path)
  50. torch.save(train_dataset, model_folder_path + 'train_dataset.pt')
  51. torch.save(val_dataset, model_folder_path + 'val_dataset.pt')
  52. torch.save(test_dataset, model_folder_path + 'test_dataset.pt')
  53. for i in range(config['training']['runs']):
  54. # Set up the model
  55. model = (
  56. cnn.CNN(
  57. config['model']['image_channels'],
  58. config['model']['clin_data_channels'],
  59. config['hyperparameters']['droprate'],
  60. )
  61. .float()
  62. .to(config['training']['device'])
  63. )
  64. criterion = nn.BCELoss()
  65. optimizer = optim.Adam(
  66. model.parameters(), lr=config['hyperparameters']['learning_rate']
  67. )
  68. runs_num = config['training']['runs']
  69. if not config['operation']['silent']:
  70. print(f'Training model {i + 1} / {runs_num} with seed {seed}...')
  71. # Train the model
  72. with warnings.catch_warnings():
  73. warnings.simplefilter('ignore')
  74. history = train.train_model(
  75. model, train_dataloader, val_dataloader, criterion, optimizer, config
  76. )
  77. # Test Model
  78. tes_acc = testn.test_model(model, test_dataloader, config)
  79. # Save model
  80. if not os.path.exists(
  81. config['paths']['model_output'] + str(config['model']['name'] + '/models/')
  82. ):
  83. os.makedirs(
  84. config['paths']['model_output'] + str(config['model']['name']) + '/models/'
  85. )
  86. model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed)
  87. torch.save(
  88. model,
  89. model_save_path + '.pt',
  90. )
  91. history.to_csv(
  92. model_save_path + '_history.csv',
  93. index=True,
  94. )
  95. with open(model_folder_path + 'summary.txt', 'a') as f:
  96. f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')