from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss import torch import torch.nn as nn import os import tomli as toml from tqdm import tqdm from utils.models import cnn from utils.data.datasets import prepare_datasets, initalize_dataloaders # CONFIGURATION if os.getenv("ADL_CONFIG_PATH") is None: with open("config.toml", "rb") as f: config = toml.load(f) else: with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f: config = toml.load(f) model = cnn.CNN() # Convert the model to a Bayesian model model = dnn_to_bnn(model, prior_mu=0, prior_sigma=0.1) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam( model.parameters(), config["hyperparameters"]["learning_rate"] ) train_set, val_set, test_set = prepare_datasets( config["paths"]["mri_data"], config["paths"]["xls_data"], config["dataset"]["validation_split"], config["training"]["device"], ) train_loader, val_loader, test_loader = initalize_dataloaders( train_set, val_set, test_set, config["training"]["batch_size"] ) # Train the model for epoch in range(config["training"]["epochs"]): print(f"Epoch {epoch + 1} / {config['training']['epochs']}") model.train() for batch_idx, (data, target) in tqdm(enumerate(train_loader)): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss += get_kl_loss(model, config["hyperparameters"]["kl_weight"]) loss = loss / len(data) loss.backward() optimizer.step() #Test the model model.eval() with torch.no_grad(): output_li