energy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from __future__ import annotations
  2. from typing import Callable
  3. import jax.numpy as jnp
  4. from curvature_assembly import (
  5. oriented_particle,
  6. data_protocols,
  7. patchy_interaction,
  8. multipole_interaction,
  9. )
  10. from jax_md import energy as jaxmd_energy
  11. from curvature_assembly.smap import oriented_pair
  12. from jax_md import partition, space, dataclasses
  13. f32 = jnp.float32
  14. f64 = jnp.float64
  15. Array = jnp.ndarray
  16. DisplacementFn = space.DisplacementFn
  17. ContactFunction = Callable[..., Array]
  18. NeighborListFormat = partition.NeighborListFormat
  19. InteractionParams = data_protocols.InteractionParams
  20. def weeks_chandler_andersen(
  21. dr: Array, sigma: Array = 1.0, epsilon: Array = 1.0, **unused_kwargs
  22. ) -> Array:
  23. """Repulsive part of the Lennard-Jones potential."""
  24. return jnp.where(
  25. dr < jnp.power(2, 1 / 6) * sigma,
  26. jaxmd_energy.lennard_jones(dr, sigma=sigma, epsilon=epsilon) + epsilon,
  27. 0.0,
  28. )
  29. @dataclasses.dataclass
  30. class GbWcaParams:
  31. eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
  32. epsilon: Array = 5.0
  33. d0: Array = 10.0
  34. sigma: Array = 1.0
  35. alpha: Array = 1.0
  36. band_theta: Array = jnp.pi / 2
  37. band_sigma: Array = 0.5
  38. def gaussian_band_wca_ellipsoid_pair(
  39. displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
  40. ) -> Callable[..., Array]:
  41. contact_function = oriented_particle.get_ellipsoid_contact_function_param(
  42. contact_fn, **cf_kwargs
  43. )
  44. def patchy_wca_ellipsoid(
  45. dr: Array,
  46. eigsys1: Array,
  47. eigsys2: Array,
  48. eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
  49. epsilon: Array = 5.0,
  50. d0: Array = 10,
  51. alpha: Array = 1.0,
  52. sigma: Array = 1.0,
  53. band_theta: Array = jnp.pi / 2,
  54. band_sigma: Array = 0.5,
  55. ) -> Array:
  56. cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
  57. wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
  58. ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
  59. patchy_value = patchy_interaction.gaussian_interaction_band(
  60. dr, eigsys1, eigsys2, band_theta, band_sigma
  61. )
  62. # patchy_value = 0.
  63. return wca_repulsion + ellipsod_morse * patchy_value
  64. energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
  65. return energy_fn
  66. def gaussian_band_fh_wca_ellipsoid_pair(
  67. displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
  68. ) -> Callable[..., Array]:
  69. contact_function = oriented_particle.get_ellipsoid_contact_function_param(
  70. contact_fn, **cf_kwargs
  71. )
  72. def patchy_wca_ellipsoid(
  73. dr: Array,
  74. eigsys1: Array,
  75. eigsys2: Array,
  76. eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
  77. epsilon: Array = 5.0,
  78. d0: Array = 10,
  79. alpha: Array = 1.0,
  80. sigma: Array = 1.0,
  81. band_theta: Array = jnp.pi / 2,
  82. band_sigma: Array = 0.5,
  83. ) -> Array:
  84. cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
  85. wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
  86. ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
  87. patchy_value = patchy_interaction.gaussian_interaction_band_fixed_height(
  88. dr, eigsys1, eigsys2, band_theta, band_sigma
  89. )
  90. # patchy_value = 0.
  91. return wca_repulsion + ellipsod_morse * patchy_value
  92. energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
  93. return energy_fn
  94. @dataclasses.dataclass
  95. class PatchyWcaParams:
  96. eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
  97. epsilon: Array = 5.0
  98. d0: Array = 10.0
  99. sigma: Array = 1.0
  100. alpha: Array = 1.0
  101. lm_magnitudes: Array = 1
  102. def patchy_wca_ellipsoid_pair(
  103. displacement: DisplacementFn,
  104. contact_fn: ContactFunction,
  105. lm: tuple | list[tuple],
  106. **cf_kwargs,
  107. ) -> Callable[..., Array]:
  108. contact_function = oriented_particle.get_ellipsoid_contact_function_param(
  109. contact_fn, **cf_kwargs
  110. )
  111. patchy_function = patchy_interaction.patchy_interaction_general(lm)
  112. def patchy_wca_ellipsoid(
  113. dr: Array,
  114. eigsys1: Array,
  115. eigsys2: Array,
  116. eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
  117. epsilon: Array = 5.0,
  118. d0: Array = 10,
  119. alpha: Array = 1.0,
  120. sigma: Array = 1.0,
  121. lm_magnitudes: Array = 1.0,
  122. ):
  123. cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
  124. wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
  125. ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
  126. patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
  127. # patchy_value = 0.
  128. return wca_repulsion + ellipsod_morse * patchy_value
  129. energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
  130. return energy_fn
  131. @dataclasses.dataclass
  132. class QuadWcaParams:
  133. eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
  134. epsilon: Array = 2.0
  135. d0: Array = 10.0
  136. q0: Array = 10.0
  137. sigma: Array = 1.0
  138. alpha: Array = 1.0
  139. lm_magnitudes: Array = 1
  140. def init_unit_volume_particle(self) -> FerroWcaParams:
  141. params_dict = vars(self)
  142. params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
  143. jnp.array([1.0, 1.0, 1.0])
  144. )
  145. return QuadWcaParams(**params_dict)
  146. def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
  147. params_dict = vars(self)
  148. params_dict["lm_magnitudes"] = lm_magnitudes
  149. return QuadWcaParams(**params_dict)
  150. def quadrupolar_wca_sphere_pair(
  151. displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
  152. ) -> Callable[..., Array]:
  153. patchy_function = patchy_interaction.patchy_interaction_general(lm)
  154. def quadrupolar_wca_ellipsoid(
  155. dr: Array,
  156. eigsys1: Array,
  157. eigsys2: Array,
  158. epsilon: Array,
  159. # eigvals: Array = jnp.array([1., 1., 1.]),
  160. d0: Array = 1,
  161. q0: Array = 1,
  162. alpha: Array = 1.0,
  163. sigma: Array = 1.0,
  164. lm_magnitudes: Array = 1.0,
  165. **unused_kwargs,
  166. ):
  167. # NOTE: we take unit volume particles
  168. # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
  169. sigma_particle = sigma
  170. wca = weeks_chandler_andersen(
  171. space.distance(dr), sigma=sigma_particle, epsilon=epsilon
  172. )
  173. # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
  174. quadrupolar = multipole_interaction.lin_quad_energy(
  175. dr,
  176. eigsys1,
  177. eigsys2,
  178. multipole_interaction.quadrupolar_eigenvalues(
  179. q0 * sigma_particle ** (5 / 2) * jnp.sqrt(epsilon), jnp.pi / 2
  180. ),
  181. )
  182. # NOTE: in quadrupolar eigenvalues calculation, exponent was corrected from 5 to 5/2
  183. ellipsod_morse = jaxmd_energy.morse(
  184. space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
  185. )
  186. patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
  187. # patchy_value = 0.
  188. return wca + quadrupolar + ellipsod_morse * patchy_value
  189. energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
  190. return energy_fn
  191. @dataclasses.dataclass
  192. class FerroWcaParams:
  193. eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
  194. epsilon: Array = 5.0
  195. d0: Array = 1.5
  196. q0: Array = 2.0
  197. sigma: Array = 1.0
  198. alpha: Array = 1.0
  199. lm_magnitudes: Array = 1
  200. softness: Array = 1.5
  201. def init_unit_volume_particle(self) -> FerroWcaParams:
  202. params_dict = vars(self)
  203. params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
  204. jnp.array([1.0, 1.0, 1.0])
  205. )
  206. return FerroWcaParams(**params_dict)
  207. def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
  208. params_dict = vars(self)
  209. params_dict["lm_magnitudes"] = lm_magnitudes
  210. return FerroWcaParams(**params_dict)
  211. def ferro_wca_sphere_pair(
  212. displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
  213. ) -> Callable[..., Array]:
  214. patchy_function = patchy_interaction.patchy_interaction_general(lm)
  215. def ferro_wca_ellipsoid(
  216. dr: Array,
  217. eigsys1: Array,
  218. eigsys2: Array,
  219. # eigvals: Array = jnp.array([1., 1., 1.]),
  220. epsilon: Array = 5.0,
  221. d0: Array = 1,
  222. q0: Array = 2,
  223. alpha: Array = 1.0,
  224. sigma: Array = 1.0,
  225. lm_magnitudes: Array = 1.0,
  226. softness: Array = 1.5,
  227. **unused_kwargs,
  228. ):
  229. # NOTE: we take unit volume particles
  230. # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
  231. sigma_particle = sigma
  232. wca = weeks_chandler_andersen(
  233. space.distance(dr), sigma=sigma_particle, epsilon=epsilon
  234. )
  235. # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
  236. ferro = multipole_interaction.ferro_orientational_energy(
  237. dr, eigsys1, eigsys2, softness=softness
  238. )
  239. morse = jaxmd_energy.morse(
  240. space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
  241. )
  242. patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
  243. # patchy_value = 0.
  244. return wca + morse * (patchy_value + q0**2 * ferro)
  245. energy_fn = oriented_pair(ferro_wca_ellipsoid, displacement)
  246. return energy_fn
  247. def quadrupolar_wca_ellipsoid_pair(
  248. displacement: DisplacementFn,
  249. contact_fn: ContactFunction,
  250. lm: tuple | list[tuple],
  251. **cf_kwargs,
  252. ) -> Callable[..., Array]:
  253. contact_function = oriented_particle.get_ellipsoid_contact_function_param(
  254. contact_fn, **cf_kwargs
  255. )
  256. patchy_function = patchy_interaction.patchy_interaction_general(lm)
  257. def quadrupolar_wca_ellipsoid(
  258. dr: Array,
  259. eigsys1: Array,
  260. eigsys2: Array,
  261. eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
  262. epsilon: Array = 5.0,
  263. d0: Array = 10,
  264. d1: Array = 10,
  265. alpha: Array = 1.0,
  266. sigma: Array = 1.0,
  267. lm_magnitudes: Array = 1.0,
  268. ):
  269. cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
  270. # wca_repulsion = weeks_chandler_andersen(cf, sigma=1., epsilon=epsilon)
  271. vdw = jaxmd_energy.lennard_jones(cf, sigma=1.0, epsilon=epsilon)
  272. quadrupolar = multipole_interaction.quadrupolar_interaction(
  273. dr, eigsys1, eigsys2, multipole_interaction.quadrupolar_eigenvalues(1.0)
  274. )
  275. # ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
  276. # patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
  277. # patchy_value = 0.
  278. return vdw + d1 * quadrupolar
  279. energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
  280. return energy_fn