expansion.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from __future__ import annotations
  2. import numpy as np
  3. from dataclasses import dataclass
  4. import charged_shells.functions as fn
  5. import quaternionic
  6. import spherical
  7. import copy
  8. Array = np.ndarray
  9. Quaternion = quaternionic.array
  10. class InvalidExpansion(Exception):
  11. pass
  12. @dataclass
  13. class Expansion:
  14. """Generic class for storing surface charge expansion coefficients."""
  15. l_array: Array
  16. coefs: Array
  17. _starting_coefs: Array = None # initialized with the __post_init__ method
  18. _rotations: Quaternion = None
  19. def __post_init__(self):
  20. if self.coefs.shape[-1] != np.sum(2 * self.l_array + 1):
  21. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  22. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  23. raise InvalidExpansion('Array of l values should be unique and sorted.')
  24. self.coefs = self.coefs.astype(np.complex128)
  25. self._starting_coefs = np.copy(self.coefs)
  26. self._rotations = Quaternion([1., 0., 0., 0.])
  27. def __getitem__(self, item):
  28. return Expansion(self.l_array, self.coefs[item])
  29. @property
  30. def max_l(self) -> int:
  31. return max(self.l_array)
  32. @property
  33. def shape(self):
  34. return self.coefs.shape[:-1]
  35. def flatten(self) -> Expansion:
  36. new_expansion = self.clone() # np.ndarray.flatten() also copies the array
  37. new_expansion.coefs = new_expansion.coefs.reshape(-1, new_expansion.coefs.shape[-1])
  38. new_expansion._rotations = new_expansion._rotations.reshape(-1, 4)
  39. return new_expansion
  40. def reshape(self, shape: tuple):
  41. self.coefs = self.coefs.reshape(shape + (self.coefs.shape[-1],))
  42. self._rotations = self._rotations.reshape(shape + (4,))
  43. @property
  44. def lm_arrays(self) -> (Array, Array):
  45. """Return l and m arrays containing all (l, m) pairs."""
  46. return full_lm_arrays(self.l_array)
  47. def repeat_over_m(self, arr: Array, axis=0) -> Array:
  48. if not arr.shape[axis] == len(self.l_array):
  49. raise ValueError('Array length should be equal to the number of l in the expansion.')
  50. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  51. def rotate(self, rotations: Quaternion, rotate_existing=False):
  52. if rotate_existing:
  53. raise NotImplementedError("Rotation of possibly already rotated coefficients is not yet supported.")
  54. self._rotations = rotations
  55. self.coefs = expansion_rotation(rotations, self._starting_coefs, self.l_array)
  56. def rotate_euler(self, alpha: Array = 0, beta: Array = 0, gamma: Array = 0, rotate_existing=False):
  57. R_euler = quaternionic.array.from_euler_angles(alpha, beta, gamma)
  58. self.rotate(R_euler, rotate_existing=rotate_existing)
  59. def inverse_sign(self, exclude_00: bool = False):
  60. if self.l_array[0] == 0 and exclude_00:
  61. self.coefs[..., 1:] = -self.coefs[..., 1:]
  62. self._starting_coefs[..., 1:] = -self._starting_coefs[..., 1:]
  63. return self
  64. self.coefs = -self.coefs
  65. self._starting_coefs = -self._starting_coefs
  66. return self
  67. def charge_value(self, theta: Array | float, phi: Array | float):
  68. if not isinstance(theta, Array):
  69. theta = np.array([theta])
  70. if not isinstance(phi, Array):
  71. phi = np.array([phi])
  72. theta, phi = np.broadcast_arrays(theta, phi)
  73. full_l_array, full_m_array = self.lm_arrays
  74. return np.squeeze(np.real(np.sum(self.coefs[..., None] * fn.sph_harm(full_l_array[:, None],
  75. full_m_array[:, None],
  76. theta[None, :], phi[None, :]), axis=-2)))
  77. def clone(self) -> Expansion:
  78. return copy.deepcopy(self)
  79. def m_indices_at_l(l_arr: Array, l_idx: int):
  80. """
  81. For a given l_array and index l_idx for some ell in this array, get indices of all (ell, m) coefficients
  82. in coefficients array.
  83. """
  84. return np.arange(np.sum(2 * l_arr[:l_idx] + 1), np.sum(2 * l_arr[:l_idx+1] + 1))
  85. def full_lm_arrays(l_array: Array) -> (Array, Array):
  86. """From an array of l_values get arrays of ell and m that give you all pairs (ell, m)."""
  87. all_m_list = []
  88. for l in l_array:
  89. for i in range(2 * l + 1):
  90. all_m_list.append(-l + i)
  91. return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
  92. def coefs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
  93. """Explicitly add missing expansion coefficients so that expansion includes all ell values from the target array."""
  94. missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
  95. fill = np.zeros(np.sum(2 * missing_l + 1))
  96. full_l_array1, _ = expansion.lm_arrays
  97. # we search for where to place missing coefs with the help of a boolean array and argmax function
  98. comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
  99. indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
  100. new_coefs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
  101. return Expansion(target_l_array, new_coefs)
  102. def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
  103. """Explicitly add zero expansion coefficients so that both expansions include coefficients for the same ell."""
  104. common_l_array = np.union1d(ex1.l_array, ex2.l_array)
  105. return coefs_fill_missing_l(ex1, common_l_array), coefs_fill_missing_l(ex2, common_l_array)
  106. def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
  107. """
  108. General function for rotations of expansion coefficients using WignerD matrices. Combines all rotations
  109. with each expansion given in coefs array.
  110. :param rotations: Quaternion array, last dimension is 4
  111. :param coefs: array of expansion coefficients
  112. :param l_array: array of all ell values of the expansion
  113. :return rotated coefficients, output shape is rotations.shape[:-1] + coefs.shape
  114. """
  115. rot_arrays = rotations.ndarray.reshape((-1, 4))
  116. coefs_reshaped = coefs.reshape((-1, coefs.shape[-1]))
  117. wigner_matrices = spherical.Wigner(np.max(l_array)).D(rot_arrays)
  118. new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
  119. for i, l in enumerate(l_array):
  120. Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
  121. all_m_indices = m_indices_at_l(l_array, i)
  122. wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
  123. new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
  124. wm, coefs_reshaped[:, all_m_indices])
  125. return new_coefs.reshape(rotations.ndarray.shape[:-1] + coefs.shape)