calibration_xarray.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import threshold_xarray as th
  2. import xarray as xr
  3. import numpy as np
  4. import torch
  5. import os
  6. import sklearn.calibration as cal
  7. # The purpose of this file is to calibrate the data on the test set, and evaluate the calibration on the validation set.
  8. # We are using scikits calibration library to do this.
  9. if __name__ == '__main__':
  10. print('Loading Config..B')
  11. config = th.load_config()
  12. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  13. V4_PATH = ENSEMBLE_PATH + '/v4'
  14. if not os.path.exists(V4_PATH):
  15. os.makedirs(V4_PATH)
  16. print('Config Loaded')
  17. # Load the predictions
  18. print('Loading Predictions...')
  19. val_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/val_predictions.nc')
  20. test_preds = xr.open_dataset(f'{ENSEMBLE_PATH}/test_predictions.nc')
  21. print('Predictions Loaded')
  22. # Now the goal is to calibrate the test set, and evaluate the calibration on the validation set.
  23. # We do this by binning the data into 15 bins, and then calculating the mean of the predictions in each bin.
  24. # We then use this to calibrate the data.
  25. # First, get the statistics of both sets
  26. print('Calculating Statistics...')
  27. val_stats = th.compute_ensemble_statistics(val_preds)
  28. test_stats = th.compute_ensemble_statistics(test_preds)
  29. # Calibrate the test set
  30. print('Calibrating Test Set...')