# MACHINE LEARNING import torch import torch.nn as nn import torch.optim as optim import shutil # GENERAL USE import random as rand # SYSTEM import tomli as toml import os import warnings # DATA PROCESSING # CUSTOM MODULES import utils.models.cnn as cnn from utils.data.datasets import prepare_datasets, initalize_dataloaders import utils.training as train import utils.testing as testn from utils.system import force_init_cudnn # CONFIGURATION if os.getenv('ADL_CONFIG_PATH') is None: with open('config.toml', 'rb') as f: config = toml.load(f) else: with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: config = toml.load(f) # Force cuDNN initialization force_init_cudnn(config['training']['device']) # Generate seed for each set of runs seed = rand.randint(0, 1000) # Prepare data train_dataset, val_dataset, test_dataset = prepare_datasets( config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed, config['training']['device'], ) train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders( train_dataset, val_dataset, test_dataset, config['hyperparameters']['batch_size'], ) # Save datasets model_folder_path = ( config['paths']['model_output'] + '/' + str(config['model']['name']) + '/' ) if not os.path.exists(model_folder_path): os.makedirs(model_folder_path) torch.save(train_dataset, model_folder_path + 'train_dataset.pt') torch.save(val_dataset, model_folder_path + 'val_dataset.pt') torch.save(test_dataset, model_folder_path + 'test_dataset.pt') for i in range(config['training']['runs']): # Set up the model model = ( cnn.CNN( config['model']['image_channels'], config['model']['clin_data_channels'], config['hyperparameters']['droprate'], ) .float() .to(config['training']['device']) ) criterion = nn.BCELoss() optimizer = optim.Adam( model.parameters(), lr=config['hyperparameters']['learning_rate'] ) runs_num = config['training']['runs'] if not config['operation']['silent']: print(f'Training model {i + 1} / {runs_num} with seed {seed}...') # Train the model with warnings.catch_warnings(): warnings.simplefilter('ignore') history = train.train_model( model, train_dataloader, val_dataloader, criterion, optimizer, config ) # Test Model tes_acc = testn.test_model(model, test_dataloader, config) # Save model if not os.path.exists( config['paths']['model_output'] + str(config['model']['name'] + '/models/') ): os.makedirs( config['paths']['model_output'] + str(config['model']['name']) + '/models/' ) model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed) torch.save( model, model_save_path + '.pt', ) history.to_csv( model_save_path + '_history.csv', index=True, ) with open(model_folder_path + 'summary.txt', 'a') as f: f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')