123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import jax
- from curvature_assembly.oriented_particle import OrientedParticleEnergy, eigensystem
- import jax.numpy as jnp
- from functools import partial
- from typing import Callable
- from jax_md import space, smap, util, partition, rigid_body
- Array = jnp.ndarray
- def oriented_pair(fn: OrientedParticleEnergy,
- displacement: space.DisplacementFn,
- ignore_unused_parameters: bool = False,
- **kwargs) -> Callable[..., Array]:
- """
- Promotes a function that acts on a pair of ellipses to one on a system.
- Args:
- fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
- displacement: displacement function that calculates distances between particles.
- ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
- passed to the mapped function get ignored if they were not first specified as keyword arguments
- when calling `oriented_pair(...)`.
- kwargs: arguments providing parameters to the mapped function.
- Return:
- A function fn_mapped that takes a RigidBody object.
- """
- kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
- merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
- def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
- rows, columns = jnp.triu_indices(body.center.shape[0], 1)
- particle1 = body[rows]
- particle2 = body[columns]
- dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
- eigsys1 = eigensystem(particle1.orientation)
- eigsys2 = eigensystem(particle2.orientation)
- _kwargs = merge_dicts(kwargs, dynamic_kwargs)
- # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
- all_pair_interctions = jax.vmap(partial(fn, **_kwargs))(dr, eigsys1, eigsys2)
- return util.high_precision_sum(all_pair_interctions)
- # def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
- # # this does not give the same results as the above fn_mapped, but it should?
- # d = space.map_product(partial(displacement, **dynamic_kwargs))
- # eigsys = eigensystem(body.orientation)
- # _kwargs = merge_dicts(kwargs, dynamic_kwargs)
- # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
- # # print(_kwargs)
- # dr = d(body.center, body.center)
- # meshx, meshy = jnp.meshgrid(jnp.arange(body.center.shape[0]), jnp.arange(body.center.shape[0]))
- # eigsys1 = eigsys[meshx]
- # eigsys2 = eigsys[meshy]
- # # print(dr.shape, eigsys1, eigsys2)
- # return util.high_precision_sum(smap._diagonal_mask(fn(dr, eigsys1, eigsys2, **_kwargs)),
- # axis=None, keepdims=False) * util.f32(0.5)
- return fn_mapped
- def oriented_pair_neighbor_list(fn: OrientedParticleEnergy,
- displacement: space.DisplacementFn,
- ignore_unused_parameters: bool = False,
- **kwargs) -> Callable[..., Array]:
- """
- Promotes a function acting on pairs of particles to use neighbor lists.
- Args:
- fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
- displacement: displacement function that calculates distances between particles.
- ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
- passed to the mapped function get ignored if they were not first specified as keyword arguments
- when calling `oriented_pair(...)`.
- kwargs: arguments providing parameters to the mapped function.
- Return:
- A function `fn_mapped` that takes a RigidBody object and a NeighborList object specifying neighbors.
- """
- kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
- merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
- def fn_mapped(body: rigid_body.RigidBody, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array:
- normalization = 2.0
- if partition.is_sparse(neighbor.format):
- particle1 = body[neighbor.idx[0]]
- particle2 = body[neighbor.idx[1]]
- dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
- eigsys1 = eigensystem(particle1.orientation)
- eigsys2 = eigensystem(particle2.orientation)
- mask = neighbor.idx[0] < body.center.shape[0] # takes care of fill values in neighbor lists
- if neighbor.format is partition.OrderedSparse:
- normalization = 1.0
- else:
- raise NotImplementedError('Only sparse neighbor lists are currently supported.')
- merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
- merged_kwargs = smap._neighborhood_kwargs_to_params(neighbor.format,
- neighbor.idx,
- None,
- merged_kwargs,
- param_combinators)
- out = jax.vmap(partial(fn, **merged_kwargs))(dr, eigsys1, eigsys2)
- if out.ndim > mask.ndim:
- ddim = out.ndim - mask.ndim
- mask = jnp.reshape(mask, mask.shape + (1,) * ddim)
- out *= mask
- return util.high_precision_sum(out) / normalization
- return fn_mapped
|