|
@@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as mtick
|
|
import matplotlib.ticker as mtick
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
# The datastructures for this file are as follows
|
|
# The datastructures for this file are as follows
|
|
# models_dict: Dictionary - {model_id: model}
|
|
# models_dict: Dictionary - {model_id: model}
|
|
# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
|
|
# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
|
|
@@ -65,19 +66,20 @@ def preprocess_data(data, device):
|
|
|
|
|
|
# Loads datasets and returns concatenated test and validation datasets
|
|
# Loads datasets and returns concatenated test and validation datasets
|
|
def load_datasets(ensemble_path):
|
|
def load_datasets(ensemble_path):
|
|
- return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
|
|
|
|
- f'{ensemble_path}/val_dataset.pt'
|
|
|
|
|
|
+ return (
|
|
|
|
+ torch.load(f'{ensemble_path}/test_dataset.pt'),
|
|
|
|
+ torch.load(f'{ensemble_path}/val_dataset.pt'),
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
# Gets the predictions for a set of models on a dataset
|
|
# Gets the predictions for a set of models on a dataset
|
|
-def get_ensemble_predictions(models, dataset, device):
|
|
|
|
|
|
+def get_ensemble_predictions(models, dataset, device, id_offset=0):
|
|
zeros = np.zeros((len(dataset), len(models), 4))
|
|
zeros = np.zeros((len(dataset), len(models), 4))
|
|
predictions = xr.DataArray(
|
|
predictions = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
dims=('data_id', 'model_id', 'prediction_value'),
|
|
dims=('data_id', 'model_id', 'prediction_value'),
|
|
coords={
|
|
coords={
|
|
- 'data_id': range(len(dataset)),
|
|
|
|
|
|
+ 'data_id': range(id_offset, len(dataset) + id_offset),
|
|
'model_id': list(models.keys()),
|
|
'model_id': list(models.keys()),
|
|
'prediction_value': [
|
|
'prediction_value': [
|
|
'negative_prediction',
|
|
'negative_prediction',
|
|
@@ -98,9 +100,9 @@ def get_ensemble_predictions(models, dataset, device):
|
|
output = model(dat)
|
|
output = model(dat)
|
|
prediction = output.cpu().numpy().tolist()[0]
|
|
prediction = output.cpu().numpy().tolist()[0]
|
|
|
|
|
|
- predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
|
|
|
|
- prediction + actual
|
|
|
|
- )
|
|
|
|
|
|
+ predictions.loc[
|
|
|
|
+ {'data_id': data_id + id_offset, 'model_id': model_id}
|
|
|
|
+ ] = prediction + actual
|
|
|
|
|
|
return predictions
|
|
return predictions
|
|
|
|
|
|
@@ -159,7 +161,7 @@ def compute_ensemble_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
# Compute the thresholded predictions given an array of predictions
|
|
# Compute the thresholded predictions given an array of predictions
|
|
def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
def compute_thresholded_predictions(input_stats: xr.DataArray):
|
|
- quantiles = np.linspace(0.05, 0.95, 19) * 100
|
|
|
|
|
|
+ quantiles = np.linspace(0.00, 1.00, 21) * 100
|
|
metrics = ['accuracy', 'f1']
|
|
metrics = ['accuracy', 'f1']
|
|
statistics = ['stdev', 'entropy', 'confidence']
|
|
statistics = ['stdev', 'entropy', 'confidence']
|
|
|
|
|
|
@@ -217,6 +219,13 @@ def compute_metric(arr, metric):
|
|
return met.F1(
|
|
return met.F1(
|
|
arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
|
|
arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
|
|
)
|
|
)
|
|
|
|
+ elif metric == 'ece':
|
|
|
|
+ true_labels = arr.loc[{'statistic': 'actual'}].values
|
|
|
|
+ predicted_labels = arr.loc[{'statistic': 'predicted'}].values
|
|
|
|
+ confidences = arr.loc[{'statistic': 'confidence'}].values
|
|
|
|
+
|
|
|
|
+ return calculate_ece_stats(confidences, predicted_labels, true_labels)
|
|
|
|
+
|
|
else:
|
|
else:
|
|
raise ValueError('Invalid metric: ' + metric)
|
|
raise ValueError('Invalid metric: ' + metric)
|
|
|
|
|
|
@@ -251,8 +260,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'confidence',
|
|
'confidence',
|
|
'accuracy',
|
|
'accuracy',
|
|
f'{save_path}/confidence_accuracy.png',
|
|
f'{save_path}/confidence_accuracy.png',
|
|
- 'Confidence vs. Accuracy',
|
|
|
|
- 'Confidence',
|
|
|
|
|
|
+ 'Coverage Analysis of Confidence vs. Accuracy',
|
|
|
|
+ 'Minimum Confidence Percentile Threshold',
|
|
'Accuracy',
|
|
'Accuracy',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -262,9 +271,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'confidence',
|
|
'confidence',
|
|
'f1',
|
|
'f1',
|
|
f'{save_path}/confidence_f1.png',
|
|
f'{save_path}/confidence_f1.png',
|
|
- 'Confidence vs. F1',
|
|
|
|
- 'Confidence',
|
|
|
|
- 'F1',
|
|
|
|
|
|
+ 'Coverage Analysis of Confidence vs. F1 Score',
|
|
|
|
+ 'Minimum Confidence Percentile Threshold',
|
|
|
|
+ 'F1 Score',
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy Accuracy
|
|
# Entropy Accuracy
|
|
@@ -273,8 +282,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'entropy',
|
|
'entropy',
|
|
'accuracy',
|
|
'accuracy',
|
|
f'{save_path}/entropy_accuracy.png',
|
|
f'{save_path}/entropy_accuracy.png',
|
|
- 'Entropy vs. Accuracy',
|
|
|
|
- 'Entropy',
|
|
|
|
|
|
+ 'Coverage Analysis of Entropy vs. Accuracy',
|
|
|
|
+ 'Maximum Entropy Percentile Threshold',
|
|
'Accuracy',
|
|
'Accuracy',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -285,9 +294,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'entropy',
|
|
'entropy',
|
|
'f1',
|
|
'f1',
|
|
f'{save_path}/entropy_f1.png',
|
|
f'{save_path}/entropy_f1.png',
|
|
- 'Entropy vs. F1',
|
|
|
|
- 'Entropy',
|
|
|
|
- 'F1',
|
|
|
|
|
|
+ 'Coverage Analysis of Entropy vs. F1 Score',
|
|
|
|
+ 'Maximum Entropy Percentile Threshold',
|
|
|
|
+ 'F1 Score',
|
|
)
|
|
)
|
|
|
|
|
|
# Stdev Accuracy
|
|
# Stdev Accuracy
|
|
@@ -296,8 +305,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'stdev',
|
|
'stdev',
|
|
'accuracy',
|
|
'accuracy',
|
|
f'{save_path}/stdev_accuracy.png',
|
|
f'{save_path}/stdev_accuracy.png',
|
|
- 'Standard Deviation vs. Accuracy',
|
|
|
|
- 'Standard Deviation',
|
|
|
|
|
|
+ 'Coverage Analysis of Standard Deviation vs. Accuracy',
|
|
|
|
+ 'Maximum Standard Deviation Percentile Threshold',
|
|
'Accuracy',
|
|
'Accuracy',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -307,8 +316,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
|
|
'stdev',
|
|
'stdev',
|
|
'f1',
|
|
'f1',
|
|
f'{save_path}/stdev_f1.png',
|
|
f'{save_path}/stdev_f1.png',
|
|
- 'Standard Deviation vs. F1',
|
|
|
|
- 'Standard Deviation',
|
|
|
|
|
|
+ 'Coverage Analysis of Standard Deviation vs. F1 Score',
|
|
|
|
+ 'Maximum Standard Deviation Percentile Threshold',
|
|
'F1',
|
|
'F1',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -371,7 +380,9 @@ def compute_individual_statistics(predictions: xr.DataArray):
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
- for data_id in predictions.data_id:
|
|
|
|
|
|
+ for data_id in tqdm(
|
|
|
|
+ predictions.data_id, total=len(predictions.data_id), unit='images'
|
|
|
|
+ ):
|
|
for model_id in predictions.model_id:
|
|
for model_id in predictions.model_id:
|
|
data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
|
|
data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
|
|
mean = data[0:2]
|
|
mean = data[0:2]
|
|
@@ -414,7 +425,9 @@ def compute_individual_thresholds(input_stats: xr.DataArray):
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
- for model_id in input_stats.model_id:
|
|
|
|
|
|
+ for model_id in tqdm(
|
|
|
|
+ input_stats.model_id, total=len(input_stats.model_id), unit='models'
|
|
|
|
+ ):
|
|
for statistic in statistics:
|
|
for statistic in statistics:
|
|
# First, we must compute the quantiles for the statistic
|
|
# First, we must compute the quantiles for the statistic
|
|
quantile_values = np.percentile(
|
|
quantile_values = np.percentile(
|
|
@@ -504,8 +517,8 @@ def graph_all_individual_thresholded_predictions(
|
|
'confidence',
|
|
'confidence',
|
|
'accuracy',
|
|
'accuracy',
|
|
f'{save_path}/indv/confidence_accuracy.png',
|
|
f'{save_path}/indv/confidence_accuracy.png',
|
|
- 'Confidence vs. Accuracy',
|
|
|
|
- 'Confidence Percentile Threshold',
|
|
|
|
|
|
+ 'Coverage Analysis of Confidence vs. Accuracy for All Models',
|
|
|
|
+ 'Minumum Confidence Percentile Threshold',
|
|
'Accuracy',
|
|
'Accuracy',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -516,9 +529,9 @@ def graph_all_individual_thresholded_predictions(
|
|
'confidence',
|
|
'confidence',
|
|
'f1',
|
|
'f1',
|
|
f'{save_path}/indv/confidence_f1.png',
|
|
f'{save_path}/indv/confidence_f1.png',
|
|
- 'Confidence vs. F1',
|
|
|
|
- 'Confidence Percentile Threshold',
|
|
|
|
- 'F1',
|
|
|
|
|
|
+ 'Coverage Analysis of Confidence vs. F1 Score for All Models',
|
|
|
|
+ 'Minimum Confidence Percentile Threshold',
|
|
|
|
+ 'F1 Score',
|
|
)
|
|
)
|
|
|
|
|
|
# Entropy Accuracy
|
|
# Entropy Accuracy
|
|
@@ -528,8 +541,8 @@ def graph_all_individual_thresholded_predictions(
|
|
'entropy',
|
|
'entropy',
|
|
'accuracy',
|
|
'accuracy',
|
|
f'{save_path}/indv/entropy_accuracy.png',
|
|
f'{save_path}/indv/entropy_accuracy.png',
|
|
- 'Entropy vs. Accuracy',
|
|
|
|
- 'Entropy Percentile Threshold',
|
|
|
|
|
|
+ 'Coverage Analysis of Entropy vs. Accuracy for All Models',
|
|
|
|
+ 'Maximum Entropy Percentile Threshold',
|
|
'Accuracy',
|
|
'Accuracy',
|
|
)
|
|
)
|
|
|
|
|
|
@@ -540,16 +553,17 @@ def graph_all_individual_thresholded_predictions(
|
|
'entropy',
|
|
'entropy',
|
|
'f1',
|
|
'f1',
|
|
f'{save_path}/indv/entropy_f1.png',
|
|
f'{save_path}/indv/entropy_f1.png',
|
|
- 'Entropy vs. F1',
|
|
|
|
- 'Entropy Percentile Threshold',
|
|
|
|
- 'F1',
|
|
|
|
|
|
+ 'Coverage Analysis of Entropy vs. F1 Score for All Models',
|
|
|
|
+ 'Maximum Entropy Percentile Threshold',
|
|
|
|
+ 'F1 Score',
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
# Calculate statistics of subsets of models for sensitivity analysis
|
|
# Calculate statistics of subsets of models for sensitivity analysis
|
|
def calculate_subset_statistics(predictions: xr.DataArray):
|
|
def calculate_subset_statistics(predictions: xr.DataArray):
|
|
- # Calculate subsets for 1-50 models
|
|
|
|
- subsets = range(1, len(predictions.model_id) + 1)
|
|
|
|
|
|
+ # Calculate subsets for 1-49 models
|
|
|
|
+ subsets = range(1, len(predictions.model_id))
|
|
|
|
+
|
|
zeros = np.zeros(
|
|
zeros = np.zeros(
|
|
(len(predictions.data_id), len(subsets), 7)
|
|
(len(predictions.data_id), len(subsets), 7)
|
|
) # Include stdev, but for 1 models set to NaN
|
|
) # Include stdev, but for 1 models set to NaN
|
|
@@ -572,7 +586,9 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
},
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
- for data_id in predictions.data_id:
|
|
|
|
|
|
+ for data_id in tqdm(
|
|
|
|
+ predictions.data_id, total=len(predictions.data_id), unit='images'
|
|
|
|
+ ):
|
|
for subset in subsets:
|
|
for subset in subsets:
|
|
data = predictions.sel(
|
|
data = predictions.sel(
|
|
data_id=data_id, model_id=predictions.model_id[:subset]
|
|
data_id=data_id, model_id=predictions.model_id[:subset]
|
|
@@ -581,7 +597,7 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
stdev = data.std(dim='model_id')[1]
|
|
stdev = data.std(dim='model_id')[1]
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
confidence = mean.max()
|
|
confidence = mean.max()
|
|
- actual = data[3]
|
|
|
|
|
|
+ actual = data[0][3]
|
|
predicted = mean.argmax()
|
|
predicted = mean.argmax()
|
|
correct = actual == predicted
|
|
correct = actual == predicted
|
|
|
|
|
|
@@ -600,17 +616,80 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
- subsets = subset_stats.subsets
|
|
|
|
- stats = ['accuracy', 'f1', 'ECE', 'MCE']
|
|
|
|
|
|
+ subsets = subset_stats.model_count
|
|
|
|
+ stats = ['accuracy', 'f1', 'ece']
|
|
|
|
|
|
zeros = np.zeros((len(subsets), len(stats)))
|
|
zeros = np.zeros((len(subsets), len(stats)))
|
|
|
|
|
|
sens_analysis = xr.DataArray(
|
|
sens_analysis = xr.DataArray(
|
|
zeros,
|
|
zeros,
|
|
dims=('model_count', 'statistic'),
|
|
dims=('model_count', 'statistic'),
|
|
- coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
|
|
|
|
|
|
+ coords={'model_count': subsets, 'statistic': stats},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ for subset in tqdm(subsets, total=len(subsets), unit='model subsets'):
|
|
|
|
+
|
|
|
|
+ data = subset_stats.sel(model_count=subset)
|
|
|
|
+ acc = compute_metric(data, 'accuracy').item()
|
|
|
|
+ f1 = compute_metric(data, 'f1').item()
|
|
|
|
+ ece = compute_metric(data, 'ece').item()
|
|
|
|
+
|
|
|
|
+ sens_analysis.loc[{'model_count': subset.item()}] = [acc, f1, ece]
|
|
|
|
+
|
|
|
|
+ return sens_analysis
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def graph_sensitivity_analysis(
|
|
|
|
+ sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
|
|
|
|
+):
|
|
|
|
+ data = sens_analysis.sel(statistic=statistic)
|
|
|
|
+
|
|
|
|
+ xdata = data.coords['model_count'].values
|
|
|
|
+ ydata = data.values
|
|
|
|
+
|
|
|
|
+ fig, ax = plt.subplots()
|
|
|
|
+ ax.plot(xdata, ydata)
|
|
|
|
+ ax.set_title(title)
|
|
|
|
+ ax.set_xlabel(xlabel)
|
|
|
|
+ ax.set_ylabel(ylabel)
|
|
|
|
+
|
|
|
|
+ plt.savefig(save_path)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def calculate_overall_stats(ensemble_statistics: xr.DataArray):
|
|
|
|
+ accuracy = compute_metric(ensemble_statistics, 'accuracy')
|
|
|
|
+ f1 = compute_metric(ensemble_statistics, 'f1')
|
|
|
|
+
|
|
|
|
+ return {'accuracy': accuracy.item(), 'f1': f1.item()}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
|
|
|
|
+def calculate_ece_stats(confidences, predicted_labels, true_labels, bins=10):
|
|
|
|
+ bin_boundaries = np.linspace(0, 1, bins + 1)
|
|
|
|
+ bin_lowers = bin_boundaries[:-1]
|
|
|
|
+ bin_uppers = bin_boundaries[1:]
|
|
|
|
+
|
|
|
|
+ ece = np.zeros(1)
|
|
|
|
+
|
|
|
|
+ for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
|
|
|
|
+ in_bin = np.logical_and(
|
|
|
|
+ confidences > bin_lower.item(), confidences <= bin_upper.item()
|
|
|
|
+ )
|
|
|
|
+ prob_in_bin = in_bin.mean()
|
|
|
|
+
|
|
|
|
+ if prob_in_bin.item() > 0:
|
|
|
|
+ accuracy_in_bin = true_labels[in_bin].mean()
|
|
|
|
+
|
|
|
|
+ avg_confidence_in_bin = confidences[in_bin].mean()
|
|
|
|
+
|
|
|
|
+ ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
|
|
|
|
+
|
|
|
|
+ return ece
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
|
|
|
|
+ fix, ax = plt.subplot()
|
|
|
|
+
|
|
|
|
|
|
# Main Function
|
|
# Main Function
|
|
def main():
|
|
def main():
|
|
@@ -625,7 +704,7 @@ def main():
|
|
|
|
|
|
# Load Datasets
|
|
# Load Datasets
|
|
print('Loading Datasets...')
|
|
print('Loading Datasets...')
|
|
- dataset = load_datasets(ENSEMBLE_PATH)
|
|
|
|
|
|
+ (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
|
|
print('Datasets Loaded')
|
|
print('Datasets Loaded')
|
|
|
|
|
|
# Get Predictions, either by running the models or loading them from a file
|
|
# Get Predictions, either by running the models or loading them from a file
|
|
@@ -638,20 +717,29 @@ def main():
|
|
|
|
|
|
# Get Predictions
|
|
# Get Predictions
|
|
print('Getting Predictions...')
|
|
print('Getting Predictions...')
|
|
- predictions = get_ensemble_predictions(models, dataset, device)
|
|
|
|
|
|
+ test_predictions = get_ensemble_predictions(models, test_dataset, device)
|
|
|
|
+ val_predictions = get_ensemble_predictions(
|
|
|
|
+ models, val_dataset, device, len(test_dataset)
|
|
|
|
+ )
|
|
print('Predictions Loaded')
|
|
print('Predictions Loaded')
|
|
|
|
|
|
# Save Prediction
|
|
# Save Prediction
|
|
- predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
|
|
|
|
|
|
+ test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
|
|
|
|
+ val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
|
|
|
|
|
|
else:
|
|
else:
|
|
- predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
|
|
|
|
|
|
+ test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
|
|
|
|
+ val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
|
|
|
|
|
|
# Prune Data
|
|
# Prune Data
|
|
print('Pruning Data...')
|
|
print('Pruning Data...')
|
|
if config['operation']['exclude_blank_ids']:
|
|
if config['operation']['exclude_blank_ids']:
|
|
excluded_data_ids = config['ensemble']['excluded_ids']
|
|
excluded_data_ids = config['ensemble']['excluded_ids']
|
|
- predictions = prune_data(predictions, excluded_data_ids)
|
|
|
|
|
|
+ test_predictions = prune_data(test_predictions, excluded_data_ids)
|
|
|
|
+ val_predictions = prune_data(val_predictions, excluded_data_ids)
|
|
|
|
+
|
|
|
|
+ # Concatenate Predictions
|
|
|
|
+ predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
|
|
|
|
|
|
# Compute Ensemble Statistics
|
|
# Compute Ensemble Statistics
|
|
print('Computing Ensemble Statistics...')
|
|
print('Computing Ensemble Statistics...')
|
|
@@ -678,7 +766,7 @@ def main():
|
|
'confidence',
|
|
'confidence',
|
|
'stdev',
|
|
'stdev',
|
|
f'{V4_PATH}/confidence_stdev.png',
|
|
f'{V4_PATH}/confidence_stdev.png',
|
|
- 'Confidence vs. Standard Deviation',
|
|
|
|
|
|
+ 'Confidence and Standard Deviation for Predictions',
|
|
'Confidence',
|
|
'Confidence',
|
|
'Standard Deviation',
|
|
'Standard Deviation',
|
|
)
|
|
)
|
|
@@ -706,6 +794,29 @@ def main():
|
|
)
|
|
)
|
|
print('Individual Thresholded Predictions Graphed')
|
|
print('Individual Thresholded Predictions Graphed')
|
|
|
|
|
|
|
|
+ # Compute subset statistics and graph
|
|
|
|
+ print('Computing Sensitivity Analysis...')
|
|
|
|
+ subset_stats = calculate_subset_statistics(predictions)
|
|
|
|
+ sens_analysis = calculate_sensitivity_analysis(subset_stats)
|
|
|
|
+ graph_sensitivity_analysis(
|
|
|
|
+ sens_analysis,
|
|
|
|
+ 'accuracy',
|
|
|
|
+ f'{V4_PATH}/sens_analysis.png',
|
|
|
|
+ 'Sensitivity Analsis of Accuracy vs. # of Models',
|
|
|
|
+ '# of Models',
|
|
|
|
+ 'Accuracy',
|
|
|
|
+ )
|
|
|
|
+ graph_sensitivity_analysis(
|
|
|
|
+ sens_analysis,
|
|
|
|
+ 'ece',
|
|
|
|
+ f'{V4_PATH}/sens_analysis_ece.png',
|
|
|
|
+ 'Sensitivity Analysis of ECE vs. # of Models',
|
|
|
|
+ '# of Models',
|
|
|
|
+ 'ECE',
|
|
|
|
+ )
|
|
|
|
+ print(sens_analysis.sel(statistic='accuracy'))
|
|
|
|
+ print(calculate_overall_stats(ensemble_statistics))
|
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if __name__ == '__main__':
|
|
main()
|
|
main()
|