patch_size.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import numpy as np
  2. from scipy.optimize import bisect, root_scalar
  3. from charged_shells import expansion, mapping, parameters, potentials
  4. from typing import Callable
  5. Expansion = expansion.Expansion
  6. Array = np.ndarray
  7. ModelParams = parameters.ModelParams
  8. @mapping.map_over_expansion
  9. def charge_patch_size(ex: Expansion, phi: float = 0, theta0: Array | float = 0, theta1: Array | float = np.pi / 2):
  10. return bisect(lambda theta: ex.charge_value(theta, phi), theta0, theta1)
  11. def potential_patch_size(ex: Expansion, params: ModelParams,
  12. phi: float = 0, theta0: Array | float = 0, theta1: Array | float = np.pi / 2,
  13. match_expansion_axis_to_params: int = None):
  14. # this is more complicate to map over leading axes of the expansion as potential also depends on model parameters,
  15. # with some, such as kappaR, also being the parameters of the expansion in the first place. When mapping,
  16. # we must therefore provide the expansion axis that should match the collection of parameters in params.
  17. @mapping.map_over_expansion
  18. def potential_zero(exp: Expansion, prms: ModelParams):
  19. return bisect(lambda theta: potentials.charged_shell_potential(theta, phi, 1, exp, prms), theta0, theta1)
  20. return mapping.parameter_map_single_expansion(potential_zero, match_expansion_axis_to_params)(ex, params)
  21. def inverse_potential_patch_size(target_patch_size: float,
  22. ex_generator: Callable[[float], Expansion],
  23. x0: float,
  24. params: ModelParams, **ps_kwargs):
  25. def patch_size_dif(x):
  26. ex = ex_generator(x)
  27. return potential_patch_size(ex, params, **ps_kwargs) - target_patch_size
  28. root_result = root_scalar(patch_size_dif, x0=x0)
  29. if not root_result.converged:
  30. raise ValueError('No convergence. Patches of desired size might not be achievable in the given model. '
  31. 'Conversely, a common mistake might be target patch size input in degrees.')
  32. return root_result.root