interactions.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import expansion
  2. import parameters
  3. import functions as fn
  4. from py3nj import wigner3j
  5. import numpy as np
  6. import time
  7. from typing import Literal
  8. import units_and_constants as uc
  9. Array = np.ndarray
  10. Expansion = expansion.Expansion
  11. ModelParams = parameters.ModelParams
  12. EnergyUnit = Literal['kT', 'eV', 'J']
  13. def energy_units(units: EnergyUnit, params: ModelParams) -> float:
  14. match units:
  15. case 'eV':
  16. return 1 / (uc.CONSTANTS.e0 * uc.UNITS.voltage)
  17. case 'kT':
  18. return 1 / (params.temperature * uc.CONSTANTS.Boltzmann)
  19. case 'J':
  20. return uc.UNITS.energy
  21. def charged_shell_energy(ex1: Expansion, ex2: Expansion, dist: float, params: ModelParams, units: EnergyUnit = 'kT'):
  22. ex1, ex2 = expansion.expansions_to_common_l(ex1, ex2)
  23. dist = dist * params.R
  24. full_l_array, full_m_array = ex1.lm_arrays
  25. # determine indices of relevant elements in the sum
  26. indices_l, indices_p = np.nonzero(full_m_array[:, None] == full_m_array[None, :])
  27. flat_l = full_l_array[indices_l]
  28. flat_p = full_l_array[indices_p]
  29. flat_m = full_m_array[indices_l] # the same as full_m_array[indices_p]
  30. charge_factor = np.real(ex1.coefs[..., indices_l] * np.conj(ex2.coefs[..., indices_p]) +
  31. (-1) ** (flat_l + flat_p) * ex1.coefs[..., indices_p] * np.conj(ex2.coefs[..., indices_l]))
  32. all_s_array = np.arange(2 * ex1.max_l + 1)
  33. bessels = fn.sph_bessel_k(all_s_array, params.kappa * dist)
  34. # additional selection that excludes terms where Wigner 3j symbols are 0 by definition
  35. s_bool1 = np.abs(flat_l[:, None] - all_s_array[None, :]) <= flat_p[:, None]
  36. s_bool2 = flat_p[:, None] <= (flat_l[:, None] + all_s_array[None, :])
  37. indices_lpm, indices_s = np.nonzero(s_bool1 * s_bool2)
  38. l_vals = flat_l[indices_lpm]
  39. p_vals = flat_p[indices_lpm]
  40. m_vals = flat_m[indices_lpm]
  41. s_vals = all_s_array[indices_s]
  42. bessel_vals = bessels[indices_s]
  43. # While all other arrays are 1D, this one can have extra leading axes corresponding to leading dimensions
  44. # of expansion coefficients. The last dimension is the same as other arrays.
  45. charge_vals = charge_factor[..., indices_lpm]
  46. lps_terms = (2 * s_vals + 1) * np.sqrt((2 * l_vals + 1) * (2 * p_vals + 1))
  47. # the same combination of l, s, and p is repeated many times
  48. _, unique_indices1, inverse1 = np.unique(np.stack((l_vals, s_vals, p_vals)),
  49. axis=1, return_inverse=True, return_index=True)
  50. wigner1 = wigner3j(2 * l_vals[unique_indices1], 2 * s_vals[unique_indices1], 2 * p_vals[unique_indices1],
  51. 0, 0, 0)[inverse1]
  52. # all the combinations (l, s, p, m) are unique
  53. wigner2 = wigner3j(2 * l_vals, 2 * s_vals, 2 * p_vals,
  54. -2 * m_vals, 0, 2 * m_vals)
  55. constants = params.R ** 2 / (params.kappa * params.epsilon * uc.CONSTANTS.epsilon0) * energy_units(units, params)
  56. C_vals = fn.interaction_coeff_C(l_vals, p_vals, params.kappaR)
  57. lspm_vals = C_vals * (-1) ** (l_vals + m_vals) * lps_terms * bessel_vals * wigner1 * wigner2
  58. broadcasted_lspm_vals = np.broadcast_to(lspm_vals, charge_vals.shape)
  59. return 0.5 * constants * np.sum(broadcasted_lspm_vals * charge_vals, axis=-1)
  60. if __name__ == '__main__':
  61. params = ModelParams(R=150, kappaR=3)
  62. ex1 = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, max_l=20)
  63. ex2 = ex1.clone()
  64. dist = 2.
  65. # ex1, ex2 = expansions_to_common_l(ex1, ex2)
  66. # print(ex1.coeffs)
  67. # print(ex2.coeffs)
  68. t0 = time.perf_counter()
  69. energy = charged_shell_energy(ex1, ex2, dist, params)
  70. t1 = time.perf_counter()
  71. print('energy: ', energy)
  72. print('time: ', t1 - t0)
  73. # plt.plot(energy)
  74. # plt.show()