smap.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import jax
  2. from curvature_assembly.oriented_particle import OrientedParticleEnergy, eigensystem
  3. import jax.numpy as jnp
  4. from functools import partial
  5. from typing import Callable
  6. from jax_md import space, smap, util, partition, rigid_body
  7. Array = jnp.ndarray
  8. def oriented_pair(fn: OrientedParticleEnergy,
  9. displacement: space.DisplacementFn,
  10. ignore_unused_parameters: bool = False,
  11. **kwargs) -> Callable[..., Array]:
  12. """
  13. Promotes a function that acts on a pair of ellipses to one on a system.
  14. Args:
  15. fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
  16. displacement: displacement function that calculates distances between particles.
  17. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
  18. passed to the mapped function get ignored if they were not first specified as keyword arguments
  19. when calling `oriented_pair(...)`.
  20. kwargs: arguments providing parameters to the mapped function.
  21. Return:
  22. A function fn_mapped that takes a RigidBody object.
  23. """
  24. kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
  25. merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
  26. def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
  27. rows, columns = jnp.triu_indices(body.center.shape[0], 1)
  28. particle1 = body[rows]
  29. particle2 = body[columns]
  30. dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
  31. eigsys1 = eigensystem(particle1.orientation)
  32. eigsys2 = eigensystem(particle2.orientation)
  33. _kwargs = merge_dicts(kwargs, dynamic_kwargs)
  34. # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
  35. all_pair_interctions = jax.vmap(partial(fn, **_kwargs))(dr, eigsys1, eigsys2)
  36. return util.high_precision_sum(all_pair_interctions)
  37. # def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
  38. # # this does not give the same results as the above fn_mapped, but it should?
  39. # d = space.map_product(partial(displacement, **dynamic_kwargs))
  40. # eigsys = eigensystem(body.orientation)
  41. # _kwargs = merge_dicts(kwargs, dynamic_kwargs)
  42. # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
  43. # # print(_kwargs)
  44. # dr = d(body.center, body.center)
  45. # meshx, meshy = jnp.meshgrid(jnp.arange(body.center.shape[0]), jnp.arange(body.center.shape[0]))
  46. # eigsys1 = eigsys[meshx]
  47. # eigsys2 = eigsys[meshy]
  48. # # print(dr.shape, eigsys1, eigsys2)
  49. # return util.high_precision_sum(smap._diagonal_mask(fn(dr, eigsys1, eigsys2, **_kwargs)),
  50. # axis=None, keepdims=False) * util.f32(0.5)
  51. return fn_mapped
  52. def oriented_pair_neighbor_list(fn: OrientedParticleEnergy,
  53. displacement: space.DisplacementFn,
  54. ignore_unused_parameters: bool = False,
  55. **kwargs) -> Callable[..., Array]:
  56. """
  57. Promotes a function acting on pairs of particles to use neighbor lists.
  58. Args:
  59. fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
  60. displacement: displacement function that calculates distances between particles.
  61. ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
  62. passed to the mapped function get ignored if they were not first specified as keyword arguments
  63. when calling `oriented_pair(...)`.
  64. kwargs: arguments providing parameters to the mapped function.
  65. Return:
  66. A function `fn_mapped` that takes a RigidBody object and a NeighborList object specifying neighbors.
  67. """
  68. kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
  69. merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
  70. def fn_mapped(body: rigid_body.RigidBody, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array:
  71. normalization = 2.0
  72. if partition.is_sparse(neighbor.format):
  73. particle1 = body[neighbor.idx[0]]
  74. particle2 = body[neighbor.idx[1]]
  75. dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
  76. eigsys1 = eigensystem(particle1.orientation)
  77. eigsys2 = eigensystem(particle2.orientation)
  78. mask = neighbor.idx[0] < body.center.shape[0] # takes care of fill values in neighbor lists
  79. if neighbor.format is partition.OrderedSparse:
  80. normalization = 1.0
  81. else:
  82. raise NotImplementedError('Only sparse neighbor lists are currently supported.')
  83. merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
  84. merged_kwargs = smap._neighborhood_kwargs_to_params(neighbor.format,
  85. neighbor.idx,
  86. None,
  87. merged_kwargs,
  88. param_combinators)
  89. out = jax.vmap(partial(fn, **merged_kwargs))(dr, eigsys1, eigsys2)
  90. if out.ndim > mask.ndim:
  91. ddim = out.ndim - mask.ndim
  92. mask = jnp.reshape(mask, mask.shape + (1,) * ddim)
  93. out *= mask
  94. return util.high_precision_sum(out) / normalization
  95. return fn_mapped