Sfoglia il codice sorgente

Implemented ECE and MCE

Nicholas Schense 3 mesi fa
parent
commit
06a736e4f7
2 ha cambiato i file con 39 aggiunte e 1 eliminazioni
  1. 1 1
      config.toml
  2. 38 0
      threshold_refac.py

+ 1 - 1
config.toml

@@ -32,4 +32,4 @@ silent = false
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 individual_id = 1     # The id of the individual model to be used for the ensemble
-run_models = true     # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+run_models = false    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file

+ 38 - 0
threshold_refac.py

@@ -273,6 +273,44 @@ def main():
         ensemble_statistics, 'confidence', low_to_high=False
     )
 
+    # Print ECE and MCE Values
+    conf_ece = met.ECE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['confidence'],
+        ensemble_statistics['actual'],
+    )
+    conf_mce = met.MCE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['confidence'],
+        ensemble_statistics['actual'],
+    )
+
+    ent_ece = met.ECE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['entropy'],
+        ensemble_statistics['actual'],
+    )
+    ent_mce = met.MCE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['entropy'],
+        ensemble_statistics['actual'],
+    )
+
+    stdev_ece = met.ECE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['stdev'],
+        ensemble_statistics['actual'],
+    )
+    stdev_mce = met.MCE(
+        ensemble_statistics['predicted'],
+        ensemble_statistics['stdev'],
+        ensemble_statistics['actual'],
+    )
+
+    print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}')
+    print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}')
+    print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}')
+
     # Print overall ensemble statistics
     print('Ensemble Statistics')
     print(f"Accuracy: {ensemble_statistics['correct'].mean()}")