| 
					
				 | 
			
			
				@@ -1,31 +1,31 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#Rewritten Program to use xarray instead of pandas for thresholding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Rewritten Program to use xarray instead of pandas for thresholding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import xarray as xr 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import torch  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import glob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import tomli as toml 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from tqdm import tqdm 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import utils.metrics as met 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib.pyplot as plt 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import matplotlib.ticker as mtick 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    main() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#The datastructures for this file are as follows 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# The datastructures for this file are as follows 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # models_dict: Dictionary - {model_id: model} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#Additionally, we also have the thresholds and statistics for the individual models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Additionally, we also have the thresholds and statistics for the individual models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#Additionally, we have some for the sensitivity analysis for number of models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Additionally, we have some for the sensitivity analysis for number of models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Loads configuration dictionary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def load_config(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if os.getenv('ADL_CONFIG_PATH') is None: 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -37,7 +37,8 @@ def load_config(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-#Loads models into a dictionary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Loads models into a dictionary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def load_models_v2(folder, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     glob_path = os.path.join(folder, '*.pt') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_files = glob.glob(glob_path) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -53,6 +54,7 @@ def load_models_v2(folder, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return model_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def preprocess_data(data, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     mri, xls = data 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -60,12 +62,14 @@ def preprocess_data(data, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     xls = xls.unsqueeze(0).to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return (mri, xls) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Loads datasets and returns concatenated test and validation datasets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def load_datasets(ensemble_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         f'{ensemble_path}/val_dataset.pt' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Gets the predictions for a set of models on a dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def get_ensemble_predictions(models, dataset, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     zeros = np.zeros((len(dataset), len(models), 4)) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -74,25 +78,35 @@ def get_ensemble_predictions(models, dataset, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dims=('data_id', 'model_id', 'prediction_value'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             'data_id': range(len(dataset)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'model_id': models.keys(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_id': list(models.keys()), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'prediction_value': [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'negative_prediction', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'positive_prediction', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'negative_actual', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'positive_actual', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for data_id, (data, target) in tqdm(enumerate(dataset)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        mri, xls = preprocess_data(data, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for data_id, (data, target) in tqdm( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        enumerate(dataset), total=len(dataset), unit='images' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dat = preprocess_data(data, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         actual = list(target.cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for model_id, model in models.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                output = model(mri, xls) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                prediction = list(output.cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                output = model(dat) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prediction = output.cpu().numpy().tolist()[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                predictions.loc[{'data_id': data_id, 'model_id': model_id}] = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    prediction + actual 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Compute the ensemble statistics given an array of predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def compute_ensemble_statistics(predictions): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_ensemble_statistics(predictions: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     zeros = np.zeros((len(predictions.data_id), 7)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ensemble_statistics = xr.DataArray( 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -100,58 +114,93 @@ def compute_ensemble_statistics(predictions): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dims=('data_id', 'statistic'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             'data_id': predictions.data_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    )    
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'mean', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'stdev', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'correct', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'predicted', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'actual', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for data_id in predictions.data_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        data = predictions.loc[{ 'data_id': data_id }] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        mean = np.mean(data, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        stdev = np.std(data, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        entropy = -np.sum(mean * np.log2(mean + 1e-12)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        confidence = np.max(mean) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        actual = data.iloc[:, 3].values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        predicted = np.argmax(mean) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        data = predictions.loc[{'data_id': data_id}] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        mean = data.mean(dim='model_id')[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            0:2 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ]  # Only take the predictions, not the actual 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        stdev = data.std(dim='model_id')[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ]  # Only need the standard deviation of the postive prediction 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        entropy = (-mean * np.log(mean)).sum() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Compute confidence 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        confidence = mean.max() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # only need one of the actual values, since they are all the same, just get the first actual_positive 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        actual = data.loc[{'prediction_value': 'positive_actual'}][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predicted = mean.argmax() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         correct = actual == predicted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_statistics.loc[{'data_id': data_id}] = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mean[1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            stdev, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            entropy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            confidence, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            correct, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predicted, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            actual, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return ensemble_statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Compute the thresholded predictions given an array of predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def compute_thresholded_predictions(ensemble_statistics: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    quantiles = np.linspace(0.05, 0.95, 19) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_thresholded_predictions(input_stats: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    quantiles = np.linspace(0.05, 0.95, 19) * 100 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     metrics = ['accuracy', 'f1'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     statistics = ['stdev', 'entropy', 'confidence'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     zeros = np.zeros((len(quantiles), len(statistics), len(metrics))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     thresholded_predictions = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         dims=('quantile', 'statistic', 'metric'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'quantile': quantiles, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'statistic': statistics, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            'metric': metrics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={'quantile': quantiles, 'statistic': statistics, 'metric': metrics}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     for statistic in statistics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #First, we must compute the quantiles for the statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # First, we must compute the quantiles for the statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        quantile_values = np.percentile( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            input_stats.sel(statistic=statistic).values, quantiles, axis=0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #Then, we must compute the metrics for each quantile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Then, we must compute the metrics for each quantile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for i, quantile in enumerate(quantiles): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if low_to_high(statistic): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                mask = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    input_stats.sel(statistic=statistic) >= quantile_values[i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                mask = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    input_stats.sel(statistic=statistic) <= quantile_values[i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Filter the data based on the mask 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            filtered_data = input_stats.where( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                input_stats.data_id.isin(np.where(mask)), drop=True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for metric in metrics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                thresholded_predictions.loc[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    {'quantile': quantile, 'statistic': statistic, 'metric': metric} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ] = compute_metric(filtered_data, metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return thresholded_predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Truth function to determine if metric should be thresholded low to high or high to low 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Low confidence is bad, high entropy is bad, high stdev is bad 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # So we threshold confidence low to high, entropy and stdev high to low 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -159,35 +208,438 @@ def compute_thresholded_predictions(ensemble_statistics: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def low_to_high(stat): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return stat in ['confidence'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 # Compute a given metric on a DataArray of statstics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def compute_metric(arr, metric): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if metric == 'accuracy': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return np.mean(arr.loc[{ 'statistic': 'correct' }]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return np.mean(arr.loc[{'statistic': 'correct'}]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     elif metric == 'f1': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return met.F1( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         raise ValueError('Invalid metric: ' + metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Graph a thresholded prediction for a given statistic and metric 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    thresholded_predictions, statistic, metric, save_path, title, xlabel, ylabel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    data = thresholded_predictions.sel(statistic=statistic, metric=metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    x_data = data.coords['quantile'].values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    y_data = data.values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, ax = plt.subplots() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.plot(x_data, y_data, 'bx-', label='Ensemble') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_title(title) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_xlabel(xlabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_ylabel(ylabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.xaxis.set_major_formatter(mtick.PercentFormatter()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not low_to_high(statistic): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ax.invert_xaxis() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Graph all thresholded predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def graph_all_thresholded_predictions(thresholded_predictions, save_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Confidence Accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/confidence_accuracy.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence vs. Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Confidence F1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'f1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/confidence_f1.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence vs. F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Entropy Accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/entropy_accuracy.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy vs. Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Entropy F1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'f1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/entropy_f1.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy vs. F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Stdev Accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'stdev', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/stdev_accuracy.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Standard Deviation vs. Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Standard Deviation', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Stdev F1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_thresholded_prediction( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        thresholded_predictions, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'stdev', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'f1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/stdev_f1.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Standard Deviation vs. F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Standard Deviation', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Graph two statistics against each other 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def graph_statistics(stats, x_stat, y_stat, save_path, title, xlabel, ylabel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Filter for correct predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    c_stats = stats.where( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 1).values)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        drop=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Filter for incorrect predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    i_stats = stats.where( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        stats.data_id.isin(np.where((stats.sel(statistic='correct') == 0).values)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        drop=True, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # x and y data for correct and incorrect predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    x_data_c = c_stats.sel(statistic=x_stat).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    y_data_c = c_stats.sel(statistic=y_stat).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    x_data_i = i_stats.sel(statistic=x_stat).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    y_data_i = i_stats.sel(statistic=y_stat).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, ax = plt.subplots() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.plot(x_data_c, y_data_c, 'go', label='Correct') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.plot(x_data_i, y_data_i, 'ro', label='Incorrect') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_title(title) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_xlabel(xlabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_ylabel(ylabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.legend() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Prune the data based on excluded data_ids 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def prune_data(data, excluded_data_ids): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return data.where(~data.data_id.isin(excluded_data_ids), drop=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Calculate individual model statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_individual_statistics(predictions: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros((len(predictions.data_id), len(predictions.model_id), 6)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_statistics = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('data_id', 'model_id', 'statistic'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'data_id': predictions.data_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_id': predictions.model_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'mean', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'correct', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'predicted', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'actual', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for data_id in predictions.data_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for model_id in predictions.model_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            data = predictions.loc[{'data_id': data_id, 'model_id': model_id}] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mean = data[0:2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            entropy = (-mean * np.log(mean)).sum() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            confidence = mean.max() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            actual = data[3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predicted = mean.argmax() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            correct = actual == predicted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            indv_statistics.loc[{'data_id': data_id, 'model_id': model_id}] = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                mean[1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                entropy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                confidence, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                correct, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                predicted, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                actual, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return indv_statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Compute individual model thresholds 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_individual_thresholds(input_stats: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    quantiles = np.linspace(0.05, 0.95, 19) * 100 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    metrics = ['accuracy', 'f1'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    statistics = ['entropy', 'confidence'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        (len(input_stats.model_id), len(quantiles), len(statistics), len(metrics)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_thresholds = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('model_id', 'quantile', 'statistic', 'metric'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_id': input_stats.model_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'quantile': quantiles, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': statistics, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'metric': metrics, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for model_id in input_stats.model_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for statistic in statistics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # First, we must compute the quantiles for the statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            quantile_values = np.percentile( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                input_stats.sel(model_id=model_id, statistic=statistic).values, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                quantiles, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                axis=0, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # Then, we must compute the metrics for each quantile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for i, quantile in enumerate(quantiles): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if low_to_high(statistic): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    mask = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        input_stats.sel(model_id=model_id, statistic=statistic) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        >= quantile_values[i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    mask = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        input_stats.sel(model_id=model_id, statistic=statistic) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        <= quantile_values[i] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ).values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                # Filter the data based on the mask 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                filtered_data = input_stats.where( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    input_stats.data_id.isin(np.where(mask)), drop=True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                for metric in metrics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    indv_thresholds.loc[ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'model_id': model_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'quantile': quantile, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'statistic': statistic, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            'metric': metric, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    ] = compute_metric(filtered_data, metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return indv_thresholds 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Graph individual model thresholded predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def graph_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ensemble_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    statistic, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    metric, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    save_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    title, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    xlabel, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ylabel, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    data = indv_thresholds.sel(statistic=statistic, metric=metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    e_data = ensemble_thresholds.sel(statistic=statistic, metric=metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    x_data = data.coords['quantile'].values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    y_data = data.values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    e_x_data = e_data.coords['quantile'].values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    e_y_data = e_data.values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fig, ax = plt.subplots() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for model_id in data.coords['model_id'].values: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_data = data.sel(model_id=model_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ax.plot(x_data, model_data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.plot(e_x_data, e_y_data, 'kx-', label='Ensemble') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_title(title) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_xlabel(xlabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.set_ylabel(ylabel) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.xaxis.set_major_formatter(mtick.PercentFormatter()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not low_to_high(statistic): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ax.invert_xaxis() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ax.legend() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    plt.savefig(save_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Graph all individual thresholded predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def graph_all_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_thresholds, ensemble_thresholds, save_path 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Confidence Accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indv_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/indv/confidence_accuracy.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence vs. Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence Percentile Threshold', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Confidence F1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indv_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'f1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/indv/confidence_f1.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence vs. F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence Percentile Threshold', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Entropy Accuracy 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indv_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/indv/entropy_accuracy.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy vs. Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy Percentile Threshold', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Accuracy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Entropy F1 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indv_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_thresholds, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'f1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{save_path}/indv/entropy_f1.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy vs. F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Entropy Percentile Threshold', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'F1', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Calculate statistics of subsets of models for sensitivity analysis 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def calculate_subset_statistics(predictions: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Calculate subsets for 1-50 models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    subsets = range(1, len(predictions.model_id) + 1) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        (len(predictions.data_id), len(subsets), 7) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    )  # Include stdev, but for 1 models set to NaN 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    subset_stats = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('data_id', 'model_count', 'statistic'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'data_id': predictions.data_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_count': subsets, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'mean', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'stdev', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'entropy', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'correct', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'predicted', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                'actual', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for data_id in predictions.data_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for subset in subsets: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            data = predictions.sel( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                data_id=data_id, model_id=predictions.model_id[:subset] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            mean = data.mean(dim='model_id')[0:2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            stdev = data.std(dim='model_id')[1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            entropy = (-mean * np.log(mean)).sum() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            confidence = mean.max() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            actual = data[3] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            predicted = mean.argmax() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            correct = actual == predicted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            subset_stats.loc[{'data_id': data_id, 'model_count': subset}] = [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                mean[1], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                stdev, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                entropy, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                confidence, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                correct, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                predicted, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                actual, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return subset_stats 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Calculate Accuracy, F1 and ECE for subset stats - sensityvity analysis 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def calculate_sensitivity_analysis(subset_stats: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    subsets = subset_stats.subsets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    stats = ['accuracy', 'f1', 'ECE', 'MCE'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros((len(subsets), len(stats))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    sens_analysis = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('model_count', 'statistic'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={'model_count': subsets, 'statistic': ['accuracy', 'f1', 'ECE', 'MCE']}, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Main Function 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Loading Config...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     config = load_config() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     V4_PATH = ENSEMBLE_PATH + '/v4' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if not os.path.exists(V4_PATH): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         os.makedirs(V4_PATH) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Config Loaded') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Load Datasets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Loading Datasets...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = load_datasets(ENSEMBLE_PATH) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Datasets Loaded') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Get Predictions, either by running the models or loading them from a file 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if config['ensemble']['run_models']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Load Models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print('Loading Models...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         device = torch.device(config['training']['device']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print('Models Loaded') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Get Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print('Getting Predictions...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         predictions = get_ensemble_predictions(models, dataset, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        print('Predictions Loaded') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # Save Prediction 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         predictions.to_netcdf(f'{V4_PATH}/predictions.nc') 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -195,10 +647,65 @@ def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Prune Data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Pruning Data...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if config['operation']['exclude_blank_ids']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        excluded_data_ids = config['ensemble']['excluded_ids'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions = prune_data(predictions, excluded_data_ids) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Compute Ensemble Statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Computing Ensemble Statistics...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ensemble_statistics = compute_ensemble_statistics(predictions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Ensemble Statistics Computed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Compute Thresholded Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Computing Thresholded Predictions...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     thresholded_predictions = compute_thresholded_predictions(ensemble_statistics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Thresholded Predictions Computed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Graph Thresholded Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Graphing Thresholded Predictions...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_all_thresholded_predictions(thresholded_predictions, V4_PATH) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Thresholded Predictions Graphed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Additional Graphs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Graphing Additional Graphs...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Confidence vs stdev 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_statistics( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_statistics, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'stdev', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{V4_PATH}/confidence_stdev.png', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence vs. Standard Deviation', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Confidence', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        'Standard Deviation', 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Additional Graphs Graphed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Compute Individual Statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Computing Individual Statistics...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_statistics = compute_individual_statistics(predictions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_statistics.to_netcdf(f'{V4_PATH}/indv_statistics.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Individual Statistics Computed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Compute Individual Thresholds 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Computing Individual Thresholds...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_thresholds = compute_individual_thresholds(indv_statistics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    indv_thresholds.to_netcdf(f'{V4_PATH}/indv_thresholds.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Individual Thresholds Computed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Graph Individual Thresholded Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Graphing Individual Thresholded Predictions...') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not os.path.exists(f'{V4_PATH}/indv'): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        os.makedirs(f'{V4_PATH}/indv') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    graph_all_individual_thresholded_predictions( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        indv_thresholds, thresholded_predictions, V4_PATH 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print('Individual Thresholded Predictions Graphed') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    main() 
			 |