from typing import Callable, TypeVar import numpy as np from charged_shells.expansion import Expansion from charged_shells.parameters import ModelParams T = TypeVar('T') V = TypeVar('V') Array = np.ndarray def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]: """Map a function f over all leading axes of an expansion. Uses for loops, so it is kinda slow.""" def mapped_f(ex: Expansion, *args, **kwargs): og_shape = ex.shape flat_ex = ex.flatten() results = [] for i in range(int(np.prod(og_shape))): results.append(f(flat_ex[i], *args, **kwargs)) try: return np.array(results).reshape(og_shape + results[0].shape) except AttributeError: return np.array(results).reshape(og_shape) return mapped_f def unravel_params(params: ModelParams) -> list[ModelParams]: if isinstance(params.R, Array) and isinstance(params.kappa, Array): # if this is to be implemented, watch out for implementations of mapping expansions that depend # on one of the parameters in ModelParams over other functions that also take the same ModelParameters raise NotImplementedError("Currently only unravel over a single parameter is supported. ") if isinstance(params.R, Array): return [ModelParams(R=r, kappa=params.kappa) for r in params.R] if isinstance(params.kappa, Array): return [ModelParams(R=params.R, kappa=kappa) for kappa in params.kappa] if not (isinstance(params.R, Array) or isinstance(params.kappa, Array)): return [params] SingleExpansionFn = Callable[[Expansion, ModelParams], T] TwoExpansionsFn = Callable[[Expansion, Expansion, ModelParams], T] def parameter_map_single_expansion(f: SingleExpansionFn, match_expansion_axis_to_params: int = None) -> SingleExpansionFn: meap = match_expansion_axis_to_params # just a shorter variable name def mapped_f(ex: Expansion, params: ModelParams): params_list = unravel_params(params) if meap is not None: expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meap)) for i in range(ex.shape[meap])] else: expansion_list = [ex for _ in params_list] if not len(expansion_list) == len(params_list): raise ValueError(f'Axis of expansion that is supposed to match params does not have the same length, got ' f'len(unraveled params) = {len(params_list)} and ' f'expansion.shape[{meap}] = {len(expansion_list)}') results = [] for exp, prms in zip(expansion_list, params_list): results.append(f(exp, prms)) if meap is not None: return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs return np.squeeze(np.array(results)) return mapped_f def parameter_map_two_expansions(f: TwoExpansionsFn, match_expansion_axis_to_params: int = None) -> TwoExpansionsFn: meap = match_expansion_axis_to_params # just a shorter variable name def mapped_f(ex1: Expansion, ex2: Expansion, params: ModelParams): params_list = unravel_params(params) if meap is not None: expansion_list1 = [Expansion(ex1.l_array, np.take(ex1.coefs, i, axis=meap)) for i in range(ex1.shape[meap])] expansion_list2 = [Expansion(ex2.l_array, np.take(ex2.coefs, i, axis=meap)) for i in range(ex2.shape[meap])] else: expansion_list1 = [ex1 for _ in params_list] expansion_list2 = [ex2 for _ in params_list] if not (len(expansion_list1) == len(params_list) and len(expansion_list2) == len(params_list)): raise ValueError(f'Axis of at least one of the expansions that is supposed to match params ' f'does not have the same length, got ' f'len(unraveled params) = {len(params_list)} and ' f'expansion1.shape[{meap}] = {len(expansion_list1)} and ' f'expansion2.shape[{meap}] = {len(expansion_list2)}') results = [] for exp1, exp2, prms in zip(expansion_list1, expansion_list2, params_list): results.append(f(exp1, exp2, prms)) if meap is not None: return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs return np.squeeze(np.array(results)) return mapped_f