# MACHINE LEARNING
import torch
import torch.nn as nn
import torch.optim as optim
import shutil

# GENERAL USE
import random as rand

# SYSTEM
import tomli as toml
import os
import warnings

# DATA PROCESSING

# CUSTOM MODULES
import utils.models.cnn as cnn
from utils.data.datasets import prepare_datasets, initalize_dataloaders
import utils.training as train
import utils.testing as testn
from utils.system import force_init_cudnn


# CONFIGURATION
if os.getenv('ADL_CONFIG_PATH') is None:
    with open('config.toml', 'rb') as f:
        config = toml.load(f)
else:
    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
        config = toml.load(f)

# Force cuDNN initialization
force_init_cudnn(config['training']['device'])
# Generate seed for each set of runs
seed = rand.randint(0, 1000)

# Prepare data
train_dataset, val_dataset, test_dataset = prepare_datasets(
    config['paths']['mri_data'],
    config['paths']['xls_data'],
    config['dataset']['validation_split'],
    seed,
    config['training']['device'],
)
train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
    train_dataset,
    val_dataset,
    test_dataset,
    config['hyperparameters']['batch_size'],
)

# Save datasets
model_folder_path = (
    config['paths']['model_output'] + '/' + str(config['model']['name']) + '/'
)

if not os.path.exists(model_folder_path):
    os.makedirs(model_folder_path)

torch.save(train_dataset, model_folder_path + 'train_dataset.pt')
torch.save(val_dataset, model_folder_path + 'val_dataset.pt')
torch.save(test_dataset, model_folder_path + 'test_dataset.pt')

for i in range(config['training']['runs']):
    # Set up the model
    model = (
        cnn.CNN(
            config['model']['image_channels'],
            config['model']['clin_data_channels'],
            config['hyperparameters']['droprate'],
        )
        .float()
        .to(config['training']['device'])
    )
    criterion = nn.BCELoss()
    optimizer = optim.Adam(
        model.parameters(), lr=config['hyperparameters']['learning_rate']
    )

    runs_num = config['training']['runs']

    if not config['operation']['silent']:
        print(f'Training model {i + 1} / {runs_num} with seed {seed}...')

    # Train the model
    with warnings.catch_warnings():
        warnings.simplefilter('ignore')

        history = train.train_model(
            model, train_dataloader, val_dataloader, criterion, optimizer, config
        )

    # Test Model
    tes_acc = testn.test_model(model, test_dataloader, config)

    # Save model
    if not os.path.exists(
        config['paths']['model_output'] + str(config['model']['name'] + '/models/')
    ):
        os.makedirs(
            config['paths']['model_output'] + str(config['model']['name']) + '/models/'
        )

    model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed)

    torch.save(
        model,
        model_save_path + '.pt',
    )

    history.to_csv(
        model_save_path + '_history.csv',
        index=True,
    )

    with open(model_folder_path + 'summary.txt', 'a') as f:
        f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')