Browse Source

Implemented metrics

Nicholas Schense 4 months ago
parent
commit
fb1dbc94c2
1 changed files with 87 additions and 21 deletions
  1. 87 21
      threshold.py

+ 87 - 21
threshold.py

@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
 import sklearn.metrics as metrics
 from tqdm import tqdm
 import utils.metrics as met
+import itertools as it
 
 RUN = False
 
@@ -102,12 +103,15 @@ def get_predictions(config):
         # Get the confidence and standard deviation of the predicted class
         print(stdev)
         pc_stdev = np.squeeze(stdev)[predicted_class]
+        # Get the individual classes
+        class_1 = mean[0][0]
+        class_2 = mean[0][1]
 
         # Get the true label
         true_label = true_label[1]
 
-        confidences.append((predicted_class, confidence, true_label))
-        stdevs.append((predicted_class, pc_stdev, true_label))
+        confidences.append((predicted_class, confidence, true_label, class_1, class_2))
+        stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
 
     return results, confidences, stdevs, indv_results
 
@@ -116,9 +120,12 @@ if RUN:
     results, confs, stdevs, indv_results = get_predictions(config)
     # Convert to pandas dataframes
     confs_df = pd.DataFrame(
-        confs, columns=['predicted_class', 'confidence', 'true_label']
+        confs,
+        columns=['predicted_class', 'confidence', 'true_label', 'class_1', 'class_2'],
+    )
+    stdevs_df = pd.DataFrame(
+        stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
     )
-    stdevs_df = pd.DataFrame(stdevs, columns=['predicted_class', 'stdev', 'true_label'])
 
     indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
 
@@ -133,21 +140,21 @@ else:
     stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
     indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
 
-# Plot confidence vs standard deviation
-plt.scatter(confs_df['confidence'], stdevs_df['stdev'])
+# Plot confidence vs standard deviation, and change color of dots based on if they are correct
+correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
+incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
+
+correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
+incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
+
+plt.scatter(correct_conf['confidence'], correct_stdev['stdev'], color='green')
+plt.scatter(incorrect_conf['confidence'], incorrect_stdev['stdev'], color='red')
 plt.xlabel('Confidence')
 plt.ylabel('Standard Deviation')
 plt.title('Confidence vs Standard Deviation')
 plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
 plt.close()
 
-# Calculate Binning for Coverage
-# Sort Dataframes
-confs_df = confs_df.sort_values(by='confidence')
-stdevs_df = stdevs_df.sort_values(by='stdev')
-
-confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
-stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
 
 # Calculate individual model accuracy
 # Determine predicted class
@@ -157,22 +164,34 @@ indv_df['predicted_class'] = indv_df['predicted_class'].apply(
 )
 indv_df['correct'] = indv_df['predicted_class'] == indv_df['true_label']
 accuracy_indv = indv_df['correct'].mean()
+f1_indv = met.F1(
+    indv_df['predicted_class'].to_numpy(), indv_df['true_label'].to_numpy()
+)
+auc_indv = metrics.roc_auc_score(
+    indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
+)
 
 # Calculate percentiles for confidence and standard deviation
-quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11))['confidence']
-quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11))['stdev']
+quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
+    'confidence'
+]
+quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
+    'stdev'
+]
 
 accuracies_conf = []
 # Use the quantiles to calculate the coverage
-for quantile in quantiles_conf.items():
+iter_conf = it.islice(quantiles_conf.items(), 0, None)
+for quantile in iter_conf:
     percentile = quantile[0]
 
     filt = confs_df[confs_df['confidence'] >= quantile[1]]
     accuracy = (
         filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
     )
+    f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
 
-    accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy})
+    accuracies_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
 
 accuracies_df = pd.DataFrame(accuracies_conf)
 
@@ -188,21 +207,38 @@ plt.xlabel('Percentile')
 plt.ylabel('Accuracy')
 plt.title('Coverage conf')
 plt.legend()
-plt.savefig(f'{V2_PATH}/coverage.png')
+plt.savefig(f'{V2_PATH}/coverage_conf.png')
 plt.close()
 
+# Plot coverage vs F1 for confidence
+plt.plot(accuracies_df['percentile'], accuracies_df['f1'], label='Ensemble')
+plt.plot(
+    accuracies_df['percentile'],
+    [f1_indv] * len(accuracies_df['percentile']),
+    label='Individual',
+    linestyle='--',
+)
+plt.xlabel('Percentile')
+plt.ylabel('F1')
+plt.title('Coverage F1')
+plt.legend()
+plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
+plt.close()
+
+
 # Repeat for standard deviation
 accuracies_stdev = []
-
-for quantile in quantiles_stdev.items():
+iter_stdev = it.islice(quantiles_stdev.items(), 0, None)
+for quantile in iter_stdev:
     percentile = quantile[0]
 
     filt = stdevs_df[stdevs_df['stdev'] <= quantile[1]]
     accuracy = (
         filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
     )
+    f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
 
-    accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy})
+    accuracies_stdev.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
 
 accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
 
@@ -223,3 +259,33 @@ plt.legend()
 plt.gca().invert_xaxis()
 plt.savefig(f'{V2_PATH}/coverage_stdev.png')
 plt.close()
+
+# Plot coverage vs F1 for standard deviation
+plt.plot(accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], label='Ensemble')
+plt.plot(
+    accuracies_stdev_df['percentile'],
+    [f1_indv] * len(accuracies_stdev_df['percentile']),
+    label='Individual',
+    linestyle='--',
+)
+plt.xlabel('Percentile')
+plt.ylabel('F1')
+plt.title('Coverage F1 Stdev')
+plt.legend()
+plt.gca().invert_xaxis()
+
+plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
+
+plt.close()
+
+
+# Print overall accuracy
+overall_accuracy = (
+    confs_df[confs_df['predicted_class'] == confs_df['true_label']].shape[0]
+    / confs_df.shape[0]
+)
+overall_f1 = met.F1(
+    confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
+)
+
+print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1}')