patchy_interaction.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from functools import partial
  2. from typing import List, Union, Callable
  3. import jax.numpy as jnp
  4. import jax
  5. from curvature_assembly.spherical_harmonics import sph_harm_fn, real_sph_harm, sph_harm_not_fast, sph_harm_fn_custom, real_sph_harm_fn_custom_rev
  6. Array = jnp.ndarray
  7. def vec_in_eigensystem(eigsys: Array, vec: Array):
  8. """Get vector components in the eigensystem."""
  9. return jnp.dot(jnp.transpose(eigsys), vec)
  10. def safe_arctan2(x, y):
  11. """
  12. Version of arctan2 that works for zero-valued inputs. Look at https://github.com/google/jax/issues/1052
  13. and https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
  14. """
  15. safe_y = jnp.where(y > 0., y, 1.)
  16. return jnp.where(y > 0, jnp.arctan2(x, safe_y), 1.)
  17. def cart_to_sph(vec: Array) -> (Array, Array):
  18. """Transformation to spherical coordinates theta and phi."""
  19. sph_coord = jnp.zeros(2, )
  20. sph_coord = sph_coord.at[0].set(safe_arctan2(jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), vec[2]))
  21. sph_coord = sph_coord.at[1].set(safe_arctan2(vec[1], vec[0]))
  22. return sph_coord
  23. def patchy_interaction_general(lm_list: Union[tuple, List[tuple]]) -> Callable:
  24. """
  25. Orientational part for a general patchy particle interaction where patches are described by a linear combination
  26. of spherical harmonics. The form of the potential is inspired by the Kern-Frenkel patchy particle model.
  27. """
  28. if isinstance(lm_list, tuple):
  29. lm_list = [lm_list]
  30. l_list, m_list = zip(*lm_list)
  31. l_array = jnp.array(l_list)
  32. m_array = jnp.array(m_list)
  33. # sph_harm = real_sph_harm_fn_custom_rev(6)
  34. if not jnp.all(jnp.abs(m_array) <= l_array):
  35. raise ValueError(f'Spherical harmonics are only defined for |m|<=l.')
  36. def fn(dr: Array, eigsys1: Array, eigsys2: Array, lm_magnitudes: Array) -> Array:
  37. if lm_magnitudes.shape == ():
  38. lm_magnitudes = jnp.full(len(lm_list), lm_magnitudes)
  39. if len(lm_magnitudes) != len(lm_list):
  40. raise ValueError(f'Length of lm_magnitudes array does not match the number of (l, m) expansion terms, '
  41. f'got {len(lm_magnitudes)} and {len(lm_list)}, respectively.')
  42. # dr points from 2nd to 1st particle (dr = r1 - r2)
  43. # we need relative direction from one particle to another, so in the case of the first, we need to take -dr
  44. normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
  45. vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
  46. vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
  47. # patches_particle1 = jnp.real(sph_harm(vec1)) @ lm_magnitudes
  48. # patches_particle2 = jnp.real(sph_harm(vec2)) @ lm_magnitudes
  49. patches_particle1 = real_sph_harm(vec1, l_list, m_list) @ lm_magnitudes
  50. patches_particle2 = real_sph_harm(vec2, l_list, m_list) @ lm_magnitudes
  51. # energy contribution from patches is defined in such a way that negative patches attract each other,
  52. # positive patches repulse and differently-signed patches have 0 energy
  53. return -(jnp.sign(patches_particle1) + jnp.sign(patches_particle2)) * patches_particle1 * patches_particle2
  54. return fn
  55. def generate_lm_list(l_max: int,
  56. only_non_neg_m: bool = False,
  57. only_even_l: bool = False,
  58. only_odd_l: bool = False) -> list:
  59. """Return list of all possible (l, m) for a given maximal l."""
  60. if only_odd_l and only_even_l:
  61. raise ValueError('Parameters only_even_l and only_odd_l cannot both be True at the same time.')
  62. lm_list = []
  63. if only_even_l:
  64. l_list = list(range(0, l_max + 1, 2))
  65. elif only_odd_l:
  66. l_list = list(range(1, l_max + 1, 2))
  67. else:
  68. l_list = list(range(0, l_max + 1))
  69. for l in l_list:
  70. min_m = 0 if only_non_neg_m else -l
  71. for m in range(min_m, l + 1):
  72. lm_list.append((l, m))
  73. return lm_list
  74. def init_lm_coefs(lm_list: list[tuple], nonzero_list: list[tuple], init_values: list = None) -> jnp.ndarray:
  75. """
  76. Initialize lm coefficients for a given lm_list with desired values. Default is 0. if init_values is not provided.
  77. """
  78. if init_values is None:
  79. init_values = [1 for _ in nonzero_list]
  80. coef_list = []
  81. for lm in lm_list:
  82. try:
  83. idx = nonzero_list.index(lm)
  84. coef_list.append(init_values[idx])
  85. except ValueError:
  86. coef_list.append(0.)
  87. return jnp.array(coef_list) / jnp.linalg.norm(jnp.array(coef_list))
  88. def patchy_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
  89. normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
  90. vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
  91. vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
  92. limit_z_plus = jnp.cos(theta + sigma)
  93. limit_z_minus = jnp.cos(theta - sigma)
  94. # return value should be positive for attractive patches
  95. # as this potential is usually combined with attractive isotropic term
  96. return jnp.heaviside(limit_z_minus - vec1[2], 0.5) * jnp.heaviside(vec1[2] - limit_z_plus, 0.5) * \
  97. jnp.heaviside(limit_z_minus - vec2[2], 0.5) * jnp.heaviside(vec2[2] - limit_z_plus, 0.5)
  98. @jax.custom_jvp
  99. def sigmoid(x):
  100. return 1 / (1 + jnp.exp(-x))
  101. @sigmoid.defjvp
  102. def sigmoid_jvp(x, x_dot):
  103. primal_out = sigmoid(x)
  104. tangent_out = primal_out * (1 - primal_out) * x_dot
  105. return primal_out, tangent_out
  106. def gaussian_belt(x, theta, sigma) -> jnp.ndarray:
  107. return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
  108. def gaussian_belt_fixed_height(x, theta, sigma) -> jnp.ndarray:
  109. return jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
  110. def gaussian_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
  111. normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
  112. vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
  113. vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
  114. belt = partial(gaussian_belt, theta=theta, sigma=sigma)
  115. # return value should be positive for attractive patches
  116. # as this potential is usually combined with attractive isotropic term
  117. return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))
  118. def gaussian_interaction_band_fixed_height(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
  119. normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
  120. vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
  121. vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
  122. belt = partial(gaussian_belt_fixed_height, theta=theta, sigma=sigma)
  123. # return value should be positive for attractive patches
  124. # as this potential is usually combined with attractive isotropic term
  125. return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))