1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import math
- import itertools as it
- import xarray as xr
- import numpy as np
- import matplotlib.pyplot as plt
- import torch as th
- import random as rand
- import threshold_xarray as txr
- import os
- def get_combinations(iterable, r, n_combinations):
- possible_combinations = math.comb(len(iterable), r)
- if n_combinations < possible_combinations:
- raise ValueError(
- f'Number of requested combinations {n_combinations} of length {r} on set of length {len(iterable)} is less than the possible number of combinations {possible_combinations}'
- )
- elif n_combinations == possible_combinations:
- return list(it.combinations(iterable, r))
- else:
- if n_combinations < possible_combinations / 5:
- combinations = []
- while len(combinations) < n_combinations:
- combination = rand.sample(iterable, r)
- if combination not in combinations:
- combinations.append(combination)
- return combinations
- else:
- combinations = list(it.combinations(iterable, r))
- return rand.sample(
- combinations, n_combinations
- )
-
- models = list(range(50))
- combos = {}
- for i in range(49):
- combos[i] = get_combinations(models, i + 1, 50)
- combos[50] = [models]
- print('Loading Config...')
- config = txr.load_config()
- ENSEMBLE_PATH = f"{config['paths']['model_output']}{config['ensemble']['name']}"
- V4_PATH = ENSEMBLE_PATH + '/v4'
- if not os.path.exists(V4_PATH):
- os.makedirs(V4_PATH)
- print('Config Loaded')
- test_predictions = xr.open_dataarray(f'{V4_PATH}/test_predictions.nc')
- val_predictions = xr.open_dataarray(f'{V4_PATH}/val_predictions.nc')
- print('Pruning Data...')
- if config['operation']['exclude_blank_ids']:
- excluded_data_ids = config['ensemble']['excluded_ids']
- test_predictions = txr.prune_data(test_predictions, excluded_data_ids)
- val_predictions = txr.prune_data(val_predictions, excluded_data_ids)
- predictions = xr.concat([test_predictions, val_predictions], dim='data_id')
- for num_models, model_combinations in combos.items():
- print(f'Calculating Accuracy for {num_models} Models')
- for i, model_combination in enumerate(model_combinations):
- print(f'Calculating Accuracy for Combination {i} of {len(model_combinations)}')
- model_predictions = predictions.sel(model=model_combination)
-
-
- num_correct =
|