Browse Source

Completed refactor to xarray - working on minimial calibration and sensitivity analysis

Nicholas Schense 1 month ago
parent
commit
1083b2301b
4 changed files with 559 additions and 56 deletions
  1. 2 1
      config.toml
  2. 1 2
      threshold.py
  3. 0 4
      threshold_refac.py
  4. 556 49
      threshold_xarray.py

+ 2 - 1
config.toml

@@ -34,4 +34,5 @@ exclude_blank_ids = false
 name = 'cnn-50x30'
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 individual_id = 1     # The id of the individual model to be used for the ensemble
 individual_id = 1     # The id of the individual model to be used for the ensemble
-run_models = true    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+run_models = false    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+excluded_ids = []     # List of data ids to be excluded from the ensemble

+ 1 - 2
threshold.py

@@ -2,7 +2,6 @@ import pandas as pd
 import numpy as np
 import numpy as np
 import os
 import os
 import tomli as toml
 import tomli as toml
-from utils.data.datasets import prepare_datasets
 import utils.ensemble as ens
 import utils.ensemble as ens
 import torch
 import torch
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
@@ -91,7 +90,7 @@ def get_predictions(config):
     # [(class_1, class_2, true_label)]
     # [(class_1, class_2, true_label)]
     indv_results = []
     indv_results = []
 
 
-    for i, (data, target) in tqdm(
+    for _, (data, target) in tqdm(
         enumerate(test_set),
         enumerate(test_set),
         total=len(test_set),
         total=len(test_set),
         desc='Getting predictions',
         desc='Getting predictions',

+ 0 - 4
threshold_refac.py

@@ -2,19 +2,15 @@ import pandas as pd
 import numpy as np
 import numpy as np
 import os
 import os
 import tomli as toml
 import tomli as toml
-from utils.data.datasets import prepare_datasets
-import utils.ensemble as ens
 import torch
 import torch
 import matplotlib.pyplot as plt
 import matplotlib.pyplot as plt
 import sklearn.metrics as metrics
 import sklearn.metrics as metrics
 from tqdm import tqdm
 from tqdm import tqdm
 import utils.metrics as met
 import utils.metrics as met
-import itertools as it
 import matplotlib.ticker as ticker
 import matplotlib.ticker as ticker
 import glob
 import glob
 import pickle as pk
 import pickle as pk
 import warnings
 import warnings
-import random as rand
 
 
 warnings.filterwarnings('error')
 warnings.filterwarnings('error')
 
 

+ 556 - 49
threshold_xarray.py

@@ -1,31 +1,31 @@
-#Rewritten Program to use xarray instead of pandas for thresholding
+# Rewritten Program to use xarray instead of pandas for thresholding
 
 
 import xarray as xr
 import xarray as xr
-import torch 
+import torch
 import numpy as np
 import numpy as np
 import os
 import os
 import glob
 import glob
 import tomli as toml
 import tomli as toml
 from tqdm import tqdm
 from tqdm import tqdm
 import utils.metrics as met
 import utils.metrics as met
+import matplotlib.pyplot as plt
+import matplotlib.ticker as mtick
 
 
-if __name__ == '__main__':
-    main()
 
 
-
-#The datastructures for this file are as follows
+# The datastructures for this file are as follows
 # models_dict: Dictionary - {model_id: model}
 # models_dict: Dictionary - {model_id: model}
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
 # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
 # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
 # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
 # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
 
 
-#Additionally, we also have the thresholds and statistics for the individual models
+# Additionally, we also have the thresholds and statistics for the individual models
 # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
 # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
 # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
 # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
 
 
-#Additionally, we have some for the sensitivity analysis for number of models
+# Additionally, we have some for the sensitivity analysis for number of models
 # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
 # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
 
 
+
 # Loads configuration dictionary
 # Loads configuration dictionary
 def load_config():
 def load_config():
     if os.getenv('ADL_CONFIG_PATH') is None:
     if os.getenv('ADL_CONFIG_PATH') is None:
@@ -37,7 +37,8 @@ def load_config():
 
 
     return config
     return config
 
 
-#Loads models into a dictionary
+
+# Loads models into a dictionary
 def load_models_v2(folder, device):
 def load_models_v2(folder, device):
     glob_path = os.path.join(folder, '*.pt')
     glob_path = os.path.join(folder, '*.pt')
     model_files = glob.glob(glob_path)
     model_files = glob.glob(glob_path)
@@ -53,6 +54,7 @@ def load_models_v2(folder, device):
 
 
     return model_dict
     return model_dict
 
 
+
 # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
 # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
 def preprocess_data(data, device):
 def preprocess_data(data, device):
     mri, xls = data
     mri, xls = data
@@ -60,12 +62,14 @@ def preprocess_data(data, device):
     xls = xls.unsqueeze(0).to(device)
     xls = xls.unsqueeze(0).to(device)
     return (mri, xls)
     return (mri, xls)
 
 
+
 # Loads datasets and returns concatenated test and validation datasets
 # Loads datasets and returns concatenated test and validation datasets
 def load_datasets(ensemble_path):
 def load_datasets(ensemble_path):
     return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
     return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
         f'{ensemble_path}/val_dataset.pt'
         f'{ensemble_path}/val_dataset.pt'
     )
     )
 
 
