123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # MACHINE LEARNING
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import shutil
- # GENERAL USE
- import random as rand
- # SYSTEM
- import tomli as toml
- import os
- import warnings
- # DATA PROCESSING
- # CUSTOM MODULES
- import utils.models.cnn as cnn
- from utils.data.datasets import prepare_datasets, initalize_dataloaders
- import utils.training as train
- import utils.testing as testn
- from utils.system import force_init_cudnn
- # CONFIGURATION
- if os.getenv("ADL_CONFIG_PATH") is None:
- with open("config.toml", "rb") as f:
- config = toml.load(f)
- else:
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
- config = toml.load(f)
- # Force cuDNN initialization
- force_init_cudnn(config["training"]["device"])
- # Generate seed for each set of runs
- seed = rand.randint(0, 1000)
- for i in range(config["training"]["runs"]):
- # Set up the model
- model = (
- cnn.CNN(
- config["model"]["image_channels"],
- config["model"]["clin_data_channels"],
- config["hyperparameters"]["droprate"],
- )
- .float()
- .to(config["training"]["device"])
- )
- criterion = nn.BCELoss()
- optimizer = optim.Adam(
- model.parameters(), lr=config["hyperparameters"]["learning_rate"]
- )
- # Prepare data
- train_dataset, val_dataset, test_dataset = prepare_datasets(
- config["paths"]["mri_data"],
- config["paths"]["xls_data"],
- config["dataset"]["validation_split"],
- seed,
- config["training"]["device"],
- )
- train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
- train_dataset,
- val_dataset,
- test_dataset,
- config["hyperparameters"]["batch_size"],
- )
- runs_num = config["training"]["runs"]
- if not config["operation"]["silent"]:
- print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
- # Train the model
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- history = train.train_model(
- model, train_dataloader, val_dataloader, criterion, optimizer, config
- )
- # Test Model
- tes_acc = testn.test_model(model, test_dataloader, config)
- # Save model
- if not os.path.exists(
- config["paths"]["model_output"] + "/" + str(config["model"]["name"])
- ):
- os.makedirs(
- config["paths"]["model_output"] + "/" + str(config["model"]["name"])
- )
- model_save_path = (
- config["paths"]["model_output"]
- + "/"
- + str(config["model"]["name"])
- + "/"
- + str(i + 1)
- + "_s-"
- + str(seed)
- )
- torch.save(
- model,
- model_save_path + ".pt",
- )
- history.to_csv(
- model_save_path + "_history.csv",
- index=True,
- )
- with open(model_save_path + "_test_acc.txt", "w") as f:
- f.write(str(tes_acc))
|