gnidovec 1 год назад
Родитель
Сommit
584526fb05
8 измененных файлов с 345 добавлено и 100 удалено
  1. 98 32
      expansion.py
  2. 2 2
      functions.py
  3. 2 2
      interactions.py
  4. 18 5
      parameters.py
  5. 55 0
      patch_size.py
  6. 58 0
      path_plot.py
  7. 5 6
      potentials.py
  8. 107 53
      rotational_path.py

+ 98 - 32
expansion.py

@@ -8,10 +8,13 @@ import spherical
 import time
 import copy
 import matplotlib.pyplot as plt
+from typing import Callable, TypeVar
 
 
 Array = np.ndarray
 Quaternion = quaternionic.array
+T = TypeVar('T')
+V = TypeVar('V')
 
 
 class InvalidExpansion(Exception):
@@ -35,10 +38,27 @@ class Expansion:
         self.coefs = self.coefs.astype(np.complex128)
         self._starting_coefs = np.copy(self.coefs)
 
+    def __getitem__(self, item):
+        return Expansion(self.l_array, self.coefs[item])
+
     @property
     def max_l(self) -> int:
         return max(self.l_array)
 
+    @property
+    def shape(self):
+        return self.coefs.shape[:-1]
+
+    def flatten(self) -> Expansion:
+        new_expansion = self.clone()
+        new_expansion.coefs = new_expansion.coefs.reshape(-1, new_expansion.coefs.shape[-1])
+        new_expansion._rotations = new_expansion._rotations.reshape(-1, 4)
+        return new_expansion
+
+    def reshape(self, shape):
+        self.coefs = self.coefs.reshape(shape + (self.coefs.shape[-1],))
+        self._rotations = self._rotations.reshape(shape + (4),)
+
     @cached_property
     def lm_arrays(self) -> (Array, Array):
         """Return l and m arrays containing all (l, m) pairs."""
@@ -50,6 +70,7 @@ class Expansion:
         return np.repeat(arr, 2 * self.l_array + 1, axis=axis)
 
     def rotate(self, rotations: Quaternion, rotate_existing=False):
+        # TODO: rotations are currently saved wrong if we start form existing coefficients not the og ones
         self._rotations = rotations
         coefs = self.coefs if rotate_existing else self._starting_coefs
         self.coefs = expansion_rotation(rotations, coefs, self.l_array)
@@ -67,8 +88,9 @@ class Expansion:
             phi = np.array([phi])
         theta, phi = np.broadcast_arrays(theta, phi)
         full_l_array, full_m_array = self.lm_arrays
-        return np.sum(self.coefs[None, :] * fn.sph_harm(full_l_array[None, :], full_m_array[None, :],
-                                                        theta[:, None], phi[:, None]), axis=1)
+        return np.squeeze(np.real(np.sum(self.coefs[..., None] * fn.sph_harm(full_l_array[:, None],
+                                                                             full_m_array[:, None],
+                                                                             theta[None, :], phi[None, :]), axis=-2)))
 
     def clone(self) -> Expansion:
         return copy.deepcopy(self)
@@ -78,42 +100,70 @@ class Expansion24(Expansion):
 
     def __init__(self, sigma2: float, sigma4: float, sigma0: float = 0.):
         l_array = np.array([0, 2, 4])
-        coeffs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
-        super().__init__(l_array, coeffs)
+        coefs = rot_sym_expansion(l_array, np.array([sigma0, sigma2, sigma4]))
+        super().__init__(l_array, coefs)
 
 
 class MappedExpansionQuad(Expansion):
 
