|
@@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|
|
import matplotlib.ticker as mtick
|
|
|
|
|
|
|
|
|
+
|
|
|
# The datastructures for this file are as follows
|
|
|
# models_dict: Dictionary - {model_id: model}
|
|
|
# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
|
|
@@ -581,7 +582,7 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
|
stdev = data.std(dim='model_id')[1]
|
|
|
entropy = (-mean * np.log(mean)).sum()
|
|
|
confidence = mean.max()
|
|
|
- actual = data[3]
|
|
|
+ actual = data[0][3]
|
|
|
predicted = mean.argmax()
|
|
|
correct = actual == predicted
|
|
|
|
|
@@ -600,17 +601,50 @@ def calculate_subset_statistics(predictions: xr.DataArray):
|
|
|
|
|
|
# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
|
|
|
def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
|
|
|
- subsets = subset_stats.subsets
|
|
|
- stats = ['accuracy', 'f1', 'ECE', 'MCE']
|
|
|
+ subsets = subset_stats.model_count
|
|
|
+ stats = ['accuracy', 'f1']
|
|
|
|
|
|
zeros = np.zeros((len(subsets), len(stats)))
|
|
|
|
|
|
sens_analysis = xr.DataArray(
|
|
|
zeros,
|
|
|
dims=('model_count', 'statistic'),
|
|
|
- coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
|
|
|
+ coords={'model_count': subsets, 'statistic': ['accuracy', 'f1']},
|
|
|
)
|
|
|
|
|
|
+ for subset in subsets:
|
|
|
+ data = subset_stats.sel(model_count=subset)
|
|
|
+ acc = compute_metric(data, 'accuracy')
|
|
|
+ f1 = compute_metric(data, 'f1')
|
|
|
+
|
|
|
+ sens_analysis.loc[{'model_count': subset}] = [acc, f1]
|
|
|
+
|
|
|
+ return sens_analysis
|
|
|
+
|
|
|
+
|
|
|
+def graph_sensitivity_analysis(
|
|
|
+ sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
|
|
|
+):
|
|
|
+ data = sens_analysis.sel(statistic=statistic)
|
|
|
+
|
|
|
+ xdata = data.coords['model_count'].values
|
|
|
+ ydata = data.values
|
|
|
+
|
|
|
+ fig, ax = plt.subplots()
|
|
|
+ ax.plot(xdata, ydata)
|
|
|
+ ax.set_title(title)
|
|
|
+ ax.set_xlabel(xlabel)
|
|
|
+ ax.set_ylabel(ylabel)
|
|
|
+
|
|
|
+ plt.savefig(save_path)
|
|
|
+
|
|
|
+
|
|
|
+def calculate_overall_stats(ensemble_statistics: xr.DataArray):
|
|
|
+ accuracy = compute_metric(ensemble_statistics, 'accuracy')
|
|
|
+ f1 = compute_metric(ensemble_statistics, 'f1')
|
|
|
+
|
|
|
+ return {'accuracy': accuracy.item(), 'f1': f1.item()}
|
|
|
+
|
|
|
|
|
|
# Main Function
|
|
|
def main():
|
|
@@ -706,6 +740,20 @@ def main():
|
|
|
)
|
|
|
print('Individual Thresholded Predictions Graphed')
|
|
|
|
|
|
+ # Compute subset statistics and graph
|
|
|
+ subset_stats = calculate_subset_statistics(predictions)
|
|
|
+ sens_analysis = calculate_sensitivity_analysis(subset_stats)
|
|
|
+ graph_sensitivity_analysis(
|
|
|
+ sens_analysis,
|
|
|
+ 'accuracy',
|
|
|
+ f'{V4_PATH}/sens_analysis.png',
|
|
|
+ 'Sensitivity Analysis',
|
|
|
+ '# of Models',
|
|
|
+ 'Accuracy',
|
|
|
+ )
|
|
|
+ print(sens_analysis.sel(statistic='accuracy'))
|
|
|
+ print(calculate_overall_stats(ensemble_statistics))
|
|
|
+
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|