123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from typing import Callable, TypeVar
- import numpy as np
- from charged_shells.expansion import Expansion
- from charged_shells.parameters import ModelParams
- T = TypeVar('T')
- V = TypeVar('V')
- Array = np.ndarray
- def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
- """Map a function f over all leading axes of an expansion. Uses for loops, so it is kinda slow."""
- def mapped_f(ex: Expansion, *args, **kwargs):
- og_shape = ex.shape
- flat_ex = ex.flatten()
- results = []
- for i in range(int(np.prod(og_shape))):
- results.append(f(flat_ex[i], *args, **kwargs))
- try:
- return np.array(results).reshape(og_shape + results[0].shape)
- except AttributeError:
- return np.array(results).reshape(og_shape)
- return mapped_f
- def unravel_params(params: ModelParams) -> list[ModelParams]:
- if isinstance(params.R, Array) and isinstance(params.kappa, Array):
- # if this is to be implemented, watch out for implementations of mapping expansions that depend
- # on one of the parameters in ModelParams over other functions that also take the same ModelParameters
- raise NotImplementedError("Currently only unravel over a single parameter is supported. ")
- if isinstance(params.R, Array):
- return [ModelParams(R=r, kappa=params.kappa) for r in params.R]
- if isinstance(params.kappa, Array):
- return [ModelParams(R=params.R, kappa=kappa) for kappa in params.kappa]
- if not (isinstance(params.R, Array) or isinstance(params.kappa, Array)):
- return [params]
- SingleExpansionFn = Callable[[Expansion, ModelParams], T]
- TwoExpansionsFn = Callable[[Expansion, Expansion, ModelParams], T]
- def parameter_map_single_expansion(f: SingleExpansionFn,
- match_expansion_axis_to_params: int = None) -> SingleExpansionFn:
- meap = match_expansion_axis_to_params # just a shorter variable name
- def mapped_f(ex: Expansion, params: ModelParams):
- params_list = unravel_params(params)
- if meap is not None:
- expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meap)) for i in range(ex.shape[meap])]
- else:
- expansion_list = [ex for _ in params_list]
- if not len(expansion_list) == len(params_list):
- raise ValueError(f'Axis of expansion that is supposed to match params does not have the same length, got '
- f'len(unraveled params) = {len(params_list)} and '
- f'expansion.shape[{meap}] = {len(expansion_list)}')
- results = []
- for exp, prms in zip(expansion_list, params_list):
- results.append(f(exp, prms))
- if meap is not None:
- return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs
- return np.squeeze(np.array(results))
- return mapped_f
- def parameter_map_two_expansions(f: TwoExpansionsFn,
- match_expansion_axis_to_params: int = None) -> TwoExpansionsFn:
- meap = match_expansion_axis_to_params # just a shorter variable name
- def mapped_f(ex1: Expansion, ex2: Expansion, params: ModelParams):
- params_list = unravel_params(params)
- if meap is not None:
- expansion_list1 = [Expansion(ex1.l_array, np.take(ex1.coefs, i, axis=meap)) for i in range(ex1.shape[meap])]
- expansion_list2 = [Expansion(ex2.l_array, np.take(ex2.coefs, i, axis=meap)) for i in range(ex2.shape[meap])]
- else:
- expansion_list1 = [ex1 for _ in params_list]
- expansion_list2 = [ex2 for _ in params_list]
- if not (len(expansion_list1) == len(params_list) and len(expansion_list2) == len(params_list)):
- raise ValueError(f'Axis of at least one of the expansions that is supposed to match params '
- f'does not have the same length, got '
- f'len(unraveled params) = {len(params_list)} and '
- f'expansion1.shape[{meap}] = {len(expansion_list1)} and '
- f'expansion2.shape[{meap}] = {len(expansion_list2)}')
- results = []
- for exp1, exp2, prms in zip(expansion_list1, expansion_list2, params_list):
- results.append(f(exp1, exp2, prms))
- if meap is not None:
- return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs
- return np.squeeze(np.array(results))
- return mapped_f
|