|
@@ -2,12 +2,9 @@
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
-import torchvision
|
|
|
|
|
|
#GENERAL USE
|
|
|
-import numpy as np
|
|
|
-import pandas as pd
|
|
|
-from datetime import datetime
|
|
|
+import random as rand
|
|
|
|
|
|
#SYSTEM
|
|
|
import tomli as toml
|
|
@@ -18,6 +15,8 @@ from sklearn.model_selection import train_test_split
|
|
|
|
|
|
#CUSTOM MODULES
|
|
|
import utils.models.cnn as cnn
|
|
|
+from utils.data.datasets import prepare_datasets, initalize_dataloaders
|
|
|
+import utils.training as train
|
|
|
|
|
|
#CONFIGURATION
|
|
|
if os.getenv('ADL_CONFIG_PATH') is None:
|
|
@@ -27,14 +26,25 @@ else:
|
|
|
with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
|
|
|
config = toml.load(f)
|
|
|
|
|
|
-
|
|
|
-#Set up the model
|
|
|
-model = cnn.CNN(config)
|
|
|
-criterion = nn.BCELoss()
|
|
|
-optimizer = optim.Adam(model.parameters(), lr = config['training']['learning_rate'])
|
|
|
-
|
|
|
-#Load datasets
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+for i in range(config['training']['runs']):
|
|
|
+ #Set up the model
|
|
|
+ model = cnn.CNN(config['model']['image_channels'], config['model']['clin_data_channels'], config['hyperparameters']['droprate']).float()
|
|
|
+ criterion = nn.BCELoss()
|
|
|
+ optimizer = optim.Adam(model.parameters(), lr = config['hyperparameters']['learning_rate'])
|
|
|
+
|
|
|
+ #Generate seed for each run
|
|
|
+ seed = rand.randint(0, 1000)
|
|
|
+
|
|
|
+ #Prepare data
|
|
|
+ train_dataset, val_dataset, test_dataset = prepare_datasets(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed)
|
|
|
+ train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(train_dataset, val_dataset, test_dataset, config['hyperparameters']['batch_size'])
|
|
|
+
|
|
|
+ #Train the model
|
|
|
+ history = train.train_model(model, train_dataloader, val_dataloader, criterion, optimizer, config)
|
|
|
+
|
|
|
+ #Save model
|
|
|
+ if not os.path.exists(config['paths']['model_output'] + "/" + str(config['model']['name'])):
|
|
|
+ os.makedirs(config['paths']['model_output'] + "/" + str(config['model']['name']))
|
|
|
+
|
|
|
+ torch.save(model, config['paths']['model_output'] + "/" + str(config['model']['name']) + "/" + str(i) + "_" + "s-" + str(seed) + ".pt")
|
|
|
|