Browse Source

Still working on sensitivity analysis work - interesting results for accuracy

Nicholas Schense 8 months ago
parent
commit
e9b380d66a
1 changed files with 52 additions and 4 deletions
  1. 52 4
      threshold_xarray.py

+ 52 - 4
threshold_xarray.py

@@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
 import matplotlib.ticker as mtick
 
 
+
 # The datastructures for this file are as follows
 # models_dict: Dictionary - {model_id: model}
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
@@ -581,7 +582,7 @@ def calculate_subset_statistics(predictions: xr.DataArray):
             stdev = data.std(dim='model_id')[1]
             entropy = (-mean * np.log(mean)).sum()
             confidence = mean.max()
-            actual = data[3]
+            actual = data[0][3]
             predicted = mean.argmax()
             correct = actual == predicted
 
@@ -600,17 +601,50 @@ def calculate_subset_statistics(predictions: xr.DataArray):
 
 # Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
 def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
-    subsets = subset_stats.subsets
-    stats = ['accuracy', 'f1', 'ECE', 'MCE']
+    subsets = subset_stats.model_count
+    stats = ['accuracy', 'f1']
 
     zeros = np.zeros((len(subsets), len(stats)))
 
     sens_analysis = xr.DataArray(
         zeros,
         dims=('model_count', 'statistic'),
-        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
+        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1']},
     )
 
+    for subset in subsets:
+        data = subset_stats.sel(model_count=subset)
+        acc = compute_metric(data, 'accuracy')
+        f1 = compute_metric(data, 'f1')
+
+        sens_analysis.loc[{'model_count': subset}] = [acc, f1]
+
+    return sens_analysis
+
+
+def graph_sensitivity_analysis(
+    sens_analysis: xr.DataArray, statistic, save_path, title, xlabel, ylabel
+):
+    data = sens_analysis.sel(statistic=statistic)
+
+    xdata = data.coords['model_count'].values
+    ydata = data.values
+
+    fig, ax = plt.subplots()
+    ax.plot(xdata, ydata)
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+
+    plt.savefig(save_path)
+
+
+def calculate_overall_stats(ensemble_statistics: xr.DataArray):
+    accuracy = compute_metric(ensemble_statistics, 'accuracy')
+    f1 = compute_metric(ensemble_statistics, 'f1')
+
+    return {'accuracy': accuracy.item(), 'f1': f1.item()}
+
 
 # Main Function
 def main():
@@ -706,6 +740,20 @@ def main():
     )
     print('Individual Thresholded Predictions Graphed')
 
+    # Compute subset statistics and graph
+    subset_stats = calculate_subset_statistics(predictions)
+    sens_analysis = calculate_sensitivity_analysis(subset_stats)
+    graph_sensitivity_analysis(
+        sens_analysis,
+        'accuracy',
+        f'{V4_PATH}/sens_analysis.png',
+        'Sensitivity Analysis',
+        '# of Models',
+        'Accuracy',
+    )
+    print(sens_analysis.sel(statistic='accuracy'))
+    print(calculate_overall_stats(ensemble_statistics))
+
 
 if __name__ == '__main__':
     main()