-    def __init__(self, a_bar: float, kappaR: float, sigma_m: float, max_l: int = 20, sigma0: float = 0):
-        l_array = np.array([l for l in range(max_l + 1) if l % 2 == 0])
-        coeffs = (2 * sigma_m * fn.coeff_C_diff(l_array, kappaR)
-                  * np.sqrt(4 * np.pi * (2 * l_array + 1)) * np.power(a_bar, l_array))
-        coeffs[0] = sigma0
-        coeffs = rot_sym_expansion(l_array, coeffs)
-        super().__init__(l_array, coeffs)
+    def __init__(self, a_bar: Array | float, kappaR: Array | float, sigma_m: float, l_max: int = 20, sigma0: float = 0):
+        a_bar, kappaR = np.broadcast_arrays(a_bar, kappaR)
+
+        l_array = np.array([l for l in range(l_max + 1) if l % 2 == 0])
+        a_bar, kappaR, l_array_expanded = np.broadcast_arrays(a_bar[..., None], kappaR[..., None], l_array[None, :])
+
+        coefs = (2 * sigma_m * fn.coef_C_diff(l_array_expanded, kappaR)
+                  * np.sqrt(4 * np.pi * (2 * l_array_expanded + 1)) * np.power(a_bar, l_array_expanded))
+        coefs[..., 0] = sigma0
+        coefs = rot_sym_expansion(l_array, coefs)
+        super().__init__(l_array, coefs)
 
 
-class SmearedCharges(Expansion):
+class GaussianCharges(Expansion):
 
-    def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int, sigma0: float = 0):
+    def __init__(self, omega_k: Array, lambda_k: Array | float, sigma1: float, l_max: int,
+                 sigma0: float = 0, equal_charges=True):
         omega_k = omega_k.reshape(-1, 2)
         if not isinstance(lambda_k, Array):
-            lambda_k = np.full((len(omega_k),), lambda_k)
+            lambda_k = np.array([lambda_k])
+        if equal_charges:
+            if lambda_k.ndim > 1:
+                raise ValueError(f'If equal_charges=True, lambda_k should be a 1D array, got shape {lambda_k.shape}')
+            lambda_k = np.full((omega_k.shape[0], lambda_k.shape[0]), lambda_k).T
         if lambda_k.shape[-1] != omega_k.shape[0]:
-            raise ValueError("Omega and lambda arrays should have the same length.")
+            raise ValueError("Number of charges (length of omega_k) should match the last dimension of lambda_k array.")
+        lambda_k = lambda_k.reshape(-1, omega_k.shape[0])
         l_array = np.arange(l_max + 1)
         full_l_array, full_m_array = full_fm_arrays(l_array)
         theta_k = omega_k[:, 0]
         phi_k = omega_k[:, 1]
-        summands = (lambda_k[None, :] / np.sinh(lambda_k[None, :])
-                    * fn.sph_bessel_i(full_l_array[:, None], lambda_k[None, :])
-                    * np.conj(fn.sph_harm(full_l_array[:, None], full_m_array[:, None],
-                                          theta_k[None, :], phi_k[None, :])))
-        coefs = 4 * np.pi * sigma1 * np.sum(summands, axis=1)
-        coefs[0] = sigma0
+        summands = (lambda_k[:, None, :] / np.sinh(lambda_k[:, None, :])
+                    * fn.sph_bessel_i(full_l_array[None, :, None], lambda_k[:, None, :])
+                    * np.conj(fn.sph_harm(full_l_array[None, :, None], full_m_array[None, :, None],
+                                          theta_k[None, None, :], phi_k[None, None, :])))
+        coefs = np.squeeze(4 * np.pi * sigma1 * np.sum(summands, axis=-1))
+        coefs[..., 0] = sigma0
+        l_array, coefs = purge_unneeded_l(l_array, coefs)
         super().__init__(l_array, coefs)
 
 
+def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
+    """Map a function f over the leading axes of an expansion."""
+
+    def mapped_f(ex: Expansion, *args, **kwargs):
+        og_shape = ex.shape
+        flat_ex = ex.flatten()
+        results = []
+        for i in range(np.prod(og_shape)):
+            results.append(f(flat_ex[i], *args, **kwargs))
+        try:
+            return np.array(results).reshape(og_shape + results[0].shape)
+        except AttributeError:
+            return np.array(results).reshape(og_shape)
+
+    return mapped_f
+
+
 def full_fm_arrays(l_array: Array) -> (Array, Array):
     all_m_list = []
     for l in l_array:
