train_cnn.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. # GENERAL USE
  6. import random as rand
  7. # SYSTEM
  8. import tomli as toml
  9. import os
  10. # DATA PROCESSING
  11. from sklearn.model_selection import train_test_split
  12. # CUSTOM MODULES
  13. import utils.models.cnn as cnn
  14. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  15. import utils.training as train
  16. # CONFIGURATION
  17. if os.getenv("ADL_CONFIG_PATH") is None:
  18. with open("config.toml", "rb") as f:
  19. config = toml.load(f)
  20. else:
  21. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  22. config = toml.load(f)
  23. for i in range(config["training"]["runs"]):
  24. # Set up the model
  25. model = cnn.CNN(
  26. config["model"]["image_channels"],
  27. config["model"]["clin_data_channels"],
  28. config["hyperparameters"]["droprate"],
  29. ).float()
  30. criterion = nn.BCELoss()
  31. optimizer = optim.Adam(
  32. model.parameters(), lr=config["hyperparameters"]["learning_rate"]
  33. )
  34. # Generate seed for each run
  35. seed = rand.randint(0, 1000)
  36. # Prepare data
  37. train_dataset, val_dataset, test_dataset = prepare_datasets(
  38. config["paths"]["mri_data"],
  39. config["paths"]["xls_data"],
  40. config["dataset"]["validation_split"],
  41. seed,
  42. )
  43. train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
  44. train_dataset,
  45. val_dataset,
  46. test_dataset,
  47. config["hyperparameters"]["batch_size"],
  48. )
  49. # Train the model
  50. history = train.train_model(
  51. model, train_dataloader, val_dataloader, criterion, optimizer, config
  52. )
  53. # Save model
  54. if not os.path.exists(
  55. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  56. ):
  57. os.makedirs(
  58. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  59. )
  60. torch.save(
  61. model,
  62. config["paths"]["model_output"]
  63. + "/"
  64. + str(config["model"]["name"])
  65. + "/"
  66. + str(i)
  67. + "_s-"
  68. + str(seed)
  69. + ".pt",
  70. )