|
@@ -1,31 +1,31 @@
|
|
|
-#Rewritten Program to use xarray instead of pandas for thresholding
|
|
|
+# Rewritten Program to use xarray instead of pandas for thresholding
|
|
|
|
|
|
import xarray as xr
|
|
|
-import torch
|
|
|
+import torch
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import glob
|
|
|
import tomli as toml
|
|
|
from tqdm import tqdm
|
|
|
import utils.metrics as met
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib.ticker as mtick
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
- main()
|
|
|
|
|
|
-
|
|
|
-#The datastructures for this file are as follows
|
|
|
+# The datastructures for this file are as follows
|
|
|
# models_dict: Dictionary - {model_id: model}
|
|
|
# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
|
|
|
# ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
|
|
|
# thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
|
|
|
|
|
|
-#Additionally, we also have the thresholds and statistics for the individual models
|
|
|
+# Additionally, we also have the thresholds and statistics for the individual models
|
|
|
# indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
|
|
|
# indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
|
|
|
|
|
|
-#Additionally, we have some for the sensitivity analysis for number of models
|
|
|
+# Additionally, we have some for the sensitivity analysis for number of models
|
|
|
# sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
|
|
|
|
|
|
+
|
|
|
# Loads configuration dictionary
|
|
|
def load_config():
|
|
|
if os.getenv('ADL_CONFIG_PATH') is None:
|
|
@@ -37,7 +37,8 @@ def load_config():
|
|
|
|
|
|
return config
|
|
|
|
|
|
-#Loads models into a dictionary
|
|
|
+
|
|
|
+# Loads models into a dictionary
|
|
|
def load_models_v2(folder, device):
|
|
|
glob_path = os.path.join(folder, '*.pt')
|
|
|
model_files = glob.glob(glob_path)
|
|
@@ -53,6 +54,7 @@ def load_models_v2(folder, device):
|
|
|
|
|
|
return model_dict
|
|
|
|
|
|
+
|
|
|
# Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
|
|
|
def preprocess_data(data, device):
|
|
|
mri, xls = data
|
|
@@ -60,12 +62,14 @@ def preprocess_data(data, device):
|
|
|
xls = xls.unsqueeze(0).to(device)
|
|
|
return (mri, xls)
|
|
|
|
|
|
+
|
|
|
# Loads datasets and returns concatenated test and validation datasets
|
|
|
def load_datasets(ensemble_path):
|
|
|
return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
|
|
|
f'{ensemble_path}/val_dataset.pt'
|
|
|
)
|
|
|
|
|
|
+
|
|
|
# Gets the predictions for a set of models on a dataset
|
|
|
def get_ensemble_predictions(models, dataset, device):
|
|
|
zeros = np.zeros((len(dataset), len(models), 4))
|
|
@@ -74,25 +78,35 @@ def get_ensemble_predictions(models, dataset, device):
|
|
|
dims=('data_id', 'model_id', 'prediction_value'),
|
|
|
coords={
|
|
|
'data_id': range(len(dataset)),
|
|
|
- 'model_id': models.keys(),
|
|
|
- 'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
|
|
|
- }
|
|
|
+ 'model_id': list(models.keys()),
|
|
|
+ 'prediction_value': [
|
|
|
+ 'negative_prediction',
|
|
|
+ 'positive_prediction',
|
|
|
+ 'negative_actual',
|
|
|
+ 'positive_actual',
|
|
|
+ ],
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
- for data_id, (data, target) in tqdm(enumerate(dataset)):
|
|
|
- mri, xls = preprocess_data(data, device)
|
|
|
+ for data_id, (data, target) in tqdm(
|
|
|
+ enumerate(dataset), total=len(dataset), unit='images'
|
|
|
+ ):
|
|
|
+ dat = preprocess_data(data, device)
|
|
|
actual = list(target.cpu().numpy())
|
|
|
for model_id, model in models.items():
|
|
|
with torch.no_grad():
|
|
|
- output = model(mri, xls)
|
|
|
- prediction = list(output.cpu().numpy())
|
|
|
+ output = model(dat)
|
|
|
+ prediction = output.cpu().numpy().tolist()[0]
|
|
|
|
|
|
- predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
|
|
|
+ predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
|
|
|
+ prediction + actual
|
|
|
+ )
|
|
|
|
|
|
return predictions
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
# Compute the ensemble statistics given an array of predictions
|
|
|
-def compute_ensemble_statistics(predictions):
|
|
|
+def compute_ensemble_statistics(predictions: xr.DataArray):
|
|
|
zeros = np.zeros((len(predictions.data_id), 7))
|
|
|
|
|
|
ensemble_statistics = xr.DataArray(
|
|
@@ -100,58 +114,93 @@ def compute_ensemble_statistics(predictions):
|
|
|
dims=('data_id', 'statistic'),
|
|
|
coords={
|
|
|
'data_id': predictions.data_id,
|
|
|
- 'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
|
|
|
- }
|
|
|
- )
|
|
|
+ 'statistic': [
|
|
|
+ 'mean',
|
|
|
+ 'stdev',
|
|
|
+ 'entropy',
|
|
|
+ 'confidence',
|
|
|
+ 'correct',
|
|
|
+ 'predicted',
|
|
|
+ 'actual',
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ )
|
|
|
|
|
|
for data_id in predictions.data_id:
|
|
|
- data = predictions.loc[{ 'data_id': data_id }]
|
|
|
- mean = np.mean(data, axis=0)
|
|
|
- stdev = np.std(data, axis=0)
|
|
|
- entropy = -np.sum(mean * np.log2(mean + 1e-12))
|
|
|
- confidence = np.max(mean)
|
|
|
-
|
|
|
- actual = data.iloc[:, 3].values
|
|
|
- predicted = np.argmax(mean)
|
|
|
+ data = predictions.loc[{'data_id': data_id}]
|
|
|
+ mean = data.mean(dim='model_id')[
|
|
|
+ 0:2
|
|
|
+ ] # Only take the predictions, not the actual
|
|
|
+ stdev = data.std(dim='model_id')[
|
|
|
+ 1
|
|
|
+ ] # Only need the standard deviation of the postive prediction
|
|
|
+ entropy = (-mean * np.log(mean)).sum()
|
|
|
+
|
|
|
+ # Compute confidence
|
|
|
+ confidence = mean.max()
|
|
|
+
|
|
|
+ # only need one of the actual values, since they are all the same, just get the first actual_positive
|
|
|
+ actual = data.loc[{'prediction_value': 'positive_actual'}][0]
|
|
|
+ predicted = mean.argmax()
|
|
|
correct = actual == predicted
|
|
|
|
|
|
- ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
|
|
|
+ ensemble_statistics.loc[{'data_id': data_id}] = [
|
|
|
+ mean[1],
|
|
|
+ stdev,
|
|
|
+ entropy,
|
|
|
+ confidence,
|
|
|
+ correct,
|
|
|
+ predicted,
|
|
|
+ actual,
|
|
|
+ ]
|
|
|
|
|
|
return ensemble_statistics
|
|
|
|
|
|
+
|
|
|
# Compute the thresholded predictions given an array of predictions
|
|
|
-def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
|
|
|
- quantiles = np.linspace(0.05, 0.95, 19)
|
|
|
+def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
|
+ quantiles = np.linspace(0.05, 0.95, 19) * 100
|
|
|
metrics = ['accuracy', 'f1']
|
|
|
statistics = ['stdev', 'entropy', 'confidence']
|
|
|
-
|
|
|
+
|
|
|
zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
|
|
|
|
|
|
thresholded_predictions = xr.DataArray(
|
|
|
zeros,
|
|
|
dims=('quantile', 'statistic', 'metric'),
|
|
|
- coords={
|
|
|
- 'quantile': quantiles,
|
|
|
- 'statistic': statistics,
|
|
|
- 'metric': metrics
|
|
|
- }
|
|
|
+ coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
|
|
|
)
|
|
|
|
|
|
for statistic in statistics:
|
|
|
- #First, we must compute the quantiles for the statistic
|
|
|
- quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
|
|
|
+ # First, we must compute the quantiles for the statistic
|
|
|
+ quantile_values = np.percentile(
|
|
|
+ input_stats.sel(statistic=statistic).values, quantiles, axis=0
|
|
|
+ )
|
|
|
|
|
|
- #Then, we must compute the metrics for each quantile
|
|
|
+ # Then, we must compute the metrics for each quantile
|
|
|
for i, quantile in enumerate(quantiles):
|
|
|
if low_to_high(statistic):
|
|
|
- filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
|
|
|
+ mask = (
|
|
|
+ input_stats.sel(statistic=statistic) >= quantile_values[i]
|
|
|
+ ).values
|
|
|
else:
|
|
|
- filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
|
|
|
+ mask = (
|
|
|
+ input_stats.sel(statistic=statistic) <= quantile_values[i]
|
|
|
+ ).values
|
|
|
+
|
|
|
+ # Filter the data based on the mask
|
|
|
+ filtered_data = input_stats.where(
|
|
|
+ input_stats.data_id.isin(np.where(mask)), drop=True
|
|
|
+ )
|
|
|
+
|
|
|
for metric in metrics:
|
|
|
- thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
|
|
|
-
|
|
|
+ thresholded_predictions.loc[
|
|
|
+ {'quantile': quantile, 'statistic': statistic, 'metric': metric}
|
|
|
+ ] = compute_metric(filtered_data, metric)
|
|
|
+
|
|
|
return thresholded_predictions
|
|
|
-
|
|
|
+
|
|
|
+
|
|
|
# Truth function to determine if metric should be thresholded low to high or high to low
|
|
|
# Low confidence is bad, high entropy is bad, high stdev is bad
|
|
|
# So we threshold confidence low to high, entropy and stdev high to low
|
|
@@ -159,35 +208,438 @@ def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
|
|
|
def low_to_high(stat):
|
|
|
return stat in ['confidence']
|
|
|
|
|
|
+
|
|
|
# Compute a given metric on a DataArray of statstics
|
|
|
def compute_metric(arr, metric):
|
|
|
if metric == 'accuracy':
|
|
|
- return np.mean(arr.loc[{ 'statistic': 'correct' }])
|
|
|
+ return np.mean(arr.loc[{'statistic': 'correct'}])
|
|
|
elif metric == 'f1':
|
|
|
- return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
|
|
|
+ return met.F1(
|
|
|
+ arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
|
|
|
+ )
|
|
|
else:
|
|
|
raise ValueError('Invalid metric: ' + metric)
|
|
|
|
|
|
|
|
|
+# Graph a thresholded prediction for a given statistic and metric
|
|
|
+def graph_thresholded_prediction(
|
|
|
+ thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
|
|
|
+):
|
|
|
+ data = thresholded_predictions.sel(statistic=statistic, metric=metric)
|
|
|
+
|
|
|
+ x_data = data.coords['quantile'].values
|
|
|
+ y_data = data.values
|
|
|
+
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ ax.plot(x_data, y_data, 'bx-', label='Ensemble')
|
|
|
+ ax.set_title(title)
|
|
|
+ ax.set_xlabel(xlabel)
|
|
|
+ ax.set_ylabel(ylabel)
|
|
|
+ ax.xaxis.set_major_formatter(mtick.PercentFormatter())
|
|
|
+
|
|
|
+ if not low_to_high(statistic):
|
|
|
+ ax.invert_xaxis()
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+
|
|
|
+
|
|
|
+# Graph all thresholded predictions
|
|
|
+def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
|
+ # Confidence Accuracy
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'confidence',
|
|
|
+ 'accuracy',
|
|
|
+ f'{save_path}/confidence_accuracy.png',
|
|
|
+ 'Confidence vs. Accuracy',
|
|
|
+ 'Confidence',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Confidence F1
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'confidence',
|
|
|
+ 'f1',
|
|
|
+ f'{save_path}/confidence_f1.png',
|
|
|
+ 'Confidence vs. F1',
|
|
|
+ 'Confidence',
|
|
|
+ 'F1',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Entropy Accuracy
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'entropy',
|
|
|
+ 'accuracy',
|
|
|
+ f'{save_path}/entropy_accuracy.png',
|
|
|
+ 'Entropy vs. Accuracy',
|
|
|
+ 'Entropy',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Entropy F1
|
|
|
+
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'entropy',
|
|
|
+ 'f1',
|
|
|
+ f'{save_path}/entropy_f1.png',
|
|
|
+ 'Entropy vs. F1',
|
|
|
+ 'Entropy',
|
|
|
+ 'F1',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Stdev Accuracy
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'stdev',
|
|
|
+ 'accuracy',
|
|
|
+ f'{save_path}/stdev_accuracy.png',
|
|
|
+ 'Standard Deviation vs. Accuracy',
|
|
|
+ 'Standard Deviation',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Stdev F1
|
|
|
+ graph_thresholded_prediction(
|
|
|
+ thresholded_predictions,
|
|
|
+ 'stdev',
|
|
|
+ 'f1',
|
|
|
+ f'{save_path}/stdev_f1.png',
|
|
|
+ 'Standard Deviation vs. F1',
|
|
|
+ 'Standard Deviation',
|
|
|
+ 'F1',
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# Graph two statistics against each other
|
|
|
+def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
|
|
|
+ # Filter for correct predictions
|
|
|
+ c_stats = stats.where(
|
|
|
+ stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
|
|
|
+ drop=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Filter for incorrect predictions
|
|
|
+ i_stats = stats.where(
|
|
|
+ stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
|
|
|
+ drop=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # x and y data for correct and incorrect predictions
|
|
|
+ x_data_c = c_stats.sel(statistic=x_stat).values
|
|
|
+ y_data_c = c_stats.sel(statistic=y_stat).values
|
|
|
+
|
|
|
+ x_data_i = i_stats.sel(statistic=x_stat).values
|
|
|
+ y_data_i = i_stats.sel(statistic=y_stat).values
|
|
|
+
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ ax.plot(x_data_c, y_data_c, 'go', label='Correct')
|
|
|
+ ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
|
|
|
+ ax.set_title(title)
|
|
|
+ ax.set_xlabel(xlabel)
|
|
|
+ ax.set_ylabel(ylabel)
|
|
|
+ ax.legend()
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+
|
|
|
+
|
|
|
+# Prune the data based on excluded data_ids
|
|
|
+def prune_data(data, excluded_data_ids):
|
|
|
+ return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
|
|
|
+
|
|
|
+
|
|
|
+# Calculate individual model statistics
|
|
|
+def compute_individual_statistics(predictions: xr.DataArray):
|
|
|
+ zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
|
|
|
+
|
|
|
+ indv_statistics = xr.DataArray(
|
|
|
+ zeros,
|
|
|
+ dims=('data_id', 'model_id', 'statistic'),
|
|
|
+ coords={
|
|
|
+ 'data_id': predictions.data_id,
|
|
|
+ 'model_id': predictions.model_id,
|
|
|
+ 'statistic': [
|
|
|
+ 'mean',
|
|
|
+ 'entropy',
|
|
|
+ 'confidence',
|
|
|
+ 'correct',
|
|
|
+ 'predicted',
|
|
|
+ 'actual',
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ for data_id in predictions.data_id:
|
|
|
+ for model_id in predictions.model_id:
|
|
|
+ data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
|
|
|
+ mean = data[0:2]
|
|
|
+ entropy = (-mean * np.log(mean)).sum()
|
|
|
+ confidence = mean.max()
|
|
|
+ actual = data[3]
|
|
|
+ predicted = mean.argmax()
|
|
|
+ correct = actual == predicted
|
|
|
+
|
|
|
+ indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
|
|
|
+ mean[1],
|
|
|
+ entropy,
|
|
|
+ confidence,
|
|
|
+ correct,
|
|
|
+ predicted,
|
|
|
+ actual,
|
|
|
+ ]
|
|
|
+
|
|
|
+ return indv_statistics
|
|
|
+
|
|
|
+
|
|
|
+# Compute individual model thresholds
|
|
|
+def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
|
+ quantiles = np.linspace(0.05, 0.95, 19) * 100
|
|
|
+ metrics = ['accuracy', 'f1']
|
|
|
+ statistics = ['entropy', 'confidence']
|
|
|
+
|
|
|
+ zeros = np.zeros(
|
|
|
+ (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
|
|
|
+ )
|
|
|
+
|
|
|
+ indv_thresholds = xr.DataArray(
|
|
|
+ zeros,
|
|
|
+ dims=('model_id', 'quantile', 'statistic', 'metric'),
|
|
|
+ coords={
|
|
|
+ 'model_id': input_stats.model_id,
|
|
|
+ 'quantile': quantiles,
|
|
|
+ 'statistic': statistics,
|
|
|
+ 'metric': metrics,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ for model_id in input_stats.model_id:
|
|
|
+ for statistic in statistics:
|
|
|
+ # First, we must compute the quantiles for the statistic
|
|
|
+ quantile_values = np.percentile(
|
|
|
+ input_stats.sel(model_id=model_id, statistic=statistic).values,
|
|
|
+ quantiles,
|
|
|
+ axis=0,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Then, we must compute the metrics for each quantile
|
|
|
+ for i, quantile in enumerate(quantiles):
|
|
|
+ if low_to_high(statistic):
|
|
|
+ mask = (
|
|
|
+ input_stats.sel(model_id=model_id, statistic=statistic)
|
|
|
+ >= quantile_values[i]
|
|
|
+ ).values
|
|
|
+ else:
|
|
|
+ mask = (
|
|
|
+ input_stats.sel(model_id=model_id, statistic=statistic)
|
|
|
+ <= quantile_values[i]
|
|
|
+ ).values
|
|
|
+
|
|
|
+ # Filter the data based on the mask
|
|
|
+ filtered_data = input_stats.where(
|
|
|
+ input_stats.data_id.isin(np.where(mask)), drop=True
|
|
|
+ )
|
|
|
+
|
|
|
+ for metric in metrics:
|
|
|
+ indv_thresholds.loc[
|
|
|
+ {
|
|
|
+ 'model_id': model_id,
|
|
|
+ 'quantile': quantile,
|
|
|
+ 'statistic': statistic,
|
|
|
+ 'metric': metric,
|
|
|
+ }
|
|
|
+ ] = compute_metric(filtered_data, metric)
|
|
|
+
|
|
|
+ return indv_thresholds
|
|
|
+
|
|
|
+
|
|
|
+# Graph individual model thresholded predictions
|
|
|
+def graph_individual_thresholded_predictions(
|
|
|
+ indv_thresholds,
|
|
|
+ ensemble_thresholds,
|
|
|
+ statistic,
|
|
|
+ metric,
|
|
|
+ save_path,
|
|
|
+ title,
|
|
|
+ xlabel,
|
|
|
+ ylabel,
|
|
|
+):
|
|
|
+ data = indv_thresholds.sel(statistic=statistic, metric=metric)
|
|
|
+ e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
|
|
|
+
|
|
|
+ x_data = data.coords['quantile'].values
|
|
|
+ y_data = data.values
|
|
|
+
|
|
|
+ e_x_data = e_data.coords['quantile'].values
|
|
|
+ e_y_data = e_data.values
|
|
|
+
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ for model_id in data.coords['model_id'].values:
|
|
|
+ model_data = data.sel(model_id=model_id)
|
|
|
+ ax.plot(x_data, model_data)
|
|
|
+
|
|
|
+ ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
|
|
|
+
|
|
|
+ ax.set_title(title)
|
|
|
+ ax.set_xlabel(xlabel)
|
|
|
+ ax.set_ylabel(ylabel)
|
|
|
+ ax.xaxis.set_major_formatter(mtick.PercentFormatter())
|
|
|
+
|
|
|
+ if not low_to_high(statistic):
|
|
|
+ ax.invert_xaxis()
|
|
|
+
|
|
|
+ ax.legend()
|
|
|
+ plt.savefig(save_path)
|
|
|
+
|
|
|
+
|
|
|
+# Graph all individual thresholded predictions
|
|
|
+def graph_all_individual_thresholded_predictions(
|
|
|
+ indv_thresholds, ensemble_thresholds, save_path
|
|
|
+):
|
|
|
+ # Confidence Accuracy
|
|
|
+ graph_individual_thresholded_predictions(
|
|
|
+ indv_thresholds,
|
|
|
+ ensemble_thresholds,
|
|
|
+ 'confidence',
|
|
|
+ 'accuracy',
|
|
|
+ f'{save_path}/indv/confidence_accuracy.png',
|
|
|
+ 'Confidence vs. Accuracy',
|
|
|
+ 'Confidence Percentile Threshold',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Confidence F1
|
|
|
+ graph_individual_thresholded_predictions(
|
|
|
+ indv_thresholds,
|
|
|
+ ensemble_thresholds,
|
|
|
+ 'confidence',
|
|
|
+ 'f1',
|
|
|
+ f'{save_path}/indv/confidence_f1.png',
|
|
|
+ 'Confidence vs. F1',
|
|
|
+ 'Confidence Percentile Threshold',
|
|
|
+ 'F1',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Entropy Accuracy
|
|
|
+ graph_individual_thresholded_predictions(
|
|
|
+ indv_thresholds,
|
|
|
+ ensemble_thresholds,
|
|
|
+ 'entropy',
|
|
|
+ 'accuracy',
|
|
|
+ f'{save_path}/indv/entropy_accuracy.png',
|
|
|
+ 'Entropy vs. Accuracy',
|
|
|
+ 'Entropy Percentile Threshold',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+
|
|
|
+ # Entropy F1
|
|
|
+ graph_individual_thresholded_predictions(
|
|
|
+ indv_thresholds,
|
|
|
+ ensemble_thresholds,
|
|
|
+ 'entropy',
|
|
|
+ 'f1',
|
|
|
+ f'{save_path}/indv/entropy_f1.png',
|
|
|
+ 'Entropy vs. F1',
|
|
|
+ 'Entropy Percentile Threshold',
|
|
|
+ 'F1',
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# Calculate statistics of subsets of models for sensitivity analysis
|
|
|
+def calculate_subset_statistics(predictions: xr.DataArray):
|
|
|
+ # Calculate subsets for 1-50 models
|
|
|
+ subsets = range(1, len(predictions.model_id) + 1)
|
|
|
+ zeros = np.zeros(
|
|
|
+ (len(predictions.data_id), len(subsets), 7)
|
|
|
+ ) # Include stdev, but for 1 models set to NaN
|
|
|
+
|
|
|
+ subset_stats = xr.DataArray(
|
|
|
+ zeros,
|
|
|
+ dims=('data_id', 'model_count', 'statistic'),
|
|
|
+ coords={
|
|
|
+ 'data_id': predictions.data_id,
|
|
|
+ 'model_count': subsets,
|
|
|
+ 'statistic': [
|
|
|
+ 'mean',
|
|
|
+ 'stdev',
|
|
|
+ 'entropy',
|
|
|
+ 'confidence',
|
|
|
+ 'correct',
|
|
|
+ 'predicted',
|
|
|
+ 'actual',
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ for data_id in predictions.data_id:
|
|
|
+ for subset in subsets:
|
|
|
+ data = predictions.sel(
|
|
|
+ data_id=data_id, model_id=predictions.model_id[:subset]
|
|
|
+ )
|
|
|
+ mean = data.mean(dim='model_id')[0:2]
|
|
|
+ stdev = data.std(dim='model_id')[1]
|
|
|
+ entropy = (-mean * np.log(mean)).sum()
|
|
|
+ confidence = mean.max()
|
|
|
+ actual = data[3]
|
|
|
+ predicted = mean.argmax()
|
|
|
+ correct = actual == predicted
|
|
|
+
|
|
|
+ subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
|
|
|
+ mean[1],
|
|
|
+ stdev,
|
|
|
+ entropy,
|
|
|
+ confidence,
|
|
|
+ correct,
|
|
|
+ predicted,
|
|
|
+ actual,
|
|
|
+ ]
|
|
|
+
|
|
|
+ return subset_stats
|
|
|
+
|
|
|
+
|
|
|
+# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
|
+def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
|
+ subsets = subset_stats.subsets
|
|
|
+ stats = ['accuracy', 'f1', 'ECE', 'MCE']
|
|
|
+
|
|
|
+ zeros = np.zeros((len(subsets), len(stats)))
|
|
|
+
|
|
|
+ sens_analysis = xr.DataArray(
|
|
|
+ zeros,
|
|
|
+ dims=('model_count', 'statistic'),
|
|
|
+ coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+# Main Function
|
|
|
def main():
|
|
|
+ print('Loading Config...')
|
|
|
config = load_config()
|
|
|
ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
|
|
|
V4_PATH = ENSEMBLE_PATH + '/v4'
|
|
|
|
|
|
if not os.path.exists(V4_PATH):
|
|
|
os.makedirs(V4_PATH)
|
|
|
+ print('Config Loaded')
|
|
|
|
|
|
# Load Datasets
|
|
|
+ print('Loading Datasets...')
|
|
|
dataset = load_datasets(ENSEMBLE_PATH)
|
|
|
+ print('Datasets Loaded')
|
|
|
|
|
|
# Get Predictions, either by running the models or loading them from a file
|
|
|
if config['ensemble']['run_models']:
|
|
|
# Load Models
|
|
|
+ print('Loading Models...')
|
|
|
device = torch.device(config['training']['device'])
|
|
|
models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
|
|
|
+ print('Models Loaded')
|
|
|
|
|
|
# Get Predictions
|
|
|
+ print('Getting Predictions...')
|
|
|
predictions = get_ensemble_predictions(models, dataset, device)
|
|
|
+ print('Predictions Loaded')
|
|
|
|
|
|
# Save Prediction
|
|
|
predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
|
|
@@ -195,10 +647,65 @@ def main():
|
|
|
else:
|
|
|
predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
|
|
|
|
|
|
+ # Prune Data
|
|
|
+ print('Pruning Data...')
|
|
|
+ if config['operation']['exclude_blank_ids']:
|
|
|
+ excluded_data_ids = config['ensemble']['excluded_ids']
|
|
|
+ predictions = prune_data(predictions, excluded_data_ids)
|
|
|
+
|
|
|
# Compute Ensemble Statistics
|
|
|
+ print('Computing Ensemble Statistics...')
|
|
|
ensemble_statistics = compute_ensemble_statistics(predictions)
|
|
|
ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
|
|
|
+ print('Ensemble Statistics Computed')
|
|
|
|
|
|
# Compute Thresholded Predictions
|
|
|
+ print('Computing Thresholded Predictions...')
|
|
|
thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
|
|
|
thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
|
|
|
+ print('Thresholded Predictions Computed')
|
|
|
+
|
|
|
+ # Graph Thresholded Predictions
|
|
|
+ print('Graphing Thresholded Predictions...')
|
|
|
+ graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
|
|
|
+ print('Thresholded Predictions Graphed')
|
|
|
+
|
|
|
+ # Additional Graphs
|
|
|
+ print('Graphing Additional Graphs...')
|
|
|
+ # Confidence vs stdev
|
|
|
+ graph_statistics(
|
|
|
+ ensemble_statistics,
|
|
|
+ 'confidence',
|
|
|
+ 'stdev',
|
|
|
+ f'{V4_PATH}/confidence_stdev.png',
|
|
|
+ 'Confidence vs. Standard Deviation',
|
|
|
+ 'Confidence',
|
|
|
+ 'Standard Deviation',
|
|
|
+ )
|
|
|
+ print('Additional Graphs Graphed')
|
|
|
+
|
|
|
+ # Compute Individual Statistics
|
|
|
+ print('Computing Individual Statistics...')
|
|
|
+ indv_statistics = compute_individual_statistics(predictions)
|
|
|
+ indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
|
|
|
+ print('Individual Statistics Computed')
|
|
|
+
|
|
|
+ # Compute Individual Thresholds
|
|
|
+ print('Computing Individual Thresholds...')
|
|
|
+ indv_thresholds = compute_individual_thresholds(indv_statistics)
|
|
|
+ indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
|
|
|
+ print('Individual Thresholds Computed')
|
|
|
+
|
|
|
+ # Graph Individual Thresholded Predictions
|
|
|
+ print('Graphing Individual Thresholded Predictions...')
|
|
|
+ if not os.path.exists(f'{V4_PATH}/indv'):
|
|
|
+ os.makedirs(f'{V4_PATH}/indv')
|
|
|
+
|
|
|
+ graph_all_individual_thresholded_predictions(
|
|
|
+ indv_thresholds, thresholded_predictions, V4_PATH
|
|
|
+ )
|
|
|
+ print('Individual Thresholded Predictions Graphed')
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|