@@ -122,22 +172,38 @@ def full_fm_arrays(l_array: Array) -> (Array, Array):
     return np.repeat(l_array, 2 * l_array + 1), np.array(all_m_list)
 
 
-def rot_sym_expansion(l_array: Array, coeffs: Array) -> Array:
+def rot_sym_expansion(l_array: Array, coefs: Array) -> Array:
     """Create full expansion array for rotationally symmetric distributions with only m=0 terms different form 0."""
-    full_coeffs = np.zeros(np.sum(2 * l_array + 1))
-    full_coeffs[np.cumsum(2 * l_array + 1) - l_array - 1] = coeffs
-    return full_coeffs
+    full_coefs = np.zeros(coefs.shape[:-1] + (np.sum(2 * l_array + 1),))
+    full_coefs[..., np.cumsum(2 * l_array + 1) - l_array - 1] = coefs
+    return full_coefs
 
 
-def coeffs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
+def coefs_fill_missing_l(expansion: Expansion, target_l_array: Array) -> Expansion:
     missing_l = np.setdiff1d(target_l_array, expansion.l_array, assume_unique=True)
     fill = np.zeros(np.sum(2 * missing_l + 1))
     full_l_array1, _ = expansion.lm_arrays
-    # we search for where to place missing coeffs with the help of a boolean array and argmax function
+    # we search for where to place missing coefs with the help of a boolean array and argmax function
     comparison_bool = (full_l_array1[:, None] - missing_l[None, :]) > 0
     indices = np.where(np.any(comparison_bool, axis=0), np.argmax(comparison_bool, axis=0), full_l_array1.shape[0])
-    new_coeffs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
-    return Expansion(target_l_array, new_coeffs)
+    new_coefs = np.insert(expansion.coefs, np.repeat(indices, 2 * missing_l + 1), fill, axis=-1)
+    return Expansion(target_l_array, new_coefs)
+
+
+def m_indices_at_l(l_arr: Array, l_idx: int):
+    return np.arange(np.sum(2 * l_arr[:l_idx] + 1), np.sum(2 * l_arr[:l_idx+1] + 1))
+
+
+def purge_unneeded_l(l_array: Array, coefs: Array) -> (Array, Array):
+    def delete_zero_entries(l, l_arr, cfs):
+        l_idx = np.where(l_arr == l)[0][0]
+        m_indices = m_indices_at_l(l_arr, l_idx)
+        if np.all(cfs[..., m_indices] == 0):
+            return np.delete(l_arr, l_idx), np.delete(cfs, m_indices, axis=-1)
+        return l_arr, cfs
+    for l in l_array:
+        l_array, coefs = delete_zero_entries(l, l_array, coefs)
+    return l_array, coefs
 
 
 def plot_theta_profile(ex: Expansion, phi=0, num=100):
@@ -149,7 +215,7 @@ def plot_theta_profile(ex: Expansion, phi=0, num=100):
 
 def expansions_to_common_l(ex1: Expansion, ex2: Expansion) -> (Expansion, Expansion):
     common_l_array = np.union1d(ex1.l_array, ex2.l_array)
-    return coeffs_fill_missing_l(ex1, common_l_array),  coeffs_fill_missing_l(ex2, common_l_array)
+    return coefs_fill_missing_l(ex1, common_l_array),  coefs_fill_missing_l(ex2, common_l_array)
 
 
 def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
@@ -167,7 +233,7 @@ def expansion_rotation(rotations: Quaternion, coefs: Array, l_array: Array):
     new_coefs = np.zeros((rot_arrays.shape[0],) + coefs_reshaped.shape, dtype=np.complex128)
     for i, l in enumerate(l_array):
         Dlmn_slice = np.arange(l * (2 * l - 1) * (2 * l + 1) / 3, (l + 1) * (2 * l + 1) * (2 * l + 3) / 3).astype(int)
