123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- from __future__ import annotations
- from typing import Callable
- import jax.numpy as jnp
- from curvature_assembly import (
- oriented_particle,
- data_protocols,
- patchy_interaction,
- multipole_interaction,
- )
- from jax_md import energy as jaxmd_energy
- from curvature_assembly.smap import oriented_pair
- from jax_md import partition, space, dataclasses
- f32 = jnp.float32
- f64 = jnp.float64
- Array = jnp.ndarray
- DisplacementFn = space.DisplacementFn
- ContactFunction = Callable[..., Array]
- NeighborListFormat = partition.NeighborListFormat
- InteractionParams = data_protocols.InteractionParams
- def weeks_chandler_andersen(
- dr: Array, sigma: Array = 1.0, epsilon: Array = 1.0, **unused_kwargs
- ) -> Array:
- """Repulsive part of the Lennard-Jones potential."""
- return jnp.where(
- dr < jnp.power(2, 1 / 6) * sigma,
- jaxmd_energy.lennard_jones(dr, sigma=sigma, epsilon=epsilon) + epsilon,
- 0.0,
- )
- @dataclasses.dataclass
- class GbWcaParams:
- eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
- epsilon: Array = 5.0
- d0: Array = 10.0
- sigma: Array = 1.0
- alpha: Array = 1.0
- band_theta: Array = jnp.pi / 2
- band_sigma: Array = 0.5
- def gaussian_band_wca_ellipsoid_pair(
- displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
- ) -> Callable[..., Array]:
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(
- contact_fn, **cf_kwargs
- )
- def patchy_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
- epsilon: Array = 5.0,
- d0: Array = 10,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- band_theta: Array = jnp.pi / 2,
- band_sigma: Array = 0.5,
- ) -> Array:
- cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
- wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
- ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
- patchy_value = patchy_interaction.gaussian_interaction_band(
- dr, eigsys1, eigsys2, band_theta, band_sigma
- )
- # patchy_value = 0.
- return wca_repulsion + ellipsod_morse * patchy_value
- energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
- return energy_fn
- def gaussian_band_fh_wca_ellipsoid_pair(
- displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
- ) -> Callable[..., Array]:
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(
- contact_fn, **cf_kwargs
- )
- def patchy_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
- epsilon: Array = 5.0,
- d0: Array = 10,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- band_theta: Array = jnp.pi / 2,
- band_sigma: Array = 0.5,
- ) -> Array:
- cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
- wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
- ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
- patchy_value = patchy_interaction.gaussian_interaction_band_fixed_height(
- dr, eigsys1, eigsys2, band_theta, band_sigma
- )
- # patchy_value = 0.
- return wca_repulsion + ellipsod_morse * patchy_value
- energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
- return energy_fn
- @dataclasses.dataclass
- class PatchyWcaParams:
- eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
- epsilon: Array = 5.0
- d0: Array = 10.0
- sigma: Array = 1.0
- alpha: Array = 1.0
- lm_magnitudes: Array = 1
- def patchy_wca_ellipsoid_pair(
- displacement: DisplacementFn,
- contact_fn: ContactFunction,
- lm: tuple | list[tuple],
- **cf_kwargs,
- ) -> Callable[..., Array]:
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(
- contact_fn, **cf_kwargs
- )
- patchy_function = patchy_interaction.patchy_interaction_general(lm)
- def patchy_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
- epsilon: Array = 5.0,
- d0: Array = 10,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- lm_magnitudes: Array = 1.0,
- ):
- cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
- wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
- ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
- patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
- # patchy_value = 0.
- return wca_repulsion + ellipsod_morse * patchy_value
- energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
- return energy_fn
- @dataclasses.dataclass
- class QuadWcaParams:
- eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
- epsilon: Array = 2.0
- d0: Array = 10.0
- q0: Array = 10.0
- sigma: Array = 1.0
- alpha: Array = 1.0
- lm_magnitudes: Array = 1
- def init_unit_volume_particle(self) -> FerroWcaParams:
- params_dict = vars(self)
- params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
- jnp.array([1.0, 1.0, 1.0])
- )
- return QuadWcaParams(**params_dict)
- def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
- params_dict = vars(self)
- params_dict["lm_magnitudes"] = lm_magnitudes
- return QuadWcaParams(**params_dict)
- def quadrupolar_wca_sphere_pair(
- displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
- ) -> Callable[..., Array]:
- patchy_function = patchy_interaction.patchy_interaction_general(lm)
- def quadrupolar_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- epsilon: Array,
- # eigvals: Array = jnp.array([1., 1., 1.]),
- d0: Array = 1,
- q0: Array = 1,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- lm_magnitudes: Array = 1.0,
- **unused_kwargs,
- ):
- # NOTE: we take unit volume particles
- # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
- sigma_particle = sigma
- wca = weeks_chandler_andersen(
- space.distance(dr), sigma=sigma_particle, epsilon=epsilon
- )
- # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
- quadrupolar = multipole_interaction.lin_quad_energy(
- dr,
- eigsys1,
- eigsys2,
- multipole_interaction.quadrupolar_eigenvalues(
- q0 * sigma_particle ** (5 / 2) * jnp.sqrt(epsilon), jnp.pi / 2
- ),
- )
- # NOTE: in quadrupolar eigenvalues calculation, exponent was corrected from 5 to 5/2
- ellipsod_morse = jaxmd_energy.morse(
- space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
- )
- patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
- # patchy_value = 0.
- return wca + quadrupolar + ellipsod_morse * patchy_value
- energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
- return energy_fn
- @dataclasses.dataclass
- class FerroWcaParams:
- eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
- epsilon: Array = 5.0
- d0: Array = 1.5
- q0: Array = 2.0
- sigma: Array = 1.0
- alpha: Array = 1.0
- lm_magnitudes: Array = 1
- softness: Array = 1.5
- def init_unit_volume_particle(self) -> FerroWcaParams:
- params_dict = vars(self)
- params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
- jnp.array([1.0, 1.0, 1.0])
- )
- return FerroWcaParams(**params_dict)
- def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
- params_dict = vars(self)
- params_dict["lm_magnitudes"] = lm_magnitudes
- return FerroWcaParams(**params_dict)
- def ferro_wca_sphere_pair(
- displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
- ) -> Callable[..., Array]:
- patchy_function = patchy_interaction.patchy_interaction_general(lm)
- def ferro_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- # eigvals: Array = jnp.array([1., 1., 1.]),
- epsilon: Array = 5.0,
- d0: Array = 1,
- q0: Array = 2,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- lm_magnitudes: Array = 1.0,
- softness: Array = 1.5,
- **unused_kwargs,
- ):
- # NOTE: we take unit volume particles
- # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
- sigma_particle = sigma
- wca = weeks_chandler_andersen(
- space.distance(dr), sigma=sigma_particle, epsilon=epsilon
- )
- # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
- ferro = multipole_interaction.ferro_orientational_energy(
- dr, eigsys1, eigsys2, softness=softness
- )
- morse = jaxmd_energy.morse(
- space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
- )
- patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
- # patchy_value = 0.
- return wca + morse * (patchy_value + q0**2 * ferro)
- energy_fn = oriented_pair(ferro_wca_ellipsoid, displacement)
- return energy_fn
- def quadrupolar_wca_ellipsoid_pair(
- displacement: DisplacementFn,
- contact_fn: ContactFunction,
- lm: tuple | list[tuple],
- **cf_kwargs,
- ) -> Callable[..., Array]:
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(
- contact_fn, **cf_kwargs
- )
- patchy_function = patchy_interaction.patchy_interaction_general(lm)
- def quadrupolar_wca_ellipsoid(
- dr: Array,
- eigsys1: Array,
- eigsys2: Array,
- eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
- epsilon: Array = 5.0,
- d0: Array = 10,
- d1: Array = 10,
- alpha: Array = 1.0,
- sigma: Array = 1.0,
- lm_magnitudes: Array = 1.0,
- ):
- cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
- # wca_repulsion = weeks_chandler_andersen(cf, sigma=1., epsilon=epsilon)
- vdw = jaxmd_energy.lennard_jones(cf, sigma=1.0, epsilon=epsilon)
- quadrupolar = multipole_interaction.quadrupolar_interaction(
- dr, eigsys1, eigsys2, multipole_interaction.quadrupolar_eigenvalues(1.0)
- )
- # ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
- # patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
- # patchy_value = 0.
- return vdw + d1 * quadrupolar
- energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
- return energy_fn
|