rotational_path.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. from dataclasses import dataclass, field
  3. import quaternionic
  4. import expansion
  5. from parameters import ModelParams
  6. import interactions
  7. import matplotlib.pyplot as plt
  8. from pathlib import Path
  9. Quaternion = quaternionic.array
  10. Array = np.ndarray
  11. Expansion = expansion.Expansion
  12. @dataclass
  13. class PairRotationalPath:
  14. rotations1: list[Quaternion] = field(default_factory=list)
  15. rotations2: list[Quaternion] = field(default_factory=list)
  16. x_axis: list[Array] = field(default_factory=list)
  17. overlapping_last: bool = True
  18. _default_x_axis: Array = None
  19. def add(self, rotation1: Quaternion, rotation2: Quaternion, x_axis: Array = None):
  20. rotation1, rotation2 = np.broadcast_arrays(rotation1, rotation2)
  21. self.rotations1.append(Quaternion(rotation1))
  22. self.rotations2.append(Quaternion(rotation2))
  23. if x_axis is None:
  24. x_axis = np.arange(len(rotation1)) if self._default_x_axis is None else self._default_x_axis
  25. self.add_x_axis(x_axis)
  26. def add_x_axis(self, x_axis: Array):
  27. try:
  28. last_x_val = self.x_axis[-1][-1]
  29. except IndexError:
  30. last_x_val = 0
  31. if self.overlapping_last:
  32. self.x_axis.append(x_axis + last_x_val)
  33. else:
  34. raise NotImplementedError('Currently only overlapping end points for x-axes are supported.')
  35. def set_default_x_axis(self, default_x_axis: Array):
  36. self._default_x_axis = default_x_axis
  37. def add_euler(self, *, alpha1: Array = 0, beta1: Array = 0, gamma1: Array = 0,
  38. alpha2: Array = 0, beta2: Array = 0, gamma2: Array = 0,
  39. x_axis: Array = None):
  40. R1_euler = quaternionic.array.from_euler_angles(alpha1, beta1, gamma1)
  41. R2_euler = quaternionic.array.from_euler_angles(alpha2, beta2, gamma2)
  42. self.add(Quaternion(R1_euler), Quaternion(R2_euler), x_axis)
  43. def stack_rotations(self) -> (Quaternion, Quaternion):
  44. return Quaternion(np.vstack(self.rotations1)), Quaternion(np.vstack(self.rotations2))
  45. def stack_x_axes(self) -> Array:
  46. return np.hstack(self.x_axis)
  47. @dataclass
  48. class PathEnergyPlot:
  49. expansion1: Expansion
  50. expansion2: Expansion
  51. rot_path: PairRotationalPath
  52. dist: float | Array
  53. params: ModelParams
  54. def __post_init__(self):
  55. if not isinstance(self.dist, Array):
  56. self.dist = np.array([self.dist])
  57. def evaluate_energy(self):
  58. energy = []
  59. for dist in self.dist:
  60. for params in self.params.unravel():
  61. energy.append(interactions.charged_shell_energy(self.expansion1, self.expansion2, dist, params))
  62. return np.squeeze(np.stack(energy, axis=-1))
  63. def path_energy(self):
  64. rotations1, rotations2 = self.rot_path.stack_rotations()
  65. self.expansion1.rotate(rotations1)
  66. self.expansion2.rotate(rotations2)
  67. return self.evaluate_energy()
  68. def section_energies(self):
  69. energy_list = []
  70. for rot1, rot2 in zip(self.rot_path.rotations1, self.rot_path.rotations2):
  71. self.expansion1.rotate(rot1)
  72. self.expansion2.rotate(rot2)
  73. energy_list.append(self.evaluate_energy())
  74. return energy_list
  75. def normalization(self, norm_euler_angles: dict):
  76. if norm_euler_angles is None:
  77. return np.array([1.])
  78. self.expansion1.rotate_euler(alpha=norm_euler_angles.get('alpha1', 0),
  79. beta=norm_euler_angles.get('beta1', 0),
  80. gamma=norm_euler_angles.get('gamma1', 0))
  81. self.expansion2.rotate_euler(alpha=norm_euler_angles.get('alpha2', 0),
  82. beta=norm_euler_angles.get('beta2', 0),
  83. gamma=norm_euler_angles.get('gamma2', 0))
  84. return np.abs(self.evaluate_energy())
  85. def plot(self, labels: list[str] = None, norm_euler_angles: dict = None):
  86. energy = self.path_energy()
  87. normalization = self.normalization(norm_euler_angles)
  88. energy = energy / normalization[None, ...]
  89. energy = energy.reshape(energy.shape[0], -1)
  90. x_axis = self.rot_path.stack_x_axes()
  91. fig, ax = plt.subplots()
  92. ax.axhline(y=0, c='k', linestyle=':')
  93. ax.plot(x_axis, energy, label=labels)
  94. ax.legend(fontsize=12)
  95. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  96. ax.set_xlabel('angle', fontsize=13)
  97. ax.set_ylabel('U', fontsize=13)
  98. plt.show()
  99. def plot_sections(self, norm_euler_angles: dict = None, save_as: Path = None):
  100. energy_list = self.section_energies()
  101. normalization = self.normalization(norm_euler_angles)
  102. fig, ax = plt.subplots()
  103. ax.axhline(y=0, c='k', linestyle=':')
  104. for x_axis, energy in zip(self.rot_path.x_axis, energy_list):
  105. energy, norm = np.broadcast_arrays(energy, normalization)
  106. ax.plot(x_axis, energy / norm)
  107. # ax.legend(fontsize=12)
  108. ax.tick_params(which='both', direction='in', top=True, right=True, labelsize=12)
  109. ax.set_xlabel('angle', fontsize=13)
  110. ax.set_ylabel('U', fontsize=13)
  111. if save_as is not None:
  112. plt.savefig(save_as, dpi=600)
  113. plt.show()