+
 # Gets the predictions for a set of models on a dataset
 # Gets the predictions for a set of models on a dataset
 def get_ensemble_predictions(models, dataset, device):
 def get_ensemble_predictions(models, dataset, device):
     zeros = np.zeros((len(dataset), len(models), 4))
     zeros = np.zeros((len(dataset), len(models), 4))
@@ -74,25 +78,35 @@ def get_ensemble_predictions(models, dataset, device):
         dims=('data_id', 'model_id', 'prediction_value'),
         dims=('data_id', 'model_id', 'prediction_value'),
         coords={
         coords={
             'data_id': range(len(dataset)),
             'data_id': range(len(dataset)),
-            'model_id': models.keys(),
-            'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
-        }
+            'model_id': list(models.keys()),
+            'prediction_value': [
+                'negative_prediction',
+                'positive_prediction',
+                'negative_actual',
+                'positive_actual',
+            ],
+        },
     )
     )
 
 
-    for data_id, (data, target) in tqdm(enumerate(dataset)):
-        mri, xls = preprocess_data(data, device)
+    for data_id, (data, target) in tqdm(
+        enumerate(dataset), total=len(dataset), unit='images'
+    ):
+        dat = preprocess_data(data, device)
         actual = list(target.cpu().numpy())
         actual = list(target.cpu().numpy())
         for model_id, model in models.items():
         for model_id, model in models.items():
             with torch.no_grad():
             with torch.no_grad():
-                output = model(mri, xls)
-                prediction = list(output.cpu().numpy())
+                output = model(dat)
+                prediction = output.cpu().numpy().tolist()[0]
 
 
-                predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
+                predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
+                    prediction + actual
+                )
 
 
     return predictions
     return predictions
-                
+
+
 # Compute the ensemble statistics given an array of predictions
 # Compute the ensemble statistics given an array of predictions
