train_cnn.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. #GENERAL USE
  6. import random as rand
  7. #SYSTEM
  8. import tomli as toml
  9. import os
  10. #DATA PROCESSING
  11. from sklearn.model_selection import train_test_split
  12. #CUSTOM MODULES
  13. import utils.models.cnn as cnn
  14. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  15. import utils.training as train
  16. #CONFIGURATION
  17. if os.getenv('ADL_CONFIG_PATH') is None:
  18. with open ('config.toml', 'rb') as f:
  19. config = toml.load(f)
  20. else:
  21. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  22. config = toml.load(f)
  23. for i in range(config['training']['runs']):
  24. #Set up the model
  25. model = cnn.CNN(config['model']['image_channels'], config['model']['clin_data_channels'], config['hyperparameters']['droprate']).float()
  26. criterion = nn.BCELoss()
  27. optimizer = optim.Adam(model.parameters(), lr = config['hyperparameters']['learning_rate'])
  28. #Generate seed for each run
  29. seed = rand.randint(0, 1000)
  30. #Prepare data
  31. train_dataset, val_dataset, test_dataset = prepare_datasets(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed)
  32. train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(train_dataset, val_dataset, test_dataset, config['hyperparameters']['batch_size'])
  33. #Train the model
  34. history = train.train_model(model, train_dataloader, val_dataloader, criterion, optimizer, config)
  35. #Save model
  36. if not os.path.exists(config['paths']['model_output'] + "/" + str(config['model']['name'])):
  37. os.makedirs(config['paths']['model_output'] + "/" + str(config['model']['name']))
  38. torch.save(model, config['paths']['model_output'] + "/" + str(config['model']['name']) + "/" + str(i) + "_" + "s-" + str(seed) + ".pt")