12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- ### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- import torch
- import os
- import threshold_refac as th
- import pickle as pk
- def main():
- config = th.load_config()
-
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V3_PATH = ENSEMBLE_PATH + '/v3'
- # Create the directory if it does not exist
- if not os.path.exists(V3_PATH):
- os.makedirs(V3_PATH)
- # Load the models
- device = torch.device(config['training']['device'])
- models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
- # Load Dataset
- dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
- f'{ENSEMBLE_PATH}/val_dataset.pt'
- )
- if config['ensemble']['run_models']:
- # Get thre predicitons of the ensemble
- ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
- # Save to file using pickle
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
- pk.dump(ensemble_predictions, f)
- else:
- # Load the predictions from file
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
- ensemble_predictions = pk.load(f)
- # Now that we have the predictions, we can run the sensitivity analysis
- #We do this by getting the stats for each possible number of models in the ensemble
- # We will store the results in a dataframe with number of models and the stats
- results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')
- for i in range(2, len(models) + 1):
- sel_preds = th.select_subset_models(ensemble_predictions, range(i))
- sel_stats = th.calculate_statistics(sel_preds)
- raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)
- sel_stats.insert(4, 'raw_confidence', raw_confidence)
- stats = th.calculate_overall_statistics(sel_stats)
- ece = stats.at['raw_confidence', 'ECE']
- accuracy = sel_stats['correct'].mean()
- results.loc[i] = (ece, accuracy)
- # Save the results to a file
- results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')
- # Plot the results
- plt.plot(results.index, results['ECE'])
- plt.xlabel('Number of Models')
- plt.ylabel('ECE')
- plt.title('Sensitivity Analysis')
- plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')
- plt.close()
- plt.plot(results.index, results['accuracy'])
- plt.xlabel('Number of Models')
- plt.ylabel('Accuracy')
- plt.title('Sensitivity Analysis')
- plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')
- plt.close()
- if __name__ == "__main__":
- main()
|