瀏覽代碼

Completed refactor to xarray - working on minimial calibration and sensitivity analysis

Nicholas Schense 1 月之前
父節點
當前提交
1083b2301b
共有 4 個文件被更改,包括 559 次插入56 次删除
  1. 2 1
      config.toml
  2. 1 2
      threshold.py
  3. 0 4
      threshold_refac.py
  4. 556 49
      threshold_xarray.py

+ 2 - 1
config.toml

@@ -34,4 +34,5 @@ exclude_blank_ids = false
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 individual_id = 1     # The id of the individual model to be used for the ensemble
-run_models = true    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+run_models = false    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+excluded_ids = []     # List of data ids to be excluded from the ensemble

+ 1 - 2
threshold.py

@@ -2,7 +2,6 @@ import pandas as pd
 import numpy as np
 import os
 import tomli as toml
-from utils.data.datasets import prepare_datasets
 import utils.ensemble as ens
 import torch
 import matplotlib.pyplot as plt
@@ -91,7 +90,7 @@ def get_predictions(config):
     # [(class_1, class_2, true_label)]
     indv_results = []
 
-    for i, (data, target) in tqdm(
+    for _, (data, target) in tqdm(
         enumerate(test_set),
         total=len(test_set),
         desc='Getting predictions',

+ 0 - 4
threshold_refac.py

@@ -2,19 +2,15 @@ import pandas as pd
 import numpy as np
 import os
 import tomli as toml
-from utils.data.datasets import prepare_datasets
-import utils.ensemble as ens
 import torch
 import matplotlib.pyplot as plt
 import sklearn.metrics as metrics
 from tqdm import tqdm
 import utils.metrics as met
-import itertools as it
 import matplotlib.ticker as ticker
 import glob
 import pickle as pk
 import warnings
-import random as rand
 
 warnings.filterwarnings('error')
 

+ 556 - 49
threshold_xarray.py

@@ -1,31 +1,31 @@
-#Rewritten Program to use xarray instead of pandas for thresholding
+# Rewritten Program to use xarray instead of pandas for thresholding
 
 import xarray as xr
-import torch 
+import torch
 import numpy as np
 import os
 import glob
 import tomli as toml
 from tqdm import tqdm
 import utils.metrics as met
+import matplotlib.pyplot as plt
+import matplotlib.ticker as mtick
 
-if __name__ == '__main__':
-    main()
 
-
-#The datastructures for this file are as follows
+# The datastructures for this file are as follows
 # models_dict: Dictionary - {model_id: model}
 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
 # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
 # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
 
-#Additionally, we also have the thresholds and statistics for the individual models
+# Additionally, we also have the thresholds and statistics for the individual models
 # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
 # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
 
-#Additionally, we have some for the sensitivity analysis for number of models
+# Additionally, we have some for the sensitivity analysis for number of models
 # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
 
+
 # Loads configuration dictionary
 def load_config():
     if os.getenv('ADL_CONFIG_PATH') is None:
@@ -37,7 +37,8 @@ def load_config():
 
     return config
 
-#Loads models into a dictionary
+
+# Loads models into a dictionary
 def load_models_v2(folder, device):
     glob_path = os.path.join(folder, '*.pt')
     model_files = glob.glob(glob_path)
@@ -53,6 +54,7 @@ def load_models_v2(folder, device):
 
     return model_dict
 
+
 # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
 def preprocess_data(data, device):
     mri, xls = data
@@ -60,12 +62,14 @@ def preprocess_data(data, device):
     xls = xls.unsqueeze(0).to(device)
     return (mri, xls)
 
+
 # Loads datasets and returns concatenated test and validation datasets
 def load_datasets(ensemble_path):
     return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
         f'{ensemble_path}/val_dataset.pt'
     )
 
+
 # Gets the predictions for a set of models on a dataset
 def get_ensemble_predictions(models, dataset, device):
     zeros = np.zeros((len(dataset), len(models), 4))
@@ -74,25 +78,35 @@ def get_ensemble_predictions(models, dataset, device):
         dims=('data_id', 'model_id', 'prediction_value'),
         coords={
             'data_id': range(len(dataset)),
-            'model_id': models.keys(),
-            'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
-        }
+            'model_id': list(models.keys()),
+            'prediction_value': [
+                'negative_prediction',
+                'positive_prediction',
+                'negative_actual',
+                'positive_actual',
+            ],
+        },
     )
 