-        all_m_indices = np.arange(np.sum(2 * l_array[:i] + 1), np.sum(2 * l_array[:i + 1] + 1))
+        all_m_indices = m_indices_at_l(l_array, i)
         wm = wigner_matrices[:, Dlmn_slice].reshape((-1, 2*l+1, 2*l+1))
         new_coefs[..., all_m_indices] = np.einsum('rnm, qm -> rqn',
                                                    wm, coefs_reshaped[:, all_m_indices])
@@ -178,7 +244,7 @@ if __name__ == '__main__':
 
     # ex = MappedExpansionQuad(0.44, 3, 1, 10)
     # ex = Expansion(np.arange(3), np.array([1, -1, 0, 1, 2, 0, 3, 0, 2]))
-    ex = SmearedCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
+    ex = GaussianCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=10, sigma1=0.001, l_max=10)
     # print(ex.coefs)
     plot_theta_profile(ex)
 

+ 2 - 2
functions.py

@@ -14,7 +14,7 @@ def sph_harm(l, m, theta, phi, **kwargs):
     return sps.sph_harm(m, l, phi, theta, **kwargs)
 
 
-def interaction_coeff_C(l, p, x):
+def interaction_coef_C(l, p, x):
     return x * sps.iv(l + 1 / 2, x) * sps.iv(p + 1 / 2, x)
 
 
@@ -22,5 +22,5 @@ def coefficient_Cpm(l, x):
     return x * sps.kv(l + 1 / 2, x) * sps.iv(l + 1 / 2, x)
 
 
-def coeff_C_diff(l, x):
+def coef_C_diff(l, x):
     return 1 / (x * sps.iv(l + 1 / 2, x) * sps.kv(l + 3 / 2, x))

+ 2 - 2
interactions.py

@@ -72,7 +72,7 @@ def charged_shell_energy(ex1: Expansion, ex2: Expansion, dist: float, params: Mo
 
     constants = params.R ** 2 / (params.kappa * params.epsilon * uc.CONSTANTS.epsilon0) * energy_units(units, params)
 
-    C_vals = fn.interaction_coeff_C(l_vals, p_vals, params.kappaR)
+    C_vals = fn.interaction_coef_C(l_vals, p_vals, params.kappaR)
     lspm_vals = C_vals * (-1) ** (l_vals + m_vals) * lps_terms * bessel_vals * wigner1 * wigner2
     broadcasted_lspm_vals = np.broadcast_to(lspm_vals, charge_vals.shape)
 
@@ -82,7 +82,7 @@ def charged_shell_energy(ex1: Expansion, ex2: Expansion, dist: float, params: Mo
 if __name__ == '__main__':
 
     params = ModelParams(R=150, kappaR=3)
-    ex1 = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, max_l=20)
+    ex1 = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, l_max=20)
     ex2 = ex1.clone()
 
     dist = 2.

+ 18 - 5
parameters.py

@@ -1,14 +1,18 @@
+from __future__ import annotations
 from dataclasses import dataclass
 import numpy as np
 import units_and_constants as uc
 
 
-@dataclass
+Array = np.ndarray
+
+
+@dataclass(kw_only=True)
 class ModelParams:
-    R: float
-    kappa: float = None
-    kappaR: float = None
-    c0: float = None
+    R: float | Array
+    kappa: float | Array = None
+    kappaR: float | Array = None
+    c0: float | Array = None
     epsilon: float = 80
     temperature: float = 293
 
@@ -16,6 +20,15 @@ class ModelParams:
         self.kappa, self.kappaR, self.c0 = screening_calculator(self.R, self.temperature, self.epsilon,
                                                                 self.kappa, self.kappaR, self.c0)
 
+    def unravel(self) -> list[ModelParams]:
+        params_list = []
+        all_r = np.array([self.R]) if not isinstance(self.R, Array) else self.R
+        all_kappa = np.array([self.kappa]) if not isinstance(self.kappa, Array) else self.kappa
+        for r in all_r:
+            for kappa in all_kappa:
+                params_list.append(ModelParams(R=r, kappa=kappa))
+        return params_list
+
 
 def bjerrum(temp: float, epsilon: float) -> float:
     return uc.CONSTANTS.e0 ** 2 / (4 * np.pi * epsilon * uc.CONSTANTS.epsilon0 * uc.CONSTANTS.Boltzmann * temp)

