import jax from curvature_assembly.oriented_particle import OrientedParticleEnergy, eigensystem import jax.numpy as jnp from functools import partial from typing import Callable from jax_md import space, smap, util, partition, rigid_body Array = jnp.ndarray def oriented_pair(fn: OrientedParticleEnergy, displacement: space.DisplacementFn, ignore_unused_parameters: bool = False, **kwargs) -> Callable[..., Array]: """ Promotes a function that acts on a pair of ellipses to one on a system. Args: fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments. displacement: displacement function that calculates distances between particles. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling `oriented_pair(...)`. kwargs: arguments providing parameters to the mapped function. Return: A function fn_mapped that takes a RigidBody object. """ kwargs, param_combinators = smap._split_params_and_combinators(kwargs) merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array: rows, columns = jnp.triu_indices(body.center.shape[0], 1) particle1 = body[rows] particle2 = body[columns] dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center) eigsys1 = eigensystem(particle1.orientation) eigsys2 = eigensystem(particle2.orientation) _kwargs = merge_dicts(kwargs, dynamic_kwargs) # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators) all_pair_interctions = jax.vmap(partial(fn, **_kwargs))(dr, eigsys1, eigsys2) return util.high_precision_sum(all_pair_interctions) # def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array: # # this does not give the same results as the above fn_mapped, but it should? # d = space.map_product(partial(displacement, **dynamic_kwargs)) # eigsys = eigensystem(body.orientation) # _kwargs = merge_dicts(kwargs, dynamic_kwargs) # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators) # # print(_kwargs) # dr = d(body.center, body.center) # meshx, meshy = jnp.meshgrid(jnp.arange(body.center.shape[0]), jnp.arange(body.center.shape[0])) # eigsys1 = eigsys[meshx] # eigsys2 = eigsys[meshy] # # print(dr.shape, eigsys1, eigsys2) # return util.high_precision_sum(smap._diagonal_mask(fn(dr, eigsys1, eigsys2, **_kwargs)), # axis=None, keepdims=False) * util.f32(0.5) return fn_mapped def oriented_pair_neighbor_list(fn: OrientedParticleEnergy, displacement: space.DisplacementFn, ignore_unused_parameters: bool = False, **kwargs) -> Callable[..., Array]: """ Promotes a function acting on pairs of particles to use neighbor lists. Args: fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments. displacement: displacement function that calculates distances between particles. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments passed to the mapped function get ignored if they were not first specified as keyword arguments when calling `oriented_pair(...)`. kwargs: arguments providing parameters to the mapped function. Return: A function `fn_mapped` that takes a RigidBody object and a NeighborList object specifying neighbors. """ kwargs, param_combinators = smap._split_params_and_combinators(kwargs) merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters) def fn_mapped(body: rigid_body.RigidBody, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array: normalization = 2.0 if partition.is_sparse(neighbor.format): particle1 = body[neighbor.idx[0]] particle2 = body[neighbor.idx[1]] dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center) eigsys1 = eigensystem(particle1.orientation) eigsys2 = eigensystem(particle2.orientation) mask = neighbor.idx[0] < body.center.shape[0] # takes care of fill values in neighbor lists if neighbor.format is partition.OrderedSparse: normalization = 1.0 else: raise NotImplementedError('Only sparse neighbor lists are currently supported.') merged_kwargs = merge_dicts(kwargs, dynamic_kwargs) merged_kwargs = smap._neighborhood_kwargs_to_params(neighbor.format, neighbor.idx, None, merged_kwargs, param_combinators) out = jax.vmap(partial(fn, **merged_kwargs))(dr, eigsys1, eigsys2) if out.ndim > mask.ndim: ddim = out.ndim - mask.ndim mask = jnp.reshape(mask, mask.shape + (1,) * ddim) out *= mask return util.high_precision_sum(out) / normalization return fn_mapped