-    for data_id, (data, target) in tqdm(enumerate(dataset)):
-        mri, xls = preprocess_data(data, device)
+    for data_id, (data, target) in tqdm(
+        enumerate(dataset), total=len(dataset), unit='images'
+    ):
+        dat = preprocess_data(data, device)
         actual = list(target.cpu().numpy())
         for model_id, model in models.items():
             with torch.no_grad():
-                output = model(mri, xls)
-                prediction = list(output.cpu().numpy())
+                output = model(dat)
+                prediction = output.cpu().numpy().tolist()[0]
 
-                predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
+                predictions.loc[{'data_id': data_id, 'model_id': model_id}] = (
+                    prediction + actual
+                )
 
     return predictions
-                
+
+
 # Compute the ensemble statistics given an array of predictions
-def compute_ensemble_statistics(predictions):
+def compute_ensemble_statistics(predictions: xr.DataArray):
     zeros = np.zeros((len(predictions.data_id), 7))
 
     ensemble_statistics = xr.DataArray(
@@ -100,58 +114,93 @@ def compute_ensemble_statistics(predictions):
         dims=('data_id', 'statistic'),
         coords={
             'data_id': predictions.data_id,
-            'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
-        }
-    )   
+            'statistic': [
+                'mean',
+                'stdev',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
 
     for data_id in predictions.data_id:
-        data = predictions.loc[{ 'data_id': data_id }]
-        mean = np.mean(data, axis=0)
-        stdev = np.std(data, axis=0)
-        entropy = -np.sum(mean * np.log2(mean + 1e-12))
-        confidence = np.max(mean)
-        
-        actual = data.iloc[:, 3].values
-        predicted = np.argmax(mean)
+        data = predictions.loc[{'data_id': data_id}]
+        mean = data.mean(dim='model_id')[
+            0:2
+        ]  # Only take the predictions, not the actual
+        stdev = data.std(dim='model_id')[
+            1
+        ]  # Only need the standard deviation of the postive prediction
+        entropy = (-mean * np.log(mean)).sum()
+
+        # Compute confidence
+        confidence = mean.max()
+
+        # only need one of the actual values, since they are all the same, just get the first actual_positive
+        actual = data.loc[{'prediction_value': 'positive_actual'}][0]
+        predicted = mean.argmax()
         correct = actual == predicted
 
-        ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
+        ensemble_statistics.loc[{'data_id': data_id}] = [
+            mean[1],
+            stdev,
+            entropy,
+            confidence,
+            correct,
+            predicted,
+            actual,
+        ]
 
     return ensemble_statistics
 
+
 # Compute the thresholded predictions given an array of predictions
-def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
-    quantiles = np.linspace(0.05, 0.95, 19)
+def compute_thresholded_predictions(input_stats: xr.DataArray):
+    quantiles = np.linspace(0.05, 0.95, 19) * 100
     metrics = ['accuracy', 'f1']
     statistics = ['stdev', 'entropy', 'confidence']
-    
+
     zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
 
     thresholded_predictions = xr.DataArray(
         zeros,
         dims=('quantile', 'statistic', 'metric'),
-        coords={
-            'quantile': quantiles,
-            'statistic': statistics,
-            'metric': metrics
-        }
+        coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics},
     )
 
     for statistic in statistics:
-        #First, we must compute the quantiles for the statistic
-        quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
+        # First, we must compute the quantiles for the statistic
+        quantile_values = np.percentile(
+            input_stats.sel(statistic=statistic).values, quantiles, axis=0
+        )
 
-        #Then, we must compute the metrics for each quantile
+        # Then, we must compute the metrics for each quantile
         for i, quantile in enumerate(quantiles):
             if low_to_high(statistic):
-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
+                mask = (
+                    input_stats.sel(statistic=statistic) >= quantile_values[i]
+                ).values
             else:
-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
+                mask = (
+                    input_stats.sel(statistic=statistic) <= quantile_values[i]
+                ).values
+
+            # Filter the data based on the mask
+            filtered_data = input_stats.where(
+                input_stats.data_id.isin(np.where(mask)), drop=True
+            )
+
             for metric in metrics:
-                thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
-    
+                thresholded_predictions.loc[
+                    {'quantile': quantile, 'statistic': statistic, 'metric': metric}
+                ] = compute_metric(filtered_data, metric)
+
     return thresholded_predictions
-                
+
+
 # Truth function to determine if metric should be thresholded low to high or high to low
 # Low confidence is bad, high entropy is bad, high stdev is bad
 # So we threshold confidence low to high, entropy and stdev high to low
@@ -159,35 +208,438 @@ def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
 def low_to_high(stat):
     return stat in ['confidence']
 
