|
@@ -10,8 +10,9 @@ import sklearn.metrics as metrics
|
|
|
from tqdm import tqdm
|
|
|
import utils.metrics as met
|
|
|
import itertools as it
|
|
|
+import matplotlib.ticker as ticker
|
|
|
|
|
|
-RUN = False
|
|
|
+RUN = True
|
|
|
|
|
|
# CONFIGURATION
|
|
|
if os.getenv('ADL_CONFIG_PATH') is None:
|
|
@@ -46,6 +47,7 @@ def get_predictions(config):
|
|
|
test_set = torch.load(f'{ENSEMBLE_PATH}/test_dataset.pt') + torch.load(
|
|
|
f'{ENSEMBLE_PATH}/val_dataset.pt'
|
|
|
)
|
|
|
+ print(f'Loaded {len(test_set)} samples')
|
|
|
|
|
|
# [([model results], labels)]
|
|
|
results = []
|
|
@@ -87,12 +89,18 @@ def get_predictions(config):
|
|
|
# [(ensemble predicted class, ensemble standard deviation, true label)]
|
|
|
stdevs = []
|
|
|
|
|
|
+ # [(ensemble predicted class, ensemble entropy, true label)]
|
|
|
+ entropies = []
|
|
|
+
|
|
|
for result in results:
|
|
|
model_results, true_label = result
|
|
|
# Get the ensemble mean and variance with numpy, as these are lists
|
|
|
mean = np.mean(model_results, axis=0)
|
|
|
variance = np.var(model_results, axis=0)
|
|
|
|
|
|
+ # Calculate the entropy
|
|
|
+ entropy = -1 * np.sum(mean * np.log(mean))
|
|
|
+
|
|
|
# Calculate confidence and standard deviation
|
|
|
confidence = (np.max(mean) - 0.5) * 2
|
|
|
stdev = np.sqrt(variance)
|
|
@@ -112,12 +120,13 @@ def get_predictions(config):
|
|
|
|
|
|
confidences.append((predicted_class, confidence, true_label, class_1, class_2))
|
|
|
stdevs.append((predicted_class, pc_stdev, true_label, class_1, class_2))
|
|
|
+ entropies.append((predicted_class, entropy, true_label, class_1, class_2))
|
|
|
|
|
|
- return results, confidences, stdevs, indv_results
|
|
|
+ return results, confidences, stdevs, entropies, indv_results
|
|
|
|
|
|
|
|
|
if RUN:
|
|
|
- results, confs, stdevs, indv_results = get_predictions(config)
|
|
|
+ results, confs, stdevs, entropies, indv_results = get_predictions(config)
|
|
|
# Convert to pandas dataframes
|
|
|
confs_df = pd.DataFrame(
|
|
|
confs,
|
|
@@ -127,6 +136,11 @@ if RUN:
|
|
|
stdevs, columns=['predicted_class', 'stdev', 'true_label', 'class_1', 'class_2']
|
|
|
)
|
|
|
|
|
|
+ entropies_df = pd.DataFrame(
|
|
|
+ entropies,
|
|
|
+ columns=['predicted_class', 'entropy', 'true_label', 'class_1', 'class_2'],
|
|
|
+ )
|
|
|
+
|
|
|
indv_df = pd.DataFrame(indv_results, columns=['class_1', 'class_2', 'true_label'])
|
|
|
|
|
|
if not os.path.exists(V2_PATH):
|
|
@@ -134,10 +148,12 @@ if RUN:
|
|
|
|
|
|
confs_df.to_csv(f'{V2_PATH}/ensemble_confidences.csv')
|
|
|
stdevs_df.to_csv(f'{V2_PATH}/ensemble_stdevs.csv')
|
|
|
+ entropies_df.to_csv(f'{V2_PATH}/ensemble_entropies.csv')
|
|
|
indv_df.to_csv(f'{V2_PATH}/individual_results.csv')
|
|
|
else:
|
|
|
confs_df = pd.read_csv(f'{V2_PATH}/ensemble_confidences.csv')
|
|
|
stdevs_df = pd.read_csv(f'{V2_PATH}/ensemble_stdevs.csv')
|
|
|
+ entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
|
|
|
indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
|
|
|
|
|
|
# Plot confidence vs standard deviation, and change color of dots based on if they are correct
|
|
@@ -147,12 +163,25 @@ incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
|
|
|
correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
|
|
|
incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
|
|
|
|
|
|
-plt.scatter(correct_conf['confidence'], correct_stdev['stdev'], color='green')
|
|
|
-plt.scatter(incorrect_conf['confidence'], incorrect_stdev['stdev'], color='red')
|
|
|
-plt.xlabel('Confidence')
|
|
|
-plt.ylabel('Standard Deviation')
|
|
|
+plot, ax = plt.subplots()
|
|
|
+plt.scatter(
|
|
|
+ correct_conf['confidence'],
|
|
|
+ correct_stdev['stdev'],
|
|
|
+ color='green',
|
|
|
+ label='Correct Prediction',
|
|
|
+)
|
|
|
+plt.scatter(
|
|
|
+ incorrect_conf['confidence'],
|
|
|
+ incorrect_stdev['stdev'],
|
|
|
+ color='red',
|
|
|
+ label='Incorrect Prediction',
|
|
|
+)
|
|
|
+plt.xlabel('Confidence (Raw Value)')
|
|
|
+plt.ylabel('Standard Deviation (Raw Value)')
|
|
|
plt.title('Confidence vs Standard Deviation')
|
|
|
+plt.legend()
|
|
|
plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
|
|
|
+
|
|
|
plt.close()
|
|
|
|
|
|
|
|
@@ -196,32 +225,36 @@ for quantile in iter_conf:
|
|
|
accuracies_df = pd.DataFrame(accuracies_conf)
|
|
|
|
|
|
# Plot the coverage
|
|
|
-plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], label='Ensemble')
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], 'ob', label='Ensemble')
|
|
|
plt.plot(
|
|
|
accuracies_df['percentile'],
|
|
|
[accuracy_indv] * len(accuracies_df['percentile']),
|
|
|
- label='Individual',
|
|
|
- linestyle='--',
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
)
|
|
|
-plt.xlabel('Percentile')
|
|
|
+plt.xlabel('Minimum Confidence Percentile (Low to High)')
|
|
|
plt.ylabel('Accuracy')
|
|
|
-plt.title('Coverage conf')
|
|
|
+plt.title('Confidence Accuracy Coverage Plot')
|
|
|
plt.legend()
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
plt.savefig(f'{V2_PATH}/coverage_conf.png')
|
|
|
plt.close()
|
|
|
|
|
|
# Plot coverage vs F1 for confidence
|
|
|
-plt.plot(accuracies_df['percentile'], accuracies_df['f1'], label='Ensemble')
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plt.plot(accuracies_df['percentile'], accuracies_df['f1'], 'ob', label='Ensemble')
|
|
|
plt.plot(
|
|
|
accuracies_df['percentile'],
|
|
|
[f1_indv] * len(accuracies_df['percentile']),
|
|
|
- label='Individual',
|
|
|
- linestyle='--',
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
)
|
|
|
-plt.xlabel('Percentile')
|
|
|
+plt.xlabel('Minimum Confidence Percentile (Low to High)')
|
|
|
plt.ylabel('F1')
|
|
|
-plt.title('Coverage F1')
|
|
|
+plt.title('Confidence F1 Coverage Plot')
|
|
|
plt.legend()
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
|
|
|
plt.close()
|
|
|
|
|
@@ -243,37 +276,45 @@ for quantile in iter_stdev:
|
|
|
accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
|
|
|
|
|
|
# Plot the coverage
|
|
|
+fig, ax = plt.subplots()
|
|
|
plt.plot(
|
|
|
- accuracies_stdev_df['percentile'], accuracies_stdev_df['accuracy'], label='Ensemble'
|
|
|
+ accuracies_stdev_df['percentile'],
|
|
|
+ accuracies_stdev_df['accuracy'],
|
|
|
+ 'ob',
|
|
|
+ label='Ensemble',
|
|
|
)
|
|
|
plt.plot(
|
|
|
accuracies_stdev_df['percentile'],
|
|
|
[accuracy_indv] * len(accuracies_stdev_df['percentile']),
|
|
|
- label='Individual',
|
|
|
- linestyle='--',
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
)
|
|
|
-plt.xlabel('Percentile')
|
|
|
+plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
|
|
|
plt.ylabel('Accuracy')
|
|
|
-plt.title('Coverage Stdev')
|
|
|
+plt.title('Standard Deviation Accuracy Coverage Plot')
|
|
|
plt.legend()
|
|
|
plt.gca().invert_xaxis()
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
plt.savefig(f'{V2_PATH}/coverage_stdev.png')
|
|
|
plt.close()
|
|
|
|
|
|
# Plot coverage vs F1 for standard deviation
|
|
|
-plt.plot(accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], label='Ensemble')
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plt.plot(
|
|
|
+ accuracies_stdev_df['percentile'], accuracies_stdev_df['f1'], 'ob', label='Ensemble'
|
|
|
+)
|
|
|
plt.plot(
|
|
|
accuracies_stdev_df['percentile'],
|
|
|
[f1_indv] * len(accuracies_stdev_df['percentile']),
|
|
|
- label='Individual',
|
|
|
- linestyle='--',
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
)
|
|
|
-plt.xlabel('Percentile')
|
|
|
+plt.xlabel('Maximum Standard Deviation Percentile (High to Low)')
|
|
|
plt.ylabel('F1')
|
|
|
-plt.title('Coverage F1 Stdev')
|
|
|
+plt.title('Standard Deviation F1 Coverage Plot')
|
|
|
plt.legend()
|
|
|
plt.gca().invert_xaxis()
|
|
|
-
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
plt.savefig(f'{V2_PATH}/coverage_f1_stdev.png')
|
|
|
|
|
|
plt.close()
|
|
@@ -287,5 +328,91 @@ overall_accuracy = (
|
|
|
overall_f1 = met.F1(
|
|
|
confs_df['predicted_class'].to_numpy(), confs_df['true_label'].to_numpy()
|
|
|
)
|
|
|
+# Calculate ECE and MCE
|
|
|
+conf_ece = met.ECE(
|
|
|
+ confs_df['predicted_class'].to_numpy(),
|
|
|
+ confs_df['confidence'].to_numpy(),
|
|
|
+ confs_df['true_label'].to_numpy(),
|
|
|
+)
|
|
|
+
|
|
|
+stdev_ece = met.ECE(
|
|
|
+ stdevs_df['predicted_class'].to_numpy(),
|
|
|
+ stdevs_df['stdev'].to_numpy(),
|
|
|
+ stdevs_df['true_label'].to_numpy(),
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1},')
|
|
|
+print(f'Confidence ECE: {conf_ece}')
|
|
|
+print(f'Standard Deviation ECE: {stdev_ece}')
|
|
|
+
|
|
|
+
|
|
|
+# Repeat for entropy
|
|
|
+quantiles_entropy = entropies_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
|
|
|
+ 'entropy'
|
|
|
+]
|
|
|
+
|
|
|
+accuracies_entropy = []
|
|
|
+iter_entropy = it.islice(quantiles_entropy.items(), 0, None)
|
|
|
+for quantile in iter_entropy:
|
|
|
+ percentile = quantile[0]
|
|
|
+
|
|
|
+ filt = entropies_df[entropies_df['entropy'] <= quantile[1]]
|
|
|
+ accuracy = (
|
|
|
+ filt[filt['predicted_class'] == filt['true_label']].shape[0] / filt.shape[0]
|
|
|
+ )
|
|
|
+ f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
|
|
|
+
|
|
|
+ accuracies_entropy.append(
|
|
|
+ {'percentile': percentile, 'accuracy': accuracy, 'f1': f1}
|
|
|
+ )
|
|
|
+
|
|
|
+accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
|
|
|
+
|
|
|
+# Plot the coverage
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plt.plot(
|
|
|
+ accuracies_entropy_df['percentile'],
|
|
|
+ accuracies_entropy_df['accuracy'],
|
|
|
+ 'ob',
|
|
|
+ label='Ensemble',
|
|
|
+)
|
|
|
+plt.plot(
|
|
|
+ accuracies_entropy_df['percentile'],
|
|
|
+ [accuracy_indv] * len(accuracies_entropy_df['percentile']),
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
+)
|
|
|
+plt.xlabel('Maximum Entropy Percentile (High to Low)')
|
|
|
+plt.ylabel('Accuracy')
|
|
|
+plt.title('Entropy Accuracy Coverage Plot')
|
|
|
+plt.legend()
|
|
|
+plt.gca().invert_xaxis()
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
+plt.savefig(f'{V2_PATH}/coverage_entropy.png')
|
|
|
+plt.close()
|
|
|
+
|
|
|
+# Plot coverage vs F1 for entropy
|
|
|
+fig, ax = plt.subplots()
|
|
|
+plt.plot(
|
|
|
+ accuracies_entropy_df['percentile'],
|
|
|
+ accuracies_entropy_df['f1'],
|
|
|
+ 'ob',
|
|
|
+ label='Ensemble',
|
|
|
+)
|
|
|
+plt.plot(
|
|
|
+ accuracies_entropy_df['percentile'],
|
|
|
+ [f1_indv] * len(accuracies_entropy_df['percentile']),
|
|
|
+ 'xr',
|
|
|
+ label='Individual (on entire dataset)',
|
|
|
+)
|
|
|
+plt.xlabel('Maximum Entropy Percentile (High to Low)')
|
|
|
+plt.ylabel('F1')
|
|
|
+plt.title('Entropy F1 Coverage Plot')
|
|
|
+plt.legend()
|
|
|
+plt.gca().invert_xaxis()
|
|
|
+ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
|
|
|
+plt.savefig(f'{V2_PATH}/coverage_f1_entropy.png')
|
|
|
+
|
|
|
+plt.close()
|
|
|
|
|
|
-print(f'Overall accuracy: {overall_accuracy}, Overall F1: {overall_f1}')
|