Explorar el Código

Implemented entropy calculation

Nicholas Schense hace 2 meses
padre
commit
f18dfa3d01
Se han modificado 1 ficheros con 156 adiciones y 29 borrados
  1. 156 29
      threshold.py

+ 156 - 29
threshold.py

@@ -10,8 +10,9 @@ import sklearn.metrics as metrics
 from tqdm import tqdm
 import utils.metrics as met
 import itertools as it
+import matplotlib.ticker as ticker
 
-RUN = False
+RUN = True
 
 # CONFIGURATION
 if os.getenv('ADL_CONFIG_PATH') is None:
@@ -46,6 +47,7 @@ def get_predictions(config):
     test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
         f'{ENSEMBLE_PATH}/val_dataset.pt'
     )
+    print(f'Loaded {len(test_set)} samples')
 
     # [([model results], labels)]
     results = []
@@ -87,12 +89,18 @@ def get_predictions(config):
     # [(ensemble predicted class, ensemble standard deviation, true label)]
     stdevs = []
 
+    # [(ensemble predicted class, ensemble entropy, true label)]
+    entropies = []
+
     for result in results:
         model_results, true_label = result
         # Get the ensemble mean and variance with numpy, as these are lists
         mean = np.mean(model_results, axis=0)
         variance = np.var(model_results, axis=0)
 
+        # Calculate the entropy
+        entropy = -1 * np.sum(mean * np.log(mean))
+
         # Calculate confidence and standard deviation
         confidence = (np.max(mean) - 0.5) * 2
         stdev = np.sqrt(variance)
@@ -112,12 +120,13 @@ def get_predictions(config):
 
         confidences.append((predicted_class, confidence, true_label, class_1, class_2))
         stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
+        entropies.append((predicted_class, entropy, true_label, class_1, class_2))
 
-    return results, confidences, stdevs, indv_results
+    return results, confidences, stdevs, entropies, indv_results
 
 
 if RUN:
-    results, confs, stdevs, indv_results = get_predictions(config)
+    results, confs, stdevs, entropies, indv_results = get_predictions(config)
     # Convert to pandas dataframes
     confs_df = pd.DataFrame(
         confs,
@@ -127,6 +136,11 @@ if RUN:
         stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
     )
 
+    entropies_df = pd.DataFrame(
+        entropies,
+        columns=['predicted_class', 'entropy', 'true_label', 'class_1', 'class_2'],
+    )
+
     indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
 
     if not os.path.exists(V2_PATH):
@@ -134,10 +148,12 @@ if RUN:
 
     confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
     stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
+    entropies_df.to_csv(f'{V2_PATH}/ensemble_entropies.csv')
     indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
 else:
     confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
     stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
+    entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
     indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
 
 # Plot confidence vs standard deviation, and change color of dots based on if they are correct
@@ -147,12 +163,25 @@ incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
 correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
 incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
 
-plt.scatter(correct_conf['confidence'], correct_stdev['stdev'], color='green')
-plt.scatter(incorrect_conf['confidence'], incorrect_stdev['stdev'], color='red')
-plt.xlabel('Confidence')
-plt.ylabel('Standard Deviation')
+plot, ax = plt.subplots()
+plt.scatter(
+    correct_conf['confidence'],
+    correct_stdev['stdev'],
+    color='green',
+    label='Correct Prediction',
+)
+plt.scatter(
+    incorrect_conf['confidence'],
+    incorrect_stdev['stdev'],
+    color='red',
+    label='Incorrect Prediction',
+)
+plt.xlabel('Confidence (Raw Value)')
+plt.ylabel('Standard Deviation (Raw Value)')
 plt.title('Confidence vs Standard Deviation')
+plt.legend()
 plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
+
 plt.close()
 
 
@@ -196,32 +225,36 @@ for quantile in iter_conf:
 accuracies_df = pd.DataFrame(accuracies_conf)
 
 # Plot the coverage
-plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
+fig, ax = plt.subplots()
+plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], 'ob', label='Ensemble')
 plt.plot(
     accuracies_df['percentile'],
     [accuracy_indv] * len(accuracies_df['percentile']),
-    label='Individual',
-    linestyle='--',
+    'xr',
+    label='Individual (on entire dataset)',
 )
-plt.xlabel('Percentile')
+plt.xlabel('Minimum Confidence Percentile (Low to High)')
 plt.ylabel('Accuracy')
-plt.title('Coverage conf')
+plt.title('Confidence Accuracy Coverage Plot')
 plt.legend()
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
 plt.savefig(f'{V2_PATH}/coverage_conf.png')
 plt.close()
 
 # Plot coverage vs F1 for confidence
-plt.plot(accuracies_df['percentile'], accuracies_df['f1'], label='Ensemble')
+fig, ax = plt.subplots()
+plt.plot(accuracies_df['percentile'], accuracies_df['f1'], 'ob', label='Ensemble')
 plt.plot(
     accuracies_df['percentile'],
     [f1_indv] * len(accuracies_df['percentile']),
-    label='Individual',
-    linestyle='--',
+    'xr',
+    label='Individual (on entire dataset)',
 )
