threshold_xarray.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #Rewritten Program to use xarray instead of pandas for thresholding
  2. import xarray as xr
  3. import torch
  4. import numpy as np
  5. import os
  6. import glob
  7. import tomli as toml
  8. from tqdm import tqdm
  9. import utils.metrics as met
  10. if __name__ == '__main__':
  11. main()
  12. #The datastructures for this file are as follows
  13. # models_dict: Dictionary - {model_id: model}
  14. # predictions: DataArray - (data_id, model_id, prediction_value) - Prediction value has coords ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  15. # ensemble_statistics: DataArray - (data_id, statistic) - Statistic has coords ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  16. # thresholded_predictions: DataArray - (quantile, statistic, metric) - Metric has coords ['accuracy, 'f1'] - only use 'stdev', 'entropy', 'confidence' for statistic
  17. #Additionally, we also have the thresholds and statistics for the individual models
  18. # indv_statistics: DataArray - (data_id, model_id, statistic) - Statistic has coords ['mean', 'entropy', 'confidence', 'correct', 'predicted', 'actual'] - No stdev as it cannot be calculated for a single model
  19. # indv_thresholds: DataArray - (model_id, quantile, statistic, metric) - Metric has coords ['accuracy', 'f1'] - only use 'entropy', 'confidence' for statistic
  20. #Additionally, we have some for the sensitivity analysis for number of models
  21. # sensitivity_statistics: DataArray - (data_id, model_count, statistic) - Statistic has coords ['accuracy', 'f1', 'ECE', 'MCE']
  22. # Loads configuration dictionary
  23. def load_config():
  24. if os.getenv('ADL_CONFIG_PATH') is None:
  25. with open('config.toml', 'rb') as f:
  26. config = toml.load(f)
  27. else:
  28. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  29. config = toml.load(f)
  30. return config
  31. #Loads models into a dictionary
  32. def load_models_v2(folder, device):
  33. glob_path = os.path.join(folder, '*.pt')
  34. model_files = glob.glob(glob_path)
  35. model_dict = {}
  36. for model_file in model_files:
  37. model = torch.load(model_file, map_location=device)
  38. model_id = os.path.basename(model_file).split('_')[0]
  39. model_dict[model_id] = model
  40. if len(model_dict) == 0:
  41. raise FileNotFoundError('No models found in the specified directory: ' + folder)
  42. return model_dict
  43. # Ensures that both mri and xls tensors in the data are unsqueezed and are on the correct device
  44. def preprocess_data(data, device):
  45. mri, xls = data
  46. mri = mri.unsqueeze(0).to(device)
  47. xls = xls.unsqueeze(0).to(device)
  48. return (mri, xls)
  49. # Loads datasets and returns concatenated test and validation datasets
  50. def load_datasets(ensemble_path):
  51. return torch.load(f'{ensemble_path}/test_dataset.pt') + torch.load(
  52. f'{ensemble_path}/val_dataset.pt'
  53. )
  54. # Gets the predictions for a set of models on a dataset
  55. def get_ensemble_predictions(models, dataset, device):
  56. zeros = np.zeros((len(dataset), len(models), 4))
  57. predictions = xr.DataArray(
  58. zeros,
  59. dims=('data_id', 'model_id', 'prediction_value'),
  60. coords={
  61. 'data_id': range(len(dataset)),
  62. 'model_id': models.keys(),
  63. 'prediction_value': ['negative_prediction', 'positive_prediction', 'negative_actual', 'positive_actual']
  64. }
  65. )
  66. for data_id, (data, target) in tqdm(enumerate(dataset)):
  67. mri, xls = preprocess_data(data, device)
  68. actual = list(target.cpu().numpy())
  69. for model_id, model in models.items():
  70. with torch.no_grad():
  71. output = model(mri, xls)
  72. prediction = list(output.cpu().numpy())
  73. predictions.loc[{ 'data_id': data_id, 'model_id': model_id }] = prediction + actual
  74. return predictions
  75. # Compute the ensemble statistics given an array of predictions
  76. def compute_ensemble_statistics(predictions):
  77. zeros = np.zeros((len(predictions.data_id), 7))
  78. ensemble_statistics = xr.DataArray(
  79. zeros,
  80. dims=('data_id', 'statistic'),
  81. coords={
  82. 'data_id': predictions.data_id,
  83. 'statistic': ['mean', 'stdev', 'entropy', 'confidence', 'correct', 'predicted', 'actual']
  84. }
  85. )
  86. for data_id in predictions.data_id:
  87. data = predictions.loc[{ 'data_id': data_id }]
  88. mean = np.mean(data, axis=0)
  89. stdev = np.std(data, axis=0)
  90. entropy = -np.sum(mean * np.log2(mean + 1e-12))
  91. confidence = np.max(mean)
  92. actual = data.iloc[:, 3].values
  93. predicted = np.argmax(mean)
  94. correct = actual == predicted
  95. ensemble_statistics.loc[{ 'data_id': data_id }] = [mean, stdev, entropy, confidence, correct, predicted, actual]
  96. return ensemble_statistics
  97. # Compute the thresholded predictions given an array of predictions
  98. def compute_thresholded_predictions(ensemble_statistics: xr.DataArray):
  99. quantiles = np.linspace(0.05, 0.95, 19)
  100. metrics = ['accuracy', 'f1']
  101. statistics = ['stdev', 'entropy', 'confidence']
  102. zeros = np.zeros((len(quantiles), len(statistics), len(metrics)))
  103. thresholded_predictions = xr.DataArray(
  104. zeros,
  105. dims=('quantile', 'statistic', 'metric'),
  106. coords={
  107. 'quantile': quantiles,
  108. 'statistic': statistics,
  109. 'metric': metrics
  110. }
  111. )
  112. for statistic in statistics:
  113. #First, we must compute the quantiles for the statistic
  114. quantile_values = np.quantiles(ensemble_statistics.loc[{ 'statistic': statistic }].values, quantiles, axis=0)
  115. #Then, we must compute the metrics for each quantile
  116. for i, quantile in enumerate(quantiles):
  117. if low_to_high(statistic):
  118. filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] > quantile_values[i], drop=True)
  119. else:
  120. filtered_data = ensemble_statistics.where(ensemble_statistics.loc[{ 'statistic': statistic }] < quantile_values[i], drop=True)
  121. for metric in metrics:
  122. thresholded_predictions.loc[{ 'quantile': quantile, 'statistic': statistic, 'metric': metric }] = compute_metric(filtered_data, metric)
  123. return thresholded_predictions
  124. # Truth function to determine if metric should be thresholded low to high or high to low
  125. # Low confidence is bad, high entropy is bad, high stdev is bad
  126. # So we threshold confidence low to high, entropy and stdev high to low
  127. # So any values BELOW the cutoff are removed for confidence, and any values ABOVE the cutoff are removed for entropy and stdev
  128. def low_to_high(stat):
  129. return stat in ['confidence']
  130. # Compute a given metric on a DataArray of statstics
  131. def compute_metric(arr, metric):
  132. if metric == 'accuracy':
  133. return np.mean(arr.loc[{ 'statistic': 'correct' }])
  134. elif metric == 'f1':
  135. return met.F1(arr.loc[{'statistic': 'predicted'}], arr.loc[{'statistic': 'actual'}])
  136. else:
  137. raise ValueError('Invalid metric: ' + metric)
  138. def main():
  139. config = load_config()
  140. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  141. V4_PATH = ENSEMBLE_PATH + '/v4'
  142. if not os.path.exists(V4_PATH):
  143. os.makedirs(V4_PATH)
  144. # Load Datasets
  145. dataset = load_datasets(ENSEMBLE_PATH)
  146. # Get Predictions, either by running the models or loading them from a file
  147. if config['ensemble']['run_models']:
  148. # Load Models
  149. device = torch.device(config['training']['device'])
  150. models = load_models_v2(f'{ENSEMBLE_PATH}/models/', device)
  151. # Get Predictions
  152. predictions = get_ensemble_predictions(models, dataset, device)
  153. # Save Prediction
  154. predictions.to_netcdf(f'{V4_PATH}/predictions.nc')
  155. else:
  156. predictions = xr.open_dataarray(f'{V4_PATH}/predictions.nc')
  157. # Compute Ensemble Statistics
  158. ensemble_statistics = compute_ensemble_statistics(predictions)
  159. ensemble_statistics.to_netcdf(f'{V4_PATH}/ensemble_statistics.nc')
  160. # Compute Thresholded Predictions
  161. thresholded_predictions = compute_thresholded_predictions(ensemble_statistics)
  162. thresholded_predictions.to_netcdf(f'{V4_PATH}/thresholded_predictions.nc')