# Rewritten Program to use xarray instead of pandas for thresholding import xarray as xr import torch import numpy as np import os import glob import tomli as toml from tqdm import tqdm import utils.metrics as met import matplotlib.pyplot as plt import matplotlib.ticker as mtick # The datastructures for this file are as follows # models_dict: Dictionary - {model_id: model} # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual'] # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic # Additionally, we also have the thresholds and statistics for the individual models # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic # Additionally, we have some for the sensitivity analysis for number of models # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE'] # Loads configuration dictionary def load_config(): if os.getenv('ADL_CONFIG_PATH') is None: with open('config.toml', 'rb') as f: config = toml.load(f) else: with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: config = toml.load(f) return config # Loads models into a dictionary def load_models_v2(folder, device): glob_path = os.path.join(folder, '*.pt') model_files = glob.glob(glob_path) model_dict = {} for model_file in model_files: model = torch.load(model_file, map_location=device) model_id = os.path.basename(model_file).split('_')[0] model_dict[model_id] = model if len(model_dict) == 0: raise FileNotFoundError('No models found in the specified directory: ' + folder) return model_dict # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device def preprocess_data(data, device): mri, xls = data mri = mri.unsqueeze(0).to(device) xls = xls.unsqueeze(0).to(device) return (mri, xls) # Loads datasets and returns concatenated test and validation datasets def load_datasets(ensemble_path): return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load( f'{ensemble_path}/val_dataset.pt' ) # Gets the predictions for a set of models on a dataset def get_ensemble_predictions(models, dataset, device): zeros = np.zeros((len(dataset), len(models), 4)) predictions = xr.DataArray( zeros, dims=('data_id', 'model_id', 'prediction_value'), coords={ 'data_id': range(len(dataset)), 'model_id': list(models.keys()), 'prediction_value': [ 'negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual', ], }, ) for data_id, (data, target) in tqdm( enumerate(dataset), total=len(dataset), unit='images' ): dat = preprocess_data(data, device) actual = list(target.cpu().numpy()) for model_id, model in models.items(): with torch.no_grad(): output = model(dat) prediction = output.cpu().numpy().tolist()[0] predictions.loc[{'data_id': data_id, 'model_id': model_id}] = ( prediction + actual ) return predictions # Compute the ensemble statistics given an array of predictions def compute_ensemble_statistics(predictions: xr.DataArray): zeros = np.zeros((len(predictions.data_id), 7)) ensemble_statistics = xr.DataArray( zeros, dims=('data_id', 'statistic'), coords={ 'data_id': predictions.data_id, 'statistic': [ 'mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual', ], }, ) for data_id in predictions.data_id: data = predictions.loc[{'data_id': data_id}] mean = data.mean(dim='model_id')[ 0:2 ] # Only take the predictions, not the actual stdev = data.std(dim='model_id')[ 1 ] # Only need the standard deviation of the postive prediction entropy = (-mean * np.log(mean)).sum() # Compute confidence confidence = mean.max() # only need one of the actual values, since they are all the same, just get the first actual_positive actual = data.loc[{'prediction_value': 'positive_actual'}][0] predicted = mean.argmax() correct = actual == predicted ensemble_statistics.loc[{'data_id': data_id}] = [ mean[1], stdev, entropy, confidence, correct, predicted, actual, ] return ensemble_statistics # Compute the thresholded predictions given an array of predictions def compute_thresholded_predictions(input_stats: xr.DataArray): quantiles = np.linspace(0.05, 0.95, 19) * 100 metrics = ['accuracy', 'f1'] statistics = ['stdev', 'entropy', 'confidence'] zeros = np.zeros((len(quantiles), len(statistics), len(metrics))) thresholded_predictions = xr.DataArray( zeros, dims=('quantile', 'statistic', 'metric'), coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics}, ) for statistic in statistics: # First, we must compute the quantiles for the statistic quantile_values = np.percentile( input_stats.sel(statistic=statistic).values, quantiles, axis=0 ) # Then, we must compute the metrics for each quantile for i, quantile in enumerate(quantiles): if low_to_high(statistic): mask = ( input_stats.sel(statistic=statistic) >= quantile_values[i] ).values else: mask = ( input_stats.sel(statistic=statistic) <= quantile_values[i] ).values # Filter the data based on the mask filtered_data = input_stats.where( input_stats.data_id.isin(np.where(mask)), drop=True ) for metric in metrics: thresholded_predictions.loc[ {'quantile': quantile, 'statistic': statistic, 'metric': metric} ] = compute_metric(filtered_data, metric) return thresholded_predictions # Truth function to determine if metric should be thresholded low to high or high to low # Low confidence is bad, high entropy is bad, high stdev is bad # So we threshold confidence low to high, entropy and stdev high to low # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev def low_to_high(stat): return stat in ['confidence'] # Compute a given metric on a DataArray of statstics def compute_metric(arr, metric): if metric == 'accuracy': return np.mean(arr.loc[{'statistic': 'correct'}]) elif metric == 'f1': return met.F1( arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}] ) else: raise ValueError('Invalid metric: ' + metric) # Graph a thresholded prediction for a given statistic and metric def graph_thresholded_prediction( thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel ): data = thresholded_predictions.sel(statistic=statistic, metric=metric) x_data = data.coords['quantile'].values y_data = data.values fig, ax = plt.subplots() ax.plot(x_data, y_data, 'bx-', label='Ensemble') ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.xaxis.set_major_formatter(mtick.PercentFormatter()) if not low_to_high(statistic): ax.invert_xaxis() plt.savefig(save_path) # Graph all thresholded predictions def graph_all_thresholded_predictions(thresholded_predictions, save_path): # Confidence Accuracy graph_thresholded_prediction( thresholded_predictions, 'confidence', 'accuracy', f'{save_path}/confidence_accuracy.png', 'Confidence vs. Accuracy', 'Confidence', 'Accuracy', ) # Confidence F1 graph_thresholded_prediction( thresholded_predictions, 'confidence', 'f1', f'{save_path}/confidence_f1.png', 'Confidence vs. F1', 'Confidence', 'F1', ) # Entropy Accuracy graph_thresholded_prediction( thresholded_predictions, 'entropy', 'accuracy', f'{save_path}/entropy_accuracy.png', 'Entropy vs. Accuracy', 'Entropy', 'Accuracy', ) # Entropy F1 graph_thresholded_prediction( thresholded_predictions, 'entropy', 'f1', f'{save_path}/entropy_f1.png', 'Entropy vs. F1', 'Entropy', 'F1', ) # Stdev Accuracy graph_thresholded_prediction( thresholded_predictions, 'stdev', 'accuracy', f'{save_path}/stdev_accuracy.png', 'Standard Deviation vs. Accuracy', 'Standard Deviation', 'Accuracy', ) # Stdev F1 graph_thresholded_prediction( thresholded_predictions, 'stdev', 'f1', f'{save_path}/stdev_f1.png', 'Standard Deviation vs. F1', 'Standard Deviation', 'F1', ) # Graph two statistics against each other def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel): # Filter for correct predictions c_stats = stats.where( stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)), drop=True, ) # Filter for incorrect predictions i_stats = stats.where( stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)), drop=True, ) # x and y data for correct and incorrect predictions x_data_c = c_stats.sel(statistic=x_stat).values y_data_c = c_stats.sel(statistic=y_stat).values x_data_i = i_stats.sel(statistic=x_stat).values y_data_i = i_stats.sel(statistic=y_stat).values fig, ax = plt.subplots() ax.plot(x_data_c, y_data_c, 'go', label='Correct') ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect') ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.legend() plt.savefig(save_path) # Prune the data based on excluded data_ids def prune_data(data, excluded_data_ids): return data.where(~data.data_id.isin(excluded_data_ids), drop=True) # Calculate individual model statistics def compute_individual_statistics(predictions: xr.DataArray): zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6)) indv_statistics = xr.DataArray( zeros, dims=('data_id', 'model_id', 'statistic'), coords={ 'data_id': predictions.data_id, 'model_id': predictions.model_id, 'statistic': [ 'mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual', ], }, ) for data_id in predictions.data_id: for model_id in predictions.model_id: data = predictions.loc[{'data_id': data_id, 'model_id': model_id}] mean = data[0:2] entropy = (-mean * np.log(mean)).sum() confidence = mean.max() actual = data[3] predicted = mean.argmax() correct = actual == predicted indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [ mean[1], entropy, confidence, correct, predicted, actual, ] return indv_statistics # Compute individual model thresholds def compute_individual_thresholds(input_stats: xr.DataArray): quantiles = np.linspace(0.05, 0.95, 19) * 100 metrics = ['accuracy', 'f1'] statistics = ['entropy', 'confidence'] zeros = np.zeros( (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics)) ) indv_thresholds = xr.DataArray( zeros, dims=('model_id', 'quantile', 'statistic', 'metric'), coords={ 'model_id': input_stats.model_id, 'quantile': quantiles, 'statistic': statistics, 'metric': metrics, }, ) for model_id in input_stats.model_id: for statistic in statistics: # First, we must compute the quantiles for the statistic quantile_values = np.percentile( input_stats.sel(model_id=model_id, statistic=statistic).values, quantiles, axis=0, ) # Then, we must compute the metrics for each quantile for i, quantile in enumerate(quantiles): if low_to_high(statistic): mask = ( input_stats.sel(model_id=model_id, statistic=statistic) >= quantile_values[i] ).values else: mask = ( input_stats.sel(model_id=model_id, statistic=statistic) <= quantile_values[i] ).values # Filter the data based on the mask filtered_data = input_stats.where( input_stats.data_id.isin(np.where(mask)), drop=True ) for metric in metrics: indv_thresholds.loc[ { 'model_id': model_id, 'quantile': quantile, 'statistic': statistic, 'metric': metric, } ] = compute_metric(filtered_data, metric) return indv_thresholds # Graph individual model thresholded predictions def graph_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, statistic, metric, save_path, title, xlabel, ylabel, ): data = indv_thresholds.sel(statistic=statistic, metric=metric) e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric) x_data = data.coords['quantile'].values y_data = data.values e_x_data = e_data.coords['quantile'].values e_y_data = e_data.values fig, ax = plt.subplots() for model_id in data.coords['model_id'].values: model_data = data.sel(model_id=model_id) ax.plot(x_data, model_data) ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble') ax.set_title(title) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.xaxis.set_major_formatter(mtick.PercentFormatter()) if not low_to_high(statistic): ax.invert_xaxis() ax.legend() plt.savefig(save_path) # Graph all individual thresholded predictions def graph_all_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, save_path ): # Confidence Accuracy graph_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, 'confidence', 'accuracy', f'{save_path}/indv/confidence_accuracy.png', 'Confidence vs. Accuracy', 'Confidence Percentile Threshold', 'Accuracy', ) # Confidence F1 graph_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, 'confidence', 'f1', f'{save_path}/indv/confidence_f1.png', 'Confidence vs. F1', 'Confidence Percentile Threshold', 'F1', ) # Entropy Accuracy graph_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, 'entropy', 'accuracy', f'{save_path}/indv/entropy_accuracy.png', 'Entropy vs. Accuracy', 'Entropy Percentile Threshold', 'Accuracy', ) # Entropy F1 graph_individual_thresholded_predictions( indv_thresholds, ensemble_thresholds, 'entropy', 'f1', f'{save_path}/indv/entropy_f1.png', 'Entropy vs. F1', 'Entropy Percentile Threshold', 'F1', ) # Calculate statistics of subsets of models for sensitivity analysis def calculate_subset_statistics(predictions: xr.DataArray): # Calculate subsets for 1-50 models subsets = range(1, len(predictions.model_id) + 1) zeros = np.zeros( (len(predictions.data_id), len(subsets), 7) ) # Include stdev, but for 1 models set to NaN subset_stats = xr.DataArray( zeros, dims=('data_id', 'model_count', 'statistic'), coords={ 'data_id': predictions.data_id, 'model_count': subsets, 'statistic': [ 'mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual', ], }, ) for data_id in predictions.data_id: for subset in subsets: data = predictions.sel( data_id=data_id, model_id=predictions.model_id[:subset] ) mean = data.mean(dim='model_id')[0:2] stdev = data.std(dim='model_id')[1] entropy = (-mean * np.log(mean)).sum() confidence = mean.max() actual = data[3] predicted = mean.argmax() correct = actual == predicted subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [ mean[1], stdev, entropy, confidence, correct, predicted, actual, ] return subset_stats # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis def calculate_sensitivity_analysis(subset_stats: xr.DataArray): subsets = subset_stats.subsets stats = ['accuracy', 'f1', 'ECE', 'MCE'] zeros = np.zeros((len(subsets), len(stats))) sens_analysis = xr.DataArray( zeros, dims=('model_count', 'statistic'), coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']}, ) # Main Function def main(): print('Loading Config...') config = load_config() ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}" V4_PATH = ENSEMBLE_PATH + '/v4' if not os.path.exists(V4_PATH): os.makedirs(V4_PATH) print('Config Loaded') # Load Datasets print('Loading Datasets...') dataset = load_datasets(ENSEMBLE_PATH) print('Datasets Loaded') # Get Predictions, either by running the models or loading them from a file if config['ensemble']['run_models']: # Load Models print('Loading Models...') device = torch.device(config['training']['device']) models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device) print('Models Loaded') # Get Predictions print('Getting Predictions...') predictions = get_ensemble_predictions(models, dataset, device) print('Predictions Loaded') # Save Prediction predictions.to_netcdf(f'{V4_PATH}/predictions.nc') else: predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc') # Prune Data print('Pruning Data...') if config['operation']['exclude_blank_ids']: excluded_data_ids = config['ensemble']['excluded_ids'] predictions = prune_data(predictions, excluded_data_ids) # Compute Ensemble Statistics print('Computing Ensemble Statistics...') ensemble_statistics = compute_ensemble_statistics(predictions) ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc') print('Ensemble Statistics Computed') # Compute Thresholded Predictions print('Computing Thresholded Predictions...') thresholded_predictions = compute_thresholded_predictions(ensemble_statistics) thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc') print('Thresholded Predictions Computed') # Graph Thresholded Predictions print('Graphing Thresholded Predictions...') graph_all_thresholded_predictions(thresholded_predictions, V4_PATH) print('Thresholded Predictions Graphed') # Additional Graphs print('Graphing Additional Graphs...') # Confidence vs stdev graph_statistics( ensemble_statistics, 'confidence', 'stdev', f'{V4_PATH}/confidence_stdev.png', 'Confidence vs. Standard Deviation', 'Confidence', 'Standard Deviation', ) print('Additional Graphs Graphed') # Compute Individual Statistics print('Computing Individual Statistics...') indv_statistics = compute_individual_statistics(predictions) indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc') print('Individual Statistics Computed') # Compute Individual Thresholds print('Computing Individual Thresholds...') indv_thresholds = compute_individual_thresholds(indv_statistics) indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc') print('Individual Thresholds Computed') # Graph Individual Thresholded Predictions print('Graphing Individual Thresholded Predictions...') if not os.path.exists(f'{V4_PATH}/indv'): os.makedirs(f'{V4_PATH}/indv') graph_all_individual_thresholded_predictions( indv_thresholds, thresholded_predictions, V4_PATH ) print('Individual Thresholded Predictions Graphed') if __name__ == '__main__': main()