Browse Source

Worked on calibration (little success) and began tweaking graphs for presentation and poster

Nicholas Schense 8 months ago
parent
commit
4549b2b349
2 changed files with 125 additions and 38 deletions
  1. 40 0
      calibration_xarray.py
  2. 85 38
      threshold_xarray.py

+ 40 - 0
calibration_xarray.py

@@ -0,0 +1,40 @@
+import threshold_xarray as th
+import xarray as xr
+import numpy as np
+import torch
+import os
+
+import sklearn.calibration as cal
+
+
+# The purpose of this file is to calibrate the data on the test set, and evaluate the calibration on the validation set.
+# We are using scikits calibration library to do this.
+
+if __name__ == '__main__':
+    print('Loading Config..B')
+    config = th.load_config()
+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+    V4_PATH = ENSEMBLE_PATH + '/v4'
+
+    if not os.path.exists(V4_PATH):
+        os.makedirs(V4_PATH)
+    print('Config Loaded')
+
+    # Load the predictions
+    print('Loading Predictions...')
+    val_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/val_predictions.nc')
+    test_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/test_predictions.nc')
+    print('Predictions Loaded')
+
+    # Now the goal is to calibrate the test set, and evaluate the calibration on the validation set.
+    # We do this by binning the data into 15 bins, and then calculating the mean of the predictions in each bin.
+    # We then use this to calibrate the data.
+
+    # First, get the statistics of both sets
+    print('Calculating Statistics...')
+    val_stats = th.compute_ensemble_statistics(val_preds)
+    test_stats = th.compute_ensemble_statistics(test_preds)
+
+    # Calibrate the test set
+    print('Calibrating Test Set...')
+ 

+ 85 - 38
threshold_xarray.py

@@ -12,7 +12,6 @@ import matplotlib.pyplot as plt
 import matplotlib.ticker as mtick
 
 
-
 # The datastructures for this file are as follows
 # models_dict: Dictionary - {model_id: model}
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
@@ -66,13 +65,14 @@ def preprocess_data(data, device):
 
 # Loads datasets and returns concatenated test and validation datasets
 def load_datasets(ensemble_path):
-    return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
-        f'{ensemble_path}/val_dataset.pt'
+    return (
+        torch.load(f'{ensemble_path}/test_dataset.pt'),
+        torch.load(f'{ensemble_path}/val_dataset.pt'),
     )
 
 
 # Gets the predictions for a set of models on a dataset
-def get_ensemble_predictions(models, dataset, device):
+def get_ensemble_predictions(models, dataset, device, id_offset=0):
     zeros = np.zeros((len(dataset), len(models), 4))
     predictions = xr.DataArray(
         zeros,
@@ -99,9 +99,9 @@ def get_ensemble_predictions(models, dataset, device):
                 output = model(dat)
                 prediction = output.cpu().numpy().tolist()[0]
 
-                predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
-                    prediction + actual
-                )
+                predictions.loc[
+                    {'data_id': data_id + id_offset, 'model_id': model_id}
+                ] = prediction + actual
 
     return predictions
 
@@ -218,6 +218,7 @@ def compute_metric(arr, metric):
         return met.F1(
             arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
         )
+
     else:
         raise ValueError('Invalid metric: ' + metric)
 
@@ -252,8 +253,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'confidence',
         'accuracy',
         f'{save_path}/confidence_accuracy.png',
-        'Confidence vs. Accuracy',
-        'Confidence',
+        'Coverage Analysis of Confidence vs. Accuracy',
+        'Minimum Confidence Percentile Threshold',
         'Accuracy',
     )
 
@@ -263,9 +264,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'confidence',
         'f1',
         f'{save_path}/confidence_f1.png',
-        'Confidence vs. F1',
-        'Confidence',
-        'F1',
+        'Coverage Analysis of Confidence vs. F1 Score',
+        'Minimum Confidence Percentile Threshold',
+        'F1 Score',
     )
 
     # Entropy Accuracy
@@ -274,8 +275,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'entropy',
         'accuracy',
         f'{save_path}/entropy_accuracy.png',
-        'Entropy vs. Accuracy',
-        'Entropy',
+        'Coverage Analysis of Entropy vs. Accuracy',
+        'Maximum Entropy Percentile Threshold',
         'Accuracy',
     )
 
@@ -286,9 +287,9 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'entropy',
         'f1',
         f'{save_path}/entropy_f1.png',
-        'Entropy vs. F1',
-        'Entropy',
-        'F1',
+        'Coverage Analysis of Entropy vs. F1 Score',
+        'Maximum Entropy Percentile Threshold',
+        'F1 Score',
     )
 
     # Stdev Accuracy
@@ -297,8 +298,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'stdev',
         'accuracy',
         f'{save_path}/stdev_accuracy.png',
-        'Standard Deviation vs. Accuracy',
-        'Standard Deviation',
+        'Coverage Analysis of Standard Deviation vs. Accuracy',
+        'Maximum Standard Deviation Percentile Threshold',
         'Accuracy',
     )
 
@@ -308,8 +309,8 @@ def graph_all_thresholded_predictions(thresholded_predictions, save_path):
         'stdev',
         'f1',
         f'{save_path}/stdev_f1.png',
-        'Standard Deviation vs. F1',
-        'Standard Deviation',
+        'Coverage Analysis of Standard Deviation vs. F1 Score',
+        'Maximum Standard Deviation Percentile Threshold',
         'F1',
     )
 
