expansion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from __future__ import annotations
  2. import numpy as np
  3. from dataclasses import dataclass
  4. from functools import cached_property
  5. import functions as fn
  6. import quaternionic
  7. import spherical
  8. import time
  9. import copy
  10. import matplotlib.pyplot as plt
  11. from typing import Callable, TypeVar
  12. from scipy.special import eval_legendre
  13. import plotly.graph_objects as go
  14. from pathlib import Path
  15. Array = np.ndarray
  16. Quaternion = quaternionic.array
  17. T = TypeVar('T')
  18. V = TypeVar('V')
  19. class InvalidExpansion(Exception):
  20. pass
  21. @dataclass
  22. class Expansion:
  23. """Generic class for storing surface charge expansion coefficients."""
  24. l_array: Array
  25. coefs: Array
  26. _starting_coefs: Array = None # initialized with the __post_init__ method
  27. _rotations: Quaternion = Quaternion([1., 0., 0., 0.])
  28. def __post_init__(self):
  29. if self.coefs.shape[-1] != np.sum(2 * self.l_array + 1):
  30. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  31. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  32. raise InvalidExpansion('Array of l values should be unique and sorted.')
  33. self.coefs = self.coefs.astype(np.complex128)
  34. self._starting_coefs = np.copy(self.coefs)
  35. def __getitem__(self, item):
  36. return Expansion(self.l_array, self.coefs[item])
  37. @property
  38. def max_l(self) -> int:
  39. return max(self.l_array)
  40. @property
  41. def shape(self):
  42. return self.coefs.shape[:-1]
  43. def flatten(self) -> Expansion:
  44. new_expansion = self.clone() # np.ndarray.flatten() also copies the array
  45. new_expansion.coefs = new_expansion.coefs.reshape(-1, new_expansion.coefs.shape[-1])
  46. new_expansion._rotations = new_expansion._rotations.reshape(-1, 4)
  47. return new_expansion
  48. def reshape(self, shape: tuple):
  49. self.coefs = self.coefs.reshape(shape + (self.coefs.shape[-1],))
  50. self._rotations = self._rotations.reshape(shape + (4,))
  51. @cached_property
  52. def lm_arrays(self) -> (Array, Array):
  53. """Return l and m arrays containing all (l, m) pairs."""
  54. return full_fm_arrays(self.l_array)
  55. def repeat_over_m(self, arr: Array, axis=0) -> Array:
  56. if not arr.shape[axis] == len(self.l_array):
  57. raise ValueError('Array length should be equal to the number of l in the expansion.')
  58. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  59. def rotate(self, rotations: Quaternion, rotate_existing=False):
  60. # TODO: rotations are currently saved wrong if we start form existing coefficients not the og ones
  61. self._rotations = rotations
  62. coefs = self.coefs if rotate_existing else self._starting_coefs
  63. self.coefs = expansion_rotation(rotations, coefs, self.l_array)
  64. def rotate_euler(self, alpha: Array, beta: Array, gamma: Array, rotate_existing=False):
  65. # TODO: additional care required on the convention used to transform euler angles to quaternions
  66. # TODO: might be off for a minus sign for each? angle !!
  67. R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
  68. self.rotate(R_euler, rotate_existing=rotate_existing)
  69. def charge_value(self, theta: Array | float, phi: Array | float):
  70. if not isinstance(theta, Array):
  71. theta = np.array([theta])
  72. if not isinstance(phi, Array):
  73. phi = np.array([phi])
  74. theta, phi = np.broadcast_arrays(theta, phi)
  75. full_l_array, full_m_array = self.lm_arrays
  76. return np.squeeze(np.real(np.sum(self.coefs[..., None] * fn.sph_harm(full_l_array[:, None],
  77. full_m_array[:, None],
  78. theta[None, :], phi[None, :]), axis=-2)))
  79. def clone(self) -> Expansion:
  80. return copy.deepcopy(self)
  81. class Expansion24(Expansion):
  82. def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
  83. l_array = np.array([0, 2, 4])
  84. coefs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
  85. super().__init__(l_array, coefs)
  86. class MappedExpansionQuad(Expansion):
  87. def __init__(self, a_bar: Array | float, kappaR: Array | float, sigma_m: float, l_max: int = 20, sigma0: float = 0):
  88. a_bar, kappaR = np.broadcast_arrays(a_bar, kappaR)
  89. l_array = np.array([l for l in range(l_max + 1) if l % 2 == 0])
  90. a_bar, kappaR, l_array_expanded = np.broadcast_arrays(a_bar[..., None], kappaR[..., None], l_array[None, :])
  91. coefs = (2 * sigma_m * fn.coef_C_diff(l_array_expanded, kappaR)
  92. * np.sqrt(4 * np.pi * (2 * l_array_expanded + 1)) * np.power(a_bar, l_array_expanded))
  93. coefs[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  94. coefs = np.squeeze(rot_sym_expansion(l_array, coefs))
  95. super().__init__(l_array, coefs)
  96. class GaussianCharges(Expansion):
  97. def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int,
  98. sigma0: float = 0, equal_charges: bool = True):
  99. omega_k = omega_k.reshape(-1, 2)
  100. if not isinstance(lambda_k, Array):
  101. lambda_k = np.array([lambda_k])
  102. if equal_charges:
  103. if lambda_k.ndim > 1:
  104. raise ValueError(f'If equal_charges=True, lambda_k should be a 1D array, got shape {lambda_k.shape}')
  105. lambda_k = np.full((omega_k.shape[0], lambda_k.shape[0]), lambda_k).T
  106. if lambda_k.shape[-1] != omega_k.shape[0]:
  107. raise ValueError("Number of charges (length of omega_k) should match the last dimension of lambda_k array.")
  108. lambda_k = lambda_k.reshape(-1, omega_k.shape[0])
  109. l_array = np.arange(l_max + 1)
  110. full_l_array, full_m_array = full_fm_arrays(l_array)
  111. theta_k = omega_k[:, 0]
  112. phi_k = omega_k[:, 1]
  113. summands = (lambda_k[:, None, :] / np.sinh(lambda_k[:, None, :])
  114. * fn.sph_bessel_i(full_l_array[None, :, None], lambda_k[:, None, :])
  115. * np.conj(fn.sph_harm(full_l_array[None, :, None], full_m_array[None, :, None],
  116. theta_k[None, None, :], phi_k[None, None, :])))
  117. coefs = np.squeeze(4 * np.pi * sigma1 * np.sum(summands, axis=-1))
  118. coefs[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  119. l_array, coefs = purge_unneeded_l(l_array, coefs)
  120. super().__init__(l_array, coefs)
  121. class SphericalCap(Expansion):
  122. def __init__(self, omega_k: Array, theta0_k: Array | float, sigma1: float, l_max: int, sigma0: float = 0,
  123. equal_sizes: bool = True):
  124. omega_k = omega_k.reshape(-1, 2)
  125. if not isinstance(theta0_k, Array):
  126. theta0_k = np.array(theta0_k)
  127. if equal_sizes:
  128. if theta0_k.ndim == 0:
  129. theta0_k = np.full(omega_k.shape[0], theta0_k)
  130. elif theta0_k.ndim == 1:
  131. theta0_k = np.full((omega_k.shape[0], theta0_k.shape[0]), theta0_k)
  132. else:
  133. raise ValueError(f'If equal_charges=True, theta0_k should be a 1D array, got shape {theta0_k.shape}')
  134. if theta0_k.shape[0] != omega_k.shape[0]:
  135. raise ValueError("Number of charges (length of omega_k) should match the last dimension of theta0_k array.")
  136. rotations = Quaternion(np.stack((np.cos(omega_k[..., 0] / 2),
  137. np.sin(omega_k[..., 1]) * np.sin(omega_k[..., 0] / 2),
  138. np.cos(omega_k[..., 1]) * np.sin(omega_k[..., 0] / 2),
  139. np.zeros_like(omega_k[..., 0]))).T)
  140. l_array = np.arange(l_max + 1)
  141. coefs_l0 = -sigma1 * (np.sqrt(np.pi / (2 * l_array[None, :] + 1)) *
  142. (eval_legendre(l_array[None, :] + 1, np.cos(theta0_k[..., None]))
  143. - eval_legendre(l_array[None, :] - 1, np.cos(theta0_k[..., None]))))
  144. coefs = rot_sym_expansion(l_array, coefs_l0)
  145. coefs_all_single_caps = expansion_rotation(rotations, coefs, l_array)
  146. # Rotating is implemented in such a way that it rotates every patch to every position,
  147. # hence the redundancy of out of diagonal elements.
  148. coefs_all = np.sum(np.diagonal(coefs_all_single_caps), axis=-1)
  149. if sigma0 is not None:
  150. coefs_all[..., 0] = sigma0 / np.sqrt(4 * np.pi)
  151. super().__init__(l_array, np.squeeze(coefs_all))
  152. def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
  153. """Map a function f over the leading axes of an expansion. Uses for loops, so it is kinda slow."""
  154. def mapped_f(ex: Expansion, *args, **kwargs):
  155. og_shape = ex.shape
  156. flat_ex = ex.flatten()
  157. results = []
  158. for i in range(int(np.prod(og_shape))):
  159. results.append(f(flat_ex[i], *args, **kwargs))
  160. try:
  161. return np.array(results).reshape(og_shape + results[0].shape)
  162. except AttributeError:
  163. return np.array(results).reshape(og_shape)
  164. return mapped_f
  165. def full_fm_arrays(l_array: Array) -> (Array, Array):
  166. all_m_list = []
  167. for l in l_array:
  168. for i in range(2 * l + 1):
  169. all_m_list.append(-l + i)
  170. return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
  171. def rot_sym_expansion(l_array: Array, coefs: Array) -> Array:
  172. """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
  173. full_coefs = np.zeros(coefs.shape[:-1] + (np.sum(2 * l_array + 1),))
  174. full_coefs[..., np.cumsum(2 * l_array + 1) - l_array - 1] = coefs
  175. return full_coefs
  176. def coefs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
  177. missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
  178. fill = np.zeros(np.sum(2 * missing_l + 1))
  179. full_l_array1, _ = expansion.lm_arrays
  180. # we search for where to place missing coefs with the help of a boolean array and argmax function
  181. comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
  182. indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
  183. new_coefs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
  184. return Expansion(target_l_array, new_coefs)
  185. def m_indices_at_l(l_arr: Array, l_idx: int):
  186. return np.arange(np.sum(2 * l_arr[:l_idx] + 1), np.sum(2 * l_arr[:l_idx+1] + 1))
  187. def purge_unneeded_l(l_array: Array, coefs: Array) -> (Array, Array):
  188. def delete_zero_entries(l, l_arr, cfs):
  189. l_idx = np.where(l_arr == l)[0][0]
  190. m_indices = m_indices_at_l(l_arr, l_idx)
  191. if np.all(cfs[..., m_indices] == 0):
  192. return np.delete(l_arr, l_idx), np.delete(cfs, m_indices, axis=-1)
  193. return l_arr, cfs
  194. for l in l_array:
  195. l_array, coefs = delete_zero_entries(l, l_array, coefs)
  196. return l_array, coefs
  197. def plot_theta_profile(ex: Expansion, phi: float = 0, num: int = 100, theta_start: float = 0, theta_end: float = np.pi):
  198. theta_vals = np.linspace(theta_start, theta_end, num)
  199. charge = ex.charge_value(theta_vals, phi)
  200. plt.plot(theta_vals, charge.T)
  201. plt.show()
  202. def plot_theta_profile_multiple(ex_list: list[Expansion], label_list, phi: float = 0, num: int = 100, theta_start: float = 0, theta_end: float = np.pi):
  203. theta_vals = np.linspace(theta_start, theta_end, num)
  204. fig, ax = plt.subplots()
  205. for ex, label in zip(ex_list, label_list):
  206. ax.plot(theta_vals, ex.charge_value(theta_vals, phi).T, label=label)
  207. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  208. ax.set_xlabel(r'$\theta$', fontsize=13)
  209. ax.set_ylabel(r'$\sigma$', fontsize=13)
  210. plt.legend(fontsize=12)
  211. plt.tight_layout()
  212. plt.savefig(Path("/home/andraz/ChargedShells/Figures/charge_shape_comparison.png"), dpi=600)
  213. plt.show()
  214. def plot_charge_3d(ex: Expansion, num_theta=100, num_phi=100):
  215. theta = np.linspace(0, np.pi, num_theta)
  216. phi = np.linspace(0, 2 * np.pi, num_phi)
  217. theta, phi = np.meshgrid(theta, phi)
  218. r = ex.charge_value(theta.flatten(), phi.flatten()).reshape(theta.shape)
  219. # Convert spherical coordinates to Cartesian coordinates
  220. x = np.sin(theta) * np.cos(phi)
  221. y = np.sin(theta) * np.sin(phi)
  222. z = np.cos(theta)
  223. # Create a heatmap on the sphere
  224. fig = go.Figure(data=go.Surface(x=x, y=y, z=z, surfacecolor=r, colorscale='Jet'))
  225. fig.update_layout(scene=dict(aspectmode='data'))
  226. fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'))
  227. fig.show()
  228. def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
  229. common_l_array = np.union1d(ex1.l_array, ex2.l_array)
  230. return coefs_fill_missing_l(ex1, common_l_array), coefs_fill_missing_l(ex2, common_l_array)
  231. def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
  232. """
  233. General function for rotations of expansion coefficients using WignerD matrices. Combines all rotations
  234. with each expansion given in coefs array.
  235. :param rotations: Quaternion array, last dimension is 4
  236. :param coefs: array of expansion coefficients
  237. :param l_array: array of all ell values of the expansion
  238. :return rotated coefficients, output shape is rotations.shape[:-1] + coefs.shape
  239. """
  240. rot_arrays = rotations.ndarray.reshape((-1, 4))
  241. coefs_reshaped = coefs.reshape((-1, coefs.shape[-1]))
  242. wigner_matrices = spherical.Wigner(np.max(l_array)).D(rot_arrays)
  243. new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
  244. for i, l in enumerate(l_array):
  245. Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
  246. all_m_indices = m_indices_at_l(l_array, i)
  247. wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
  248. new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
  249. wm, coefs_reshaped[:, all_m_indices])
  250. return new_coefs.reshape(rotations.ndarray.shape[:-1] + coefs.shape)
  251. if __name__ == '__main__':
  252. # ex = MappedExpansionQuad(0.44, 3, 1, 10)
  253. # ex = Expansion(np.arange(3), np.array([1, -1, 0, 1, 2, 0, 3, 0, 2]))
  254. # ex = GaussianCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
  255. ex = SphericalCap(np.array([[0, 0], [np.pi, 0]]), 0.5, 0.1, 70)
  256. # print(np.real(ex.coefs))
  257. # plot_theta_profile(ex, num=1000, theta_end=2 * np.pi, phi=0)
  258. plot_charge_3d(ex)
  259. # new_coeffs = expansion_rotation(Quaternion(np.arange(20).reshape(5, 4)).normalized, ex.coeffs, ex.l_array)
  260. # print(new_coeffs.shape)
  261. #
  262. # newnew_coeffs = expansion_rotation(Quaternion(np.arange(16).reshape(4, 4)).normalized, new_coeffs, ex.l_array)
  263. # print(newnew_coeffs.shape)
  264. # rot_angles = np.linspace(0, np.pi, 1000)
  265. # t0 = time.time()
  266. # ex.rotate_euler(0, np.pi / 2, -1)
  267. # t1 = time.time()
  268. # print(ex.coefs)
  269. # print(t1 - t0)