|
@@ -16,7 +16,10 @@ def load_models(folder, device):
|
|
|
for model_file in model_files:
|
|
|
model = torch.load(model_file, map_location=device)
|
|
|
models.append(model)
|
|
|
- model_descs.append(os.path.basename(model_file))
|
|
|
+
|
|
|
+ # Extract model description from filename
|
|
|
+ desc = os.path.basename(model_file)
|
|
|
+ model_descs.append(os.path.splitext(desc)[0])
|
|
|
|
|
|
return models, model_descs
|
|
|
|
|
@@ -46,10 +49,24 @@ def ensemble_predict_strict_classes(models, input):
|
|
|
# Apply model and extract prediction
|
|
|
output = model(input)
|
|
|
_, predicted = torch.max(output.data, 1)
|
|
|
- predictions.append(predicted)
|
|
|
+ predictions.append(predicted.item())
|
|
|
|
|
|
- predictions = torch.stack(predictions)
|
|
|
- pos_votes = (predictions == 1).sum()
|
|
|
- neg_votes = (predictions == 0).sum()
|
|
|
+ pos_votes = len([p for p in predictions if p == 1])
|
|
|
+ neg_votes = len([p for p in predictions if p == 0])
|
|
|
|
|
|
return pos_votes / len(models), pos_votes, neg_votes
|
|
|
+
|
|
|
+
|
|
|
+# Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
|
|
|
+def prune_models(models, model_descs, folder, threshold):
|
|
|
+ new_models = []
|
|
|
+ new_descs = []
|
|
|
+
|
|
|
+ for model, desc in zip(models, model_descs):
|
|
|
+ with open(folder + desc + "_test_acc.txt", "r") as f:
|
|
|
+ acc = float(f.read())
|
|
|
+ if acc >= threshold:
|
|
|
+ new_models.append(model)
|
|
|
+ new_descs.append(desc)
|
|
|
+
|
|
|
+ return new_models, new_descs
|