| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 | from typing import Protocol, Callable, TypeVarimport jaximport jax.numpy as jnpfrom curvature_assembly import data_protocolsfrom jax_md import rigid_body, energy, quantityfrom functools import partialArray = jnp.ndarrayT = TypeVar('T')@partial(jnp.vectorize, signature='(d,d),(d,d)->(d,d)')def qf_from_rotation(rotation: Array, eigen_qf: Array) -> Array:    """Get particle quadratic form in world frame given the rotation matrix that describes eigensystem orientation."""    return jnp.linalg.multi_dot((rotation, eigen_qf, jnp.transpose(rotation)))@partial(jnp.vectorize, signature='(d)->(d,d)')def make_diagonal(eigvals: Array) -> Array:    """Create diagonal matrix from an 1D array of length 3."""    a, b, c = eigvals    return jnp.array([[a, 0, 0],                      [0, b, 0],                      [0, 0, c]])def eigensystem(orientation: rigid_body.Quaternion) -> Array:    """Get eigensystem matrix with eigenvectors as columns."""    return jnp.moveaxis(rigid_body.space_to_body_rotation(orientation), -1, -2)def matrix_repr(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:    """Quadratic form of the oriented particle given the matrix eigenvalues and quaternion orientation."""    return qf_from_rotation(eigensystem(orientation), make_diagonal(eigvals))def get_weight_matrices(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:    """Weight matrices of the rigid body with squared semi-axes lengths as matrix eigenvalues."""    return matrix_repr(orientation, 1 / eigvals)@partial(jnp.vectorize, signature='(),(d)->(d)')def ellipsoid_moment_of_inertia(m, eigvals):    eig1, eig2, eig3 = eigvals    a2 = 1 / eig1    b2 = 1 / eig2    c2 = 1 / eig3    return m / 5 * jnp.array([b2 + c2, a2 + c2, a2 + b2])def ellipsoid_mass(masses, eigvals) -> rigid_body.RigidBody:    """Get an Ellipsoid with the mass and moment of inertia for each particle."""    return rigid_body.RigidBody(masses, ellipsoid_moment_of_inertia(masses, eigvals))def contact_to_distance_cutoff(cf_cut: float, eigvals: Array) -> float:    """    Calculate a sufficient distance cutoff from the contact function cutoff.    Contact function should be the square root of the Perram-Wertheim contact function.    """    return 2 / jnp.sqrt(jnp.min(eigvals)) * cf_cutdef contact_to_distance_threshold(cf_cut: float, cf_theshold: float, eigvals: Array) -> float:    """Map from threshold in contact function to the distance threshold. We take the minimal distance    that comes from the particle move for cf_threshold at the very edge of the function range."""    return contact_to_distance_cutoff(cf_cut, eigvals) - contact_to_distance_cutoff(cf_cut - cf_theshold, eigvals)def distance_to_contact_cutoff(r_cut: float, eigvals: Array) -> float:    """    Calculate a sufficient contact function cutoff from the distance cutoff.    Contact function value returned corresponds to the square root of the Perram-Wertheim contact function.    """    return jnp.min(eigvals) * r_cut / 2def eigenvalues_at_unit_volume(eigenvalues: Array) -> Array:    """Rescales the eigenvalues to get unit volume ellipsoids."""    particle_volume = 4 * jnp.pi / 3 * jnp.prod(1 / jnp.sqrt(eigenvalues))    return jnp.cbrt(particle_volume) ** 2 * eigenvaluesdef eigenvalues_to_semiaxes(eigenvalues: Array) -> Array:    """Calculate ellipsoid semiaxes from eigenvalues."""    return jnp.sort(1 / jnp.sqrt(eigenvalues))def canonicalize_eigvals(interaction_params: T) -> T:    """    Create a new InteractionParams instance with transformed eigenvalues    so that they correspond to unit volume ellipsoidal particles.    """    params_dict = vars(interaction_params)    new_dict = params_dict.copy()  # shallow copy is enough as values (interaction_params elements) are jax arrays    new_dict['eigvals'] = eigenvalues_at_unit_volume(params_dict['eigvals'])    return type(interaction_params)(**new_dict)def box_size_at_number_density(particle_count: int,                               number_density: float,                               spatial_dimension: int = 3):    return quantity.box_size_at_number_density(particle_count,                                               number_density,                                               spatial_dimension=spatial_dimension)def ellipsoid_volume(eigvals: Array):    return 4 / 3 * jnp.pi / jnp.prod(jnp.sqrt(eigvals), axis=-1)def box_size_at_ellipsoid_density(particle_count: int,                                  density: float,                                  eigvals: Array):    if eigvals.ndim > 2:        raise ValueError("Eigenvalue matrix should have at most 2 dimensions.")    spatial_dimension = eigvals.shape[-1]    particle_volume = ellipsoid_volume(eigvals)    if particle_volume.ndim == 0:        particle_volume = jnp.full((particle_count,), particle_volume)    total_particle_volume = jnp.sum(particle_volume)    return jnp.power(total_particle_volume / density, 1 / spatial_dimension)@jax.jitdef update_interaction_params(grad: data_protocols.InteractionParams,                              interaction_params: data_protocols.InteractionParams,                              learning_rate: float) -> data_protocols.InteractionParams:    """    Update interaction parameters with gradient descent step. Rescales the new ellipsoid eigenvalues    so that they correspond to unit volume particles.    """    new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, interaction_params, grad)    return canonicalize_eigvals(new_params)class OrientedParticleEnergy(Protocol):    """Protocol specifying the signature for energy functions between oriented particles."""    def __call__(self, dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:        ...def get_ellipsoid_contact_function(contact_function: Callable[..., Array], eigvals: Array, **cf_kwargs):    """    Return a function that calculates square root of the Perram-Wertheim contact function between a pair of ellipsoids.    """    def fun(dr: Array, eigsys1: Array, eigsys2: Array) -> Array:        qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))        qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))        return contact_function(dr, qf1, qf2, **cf_kwargs)    return fundef get_ellipsoid_contact_function_param(contact_function: Callable[..., Array], **cf_kwargs):    """    Return a function that calculates the contact function between a pair of ellipsoids with a standardized call    signature. It also does the transform from the standard quadratic form eigenvalues for ellipsoids (where    eigenvalues are invere squares of semiaxis lenghts) to the weight matrix used in the Perram-Wertheim contact    function (eigenvalues are just semiaxes squared, without the inverse).    """    def fun(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:        qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))        qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))        return contact_function(dr, qf1, qf2, **cf_kwargs)    return fundef isotropic_to_ellipsoid_energy(energy_fn: Callable[..., Array],                                  contact_function: Callable[..., Array],                                  eigvals: Array,                                  **cf_kwargs) -> OrientedParticleEnergy:    """Promotes an isotropic energy function to one acting between ellipsoids,    with a given contact function as a measure of distance."""    cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:        return energy_fn(cf(dr, eigsys1, eigsys2), **kwargs)    return ellipsoid_energy_fndef isotropic_to_cf_energy(energy_fn: Callable[..., Array],                           contact_function: Callable[..., Array],                           **cf_kwargs) -> OrientedParticleEnergy:    """Promotes an isotropic energy function to one acting between ellipsoids,    with a given contact function as a measure of distance."""    cf = get_ellipsoid_contact_function_param(contact_function, **cf_kwargs)    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:        return energy_fn(cf(dr, eigsys1, eigsys2, **cf_kwargs), **kwargs)    return ellipsoid_energy_fndef isotropic_to_ellipsoid_energy_with_cutoff(energy_fn: Callable[..., Array],                                              contact_function: Callable[..., Array],                                              eigvals: Array,                                              cf_onset: float,                                              cf_cutoff: float,                                              **cf_kwargs) -> OrientedParticleEnergy:    """    Promotes an isotropic energy function to one acting between ellipsoids,    with a given contact function as a measure of distance.    Adds the multiplicative isotropic cutoff to get a truncated function.    """    cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:        return energy.multiplicative_isotropic_cutoff(            energy_fn, cf_onset, cf_cutoff)(cf(dr, eigsys1, eigsys2), **kwargs)    return ellipsoid_energy_fn
 |