rotational_path.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. from dataclasses import dataclass, field
  3. import quaternionic
  4. from charged_shells import expansion, interactions, mapping
  5. from charged_shells.parameters import ModelParams
  6. import matplotlib.pyplot as plt
  7. from pathlib import Path
  8. from functools import partial
  9. Quaternion = quaternionic.array
  10. Array = np.ndarray
  11. Expansion = expansion.Expansion
  12. @dataclass
  13. class PairRotationalPath:
  14. rotations1: list[Quaternion] = field(default_factory=list)
  15. rotations2: list[Quaternion] = field(default_factory=list)
  16. x_axis: list[Array] = field(default_factory=list)
  17. ticks: dict[float, str] = field(default_factory=dict)
  18. overlapping_last: bool = True
  19. _default_x_axis: Array = None
  20. def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None,
  21. start_name: str | float = None, end_name: str | float = None):
  22. rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
  23. self.rotations1.append(Quaternion(rotation1))
  24. self.rotations2.append(Quaternion(rotation2))
  25. if x_axis is None:
  26. x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
  27. self.add_x_axis(x_axis, start_name, end_name)
  28. def add_x_axis(self, x_axis: Array, start_name: str | float = None, end_name: str | float = None):
  29. try:
  30. last_x_val = self.x_axis[-1][-1]
  31. except IndexError:
  32. last_x_val = 0
  33. if self.overlapping_last:
  34. self.x_axis.append(x_axis + last_x_val)
  35. else:
  36. raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
  37. # adding ticks to x_axis
  38. s_n = x_axis[0] + last_x_val if start_name is None else start_name # defaults to just numbers
  39. e_n = x_axis[-1] + last_x_val if end_name is None else end_name
  40. start_position = float(self.x_axis[-1][0])
  41. end_position = float(self.x_axis[-1][-1])
  42. if start_position not in self.ticks or start_name is not None: # allows overwriting previous end with new start
  43. self.ticks[start_position] = s_n
  44. self.ticks[end_position] = e_n
  45. def set_default_x_axis(self, default_x_axis: Array):
  46. self._default_x_axis = default_x_axis
  47. def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
  48. alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0, **add_kwargs):
  49. R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
  50. R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
  51. self.add(Quaternion(R1_euler), Quaternion(R2_euler), **add_kwargs)
  52. def stack_rotations(self) -> (Quaternion, Quaternion):
  53. return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
  54. def stack_x_axes(self) -> Array:
  55. return np.hstack(self.x_axis)
  56. def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
  57. ax.axhline(y=0, c='k', linestyle=':')
  58. ax.legend(fontsize=18)
  59. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=15)
  60. # ax.set_xlabel('angle', fontsize=15)
  61. ax.set_ylabel(f'U [{energy_units}]', fontsize=20)
  62. ax.set_xticks(list(self.ticks.keys()), list(self.ticks.values()), fontsize=18)
  63. fig.set_size_inches(plt.figaspect(0.5))
  64. plt.tight_layout()
  65. @dataclass
  66. class PathEnergyPlot:
  67. """Path comparison for a pair of charge distributions, possibly at different parameter values."""
  68. expansion1: Expansion
  69. expansion2: Expansion
  70. rot_path: PairRotationalPath
  71. dist: float | Array
  72. params: ModelParams
  73. match_expansion_axis_to_params: int | None = None
  74. units: interactions.EnergyUnit = 'kT'
  75. def __post_init__(self):
  76. if not isinstance(self.dist, Array):
  77. self.dist = np.array([self.dist])
  78. # we add 1 to match_expansion_axis as rotations take the new leading axis
  79. if self.match_expansion_axis_to_params is not None:
  80. self.match_expansion_axis_to_params += 1
  81. def plot_style(self, fig, ax, energy_units: interactions.EnergyUnit = 'kT'):
  82. self.rot_path.plot_style(fig, ax, energy_units)
  83. def evaluate_energy(self):
  84. energy = []
  85. for dist in self.dist:
  86. energy_fn = mapping.parameter_map_two_expansions(partial(interactions.charged_shell_energy,
  87. dist=dist, units=self.units),
  88. self.match_expansion_axis_to_params)
  89. energy.append(energy_fn(self.expansion1, self.expansion2, self.params))
  90. return np.squeeze(np.stack(energy, axis=-1))
  91. def path_energy(self):
  92. rotations1, rotations2 = self.rot_path.stack_rotations()
  93. self.expansion1.rotate(rotations1)
  94. self.expansion2.rotate(rotations2)
  95. return self.evaluate_energy()
  96. def section_energies(self):
  97. energy_list = []
  98. for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2):
  99. self.expansion1.rotate(rot1)
  100. self.expansion2.rotate(rot2)
  101. energy_list.append(self.evaluate_energy())
  102. return energy_list
  103. def normalization(self, norm_euler_angles: dict):
  104. if norm_euler_angles is None:
  105. return np.array([1.])
  106. self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0),
  107. beta=norm_euler_angles.get('beta1', 0),
  108. gamma=norm_euler_angles.get('gamma1', 0))
  109. self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0),
  110. beta=norm_euler_angles.get('beta2', 0),
  111. gamma=norm_euler_angles.get('gamma2', 0))
  112. return np.abs(self.evaluate_energy())
  113. def evaluate_path(self, norm_euler_angles: dict = None):
  114. energy = self.path_energy()
  115. normalization = self.normalization(norm_euler_angles)
  116. energy = energy / normalization[None, ...]
  117. energy = energy.reshape(energy.shape[0], -1)
  118. return np.squeeze(energy)
  119. def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
  120. energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
  121. x_axis = self.rot_path.stack_x_axes()
  122. fig, ax = plt.subplots(figsize=plt.figaspect(0.5))
  123. ax.plot(x_axis, np.squeeze(energy), label=labels)
  124. self.plot_style(fig, ax, energy_units=self.units)
  125. if save_as is not None:
  126. plt.savefig(save_as, dpi=600)
  127. plt.show()
  128. def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None):
  129. energy_list = self.section_energies()
  130. normalization = self.normalization(norm_euler_angles)
  131. fig, ax = plt.subplots()
  132. for x_axis, energy in zip(self.rot_path.x_axis, energy_list):
  133. energy, norm = np.broadcast_arrays(energy, normalization)
  134. ax.plot(x_axis, energy / norm)
  135. self.plot_style(fig, ax, energy_units=self.units)
  136. if save_as is not None:
  137. plt.savefig(save_as, dpi=600)
  138. plt.show()
  139. @dataclass
  140. class PathExpansionComparison:
  141. """Path comparison for different charge distribution models."""
  142. ex_list: list[Expansion]
  143. rot_path: PairRotationalPath
  144. dist: float
  145. params: ModelParams
  146. path_list: list[PathEnergyPlot] = field(default_factory=list)
  147. def __post_init__(self):
  148. for ex in self.ex_list:
  149. self.path_list.append(PathEnergyPlot(ex, ex.clone(), self.rot_path, self.dist, self.params))
  150. def evaluate_path(self, norm_euler_angles: dict = None):
  151. path_vals = []
  152. for path in self.path_list:
  153. path_vals.append(path.evaluate_path(norm_euler_angles))
  154. return np.squeeze(np.stack(path_vals))
  155. def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
  156. energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
  157. x_axis = self.rot_path.stack_x_axes()
  158. fig, ax = plt.subplots()
  159. ax.axhline(y=0, c='k', linestyle=':')
  160. ax.plot(x_axis, energy, label=labels)
  161. ax.legend(fontsize=12)
  162. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  163. ax.set_xlabel('angle', fontsize=13)
  164. ax.set_ylabel('U', fontsize=13)
  165. plt.tight_layout()
  166. if save_as is not None:
  167. plt.savefig(save_as, dpi=600)
  168. plt.show()