-def compute_ensemble_statistics(predictions):
+def compute_ensemble_statistics(predictions: xr.DataArray):
     zeros = np.zeros((len(predictions.data_id), 7))
     zeros = np.zeros((len(predictions.data_id), 7))
 
 
     ensemble_statistics = xr.DataArray(
     ensemble_statistics = xr.DataArray(
@@ -100,58 +114,93 @@ def compute_ensemble_statistics(predictions):
         dims=('data_id', 'statistic'),
         dims=('data_id', 'statistic'),
         coords={
         coords={
             'data_id': predictions.data_id,
             'data_id': predictions.data_id,
-            'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
-        }
-    )   
+            'statistic': [
+                'mean',
+                'stdev',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
 
 
     for data_id in predictions.data_id:
     for data_id in predictions.data_id:
-        data = predictions.loc[{ 'data_id': data_id }]
-        mean = np.mean(data, axis=0)
-        stdev = np.std(data, axis=0)
-        entropy = -np.sum(mean * np.log2(mean + 1e-12))
-        confidence = np.max(mean)
-        
-        actual = data.iloc[:, 3].values
-        predicted = np.argmax(mean)
+        data = predictions.loc[{'data_id': data_id}]
+        mean = data.mean(dim='model_id')[
+            0:2
+        ]  # Only take the predictions, not the actual
+        stdev = data.std(dim='model_id')[
+            1
+        ]  # Only need the standard deviation of the postive prediction
+        entropy = (-mean * np.log(mean)).sum()
+
+        # Compute confidence
+        confidence = mean.max()
+
+        # only need one of the actual values, since they are all the same, just get the first actual_positive
+        actual = data.loc[{'prediction_value': 'positive_actual'}][0]
+        predicted = mean.argmax()
         correct = actual == predicted
         correct = actual == predicted
 
 
-        ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
+        ensemble_statistics.loc[{'data_id': data_id}] = [
+            mean[1],
+            stdev,
+            entropy,
+            confidence,
+            correct,
+            predicted,
+            actual,
+        ]
 
 
     return ensemble_statistics
     return ensemble_statistics
 
 
+
 # Compute the thresholded predictions given an array of predictions
 # Compute the thresholded predictions given an array of predictions
-def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
-    quantiles = np.linspace(0.05, 0.95, 19)
+def compute_thresholded_predictions(input_stats: xr.DataArray):
+    quantiles = np.linspace(0.05, 0.95, 19) * 100
     metrics = ['accuracy', 'f1']
     metrics = ['accuracy', 'f1']
     statistics = ['stdev', 'entropy', 'confidence']
     statistics = ['stdev', 'entropy', 'confidence']
-    
+
     zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
     zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
 
 
     thresholded_predictions = xr.DataArray(
     thresholded_predictions = xr.DataArray(
         zeros,
         zeros,
         dims=('quantile', 'statistic', 'metric'),
         dims=('quantile', 'statistic', 'metric'),
-        coords={
-            'quantile': quantiles,
-            'statistic': statistics,
-            'metric': metrics
-        }
+        coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
     )
     )
 
 
     for statistic in statistics:
     for statistic in statistics:
-        #First, we must compute the quantiles for the statistic
-        quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
+        # First, we must compute the quantiles for the statistic
+        quantile_values = np.percentile(
+            input_stats.sel(statistic=statistic).values, quantiles, axis=0
+        )
 
 
-        #Then, we must compute the metrics for each quantile
+        # Then, we must compute the metrics for each quantile
         for i, quantile in enumerate(quantiles):
         for i, quantile in enumerate(quantiles):
             if low_to_high(statistic):
             if low_to_high(statistic):
-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
+                mask = (
+                    input_stats.sel(statistic=statistic) >= quantile_values[i]
+                ).values
             else:
             else:
-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
+                mask = (
+                    input_stats.sel(statistic=statistic) <= quantile_values[i]
+                ).values
+
+            # Filter the data based on the mask
+            filtered_data = input_stats.where(
+                input_stats.data_id.isin(np.where(mask)), drop=True
+            )
+
             for metric in metrics:
             for metric in metrics:
-                thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
-    
+                thresholded_predictions.loc[
+                    {'quantile': quantile, 'statistic': statistic, 'metric': metric}
+                ] = compute_metric(filtered_data, metric)
+
     return thresholded_predictions
     return thresholded_predictions
-                
+
+
 # Truth function to determine if metric should be thresholded low to high or high to low
 # Truth function to determine if metric should be thresholded low to high or high to low
 # Low confidence is bad, high entropy is bad, high stdev is bad
 # Low confidence is bad, high entropy is bad, high stdev is bad
 # So we threshold confidence low to high, entropy and stdev high to low
 # So we threshold confidence low to high, entropy and stdev high to low
@@ -159,35 +208,438 @@ def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
 def low_to_high(stat):
 def low_to_high(stat):
     return stat in ['confidence']
     return stat in ['confidence']
 
 
+
 # Compute a given metric on a DataArray of statstics
 # Compute a given metric on a DataArray of statstics
 def compute_metric(arr, metric):
 def compute_metric(arr, metric):
     if metric == 'accuracy':
     if metric == 'accuracy':
-        return np.mean(arr.loc[{ 'statistic': 'correct' }])
+        return np.mean(arr.loc[{'statistic': 'correct'}])
     elif metric == 'f1':
     elif metric == 'f1':
-        return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
+        return met.F1(
+            arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
+        )
     else:
     else:
         raise ValueError('Invalid metric: ' + metric)
         raise ValueError('Invalid metric: ' + metric)
 
 
 
 
+# Graph a thresholded prediction for a given statistic and metric
+def graph_thresholded_prediction(
+    thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
+):
+    data = thresholded_predictions.sel(statistic=statistic, metric=metric)
+
+    x_data = data.coords['quantile'].values
+    y_data = data.values
+
+    fig, ax = plt.subplots()
+    ax.plot(x_data, y_data, 'bx-', label='Ensemble')
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.xaxis.set_major_formatter(mtick.PercentFormatter())
+
+    if not low_to_high(statistic):
+        ax.invert_xaxis()
+
+    plt.savefig(save_path)
+
+
+# Graph all thresholded predictions
+def graph_all_thresholded_predictions(thresholded_predictions, save_path):
+    # Confidence Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'confidence',
+        'accuracy',
+        f'{save_path}/confidence_accuracy.png',
+        'Confidence vs. Accuracy',
+        'Confidence',
+        'Accuracy',
+    )
+
+    # Confidence F1
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'confidence',
+        'f1',
+        f'{save_path}/confidence_f1.png',
+        'Confidence vs. F1',
+        'Confidence',
+        'F1',
+    )
+
+    # Entropy Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'entropy',
+        'accuracy',
+        f'{save_path}/entropy_accuracy.png',
+        'Entropy vs. Accuracy',
+        'Entropy',
+        'Accuracy',
+    )
+
+    # Entropy F1
+
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'entropy',
+        'f1',
+        f'{save_path}/entropy_f1.png',
+        'Entropy vs. F1',
+        'Entropy',
+        'F1',
+    )
+
+    # Stdev Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'stdev',
+        'accuracy',
+        f'{save_path}/stdev_accuracy.png',
+        'Standard Deviation vs. Accuracy',
+        'Standard Deviation',
+        'Accuracy',
+    )
+
+    # Stdev F1
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'stdev',
+        'f1',
+        f'{save_path}/stdev_f1.png',
+        'Standard Deviation vs. F1',
+        'Standard Deviation',
+        'F1',
+    )
+
+
+# Graph two statistics against each other
+def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
+    # Filter for correct predictions
+    c_stats = stats.where(
+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
+        drop=True,
+    )
+
+    # Filter for incorrect predictions
+    i_stats = stats.where(
+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
+        drop=True,
+    )
+
+    # x and y data for correct and incorrect predictions
+    x_data_c = c_stats.sel(statistic=x_stat).values
+    y_data_c = c_stats.sel(statistic=y_stat).values
+
+    x_data_i = i_stats.sel(statistic=x_stat).values
+    y_data_i = i_stats.sel(statistic=y_stat).values
+
+    fig, ax = plt.subplots()
+    ax.plot(x_data_c, y_data_c, 'go', label='Correct')
+    ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.legend()
+
+    plt.savefig(save_path)
+
+
+# Prune the data based on excluded data_ids
+def prune_data(data, excluded_data_ids):
+    return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
+
+
+# Calculate individual model statistics
+def compute_individual_statistics(predictions: xr.DataArray):
+    zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
+
+    indv_statistics = xr.DataArray(
+        zeros,
+        dims=('data_id', 'model_id', 'statistic'),
+        coords={
+            'data_id': predictions.data_id,
+            'model_id': predictions.model_id,
+            'statistic': [
+                'mean',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
+
+    for data_id in predictions.data_id:
+        for model_id in predictions.model_id:
+            data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
+            mean = data[0:2]
+            entropy = (-mean * np.log(mean)).sum()
+            confidence = mean.max()
+            actual = data[3]
+            predicted = mean.argmax()
+            correct = actual == predicted
+
+            indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
+                mean[1],
+                entropy,
+                confidence,
+                correct,
+                predicted,
+                actual,
+            ]
+
+    return indv_statistics
+
+
+# Compute individual model thresholds
+def compute_individual_thresholds(input_stats: xr.DataArray):
+    quantiles = np.linspace(0.05, 0.95, 19) * 100
+    metrics = ['accuracy', 'f1']
+    statistics = ['entropy', 'confidence']
+
+    zeros = np.zeros(
+        (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
+    )
+
+    indv_thresholds = xr.DataArray(
+        zeros,
+        dims=('model_id', 'quantile', 'statistic', 'metric'),
+        coords={
+            'model_id': input_stats.model_id,
+            'quantile': quantiles,
+            'statistic': statistics,
+            'metric': metrics,
+        },
+    )
+
+    for model_id in input_stats.model_id:
+        for statistic in statistics:
+            # First, we must compute the quantiles for the statistic
+            quantile_values = np.percentile(
+                input_stats.sel(model_id=model_id, statistic=statistic).values,
+                quantiles,
+                axis=0,
+            )
+
+            # Then, we must compute the metrics for each quantile
+            for i, quantile in enumerate(quantiles):
+                if low_to_high(statistic):
+                    mask = (
+                        input_stats.sel(model_id=model_id, statistic=statistic)
+                        >= quantile_values[i]
+                    ).values
+                else:
+                    mask = (
+                        input_stats.sel(model_id=model_id, statistic=statistic)
+                        <= quantile_values[i]
+                    ).values
+
+                # Filter the data based on the mask
+                filtered_data = input_stats.where(
+                    input_stats.data_id.isin(np.where(mask)), drop=True
+                )
+
+                for metric in metrics:
+                    indv_thresholds.loc[
+                        {
+                            'model_id': model_id,
+                            'quantile': quantile,
+                            'statistic': statistic,
+                            'metric': metric,
+                        }
+                    ] = compute_metric(filtered_data, metric)
+
+    return indv_thresholds
+
+
+# Graph individual model thresholded predictions
+def graph_individual_thresholded_predictions(
+    indv_thresholds,
+    ensemble_thresholds,
+    statistic,
+    metric,
+    save_path,
+    title,
+    xlabel,
+    ylabel,
+):
+    data = indv_thresholds.sel(statistic=statistic, metric=metric)
+    e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
+
+    x_data = data.coords['quantile'].values
+    y_data = data.values
+
+    e_x_data = e_data.coords['quantile'].values
+    e_y_data = e_data.values
+
+    fig, ax = plt.subplots()
+    for model_id in data.coords['model_id'].values:
+        model_data = data.sel(model_id=model_id)
+        ax.plot(x_data, model_data)
+
+    ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
+
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.xaxis.set_major_formatter(mtick.PercentFormatter())
+
+    if not low_to_high(statistic):
+        ax.invert_xaxis()
+
+    ax.legend()
+    plt.savefig(save_path)
+
+
+# Graph all individual thresholded predictions
+def graph_all_individual_thresholded_predictions(
+    indv_thresholds, ensemble_thresholds, save_path
+):
+    # Confidence Accuracy
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'confidence',
+        'accuracy',
+        f'{save_path}/indv/confidence_accuracy.png',
+        'Confidence vs. Accuracy',
+        'Confidence Percentile Threshold',
+        'Accuracy',
+    )
+
+    # Confidence F1
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'confidence',
+        'f1',
+        f'{save_path}/indv/confidence_f1.png',
+        'Confidence vs. F1',
+        'Confidence Percentile Threshold',
+        'F1',
+    )
+
+    # Entropy Accuracy
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'entropy',
+        'accuracy',
+        f'{save_path}/indv/entropy_accuracy.png',
+        'Entropy vs. Accuracy',
+        'Entropy Percentile Threshold',
+        'Accuracy',
+    )
+
+    # Entropy F1
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'entropy',
+        'f1',
+        f'{save_path}/indv/entropy_f1.png',
+        'Entropy vs. F1',
+        'Entropy Percentile Threshold',
+        'F1',
+    )
+
+
+# Calculate statistics of subsets of models for sensitivity analysis
+def calculate_subset_statistics(predictions: xr.DataArray):
+    # Calculate subsets for 1-50 models
+    subsets = range(1, len(predictions.model_id) + 1)
+    zeros = np.zeros(
+        (len(predictions.data_id), len(subsets), 7)
+    )  # Include stdev, but for 1 models set to NaN
+
+    subset_stats = xr.DataArray(
+        zeros,
+        dims=('data_id', 'model_count', 'statistic'),
+        coords={
+            'data_id': predictions.data_id,
+            'model_count': subsets,
+            'statistic': [
+                'mean',
+                'stdev',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
+
+    for data_id in predictions.data_id:
+        for subset in subsets:
+            data = predictions.sel(
+                data_id=data_id, model_id=predictions.model_id[:subset]
+            )
+            mean = data.mean(dim='model_id')[0:2]
+            stdev = data.std(dim='model_id')[1]
+            entropy = (-mean * np.log(mean)).sum()
+            confidence = mean.max()
+            actual = data[3]
+            predicted = mean.argmax()
+            correct = actual == predicted
+
+            subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
+                mean[1],
+                stdev,
+                entropy,
+                confidence,
+                correct,
+                predicted,
+                actual,
+            ]
+
+    return subset_stats
+
+
+# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
+def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
+    subsets = subset_stats.subsets
+    stats = ['accuracy', 'f1', 'ECE', 'MCE']
+
+    zeros = np.zeros((len(subsets), len(stats)))
+
+    sens_analysis = xr.DataArray(
+        zeros,
+        dims=('model_count', 'statistic'),
+        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
+    )
+
+
+# Main Function
 def main():
 def main():
+    print('Loading Config...')
     config = load_config()
     config = load_config()
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
     V4_PATH = ENSEMBLE_PATH + '/v4'
     V4_PATH = ENSEMBLE_PATH + '/v4'
 
 
     if not os.path.exists(V4_PATH):
     if not os.path.exists(V4_PATH):
         os.makedirs(V4_PATH)
         os.makedirs(V4_PATH)
+    print('Config Loaded')
 
 
     # Load Datasets
     # Load Datasets
+    print('Loading Datasets...')
     dataset = load_datasets(ENSEMBLE_PATH)
     dataset = load_datasets(ENSEMBLE_PATH)
+    print('Datasets Loaded')
 
 
     # Get Predictions, either by running the models or loading them from a file
     # Get Predictions, either by running the models or loading them from a file
     if config['ensemble']['run_models']:
     if config['ensemble']['run_models']:
         # Load Models
         # Load Models
+        print('Loading Models...')
         device = torch.device(config['training']['device'])
         device = torch.device(config['training']['device'])
         models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
         models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+        print('Models Loaded')
 
 
         # Get Predictions
         # Get Predictions
+        print('Getting Predictions...')
         predictions = get_ensemble_predictions(models, dataset, device)
         predictions = get_ensemble_predictions(models, dataset, device)
+        print('Predictions Loaded')
 
 
         # Save Prediction
         # Save Prediction
         predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
         predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
@@ -195,10 +647,65 @@ def main():
     else:
     else:
         predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
         predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
 
 
+    # Prune Data
+    print('Pruning Data...')
+    if config['operation']['exclude_blank_ids']:
+        excluded_data_ids = config['ensemble']['excluded_ids']
+        predictions = prune_data(predictions, excluded_data_ids)
+
     # Compute Ensemble Statistics
     # Compute Ensemble Statistics
+    print('Computing Ensemble Statistics...')
     ensemble_statistics = compute_ensemble_statistics(predictions)
     ensemble_statistics = compute_ensemble_statistics(predictions)
     ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
     ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
+    print('Ensemble Statistics Computed')
 
 
     # Compute Thresholded Predictions
     # Compute Thresholded Predictions
+    print('Computing Thresholded Predictions...')
     thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
     thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
     thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
     thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
+    print('Thresholded Predictions Computed')
+
+    # Graph Thresholded Predictions
+    print('Graphing Thresholded Predictions...')
+    graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
+    print('Thresholded Predictions Graphed')
+
+    # Additional Graphs
+    print('Graphing Additional Graphs...')
+    # Confidence vs stdev
+    graph_statistics(
+        ensemble_statistics,
+        'confidence',
+        'stdev',
+        f'{V4_PATH}/confidence_stdev.png',
+        'Confidence vs. Standard Deviation',
+        'Confidence',
+        'Standard Deviation',
+    )
+    print('Additional Graphs Graphed')
+
+    # Compute Individual Statistics
+    print('Computing Individual Statistics...')
+    indv_statistics = compute_individual_statistics(predictions)
+    indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
+    print('Individual Statistics Computed')
+
+    # Compute Individual Thresholds
+    print('Computing Individual Thresholds...')
+    indv_thresholds = compute_individual_thresholds(indv_statistics)
+    indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
+    print('Individual Thresholds Computed')
+
+    # Graph Individual Thresholded Predictions
+    print('Graphing Individual Thresholded Predictions...')
+    if not os.path.exists(f'{V4_PATH}/indv'):
+        os.makedirs(f'{V4_PATH}/indv')
+
+    graph_all_individual_thresholded_predictions(
+        indv_thresholds, thresholded_predictions, V4_PATH
+    )
+    print('Individual Thresholded Predictions Graphed')
+
+
+if __name__ == '__main__':
+    main()