-plt.xlabel('Percentile')
+plt.xlabel('Minimum Confidence Percentile (Low to High)')
 plt.ylabel('F1')
-plt.title('Coverage F1')
+plt.title('Confidence F1 Coverage Plot')
 plt.legend()
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
 plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
 plt.close()
 
@@ -243,37 +276,45 @@ for quantile in iter_stdev:
 accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
 
 # Plot the coverage
+fig, ax = plt.subplots()
 plt.plot(
-    accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
+    accuracies_stdev_df['percentile'],
+    accuracies_stdev_df['accuracy'],
+    'ob',
+    label='Ensemble',
 )
 plt.plot(
     accuracies_stdev_df['percentile'],
     [accuracy_indv] * len(accuracies_stdev_df['percentile']),
-    label='Individual',
-    linestyle='--',
+    'xr',
+    label='Individual (on entire dataset)',
 )
-plt.xlabel('Percentile')
+plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
 plt.ylabel('Accuracy')
-plt.title('Coverage Stdev')
+plt.title('Standard Deviation Accuracy Coverage Plot')
 plt.legend()
 plt.gca().invert_xaxis()
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
 plt.savefig(f'{V2_PATH}/coverage_stdev.png')
 plt.close()
 
 # Plot coverage vs F1 for standard deviation
-plt.plot(accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], label='Ensemble')
+fig, ax = plt.subplots()
+plt.plot(
+    accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], 'ob', label='Ensemble'
+)
 plt.plot(
     accuracies_stdev_df['percentile'],
     [f1_indv] * len(accuracies_stdev_df['percentile']),
-    label='Individual',
-    linestyle='--',
+    'xr',
+    label='Individual (on entire dataset)',
 )
-plt.xlabel('Percentile')
+plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
 plt.ylabel('F1')
-plt.title('Coverage F1 Stdev')
+plt.title('Standard Deviation F1 Coverage Plot')
 plt.legend()
 plt.gca().invert_xaxis()
-
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
 plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
 
 plt.close()
@@ -287,5 +328,91 @@ overall_accuracy = (
 overall_f1 = met.F1(
     confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
 )
+# Calculate ECE and MCE
+conf_ece = met.ECE(
+    confs_df['predicted_class'].to_numpy(),
+    confs_df['confidence'].to_numpy(),
+    confs_df['true_label'].to_numpy(),
+)
+
+stdev_ece = met.ECE(
+    stdevs_df['predicted_class'].to_numpy(),
+    stdevs_df['stdev'].to_numpy(),
+    stdevs_df['true_label'].to_numpy(),
+)
+
+
+print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1},')
+print(f'Confidence ECE: {conf_ece}')
+print(f'Standard Deviation ECE: {stdev_ece}')
+
+
+# Repeat for entropy
+quantiles_entropy = entropies_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
+    'entropy'
+]
+
+accuracies_entropy = []
+iter_entropy = it.islice(quantiles_entropy.items(), 0, None)
+for quantile in iter_entropy:
+    percentile = quantile[0]
+
+    filt = entropies_df[entropies_df['entropy'] <= quantile[1]]
+    accuracy = (
+        filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
+    )
+    f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
+
+    accuracies_entropy.append(
+        {'percentile': percentile, 'accuracy': accuracy, 'f1': f1}
+    )
+
+accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
+
+# Plot the coverage
+fig, ax = plt.subplots()
+plt.plot(
+    accuracies_entropy_df['percentile'],
+    accuracies_entropy_df['accuracy'],
+    'ob',
+    label='Ensemble',
+)
+plt.plot(
+    accuracies_entropy_df['percentile'],
+    [accuracy_indv] * len(accuracies_entropy_df['percentile']),
+    'xr',
+    label='Individual (on entire dataset)',
+)
+plt.xlabel('Maximum Entropy Percentile (High to Low)')
+plt.ylabel('Accuracy')
+plt.title('Entropy Accuracy Coverage Plot')
+plt.legend()
+plt.gca().invert_xaxis()
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
+plt.savefig(f'{V2_PATH}/coverage_entropy.png')
+plt.close()
+
+# Plot coverage vs F1 for entropy
+fig, ax = plt.subplots()
+plt.plot(
+    accuracies_entropy_df['percentile'],
+    accuracies_entropy_df['f1'],
+    'ob',
+    label='Ensemble',
+)
+plt.plot(
+    accuracies_entropy_df['percentile'],
+    [f1_indv] * len(accuracies_entropy_df['percentile']),
+    'xr',
+    label='Individual (on entire dataset)',
+)
+plt.xlabel('Maximum Entropy Percentile (High to Low)')
+plt.ylabel('F1')
+plt.title('Entropy F1 Coverage Plot')
+plt.legend()
+plt.gca().invert_xaxis()
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
+plt.savefig(f'{V2_PATH}/coverage_f1_entropy.png')
+
+plt.close()
 
-print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1}')