sensitivity_analysis.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. ### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.
  2. import numpy as np
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. import torch
  6. import os
  7. import threshold_refac as th
  8. import pickle as pk
  9. import utils.models.cnn
  10. torch.serialization.safe_globals([utils.models.cnn.CNN])
  11. def main():
  12. config = th.load_config()
  13. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  14. V3_PATH = ENSEMBLE_PATH + "/v3"
  15. # Create the directory if it does not exist
  16. if not os.path.exists(V3_PATH):
  17. os.makedirs(V3_PATH)
  18. # Load the models
  19. device = torch.device(config["training"]["device"])
  20. models = th.load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
  21. # Load Dataset
  22. dataset = torch.load(
  23. f"{ENSEMBLE_PATH}/test_dataset.pt", weights_only=False
  24. ) + torch.load(f"{ENSEMBLE_PATH}/val_dataset.pt", weights_only=False)
  25. if config["ensemble"]["run_models"]:
  26. # Get thre predicitons of the ensemble
  27. ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
  28. # Save to file using pickle
  29. with open(f"{V3_PATH}/ensemble_predictions.pk", "wb") as f:
  30. pk.dump(ensemble_predictions, f)
  31. else:
  32. # Load the predictions from file
  33. with open(f"{V3_PATH}/ensemble_predictions.pk", "rb") as f:
  34. ensemble_predictions = pk.load(f)
  35. # Now that we have the predictions, we can run the sensitivity analysis
  36. # We do this by getting the stats for each possible number of models in the ensemble
  37. # We will store the results in a dataframe with number of models and the stats
  38. results = pd.DataFrame(columns=["num_models", "ECE", "accuracy"]).set_index(
  39. "num_models"
  40. )
  41. for i in range(2, len(models) + 1):
  42. sel_preds = th.select_subset_models(ensemble_predictions, range(i))
  43. sel_stats = th.calculate_statistics(sel_preds)
  44. raw_confidence = sel_stats["confidence"].apply(lambda x: (x / 2) + 0.5)
  45. sel_stats.insert(4, "raw_confidence", raw_confidence)
  46. stats = th.calculate_overall_statistics(sel_stats)
  47. ece = stats.at["raw_confidence", "ECE"]
  48. accuracy = sel_stats["correct"].mean()
  49. results.loc[i] = (ece, accuracy)
  50. # Save the results to a file
  51. results.to_csv(f"{V3_PATH}/sensitivity_analysis.csv")
  52. # Plot the results
  53. plt.plot(results.index, results["ECE"])
  54. plt.xlabel("Number of Models")
  55. plt.ylabel("ECE")
  56. plt.title("Sensitivity Analysis")
  57. plt.savefig(f"{V3_PATH}/sensitivity_analysis.png")
  58. plt.close()
  59. plt.plot(results.index, results["accuracy"])
  60. plt.xlabel("Number of Models")
  61. plt.ylabel("Accuracy")
  62. plt.title("Sensitivity Analysis")
  63. plt.savefig(f"{V3_PATH}/sensitivity_analysis_accuracy.png")
  64. plt.close()
  65. if __name__ == "__main__":
  66. main()