|
@@ -23,57 +23,69 @@ from utils.system import force_init_cudnn
|
|
|
|
|
|
|
|
|
# CONFIGURATION
|
|
|
-if os.getenv("ADL_CONFIG_PATH") is None:
|
|
|
- with open("config.toml", "rb") as f:
|
|
|
+if os.getenv('ADL_CONFIG_PATH') is None:
|
|
|
+ with open('config.toml', 'rb') as f:
|
|
|
config = toml.load(f)
|
|
|
else:
|
|
|
- with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
|
|
|
+ with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
|
|
|
config = toml.load(f)
|
|
|
|
|
|
# Force cuDNN initialization
|
|
|
-force_init_cudnn(config["training"]["device"])
|
|
|
+force_init_cudnn(config['training']['device'])
|
|
|
# Generate seed for each set of runs
|
|
|
seed = rand.randint(0, 1000)
|
|
|
|
|
|
-for i in range(config["training"]["runs"]):
|
|
|
+# Prepare data
|
|
|
+train_dataset, val_dataset, test_dataset = prepare_datasets(
|
|
|
+ config['paths']['mri_data'],
|
|
|
+ config['paths']['xls_data'],
|
|
|
+ config['dataset']['validation_split'],
|
|
|
+ seed,
|
|
|
+ config['training']['device'],
|
|
|
+)
|
|
|
+train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
|
|
|
+ train_dataset,
|
|
|
+ val_dataset,
|
|
|
+ test_dataset,
|
|
|
+ config['hyperparameters']['batch_size'],
|
|
|
+)
|
|
|
+
|
|
|
+# Save datasets
|
|
|
+model_folder_path = (
|
|
|
+ config['paths']['model_output'] + '/' + str(config['model']['name']) + '/'
|
|
|
+)
|
|
|
+
|
|
|
+if not os.path.exists(model_folder_path):
|
|
|
+ os.makedirs(model_folder_path)
|
|
|
+
|
|
|
+torch.save(train_dataset, model_folder_path + 'train_dataset.pt')
|
|
|
+torch.save(val_dataset, model_folder_path + 'val_dataset.pt')
|
|
|
+torch.save(test_dataset, model_folder_path + 'test_dataset.pt')
|
|
|
+
|
|
|
+for i in range(config['training']['runs']):
|
|
|
# Set up the model
|
|
|
model = (
|
|
|
cnn.CNN(
|
|
|
- config["model"]["image_channels"],
|
|
|
- config["model"]["clin_data_channels"],
|
|
|
- config["hyperparameters"]["droprate"],
|
|
|
+ config['model']['image_channels'],
|
|
|
+ config['model']['clin_data_channels'],
|
|
|
+ config['hyperparameters']['droprate'],
|
|
|
)
|
|
|
.float()
|
|
|
- .to(config["training"]["device"])
|
|
|
+ .to(config['training']['device'])
|
|
|
)
|
|
|
criterion = nn.BCELoss()
|
|
|
optimizer = optim.Adam(
|
|
|
- model.parameters(), lr=config["hyperparameters"]["learning_rate"]
|
|
|
+ model.parameters(), lr=config['hyperparameters']['learning_rate']
|
|
|
)
|
|
|
|
|
|
- # Prepare data
|
|
|
- train_dataset, val_dataset, test_dataset = prepare_datasets(
|
|
|
- config["paths"]["mri_data"],
|
|
|
- config["paths"]["xls_data"],
|
|
|
- config["dataset"]["validation_split"],
|
|
|
- seed,
|
|
|
- config["training"]["device"],
|
|
|
- )
|
|
|
- train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
|
|
|
- train_dataset,
|
|
|
- val_dataset,
|
|
|
- test_dataset,
|
|
|
- config["hyperparameters"]["batch_size"],
|
|
|
- )
|
|
|
+ runs_num = config['training']['runs']
|
|
|
|
|
|
- runs_num = config["training"]["runs"]
|
|
|
-
|
|
|
- if not config["operation"]["silent"]:
|
|
|
- print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
|
|
|
+ if not config['operation']['silent']:
|
|
|
+ print(f'Training model {i + 1} / {runs_num} with seed {seed}...')
|
|
|
|
|
|
# Train the model
|
|
|
with warnings.catch_warnings():
|
|
|
- warnings.simplefilter("ignore")
|
|
|
+ warnings.simplefilter('ignore')
|
|
|
|
|
|
history = train.train_model(
|
|
|
model, train_dataloader, val_dataloader, criterion, optimizer, config
|
|
@@ -84,31 +96,23 @@ for i in range(config["training"]["runs"]):
|
|
|
|
|
|
# Save model
|
|
|
if not os.path.exists(
|
|
|
- config["paths"]["model_output"] + "/" + str(config["model"]["name"])
|
|
|
+ config['paths']['model_output'] + '/' + str(config['model']['name'])
|
|
|
):
|
|
|
os.makedirs(
|
|
|
- config["paths"]["model_output"] + "/" + str(config["model"]["name"])
|
|
|
+ config['paths']['model_output'] + '/' + str(config['model']['name'])
|
|
|
)
|
|
|
|
|
|
- model_save_path = (
|
|
|
- config["paths"]["model_output"]
|
|
|
- + "/"
|
|
|
- + str(config["model"]["name"])
|
|
|
- + "/"
|
|
|
- + str(i + 1)
|
|
|
- + "_s-"
|
|
|
- + str(seed)
|
|
|
- )
|
|
|
+ model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed)
|
|
|
|
|
|
torch.save(
|
|
|
model,
|
|
|
- model_save_path + ".pt",
|
|
|
+ model_save_path + '.pt',
|
|
|
)
|
|
|
|
|
|
history.to_csv(
|
|
|
- model_save_path + "_history.csv",
|
|
|
+ model_save_path + '_history.csv',
|
|
|
index=True,
|
|
|
)
|
|
|
|
|
|
- with open(model_save_path + "_test_acc.txt", "w") as f:
|
|
|
- f.write(str(tes_acc))
|
|
|
+ with open(model_save_path + 'summary.txt', 'a') as f:
|
|
|
+ f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')
|