Przeglądaj źródła

Begin threshold refactor

Nicholas Schense 4 miesięcy temu
rodzic
commit
7775cf4b28
2 zmienionych plików z 158 dodań i 76 usunięć
  1. 14 3
      planning.md
  2. 144 73
      threshold.py

+ 14 - 3
planning.md

@@ -3,9 +3,9 @@
 As of now, we have a program set up to be able to:
 
 - train an individual model with specific hyperparameters
-- train a ensemble of models with the identical hyperparameters 
-- evaluate the accuracy of an ensemble of models 
-- perform a coverage analysis on an ensemble of models 
+- train a ensemble of models with the identical hyperparameters
+- evaluate the accuracy of an ensemble of models
+- perform a coverage analysis on an ensemble of models
 
 The goal of this rewrite is to preserve those functions while making the program significantly cleaner and easier to use, and to make it easier to extend with new functionality in the future as well. The hope is for this project to take approximately ~1-2 days, and be completed by Monday (6/17). The additional features that I would like to implement are:
 
@@ -16,3 +16,14 @@ The goal of this rewrite is to preserve those functions while making the program
 - Implementation of new metrics and ensembles
 - Deterministic dataloading (for a specified seed, the data used is set and does not change, even if the loading methods do)
 
+## Further Planning as of 7/8/24
+
+- With the implementation of uncertainty through standard deviation, confidence and entropy, next steps are
+  - Refactor current threshold implementation - very very messy and difficult to add new features
+  - Enable checking images for incorrect prediction, and predictions off of the main curve for stdev-conf curve thing
+  - Investigate physician confidence, and compare to uncertianty predictions
+  - Deep dive standard deviation
+  - Box plot?
+  - Investigate calibration - do we need it?
+  - Consider manuscript - should be thinking about writing
+  

+ 144 - 73
threshold.py

@@ -12,7 +12,43 @@ import utils.metrics as met
 import itertools as it
 import matplotlib.ticker as ticker
 
-RUN = True
+
+# Define plotting helper function
+def plot_coverage(
+    percentiles,
+    ensemble_results,
+    individual_results,
+    title,
+    x_lablel,
+    y_label,
+    save_path,
+    flip=False,
+):
+    fig, ax = plt.subplots()
+    plt.plot(
+        percentiles,
+        ensemble_results,
+        'ob',
+        label='Ensemble',
+    )
+    plt.plot(
+        percentiles,
+        individual_results,
+        'xr',
+        label='Individual (on entire dataset)',
+    )
+    plt.xlabel(x_lablel)
+    plt.ylabel(y_label)
+    plt.title(title)
+    plt.legend()
+    if flip:
+        plt.gca().invert_xaxis()
+    ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
+    plt.savefig(save_path)
+    plt.close()
+
+
+RUN = False
 
 # CONFIGURATION
 if os.getenv('ADL_CONFIG_PATH') is None:
@@ -109,7 +145,6 @@ def get_predictions(config):
         predicted_class = np.argmax(mean)
 
         # Get the confidence and standard deviation of the predicted class
-        print(stdev)
         pc_stdev = np.squeeze(stdev)[predicted_class]
         # Get the individual classes
         class_1 = mean[0][0]
@@ -156,6 +191,7 @@ else:
     entropies_df = pd.read_csv(f'{V2_PATH}/ensemble_entropies.csv')
     indv_df = pd.read_csv(f'{V2_PATH}/individual_results.csv')
 
+
 # Plot confidence vs standard deviation, and change color of dots based on if they are correct
 correct_conf = confs_df[confs_df['predicted_class'] == confs_df['true_label']]
 incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
@@ -163,6 +199,13 @@ incorrect_conf = confs_df[confs_df['predicted_class'] != confs_df['true_label']]
 correct_stdev = stdevs_df[stdevs_df['predicted_class'] == stdevs_df['true_label']]
 incorrect_stdev = stdevs_df[stdevs_df['predicted_class'] != stdevs_df['true_label']]
 
