### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is. import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch import os import threshold_refac as th import pickle as pk def main(): config = th.load_config() ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}" V3_PATH = ENSEMBLE_PATH + '/v3' # Create the directory if it does not exist if not os.path.exists(V3_PATH): os.makedirs(V3_PATH) # Load the models device = torch.device(config['training']['device']) models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device) # Load Dataset dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load( f'{ENSEMBLE_PATH}/val_dataset.pt' ) if config['ensemble']['run_models']: # Get thre predicitons of the ensemble ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device) # Save to file using pickle with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f: pk.dump(ensemble_predictions, f) else: # Load the predictions from file with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f: ensemble_predictions = pk.load(f) # Now that we have the predictions, we can run the sensitivity analysis #We do this by getting the stats for each possible number of models in the ensemble # We will store the results in a dataframe with number of models and the stats results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models') for i in range(2, len(models) + 1): sel_preds = th.select_subset_models(ensemble_predictions, range(i)) sel_stats = th.calculate_statistics(sel_preds) raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5) sel_stats.insert(4, 'raw_confidence', raw_confidence) stats = th.calculate_overall_statistics(sel_stats) ece = stats.at['raw_confidence', 'ECE'] accuracy = sel_stats['correct'].mean() results.loc[i] = (ece, accuracy) # Save the results to a file results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv') # Plot the results plt.plot(results.index, results['ECE']) plt.xlabel('Number of Models') plt.ylabel('ECE') plt.title('Sensitivity Analysis') plt.savefig(f'{V3_PATH}/sensitivity_analysis.png') plt.close() plt.plot(results.index, results['accuracy']) plt.xlabel('Number of Models') plt.ylabel('Accuracy') plt.title('Sensitivity Analysis') plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png') plt.close() if __name__ == "__main__": main()