+ 55 - 0
patch_size.py

@@ -0,0 +1,55 @@
+import numpy as np
+from scipy.optimize import bisect
+import expansion
+from matplotlib import pyplot as plt
+import parameters
+import potentials
+
+Expansion = expansion.Expansion
+Array = np.ndarray
+ModelParams = parameters.ModelParams
+
+
+@expansion.map_over_expansion
+def charge_patch_size(ex: Expansion, phi: float = 0, theta0: Array | float = 0, theta1: Array | float = np.pi / 2):
+    return bisect(lambda theta: ex.charge_value(theta, phi), theta0, theta1)
+
+
+def potential_patch_size(ex: Expansion, params: ModelParams,
+                         phi: float = 0, theta0: Array | float = 0, theta1: Array | float = np.pi / 2,
+                         match_expansion_axis_to_params: int = None):
+
+    meatp = match_expansion_axis_to_params
+
+    @expansion.map_over_expansion
+    def potential_zero(exp: Expansion, prms: ModelParams):
+        return bisect(lambda theta: potentials.charged_shell_potential(theta, phi, 1, exp, prms), theta0, theta1)
+
+    print(ex.shape)
+
+    params_list = params.unravel()
+    if meatp is not None:
+        expansion_list = [Expansion(ex.l_array, np.take(ex.coefs, i, axis=meatp)) for i in range(ex.shape[meatp])]
+    else:
+        expansion_list = [ex]
+
+    results = []
+    for exp, prms in zip(expansion_list, params_list):
+        results.append(potential_zero(exp, prms))
+
+
+if __name__ == '__main__':
+
+    a_bar = np.linspace(0.2, 0.8, 100)
+    kappaR = np.array([0.26, 1, 3, 10, 26])
+    params = ModelParams(R=150, kappaR=kappaR)
+    ex = expansion.MappedExpansionQuad(a_bar=a_bar[:, None], sigma_m=0.001, l_max=20, kappaR=kappaR[None, :])
+    print(ex.shape)
+
+    patch_size = charge_patch_size(ex)
+    patch_size_pot = potential_patch_size(ex, params, match_expansion_axis_to_params=1)
+
+    plt.plot(a_bar, patch_size * 180 / np.pi, label=kappaR)
+    plt.legend()
+    plt.show()
+

+ 58 - 0
path_plot.py