+
 # Compute a given metric on a DataArray of statstics
 def compute_metric(arr, metric):
     if metric == 'accuracy':
-        return np.mean(arr.loc[{ 'statistic': 'correct' }])
+        return np.mean(arr.loc[{'statistic': 'correct'}])
     elif metric == 'f1':
-        return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
+        return met.F1(
+            arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]
+        )
     else:
         raise ValueError('Invalid metric: ' + metric)
 
 
+# Graph a thresholded prediction for a given statistic and metric
+def graph_thresholded_prediction(
+    thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel
+):
+    data = thresholded_predictions.sel(statistic=statistic, metric=metric)
+
+    x_data = data.coords['quantile'].values
+    y_data = data.values
+
+    fig, ax = plt.subplots()
+    ax.plot(x_data, y_data, 'bx-', label='Ensemble')
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.xaxis.set_major_formatter(mtick.PercentFormatter())
+
+    if not low_to_high(statistic):
+        ax.invert_xaxis()
+
+    plt.savefig(save_path)
+
+
+# Graph all thresholded predictions
+def graph_all_thresholded_predictions(thresholded_predictions, save_path):
+    # Confidence Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'confidence',
+        'accuracy',
+        f'{save_path}/confidence_accuracy.png',
+        'Confidence vs. Accuracy',
+        'Confidence',
+        'Accuracy',
+    )
+
+    # Confidence F1
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'confidence',
+        'f1',
+        f'{save_path}/confidence_f1.png',
+        'Confidence vs. F1',
+        'Confidence',
+        'F1',
+    )
+
+    # Entropy Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'entropy',
+        'accuracy',
+        f'{save_path}/entropy_accuracy.png',
+        'Entropy vs. Accuracy',
+        'Entropy',
+        'Accuracy',
+    )
+
+    # Entropy F1
+
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'entropy',
+        'f1',
+        f'{save_path}/entropy_f1.png',
+        'Entropy vs. F1',
+        'Entropy',
+        'F1',
+    )
+
+    # Stdev Accuracy
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'stdev',
+        'accuracy',
+        f'{save_path}/stdev_accuracy.png',
+        'Standard Deviation vs. Accuracy',
+        'Standard Deviation',
+        'Accuracy',
+    )
+
+    # Stdev F1
+    graph_thresholded_prediction(
+        thresholded_predictions,
+        'stdev',
+        'f1',
+        f'{save_path}/stdev_f1.png',
+        'Standard Deviation vs. F1',
+        'Standard Deviation',
+        'F1',
+    )
+
+
+# Graph two statistics against each other
+def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel):
+    # Filter for correct predictions
+    c_stats = stats.where(
+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)),
+        drop=True,
+    )
+
+    # Filter for incorrect predictions
+    i_stats = stats.where(
+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)),
+        drop=True,
+    )
+
+    # x and y data for correct and incorrect predictions
+    x_data_c = c_stats.sel(statistic=x_stat).values
+    y_data_c = c_stats.sel(statistic=y_stat).values
+
+    x_data_i = i_stats.sel(statistic=x_stat).values
+    y_data_i = i_stats.sel(statistic=y_stat).values
+
+    fig, ax = plt.subplots()
+    ax.plot(x_data_c, y_data_c, 'go', label='Correct')
+    ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect')
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.legend()
+
+    plt.savefig(save_path)
+
+
+# Prune the data based on excluded data_ids
+def prune_data(data, excluded_data_ids):
+    return data.where(~data.data_id.isin(excluded_data_ids), drop=True)
+
+
+# Calculate individual model statistics
+def compute_individual_statistics(predictions: xr.DataArray):
+    zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6))
+
+    indv_statistics = xr.DataArray(
+        zeros,
+        dims=('data_id', 'model_id', 'statistic'),
+        coords={
+            'data_id': predictions.data_id,
+            'model_id': predictions.model_id,
+            'statistic': [
+                'mean',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
+
+    for data_id in predictions.data_id:
+        for model_id in predictions.model_id:
+            data = predictions.loc[{'data_id': data_id, 'model_id': model_id}]
+            mean = data[0:2]
+            entropy = (-mean * np.log(mean)).sum()
+            confidence = mean.max()
+            actual = data[3]
+            predicted = mean.argmax()
+            correct = actual == predicted
+
+            indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [
+                mean[1],
+                entropy,
+                confidence,
+                correct,
+                predicted,
+                actual,
+            ]
+
+    return indv_statistics
+
+
+# Compute individual model thresholds
+def compute_individual_thresholds(input_stats: xr.DataArray):
+    quantiles = np.linspace(0.05, 0.95, 19) * 100
+    metrics = ['accuracy', 'f1']
+    statistics = ['entropy', 'confidence']
+
+    zeros = np.zeros(
+        (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics))
+    )
+
+    indv_thresholds = xr.DataArray(
+        zeros,
+        dims=('model_id', 'quantile', 'statistic', 'metric'),
+        coords={
+            'model_id': input_stats.model_id,
+            'quantile': quantiles,
+            'statistic': statistics,
+            'metric': metrics,
+        },
+    )
+
+    for model_id in input_stats.model_id:
+        for statistic in statistics:
+            # First, we must compute the quantiles for the statistic
+            quantile_values = np.percentile(
+                input_stats.sel(model_id=model_id, statistic=statistic).values,
+                quantiles,
+                axis=0,
+            )
+
+            # Then, we must compute the metrics for each quantile
+            for i, quantile in enumerate(quantiles):
+                if low_to_high(statistic):
+                    mask = (
+                        input_stats.sel(model_id=model_id, statistic=statistic)
+                        >= quantile_values[i]
+                    ).values
+                else:
+                    mask = (
+                        input_stats.sel(model_id=model_id, statistic=statistic)
+                        <= quantile_values[i]
+                    ).values
+
+                # Filter the data based on the mask
+                filtered_data = input_stats.where(
+                    input_stats.data_id.isin(np.where(mask)), drop=True
+                )
+
+                for metric in metrics:
+                    indv_thresholds.loc[
+                        {
+                            'model_id': model_id,
+                            'quantile': quantile,
+                            'statistic': statistic,
+                            'metric': metric,
+                        }
+                    ] = compute_metric(filtered_data, metric)
+
+    return indv_thresholds
+
+
+# Graph individual model thresholded predictions
+def graph_individual_thresholded_predictions(
+    indv_thresholds,
+    ensemble_thresholds,
+    statistic,
+    metric,
+    save_path,
+    title,
+    xlabel,
+    ylabel,
+):
+    data = indv_thresholds.sel(statistic=statistic, metric=metric)
+    e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric)
+
+    x_data = data.coords['quantile'].values
+    y_data = data.values
+
+    e_x_data = e_data.coords['quantile'].values
+    e_y_data = e_data.values
+
+    fig, ax = plt.subplots()
+    for model_id in data.coords['model_id'].values:
+        model_data = data.sel(model_id=model_id)
+        ax.plot(x_data, model_data)
+
+    ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble')
+
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.xaxis.set_major_formatter(mtick.PercentFormatter())
+
+    if not low_to_high(statistic):
+        ax.invert_xaxis()
+
+    ax.legend()
+    plt.savefig(save_path)
+
+
+# Graph all individual thresholded predictions
+def graph_all_individual_thresholded_predictions(
+    indv_thresholds, ensemble_thresholds, save_path
+):
+    # Confidence Accuracy
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'confidence',
+        'accuracy',
+        f'{save_path}/indv/confidence_accuracy.png',
+        'Confidence vs. Accuracy',
+        'Confidence Percentile Threshold',
+        'Accuracy',
+    )
+
+    # Confidence F1
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'confidence',
+        'f1',
+        f'{save_path}/indv/confidence_f1.png',
+        'Confidence vs. F1',
+        'Confidence Percentile Threshold',
+        'F1',
+    )
+
+    # Entropy Accuracy
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'entropy',
+        'accuracy',
+        f'{save_path}/indv/entropy_accuracy.png',
+        'Entropy vs. Accuracy',
+        'Entropy Percentile Threshold',
+        'Accuracy',
+    )
+
+    # Entropy F1
+    graph_individual_thresholded_predictions(
+        indv_thresholds,
+        ensemble_thresholds,
+        'entropy',
+        'f1',
+        f'{save_path}/indv/entropy_f1.png',
+        'Entropy vs. F1',
+        'Entropy Percentile Threshold',
+        'F1',
+    )
+
+
+# Calculate statistics of subsets of models for sensitivity analysis
+def calculate_subset_statistics(predictions: xr.DataArray):
+    # Calculate subsets for 1-50 models
+    subsets = range(1, len(predictions.model_id) + 1)
+    zeros = np.zeros(
+        (len(predictions.data_id), len(subsets), 7)
+    )  # Include stdev, but for 1 models set to NaN
+
+    subset_stats = xr.DataArray(
+        zeros,
+        dims=('data_id', 'model_count', 'statistic'),
+        coords={
+            'data_id': predictions.data_id,
+            'model_count': subsets,
+            'statistic': [
+                'mean',
+                'stdev',
+                'entropy',
+                'confidence',
+                'correct',
+                'predicted',
+                'actual',
+            ],
+        },
+    )
+
+    for data_id in predictions.data_id:
+        for subset in subsets:
+            data = predictions.sel(
+                data_id=data_id, model_id=predictions.model_id[:subset]
+            )
+            mean = data.mean(dim='model_id')[0:2]
+            stdev = data.std(dim='model_id')[1]
+            entropy = (-mean * np.log(mean)).sum()
+            confidence = mean.max()
+            actual = data[3]
+            predicted = mean.argmax()
+            correct = actual == predicted
+
+            subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [
+                mean[1],
+                stdev,
+                entropy,
+                confidence,
+                correct,
+                predicted,
+                actual,
+            ]
+
+    return subset_stats
+
+
+# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis
+def calculate_sensitivity_analysis(subset_stats: xr.DataArray):
+    subsets = subset_stats.subsets
+    stats = ['accuracy', 'f1', 'ECE', 'MCE']
+
+    zeros = np.zeros((len(subsets), len(stats)))
+
+    sens_analysis = xr.DataArray(
+        zeros,
+        dims=('model_count', 'statistic'),
+        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']},
+    )
+
+
+# Main Function
 def main():
