import pandas as pd import numpy as np import os import tomli as toml from utils.data.datasets import prepare_datasets import utils.ensemble as ens import torch import matplotlib.pyplot as plt import sklearn.metrics as metrics from tqdm import tqdm import utils.metrics as met import itertools as it import matplotlib.ticker as ticker import glob import pickle as pk import warnings warnings.filterwarnings('error') # CONFIGURATION if os.getenv('ADL_CONFIG_PATH') is None: with open('config.toml', 'rb') as f: config = toml.load(f) else: with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: config = toml.load(f) ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}" V3_PATH = ENSEMBLE_PATH + '/v3' # Create the directory if it does not exist if not os.path.exists(V3_PATH): os.makedirs(V3_PATH) # Models is a dictionary with the model ids as keys and the model data as values def get_model_predictions(models, data): predictions = {} for model_id, model in models.items(): model.eval() with torch.no_grad(): # Get the predictions output = model(data) predictions[model_id] = output.detach().cpu().numpy() return predictions def load_models_v2(folder, device): glob_path = os.path.join(folder, '*.pt') model_files = glob.glob(glob_path) model_dict = {} for model_file in model_files: model = torch.load(model_file, map_location=device) model_id = os.path.basename(model_file).split('_')[0] model_dict[model_id] = model if len(model_dict) == 0: raise FileNotFoundError('No models found in the specified directory: ' + folder) return model_dict # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device def preprocess_data(data, device): mri, xls = data mri = mri.unsqueeze(0).to(device) xls = xls.unsqueeze(0).to(device) return (mri, xls) def ensemble_dataset_predictions(models, dataset, device): # For each datapoint, get the predictions of each model predictions = {} for i, (data, target) in tqdm(enumerate(dataset), total=len(dataset)): # Preprocess data data = preprocess_data(data, device) # Predictions is a dicionary of tuples, with the target as the first and the model predicions dictionary as the second # The key is the id of the image predictions[i] = ( target.detach().cpu().numpy(), get_model_predictions(models, data), ) return predictions # Given a dictionary of predictions, select one model and eliminate the rest def select_individual_model(predictions, model_id): selected_model_predictions = {} for key, value in predictions.items(): selected_model_predictions[key] = ( value[0], {model_id: value[1][str(model_id)]}, ) return selected_model_predictions # Given a dictionary of predictions, select a subset of models and eliminate the rest def select_subset_models(predictions, model_ids): selected_model_predictions = {} for key, value in predictions.items(): selected_model_predictions[key] = ( value[0], {model_id: value[1][model_id] for model_id in model_ids}, ) return selected_model_predictions # Given a dictionary of predictions, calculate statistics (stdev, mean, entropy, correctness) for each result # Returns a dataframe of the form {data_id: (mean, stdev, entropy, confidence, correct, predicted, actual)} def calculate_statistics(predictions): # Create DataFrame with columns for each statistic stats_df = pd.DataFrame( columns=[ 'mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual', ] ) # First, loop through each prediction for key, value in predictions.items(): target = value[0] model_predictions = list(value[1].values()) # Calculate the mean and stdev of predictions mean = np.squeeze(np.mean(model_predictions, axis=0)) stdev = np.squeeze(np.std(model_predictions, axis=0))[1] # Calculate the entropy of the predictions entropy = met.entropy(mean) # Calculate confidence confidence = (np.max(mean) - 0.5) * 2 # Calculate predicted and actual predicted = np.argmax(mean) actual = np.argmax(target) # Determine if the prediction is correct correct = predicted == actual # Add the statistics to the dataframe stats_df.loc[key] = [ mean, stdev, entropy, confidence, correct, predicted, actual, ] return stats_df # Takes in a dataframe of the form {data_id: statistic, ...} and calculates the thresholds for the statistic # Output of the form DataFrame(index=threshold, columns=[accuracy, f1]) def conduct_threshold_analysis(statistics, statistic_name, low_to_high=True): # Gives a dataframe percentile_df = statistics[statistic_name].quantile( q=np.linspace(0.05, 0.95, num=18) ) # Dictionary of form {threshold: {metric: value}} thresholds_pd = pd.DataFrame(index=percentile_df.index, columns=['accuracy', 'f1']) for percentile, value in percentile_df.items(): # Filter the statistics if low_to_high: filtered_statistics = statistics[statistics[statistic_name] < value] else: filtered_statistics = statistics[statistics[statistic_name] >= value] # Calculate accuracy and f1 score accuracy = filtered_statistics['correct'].mean() # Calculate F1 score predicted = filtered_statistics['predicted'].values actual = filtered_statistics['actual'].values f1 = metrics.f1_score(actual, predicted) # Add the metrics to the dataframe thresholds_pd.loc[percentile] = [accuracy, f1] return thresholds_pd # Takes a dictionary of the form {threshold: {metric: value}} for a given statistic and plots the metric against the threshold. # Can plot an additional line if given (used for individual results) def plot_threshold_analysis( thresholds_metric, title, x_label, y_label, path, additional_set=None, flip=False ): # Initialize the plot fig, ax = plt.subplots() # Get the thresholds and metrics thresholds = list(thresholds_metric.index) metric = list(thresholds_metric.values) # Plot the metric against the threshold plt.plot(thresholds, metric, 'bo-', label='Ensemble') if additional_set is not None: # Get the thresholds and metrics thresholds = list(additional_set.index) metric = list(additional_set.values) # Plot the metric against the threshold plt.plot(thresholds, metric, 'rx-', label='Individual') if flip: ax.invert_xaxis() # Add labels plt.title(title) plt.xlabel(x_label) plt.ylabel(y_label) plt.legend() ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1)) plt.savefig(path) plt.close() # Code from https://stackoverflow.com/questions/16458340 # Returns the intersections of multiple dictionaries def common_entries(*dcts): if not dcts: return for i in set(dcts[0]).intersection(*dcts[1:]): yield (i,) + tuple(d[i] for d in dcts) def main(): # Load the models device = torch.device(config['training']['device']) models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device) # Load Dataset dataset = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load( f'{ENSEMBLE_PATH}/val_dataset.pt' ) if config['ensemble']['run_models']: # Get thre predicitons of the ensemble ensemble_predictions = ensemble_dataset_predictions(models, dataset, device) # Save to file using pickle with open(f'{V3_PATH}/ensemble_predictions.pk', 'wb') as f: pk.dump(ensemble_predictions, f) else: # Load the predictions from file with open(f'{V3_PATH}/ensemble_predictions.pk', 'rb') as f: ensemble_predictions = pk.load(f) # Get the statistics and thresholds of the ensemble ensemble_statistics = calculate_statistics(ensemble_predictions) stdev_thresholds = conduct_threshold_analysis( ensemble_statistics, 'stdev', low_to_high=True ) entropy_thresholds = conduct_threshold_analysis( ensemble_statistics, 'entropy', low_to_high=True ) confidence_thresholds = conduct_threshold_analysis( ensemble_statistics, 'confidence', low_to_high=False ) # Print ECE and MCE Values conf_ece = met.ECE( ensemble_statistics['predicted'], ensemble_statistics['confidence'], ensemble_statistics['actual'], ) conf_mce = met.MCE( ensemble_statistics['predicted'], ensemble_statistics['confidence'], ensemble_statistics['actual'], ) ent_ece = met.ECE( ensemble_statistics['predicted'], ensemble_statistics['entropy'], ensemble_statistics['actual'], ) ent_mce = met.MCE( ensemble_statistics['predicted'], ensemble_statistics['entropy'], ensemble_statistics['actual'], ) stdev_ece = met.ECE( ensemble_statistics['predicted'], ensemble_statistics['stdev'], ensemble_statistics['actual'], ) stdev_mce = met.MCE( ensemble_statistics['predicted'], ensemble_statistics['stdev'], ensemble_statistics['actual'], ) print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}') print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}') print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}') # Print overall ensemble statistics print('Ensemble Statistics') print(f"Accuracy: {ensemble_statistics['correct'].mean()}") print( f"F1 Score: {metrics.f1_score(ensemble_statistics['actual'], ensemble_statistics['predicted'])}" ) # Get the predictions, statistics and thresholds an individual model indv_id = config['ensemble']['individual_id'] indv_predictions = select_individual_model(ensemble_predictions, indv_id) indv_statistics = calculate_statistics(indv_predictions) # Calculate entropy and confidence thresholds for individual model indv_entropy_thresholds = conduct_threshold_analysis( indv_statistics, 'entropy', low_to_high=True ) indv_confidence_thresholds = conduct_threshold_analysis( indv_statistics, 'confidence', low_to_high=False ) # Plot the threshold analysis for standard deviation plot_threshold_analysis( stdev_thresholds['accuracy'], 'Stdev Threshold Analysis for Accuracy', 'Stdev Threshold', 'Accuracy', f'{V3_PATH}/stdev_threshold_analysis.png', flip=True, ) plot_threshold_analysis( stdev_thresholds['f1'], 'Stdev Threshold Analysis for F1 Score', 'Stdev Threshold', 'F1 Score', f'{V3_PATH}/stdev_threshold_analysis_f1.png', flip=True, ) # Plot the threshold analysis for entropy plot_threshold_analysis( entropy_thresholds['accuracy'], 'Entropy Threshold Analysis for Accuracy', 'Entropy Threshold', 'Accuracy', f'{V3_PATH}/entropy_threshold_analysis.png', indv_entropy_thresholds['accuracy'], flip=True, ) plot_threshold_analysis( entropy_thresholds['f1'], 'Entropy Threshold Analysis for F1 Score', 'Entropy Threshold', 'F1 Score', f'{V3_PATH}/entropy_threshold_analysis_f1.png', indv_entropy_thresholds['f1'], flip=True, ) # Plot the threshold analysis for confidence plot_threshold_analysis( confidence_thresholds['accuracy'], 'Confidence Threshold Analysis for Accuracy', 'Confidence Threshold', 'Accuracy', f'{V3_PATH}/confidence_threshold_analysis.png', indv_confidence_thresholds['accuracy'], ) plot_threshold_analysis( confidence_thresholds['f1'], 'Confidence Threshold Analysis for F1 Score', 'Confidence Threshold', 'F1 Score', f'{V3_PATH}/confidence_threshold_analysis_f1.png', indv_confidence_thresholds['f1'], ) if __name__ == '__main__': main()