Просмотр исходного кода

Path plot style now part of PairRotationalPath

gnidovec 1 год назад
Родитель
Сommit
9cb152786e
1 измененных файлов с 27 добавлено и 15 удалено
  1. 27 15
      charged_shells/rotational_path.py

+ 27 - 15
charged_shells/rotational_path.py

@@ -19,18 +19,20 @@ class PairRotationalPath:
     rotations1: list[Quaternion] = field(default_factory=list)
     rotations2: list[Quaternion] = field(default_factory=list)
     x_axis: list[Array] = field(default_factory=list)
+    ticks: dict[float, str] = field(default_factory=dict)
     overlapping_last: bool = True
     _default_x_axis: Array = None
 
-    def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
+    def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None,
+            start_name: str | float = None, end_name: str | float = None):
         rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
         self.rotations1.append(Quaternion(rotation1))
         self.rotations2.append(Quaternion(rotation2))
         if x_axis is None:
             x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
-        self.add_x_axis(x_axis)
+        self.add_x_axis(x_axis, start_name, end_name)
 
-    def add_x_axis(self, x_axis: Array):
+    def add_x_axis(self, x_axis: Array, start_name: str | float = None, end_name: str | float = None):
         try:
             last_x_val = self.x_axis[-1][-1]
         except IndexError:
@@ -39,16 +41,23 @@ class PairRotationalPath:
             self.x_axis.append(x_axis + last_x_val)
         else:
             raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
+        # adding ticks to x_axis
+        s_n = x_axis[0] + last_x_val if start_name is None else start_name  # defaults to just numbers
+        e_n = x_axis[-1] + last_x_val if end_name is None else end_name
+        start_position = float(self.x_axis[-1][0])
+        end_position = float(self.x_axis[-1][-1])
+        if start_position not in self.ticks or start_name is not None:  # allows overwriting previous end with new start
+            self.ticks[start_position] = s_n
+        self.ticks[end_position] = e_n
 
     def set_default_x_axis(self, default_x_axis: Array):
         self._default_x_axis = default_x_axis
 
     def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
-                  alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
-                  x_axis: Array = None):
+                  alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0, **add_kwargs):
         R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
         R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
-        self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
+        self.add(Quaternion(R1_euler), Quaternion(R2_euler), **add_kwargs)
 
     def stack_rotations(self) -> (Quaternion, Quaternion):
         return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
@@ -56,6 +65,16 @@ class PairRotationalPath:
     def stack_x_axes(self) -> Array:
         return np.hstack(self.x_axis)
 
+    def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
+        ax.axhline(y=0, c='k', linestyle=':')
+        ax.legend(fontsize=18)
+        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
+        # ax.set_xlabel('angle', fontsize=15)
+        ax.set_ylabel(f'U [{energy_units}]', fontsize=20)
+        ax.set_xticks(list(self.ticks.keys()), list(self.ticks.values()), fontsize=18)
+        fig.set_size_inches(plt.figaspect(0.5))
+        plt.tight_layout()
+
 
 @dataclass
 class PathEnergyPlot:
@@ -76,15 +95,8 @@ class PathEnergyPlot:
         if self.match_expansion_axis_to_params is not None:
             self.match_expansion_axis_to_params += 1
 
-    @staticmethod
-    def plot_style(fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
-        ax.axhline(y=0, c='k', linestyle=':')
-        ax.legend(fontsize=15)
-        ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
-        ax.set_xlabel('angle', fontsize=15)
-        ax.set_ylabel(f'U [{energy_units}]', fontsize=15)
-        fig.set_size_inches(plt.figaspect(0.5))
-        plt.tight_layout()
+    def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
+        self.rot_path.plot_style(fig, ax, energy_units)
 
     def evaluate_energy(self):
         energy = []