ensemble.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import torch
  2. import pathlib
  3. import utils.models.cnn as c
  4. from typing import Tuple, List
  5. import xarray as xr
  6. type ModelPair = Tuple[c.CNN, str]
  7. type ModelPredictionData = xr.DataArray
  8. type InputData = Tuple[torch.Tensor, torch.Tensor]
  9. # This file contains functions to ensemble a folder of models and evaluate them on a test set, with included uncertainty estimation.
  10. def load_models(folder: pathlib.Path, device: str) -> List[ModelPair]:
  11. model_files = folder.glob("*.pt")
  12. model_pairs: List[ModelPair] = []
  13. for model_file in model_files:
  14. model: c.CNN = torch.load(model_file, map_location=device, weights_only=False)
  15. # Extract model description from filename
  16. model_pairs.append((model, model_file.stem))
  17. return model_pairs
  18. def prepare_datasets(data: Tuple[torch.Tensor, torch.Tensor]) -> InputData:
  19. # Ensure the data is in the correct format
  20. mri_data.unsqueeze(0)
  21. xls_data.unsqueeze(0)
  22. # Combine MRI and XLS data into a tuple
  23. return (mri_data, xls_data)
  24. def get_model_names(models: List[ModelPair]) -> List[str]:
  25. # Extract model names from the model pairs
  26. return [model_pair[1] for model_pair in models]
  27. def get_model_objects(models: List[ModelPair]) -> List[c.CNN]:
  28. # Extract model objects from the model pairs
  29. return [model_pair[0] for model_pair in models]
  30. def ensemble_predict(models: List[c.CNN], input: InputData):
  31. predictions = []
  32. for model in models:
  33. model.eval()
  34. with torch.no_grad():
  35. # Apply model and extract positive class predictions
  36. output = model(input)[:, 1]
  37. predictions.append(output)
  38. # Calculate mean and variance of predictions
  39. predictions = torch.stack(predictions)
  40. mean = predictions.mean()
  41. variance = predictions.var()
  42. return mean, variance
  43. def ensemble_predict_strict_classes(models, input):
  44. predictions = []
  45. for model in models:
  46. model.eval()
  47. with torch.no_grad():
  48. # Apply model and extract prediction
  49. output = model(input)
  50. _, predicted = torch.max(output.data, 1)
  51. predictions.append(predicted.item())
  52. pos_votes = len([p for p in predictions if p == 1])
  53. neg_votes = len([p for p in predictions if p == 0])
  54. return pos_votes / len(models), pos_votes, neg_votes
  55. # Prune the ensemble by removing models with low accuracy on the test set, as determined in their tes_acc.txt files
  56. def prune_models(models, model_descs, folder, threshold):
  57. new_models = []
  58. new_descs = []
  59. for model, desc in zip(models, model_descs):
  60. with open(folder + desc + "_test_acc.txt", "r") as f:
  61. acc = float(f.read())
  62. if acc >= threshold:
  63. new_models.append(model)
  64. new_descs.append(desc)
  65. return new_models, new_descs