Quellcode durchsuchen

Started work on xarray migration from pandas

Nicholas Schense vor 3 Monaten
Ursprung
Commit
ea069c3f1f
3 geänderte Dateien mit 206 neuen und 2 gelöschten Zeilen
  1. 2 1
      config.toml
  2. 0 1
      threshold_refac.py
  3. 204 0
      threshold_xarray.py

+ 2 - 1
config.toml

@@ -28,9 +28,10 @@ droprate = 0.5
 
 [operation]
 silent = false
+exclude_blank_ids = false
 
 [ensemble]
 name = 'cnn-50x30'
 prune_threshold = 0.0 # Any models with accuracy below this threshold will be pruned, set to 0 to disable pruning
 individual_id = 1     # The id of the individual model to be used for the ensemble
-run_models = false    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file
+run_models = true    # If true, the ensemble will run the models to generate the predictions, otherwise will load from file

+ 0 - 1
threshold_refac.py

@@ -350,7 +350,6 @@ def main():
         f'{ENSEMBLE_PATH}/val_dataset.pt'
     )
 
-    dataset = 
 
     if config['ensemble']['run_models']:
         # Get thre predicitons of the ensemble

+ 204 - 0
threshold_xarray.py

@@ -0,0 +1,204 @@
+#Rewritten Program to use xarray instead of pandas for thresholding
+
+import xarray as xr
+import torch 
+import numpy as np
+import os
+import glob
+import tomli as toml
+from tqdm import tqdm
+import utils.metrics as met
+
+if __name__ == '__main__':
+    main()
+
+
+#The datastructures for this file are as follows
+# models_dict: Dictionary - {model_id: model}
+# predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
+# ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
+# thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
+
+#Additionally, we also have the thresholds and statistics for the individual models
+# indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
+# indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
+
+#Additionally, we have some for the sensitivity analysis for number of models
+# sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
+
+# Loads configuration dictionary
+def load_config():
+    if os.getenv('ADL_CONFIG_PATH') is None:
+        with open('config.toml', 'rb') as f:
+            config = toml.load(f)
+    else:
+        with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+            config = toml.load(f)
+
+    return config
+
+#Loads models into a dictionary
+def load_models_v2(folder, device):
+    glob_path = os.path.join(folder, '*.pt')
+    model_files = glob.glob(glob_path)
+    model_dict = {}
+
+    for model_file in model_files:
+        model = torch.load(model_file, map_location=device)
+        model_id = os.path.basename(model_file).split('_')[0]
+        model_dict[model_id] = model
+
+    if len(model_dict) == 0:
+        raise FileNotFoundError('No models found in the specified directory: ' + folder)
+
+    return model_dict
+
+# Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
+def preprocess_data(data, device):
+    mri, xls = data
+    mri = mri.unsqueeze(0).to(device)
+    xls = xls.unsqueeze(0).to(device)
+    return (mri, xls)
+
+# Loads datasets and returns concatenated test and validation datasets
+def load_datasets(ensemble_path):
+    return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
+        f'{ensemble_path}/val_dataset.pt'
+    )
+
+# Gets the predictions for a set of models on a dataset
+def get_ensemble_predictions(models, dataset, device):
+    zeros = np.zeros((len(dataset), len(models), 4))
+    predictions = xr.DataArray(
+        zeros,
+        dims=('data_id', 'model_id', 'prediction_value'),
+        coords={
+            'data_id': range(len(dataset)),
+            'model_id': models.keys(),
+            'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
+        }
+    )
+
+    for data_id, (data, target) in tqdm(enumerate(dataset)):
+        mri, xls = preprocess_data(data, device)
+        actual = list(target.cpu().numpy())
+        for model_id, model in models.items():
+            with torch.no_grad():
+                output = model(mri, xls)
+                prediction = list(output.cpu().numpy())
+
+                predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
+
+    return predictions
+                
+# Compute the ensemble statistics given an array of predictions
+def compute_ensemble_statistics(predictions):
+    zeros = np.zeros((len(predictions.data_id), 7))
+
+    ensemble_statistics = xr.DataArray(
+        zeros,
+        dims=('data_id', 'statistic'),
+        coords={
+            'data_id': predictions.data_id,
+            'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
+        }
+    )   
+
+    for data_id in predictions.data_id:
+        data = predictions.loc[{ 'data_id': data_id }]
+        mean = np.mean(data, axis=0)
+        stdev = np.std(data, axis=0)
+        entropy = -np.sum(mean * np.log2(mean + 1e-12))
+        confidence = np.max(mean)
+        
+        actual = data.iloc[:, 3].values
+        predicted = np.argmax(mean)
+        correct = actual == predicted
+
+        ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
+
+    return ensemble_statistics
+
+# Compute the thresholded predictions given an array of predictions
+def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
+    quantiles = np.linspace(0.05, 0.95, 19)
+    metrics = ['accuracy', 'f1']
+    statistics = ['stdev', 'entropy', 'confidence']
+    
+    zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
+
+    thresholded_predictions = xr.DataArray(
+        zeros,
+        dims=('quantile', 'statistic', 'metric'),
+        coords={
+            'quantile': quantiles,
+            'statistic': statistics,
+            'metric': metrics
+        }
+    )
+
+    for statistic in statistics:
+        #First, we must compute the quantiles for the statistic
+        quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
+
+        #Then, we must compute the metrics for each quantile
+        for i, quantile in enumerate(quantiles):
+            if low_to_high(statistic):
+                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
+            else:
+                filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
+            for metric in metrics:
+                thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
+    
+    return thresholded_predictions
+                
+# Truth function to determine if metric should be thresholded low to high or high to low
+# Low confidence is bad, high entropy is bad, high stdev is bad
+# So we threshold confidence low to high, entropy and stdev high to low
+# So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
+def low_to_high(stat):
+    return stat in ['confidence']
+
+# Compute a given metric on a DataArray of statstics
+def compute_metric(arr, metric):
+    if metric == 'accuracy':
+        return np.mean(arr.loc[{ 'statistic': 'correct' }])
+    elif metric == 'f1':
+        return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
+    else:
+        raise ValueError('Invalid metric: ' + metric)
+
+
+def main():
+    config = load_config()
+    ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
+    V4_PATH = ENSEMBLE_PATH + '/v4'
+
+    if not os.path.exists(V4_PATH):
+        os.makedirs(V4_PATH)
+
+    # Load Datasets
+    dataset = load_datasets(ENSEMBLE_PATH)
+
+    # Get Predictions, either by running the models or loading them from a file
+    if config['ensemble']['run_models']:
+        # Load Models
+        device = torch.device(config['training']['device'])
+        models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
+
+        # Get Predictions
+        predictions = get_ensemble_predictions(models, dataset, device)
+
+        # Save Prediction
+        predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
+
+    else:
+        predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
+
+    # Compute Ensemble Statistics
+    ensemble_statistics = compute_ensemble_statistics(predictions)
+    ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
+
+    # Compute Thresholded Predictions
+    thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
+    thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')