ensemble.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch
  2. import os
  3. from glob import glob
  4. # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation.
  5. def load_models(folder, device):
  6. glob_path = os.path.join(folder, "*.pt")
  7. model_files = glob(glob_path)
  8. models = []
  9. model_descs = []
  10. for model_file in model_files:
  11. model = torch.load(model_file, map_location=device)
  12. models.append(model)
  13. # Extract model description from filename
  14. desc = os.path.basename(model_file)
  15. model_descs.append(os.path.splitext(desc)[0])
  16. return models, model_descs
  17. def ensemble_predict(models, input):
  18. predictions = []
  19. for model in models:
  20. model.eval()
  21. with torch.no_grad():
  22. # Apply model and extract positive class predictions
  23. output = model(input)[:, 1]
  24. predictions.append(output)
  25. # Calculate mean and variance of predictions
  26. predictions = torch.stack(predictions)
  27. mean = predictions.mean()
  28. variance = predictions.var()
  29. return mean, variance
  30. def ensemble_predict_strict_classes(models, input):
  31. predictions = []
  32. for model in models:
  33. model.eval()
  34. with torch.no_grad():
  35. # Apply model and extract prediction
  36. output = model(input)
  37. _, predicted = torch.max(output.data, 1)
  38. predictions.append(predicted.item())
  39. pos_votes = len([p for p in predictions if p == 1])
  40. neg_votes = len([p for p in predictions if p == 0])
  41. return pos_votes / len(models), pos_votes, neg_votes
  42. # Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
  43. def prune_models(models, model_descs, folder, threshold):
  44. new_models = []
  45. new_descs = []
  46. for model, desc in zip(models, model_descs):
  47. with open(folder + desc + "_test_acc.txt", "r") as f:
  48. acc = float(f.read())
  49. if acc >= threshold:
  50. new_models.append(model)
  51. new_descs.append(desc)
  52. return new_models, new_descs