mapping.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from typing import Callable, TypeVar
  2. import numpy as np
  3. from charged_shells.expansion import Expansion
  4. from charged_shells.parameters import ModelParams
  5. T = TypeVar('T')
  6. V = TypeVar('V')
  7. Array = np.ndarray
  8. def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
  9. """Map a function f over all leading axes of an expansion. Uses for loops, so it is kinda slow."""
  10. def mapped_f(ex: Expansion, *args, **kwargs):
  11. og_shape = ex.shape
  12. flat_ex = ex.flatten()
  13. results = []
  14. for i in range(int(np.prod(og_shape))):
  15. results.append(f(flat_ex[i], *args, **kwargs))
  16. try:
  17. return np.array(results).reshape(og_shape + results[0].shape)
  18. except AttributeError:
  19. return np.array(results).reshape(og_shape)
  20. return mapped_f
  21. def unravel_params(params: ModelParams) -> list[ModelParams]:
  22. if isinstance(params.R, Array) and isinstance(params.kappa, Array):
  23. # if this is to be implemented, watch out for implementations of mapping expansions that depend
  24. # on one of the parameters in ModelParams over other functions that also take the same ModelParameters
  25. raise NotImplementedError("Currently only unravel over a single parameter is supported. ")
  26. if isinstance(params.R, Array):
  27. return [ModelParams(R=r, kappa=params.kappa) for r in params.R]
  28. if isinstance(params.kappa, Array):
  29. return [ModelParams(R=params.R, kappa=kappa) for kappa in params.kappa]
  30. if not (isinstance(params.R, Array) or isinstance(params.kappa, Array)):
  31. return [params]
  32. SingleExpansionFn = Callable[[Expansion, ModelParams], T]
  33. TwoExpansionsFn = Callable[[Expansion, Expansion, ModelParams], T]
  34. def parameter_map_single_expansion(f: SingleExpansionFn,
  35. match_expansion_axis_to_params: int = None) -> SingleExpansionFn:
  36. meap = match_expansion_axis_to_params # just a shorter variable name
  37. def mapped_f(ex: Expansion, params: ModelParams):
  38. params_list = unravel_params(params)
  39. if meap is not None:
  40. expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meap)) for i in range(ex.shape[meap])]
  41. else:
  42. expansion_list = [ex for _ in params_list]
  43. if not len(expansion_list) == len(params_list):
  44. raise ValueError(f'Axis of expansion that is supposed to match params does not have the same length, got '
  45. f'len(unraveled params) = {len(params_list)} and '
  46. f'expansion.shape[{meap}] = {len(expansion_list)}')
  47. results = []
  48. for exp, prms in zip(expansion_list, params_list):
  49. results.append(f(exp, prms))
  50. if meap is not None:
  51. return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs
  52. return np.squeeze(np.array(results))
  53. return mapped_f
  54. def parameter_map_two_expansions(f: TwoExpansionsFn,
  55. match_expansion_axis_to_params: int = None) -> TwoExpansionsFn:
  56. meap = match_expansion_axis_to_params # just a shorter variable name
  57. def mapped_f(ex1: Expansion, ex2: Expansion, params: ModelParams):
  58. params_list = unravel_params(params)
  59. if meap is not None:
  60. expansion_list1 = [Expansion(ex1.l_array, np.take(ex1.coefs, i, axis=meap)) for i in range(ex1.shape[meap])]
  61. expansion_list2 = [Expansion(ex2.l_array, np.take(ex2.coefs, i, axis=meap)) for i in range(ex2.shape[meap])]
  62. else:
  63. expansion_list1 = [ex1 for _ in params_list]
  64. expansion_list2 = [ex2 for _ in params_list]
  65. if not (len(expansion_list1) == len(params_list) and len(expansion_list2) == len(params_list)):
  66. raise ValueError(f'Axis of at least one of the expansions that is supposed to match params '
  67. f'does not have the same length, got '
  68. f'len(unraveled params) = {len(params_list)} and '
  69. f'expansion1.shape[{meap}] = {len(expansion_list1)} and '
  70. f'expansion2.shape[{meap}] = {len(expansion_list2)}')
  71. results = []
  72. for exp1, exp2, prms in zip(expansion_list1, expansion_list2, params_list):
  73. results.append(f(exp1, exp2, prms))
  74. if meap is not None:
  75. return np.array(results).swapaxes(0, meap) # return the params-matched axis to where it belongs
  76. return np.squeeze(np.array(results))
  77. return mapped_f