+    print('Loading Config...')
     config = load_config()
     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
     V4_PATH = ENSEMBLE_PATH + '/v4'
 
     if not os.path.exists(V4_PATH):
         os.makedirs(V4_PATH)
+    print('Config Loaded')
 
     # Load Datasets
+    print('Loading Datasets...')
     dataset = load_datasets(ENSEMBLE_PATH)
+    print('Datasets Loaded')
 
     # Get Predictions, either by running the models or loading them from a file
     if config['ensemble']['run_models']:
         # Load Models
+        print('Loading Models...')
         device = torch.device(config['training']['device'])
         models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+        print('Models Loaded')
 
         # Get Predictions
+        print('Getting Predictions...')
         predictions = get_ensemble_predictions(models, dataset, device)
+        print('Predictions Loaded')
 
         # Save Prediction
         predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
@@ -195,10 +647,65 @@ def main():
     else:
         predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
 
+    # Prune Data
+    print('Pruning Data...')
+    if config['operation']['exclude_blank_ids']:
+        excluded_data_ids = config['ensemble']['excluded_ids']
+        predictions = prune_data(predictions, excluded_data_ids)
+
     # Compute Ensemble Statistics
+    print('Computing Ensemble Statistics...')
     ensemble_statistics = compute_ensemble_statistics(predictions)
     ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
