expansion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from __future__ import annotations
  2. import numpy as np
  3. from dataclasses import dataclass
  4. from functools import cached_property
  5. import functions as fn
  6. import quaternionic
  7. import spherical
  8. import time
  9. import copy
  10. import matplotlib.pyplot as plt
  11. from typing import Callable, TypeVar
  12. Array = np.ndarray
  13. Quaternion = quaternionic.array
  14. T = TypeVar('T')
  15. V = TypeVar('V')
  16. class InvalidExpansion(Exception):
  17. pass
  18. @dataclass
  19. class Expansion:
  20. """Generic class for storing surface charge expansion coefficients."""
  21. l_array: Array
  22. coefs: Array
  23. _starting_coefs: Array = None # initialized with the __post_init__ method
  24. _rotations: Quaternion = Quaternion([1., 0., 0., 0.])
  25. def __post_init__(self):
  26. if self.coefs.shape[-1] != np.sum(2 * self.l_array + 1):
  27. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  28. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  29. raise InvalidExpansion('Array of l values should be unique and sorted.')
  30. self.coefs = self.coefs.astype(np.complex128)
  31. self._starting_coefs = np.copy(self.coefs)
  32. def __getitem__(self, item):
  33. return Expansion(self.l_array, self.coefs[item])
  34. @property
  35. def max_l(self) -> int:
  36. return max(self.l_array)
  37. @property
  38. def shape(self):
  39. return self.coefs.shape[:-1]
  40. def flatten(self) -> Expansion:
  41. new_expansion = self.clone() # np.ndarray.flatten() also copies the array
  42. new_expansion.coefs = new_expansion.coefs.reshape(-1, new_expansion.coefs.shape[-1])
  43. new_expansion._rotations = new_expansion._rotations.reshape(-1, 4)
  44. return new_expansion
  45. def reshape(self, shape: tuple):
  46. self.coefs = self.coefs.reshape(shape + (self.coefs.shape[-1],))
  47. self._rotations = self._rotations.reshape(shape + (4,))
  48. @cached_property
  49. def lm_arrays(self) -> (Array, Array):
  50. """Return l and m arrays containing all (l, m) pairs."""
  51. return full_fm_arrays(self.l_array)
  52. def repeat_over_m(self, arr: Array, axis=0) -> Array:
  53. if not arr.shape[axis] == len(self.l_array):
  54. raise ValueError('Array length should be equal to the number of l in the expansion.')
  55. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  56. def rotate(self, rotations: Quaternion, rotate_existing=False):
  57. # TODO: rotations are currently saved wrong if we start form existing coefficients not the og ones
  58. self._rotations = rotations
  59. coefs = self.coefs if rotate_existing else self._starting_coefs
  60. self.coefs = expansion_rotation(rotations, coefs, self.l_array)
  61. def rotate_euler(self, alpha: Array, beta: Array, gamma: Array, rotate_existing=False):
  62. # TODO: additional care required on the convention used to transform euler angles to quaternions
  63. # TODO: might be off for a minus sign for each? angle !!
  64. R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
  65. self.rotate(R_euler, rotate_existing=rotate_existing)
  66. def charge_value(self, theta: Array | float, phi: Array | float):
  67. if not isinstance(theta, Array):
  68. theta = np.array([theta])
  69. if not isinstance(phi, Array):
  70. phi = np.array([phi])
  71. theta, phi = np.broadcast_arrays(theta, phi)
  72. full_l_array, full_m_array = self.lm_arrays
  73. return np.squeeze(np.real(np.sum(self.coefs[..., None] * fn.sph_harm(full_l_array[:, None],
  74. full_m_array[:, None],
  75. theta[None, :], phi[None, :]), axis=-2)))
  76. def clone(self) -> Expansion:
  77. return copy.deepcopy(self)
  78. class Expansion24(Expansion):
  79. def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
  80. l_array = np.array([0, 2, 4])
  81. coefs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
  82. super().__init__(l_array, coefs)
  83. class MappedExpansionQuad(Expansion):
  84. def __init__(self, a_bar: Array | float, kappaR: Array | float, sigma_m: float, l_max: int = 20, sigma0: float = 0):
  85. a_bar, kappaR = np.broadcast_arrays(a_bar, kappaR)
  86. l_array = np.array([l for l in range(l_max + 1) if l % 2 == 0])
  87. a_bar, kappaR, l_array_expanded = np.broadcast_arrays(a_bar[..., None], kappaR[..., None], l_array[None, :])
  88. coefs = (2 * sigma_m * fn.coef_C_diff(l_array_expanded, kappaR)
  89. * np.sqrt(4 * np.pi * (2 * l_array_expanded + 1)) * np.power(a_bar, l_array_expanded))
  90. coefs[..., 0] = sigma0
  91. coefs = rot_sym_expansion(l_array, coefs)
  92. super().__init__(l_array, coefs)
  93. class GaussianCharges(Expansion):
  94. def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int,
  95. sigma0: float = 0, equal_charges=True):
  96. omega_k = omega_k.reshape(-1, 2)
  97. if not isinstance(lambda_k, Array):
  98. lambda_k = np.array([lambda_k])
  99. if equal_charges:
  100. if lambda_k.ndim > 1:
  101. raise ValueError(f'If equal_charges=True, lambda_k should be a 1D array, got shape {lambda_k.shape}')
  102. lambda_k = np.full((omega_k.shape[0], lambda_k.shape[0]), lambda_k).T
  103. if lambda_k.shape[-1] != omega_k.shape[0]:
  104. raise ValueError("Number of charges (length of omega_k) should match the last dimension of lambda_k array.")
  105. lambda_k = lambda_k.reshape(-1, omega_k.shape[0])
  106. l_array = np.arange(l_max + 1)
  107. full_l_array, full_m_array = full_fm_arrays(l_array)
  108. theta_k = omega_k[:, 0]
  109. phi_k = omega_k[:, 1]
  110. summands = (lambda_k[:, None, :] / np.sinh(lambda_k[:, None, :])
  111. * fn.sph_bessel_i(full_l_array[None, :, None], lambda_k[:, None, :])
  112. * np.conj(fn.sph_harm(full_l_array[None, :, None], full_m_array[None, :, None],
  113. theta_k[None, None, :], phi_k[None, None, :])))
  114. coefs = np.squeeze(4 * np.pi * sigma1 * np.sum(summands, axis=-1))
  115. coefs[..., 0] = sigma0
  116. l_array, coefs = purge_unneeded_l(l_array, coefs)
  117. super().__init__(l_array, coefs)
  118. def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
  119. """Map a function f over the leading axes of an expansion. Uses for loops, so it is kinda slow."""
  120. def mapped_f(ex: Expansion, *args, **kwargs):
  121. og_shape = ex.shape
  122. flat_ex = ex.flatten()
  123. results = []
  124. for i in range(np.prod(og_shape)):
  125. results.append(f(flat_ex[i], *args, **kwargs))
  126. try:
  127. return np.array(results).reshape(og_shape + results[0].shape)
  128. except AttributeError:
  129. return np.array(results).reshape(og_shape)
  130. return mapped_f
  131. def full_fm_arrays(l_array: Array) -> (Array, Array):
  132. all_m_list = []
  133. for l in l_array:
  134. for i in range(2 * l + 1):
  135. all_m_list.append(-l + i)
  136. return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
  137. def rot_sym_expansion(l_array: Array, coefs: Array) -> Array:
  138. """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
  139. full_coefs = np.zeros(coefs.shape[:-1] + (np.sum(2 * l_array + 1),))
  140. full_coefs[..., np.cumsum(2 * l_array + 1) - l_array - 1] = coefs
  141. return full_coefs
  142. def coefs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
  143. missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
  144. fill = np.zeros(np.sum(2 * missing_l + 1))
  145. full_l_array1, _ = expansion.lm_arrays
  146. # we search for where to place missing coefs with the help of a boolean array and argmax function
  147. comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
  148. indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
  149. new_coefs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
  150. return Expansion(target_l_array, new_coefs)
  151. def m_indices_at_l(l_arr: Array, l_idx: int):
  152. return np.arange(np.sum(2 * l_arr[:l_idx] + 1), np.sum(2 * l_arr[:l_idx+1] + 1))
  153. def purge_unneeded_l(l_array: Array, coefs: Array) -> (Array, Array):
  154. def delete_zero_entries(l, l_arr, cfs):
  155. l_idx = np.where(l_arr == l)[0][0]
  156. m_indices = m_indices_at_l(l_arr, l_idx)
  157. if np.all(cfs[..., m_indices] == 0):
  158. return np.delete(l_arr, l_idx), np.delete(cfs, m_indices, axis=-1)
  159. return l_arr, cfs
  160. for l in l_array:
  161. l_array, coefs = delete_zero_entries(l, l_array, coefs)
  162. return l_array, coefs
  163. def plot_theta_profile(ex: Expansion, phi=0, num=100):
  164. theta_vals = np.linspace(0, np.pi, num)
  165. charge = ex.charge_value(theta_vals, phi)
  166. plt.plot(theta_vals, charge)
  167. plt.show()
  168. def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
  169. common_l_array = np.union1d(ex1.l_array, ex2.l_array)
  170. return coefs_fill_missing_l(ex1, common_l_array), coefs_fill_missing_l(ex2, common_l_array)
  171. def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
  172. """
  173. General function for rotations of expansion coefficients using WignerD matrices. Combines all rotations
  174. with each expansion given in coefs array.
  175. :param rotations: Quaternion array, last dimension is 4
  176. :param coefs: array of expansion coefficients
  177. :param l_array: array of all ell values of the expansion
  178. :return rotated coefficients, output shape is rotations.shape[:-1] + coefs.shape
  179. """
  180. rot_arrays = rotations.ndarray.reshape((-1, 4))
  181. coefs_reshaped = coefs.reshape((-1, coefs.shape[-1]))
  182. wigner_matrices = spherical.Wigner(np.max(l_array)).D(rot_arrays)
  183. new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
  184. for i, l in enumerate(l_array):
  185. Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
  186. all_m_indices = m_indices_at_l(l_array, i)
  187. wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
  188. new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
  189. wm, coefs_reshaped[:, all_m_indices])
  190. return new_coefs.reshape(rotations.ndarray.shape[:-1] + coefs.shape)
  191. if __name__ == '__main__':
  192. # ex = MappedExpansionQuad(0.44, 3, 1, 10)
  193. # ex = Expansion(np.arange(3), np.array([1, -1, 0, 1, 2, 0, 3, 0, 2]))
  194. ex = GaussianCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
  195. # print(ex.coefs)
  196. plot_theta_profile(ex)
  197. # new_coeffs = expansion_rotation(Quaternion(np.arange(20).reshape(5, 4)).normalized, ex.coeffs, ex.l_array)
  198. # print(new_coeffs.shape)
  199. #
  200. # newnew_coeffs = expansion_rotation(Quaternion(np.arange(16).reshape(4, 4)).normalized, new_coeffs, ex.l_array)
  201. # print(newnew_coeffs.shape)
  202. # rot_angles = np.linspace(0, np.pi, 1000)
  203. # t0 = time.time()
  204. # ex.rotate_euler(0, np.pi / 2, -1)
  205. # t1 = time.time()
  206. # print(ex.coefs)
  207. # print(t1 - t0)