xarray_sensitivity.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # For this project, we want to calculate the accuracy for many different numbers of models and model selections
  2. # We cannot calculate every possible permutation and combination of models, so we will use a random selection
  3. import math
  4. import itertools as it
  5. import xarray as xr
  6. import numpy as np
  7. import matplotlib.pyplot as plt
  8. import torch as th
  9. import random as rand
  10. import threshold_xarray as txr
  11. import os
  12. # Generate a selection of combinations given an iterable, combination length, and number of combinations
  13. # This checks the number of possible combinations and compares it to the requested number of combinations
  14. # If the number of requested combinations is more than the possible number of combinations, it wil error
  15. # If it is less, it will generate all possible combinations and select the requested number of combinations
  16. # If it is much less, it will randomly generate the requested number of combinations and check that they are unique
  17. # If it is equal, it will generate and return all possible combinations\
  18. def get_combinations(iterable, r, n_combinations):
  19. possible_combinations = math.comb(len(iterable), r)
  20. if n_combinations < possible_combinations:
  21. raise ValueError(
  22. f'Number of requested combinations {n_combinations} of length {r} on set of length {len(iterable)} is less than the possible number of combinations {possible_combinations}'
  23. )
  24. elif n_combinations == possible_combinations:
  25. return list(it.combinations(iterable, r))
  26. else:
  27. if n_combinations < possible_combinations / 5:
  28. combinations = []
  29. while len(combinations) < n_combinations:
  30. combination = rand.sample(iterable, r)
  31. if combination not in combinations:
  32. combinations.append(combination)
  33. return combinations
  34. else:
  35. combinations = list(it.combinations(iterable, r)) # All possible combinations
  36. return rand.sample(
  37. combinations, n_combinations
  38. ) # Randomly select n_combinations
  39. # Now that we have a function to generate combinations, we can generate a list of 49 * 50 + 1 combinations
  40. # This will be a list of 2451 combinations of 50 models
  41. models = list(range(50))
  42. combos = {}
  43. for i in range(49):
  44. combos[i] = get_combinations(models, i + 1, 50)
  45. combos[50] = [models]
  46. # Now that we have the list of combinations, we need the predictions
  47. print('Loading Config...')
  48. config = txr.load_config()
  49. ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
  50. V4_PATH = ENSEMBLE_PATH + '/v4'
  51. if not os.path.exists(V4_PATH):
  52. os.makedirs(V4_PATH)
  53. print('Config Loaded')
  54. test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
  55. val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
  56. # Prune Data
  57. print('Pruning Data...')
  58. if config['operation']['exclude_blank_ids']:
  59. excluded_data_ids = config['ensemble']['excluded_ids']
  60. test_predictions = txr.prune_data(test_predictions, excluded_data_ids)
  61. val_predictions = txr.prune_data(val_predictions, excluded_data_ids)
  62. # Concatenate Predictions
  63. predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
  64. # Now that we have the list of predictions, we can calculate the accuracy and other stats for each combination
  65. # We will calculate the accuracy for each combination of models and save the results
  66. # Calculate the accuracy for each combination of models
  67. for num_models, model_combinations in combos.items():
  68. print(f'Calculating Accuracy for {num_models} Models')
  69. for i, model_combination in enumerate(model_combinations):
  70. print(f'Calculating Accuracy for Combination {i} of {len(model_combinations)}')
  71. model_predictions = predictions.sel(model=model_combination)
  72. # Calculate the accuracy
  73. num_correct =