expansion.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from __future__ import annotations
  2. import numpy as np
  3. from dataclasses import dataclass
  4. from functools import cached_property
  5. import functions as fn
  6. import quaternionic
  7. import spherical
  8. import time
  9. import copy
  10. import matplotlib.pyplot as plt
  11. from typing import Callable, TypeVar
  12. from scipy.special import eval_legendre
  13. import plotly.graph_objects as go
  14. Array = np.ndarray
  15. Quaternion = quaternionic.array
  16. T = TypeVar('T')
  17. V = TypeVar('V')
  18. class InvalidExpansion(Exception):
  19. pass
  20. @dataclass
  21. class Expansion:
  22. """Generic class for storing surface charge expansion coefficients."""
  23. l_array: Array
  24. coefs: Array
  25. _starting_coefs: Array = None # initialized with the __post_init__ method
  26. _rotations: Quaternion = Quaternion([1., 0., 0., 0.])
  27. def __post_init__(self):
  28. if self.coefs.shape[-1] != np.sum(2 * self.l_array + 1):
  29. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  30. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  31. raise InvalidExpansion('Array of l values should be unique and sorted.')
  32. self.coefs = self.coefs.astype(np.complex128)
  33. self._starting_coefs = np.copy(self.coefs)
  34. def __getitem__(self, item):
  35. return Expansion(self.l_array, self.coefs[item])
  36. @property
  37. def max_l(self) -> int:
  38. return max(self.l_array)
  39. @property
  40. def shape(self):
  41. return self.coefs.shape[:-1]
  42. def flatten(self) -> Expansion:
  43. new_expansion = self.clone() # np.ndarray.flatten() also copies the array
  44. new_expansion.coefs = new_expansion.coefs.reshape(-1, new_expansion.coefs.shape[-1])
  45. new_expansion._rotations = new_expansion._rotations.reshape(-1, 4)
  46. return new_expansion
  47. def reshape(self, shape: tuple):
  48. self.coefs = self.coefs.reshape(shape + (self.coefs.shape[-1],))
  49. self._rotations = self._rotations.reshape(shape + (4,))
  50. @cached_property
  51. def lm_arrays(self) -> (Array, Array):
  52. """Return l and m arrays containing all (l, m) pairs."""
  53. return full_fm_arrays(self.l_array)
  54. def repeat_over_m(self, arr: Array, axis=0) -> Array:
  55. if not arr.shape[axis] == len(self.l_array):
  56. raise ValueError('Array length should be equal to the number of l in the expansion.')
  57. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  58. def rotate(self, rotations: Quaternion, rotate_existing=False):
  59. # TODO: rotations are currently saved wrong if we start form existing coefficients not the og ones
  60. self._rotations = rotations
  61. coefs = self.coefs if rotate_existing else self._starting_coefs
  62. self.coefs = expansion_rotation(rotations, coefs, self.l_array)
  63. def rotate_euler(self, alpha: Array, beta: Array, gamma: Array, rotate_existing=False):
  64. # TODO: additional care required on the convention used to transform euler angles to quaternions
  65. # TODO: might be off for a minus sign for each? angle !!
  66. R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
  67. self.rotate(R_euler, rotate_existing=rotate_existing)
  68. def charge_value(self, theta: Array | float, phi: Array | float):
  69. if not isinstance(theta, Array):
  70. theta = np.array([theta])
  71. if not isinstance(phi, Array):
  72. phi = np.array([phi])
  73. theta, phi = np.broadcast_arrays(theta, phi)
  74. full_l_array, full_m_array = self.lm_arrays
  75. return np.squeeze(np.real(np.sum(self.coefs[..., None] * fn.sph_harm(full_l_array[:, None],
  76. full_m_array[:, None],
  77. theta[None, :], phi[None, :]), axis=-2)))
  78. def clone(self) -> Expansion:
  79. return copy.deepcopy(self)
  80. class Expansion24(Expansion):
  81. def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
  82. l_array = np.array([0, 2, 4])
  83. coefs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
  84. super().__init__(l_array, coefs)
  85. class MappedExpansionQuad(Expansion):
  86. def __init__(self, a_bar: Array | float, kappaR: Array | float, sigma_m: float, l_max: int = 20, sigma0: float = 0):
  87. a_bar, kappaR = np.broadcast_arrays(a_bar, kappaR)
  88. l_array = np.array([l for l in range(l_max + 1) if l % 2 == 0])
  89. a_bar, kappaR, l_array_expanded = np.broadcast_arrays(a_bar[..., None], kappaR[..., None], l_array[None, :])
  90. coefs = (2 * sigma_m * fn.coef_C_diff(l_array_expanded, kappaR)
  91. * np.sqrt(4 * np.pi * (2 * l_array_expanded + 1)) * np.power(a_bar, l_array_expanded))
  92. coefs[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  93. coefs = np.squeeze(rot_sym_expansion(l_array, coefs))
  94. super().__init__(l_array, coefs)
  95. class GaussianCharges(Expansion):
  96. def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int,
  97. sigma0: float = 0, equal_charges: bool = True):
  98. omega_k = omega_k.reshape(-1, 2)
  99. if not isinstance(lambda_k, Array):
  100. lambda_k = np.array([lambda_k])
  101. if equal_charges:
  102. if lambda_k.ndim > 1:
  103. raise ValueError(f'If equal_charges=True, lambda_k should be a 1D array, got shape {lambda_k.shape}')
  104. lambda_k = np.full((omega_k.shape[0], lambda_k.shape[0]), lambda_k).T
  105. if lambda_k.shape[-1] != omega_k.shape[0]:
  106. raise ValueError("Number of charges (length of omega_k) should match the last dimension of lambda_k array.")
  107. lambda_k = lambda_k.reshape(-1, omega_k.shape[0])
  108. l_array = np.arange(l_max + 1)
  109. full_l_array, full_m_array = full_fm_arrays(l_array)
  110. theta_k = omega_k[:, 0]
  111. phi_k = omega_k[:, 1]
  112. summands = (lambda_k[:, None, :] / np.sinh(lambda_k[:, None, :])
  113. * fn.sph_bessel_i(full_l_array[None, :, None], lambda_k[:, None, :])
  114. * np.conj(fn.sph_harm(full_l_array[None, :, None], full_m_array[None, :, None],
  115. theta_k[None, None, :], phi_k[None, None, :])))
  116. coefs = np.squeeze(4 * np.pi * sigma1 * np.sum(summands, axis=-1))
  117. coefs[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  118. l_array, coefs = purge_unneeded_l(l_array, coefs)
  119. super().__init__(l_array, coefs)
  120. class SphericalCap(Expansion):
  121. def __init__(self, omega_k: Array, theta0_k: Array | float, sigma1: float, l_max: int, sigma0: float = 0,
  122. equal_sizes: bool = True):
  123. omega_k = omega_k.reshape(-1, 2)
  124. if not isinstance(theta0_k, Array):
  125. theta0_k = np.array(theta0_k)
  126. if equal_sizes:
  127. if theta0_k.ndim == 0:
  128. theta0_k = np.full(omega_k.shape[0], theta0_k)
  129. elif theta0_k.ndim == 1:
  130. theta0_k = np.full((omega_k.shape[0], theta0_k.shape[0]), theta0_k)
  131. else:
  132. raise ValueError(f'If equal_charges=True, theta0_k should be a 1D array, got shape {theta0_k.shape}')
  133. if theta0_k.shape[0] != omega_k.shape[0]:
  134. raise ValueError("Number of charges (length of omega_k) should match the last dimension of theta0_k array.")
  135. rotations = Quaternion(np.stack((np.cos(omega_k[..., 0] / 2),
  136. np.sin(omega_k[..., 1]) * np.sin(omega_k[..., 0] / 2),
  137. np.cos(omega_k[..., 1]) * np.sin(omega_k[..., 0] / 2),
  138. np.zeros_like(omega_k[..., 0]))).T)
  139. l_array = np.arange(l_max + 1)
  140. coefs_l0 = -sigma1 * (np.sqrt(np.pi / (2 * l_array[None, :] + 1)) *
  141. (eval_legendre(l_array[None, :] + 1, np.cos(theta0_k[..., None]))
  142. - eval_legendre(l_array[None, :] - 1, np.cos(theta0_k[..., None]))))
  143. coefs = rot_sym_expansion(l_array, coefs_l0)
  144. coefs_all_single_caps = expansion_rotation(rotations, coefs, l_array)
  145. # Rotating is implemented in such a way that it rotates every patch to every position,
  146. # hence the redundancy of out of diagonal elements.
  147. coefs_all = np.sum(np.diagonal(coefs_all_single_caps), axis=-1)
  148. if sigma0 is not None:
  149. coefs_all[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  150. super().__init__(l_array, np.squeeze(coefs_all))
  151. def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
  152. """Map a function f over the leading axes of an expansion. Uses for loops, so it is kinda slow."""
  153. def mapped_f(ex: Expansion, *args, **kwargs):
  154. og_shape = ex.shape
  155. flat_ex = ex.flatten()
  156. results = []
  157. for i in range(int(np.prod(og_shape))):
  158. results.append(f(flat_ex[i], *args, **kwargs))
  159. try:
  160. return np.array(results).reshape(og_shape + results[0].shape)
  161. except AttributeError:
  162. return np.array(results).reshape(og_shape)
  163. return mapped_f
  164. def full_fm_arrays(l_array: Array) -> (Array, Array):
  165. all_m_list = []
  166. for l in l_array:
  167. for i in range(2 * l + 1):
  168. all_m_list.append(-l + i)
  169. return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
  170. def rot_sym_expansion(l_array: Array, coefs: Array) -> Array:
  171. """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
  172. full_coefs = np.zeros(coefs.shape[:-1] + (np.sum(2 * l_array + 1),))
  173. full_coefs[..., np.cumsum(2 * l_array + 1) - l_array - 1] = coefs
  174. return full_coefs
  175. def coefs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
  176. missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
  177. fill = np.zeros(np.sum(2 * missing_l + 1))
  178. full_l_array1, _ = expansion.lm_arrays
  179. # we search for where to place missing coefs with the help of a boolean array and argmax function
  180. comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
  181. indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
  182. new_coefs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
  183. return Expansion(target_l_array, new_coefs)
  184. def m_indices_at_l(l_arr: Array, l_idx: int):
  185. return np.arange(np.sum(2 * l_arr[:l_idx] + 1), np.sum(2 * l_arr[:l_idx+1] + 1))
  186. def purge_unneeded_l(l_array: Array, coefs: Array) -> (Array, Array):
  187. def delete_zero_entries(l, l_arr, cfs):
  188. l_idx = np.where(l_arr == l)[0][0]
  189. m_indices = m_indices_at_l(l_arr, l_idx)
  190. if np.all(cfs[..., m_indices] == 0):
  191. return np.delete(l_arr, l_idx), np.delete(cfs, m_indices, axis=-1)
  192. return l_arr, cfs
  193. for l in l_array:
  194. l_array, coefs = delete_zero_entries(l, l_array, coefs)
  195. return l_array, coefs
  196. def plot_theta_profile(ex: Expansion, phi: float = 0, num: int = 100, theta_start: float = 0, theta_end: float = np.pi):
  197. theta_vals = np.linspace(theta_start, theta_end, num)
  198. charge = ex.charge_value(theta_vals, phi)
  199. plt.plot(theta_vals, charge.T)
  200. plt.show()
  201. def plot_charge_3d(ex: Expansion, num_theta=100, num_phi=100):
  202. theta = np.linspace(0, np.pi, num_theta)
  203. phi = np.linspace(0, 2 * np.pi, num_phi)
  204. theta, phi = np.meshgrid(theta, phi)
  205. r = ex.charge_value(theta.flatten(), phi.flatten()).reshape(theta.shape)
  206. # Convert spherical coordinates to Cartesian coordinates
  207. x = np.sin(theta) * np.cos(phi)
  208. y = np.sin(theta) * np.sin(phi)
  209. z = np.cos(theta)
  210. # Create a heatmap on the sphere
  211. fig = go.Figure(data=go.Surface(x=x, y=y, z=z, surfacecolor=r, colorscale='Jet'))
  212. fig.update_layout(scene=dict(aspectmode='data'))
  213. fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
  214. fig.show()
  215. def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
  216. common_l_array = np.union1d(ex1.l_array, ex2.l_array)
  217. return coefs_fill_missing_l(ex1, common_l_array), coefs_fill_missing_l(ex2, common_l_array)
  218. def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
  219. """
  220. General function for rotations of expansion coefficients using WignerD matrices. Combines all rotations
  221. with each expansion given in coefs array.
  222. :param rotations: Quaternion array, last dimension is 4
  223. :param coefs: array of expansion coefficients
  224. :param l_array: array of all ell values of the expansion
  225. :return rotated coefficients, output shape is rotations.shape[:-1] + coefs.shape
  226. """
  227. rot_arrays = rotations.ndarray.reshape((-1, 4))
  228. coefs_reshaped = coefs.reshape((-1, coefs.shape[-1]))
  229. wigner_matrices = spherical.Wigner(np.max(l_array)).D(rot_arrays)
  230. new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
  231. for i, l in enumerate(l_array):
  232. Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
  233. all_m_indices = m_indices_at_l(l_array, i)
  234. wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
  235. new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
  236. wm, coefs_reshaped[:, all_m_indices])
  237. return new_coefs.reshape(rotations.ndarray.shape[:-1] + coefs.shape)
  238. if __name__ == '__main__':
  239. # ex = MappedExpansionQuad(0.44, 3, 1, 10)
  240. # ex = Expansion(np.arange(3), np.array([1, -1, 0, 1, 2, 0, 3, 0, 2]))
  241. # ex = GaussianCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
  242. ex = SphericalCap(np.array([[0, 0], [np.pi, 0]]), 0.5, 0.1, 70)
  243. print(np.real(ex.coefs))
  244. plot_theta_profile(ex, num=1000, theta_end=2 * np.pi, phi=0)
  245. # plot_charge_3d(ex)
  246. # new_coeffs = expansion_rotation(Quaternion(np.arange(20).reshape(5, 4)).normalized, ex.coeffs, ex.l_array)
  247. # print(new_coeffs.shape)
  248. #
  249. # newnew_coeffs = expansion_rotation(Quaternion(np.arange(16).reshape(4, 4)).normalized, new_coeffs, ex.l_array)
  250. # print(newnew_coeffs.shape)
  251. # rot_angles = np.linspace(0, np.pi, 1000)
  252. # t0 = time.time()
  253. # ex.rotate_euler(0, np.pi / 2, -1)
  254. # t1 = time.time()
  255. # print(ex.coefs)
  256. # print(t1 - t0)