Kaynağa Gözat

Tweaks for ensembles, trained 10x10 ensemble, added ensemble pruning

Nicholas Schense 5 ay önce
ebeveyn
işleme
83204b3b99
4 değiştirilmiş dosya ile 38 ekleme ve 12 silme
  1. 2 1
      config.toml
  2. 11 3
      ensemble_predict.py
  3. 3 3
      train_cnn.py
  4. 22 5
      utils/ensemble.py

+ 2 - 1
config.toml

@@ -27,4 +27,5 @@ droprate = 0.5
 silent = false
 
 [ensemble]
-name = 'cnn-ensemble1'
+name = 'cnn-ensemble10x10'
+prune_threshold = 0.7      # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning

+ 11 - 3
ensemble_predict.py

@@ -20,6 +20,9 @@ force_init_cudnn(config["training"]["device"])
 
 ensemble_folder = config["paths"]["model_output"] + config["ensemble"]["name"] + "/"
 models, model_descs = ens.load_models(ensemble_folder, config["training"]["device"])
+models, model_descs = ens.prune_models(
+    models, model_descs, ensemble_folder, config["ensemble"]["prune_threshold"]
+)
 
 # Load test data
 test_dataset = prepare_datasets(
@@ -39,6 +42,7 @@ stdevs = []
 yes_votes = []
 no_votes = []
 
+
 for data, target in test_dataset:
     mri, xls = data
     mri = mri.unsqueeze(0)
@@ -62,7 +66,11 @@ for data, target in test_dataset:
 
 accuracy = correct / total
 
-with open(ensemble_folder + "ensemble_test_results.txt", "w") as f:
+with open(
+    ensemble_folder
+    + f"ensemble_test_results_{config['ensemble']['prune_threshold']}.txt",
+    "w",
+) as f:
     f.write("Accuracy: " + str(accuracy) + "\n")
     f.write("Correct: " + str(correct) + "\n")
     f.write("Total: " + str(total) + "\n")
@@ -75,8 +83,8 @@ with open(ensemble_folder + "ensemble_test_results.txt", "w") as f:
             + ", "
             + str(stdev)
             + ", "
-            + str(yes_votes.item())
+            + str(yes_votes)
             + ", "
-            + str(no_votes.item())
+            + str(no_votes)
             + "\n"
         )

+ 3 - 3
train_cnn.py

@@ -66,10 +66,10 @@ for i in range(config["training"]["runs"]):
         config["hyperparameters"]["batch_size"],
     )
 
+    runs_num = config["training"]["runs"]
+
     if not config["operation"]["silent"]:
-        print(
-            f"Training model {i + 1} / {config["training"]["runs"]} with seed {seed}..."
-        )
+        print(f"Training model {i + 1} / {runs_num} with seed {seed}...")
 
     # Train the model
     with warnings.catch_warnings():

+ 22 - 5
utils/ensemble.py

@@ -16,7 +16,10 @@ def load_models(folder, device):
     for model_file in model_files:
         model = torch.load(model_file, map_location=device)
         models.append(model)
-        model_descs.append(os.path.basename(model_file))
+
+        # Extract model description from filename
+        desc = os.path.basename(model_file)
+        model_descs.append(os.path.splitext(desc)[0])
 
     return models, model_descs
 
@@ -46,10 +49,24 @@ def ensemble_predict_strict_classes(models, input):
             # Apply model and extract prediction
             output = model(input)
             _, predicted = torch.max(output.data, 1)
-            predictions.append(predicted)
+            predictions.append(predicted.item())
 
-    predictions = torch.stack(predictions)
-    pos_votes = (predictions == 1).sum()
-    neg_votes = (predictions == 0).sum()
+    pos_votes = len([p for p in predictions if p == 1])
+    neg_votes = len([p for p in predictions if p == 0])
 
     return pos_votes / len(models), pos_votes, neg_votes
+
+
+# Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
+def prune_models(models, model_descs, folder, threshold):
+    new_models = []
+    new_descs = []
+
+    for model, desc in zip(models, model_descs):
+        with open(folder + desc + "_test_acc.txt", "r") as f:
+            acc = float(f.read())
+            if acc >= threshold:
+                new_models.append(model)
+                new_descs.append(desc)
+
+    return new_models, new_descs