| 
					
				 | 
			
			
				@@ -0,0 +1,204 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#Rewritten Program to use xarray instead of pandas for thresholding 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import xarray as xr 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import torch  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import glob 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import tomli as toml 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from tqdm import tqdm 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import utils.metrics as met 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == '__main__': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    main() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#The datastructures for this file are as follows 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# models_dict: Dictionary - {model_id: model} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#Additionally, we also have the thresholds and statistics for the individual models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#Additionally, we have some for the sensitivity analysis for number of models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Loads configuration dictionary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def load_config(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if os.getenv('ADL_CONFIG_PATH') is None: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with open('config.toml', 'rb') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config = toml.load(f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config = toml.load(f) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+#Loads models into a dictionary 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def load_models_v2(folder, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    glob_path = os.path.join(folder, '*.pt') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model_files = glob.glob(glob_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for model_file in model_files: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model = torch.load(model_file, map_location=device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_id = os.path.basename(model_file).split('_')[0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model_dict[model_id] = model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if len(model_dict) == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        raise FileNotFoundError('No models found in the specified directory: ' + folder) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return model_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def preprocess_data(data, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    mri, xls = data 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    mri = mri.unsqueeze(0).to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    xls = xls.unsqueeze(0).to(device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return (mri, xls) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Loads datasets and returns concatenated test and validation datasets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def load_datasets(ensemble_path): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        f'{ensemble_path}/val_dataset.pt' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Gets the predictions for a set of models on a dataset 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def get_ensemble_predictions(models, dataset, device): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros((len(dataset), len(models), 4)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    predictions = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('data_id', 'model_id', 'prediction_value'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'data_id': range(len(dataset)), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'model_id': models.keys(), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for data_id, (data, target) in tqdm(enumerate(dataset)): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        mri, xls = preprocess_data(data, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        actual = list(target.cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for model_id, model in models.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            with torch.no_grad(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                output = model(mri, xls) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                prediction = list(output.cpu().numpy()) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Compute the ensemble statistics given an array of predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_ensemble_statistics(predictions): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros((len(predictions.data_id), 7)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ensemble_statistics = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('data_id', 'statistic'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'data_id': predictions.data_id, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    )    
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for data_id in predictions.data_id: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        data = predictions.loc[{ 'data_id': data_id }] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        mean = np.mean(data, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        stdev = np.std(data, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        entropy = -np.sum(mean * np.log2(mean + 1e-12)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        confidence = np.max(mean) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        actual = data.iloc[:, 3].values 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predicted = np.argmax(mean) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        correct = actual == predicted 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return ensemble_statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Compute the thresholded predictions given an array of predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_thresholded_predictions(ensemble_statistics: xr.DataArray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    quantiles = np.linspace(0.05, 0.95, 19) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    metrics = ['accuracy', 'f1'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    statistics = ['stdev', 'entropy', 'confidence'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    zeros = np.zeros((len(quantiles), len(statistics), len(metrics))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    thresholded_predictions = xr.DataArray( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        zeros, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        dims=('quantile', 'statistic', 'metric'), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        coords={ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'quantile': quantiles, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'statistic': statistics, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            'metric': metrics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for statistic in statistics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #First, we must compute the quantiles for the statistic 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #Then, we must compute the metrics for each quantile 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i, quantile in enumerate(quantiles): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if low_to_high(statistic): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for metric in metrics: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return thresholded_predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Truth function to determine if metric should be thresholded low to high or high to low 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Low confidence is bad, high entropy is bad, high stdev is bad 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# So we threshold confidence low to high, entropy and stdev high to low 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def low_to_high(stat): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return stat in ['confidence'] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# Compute a given metric on a DataArray of statstics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def compute_metric(arr, metric): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if metric == 'accuracy': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return np.mean(arr.loc[{ 'statistic': 'correct' }]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    elif metric == 'f1': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        raise ValueError('Invalid metric: ' + metric) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    config = load_config() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    V4_PATH = ENSEMBLE_PATH + '/v4' 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if not os.path.exists(V4_PATH): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        os.makedirs(V4_PATH) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Load Datasets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset = load_datasets(ENSEMBLE_PATH) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Get Predictions, either by running the models or loading them from a file 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    if config['ensemble']['run_models']: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Load Models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        device = torch.device(config['training']['device']) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Get Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions = get_ensemble_predictions(models, dataset, device) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # Save Prediction 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions.to_netcdf(f'{V4_PATH}/predictions.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Compute Ensemble Statistics 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ensemble_statistics = compute_ensemble_statistics(predictions) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Compute Thresholded Predictions 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    thresholded_predictions = compute_thresholded_predictions(ensemble_statistics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc') 
			 |