train_cnn.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import shutil
  6. # GENERAL USE
  7. import random as rand
  8. # SYSTEM
  9. import tomli as toml
  10. import os
  11. import warnings
  12. # DATA PROCESSING
  13. # CUSTOM MODULES
  14. import utils.models.cnn as cnn
  15. from utils.data.datasets import prepare_datasets, initalize_dataloaders
  16. import utils.training as train
  17. import utils.testing as testn
  18. from utils.system import force_init_cudnn
  19. # CONFIGURATION
  20. if os.getenv("ADL_CONFIG_PATH") is None:
  21. with open("config.toml", "rb") as f:
  22. config = toml.load(f)
  23. else:
  24. with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
  25. config = toml.load(f)
  26. # Force cuDNN initialization
  27. force_init_cudnn(config["training"]["device"])
  28. # Generate seed for each set of runs
  29. seed = rand.randint(0, 1000)
  30. for i in range(config["training"]["runs"]):
  31. # Set up the model
  32. model = (
  33. cnn.CNN(
  34. config["model"]["image_channels"],
  35. config["model"]["clin_data_channels"],
  36. config["hyperparameters"]["droprate"],
  37. )
  38. .float()
  39. .to(config["training"]["device"])
  40. )
  41. criterion = nn.BCELoss()
  42. optimizer = optim.Adam(
  43. model.parameters(), lr=config["hyperparameters"]["learning_rate"]
  44. )
  45. # Prepare data
  46. train_dataset, val_dataset, test_dataset = prepare_datasets(
  47. config["paths"]["mri_data"],
  48. config["paths"]["xls_data"],
  49. config["dataset"]["validation_split"],
  50. seed,
  51. config["training"]["device"],
  52. )
  53. train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
  54. train_dataset,
  55. val_dataset,
  56. test_dataset,
  57. config["hyperparameters"]["batch_size"],
  58. )
  59. runs_num = config["training"]["runs"]
  60. if not config["operation"]["silent"]:
  61. print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
  62. # Train the model
  63. with warnings.catch_warnings():
  64. warnings.simplefilter("ignore")
  65. history = train.train_model(
  66. model, train_dataloader, val_dataloader, criterion, optimizer, config
  67. )
  68. # Test Model
  69. tes_acc = testn.test_model(model, test_dataloader, config)
  70. # Save model
  71. if not os.path.exists(
  72. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  73. ):
  74. os.makedirs(
  75. config["paths"]["model_output"] + "/" + str(config["model"]["name"])
  76. )
  77. model_save_path = (
  78. config["paths"]["model_output"]
  79. + "/"
  80. + str(config["model"]["name"])
  81. + "/"
  82. + str(i + 1)
  83. + "_s-"
  84. + str(seed)
  85. )
  86. torch.save(
  87. model,
  88. model_save_path + ".pt",
  89. )
  90. history.to_csv(
  91. model_save_path + "_history.csv",
  92. index=True,
  93. )
  94. with open(model_save_path + "_test_acc.txt", "w") as f:
  95. f.write(str(tes_acc))