Bläddra i källkod

More work done on overall stats

Nicholas Schense 4 månader sedan
förälder
incheckning
f2e7f78a40
3 ändrade filer med 38 tillägg och 36 borttagningar
  1. 1 1
      config.toml
  2. 26 35
      threshold_refac.py
  3. 11 0
      utils/metrics.py

+ 1 - 1
config.toml

@@ -32,4 +32,4 @@ silent = false
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 individual_id = 1     # The id of the individual model to be used for the ensemble
-run_models = false    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+run_models = true    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file

+ 26 - 35
threshold_refac.py

@@ -238,6 +238,26 @@ def common_entries(*dcts):
     for i in set(dcts[0]).intersection(*dcts[1:]):
         yield (i,) + tuple(d[i] for d in dcts)
 
+#Given ensemble statistics, calculate overall stats (ECE, MCE, Brier Score, NLL)
+def calculate_overall_statistics(ensemble_statistics):
+    predicted = ensemble_statistics['predicted']
+    actual = ensemble_statistics['actual']
+
+    # New dataframe to store the statistics
+    stats_df = pd.DataFrame(columns=['stat', 'ECE', 'MCE', 'Brier Score', 'NLL']).set_index('stat')
+
+    # Loop through and calculate the ECE, MCE, Brier Score, and NLL
+    for stat in ['confidence', 'entropy', 'stdev', 'raw_confidence']:
+        ece = met.ECE(predicted, ensemble_statistics[stat], actual)
+        mce = met.MCE(predicted, ensemble_statistics[stat], actual)
+        brier = met.brier_binary(ensemble_statistics[stat], actual)
+        nll = met.nll_binary(ensemble_statistics[stat], actual)
+
+        stats_df.loc[stat] = [ece, mce, brier, nll]
+    
+    return stats_df
+
+
 
 def main():
     # Load the models
@@ -273,43 +293,14 @@ def main():
         ensemble_statistics, 'confidence', low_to_high=False
     )
 
-    # Print ECE and MCE Values
-    conf_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['confidence'],
-        ensemble_statistics['actual'],
-    )
-    conf_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['confidence'],
-        ensemble_statistics['actual'],
-    )
-
-    ent_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['entropy'],
-        ensemble_statistics['actual'],
-    )
-    ent_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['entropy'],
-        ensemble_statistics['actual'],
-    )
+    raw_confidence = ensemble_statistics['confidence'].apply(lambda x: (x / 2) + 0.5)
+    ensemble_statistics.insert(4, 'raw_confidence', raw_confidence)
 
-    stdev_ece = met.ECE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['stdev'],
-        ensemble_statistics['actual'],
-    )
-    stdev_mce = met.MCE(
-        ensemble_statistics['predicted'],
-        ensemble_statistics['stdev'],
-        ensemble_statistics['actual'],
-    )
+    # Calculate overall statistics
+    overall_statistics = calculate_overall_statistics(ensemble_statistics)
 
-    print(f'Confidence ECE: {conf_ece}, Confidence MCE: {conf_mce}')
-    print(f'Entropy ECE: {ent_ece}, Entropy MCE: {ent_mce}')
-    print(f'Stdev ECE: {stdev_ece}, Stdev MCE: {stdev_mce}')
+    # Print overall statistics
+    print(overall_statistics)
 
     # Print overall ensemble statistics
     print('Ensemble Statistics')

+ 11 - 0
utils/metrics.py

@@ -75,3 +75,14 @@ def AUC(confidences, true_labels):
 
 def entropy(confidences):
     return -np.sum(confidences * np.log(confidences))
+
+### Negative Log Likelyhood for binary classification
+def nll_binary(confidences, true_labels):
+    return -np.sum(np.log(confidences[true_labels == 1])) - np.sum(np.log(1 - confidences[true_labels == 0]))
+
+### Breier score for binary classification
+def brier_binary(confidences, true_labels):
+    return np.mean((confidences - true_labels) ** 2)
+
+
+