bayesian.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
  2. import torch
  3. import torch.nn as nn
  4. import os
  5. import tomli as toml
  6. from tqdm import tqdm
  7. from utils.models import cnn
  8. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  9. # CONFIGURATION
  10. if os.getenv("ADL_CONFIG_PATH") is None:
  11. with open("config.toml", "rb") as f:
  12. config = toml.load(f)
  13. else:
  14. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  15. config = toml.load(f)
  16. model = cnn.CNN()
  17. # Convert the model to a Bayesian model
  18. model = dnn_to_bnn(model, prior_mu=0, prior_sigma=0.1)
  19. criterion = nn.CrossEntropyLoss()
  20. optimizer = torch.optim.Adam(
  21. model.parameters(), config["hyperparameters"]["learning_rate"]
  22. )
  23. train_set, val_set, test_set = prepare_datasets(
  24. config["paths"]["mri_data"],
  25. config["paths"]["xls_data"],
  26. config["dataset"]["validation_split"],
  27. config["training"]["device"],
  28. )
  29. train_loader, val_loader, test_loader = initalize_dataloaders(
  30. train_set, val_set, test_set, config["training"]["batch_size"]
  31. )
  32. # Train the model
  33. for epoch in range(config["training"]["epochs"]):
  34. print(f"Epoch {epoch + 1} / {config['training']['epochs']}")
  35. model.train()
  36. for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
  37. optimizer.zero_grad()
  38. output = model(data)
  39. loss = criterion(output, target)
  40. loss += get_kl_loss(model, config["hyperparameters"]["kl_weight"])
  41. loss = loss / len(data)
  42. loss.backward()
  43. optimizer.step()
  44. #Test the model
  45. model.eval()
  46. with torch.no_grad():
  47. output_li