@@ -0,0 +1,58 @@
+import numpy as np
+from rotational_path import PairRotationalPath, PathEnergyPlot
+import expansion
+from parameters import ModelParams
+from pathlib import Path
+
+zero_to_pi_half = np.linspace(0, np.pi/2, 100, endpoint=True)
+pi_half_to_pi = np.linspace(np.pi/2, np.pi, 100, endpoint=True)
+
+QuadPath = PairRotationalPath()
+QuadPath.set_default_x_axis(zero_to_pi_half)
+QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half)
+QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half[::-1])
+QuadPath.add_euler(beta1=zero_to_pi_half)
+QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half)
+QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half, alpha2=np.pi/2)
+QuadPath.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half[::-1])
+
+DipolePath = PairRotationalPath()
+DipolePath.set_default_x_axis(zero_to_pi_half)
+DipolePath.add_euler(beta1=pi_half_to_pi[::-1])
+DipolePath.add_euler(beta1=zero_to_pi_half[::-1])
+DipolePath.add_euler(beta1=zero_to_pi_half, beta2=zero_to_pi_half)
+DipolePath.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half)
+DipolePath.add_euler(beta1=np.pi/2, alpha1=np.pi/2, beta2=pi_half_to_pi)
+DipolePath.add_euler(beta1=np.pi/2, beta2=pi_half_to_pi[::-1], alpha2=np.pi)
+DipolePath.add_euler(beta1=zero_to_pi_half[::-1], beta2=pi_half_to_pi, alpha2=np.pi)
+DipolePath.add_euler(beta1=zero_to_pi_half, beta2=pi_half_to_pi[::-1], alpha2=np.pi)
+DipolePath.add_euler(beta1=pi_half_to_pi, beta2=zero_to_pi_half[::-1], alpha2=np.pi)
+
+DipolePath2 = PairRotationalPath()
+DipolePath2.set_default_x_axis(zero_to_pi_half)
+DipolePath2.add_euler(beta1=pi_half_to_pi[::-1])
+DipolePath2.add_euler(beta1=zero_to_pi_half[::-1])
+DipolePath2.add_euler(beta1=zero_to_pi_half, beta2=zero_to_pi_half)
+DipolePath2.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half)
+DipolePath2.add_euler(beta1=np.pi/2, alpha1=np.pi/2, beta2=pi_half_to_pi)
+DipolePath2.add_euler(beta1=zero_to_pi_half[::-1], beta2=pi_half_to_pi[::-1])
+DipolePath2.add_euler(beta1=zero_to_pi_half[::-1], beta2=np.pi)
+DipolePath2.add_euler(beta1=zero_to_pi_half, beta2=pi_half_to_pi[::-1], alpha2=np.pi)
+DipolePath2.add_euler(beta1=pi_half_to_pi, beta2=zero_to_pi_half[::-1], alpha2=np.pi)
+
+
+if __name__ == '__main__':
+
+    params = ModelParams(R=150, kappaR=3)
+    # ex1 = expansion.MappedExpansionQuad(np.array([0.35, 0.44, 0.6]), params.kappaR, 0.001, max_l=20)
+    # ex1 = expansion.Expansion24(sigma2=0.001, sigma4=0)
+    # ex1 = expansion.Expansion(l_array=np.array([1]), coefs=expansion.rot_sym_expansion(np.array([1]), np.array([0.001])))
+    # ex1 = expansion.GaussianCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=np.array([5]), sigma1=0.001, l_max=10)
+    ex1 = expansion.GaussianCharges(omega_k=np.array([np.pi, 0]), lambda_k=np.array([1, 5, 10, 15]), sigma1=0.001, l_max=10)
+    ex2 = ex1.clone()
+
+    path_plot = PathEnergyPlot(ex1, ex2, DipolePath2, dist=2., params=params)
+    path_plot.plot(labels=[rf'$\lambda$={l}' for l in [1, 5, 10, 15]],
+                   # norm_euler_angles={'beta1': np.pi/2}
+                   )
+    # path_plot.plot_sections(save_as=Path('/home/andraz/ChargedShells/Figures/dipole_path2.png'))

+ 5 - 6
potentials.py

