| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 | ### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport torch import osimport threshold_refac as thimport pickle as pkdef main():    config = th.load_config()        ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"    V3_PATH = ENSEMBLE_PATH + '/v3'    # Create the directory if it does not exist    if not os.path.exists(V3_PATH):        os.makedirs(V3_PATH)    # Load the models    device = torch.device(config['training']['device'])    models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)    # Load Dataset    dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(        f'{ENSEMBLE_PATH}/val_dataset.pt'    )    if config['ensemble']['run_models']:        # Get thre predicitons of the ensemble        ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)        # Save to file using pickle        with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:            pk.dump(ensemble_predictions, f)    else:        # Load the predictions from file        with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:            ensemble_predictions = pk.load(f)    # Now that we have the predictions, we can run the sensitivity analysis    #We do this by getting the stats for each possible number of models in the ensemble    # We will store the results in a dataframe with number of models and the stats    results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')    for i in range(2, len(models) + 1):        sel_preds = th.select_subset_models(ensemble_predictions, range(i))        sel_stats = th.calculate_statistics(sel_preds)        raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)        sel_stats.insert(4, 'raw_confidence', raw_confidence)        stats = th.calculate_overall_statistics(sel_stats)        ece = stats.at['raw_confidence', 'ECE']        accuracy = sel_stats['correct'].mean()        results.loc[i] = (ece, accuracy)    # Save the results to a file    results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')    # Plot the results    plt.plot(results.index, results['ECE'])    plt.xlabel('Number of Models')    plt.ylabel('ECE')    plt.title('Sensitivity Analysis')    plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')    plt.close()    plt.plot(results.index, results['accuracy'])    plt.xlabel('Number of Models')    plt.ylabel('Accuracy')    plt.title('Sensitivity Analysis')    plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')    plt.close()if __name__ == "__main__":    main()
 |