interactions.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from charged_shells import expansion, parameters
  2. import charged_shells.functions as fn
  3. from py3nj import wigner3j
  4. import numpy as np
  5. from typing import Literal
  6. import charged_shells.units_and_constants as uc
  7. Array = np.ndarray
  8. Expansion = expansion.Expansion
  9. ModelParams = parameters.ModelParams
  10. EnergyUnit = Literal['kT', 'eV', 'J']
  11. def energy_units(units: EnergyUnit, params: ModelParams) -> float:
  12. match units:
  13. case 'eV':
  14. # return 1 / (uc.CONSTANTS.e0 * uc.UNITS.voltage)
  15. return 1.
  16. case 'kT':
  17. return 1 / (params.temperature * uc.CONSTANTS.Boltzmann)
  18. case 'J':
  19. return uc.UNITS.energy
  20. def charged_shell_energy(ex1: Expansion, ex2: Expansion, params: ModelParams, dist: float = 2, units: EnergyUnit = 'kT',
  21. chunk_size: int = 1000):
  22. ex1, ex2 = expansion.expansions_to_common_l(ex1, ex2)
  23. dist = dist * params.R
  24. full_l_array, full_m_array = ex1.lm_arrays
  25. # determine indices of relevant elements in the sum
  26. indices_l, indices_p = np.nonzero(full_m_array[:, None] == full_m_array[None, :])
  27. flat_l = full_l_array[indices_l]
  28. flat_p = full_l_array[indices_p]
  29. flat_m = full_m_array[indices_l] # the same as full_m_array[indices_p]
  30. relevant_pairs, = np.nonzero(flat_l >= flat_p)
  31. flat_l = flat_l[relevant_pairs]
  32. flat_p = flat_p[relevant_pairs]
  33. flat_m = flat_m[relevant_pairs]
  34. indices_l = indices_l[relevant_pairs]
  35. indices_p = indices_p[relevant_pairs]
  36. charge_factor = np.real(ex1.coefs[..., indices_l] * np.conj(ex2.coefs[..., indices_p]) +
  37. (-1) ** (flat_l + flat_p) * ex1.coefs[..., indices_p] * np.conj(ex2.coefs[..., indices_l]))
  38. all_s_array = np.arange(2 * ex1.max_l + 1)
  39. bessels = fn.sph_bessel_k(all_s_array, params.kappa * dist)
  40. # additional selection that excludes terms where Wigner 3j symbols are 0 by definition
  41. s_bool1 = np.abs(flat_l[:, None] - all_s_array[None, :]) <= flat_p[:, None]
  42. s_bool2 = flat_p[:, None] <= (flat_l[:, None] + all_s_array[None, :])
  43. indices_lpm_all, indices_s_all = np.nonzero(s_bool1 * s_bool2)
  44. # indices array can get really large (a lot of combinations) so we split the calculation into chunks to preserve RAM
  45. # interestingly, this also leads to performance improvements if chunks are still large enough
  46. if chunk_size is None:
  47. chunk_size = len(indices_lpm_all)
  48. num_sections = np.ceil(len(indices_lpm_all) / chunk_size)
  49. energy = 0
  50. for indices_lpm, indices_s in zip(np.array_split(indices_lpm_all, num_sections),
  51. np.array_split(indices_s_all, num_sections)):
  52. l_vals = flat_l[indices_lpm]
  53. p_vals = flat_p[indices_lpm]
  54. m_vals = flat_m[indices_lpm]
  55. s_vals = all_s_array[indices_s]
  56. bessel_vals = bessels[indices_s]
  57. # While all other arrays are 1D, this one can have extra leading axes corresponding to leading dimensions
  58. # of expansion coefficients. The last dimension is the same as other arrays.
  59. charge_vals = charge_factor[..., indices_lpm]
  60. lps_terms = (2 * s_vals + 1) * np.sqrt((2 * l_vals + 1) * (2 * p_vals + 1))
  61. # the same combination of l, s, and p is repeated many times
  62. _, unique_indices1, inverse1 = np.unique(np.stack((l_vals, s_vals, p_vals)),
  63. axis=1, return_inverse=True, return_index=True)
  64. wigner1 = wigner3j(2 * l_vals[unique_indices1], 2 * s_vals[unique_indices1], 2 * p_vals[unique_indices1],
  65. 0, 0, 0)[inverse1]
  66. # all the combinations (l, s, p, m) are unique
  67. wigner2 = wigner3j(2 * l_vals, 2 * s_vals, 2 * p_vals,
  68. -2 * m_vals, 0, 2 * m_vals)
  69. constants = params.R ** 2 / (params.kappa * params.epsilon * uc.CONSTANTS.epsilon0) * energy_units(units, params)
  70. C_vals = fn.interaction_coef_C(l_vals, p_vals, params.kappaR)
  71. lspm_vals = C_vals * (-1) ** (l_vals + m_vals) * lps_terms * bessel_vals * wigner1 * wigner2
  72. broadcasted_lspm_vals = np.broadcast_to(lspm_vals, charge_vals.shape)
  73. rescale_at_equal_lp = np.where(l_vals == p_vals, 0.5, 1)
  74. energy += constants * np.sum(rescale_at_equal_lp * broadcasted_lspm_vals * charge_vals, axis=-1)
  75. return energy