|
@@ -19,18 +19,20 @@ class PairRotationalPath:
|
|
|
rotations1: list[Quaternion] = field(default_factory=list)
|
|
|
rotations2: list[Quaternion] = field(default_factory=list)
|
|
|
x_axis: list[Array] = field(default_factory=list)
|
|
|
+ ticks: dict[float, str] = field(default_factory=dict)
|
|
|
overlapping_last: bool = True
|
|
|
_default_x_axis: Array = None
|
|
|
|
|
|
- def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
|
|
|
+ def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None,
|
|
|
+ start_name: str | float = None, end_name: str | float = None):
|
|
|
rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
|
|
|
self.rotations1.append(Quaternion(rotation1))
|
|
|
self.rotations2.append(Quaternion(rotation2))
|
|
|
if x_axis is None:
|
|
|
x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
|
|
|
- self.add_x_axis(x_axis)
|
|
|
+ self.add_x_axis(x_axis, start_name, end_name)
|
|
|
|
|
|
- def add_x_axis(self, x_axis: Array):
|
|
|
+ def add_x_axis(self, x_axis: Array, start_name: str | float = None, end_name: str | float = None):
|
|
|
try:
|
|
|
last_x_val = self.x_axis[-1][-1]
|
|
|
except IndexError:
|
|
@@ -39,16 +41,23 @@ class PairRotationalPath:
|
|
|
self.x_axis.append(x_axis + last_x_val)
|
|
|
else:
|
|
|
raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
|
|
|
+ # adding ticks to x_axis
|
|
|
+ s_n = x_axis[0] + last_x_val if start_name is None else start_name # defaults to just numbers
|
|
|
+ e_n = x_axis[-1] + last_x_val if end_name is None else end_name
|
|
|
+ start_position = float(self.x_axis[-1][0])
|
|
|
+ end_position = float(self.x_axis[-1][-1])
|
|
|
+ if start_position not in self.ticks or start_name is not None: # allows overwriting previous end with new start
|
|
|
+ self.ticks[start_position] = s_n
|
|
|
+ self.ticks[end_position] = e_n
|
|
|
|
|
|
def set_default_x_axis(self, default_x_axis: Array):
|
|
|
self._default_x_axis = default_x_axis
|
|
|
|
|
|
def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
|
|
|
- alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
|
|
|
- x_axis: Array = None):
|
|
|
+ alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0, **add_kwargs):
|
|
|
R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
|
|
|
R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
|
|
|
- self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
|
|
|
+ self.add(Quaternion(R1_euler), Quaternion(R2_euler), **add_kwargs)
|
|
|
|
|
|
def stack_rotations(self) -> (Quaternion, Quaternion):
|
|
|
return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
|
|
@@ -56,6 +65,16 @@ class PairRotationalPath:
|
|
|
def stack_x_axes(self) -> Array:
|
|
|
return np.hstack(self.x_axis)
|
|
|
|
|
|
+ def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
|
|
|
+ ax.axhline(y=0, c='k', linestyle=':')
|
|
|
+ ax.legend(fontsize=18)
|
|
|
+ ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
|
|
|
+ # ax.set_xlabel('angle', fontsize=15)
|
|
|
+ ax.set_ylabel(f'U [{energy_units}]', fontsize=20)
|
|
|
+ ax.set_xticks(list(self.ticks.keys()), list(self.ticks.values()), fontsize=18)
|
|
|
+ fig.set_size_inches(plt.figaspect(0.5))
|
|
|
+ plt.tight_layout()
|
|
|
+
|
|
|
|
|
|
@dataclass
|
|
|
class PathEnergyPlot:
|
|
@@ -76,15 +95,8 @@ class PathEnergyPlot:
|
|
|
if self.match_expansion_axis_to_params is not None:
|
|
|
self.match_expansion_axis_to_params += 1
|
|
|
|
|
|
- @staticmethod
|
|
|
- def plot_style(fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
|
|
|
- ax.axhline(y=0, c='k', linestyle=':')
|
|
|
- ax.legend(fontsize=15)
|
|
|
- ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
|
|
|
- ax.set_xlabel('angle', fontsize=15)
|
|
|
- ax.set_ylabel(f'U [{energy_units}]', fontsize=15)
|
|
|
- fig.set_size_inches(plt.figaspect(0.5))
|
|
|
- plt.tight_layout()
|
|
|
+ def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
|
|
|
+ self.rot_path.plot_style(fig, ax, energy_units)
|
|
|
|
|
|
def evaluate_energy(self):
|
|
|
energy = []
|