rotational_path.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import numpy as np
  2. from dataclasses import dataclass, field
  3. import quaternionic
  4. import expansion
  5. from parameters import ModelParams
  6. import interactions
  7. import matplotlib.pyplot as plt
  8. from pathlib import Path
  9. Quaternion = quaternionic.array
  10. Array = np.ndarray
  11. Expansion = expansion.Expansion
  12. @dataclass
  13. class PairRotationalPath:
  14. rotations1: list[Quaternion] = field(default_factory=list)
  15. rotations2: list[Quaternion] = field(default_factory=list)
  16. x_axis: list[Array] = field(default_factory=list)
  17. overlapping_last: bool = True
  18. _default_x_axis: Array = None
  19. def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
  20. rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
  21. self.rotations1.append(Quaternion(rotation1))
  22. self.rotations2.append(Quaternion(rotation2))
  23. if x_axis is None:
  24. x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
  25. self.add_x_axis(x_axis)
  26. def add_x_axis(self, x_axis: Array):
  27. try:
  28. last_x_val = self.x_axis[-1][-1]
  29. except IndexError:
  30. last_x_val = 0
  31. if self.overlapping_last:
  32. self.x_axis.append(x_axis + last_x_val)
  33. else:
  34. raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
  35. def set_default_x_axis(self, default_x_axis: Array):
  36. self._default_x_axis = default_x_axis
  37. def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
  38. alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
  39. x_axis: Array = None):
  40. R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
  41. R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
  42. self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
  43. def stack_rotations(self) -> (Quaternion, Quaternion):
  44. return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
  45. def stack_x_axes(self) -> Array:
  46. return np.hstack(self.x_axis)
  47. @dataclass
  48. class PathEnergyPlot:
  49. """Path comparison for a pair of charge distributions, possibly at different parameter values."""
  50. expansion1: Expansion
  51. expansion2: Expansion
  52. rot_path: PairRotationalPath
  53. dist: float | Array
  54. params: ModelParams
  55. def __post_init__(self):
  56. if not isinstance(self.dist, Array):
  57. self.dist = np.array([self.dist])
  58. def evaluate_energy(self):
  59. energy = []
  60. for dist in self.dist:
  61. for params in self.params.unravel():
  62. energy.append(interactions.charged_shell_energy(self.expansion1, self.expansion2, dist, params))
  63. return np.squeeze(np.stack(energy, axis=-1))
  64. def path_energy(self):
  65. rotations1, rotations2 = self.rot_path.stack_rotations()
  66. self.expansion1.rotate(rotations1)
  67. self.expansion2.rotate(rotations2)
  68. return self.evaluate_energy()
  69. def section_energies(self):
  70. energy_list = []
  71. for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2):
  72. self.expansion1.rotate(rot1)
  73. self.expansion2.rotate(rot2)
  74. energy_list.append(self.evaluate_energy())
  75. return energy_list
  76. def normalization(self, norm_euler_angles: dict):
  77. if norm_euler_angles is None:
  78. return np.array([1.])
  79. self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0),
  80. beta=norm_euler_angles.get('beta1', 0),
  81. gamma=norm_euler_angles.get('gamma1', 0))
  82. self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0),
  83. beta=norm_euler_angles.get('beta2', 0),
  84. gamma=norm_euler_angles.get('gamma2', 0))
  85. return np.abs(self.evaluate_energy())
  86. def evaluate_path(self, norm_euler_angles: dict = None):
  87. energy = self.path_energy()
  88. normalization = self.normalization(norm_euler_angles)
  89. energy = energy / normalization[None, ...]
  90. energy = energy.reshape(energy.shape[0], -1)
  91. return np.squeeze(energy)
  92. def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
  93. energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
  94. x_axis = self.rot_path.stack_x_axes()
  95. fig, ax = plt.subplots()
  96. ax.axhline(y=0, c='k', linestyle=':')
  97. ax.plot(x_axis, np.squeeze(energy), label=labels)
  98. ax.legend(fontsize=12)
  99. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  100. ax.set_xlabel('angle', fontsize=13)
  101. ax.set_ylabel('U', fontsize=13)
  102. plt.tight_layout()
  103. if save_as is not None:
  104. plt.savefig(save_as, dpi=600)
  105. plt.show()
  106. def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None):
  107. energy_list = self.section_energies()
  108. normalization = self.normalization(norm_euler_angles)
  109. fig, ax = plt.subplots()
  110. ax.axhline(y=0, c='k', linestyle=':')
  111. for x_axis, energy in zip(self.rot_path.x_axis, energy_list):
  112. energy, norm = np.broadcast_arrays(energy, normalization)
  113. ax.plot(x_axis, energy / norm)
  114. # ax.legend(fontsize=12)
  115. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  116. ax.set_xlabel('angle', fontsize=13)
  117. ax.set_ylabel('U', fontsize=13)
  118. if save_as is not None:
  119. plt.savefig(save_as, dpi=600)
  120. plt.show()
  121. @dataclass
  122. class PathExpansionComparison:
  123. """Path comparison for different charge distribution models."""
  124. ex_list: list[Expansion]
  125. rot_path: PairRotationalPath
  126. dist: float
  127. params: ModelParams
  128. path_list: list[PathEnergyPlot] = field(default_factory=list)
  129. def __post_init__(self):
  130. for ex in self.ex_list:
  131. self.path_list.append(PathEnergyPlot(ex, ex.clone(), self.rot_path, self.dist, self.params))
  132. def evaluate_path(self, norm_euler_angles: dict = None):
  133. path_vals = []
  134. for path in self.path_list:
  135. path_vals.append(path.evaluate_path(norm_euler_angles))
  136. return np.stack(path_vals)
  137. def plot(self, labels: list[str] = None, norm_euler_angles: dict = None, save_as: Path = None):
  138. energy = self.evaluate_path(norm_euler_angles=norm_euler_angles)
  139. x_axis = self.rot_path.stack_x_axes()
  140. fig, ax = plt.subplots()
  141. ax.axhline(y=0, c='k', linestyle=':')
  142. ax.plot(x_axis, np.squeeze(energy).T, label=labels)
  143. ax.legend(fontsize=12)
  144. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  145. ax.set_xlabel('angle', fontsize=13)
  146. ax.set_ylabel('U', fontsize=13)
  147. plt.tight_layout()
  148. if save_as is not None:
  149. plt.savefig(save_as, dpi=600)
  150. plt.show()