123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- #Rewritten Program to use xarray instead of pandas for thresholding
- import xarray as xr
- import torch
- import numpy as np
- import os
- import glob
- import tomli as toml
- from tqdm import tqdm
- import utils.metrics as met
- if __name__ == '__main__':
- main()
- #The datastructures for this file are as follows
- # models_dict: Dictionary - {model_id: model}
- # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
- # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
- # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
- #Additionally, we also have the thresholds and statistics for the individual models
- # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
- # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
- #Additionally, we have some for the sensitivity analysis for number of models
- # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
- # Loads configuration dictionary
- def load_config():
- if os.getenv('ADL_CONFIG_PATH') is None:
- with open('config.toml', 'rb') as f:
- config = toml.load(f)
- else:
- with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
- config = toml.load(f)
- return config
- #Loads models into a dictionary
- def load_models_v2(folder, device):
- glob_path = os.path.join(folder, '*.pt')
- model_files = glob.glob(glob_path)
- model_dict = {}
- for model_file in model_files:
- model = torch.load(model_file, map_location=device)
- model_id = os.path.basename(model_file).split('_')[0]
- model_dict[model_id] = model
- if len(model_dict) == 0:
- raise FileNotFoundError('No models found in the specified directory: ' + folder)
- return model_dict
- # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
- def preprocess_data(data, device):
- mri, xls = data
- mri = mri.unsqueeze(0).to(device)
- xls = xls.unsqueeze(0).to(device)
- return (mri, xls)
- # Loads datasets and returns concatenated test and validation datasets
- def load_datasets(ensemble_path):
- return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
- f'{ensemble_path}/val_dataset.pt'
- )
- # Gets the predictions for a set of models on a dataset
- def get_ensemble_predictions(models, dataset, device):
- zeros = np.zeros((len(dataset), len(models), 4))
- predictions = xr.DataArray(
- zeros,
- dims=('data_id', 'model_id', 'prediction_value'),
- coords={
- 'data_id': range(len(dataset)),
- 'model_id': models.keys(),
- 'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
- }
- )
- for data_id, (data, target) in tqdm(enumerate(dataset)):
- mri, xls = preprocess_data(data, device)
- actual = list(target.cpu().numpy())
- for model_id, model in models.items():
- with torch.no_grad():
- output = model(mri, xls)
- prediction = list(output.cpu().numpy())
- predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
- return predictions
-
- # Compute the ensemble statistics given an array of predictions
- def compute_ensemble_statistics(predictions):
- zeros = np.zeros((len(predictions.data_id), 7))
- ensemble_statistics = xr.DataArray(
- zeros,
- dims=('data_id', 'statistic'),
- coords={
- 'data_id': predictions.data_id,
- 'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
- }
- )
- for data_id in predictions.data_id:
- data = predictions.loc[{ 'data_id': data_id }]
- mean = np.mean(data, axis=0)
- stdev = np.std(data, axis=0)
- entropy = -np.sum(mean * np.log2(mean + 1e-12))
- confidence = np.max(mean)
-
- actual = data.iloc[:, 3].values
- predicted = np.argmax(mean)
- correct = actual == predicted
- ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
- return ensemble_statistics
- # Compute the thresholded predictions given an array of predictions
- def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
- quantiles = np.linspace(0.05, 0.95, 19)
- metrics = ['accuracy', 'f1']
- statistics = ['stdev', 'entropy', 'confidence']
-
- zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
- thresholded_predictions = xr.DataArray(
- zeros,
- dims=('quantile', 'statistic', 'metric'),
- coords={
- 'quantile': quantiles,
- 'statistic': statistics,
- 'metric': metrics
- }
- )
- for statistic in statistics:
- #First, we must compute the quantiles for the statistic
- quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
- #Then, we must compute the metrics for each quantile
- for i, quantile in enumerate(quantiles):
- if low_to_high(statistic):
- filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
- else:
- filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
- for metric in metrics:
- thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
-
- return thresholded_predictions
-
- # Truth function to determine if metric should be thresholded low to high or high to low
- # Low confidence is bad, high entropy is bad, high stdev is bad
- # So we threshold confidence low to high, entropy and stdev high to low
- # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
- def low_to_high(stat):
- return stat in ['confidence']
- # Compute a given metric on a DataArray of statstics
- def compute_metric(arr, metric):
- if metric == 'accuracy':
- return np.mean(arr.loc[{ 'statistic': 'correct' }])
- elif metric == 'f1':
- return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
- else:
- raise ValueError('Invalid metric: ' + metric)
- def main():
- config = load_config()
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V4_PATH = ENSEMBLE_PATH + '/v4'
- if not os.path.exists(V4_PATH):
- os.makedirs(V4_PATH)
- # Load Datasets
- dataset = load_datasets(ENSEMBLE_PATH)
- # Get Predictions, either by running the models or loading them from a file
- if config['ensemble']['run_models']:
- # Load Models
- device = torch.device(config['training']['device'])
- models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
- # Get Predictions
- predictions = get_ensemble_predictions(models, dataset, device)
- # Save Prediction
- predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
- else:
- predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
- # Compute Ensemble Statistics
- ensemble_statistics = compute_ensemble_statistics(predictions)
- ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
- # Compute Thresholded Predictions
- thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
- thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')
|