123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- from functools import partial
- from typing import List, Union, Callable
- import jax.numpy as jnp
- import jax
- from curvature_assembly.spherical_harmonics import sph_harm_fn, real_sph_harm, sph_harm_not_fast, sph_harm_fn_custom, real_sph_harm_fn_custom_rev
- Array = jnp.ndarray
- def vec_in_eigensystem(eigsys: Array, vec: Array):
- """Get vector components in the eigensystem."""
- return jnp.dot(jnp.transpose(eigsys), vec)
- def safe_arctan2(x, y):
- """
- Version of arctan2 that works for zero-valued inputs. Look at https://github.com/google/jax/issues/1052
- and https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
- """
- safe_y = jnp.where(y > 0., y, 1.)
- return jnp.where(y > 0, jnp.arctan2(x, safe_y), 1.)
- def cart_to_sph(vec: Array) -> (Array, Array):
- """Transformation to spherical coordinates theta and phi."""
- sph_coord = jnp.zeros(2, )
- sph_coord = sph_coord.at[0].set(safe_arctan2(jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), vec[2]))
- sph_coord = sph_coord.at[1].set(safe_arctan2(vec[1], vec[0]))
- return sph_coord
- def patchy_interaction_general(lm_list: Union[tuple, List[tuple]]) -> Callable:
- """
- Orientational part for a general patchy particle interaction where patches are described by a linear combination
- of spherical harmonics. The form of the potential is inspired by the Kern-Frenkel patchy particle model.
- """
- if isinstance(lm_list, tuple):
- lm_list = [lm_list]
- l_list, m_list = zip(*lm_list)
- l_array = jnp.array(l_list)
- m_array = jnp.array(m_list)
- # sph_harm = real_sph_harm_fn_custom_rev(6)
- if not jnp.all(jnp.abs(m_array) <= l_array):
- raise ValueError(f'Spherical harmonics are only defined for |m|<=l.')
- def fn(dr: Array, eigsys1: Array, eigsys2: Array, lm_magnitudes: Array) -> Array:
- if lm_magnitudes.shape == ():
- lm_magnitudes = jnp.full(len(lm_list), lm_magnitudes)
- if len(lm_magnitudes) != len(lm_list):
- raise ValueError(f'Length of lm_magnitudes array does not match the number of (l, m) expansion terms, '
- f'got {len(lm_magnitudes)} and {len(lm_list)}, respectively.')
- # dr points from 2nd to 1st particle (dr = r1 - r2)
- # we need relative direction from one particle to another, so in the case of the first, we need to take -dr
- normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
- vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
- vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
- # patches_particle1 = jnp.real(sph_harm(vec1)) @ lm_magnitudes
- # patches_particle2 = jnp.real(sph_harm(vec2)) @ lm_magnitudes
- patches_particle1 = real_sph_harm(vec1, l_list, m_list) @ lm_magnitudes
- patches_particle2 = real_sph_harm(vec2, l_list, m_list) @ lm_magnitudes
- # energy contribution from patches is defined in such a way that negative patches attract each other,
- # positive patches repulse and differently-signed patches have 0 energy
- return -(jnp.sign(patches_particle1) + jnp.sign(patches_particle2)) * patches_particle1 * patches_particle2
- return fn
- def generate_lm_list(l_max: int,
- only_non_neg_m: bool = False,
- only_even_l: bool = False,
- only_odd_l: bool = False) -> list:
- """Return list of all possible (l, m) for a given maximal l."""
- if only_odd_l and only_even_l:
- raise ValueError('Parameters only_even_l and only_odd_l cannot both be True at the same time.')
- lm_list = []
- if only_even_l:
- l_list = list(range(0, l_max + 1, 2))
- elif only_odd_l:
- l_list = list(range(1, l_max + 1, 2))
- else:
- l_list = list(range(0, l_max + 1))
- for l in l_list:
- min_m = 0 if only_non_neg_m else -l
- for m in range(min_m, l + 1):
- lm_list.append((l, m))
- return lm_list
- def init_lm_coefs(lm_list: list[tuple], nonzero_list: list[tuple], init_values: list = None) -> jnp.ndarray:
- """
- Initialize lm coefficients for a given lm_list with desired values. Default is 0. if init_values is not provided.
- """
- if init_values is None:
- init_values = [1 for _ in nonzero_list]
- coef_list = []
- for lm in lm_list:
- try:
- idx = nonzero_list.index(lm)
- coef_list.append(init_values[idx])
- except ValueError:
- coef_list.append(0.)
- return jnp.array(coef_list) / jnp.linalg.norm(jnp.array(coef_list))
- def patchy_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
- normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
- vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
- vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
- limit_z_plus = jnp.cos(theta + sigma)
- limit_z_minus = jnp.cos(theta - sigma)
- # return value should be positive for attractive patches
- # as this potential is usually combined with attractive isotropic term
- return jnp.heaviside(limit_z_minus - vec1[2], 0.5) * jnp.heaviside(vec1[2] - limit_z_plus, 0.5) * \
- jnp.heaviside(limit_z_minus - vec2[2], 0.5) * jnp.heaviside(vec2[2] - limit_z_plus, 0.5)
- @jax.custom_jvp
- def sigmoid(x):
- return 1 / (1 + jnp.exp(-x))
- @sigmoid.defjvp
- def sigmoid_jvp(x, x_dot):
- primal_out = sigmoid(x)
- tangent_out = primal_out * (1 - primal_out) * x_dot
- return primal_out, tangent_out
- def gaussian_belt(x, theta, sigma) -> jnp.ndarray:
- return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
- def gaussian_belt_fixed_height(x, theta, sigma) -> jnp.ndarray:
- return jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
- def gaussian_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
- normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
- vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
- vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
- belt = partial(gaussian_belt, theta=theta, sigma=sigma)
- # return value should be positive for attractive patches
- # as this potential is usually combined with attractive isotropic term
- return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))
- def gaussian_interaction_band_fixed_height(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
- normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
- vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
- vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
- belt = partial(gaussian_belt_fixed_height, theta=theta, sigma=sigma)
- # return value should be positive for attractive patches
- # as this potential is usually combined with attractive isotropic term
- return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))
|