|
@@ -238,6 +238,26 @@ def common_entries(*dcts):
|
|
|
for i in set(dcts[0]).intersection(*dcts[1:]):
|
|
|
yield (i,) + tuple(d[i] for d in dcts)
|
|
|
|
|
|
+#Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
|
|
|
+def calculate_overall_statistics(ensemble_statistics):
|
|
|
+ predicted = ensemble_statistics['predicted']
|
|
|
+ actual = ensemble_statistics['actual']
|
|
|
+
|
|
|
+ # New dataframe to store the statistics
|
|
|
+ stats_df = pd.DataFrame(columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']).set_index('stat')
|
|
|
+
|
|
|
+ # Loop through and calculate the ECE, MCE, Brier Score, and NLL
|
|
|
+ for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
|
|
|
+ ece = met.ECE(predicted, ensemble_statistics[stat], actual)
|
|
|
+ mce = met.MCE(predicted, ensemble_statistics[stat], actual)
|
|
|
+ brier = met.brier_binary(ensemble_statistics[stat], actual)
|
|
|
+ nll = met.nll_binary(ensemble_statistics[stat], actual)
|
|
|
+
|
|
|
+ stats_df.loc[stat] = [ece, mce, brier, nll]
|
|
|
+
|
|
|
+ return stats_df
|
|
|
+
|
|
|
+
|
|
|
|
|
|
def main():
|
|
|
# Load the models
|
|
@@ -273,43 +293,14 @@ def main():
|
|
|
ensemble_statistics, 'confidence', low_to_high=False
|
|
|
)
|
|
|
|
|
|
- # Print ECE and MCE Values
|
|
|
- conf_ece = met.ECE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['confidence'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
- conf_mce = met.MCE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['confidence'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
-
|
|
|
- ent_ece = met.ECE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['entropy'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
- ent_mce = met.MCE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['entropy'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
+ raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
|
|
|
+ ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
|
|
|
|
|
|
- stdev_ece = met.ECE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['stdev'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
- stdev_mce = met.MCE(
|
|
|
- ensemble_statistics['predicted'],
|
|
|
- ensemble_statistics['stdev'],
|
|
|
- ensemble_statistics['actual'],
|
|
|
- )
|
|
|
+ # Calculate overall statistics
|
|
|
+ overall_statistics = calculate_overall_statistics(ensemble_statistics)
|
|
|
|
|
|
- print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}')
|
|
|
- print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}')
|
|
|
- print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}')
|
|
|
+ # Print overall statistics
|
|
|
+ print(overall_statistics)
|
|
|
|
|
|
# Print overall ensemble statistics
|
|
|
print('Ensemble Statistics')
|