@@ -25,11 +25,7 @@ def charged_shell_potential(theta: Array | float,
     :param ex: Expansion object detailing patch distribution
     :param params: ModelParams object specifying parameter values for the model
     """
-    if isinstance(theta, float):
-        theta = np.full_like(phi, theta)
-
-    if isinstance(phi, float):
-        phi = np.full_like(theta, phi)
+    theta, phi = np.broadcast_arrays(theta, phi)
 
     if not theta.shape == phi.shape:
         raise ValueError('theta and phi arrays should have the same shape.')
@@ -38,6 +34,9 @@ def charged_shell_potential(theta: Array | float,
     dist = dist * params.R
     l_factors = (fn.coefficient_Cpm(ex.l_array, params.kappaR) * fn.sph_bessel_k(ex.l_array, params.kappa * dist)
                  / fn.sph_bessel_k(ex.l_array, params.kappaR))
+    l_factors = ex.repeat_over_m(l_factors)
+
+
 
     return (1 / (params.kappa * params.epsilon * uc.CONSTANTS.epsilon0)
             * np.real(np.sum(ex.repeat_over_m(l_factors)[None, :] * ex.coefs
@@ -47,7 +46,7 @@ def charged_shell_potential(theta: Array | float,
 if __name__ == '__main__':
 
     params = ModelParams(R=150, kappaR=3)
-    ex = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, max_l=10)
+    ex = expansion.MappedExpansionQuad(np.array([0.44, 0.5]), params.kappaR, 0.001, l_max=10)
 
     theta = np.linspace(0, np.pi, 1000)
     phi = 0.

+ 107 - 53
rotational_path.py

@@ -5,10 +5,12 @@ import expansion
 from parameters import ModelParams
 import interactions
 import matplotlib.pyplot as plt
+from pathlib import Path
 
 
 Quaternion = quaternionic.array
 Array = np.ndarray
+Expansion = expansion.Expansion
 
 
 @dataclass
@@ -16,67 +18,119 @@ class PairRotationalPath:
 
     rotations1: list[Quaternion] = field(default_factory=list)
     rotations2: list[Quaternion] = field(default_factory=list)
+    x_axis: list[Array] = field(default_factory=list)
+    overlapping_last: bool = True
+    _default_x_axis: Array = None
 
-    def add(self, rotation1: Quaternion, rotation2: Quaternion):
+    def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
         rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
-        self.rotations1.append(rotation1)
-        self.rotations2.append(rotation2)
+        self.rotations1.append(Quaternion(rotation1))
+        self.rotations2.append(Quaternion(rotation2))
+        if x_axis is None:
+            x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
+        self.add_x_axis(x_axis)
+
+    def add_x_axis(self, x_axis: Array):
+        try:
+            last_x_val = self.x_axis[-1][-1]
+        except IndexError:
+            last_x_val = 0
+        if self.overlapping_last:
+            self.x_axis.append(x_axis + last_x_val)
+        else:
+            raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
+
+    def set_default_x_axis(self, default_x_axis: Array):
+        self._default_x_axis = default_x_axis
 
     def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
-                  alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0):
+                  alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
+                  x_axis: Array = None):
         R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
         R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
-        self.add(Quaternion(R1_euler), Quaternion(R2_euler))
+        self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
 
-    def get_rotations(self) -> (Quaternion, Quaternion):
+    def stack_rotations(self) -> (Quaternion, Quaternion):
         return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
 
-
-zero_to_pi_half = np.linspace(0, np.pi/2, 100, endpoint=False)
-pi_half_to_pi = np.linspace(np.pi/2, np.pi, 100, endpoint=False)
-
-QuadPath = PairRotationalPath()
-QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half)
-QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half[::-1])
-QuadPath.add_euler(beta1=zero_to_pi_half)
-QuadPath.add_euler(beta1=zero_to_pi_half[::-1], beta2=zero_to_pi_half)
-QuadPath.add_euler(beta1=np.pi/2, beta2=zero_to_pi_half, alpha2=np.pi/2)
-QuadPath.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half[::-1])
-
-DipolePath = PairRotationalPath()
-DipolePath.add_euler(beta1=pi_half_to_pi[::-1])
-DipolePath.add_euler(beta1=zero_to_pi_half[::-1])
-DipolePath.add_euler(beta1=zero_to_pi_half, beta2=zero_to_pi_half)
-DipolePath.add_euler(beta1=np.pi/2, beta2=np.pi/2, alpha1=zero_to_pi_half)
-DipolePath.add_euler(beta1=np.pi/2, alpha1=np.pi/2, beta2=zero_to_pi_half)
-DipolePath.add_euler(beta1=np.pi/2, beta2=pi_half_to_pi[::-1], alpha2=np.pi)
-DipolePath.add_euler(beta1=zero_to_pi_half[::-1], beta2=pi_half_to_pi, alpha2=np.pi)
-DipolePath.add_euler(beta1=zero_to_pi_half, beta2=pi_half_to_pi[::-1], alpha2=np.pi)
-DipolePath.add_euler(beta1=pi_half_to_pi, beta2=zero_to_pi_half[::-1], alpha2=np.pi)
-
-
-if __name__ == '__main__':
-
-    params = ModelParams(R=150, kappaR=3)
-    # ex1 = expansion.MappedExpansionQuad(0.44, params.kappaR, 0.001, max_l=20)
-    # ex1 = expansion.Expansion24(sigma2=0.001, sigma4=0)
-    # ex1 = expansion.Expansion(l_array=np.array([1]), coefs=expansion.rot_sym_expansion(np.array([1]), np.array([0.001])))
-    # ex1 = expansion.SmearedCharges(omega_k=np.array([[0, 0], [np.pi, 0]]), lambda_k=5, sigma1=0.001, l_max=10)
-    ex1 = expansion.SmearedCharges(omega_k=np.array([np.pi, 0]), lambda_k=3, sigma1=0.001, l_max=10)
-    ex2 = ex1.clone()
-
-    # rotations1, rotations2 = QuadPath.get_rotations()
-    rotations1, rotations2 = DipolePath.get_rotations()
-    ex1.rotate(rotations1)
-    ex2.rotate(rotations2)
-
-    dist = 2
-
-    energy = interactions.charged_shell_energy(ex1, ex2, dist, params)
-
-    plt.plot(energy)
-    plt.show()
-
-
+    def stack_x_axes(self) -> Array:
+        return np.hstack(self.x_axis)
 
 
+@dataclass
+class PathEnergyPlot:
+
+    expansion1: Expansion
+    expansion2: Expansion
+    rot_path: PairRotationalPath
+    dist: float | Array
+    params: ModelParams
+
+    def __post_init__(self):
+        if not isinstance(self.dist, Array):
+            self.dist = np.array([self.dist])
+
+    def evaluate_energy(self):
+        energy = []
+        for dist in self.dist:
+            for params in self.params.unravel():
+                energy.append(interactions.charged_shell_energy(self.expansion1, self.expansion2, dist, params))
+        return np.squeeze(np.stack(energy, axis=-1))
+
+    def path_energy(self):
+        rotations1, rotations2 = self.rot_path.stack_rotations()
+        self.expansion1.rotate(rotations1)
+        self.expansion2.rotate(rotations2)
+        return self.evaluate_energy()
+
+    def section_energies(self):
+        energy_list = []
+        for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2):
+            self.expansion1.rotate(rot1)
+            self.expansion2.rotate(rot2)
+            energy_list.append(self.evaluate_energy())
+        return energy_list
+
+    def normalization(self, norm_euler_angles: dict):
+        if norm_euler_angles is None:
+            return np.array([1.])
+        self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0),
+                                     beta=norm_euler_angles.get('beta1', 0),
+                                     gamma=norm_euler_angles.get('gamma1', 0))
+        self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0),
+                                     beta=norm_euler_angles.get('beta2', 0),
+                                     gamma=norm_euler_angles.get('gamma2', 0))
+        return np.abs(self.evaluate_energy())
+
+    def plot(self, labels: list[str] = None, norm_euler_angles: dict = None):
+        energy = self.path_energy()
+        normalization = self.normalization(norm_euler_angles)
+        energy = energy / normalization[None, ...]
+        energy = energy.reshape(energy.shape[0], -1)
+        x_axis = self.rot_path.stack_x_axes()
+
+        fig, ax = plt.subplots()
+        ax.axhline(y=0, c='k', linestyle=':')
+        ax.plot(x_axis, energy, label=labels)
+        ax.legend(fontsize=12)
+        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
+        ax.set_xlabel('angle', fontsize=13)
+        ax.set_ylabel('U', fontsize=13)
+        plt.show()
+
+    def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None):
+        energy_list = self.section_energies()
+        normalization = self.normalization(norm_euler_angles)
+
+        fig, ax = plt.subplots()
+        ax.axhline(y=0, c='k', linestyle=':')
+        for x_axis, energy in zip(self.rot_path.x_axis, energy_list):
+            energy, norm = np.broadcast_arrays(energy, normalization)
+            ax.plot(x_axis, energy / norm)
+        # ax.legend(fontsize=12)
+        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
+        ax.set_xlabel('angle', fontsize=13)
+        ax.set_ylabel('U', fontsize=13)
+        if save_as is not None:
+            plt.savefig(save_as, dpi=600)
+        plt.show()