Quellcode durchsuchen

Fixed summary saving

Nicholas Schense vor 5 Monaten
Ursprung
Commit
c19d086c91
2 geänderte Dateien mit 7 neuen und 7 gelöschten Zeilen
  1. 4 4
      config.toml
  2. 3 3
      train_cnn.py

+ 4 - 4
config.toml

@@ -7,8 +7,8 @@ model_output = '/export/home/nschense/alzheimers/alzheimers_nn/saved_models/'
 
 [training]
 device = 'cuda:1'
-runs = 10
-max_epochs = 10
+runs = 3
+max_epochs = 3
 
 [dataset]
 validation_split = 0.4 #Splits the dataset into the train and validation/test set, 50% each for validation and test
@@ -16,7 +16,7 @@ validation_split = 0.4 #Splits the dataset into the train and validation/test se
 #|splt*0.5  | split*0.5      | 1-split   |
 
 [model]
-name = 'cnn-10x10'
+name = 'cnn-3x3'
 image_channels = 1
 clin_data_channels = 2
 
@@ -29,5 +29,5 @@ droprate = 0.5
 silent = false
 
 [ensemble]
-name = 'cnn-10x10'
+name = 'cnn-3x3'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning

+ 3 - 3
train_cnn.py

@@ -96,10 +96,10 @@ for i in range(config['training']['runs']):
 
     # Save model
     if not os.path.exists(
-        config['paths']['model_output'] + '/' + str(config['model']['name'])
+        config['paths']['model_output'] + str(config['model']['name'] + '/models/')
     ):
         os.makedirs(
-            config['paths']['model_output'] + '/' + str(config['model']['name'])
+            config['paths']['model_output'] + str(config['model']['name']) + '/models/'
         )
 
     model_save_path = model_folder_path + 'models/' + str(i + 1) + '_s-' + str(seed)
@@ -114,5 +114,5 @@ for i in range(config['training']['runs']):
         index=True,
     )
 
-    with open(model_save_path + 'summary.txt', 'a') as f:
+    with open(model_folder_path + 'summary.txt', 'a') as f:
         f.write(f'{i + 1}: Test Accuracy: {tes_acc}\n')