expansion.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. from dataclasses import dataclass
  3. from functools import cached_property
  4. import functions as fn
  5. Array = np.ndarray
  6. class InvalidExpansion(Exception):
  7. pass
  8. @dataclass
  9. class Expansion:
  10. """Generic class for storing surface charge expansion coefficients."""
  11. l_array: Array
  12. coeffs: Array
  13. def __post_init__(self):
  14. """Validation of the given expansion."""
  15. if len(self.coeffs) != np.sum(2 * self.l_array + 1):
  16. raise InvalidExpansion('Number of expansion coefficients does not match the provided l_array.')
  17. if np.all(np.sort(self.l_array) != self.l_array) or np.all(np.unique(self.l_array) != self.l_array):
  18. raise InvalidExpansion('Array of l values should be unique and sorted.')
  19. @property
  20. def max_l(self) -> int:
  21. return max(self.l_array)
  22. @cached_property
  23. def lm_arrays(self) -> (Array, Array):
  24. """Return l and m arrays containing all (l, m) pairs."""
  25. all_m_list = []
  26. for l in self.l_array:
  27. for i in range(2 * l + 1):
  28. all_m_list.append(-l + i)
  29. return np.repeat(self.l_array, 2 * self.l_array + 1), np.array(all_m_list)
  30. def repeat_over_m(self, arr: Array, axis=0):
  31. if not arr.shape[axis] == len(self.l_array):
  32. raise ValueError('Array length should be equal to the number of l in the expansion.')
  33. return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
  34. def rot_sym_expansion(l_array: Array, coeffs: Array) -> Array:
  35. """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
  36. full_coeffs = np.zeros(np.sum(2 * l_array + 1))
  37. full_coeffs[np.cumsum(2 * l_array + 1) - l_array - 1] = coeffs
  38. return full_coeffs
  39. class Expansion24(Expansion):
  40. def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
  41. l_array = np.array([0, 2, 4])
  42. coeffs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
  43. super().__init__(l_array, coeffs)
  44. class MappedExpansion(Expansion):
  45. def __init__(self, a_bar: float, kappaR: float, sigma_m: float, max_l: int = 20, sigma0: float = 0):
  46. l_array = np.array([l for l in range(max_l + 1) if l % 2 == 0])
  47. coeffs = (2 * sigma_m * fn.coeff_C_diff(l_array, kappaR)
  48. * np.sqrt(4 * np.pi * (2 * l_array + 1)) * np.power(a_bar, l_array))
  49. coeffs[0] = sigma0
  50. coeffs = rot_sym_expansion(l_array, coeffs)
  51. super().__init__(l_array, coeffs)
  52. if __name__ == '__main__':
  53. ex = MappedExpansion(0.44, 3, 1, 10)
  54. print(ex.coeffs)