123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
- import torch
- import torch.nn as nn
- import os
- import tomli as toml
- from tqdm import tqdm
- from utils.models import cnn
- from utils.data.datasets import prepare_datasets, initalize_dataloaders
- # CONFIGURATION
- if os.getenv("ADL_CONFIG_PATH") is None:
- with open("config.toml", "rb") as f:
- config = toml.load(f)
- else:
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
- config = toml.load(f)
- model = cnn.CNN()
- # Convert the model to a Bayesian model
- model = dnn_to_bnn(model, prior_mu=0, prior_sigma=0.1)
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(
- model.parameters(), config["hyperparameters"]["learning_rate"]
- )
- train_set, val_set, test_set = prepare_datasets(
- config["paths"]["mri_data"],
- config["paths"]["xls_data"],
- config["dataset"]["validation_split"],
- config["training"]["device"],
- )
- train_loader, val_loader, test_loader = initalize_dataloaders(
- train_set, val_set, test_set, config["training"]["batch_size"]
- )
- # Train the model
- for epoch in range(config["training"]["epochs"]):
- print(f"Epoch {epoch + 1} / {config['training']['epochs']}")
- model.train()
- for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
- optimizer.zero_grad()
- output = model(data)
- loss = criterion(output, target)
- loss += get_kl_loss(model, config["hyperparameters"]["kl_weight"])
- loss = loss / len(data)
- loss.backward()
- optimizer.step()
-
- #Test the model
- model.eval()
- with torch.no_grad():
- output_li
-
-
|