from __future__ import annotations from typing import Callable import jax.numpy as jnp from curvature_assembly import ( oriented_particle, data_protocols, patchy_interaction, multipole_interaction, ) from jax_md import energy as jaxmd_energy from curvature_assembly.smap import oriented_pair from jax_md import partition, space, dataclasses f32 = jnp.float32 f64 = jnp.float64 Array = jnp.ndarray DisplacementFn = space.DisplacementFn ContactFunction = Callable[..., Array] NeighborListFormat = partition.NeighborListFormat InteractionParams = data_protocols.InteractionParams def weeks_chandler_andersen( dr: Array, sigma: Array = 1.0, epsilon: Array = 1.0, **unused_kwargs ) -> Array: """Repulsive part of the Lennard-Jones potential.""" return jnp.where( dr < jnp.power(2, 1 / 6) * sigma, jaxmd_energy.lennard_jones(dr, sigma=sigma, epsilon=epsilon) + epsilon, 0.0, ) @dataclasses.dataclass class GbWcaParams: eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,))) epsilon: Array = 5.0 d0: Array = 10.0 sigma: Array = 1.0 alpha: Array = 1.0 band_theta: Array = jnp.pi / 2 band_sigma: Array = 0.5 def gaussian_band_wca_ellipsoid_pair( displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs ) -> Callable[..., Array]: contact_function = oriented_particle.get_ellipsoid_contact_function_param( contact_fn, **cf_kwargs ) def patchy_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array = jnp.array([1.0, 1.0, 1.0]), epsilon: Array = 5.0, d0: Array = 10, alpha: Array = 1.0, sigma: Array = 1.0, band_theta: Array = jnp.pi / 2, band_sigma: Array = 0.5, ) -> Array: cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals) wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon) ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma) patchy_value = patchy_interaction.gaussian_interaction_band( dr, eigsys1, eigsys2, band_theta, band_sigma ) # patchy_value = 0. return wca_repulsion + ellipsod_morse * patchy_value energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement) return energy_fn def gaussian_band_fh_wca_ellipsoid_pair( displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs ) -> Callable[..., Array]: contact_function = oriented_particle.get_ellipsoid_contact_function_param( contact_fn, **cf_kwargs ) def patchy_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array = jnp.array([1.0, 1.0, 1.0]), epsilon: Array = 5.0, d0: Array = 10, alpha: Array = 1.0, sigma: Array = 1.0, band_theta: Array = jnp.pi / 2, band_sigma: Array = 0.5, ) -> Array: cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals) wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon) ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma) patchy_value = patchy_interaction.gaussian_interaction_band_fixed_height( dr, eigsys1, eigsys2, band_theta, band_sigma ) # patchy_value = 0. return wca_repulsion + ellipsod_morse * patchy_value energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement) return energy_fn @dataclasses.dataclass class PatchyWcaParams: eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,))) epsilon: Array = 5.0 d0: Array = 10.0 sigma: Array = 1.0 alpha: Array = 1.0 lm_magnitudes: Array = 1 def patchy_wca_ellipsoid_pair( displacement: DisplacementFn, contact_fn: ContactFunction, lm: tuple | list[tuple], **cf_kwargs, ) -> Callable[..., Array]: contact_function = oriented_particle.get_ellipsoid_contact_function_param( contact_fn, **cf_kwargs ) patchy_function = patchy_interaction.patchy_interaction_general(lm) def patchy_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array = jnp.array([1.0, 1.0, 1.0]), epsilon: Array = 5.0, d0: Array = 10, alpha: Array = 1.0, sigma: Array = 1.0, lm_magnitudes: Array = 1.0, ): cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals) wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon) ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma) patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes) # patchy_value = 0. return wca_repulsion + ellipsod_morse * patchy_value energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement) return energy_fn @dataclasses.dataclass class QuadWcaParams: eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,))) epsilon: Array = 2.0 d0: Array = 10.0 q0: Array = 10.0 sigma: Array = 1.0 alpha: Array = 1.0 lm_magnitudes: Array = 1 def init_unit_volume_particle(self) -> FerroWcaParams: params_dict = vars(self) params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume( jnp.array([1.0, 1.0, 1.0]) ) return QuadWcaParams(**params_dict) def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams: params_dict = vars(self) params_dict["lm_magnitudes"] = lm_magnitudes return QuadWcaParams(**params_dict) def quadrupolar_wca_sphere_pair( displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs ) -> Callable[..., Array]: patchy_function = patchy_interaction.patchy_interaction_general(lm) def quadrupolar_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, epsilon: Array, # eigvals: Array = jnp.array([1., 1., 1.]), d0: Array = 1, q0: Array = 1, alpha: Array = 1.0, sigma: Array = 1.0, lm_magnitudes: Array = 1.0, **unused_kwargs, ): # NOTE: we take unit volume particles # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi)) sigma_particle = sigma wca = weeks_chandler_andersen( space.distance(dr), sigma=sigma_particle, epsilon=epsilon ) # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.) quadrupolar = multipole_interaction.lin_quad_energy( dr, eigsys1, eigsys2, multipole_interaction.quadrupolar_eigenvalues( q0 * sigma_particle ** (5 / 2) * jnp.sqrt(epsilon), jnp.pi / 2 ), ) # NOTE: in quadrupolar eigenvalues calculation, exponent was corrected from 5 to 5/2 ellipsod_morse = jaxmd_energy.morse( space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle ) patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes) # patchy_value = 0. return wca + quadrupolar + ellipsod_morse * patchy_value energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement) return energy_fn @dataclasses.dataclass class FerroWcaParams: eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,))) epsilon: Array = 5.0 d0: Array = 1.5 q0: Array = 2.0 sigma: Array = 1.0 alpha: Array = 1.0 lm_magnitudes: Array = 1 softness: Array = 1.5 def init_unit_volume_particle(self) -> FerroWcaParams: params_dict = vars(self) params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume( jnp.array([1.0, 1.0, 1.0]) ) return FerroWcaParams(**params_dict) def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams: params_dict = vars(self) params_dict["lm_magnitudes"] = lm_magnitudes return FerroWcaParams(**params_dict) def ferro_wca_sphere_pair( displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs ) -> Callable[..., Array]: patchy_function = patchy_interaction.patchy_interaction_general(lm) def ferro_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, # eigvals: Array = jnp.array([1., 1., 1.]), epsilon: Array = 5.0, d0: Array = 1, q0: Array = 2, alpha: Array = 1.0, sigma: Array = 1.0, lm_magnitudes: Array = 1.0, softness: Array = 1.5, **unused_kwargs, ): # NOTE: we take unit volume particles # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi)) sigma_particle = sigma wca = weeks_chandler_andersen( space.distance(dr), sigma=sigma_particle, epsilon=epsilon ) # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.) ferro = multipole_interaction.ferro_orientational_energy( dr, eigsys1, eigsys2, softness=softness ) morse = jaxmd_energy.morse( space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle ) patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes) # patchy_value = 0. return wca + morse * (patchy_value + q0**2 * ferro) energy_fn = oriented_pair(ferro_wca_ellipsoid, displacement) return energy_fn def quadrupolar_wca_ellipsoid_pair( displacement: DisplacementFn, contact_fn: ContactFunction, lm: tuple | list[tuple], **cf_kwargs, ) -> Callable[..., Array]: contact_function = oriented_particle.get_ellipsoid_contact_function_param( contact_fn, **cf_kwargs ) patchy_function = patchy_interaction.patchy_interaction_general(lm) def quadrupolar_wca_ellipsoid( dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array = jnp.array([1.0, 1.0, 1.0]), epsilon: Array = 5.0, d0: Array = 10, d1: Array = 10, alpha: Array = 1.0, sigma: Array = 1.0, lm_magnitudes: Array = 1.0, ): cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals) # wca_repulsion = weeks_chandler_andersen(cf, sigma=1., epsilon=epsilon) vdw = jaxmd_energy.lennard_jones(cf, sigma=1.0, epsilon=epsilon) quadrupolar = multipole_interaction.quadrupolar_interaction( dr, eigsys1, eigsys2, multipole_interaction.quadrupolar_eigenvalues(1.0) ) # ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma) # patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes) # patchy_value = 0. return vdw + d1 * quadrupolar energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement) return energy_fn