Explorar o código

Path saving fixes

Nicholas Schense hai 5 meses
pai
achega
ca0b9e9c1f
Modificáronse 2 ficheiros con 13 adicións e 11 borrados
  1. 4 4
      config.toml
  2. 9 7
      threshold.py

+ 4 - 4
config.toml

@@ -7,8 +7,8 @@ model_output = '/export/home/nschense/alzheimers/alzheimers_nn/saved_models/'
 
 [training]
 device = 'cuda:1'
-runs = 3
-max_epochs = 3
+runs = 30
+max_epochs = 10
 
 [dataset]
 validation_split = 0.4 #Splits the dataset into the train and validation/test set, 50% each for validation and test
@@ -16,7 +16,7 @@ validation_split = 0.4 #Splits the dataset into the train and validation/test se
 #|splt*0.5  | split*0.5      | 1-split   |
 
 [model]
-name = 'cnn-3x3'
+name = 'cnn-30x10'
 image_channels = 1
 clin_data_channels = 2
 
@@ -29,5 +29,5 @@ droprate = 0.5
 silent = false
 
 [ensemble]
-name = 'cnn-3x3'
+name = 'cnn-30x10'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning

+ 9 - 7
threshold.py

@@ -23,18 +23,20 @@ else:
 # This function returns a list of the accuracies given a threshold
 def threshold(config):
     # First, get the model data
-    ts, vs, test_set = prepare_datasets(
-        config['paths']['mri_data'],
-        config['paths']['xls_data'],
-        config['dataset']['validation_split'],
-        944,
-        config['training']['device'],
+    test_set = torch.load(
+        config['paths']['model_output']
+        + config['ensemble']['name']
+        + '/test_dataset.pt'
+    )
+
+    vs = torch.load(
+        config['paths']['model_output'] + config['ensemble']['name'] + '/val_dataset.pt'
     )
 
     test_set = test_set + vs
 
     models, _ = ens.load_models(
-        config['paths']['model_output'] + config['ensemble']['name'] + '/',
+        config['paths']['model_output'] + config['ensemble']['name'] + '/models/',
         config['training']['device'],
     )