+correct_ent = entropies_df[
+    entropies_df['predicted_class'] == entropies_df['true_label']
+]
+incorrect_ent = entropies_df[
+    entropies_df['predicted_class'] != entropies_df['true_label']
+]
+
 plot, ax = plt.subplots()
 plt.scatter(
     correct_conf['confidence'],
@@ -184,8 +227,30 @@ plt.savefig(f'{V2_PATH}/confidence_vs_stdev.png')
 
 plt.close()
 
+# Do the same for confidence vs entropy
+plot, ax = plt.subplots()
+plt.scatter(
+    correct_conf['confidence'],
+    correct_ent['entropy'],
+    color='green',
+    label='Correct Prediction',
+)
+plt.scatter(
+    incorrect_conf['confidence'],
+    incorrect_ent['entropy'],
+    color='red',
+    label='Incorrect Prediction',
+)
+plt.xlabel('Confidence (Raw Value)')
+plt.ylabel('Entropy (Raw Value)')
+plt.title('Confidence vs Entropy')
+plt.legend()
+plt.savefig(f'{V2_PATH}/confidence_vs_entropy.png')
+
+plt.close()
+
 
-# Calculate individual model accuracy
+# Calculate individual model accuracy and entropy
 # Determine predicted class
 indv_df['predicted_class'] = indv_df[['class_1', 'class_2']].idxmax(axis=1)
 indv_df['predicted_class'] = indv_df['predicted_class'].apply(
@@ -199,6 +264,9 @@ f1_indv = met.F1(
 auc_indv = metrics.roc_auc_score(
     indv_df['true_label'].to_numpy(), indv_df['class_2'].to_numpy()
 )
+indv_df['entropy'] = -1 * indv_df[['class_1', 'class_2']].apply(
+    lambda x: x * np.log(x), axis=0
+).sum(axis=1)
 
 # Calculate percentiles for confidence and standard deviation
 quantiles_conf = confs_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
@@ -208,6 +276,18 @@ quantiles_stdev = stdevs_df.quantile(np.linspace(0, 1, 11), interpolation='lower
     'stdev'
 ]
 
+# Additionally for individual confidence
+quantiles_indv_conf = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
+    'class_2'
+]
+
+# For indivual entropy
+quantiles_indv_entropy = indv_df.quantile(np.linspace(0, 1, 11), interpolation='lower')[
+    'entropy'
+]
+
+#
+
 accuracies_conf = []
 # Use the quantiles to calculate the coverage
 iter_conf = it.islice(quantiles_conf.items(), 0, None)
@@ -224,40 +304,56 @@ for quantile in iter_conf:
 
 accuracies_df = pd.DataFrame(accuracies_conf)
 
-# Plot the coverage
-fig, ax = plt.subplots()
-plt.plot(accuracies_df['percentile'], accuracies_df['accuracy'], 'ob', label='Ensemble')
-plt.plot(
+indv_conf = []
+# Use the quantiles to calculate the coverage
+iter_conf = it.islice(quantiles_indv_conf.items(), 0, None)
+for quantile in iter_conf:
+    percentile = quantile[0]
+
+    filt = indv_df[indv_df['class_2'] >= quantile[1]]
+    accuracy = filt['correct'].mean()
+    f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
+
+    indv_conf.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
+
+indv_conf_df = pd.DataFrame(indv_conf)
+
+# Do the same for entropy
+indv_entropy = []
+iter_entropy = it.islice(quantiles_indv_entropy.items(), 0, None)
+for quantile in iter_entropy:
+    percentile = quantile[0]
+
+    filt = indv_df[indv_df['entropy'] <= quantile[1]]
+    accuracy = filt['correct'].mean()
+    f1 = met.F1(filt['predicted_class'].to_numpy(), filt['true_label'].to_numpy())
+
+    indv_entropy.append({'percentile': percentile, 'accuracy': accuracy, 'f1': f1})
+
+indv_entropy_df = pd.DataFrame(indv_entropy)
+
+
+# Plot the coverage for confidence and accuracy
+plot_coverage(
     accuracies_df['percentile'],
-    [accuracy_indv] * len(accuracies_df['percentile']),
-    'xr',
-    label='Individual (on entire dataset)',
+    accuracies_df['accuracy'],
+    indv_conf_df['accuracy'],
+    'Confidence Accuracy Coverage Plot',
+    'Minimum Confidence Percentile (Low to High)',
+    'Accuracy',
+    f'{V2_PATH}/coverage_conf.png',
 )
-plt.xlabel('Minimum Confidence Percentile (Low to High)')
-plt.ylabel('Accuracy')
-plt.title('Confidence Accuracy Coverage Plot')
-plt.legend()
-ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
-plt.savefig(f'{V2_PATH}/coverage_conf.png')
-plt.close()
 
-# Plot coverage vs F1 for confidence
-fig, ax = plt.subplots()
-plt.plot(accuracies_df['percentile'], accuracies_df['f1'], 'ob', label='Ensemble')
-plt.plot(
+# Plot the coverage for confidence and F1
+plot_coverage(
     accuracies_df['percentile'],
-    [f1_indv] * len(accuracies_df['percentile']),
-    'xr',
-    label='Individual (on entire dataset)',
+    accuracies_df['f1'],
+    indv_conf_df['f1'],
+    'Confidence F1 Coverage Plot',
+    'Minimum Confidence Percentile (Low to High)',
+    'F1',
+    f'{V2_PATH}/f1_coverage_conf.png',
 )
-plt.xlabel('Minimum Confidence Percentile (Low to High)')
-plt.ylabel('F1')
-plt.title('Confidence F1 Coverage Plot')
-plt.legend()
-ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
-plt.savefig(f'{V2_PATH}/coverage_f1_conf.png')
-plt.close()
-
 
 # Repeat for standard deviation
 accuracies_stdev = []
@@ -275,7 +371,6 @@ for quantile in iter_stdev:
 
 accuracies_stdev_df = pd.DataFrame(accuracies_stdev)
 
-# Plot the coverage
 fig, ax = plt.subplots()
 plt.plot(
     accuracies_stdev_df['percentile'],
@@ -369,50 +464,26 @@ for quantile in iter_entropy:
 
 accuracies_entropy_df = pd.DataFrame(accuracies_entropy)
 
-# Plot the coverage
-fig, ax = plt.subplots()
-plt.plot(
+
+# Plot the coverage for entropy and accuracy
+plot_coverage(
     accuracies_entropy_df['percentile'],
     accuracies_entropy_df['accuracy'],
-    'ob',
-    label='Ensemble',
+    indv_entropy_df['accuracy'],
+    'Entropy Accuracy Coverage Plot',
+    'Minimum Entropy Percentile (Low to High)',
+    'Accuracy',
+    f'{V2_PATH}/coverage_entropy.png',
 )
-plt.plot(
-    accuracies_entropy_df['percentile'],
-    [accuracy_indv] * len(accuracies_entropy_df['percentile']),
-    'xr',
-    label='Individual (on entire dataset)',
-)
-plt.xlabel('Maximum Entropy Percentile (High to Low)')
-plt.ylabel('Accuracy')
-plt.title('Entropy Accuracy Coverage Plot')
-plt.legend()
-plt.gca().invert_xaxis()
-ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
-plt.savefig(f'{V2_PATH}/coverage_entropy.png')
-plt.close()
 
-# Plot coverage vs F1 for entropy
-fig, ax = plt.subplots()
-plt.plot(
+# Plot the coverage for entropy and F1
+plot_coverage(
     accuracies_entropy_df['percentile'],
     accuracies_entropy_df['f1'],
-    'ob',
-    label='Ensemble',
+    indv_entropy_df['f1'],
+    'Entropy F1 Coverage Plot',
+    'Maximum Entropy Percentile (High to Low)',
+    'F1',
+    f'{V2_PATH}/f1_coverage_entropy.png',
+    flip=True,
 )
-plt.plot(
-    accuracies_entropy_df['percentile'],
-    [f1_indv] * len(accuracies_entropy_df['percentile']),
-    'xr',
-    label='Individual (on entire dataset)',
-)
-plt.xlabel('Maximum Entropy Percentile (High to Low)')
-plt.ylabel('F1')
-plt.title('Entropy F1 Coverage Plot')
-plt.legend()
-plt.gca().invert_xaxis()
-ax.xaxis.set_major_formatter(ticker.PercentFormatter(xmax=1.0))
-plt.savefig(f'{V2_PATH}/coverage_f1_entropy.png')
-
-plt.close()
-