import numpy as np from dataclasses import dataclass, field import quaternionic from charged_shells import expansion, interactions, mapping from charged_shells.parameters import ModelParams import matplotlib.pyplot as plt from pathlib import Path from functools import partial Quaternion = quaternionic.array Array = np.ndarray Expansion = expansion.Expansion @dataclass class PairRotationalPath: rotations1: list[Quaternion] = field(default_factory=list) rotations2: list[Quaternion] = field(default_factory=list) x_axis: list[Array] = field(default_factory=list) ticks: dict[float, str] = field(default_factory=dict) overlapping_last: bool = True _default_x_axis: Array = None def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None, start_name: str | float = None, end_name: str | float = None): rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2) self.rotations1.append(Quaternion(rotation1)) self.rotations2.append(Quaternion(rotation2)) if x_axis is None: x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis self.add_x_axis(x_axis, start_name, end_name) def add_x_axis(self, x_axis: Array, start_name: str | float = None, end_name: str | float = None): try: last_x_val = self.x_axis[-1][-1] except IndexError: last_x_val = 0 if self.overlapping_last: self.x_axis.append(x_axis + last_x_val) else: raise NotImplementedError('Currently only overlapping end points for x-axes are supported.') # adding ticks to x_axis s_n = x_axis[0] + last_x_val if start_name is None else start_name # defaults to just numbers e_n = x_axis[-1] + last_x_val if end_name is None else end_name start_position = float(self.x_axis[-1][0]) end_position = float(self.x_axis[-1][-1]) if start_position not in self.ticks or start_name is not None: # allows overwriting previous end with new start self.ticks[start_position] = s_n self.ticks[end_position] = e_n def set_default_x_axis(self, default_x_axis: Array): self._default_x_axis = default_x_axis def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0, alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0, **add_kwargs): R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1) R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2) self.add(Quaternion(R1_euler), Quaternion(R2_euler), **add_kwargs) def stack_rotations(self) -> (Quaternion, Quaternion): return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2)) def stack_x_axes(self) -> Array: return np.hstack(self.x_axis) def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT', legend: bool = True, size: tuple | None = (4, 1.7), legend_loc: str = 'upper left'): ax.axhline(y=0, c='k', linestyle=':') if legend: ax.legend(fontsize=10, frameon=False, loc=legend_loc, bbox_to_anchor=(0.57, 1)) ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=11) # ax.set_xlabel('angle', fontsize=15) if energy_units == 'kT': energy_units = 'k_B T' ax.set_ylabel(f'$V [{energy_units}]$', fontsize=11) ax.set_xticks(list(self.ticks.keys()), list(self.ticks.values()), fontsize=11) if size is not None: fig.set_size_inches(size) for line in ax.get_lines(): line.set_linewidth(1.2) # plt.tight_layout() plt.subplots_adjust(left=0.12, right=0.97, top=0.95, bottom=0.15) @dataclass class PathEnergyPlot: """Path comparison for a pair of charge distributions, possibly at different parameter values.""" expansion1: Expansion expansion2: Expansion rot_path: PairRotationalPath dist: float | Array params: ModelParams match_expansion_axis_to_params: int | None = None units: interactions.EnergyUnit = 'kT' def __post_init__(self): if not isinstance(self.dist, Array): self.dist = np.array([self.dist]) # we add 1 to match_expansion_axis as rotations take the new leading axis if self.match_expansion_axis_to_params is not None: self.match_expansion_axis_to_params += 1 def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT', legend: bool = True, size: tuple = (8.25, 4.125)): self.rot_path.plot_style(fig, ax, energy_units, legend=legend, size=size) def evaluate_energy(self): energy = [] for dist in self.dist: energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy, dist=dist, units=self.units), self.match_expansion_axis_to_params) energy.append(energy_fn(self.expansion1, self.expansion2, self.params)) return np.squeeze(np.stack(energy, axis=-1)) def path_energy(self): rotations1, rotations2 = self.rot_path.stack_rotations() self.expansion1.rotate(rotations1) self.expansion2.rotate(rotations2) return self.evaluate_energy() def section_energies(self): energy_list = [] for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2): self.expansion1.rotate(rot1) self.expansion2.rotate(rot2) energy_list.append(self.evaluate_energy()) return energy_list def normalization(self, norm_euler_angles: dict): if norm_euler_angles is None: return np.array([1.]) self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0), beta=norm_euler_angles.get('beta1', 0), gamma=norm_euler_angles.get('gamma1', 0)) self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0), beta=norm_euler_angles.get('beta2', 0), gamma=norm_euler_angles.get('gamma2', 0)) return np.abs(self.evaluate_energy()) def evaluate_path(self, norm_euler_angles: dict = None): energy = self.path_energy() normalization = self.normalization(norm_euler_angles) energy = energy / normalization[None, ...] energy = energy.reshape(energy.shape[0], -1) return np.squeeze(energy) def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None): energy = self.evaluate_path(norm_euler_angles=norm_euler_angles) x_axis = self.rot_path.stack_x_axes() print(plt.figaspect(0.5) * 8.25) fig, ax = plt.subplots(figsize=plt.figaspect(0.5) * 8.25) ax.plot(x_axis, np.squeeze(energy), label=labels) self.plot_style(fig, ax, energy_units=self.units) if save_as is not None: plt.savefig(save_as, dpi=600) plt.show() def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None): energy_list = self.section_energies() normalization = self.normalization(norm_euler_angles) fig, ax = plt.subplots() for x_axis, energy in zip(self.rot_path.x_axis, energy_list): energy, norm = np.broadcast_arrays(energy, normalization) ax.plot(x_axis, energy / norm) self.plot_style(fig, ax, energy_units=self.units, legend=False) if save_as is not None: plt.savefig(save_as, dpi=600) plt.show() @dataclass class PathExpansionComparison: """Path comparison for different charge distribution models.""" ex_list: list[Expansion] rot_path: PairRotationalPath dist: float params: ModelParams path_list: list[PathEnergyPlot] = field(default_factory=list) def __post_init__(self): for ex in self.ex_list: self.path_list.append(PathEnergyPlot(ex, ex.clone(), self.rot_path, self.dist, self.params)) def evaluate_path(self, norm_euler_angles: dict = None): path_vals = [] for path in self.path_list: path_vals.append(path.evaluate_path(norm_euler_angles)) return np.squeeze(np.stack(path_vals)) def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None): energy = self.evaluate_path(norm_euler_angles=norm_euler_angles) x_axis = self.rot_path.stack_x_axes() fig, ax = plt.subplots() ax.axhline(y=0, c='k', linestyle=':') ax.plot(x_axis, energy, label=labels) ax.legend(fontsize=12) ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12) ax.set_xlabel('angle', fontsize=13) ax.set_ylabel('U', fontsize=13) plt.tight_layout() if save_as is not None: plt.savefig(save_as, dpi=600) plt.show()