| 
														
															@@ -23,57 +23,69 @@ from utils.system import force_init_cudnn 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # CONFIGURATION 
														 | 
														
														 | 
														
															 # CONFIGURATION 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-if os.getenv("ADL_CONFIG_PATH") is None: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    with open("config.toml", "rb") as f: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+if os.getenv('ADL_CONFIG_PATH') is None: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    with open('config.toml', 'rb') as f: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         config = toml.load(f) 
														 | 
														
														 | 
														
															         config = toml.load(f) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 else: 
														 | 
														
														 | 
														
															 else: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         config = toml.load(f) 
														 | 
														
														 | 
														
															         config = toml.load(f) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # Force cuDNN initialization 
														 | 
														
														 | 
														
															 # Force cuDNN initialization 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-force_init_cudnn(config["training"]["device"]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+force_init_cudnn(config['training']['device']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 # Generate seed for each set of runs 
														 | 
														
														 | 
														
															 # Generate seed for each set of runs 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 seed = rand.randint(0, 1000) 
														 | 
														
														 | 
														
															 seed = rand.randint(0, 1000) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-for i in range(config["training"]["runs"]): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# Prepare data 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+train_dataset, val_dataset, test_dataset = prepare_datasets( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['paths']['mri_data'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['paths']['xls_data'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['dataset']['validation_split'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    seed, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['training']['device'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    train_dataset, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    val_dataset, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    test_dataset, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['hyperparameters']['batch_size'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# Save datasets 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+model_folder_path = ( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    config['paths']['model_output'] + '/' + str(config['model']['name']) + '/' 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+if not os.path.exists(model_folder_path): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    os.makedirs(model_folder_path) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+torch.save(train_dataset, model_folder_path + 'train_dataset.pt') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+torch.save(val_dataset, model_folder_path + 'val_dataset.pt') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+torch.save(test_dataset, model_folder_path + 'test_dataset.pt') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+for i in range(config['training']['runs']): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # Set up the model 
														 | 
														
														 | 
														
															     # Set up the model 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     model = ( 
														 | 
														
														 | 
														
															     model = ( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         cnn.CNN( 
														 | 
														
														 | 
														
															         cnn.CNN( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            config["model"]["image_channels"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            config["model"]["clin_data_channels"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            config["hyperparameters"]["droprate"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config['model']['image_channels'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config['model']['clin_data_channels'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config['hyperparameters']['droprate'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ) 
														 | 
														
														 | 
														
															         ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         .float() 
														 | 
														
														 | 
														
															         .float() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        .to(config["training"]["device"]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        .to(config['training']['device']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ) 
														 | 
														
														 | 
														
															     ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     criterion = nn.BCELoss() 
														 | 
														
														 | 
														
															     criterion = nn.BCELoss() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     optimizer = optim.Adam( 
														 | 
														
														 | 
														
															     optimizer = optim.Adam( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        model.parameters(), lr=config["hyperparameters"]["learning_rate"] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        model.parameters(), lr=config['hyperparameters']['learning_rate'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ) 
														 | 
														
														 | 
														
															     ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    # Prepare data 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    train_dataset, val_dataset, test_dataset = prepare_datasets( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["paths"]["mri_data"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["paths"]["xls_data"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["dataset"]["validation_split"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        seed, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["training"]["device"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    ) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        train_dataset, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        val_dataset, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        test_dataset, 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["hyperparameters"]["batch_size"], 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    ) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    runs_num = config['training']['runs'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    runs_num = config["training"]["runs"] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    if not config["operation"]["silent"]: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        print(f"Training model {i + 1} / {runs_num} with seed {seed}...") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    if not config['operation']['silent']: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        print(f'Training model {i + 1} / {runs_num} with seed {seed}...') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # Train the model 
														 | 
														
														 | 
														
															     # Train the model 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     with warnings.catch_warnings(): 
														 | 
														
														 | 
														
															     with warnings.catch_warnings(): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        warnings.simplefilter("ignore") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        warnings.simplefilter('ignore') 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															         history = train.train_model( 
														 | 
														
														 | 
														
															         history = train.train_model( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															             model, train_dataloader, val_dataloader, criterion, optimizer, config 
														 | 
														
														 | 
														
															             model, train_dataloader, val_dataloader, criterion, optimizer, config 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -84,31 +96,23 @@ for i in range(config["training"]["runs"]): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # Save model 
														 | 
														
														 | 
														
															     # Save model 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     if not os.path.exists( 
														 | 
														
														 | 
														
															     if not os.path.exists( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["paths"]["model_output"] + "/" + str(config["model"]["name"]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        config['paths']['model_output'] + '/' + str(config['model']['name']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ): 
														 | 
														
														 | 
														
															     ): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         os.makedirs( 
														 | 
														
														 | 
														
															         os.makedirs( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            config["paths"]["model_output"] + "/" + str(config["model"]["name"]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config['paths']['model_output'] + '/' + str(config['model']['name']) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         ) 
														 | 
														
														 | 
														
															         ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    model_save_path = ( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        config["paths"]["model_output"] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + "/" 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + str(config["model"]["name"]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + "/" 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + str(i + 1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + "_s-" 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        + str(seed) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    ) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     torch.save( 
														 | 
														
														 | 
														
															     torch.save( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         model, 
														 | 
														
														 | 
														
															         model, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        model_save_path + ".pt", 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        model_save_path + '.pt', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ) 
														 | 
														
														 | 
														
															     ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     history.to_csv( 
														 | 
														
														 | 
														
															     history.to_csv( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        model_save_path + "_history.csv", 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        model_save_path + '_history.csv', 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         index=True, 
														 | 
														
														 | 
														
															         index=True, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     ) 
														 | 
														
														 | 
														
															     ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    with open(model_save_path + "_test_acc.txt", "w") as f: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        f.write(str(tes_acc)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    with open(model_save_path + 'summary.txt', 'a') as f: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n') 
														 |