123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import torch
- import os
- from glob import glob
- # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation.
- def load_models(folder, device):
- glob_path = os.path.join(folder, "*.pt")
- model_files = glob(glob_path)
- models = []
- model_descs = []
- for model_file in model_files:
- model = torch.load(model_file, map_location=device)
- models.append(model)
- # Extract model description from filename
- desc = os.path.basename(model_file)
- model_descs.append(os.path.splitext(desc)[0])
- return models, model_descs
- def ensemble_predict(models, input):
- predictions = []
- for model in models:
- model.eval()
- with torch.no_grad():
- # Apply model and extract positive class predictions
- output = model(input)[:, 1]
- predictions.append(output)
- # Calculate mean and variance of predictions
- predictions = torch.stack(predictions)
- mean = predictions.mean()
- variance = predictions.var()
- return mean, variance
- def ensemble_predict_strict_classes(models, input):
- predictions = []
- for model in models:
- model.eval()
- with torch.no_grad():
- # Apply model and extract prediction
- output = model(input)
- _, predicted = torch.max(output.data, 1)
- predictions.append(predicted.item())
- pos_votes = len([p for p in predictions if p == 1])
- neg_votes = len([p for p in predictions if p == 0])
- return pos_votes / len(models), pos_votes, neg_votes
- # Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
- def prune_models(models, model_descs, folder, threshold):
- new_models = []
- new_descs = []
- for model, desc in zip(models, model_descs):
- with open(folder + desc + "_test_acc.txt", "r") as f:
- acc = float(f.read())
- if acc >= threshold:
- new_models.append(model)
- new_descs.append(desc)
- return new_models, new_descs
|