@@ -505,8 +506,8 @@ def graph_all_individual_thresholded_predictions(
         'confidence',
         'accuracy',
         f'{save_path}/indv/confidence_accuracy.png',
-        'Confidence vs. Accuracy',
-        'Confidence Percentile Threshold',
+        'Coverage Analysis of Confidence vs. Accuracy for All Models',
+        'Minumum Confidence Percentile Threshold',
         'Accuracy',
     )
 
@@ -517,9 +518,9 @@ def graph_all_individual_thresholded_predictions(
         'confidence',
         'f1',
         f'{save_path}/indv/confidence_f1.png',
-        'Confidence vs. F1',
-        'Confidence Percentile Threshold',
-        'F1',
+        'Coverage Analysis of Confidence vs. F1 Score for All Models',
+        'Minimum Confidence Percentile Threshold',
+        'F1 Score',
     )
 
     # Entropy Accuracy
@@ -529,8 +530,8 @@ def graph_all_individual_thresholded_predictions(
         'entropy',
         'accuracy',
         f'{save_path}/indv/entropy_accuracy.png',
-        'Entropy vs. Accuracy',
-        'Entropy Percentile Threshold',
+        'Coverage Analysis of Entropy vs. Accuracy for All Models',
+        'Maximum Entropy Percentile Threshold',
         'Accuracy',
     )
 
@@ -541,9 +542,9 @@ def graph_all_individual_thresholded_predictions(
         'entropy',
         'f1',
         f'{save_path}/indv/entropy_f1.png',
-        'Entropy vs. F1',
-        'Entropy Percentile Threshold',
-        'F1',
+        'Coverage Analysis of Entropy vs. F1 Score for All Models',
+        'Maximum Entropy Percentile Threshold',
+        'F1 Score',
     )
 
 
@@ -646,6 +647,43 @@ def calculate_overall_stats(ensemble_statistics: xr.DataArray):
     return {'accuracy': accuracy.item(), 'f1': f1.item()}
 
 
+# https://towardsdatascience.com/expected-calibration-error-ece-a-step-by-step-visual-explanation-with-python-code-c3e9aa12937d
+def calculate_ece_stats(statistics, bins=10):
+    bin_boundaries = np.linspace(0, 1, bins + 1)
+    bin_lowers = bin_boundaries[:-1]
+    bin_uppers = bin_boundaries[1:]
+
+    confidences = ((statistics.sel(statistic='mean').values) - 0.5) * 2
+    accuracies = statistics.sel(statistic='correct').values
+
+    ece = np.zeros(1)
+    bin_accuracies = xr.DataArray(
+        np.zeros(bins), dims=('lower_bound'), coords={'lower_bound': bin_lowers}
+    )
+
+    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
+        in_bin = np.logical_and(
+            confidences > bin_lower.item(), confidences <= bin_upper.item()
+        )
+        prob_in_bin = in_bin.mean()
+
+        if prob_in_bin.item() > 0:
+            accuracy_in_bin = accuracies[in_bin].mean()
+
+            bin_accuracies.loc[{'lower_bound': bin_lower}]
+            avg_confidence_in_bin = confidences[in_bin].mean()
+            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
+
+    bin_accuracies.attrs['ece'] = ece
+    bin_accuracies.attrs['bin_number'] = bins
+
+    return bin_accuracies
+
+
+def plot_ece_graph(ece_stats, title, xlabel, ylabel, save_path):
+    fix, ax = plt.subplot()
+
+
 # Main Function
 def main():
     print('Loading Config...')
@@ -659,7 +697,7 @@ def main():
 
     # Load Datasets
     print('Loading Datasets...')
-    dataset = load_datasets(ENSEMBLE_PATH)
+    (test_dataset, val_dataset) = load_datasets(ENSEMBLE_PATH)
     print('Datasets Loaded')
 
     # Get Predictions, either by running the models or loading them from a file
@@ -672,20 +710,29 @@ def main():
 
         # Get Predictions
         print('Getting Predictions...')
-        predictions = get_ensemble_predictions(models, dataset, device)
+        test_predictions = get_ensemble_predictions(models, test_dataset, device)
+        val_predictions = get_ensemble_predictions(
+            models, val_dataset, device, len(test_dataset)
+        )
         print('Predictions Loaded')
 
         # Save Prediction
-        predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
+        test_predictions.to_netcdf(f'{V4_PATH}/test_predictions.nc')
+        val_predictions.to_netcdf(f'{V4_PATH}/val_predictions.nc')
 
     else:
-        predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
+        test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
+        val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
 
     # Prune Data
     print('Pruning Data...')
     if config['operation']['exclude_blank_ids']:
         excluded_data_ids = config['ensemble']['excluded_ids']
-        predictions = prune_data(predictions, excluded_data_ids)
+        test_predictions = prune_data(test_predictions, excluded_data_ids)
+        val_predictions = prune_data(val_predictions, excluded_data_ids)
+
+    # Concatenate Predictions
+    predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
 
     # Compute Ensemble Statistics
     print('Computing Ensemble Statistics...')
@@ -712,7 +759,7 @@ def main():
         'confidence',
         'stdev',
         f'{V4_PATH}/confidence_stdev.png',
-        'Confidence vs. Standard Deviation',
+        'Confidence and Standard Deviation for Predictions',
         'Confidence',
         'Standard Deviation',
     )
@@ -747,7 +794,7 @@ def main():
         sens_analysis,
         'accuracy',
         f'{V4_PATH}/sens_analysis.png',
-        'Sensitivity Analysis',
+        'Sensitivity Analsis of Accuracy vs. # of Models',
         '# of Models',
         'Accuracy',
     )