from functools import partial from typing import List, Union, Callable import jax.numpy as jnp import jax from curvature_assembly.spherical_harmonics import sph_harm_fn, real_sph_harm, sph_harm_not_fast, sph_harm_fn_custom, real_sph_harm_fn_custom_rev Array = jnp.ndarray def vec_in_eigensystem(eigsys: Array, vec: Array): """Get vector components in the eigensystem.""" return jnp.dot(jnp.transpose(eigsys), vec) def safe_arctan2(x, y): """ Version of arctan2 that works for zero-valued inputs. Look at https://github.com/google/jax/issues/1052 and https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf """ safe_y = jnp.where(y > 0., y, 1.) return jnp.where(y > 0, jnp.arctan2(x, safe_y), 1.) def cart_to_sph(vec: Array) -> (Array, Array): """Transformation to spherical coordinates theta and phi.""" sph_coord = jnp.zeros(2, ) sph_coord = sph_coord.at[0].set(safe_arctan2(jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), vec[2])) sph_coord = sph_coord.at[1].set(safe_arctan2(vec[1], vec[0])) return sph_coord def patchy_interaction_general(lm_list: Union[tuple, List[tuple]]) -> Callable: """ Orientational part for a general patchy particle interaction where patches are described by a linear combination of spherical harmonics. The form of the potential is inspired by the Kern-Frenkel patchy particle model. """ if isinstance(lm_list, tuple): lm_list = [lm_list] l_list, m_list = zip(*lm_list) l_array = jnp.array(l_list) m_array = jnp.array(m_list) # sph_harm = real_sph_harm_fn_custom_rev(6) if not jnp.all(jnp.abs(m_array) <= l_array): raise ValueError(f'Spherical harmonics are only defined for |m|<=l.') def fn(dr: Array, eigsys1: Array, eigsys2: Array, lm_magnitudes: Array) -> Array: if lm_magnitudes.shape == (): lm_magnitudes = jnp.full(len(lm_list), lm_magnitudes) if len(lm_magnitudes) != len(lm_list): raise ValueError(f'Length of lm_magnitudes array does not match the number of (l, m) expansion terms, ' f'got {len(lm_magnitudes)} and {len(lm_list)}, respectively.') # dr points from 2nd to 1st particle (dr = r1 - r2) # we need relative direction from one particle to another, so in the case of the first, we need to take -dr normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2)) vec1 = vec_in_eigensystem(eigsys1, -normalized_dr) vec2 = vec_in_eigensystem(eigsys2, normalized_dr) # patches_particle1 = jnp.real(sph_harm(vec1)) @ lm_magnitudes # patches_particle2 = jnp.real(sph_harm(vec2)) @ lm_magnitudes patches_particle1 = real_sph_harm(vec1, l_list, m_list) @ lm_magnitudes patches_particle2 = real_sph_harm(vec2, l_list, m_list) @ lm_magnitudes # energy contribution from patches is defined in such a way that negative patches attract each other, # positive patches repulse and differently-signed patches have 0 energy return -(jnp.sign(patches_particle1) + jnp.sign(patches_particle2)) * patches_particle1 * patches_particle2 return fn def generate_lm_list(l_max: int, only_non_neg_m: bool = False, only_even_l: bool = False, only_odd_l: bool = False) -> list: """Return list of all possible (l, m) for a given maximal l.""" if only_odd_l and only_even_l: raise ValueError('Parameters only_even_l and only_odd_l cannot both be True at the same time.') lm_list = [] if only_even_l: l_list = list(range(0, l_max + 1, 2)) elif only_odd_l: l_list = list(range(1, l_max + 1, 2)) else: l_list = list(range(0, l_max + 1)) for l in l_list: min_m = 0 if only_non_neg_m else -l for m in range(min_m, l + 1): lm_list.append((l, m)) return lm_list def init_lm_coefs(lm_list: list[tuple], nonzero_list: list[tuple], init_values: list = None) -> jnp.ndarray: """ Initialize lm coefficients for a given lm_list with desired values. Default is 0. if init_values is not provided. """ if init_values is None: init_values = [1 for _ in nonzero_list] coef_list = [] for lm in lm_list: try: idx = nonzero_list.index(lm) coef_list.append(init_values[idx]) except ValueError: coef_list.append(0.) return jnp.array(coef_list) / jnp.linalg.norm(jnp.array(coef_list)) def patchy_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array): normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2)) vec1 = vec_in_eigensystem(eigsys1, -normalized_dr) vec2 = vec_in_eigensystem(eigsys2, normalized_dr) limit_z_plus = jnp.cos(theta + sigma) limit_z_minus = jnp.cos(theta - sigma) # return value should be positive for attractive patches # as this potential is usually combined with attractive isotropic term return jnp.heaviside(limit_z_minus - vec1[2], 0.5) * jnp.heaviside(vec1[2] - limit_z_plus, 0.5) * \ jnp.heaviside(limit_z_minus - vec2[2], 0.5) * jnp.heaviside(vec2[2] - limit_z_plus, 0.5) @jax.custom_jvp def sigmoid(x): return 1 / (1 + jnp.exp(-x)) @sigmoid.defjvp def sigmoid_jvp(x, x_dot): primal_out = sigmoid(x) tangent_out = primal_out * (1 - primal_out) * x_dot return primal_out, tangent_out def gaussian_belt(x, theta, sigma) -> jnp.ndarray: return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - theta) / sigma) ** 2) def gaussian_belt_fixed_height(x, theta, sigma) -> jnp.ndarray: return jnp.exp(-0.5 * ((x - theta) / sigma) ** 2) def gaussian_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array): normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2)) vec1 = vec_in_eigensystem(eigsys1, -normalized_dr) vec2 = vec_in_eigensystem(eigsys2, normalized_dr) belt = partial(gaussian_belt, theta=theta, sigma=sigma) # return value should be positive for attractive patches # as this potential is usually combined with attractive isotropic term return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2])) def gaussian_interaction_band_fixed_height(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array): normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2)) vec1 = vec_in_eigensystem(eigsys1, -normalized_dr) vec2 = vec_in_eigensystem(eigsys2, normalized_dr) belt = partial(gaussian_belt_fixed_height, theta=theta, sigma=sigma) # return value should be positive for attractive patches # as this potential is usually combined with attractive isotropic term return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))