train_cnn.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. # GENERAL USE
  6. import random as rand
  7. # SYSTEM
  8. import tomli as toml
  9. import os
  10. import warnings
  11. # DATA PROCESSING
  12. # CUSTOM MODULES
  13. import utils.models.cnn as cnn
  14. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  15. import utils.training as train
  16. import utils.testing as testn
  17. from utils.system import force_init_cudnn
  18. # CONFIGURATION
  19. if os.getenv("ADL_CONFIG_PATH") is None:
  20. with open("config.toml", "rb") as f:
  21. config = toml.load(f)
  22. else:
  23. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  24. config = toml.load(f)
  25. # Force cuDNN initialization
  26. force_init_cudnn(config["training"]["device"])
  27. for i in range(config["training"]["runs"]):
  28. # Set up the model
  29. model = (
  30. cnn.CNN(
  31. config["model"]["image_channels"],
  32. config["model"]["clin_data_channels"],
  33. config["hyperparameters"]["droprate"],
  34. )
  35. .float()
  36. .to(config["training"]["device"])
  37. )
  38. criterion = nn.BCELoss()
  39. optimizer = optim.Adam(
  40. model.parameters(), lr=config["hyperparameters"]["learning_rate"]
  41. )
  42. # Generate seed for each run
  43. seed = rand.randint(0, 1000)
  44. # Prepare data
  45. train_dataset, val_dataset, test_dataset = prepare_datasets(
  46. config["paths"]["mri_data"],
  47. config["paths"]["xls_data"],
  48. config["dataset"]["validation_split"],
  49. seed,
  50. config["training"]["device"],
  51. )
  52. train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
  53. train_dataset,
  54. val_dataset,
  55. test_dataset,
  56. config["hyperparameters"]["batch_size"],
  57. )
  58. runs_num = config["training"]["runs"]
  59. if not config["operation"]["silent"]:
  60. print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
  61. # Train the model
  62. with warnings.catch_warnings():
  63. warnings.simplefilter("ignore")
  64. history = train.train_model(
  65. model, train_dataloader, val_dataloader, criterion, optimizer, config
  66. )
  67. # Test Model
  68. tes_acc = testn.test_model(model, test_dataloader, config)
  69. # Save model
  70. if not os.path.exists(
  71. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  72. ):
  73. os.makedirs(
  74. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  75. )
  76. model_save_path = (
  77. config["paths"]["model_output"]
  78. + "/"
  79. + str(config["model"]["name"])
  80. + "/"
  81. + str(i + 1)
  82. + "_s-"
  83. + str(seed)
  84. )
  85. torch.save(
  86. model,
  87. model_save_path + ".pt",
  88. )
  89. history.to_csv(
  90. model_save_path + "_history.csv",
  91. index=True,
  92. )
  93. with open(model_save_path + "_test_acc.txt", "w") as f:
  94. f.write(str(tes_acc))