import numpy as np from dataclasses import dataclass, field import quaternionic import expansion from parameters import ModelParams import interactions import matplotlib.pyplot as plt from pathlib import Path Quaternion = quaternionic.array Array = np.ndarray Expansion = expansion.Expansion @dataclass class PairRotationalPath: rotations1: list[Quaternion] = field(default_factory=list) rotations2: list[Quaternion] = field(default_factory=list) x_axis: list[Array] = field(default_factory=list) overlapping_last: bool = True _default_x_axis: Array = None def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None): rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2) self.rotations1.append(Quaternion(rotation1)) self.rotations2.append(Quaternion(rotation2)) if x_axis is None: x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis self.add_x_axis(x_axis) def add_x_axis(self, x_axis: Array): try: last_x_val = self.x_axis[-1][-1] except IndexError: last_x_val = 0 if self.overlapping_last: self.x_axis.append(x_axis + last_x_val) else: raise NotImplementedError('Currently only overlapping end points for x-axes are supported.') def set_default_x_axis(self, default_x_axis: Array): self._default_x_axis = default_x_axis def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0, alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0, x_axis: Array = None): R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1) R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2) self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis) def stack_rotations(self) -> (Quaternion, Quaternion): return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2)) def stack_x_axes(self) -> Array: return np.hstack(self.x_axis) @dataclass class PathEnergyPlot: """Path comparison for a pair of charge distributions, possibly at different parameter values.""" expansion1: Expansion expansion2: Expansion rot_path: PairRotationalPath dist: float | Array params: ModelParams def __post_init__(self): if not isinstance(self.dist, Array): self.dist = np.array([self.dist]) def evaluate_energy(self): energy = [] for dist in self.dist: for params in self.params.unravel(): energy.append(interactions.charged_shell_energy(self.expansion1, self.expansion2, dist, params)) return np.squeeze(np.stack(energy, axis=-1)) def path_energy(self): rotations1, rotations2 = self.rot_path.stack_rotations() self.expansion1.rotate(rotations1) self.expansion2.rotate(rotations2) return self.evaluate_energy() def section_energies(self): energy_list = [] for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2): self.expansion1.rotate(rot1) self.expansion2.rotate(rot2) energy_list.append(self.evaluate_energy()) return energy_list def normalization(self, norm_euler_angles: dict): if norm_euler_angles is None: return np.array([1.]) self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0), beta=norm_euler_angles.get('beta1', 0), gamma=norm_euler_angles.get('gamma1', 0)) self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0), beta=norm_euler_angles.get('beta2', 0), gamma=norm_euler_angles.get('gamma2', 0)) return np.abs(self.evaluate_energy()) def evaluate_path(self, norm_euler_angles: dict = None): energy = self.path_energy() normalization = self.normalization(norm_euler_angles) energy = energy / normalization[None, ...] energy = energy.reshape(energy.shape[0], -1) return np.squeeze(energy) def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None): energy = self.evaluate_path(norm_euler_angles=norm_euler_angles) x_axis = self.rot_path.stack_x_axes() fig, ax = plt.subplots() ax.axhline(y=0, c='k', linestyle=':') ax.plot(x_axis, np.squeeze(energy), label=labels) ax.legend(fontsize=12) ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12) ax.set_xlabel('angle', fontsize=13) ax.set_ylabel('U', fontsize=13) plt.tight_layout() if save_as is not None: plt.savefig(save_as, dpi=600) plt.show() def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None): energy_list = self.section_energies() normalization = self.normalization(norm_euler_angles) fig, ax = plt.subplots() ax.axhline(y=0, c='k', linestyle=':') for x_axis, energy in zip(self.rot_path.x_axis, energy_list): energy, norm = np.broadcast_arrays(energy, normalization) ax.plot(x_axis, energy / norm) # ax.legend(fontsize=12) ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12) ax.set_xlabel('angle', fontsize=13) ax.set_ylabel('U', fontsize=13) if save_as is not None: plt.savefig(save_as, dpi=600) plt.show() @dataclass class PathExpansionComparison: """Path comparison for different charge distribution models.""" ex_list: list[Expansion] rot_path: PairRotationalPath dist: float params: ModelParams path_list: list[PathEnergyPlot] = field(default_factory=list) def __post_init__(self): for ex in self.ex_list: self.path_list.append(PathEnergyPlot(ex, ex.clone(), self.rot_path, self.dist, self.params)) def evaluate_path(self, norm_euler_angles: dict = None): path_vals = [] for path in self.path_list: path_vals.append(path.evaluate_path(norm_euler_angles)) return np.stack(path_vals) def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None): energy = self.evaluate_path(norm_euler_angles=norm_euler_angles) x_axis = self.rot_path.stack_x_axes() fig, ax = plt.subplots() ax.axhline(y=0, c='k', linestyle=':') ax.plot(x_axis, np.squeeze(energy).T, label=labels) ax.legend(fontsize=12) ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12) ax.set_xlabel('angle', fontsize=13) ax.set_ylabel('U', fontsize=13) plt.tight_layout() if save_as is not None: plt.savefig(save_as, dpi=600) plt.show()