123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- from typing import Protocol, Callable, TypeVar
- import jax
- import jax.numpy as jnp
- from curvature_assembly import data_protocols
- from jax_md import rigid_body, energy, quantity
- from functools import partial
- Array = jnp.ndarray
- T = TypeVar('T')
- @partial(jnp.vectorize, signature='(d,d),(d,d)->(d,d)')
- def qf_from_rotation(rotation: Array, eigen_qf: Array) -> Array:
- """Get particle quadratic form in world frame given the rotation matrix that describes eigensystem orientation."""
- return jnp.linalg.multi_dot((rotation, eigen_qf, jnp.transpose(rotation)))
- @partial(jnp.vectorize, signature='(d)->(d,d)')
- def make_diagonal(eigvals: Array) -> Array:
- """Create diagonal matrix from an 1D array of length 3."""
- a, b, c = eigvals
- return jnp.array([[a, 0, 0],
- [0, b, 0],
- [0, 0, c]])
- def eigensystem(orientation: rigid_body.Quaternion) -> Array:
- """Get eigensystem matrix with eigenvectors as columns."""
- return jnp.moveaxis(rigid_body.space_to_body_rotation(orientation), -1, -2)
- def matrix_repr(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
- """Quadratic form of the oriented particle given the matrix eigenvalues and quaternion orientation."""
- return qf_from_rotation(eigensystem(orientation), make_diagonal(eigvals))
- def get_weight_matrices(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
- """Weight matrices of the rigid body with squared semi-axes lengths as matrix eigenvalues."""
- return matrix_repr(orientation, 1 / eigvals)
- @partial(jnp.vectorize, signature='(),(d)->(d)')
- def ellipsoid_moment_of_inertia(m, eigvals):
- eig1, eig2, eig3 = eigvals
- a2 = 1 / eig1
- b2 = 1 / eig2
- c2 = 1 / eig3
- return m / 5 * jnp.array([b2 + c2, a2 + c2, a2 + b2])
- def ellipsoid_mass(masses, eigvals) -> rigid_body.RigidBody:
- """Get an Ellipsoid with the mass and moment of inertia for each particle."""
- return rigid_body.RigidBody(masses, ellipsoid_moment_of_inertia(masses, eigvals))
- def contact_to_distance_cutoff(cf_cut: float, eigvals: Array) -> float:
- """
- Calculate a sufficient distance cutoff from the contact function cutoff.
- Contact function should be the square root of the Perram-Wertheim contact function.
- """
- return 2 / jnp.sqrt(jnp.min(eigvals)) * cf_cut
- def contact_to_distance_threshold(cf_cut: float, cf_theshold: float, eigvals: Array) -> float:
- """Map from threshold in contact function to the distance threshold. We take the minimal distance
- that comes from the particle move for cf_threshold at the very edge of the function range."""
- return contact_to_distance_cutoff(cf_cut, eigvals) - contact_to_distance_cutoff(cf_cut - cf_theshold, eigvals)
- def distance_to_contact_cutoff(r_cut: float, eigvals: Array) -> float:
- """
- Calculate a sufficient contact function cutoff from the distance cutoff.
- Contact function value returned corresponds to the square root of the Perram-Wertheim contact function.
- """
- return jnp.min(eigvals) * r_cut / 2
- def eigenvalues_at_unit_volume(eigenvalues: Array) -> Array:
- """Rescales the eigenvalues to get unit volume ellipsoids."""
- particle_volume = 4 * jnp.pi / 3 * jnp.prod(1 / jnp.sqrt(eigenvalues))
- return jnp.cbrt(particle_volume) ** 2 * eigenvalues
- def eigenvalues_to_semiaxes(eigenvalues: Array) -> Array:
- """Calculate ellipsoid semiaxes from eigenvalues."""
- return jnp.sort(1 / jnp.sqrt(eigenvalues))
- def canonicalize_eigvals(interaction_params: T) -> T:
- """
- Create a new InteractionParams instance with transformed eigenvalues
- so that they correspond to unit volume ellipsoidal particles.
- """
- params_dict = vars(interaction_params)
- new_dict = params_dict.copy() # shallow copy is enough as values (interaction_params elements) are jax arrays
- new_dict['eigvals'] = eigenvalues_at_unit_volume(params_dict['eigvals'])
- return type(interaction_params)(**new_dict)
- def box_size_at_number_density(particle_count: int,
- number_density: float,
- spatial_dimension: int = 3):
- return quantity.box_size_at_number_density(particle_count,
- number_density,
- spatial_dimension=spatial_dimension)
- def ellipsoid_volume(eigvals: Array):
- return 4 / 3 * jnp.pi / jnp.prod(jnp.sqrt(eigvals), axis=-1)
- def box_size_at_ellipsoid_density(particle_count: int,
- density: float,
- eigvals: Array):
- if eigvals.ndim > 2:
- raise ValueError("Eigenvalue matrix should have at most 2 dimensions.")
- spatial_dimension = eigvals.shape[-1]
- particle_volume = ellipsoid_volume(eigvals)
- if particle_volume.ndim == 0:
- particle_volume = jnp.full((particle_count,), particle_volume)
- total_particle_volume = jnp.sum(particle_volume)
- return jnp.power(total_particle_volume / density, 1 / spatial_dimension)
- @jax.jit
- def update_interaction_params(grad: data_protocols.InteractionParams,
- interaction_params: data_protocols.InteractionParams,
- learning_rate: float) -> data_protocols.InteractionParams:
- """
- Update interaction parameters with gradient descent step. Rescales the new ellipsoid eigenvalues
- so that they correspond to unit volume particles.
- """
- new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, interaction_params, grad)
- return canonicalize_eigvals(new_params)
- class OrientedParticleEnergy(Protocol):
- """Protocol specifying the signature for energy functions between oriented particles."""
- def __call__(self, dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
- ...
- def get_ellipsoid_contact_function(contact_function: Callable[..., Array], eigvals: Array, **cf_kwargs):
- """
- Return a function that calculates square root of the Perram-Wertheim contact function between a pair of ellipsoids.
- """
- def fun(dr: Array, eigsys1: Array, eigsys2: Array) -> Array:
- qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
- qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
- return contact_function(dr, qf1, qf2, **cf_kwargs)
- return fun
- def get_ellipsoid_contact_function_param(contact_function: Callable[..., Array], **cf_kwargs):
- """
- Return a function that calculates the contact function between a pair of ellipsoids with a standardized call
- signature. It also does the transform from the standard quadratic form eigenvalues for ellipsoids (where
- eigenvalues are invere squares of semiaxis lenghts) to the weight matrix used in the Perram-Wertheim contact
- function (eigenvalues are just semiaxes squared, without the inverse).
- """
- def fun(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:
- qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
- qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
- return contact_function(dr, qf1, qf2, **cf_kwargs)
- return fun
- def isotropic_to_ellipsoid_energy(energy_fn: Callable[..., Array],
- contact_function: Callable[..., Array],
- eigvals: Array,
- **cf_kwargs) -> OrientedParticleEnergy:
- """Promotes an isotropic energy function to one acting between ellipsoids,
- with a given contact function as a measure of distance."""
- cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
- def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
- return energy_fn(cf(dr, eigsys1, eigsys2), **kwargs)
- return ellipsoid_energy_fn
- def isotropic_to_cf_energy(energy_fn: Callable[..., Array],
- contact_function: Callable[..., Array],
- **cf_kwargs) -> OrientedParticleEnergy:
- """Promotes an isotropic energy function to one acting between ellipsoids,
- with a given contact function as a measure of distance."""
- cf = get_ellipsoid_contact_function_param(contact_function, **cf_kwargs)
- def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
- return energy_fn(cf(dr, eigsys1, eigsys2, **cf_kwargs), **kwargs)
- return ellipsoid_energy_fn
- def isotropic_to_ellipsoid_energy_with_cutoff(energy_fn: Callable[..., Array],
- contact_function: Callable[..., Array],
- eigvals: Array,
- cf_onset: float,
- cf_cutoff: float,
- **cf_kwargs) -> OrientedParticleEnergy:
- """
- Promotes an isotropic energy function to one acting between ellipsoids,
- with a given contact function as a measure of distance.
- Adds the multiplicative isotropic cutoff to get a truncated function.
- """
- cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
- def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
- return energy.multiplicative_isotropic_cutoff(
- energy_fn, cf_onset, cf_cutoff)(cf(dr, eigsys1, eigsys2), **kwargs)
- return ellipsoid_energy_fn
|