12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- import torch
- import os
- import threshold_refac as th
- import pickle as pk
- def main():
- config = th.load_config()
-
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V3_PATH = ENSEMBLE_PATH + '/v3'
-
- if not os.path.exists(V3_PATH):
- os.makedirs(V3_PATH)
-
- device = torch.device(config['training']['device'])
- models = th.load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
-
- dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
- f'{ENSEMBLE_PATH}/val_dataset.pt'
- )
- if config['ensemble']['run_models']:
-
- ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
-
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f:
- pk.dump(ensemble_predictions, f)
- else:
-
- with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f:
- ensemble_predictions = pk.load(f)
-
-
-
- results = pd.DataFrame(columns=['num_models', 'ECE', 'accuracy']).set_index('num_models')
- for i in range(2, len(models) + 1):
- sel_preds = th.select_subset_models(ensemble_predictions, range(i))
- sel_stats = th.calculate_statistics(sel_preds)
- raw_confidence = sel_stats['confidence'].apply(lambda x: (x / 2) + 0.5)
- sel_stats.insert(4, 'raw_confidence', raw_confidence)
- stats = th.calculate_overall_statistics(sel_stats)
- ece = stats.at['raw_confidence', 'ECE']
- accuracy = sel_stats['correct'].mean()
- results.loc[i] = (ece, accuracy)
-
- results.to_csv(f'{V3_PATH}/sensitivity_analysis.csv')
-
- plt.plot(results.index, results['ECE'])
- plt.xlabel('Number of Models')
- plt.ylabel('ECE')
- plt.title('Sensitivity Analysis')
- plt.savefig(f'{V3_PATH}/sensitivity_analysis.png')
- plt.close()
- plt.plot(results.index, results['accuracy'])
- plt.xlabel('Number of Models')
- plt.ylabel('Accuracy')
- plt.title('Sensitivity Analysis')
- plt.savefig(f'{V3_PATH}/sensitivity_analysis_accuracy.png')
- plt.close()
- if __name__ == "__main__":
- main()
|