expansion.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from __future__ import annotations
  2. import numpy as np
  3. from dataclasses import dataclass
  4. from functools import cached_property
  5. import functions as fn
  6. import quaternionic
  7. import spherical
  8. import time
  9. import copy
  10. import matplotlib.pyplot as plt
  11. Array = np.ndarray
  12. Quaternion = quaternionic.array
  13. class InvalidExpansion(Exception):
  14. pass
  15. @dataclass
  16. class Expansion:
  17. """Generic class for storing surface charge expansion coefficients."""
  18. l_array: Array
  19. coefs: Array
  20. _starting_coefs: Array = None # initialized with the __post_init__ method
  21. _rotations: Quaternion = Quaternion([1., 0., 0., 0.])
  22. def __post_init__(self):
  23. if self.coefs.shape[-1] != np.sum(2 * self.l_array + 1):
  24. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  25. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  26. raise InvalidExpansion('Array of l values should be unique and sorted.')
  27. self.coefs = self.coefs.astype(np.complex128)
  28. self._starting_coefs = np.copy(self.coefs)
  29. @property
  30. def max_l(self) -> int:
  31. return max(self.l_array)
  32. @cached_property
  33. def lm_arrays(self) -> (Array, Array):
  34. """Return l and m arrays containing all (l, m) pairs."""
  35. return full_fm_arrays(self.l_array)
  36. def repeat_over_m(self, arr: Array, axis=0) -> Array:
  37. if not arr.shape[axis] == len(self.l_array):
  38. raise ValueError('Array length should be equal to the number of l in the expansion.')
  39. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  40. def rotate(self, rotations: Quaternion, rotate_existing=False):
  41. self._rotations = rotations
  42. coefs = self.coefs if rotate_existing else self._starting_coefs
  43. self.coefs = expansion_rotation(rotations, coefs, self.l_array)
  44. def rotate_euler(self, alpha: Array, beta: Array, gamma: Array, rotate_existing=False):
  45. # TODO: additional care required on the convention used to transform euler angles to quaternions
  46. # TODO: might be off for a minus sign for each? angle !!
  47. R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
  48. self.rotate(R_euler, rotate_existing=rotate_existing)
  49. def charge_value(self, theta: Array | float, phi: Array | float):
  50. if not isinstance(theta, Array):
  51. theta = np.array([theta])
  52. if not isinstance(phi, Array):
  53. phi = np.array([phi])
  54. theta, phi = np.broadcast_arrays(theta, phi)
  55. full_l_array, full_m_array = self.lm_arrays
  56. return np.sum(self.coefs[None, :] * fn.sph_harm(full_l_array[None, :], full_m_array[None, :],
  57. theta[:, None], phi[:, None]), axis=1)
  58. def clone(self) -> Expansion:
  59. return copy.deepcopy(self)
  60. class Expansion24(Expansion):
  61. def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
  62. l_array = np.array([0, 2, 4])
  63. coeffs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
  64. super().__init__(l_array, coeffs)
  65. class MappedExpansionQuad(Expansion):
  66. def __init__(self, a_bar: float, kappaR: float, sigma_m: float, max_l: int = 20, sigma0: float = 0):
  67. l_array = np.array([l for l in range(max_l + 1) if l % 2 == 0])
  68. coeffs = (2 * sigma_m * fn.coeff_C_diff(l_array, kappaR)
  69. * np.sqrt(4 * np.pi * (2 * l_array + 1)) * np.power(a_bar, l_array))
  70. coeffs[0] = sigma0
  71. coeffs = rot_sym_expansion(l_array, coeffs)
  72. super().__init__(l_array, coeffs)
  73. class SmearedCharges(Expansion):
  74. def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int, sigma0: float = 0):
  75. omega_k = omega_k.reshape(-1, 2)
  76. if not isinstance(lambda_k, Array):
  77. lambda_k = np.full((len(omega_k),), lambda_k)
  78. if lambda_k.shape[-1] != omega_k.shape[0]:
  79. raise ValueError("Omega and lambda arrays should have the same length.")
  80. l_array = np.arange(l_max + 1)
  81. full_l_array, full_m_array = full_fm_arrays(l_array)
  82. theta_k = omega_k[:, 0]
  83. phi_k = omega_k[:, 1]
  84. summands = (lambda_k[None, :] / np.sinh(lambda_k[None, :])
  85. * fn.sph_bessel_i(full_l_array[:, None], lambda_k[None, :])
  86. * np.conj(fn.sph_harm(full_l_array[:, None], full_m_array[:, None],
  87. theta_k[None, :], phi_k[None, :])))
  88. coefs = 4 * np.pi * sigma1 * np.sum(summands, axis=1)
  89. coefs[0] = sigma0
  90. super().__init__(l_array, coefs)
  91. def full_fm_arrays(l_array: Array) -> (Array, Array):
  92. all_m_list = []
  93. for l in l_array:
  94. for i in range(2 * l + 1):
  95. all_m_list.append(-l + i)
  96. return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
  97. def rot_sym_expansion(l_array: Array, coeffs: Array) -> Array:
  98. """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
  99. full_coeffs = np.zeros(np.sum(2 * l_array + 1))
  100. full_coeffs[np.cumsum(2 * l_array + 1) - l_array - 1] = coeffs
  101. return full_coeffs
  102. def coeffs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
  103. missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
  104. fill = np.zeros(np.sum(2 * missing_l + 1))
  105. full_l_array1, _ = expansion.lm_arrays
  106. # we search for where to place missing coeffs with the help of a boolean array and argmax function
  107. comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
  108. indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
  109. new_coeffs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
  110. return Expansion(target_l_array, new_coeffs)
  111. def plot_theta_profile(ex: Expansion, phi=0, num=100):
  112. theta_vals = np.linspace(0, np.pi, num)
  113. charge = ex.charge_value(theta_vals, phi)
  114. plt.plot(theta_vals, charge)
  115. plt.show()
  116. def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
  117. common_l_array = np.union1d(ex1.l_array, ex2.l_array)
  118. return coeffs_fill_missing_l(ex1, common_l_array), coeffs_fill_missing_l(ex2, common_l_array)
  119. def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
  120. """
  121. General function for rotations of expansion coefficients using WignerD matrices. Combines all rotations
  122. with each expansion given in coefs array.
  123. :param rotations: Quaternion array, last dimension is 4
  124. :param coefs: array of expansion coefficients
  125. :param l_array: array of all ell values of the expansion
  126. :return rotated coefficients, output shape is rotations.shape[:-1] + coefs.shape
  127. """
  128. rot_arrays = rotations.ndarray.reshape((-1, 4))
  129. coefs_reshaped = coefs.reshape((-1, coefs.shape[-1]))
  130. wigner_matrices = spherical.Wigner(np.max(l_array)).D(rot_arrays)
  131. new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
  132. for i, l in enumerate(l_array):
  133. Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
  134. all_m_indices = np.arange(np.sum(2 * l_array[:i] + 1), np.sum(2 * l_array[:i + 1] + 1))
  135. wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
  136. new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
  137. wm, coefs_reshaped[:, all_m_indices])
  138. return new_coefs.reshape(rotations.ndarray.shape[:-1] + coefs.shape)
  139. if __name__ == '__main__':
  140. # ex = MappedExpansionQuad(0.44, 3, 1, 10)
  141. # ex = Expansion(np.arange(3), np.array([1, -1, 0, 1, 2, 0, 3, 0, 2]))
  142. ex = SmearedCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
  143. # print(ex.coefs)
  144. plot_theta_profile(ex)
  145. # new_coeffs = expansion_rotation(Quaternion(np.arange(20).reshape(5, 4)).normalized, ex.coeffs, ex.l_array)
  146. # print(new_coeffs.shape)
  147. #
  148. # newnew_coeffs = expansion_rotation(Quaternion(np.arange(16).reshape(4, 4)).normalized, new_coeffs, ex.l_array)
  149. # print(newnew_coeffs.shape)
  150. # rot_angles = np.linspace(0, np.pi, 1000)
  151. # t0 = time.time()
  152. # ex.rotate_euler(0, np.pi / 2, -1)
  153. # t1 = time.time()
  154. # print(ex.coefs)
  155. # print(t1 - t0)