+    print('Ensemble Statistics Computed')
 
     # Compute Thresholded Predictions
+    print('Computing Thresholded Predictions...')
     thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
     thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
+    print('Thresholded Predictions Computed')
+
+    # Graph Thresholded Predictions
+    print('Graphing Thresholded Predictions...')
+    graph_all_thresholded_predictions(thresholded_predictions, V4_PATH)
+    print('Thresholded Predictions Graphed')
+
+    # Additional Graphs
+    print('Graphing Additional Graphs...')
+    # Confidence vs stdev
+    graph_statistics(
+        ensemble_statistics,
+        'confidence',
+        'stdev',
+        f'{V4_PATH}/confidence_stdev.png',
+        'Confidence vs. Standard Deviation',
+        'Confidence',
+        'Standard Deviation',
+    )
+    print('Additional Graphs Graphed')
+
+    # Compute Individual Statistics
+    print('Computing Individual Statistics...')
+    indv_statistics = compute_individual_statistics(predictions)
+    indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc')
+    print('Individual Statistics Computed')
+
+    # Compute Individual Thresholds
+    print('Computing Individual Thresholds...')
+    indv_thresholds = compute_individual_thresholds(indv_statistics)
+    indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc')
+    print('Individual Thresholds Computed')
+
+    # Graph Individual Thresholded Predictions
+    print('Graphing Individual Thresholded Predictions...')
+    if not os.path.exists(f'{V4_PATH}/indv'):
+        os.makedirs(f'{V4_PATH}/indv')
+
+    graph_all_individual_thresholded_predictions(
+        indv_thresholds, thresholded_predictions, V4_PATH
+    )
+    print('Individual Thresholded Predictions Graphed')
+
+
+if __name__ == '__main__':
+    main()