12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from typing import Callable, TypeVar, Protocol
- import numpy as np
- from charged_shells.expansion import Expansion
- from charged_shells.parameters import ModelParams
- T = TypeVar('T')
- V = TypeVar('V')
- Array = np.ndarray
- def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
- """Map a function f over all leading axes of an expansion. Uses for loops, so it is kinda slow."""
- def mapped_f(ex: Expansion, *args, **kwargs):
- og_shape = ex.shape
- flat_ex = ex.flatten()
- results = []
- for i in range(int(np.prod(og_shape))):
- results.append(f(flat_ex[i], *args, **kwargs))
- try:
- return np.array(results).reshape(og_shape + results[0].shape)
- except AttributeError:
- return np.array(results).reshape(og_shape)
- return mapped_f
- def unravel_params(params: ModelParams) -> list[ModelParams]:
- if isinstance(params.R, Array) and isinstance(params.kappa, Array):
- # if this is to be implemented, watch out for implementations of mapping expansions that depend
- # on one of the parameters in ModelParams over other functions that also take the same ModelParameters
- raise NotImplementedError("Currently only unravel over a single parameter is supported. ")
- if isinstance(params.R, Array):
- return [ModelParams(R=r, kappa=params.kappa) for r in params.R]
- if isinstance(params.kappa, Array):
- return [ModelParams(R=params.R, kappa=kappa) for kappa in params.kappa]
- if not (isinstance(params.R, Array) or isinstance(params.kappa, Array)):
- return [params]
- raise NotImplementedError
- def unravel_expansion_over_axis(ex: Expansion, axis: int | None, param_list_len: int) -> list[Expansion]:
- if axis is None:
- return [ex for _ in range(param_list_len)]
- axis_len = ex.shape[axis]
- if axis_len != param_list_len:
- raise ValueError(f'Parameter list has different length than the provided expansion axis, '
- f'got param_list_len={param_list_len} and axis_len={axis_len}.')
- return [Expansion(ex.l_array, np.take(ex.coefs, i, axis=axis)) for i in range(axis_len)]
- SingleExpansionFn = Callable[[Expansion, ModelParams], T]
- # TwoExpansionsFn = Callable[[Expansion, Expansion, ModelParams], T]
- class TwoExpansionsFn(Protocol):
- def __call__(self, ex1: Expansion, ex2: Expansion, params: ModelParams, **kwargs) -> T:
- ...
- def parameter_map_single_expansion(f: SingleExpansionFn,
- match_expansion_axis_to_params: int = None) -> SingleExpansionFn:
- def mapped_f(ex: Expansion, params: ModelParams):
- params_list = unravel_params(params)
- expansion_list = unravel_expansion_over_axis(ex, match_expansion_axis_to_params, len(params_list))
- results = []
- for exp, prms in zip(expansion_list, params_list):
- results.append(f(exp, prms))
- if match_expansion_axis_to_params is not None:
- # return the params-matched axis to where it belongs
- return np.moveaxis(np.array(results), 0, match_expansion_axis_to_params)
- return np.squeeze(np.array(results))
- return mapped_f
- def parameter_map_two_expansions(f: TwoExpansionsFn,
- match_expansion_axis_to_params: int = None,
- ) -> TwoExpansionsFn:
- def mapped_f(ex1: Expansion, ex2: Expansion, params: ModelParams, **kwargs):
- params_list = unravel_params(params)
- expansion_list1 = unravel_expansion_over_axis(ex1, match_expansion_axis_to_params, len(params_list))
- expansion_list2 = unravel_expansion_over_axis(ex2, match_expansion_axis_to_params, len(params_list))
- results = []
- for exp1, exp2, prms in zip(expansion_list1, expansion_list2, params_list):
- results.append(f(exp1, exp2, prms, **kwargs))
- if match_expansion_axis_to_params is not None:
- # return the params-matched axis to where it belongs
- return np.moveaxis(np.array(results), 0, match_expansion_axis_to_params)
- return np.squeeze(np.array(results))
- return mapped_f
|