123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import numpy as np
- from dataclasses import dataclass, field
- import quaternionic
- import expansion
- from parameters import ModelParams
- import interactions
- import matplotlib.pyplot as plt
- from pathlib import Path
- Quaternion = quaternionic.array
- Array = np.ndarray
- Expansion = expansion.Expansion
- @dataclass
- class PairRotationalPath:
- rotations1: list[Quaternion] = field(default_factory=list)
- rotations2: list[Quaternion] = field(default_factory=list)
- x_axis: list[Array] = field(default_factory=list)
- overlapping_last: bool = True
- _default_x_axis: Array = None
- def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
- rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
- self.rotations1.append(Quaternion(rotation1))
- self.rotations2.append(Quaternion(rotation2))
- if x_axis is None:
- x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
- self.add_x_axis(x_axis)
- def add_x_axis(self, x_axis: Array):
- try:
- last_x_val = self.x_axis[-1][-1]
- except IndexError:
- last_x_val = 0
- if self.overlapping_last:
- self.x_axis.append(x_axis + last_x_val)
- else:
- raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
- def set_default_x_axis(self, default_x_axis: Array):
- self._default_x_axis = default_x_axis
- def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
- alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
- x_axis: Array = None):
- R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
- R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
- self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
- def stack_rotations(self) -> (Quaternion, Quaternion):
- return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
- def stack_x_axes(self) -> Array:
- return np.hstack(self.x_axis)
- @dataclass
- class PathEnergyPlot:
- """Path comparison for a pair of charge distributions, possibly at different parameter values."""
- expansion1: Expansion
- expansion2: Expansion
- rot_path: PairRotationalPath
- dist: float | Array
- params: ModelParams
- def __post_init__(self):
- if not isinstance(self.dist, Array):
- self.dist = np.array([self.dist])
- def evaluate_energy(self):
- energy = []
- for dist in self.dist:
- for params in self.params.unravel():
- energy.append(interactions.charged_shell_energy(self.expansion1, self.expansion2, dist, params))
- return np.squeeze(np.stack(energy, axis=-1))
- def path_energy(self):
- rotations1, rotations2 = self.rot_path.stack_rotations()
- self.expansion1.rotate(rotations1)
- self.expansion2.rotate(rotations2)
- return self.evaluate_energy()
- def section_energies(self):
- energy_list = []
- for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2):
- self.expansion1.rotate(rot1)
- self.expansion2.rotate(rot2)
- energy_list.append(self.evaluate_energy())
- return energy_list
- def normalization(self, norm_euler_angles: dict):
- if norm_euler_angles is None:
- return np.array([1.])
- self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0),
- beta=norm_euler_angles.get('beta1', 0),
- gamma=norm_euler_angles.get('gamma1', 0))
- self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0),
- beta=norm_euler_angles.get('beta2', 0),
- gamma=norm_euler_angles.get('gamma2', 0))
- return np.abs(self.evaluate_energy())
- def evaluate_path(self, norm_euler_angles: dict = None):
- energy = self.path_energy()
- normalization = self.normalization(norm_euler_angles)
- energy = energy / normalization[None, ...]
- energy = energy.reshape(energy.shape[0], -1)
- return np.squeeze(energy)
- def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
- energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
- x_axis = self.rot_path.stack_x_axes()
- fig, ax = plt.subplots()
- ax.axhline(y=0, c='k', linestyle=':')
- ax.plot(x_axis, np.squeeze(energy), label=labels)
- ax.legend(fontsize=12)
- ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
- ax.set_xlabel('angle', fontsize=13)
- ax.set_ylabel('U', fontsize=13)
- plt.tight_layout()
- if save_as is not None:
- plt.savefig(save_as, dpi=600)
- plt.show()
- def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None):
- energy_list = self.section_energies()
- normalization = self.normalization(norm_euler_angles)
- fig, ax = plt.subplots()
- ax.axhline(y=0, c='k', linestyle=':')
- for x_axis, energy in zip(self.rot_path.x_axis, energy_list):
- energy, norm = np.broadcast_arrays(energy, normalization)
- ax.plot(x_axis, energy / norm)
- # ax.legend(fontsize=12)
- ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
- ax.set_xlabel('angle', fontsize=13)
- ax.set_ylabel('U', fontsize=13)
- if save_as is not None:
- plt.savefig(save_as, dpi=600)
- plt.show()
- @dataclass
- class PathExpansionComparison:
- """Path comparison for different charge distribution models."""
- ex_list: list[Expansion]
- rot_path: PairRotationalPath
- dist: float
- params: ModelParams
- path_list: list[PathEnergyPlot] = field(default_factory=list)
- def __post_init__(self):
- for ex in self.ex_list:
- self.path_list.append(PathEnergyPlot(ex, ex.clone(), self.rot_path, self.dist, self.params))
- def evaluate_path(self, norm_euler_angles: dict = None):
- path_vals = []
- for path in self.path_list:
- path_vals.append(path.evaluate_path(norm_euler_angles))
- return np.stack(path_vals)
- def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
- energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
- x_axis = self.rot_path.stack_x_axes()
- fig, ax = plt.subplots()
- ax.axhline(y=0, c='k', linestyle=':')
- ax.plot(x_axis, np.squeeze(energy).T, label=labels)
- ax.legend(fontsize=12)
- ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
- ax.set_xlabel('angle', fontsize=13)
- ax.set_ylabel('U', fontsize=13)
- plt.tight_layout()
- if save_as is not None:
- plt.savefig(save_as, dpi=600)
- plt.show()
|