123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- ### This file is a program to run a sentivity analysis to determine what the best number of models to use in the ensemble is.
- import numpy as np
- import pandas as pd
- import matplotlib.pyplot as plt
- import torch
- import os
- import threshold_refac as th
- import pickle as pk
- import utils.models.cnn
- torch.serialization.safe_globals([utils.models.cnn.CNN])
- def main():
- config = th.load_config()
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V3_PATH = ENSEMBLE_PATH + "/v3"
- # Create the directory if it does not exist
- if not os.path.exists(V3_PATH):
- os.makedirs(V3_PATH)
- # Load the models
- device = torch.device(config["training"]["device"])
- models = th.load_models_v2(f"{ENSEMBLE_PATH}/models/", device)
- # Load Dataset
- dataset = torch.load(
- f"{ENSEMBLE_PATH}/test_dataset.pt", weights_only=False
- ) + torch.load(f"{ENSEMBLE_PATH}/val_dataset.pt", weights_only=False)
- if config["ensemble"]["run_models"]:
- # Get thre predicitons of the ensemble
- ensemble_predictions = th.ensemble_dataset_predictions(models, dataset, device)
- # Save to file using pickle
- with open(f"{V3_PATH}/ensemble_predictions.pk", "wb") as f:
- pk.dump(ensemble_predictions, f)
- else:
- # Load the predictions from file
- with open(f"{V3_PATH}/ensemble_predictions.pk", "rb") as f:
- ensemble_predictions = pk.load(f)
- # Now that we have the predictions, we can run the sensitivity analysis
- # We do this by getting the stats for each possible number of models in the ensemble
- # We will store the results in a dataframe with number of models and the stats
- results = pd.DataFrame(columns=["num_models", "ECE", "accuracy"]).set_index(
- "num_models"
- )
- for i in range(2, len(models) + 1):
- sel_preds = th.select_subset_models(ensemble_predictions, range(i))
- sel_stats = th.calculate_statistics(sel_preds)
- raw_confidence = sel_stats["confidence"].apply(lambda x: (x / 2) + 0.5)
- sel_stats.insert(4, "raw_confidence", raw_confidence)
- stats = th.calculate_overall_statistics(sel_stats)
- ece = stats.at["raw_confidence", "ECE"]
- accuracy = sel_stats["correct"].mean()
- results.loc[i] = (ece, accuracy)
- # Save the results to a file
- results.to_csv(f"{V3_PATH}/sensitivity_analysis.csv")
- # Plot the results
- plt.plot(results.index, results["ECE"])
- plt.xlabel("Number of Models")
- plt.ylabel("ECE")
- plt.title("Sensitivity Analysis")
- plt.savefig(f"{V3_PATH}/sensitivity_analysis.png")
- plt.close()
- plt.plot(results.index, results["accuracy"])
- plt.xlabel("Number of Models")
- plt.ylabel("Accuracy")
- plt.title("Sensitivity Analysis")
- plt.savefig(f"{V3_PATH}/sensitivity_analysis_accuracy.png")
- plt.close()
- if __name__ == "__main__":
- main()
|