sensitivity_analysis.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. ### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.
  2. import numpy as np
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. import torch
  6. import os
  7. import threshold_refac as th
  8. import pickle as pk
  9. def main():
  10. config = th.load_config()
  11. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  12. V3_PATH = ENSEMBLE_PATH + '/v3'
  13. # Create the directory if it does not exist
  14. if not os.path.exists(V3_PATH):
  15. os.makedirs(V3_PATH)
  16. # Load the models
  17. device = torch.device(config['training']['device'])
  18. models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  19. # Load Dataset
  20. dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
  21. f'{ENSEMBLE_PATH}/val_dataset.pt'
  22. )
  23. if config['ensemble']['run_models']:
  24. # Get thre predicitons of the ensemble
  25. ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
  26. # Save to file using pickle
  27. with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
  28. pk.dump(ensemble_predictions, f)
  29. else:
  30. # Load the predictions from file
  31. with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
  32. ensemble_predictions = pk.load(f)
  33. # Now that we have the predictions, we can run the sensitivity analysis
  34. #We do this by getting the stats for each possible number of models in the ensemble
  35. # We will store the results in a dataframe with number of models and the stats
  36. results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')
  37. for i in range(2, len(models) + 1):
  38. sel_preds = th.select_subset_models(ensemble_predictions, range(i))
  39. sel_stats = th.calculate_statistics(sel_preds)
  40. raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)
  41. sel_stats.insert(4, 'raw_confidence', raw_confidence)
  42. stats = th.calculate_overall_statistics(sel_stats)
  43. ece = stats.at['raw_confidence', 'ECE']
  44. accuracy = sel_stats['correct'].mean()
  45. results.loc[i] = (ece, accuracy)
  46. # Save the results to a file
  47. results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')
  48. # Plot the results
  49. plt.plot(results.index, results['ECE'])
  50. plt.xlabel('Number of Models')
  51. plt.ylabel('ECE')
  52. plt.title('Sensitivity Analysis')
  53. plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')
  54. plt.close()
  55. plt.plot(results.index, results['accuracy'])
  56. plt.xlabel('Number of Models')
  57. plt.ylabel('Accuracy')
  58. plt.title('Sensitivity Analysis')
  59. plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')
  60. plt.close()
  61. if __name__ == "__main__":
  62. main()