oriented_particle.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from typing import Protocol, Callable, TypeVar
  2. import jax
  3. import jax.numpy as jnp
  4. from curvature_assembly import data_protocols
  5. from jax_md import rigid_body, energy, quantity
  6. from functools import partial
  7. Array = jnp.ndarray
  8. T = TypeVar('T')
  9. @partial(jnp.vectorize, signature='(d,d),(d,d)->(d,d)')
  10. def qf_from_rotation(rotation: Array, eigen_qf: Array) -> Array:
  11. """Get particle quadratic form in world frame given the rotation matrix that describes eigensystem orientation."""
  12. return jnp.linalg.multi_dot((rotation, eigen_qf, jnp.transpose(rotation)))
  13. @partial(jnp.vectorize, signature='(d)->(d,d)')
  14. def make_diagonal(eigvals: Array) -> Array:
  15. """Create diagonal matrix from an 1D array of length 3."""
  16. a, b, c = eigvals
  17. return jnp.array([[a, 0, 0],
  18. [0, b, 0],
  19. [0, 0, c]])
  20. def eigensystem(orientation: rigid_body.Quaternion) -> Array:
  21. """Get eigensystem matrix with eigenvectors as columns."""
  22. return jnp.moveaxis(rigid_body.space_to_body_rotation(orientation), -1, -2)
  23. def matrix_repr(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
  24. """Quadratic form of the oriented particle given the matrix eigenvalues and quaternion orientation."""
  25. return qf_from_rotation(eigensystem(orientation), make_diagonal(eigvals))
  26. def get_weight_matrices(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
  27. """Weight matrices of the rigid body with squared semi-axes lengths as matrix eigenvalues."""
  28. return matrix_repr(orientation, 1 / eigvals)
  29. @partial(jnp.vectorize, signature='(),(d)->(d)')
  30. def ellipsoid_moment_of_inertia(m, eigvals):
  31. eig1, eig2, eig3 = eigvals
  32. a2 = 1 / eig1
  33. b2 = 1 / eig2
  34. c2 = 1 / eig3
  35. return m / 5 * jnp.array([b2 + c2, a2 + c2, a2 + b2])
  36. def ellipsoid_mass(masses, eigvals) -> rigid_body.RigidBody:
  37. """Get an Ellipsoid with the mass and moment of inertia for each particle."""
  38. return rigid_body.RigidBody(masses, ellipsoid_moment_of_inertia(masses, eigvals))
  39. def contact_to_distance_cutoff(cf_cut: float, eigvals: Array) -> float:
  40. """
  41. Calculate a sufficient distance cutoff from the contact function cutoff.
  42. Contact function should be the square root of the Perram-Wertheim contact function.
  43. """
  44. return 2 / jnp.sqrt(jnp.min(eigvals)) * cf_cut
  45. def contact_to_distance_threshold(cf_cut: float, cf_theshold: float, eigvals: Array) -> float:
  46. """Map from threshold in contact function to the distance threshold. We take the minimal distance
  47. that comes from the particle move for cf_threshold at the very edge of the function range."""
  48. return contact_to_distance_cutoff(cf_cut, eigvals) - contact_to_distance_cutoff(cf_cut - cf_theshold, eigvals)
  49. def distance_to_contact_cutoff(r_cut: float, eigvals: Array) -> float:
  50. """
  51. Calculate a sufficient contact function cutoff from the distance cutoff.
  52. Contact function value returned corresponds to the square root of the Perram-Wertheim contact function.
  53. """
  54. return jnp.min(eigvals) * r_cut / 2
  55. def eigenvalues_at_unit_volume(eigenvalues: Array) -> Array:
  56. """Rescales the eigenvalues to get unit volume ellipsoids."""
  57. particle_volume = 4 * jnp.pi / 3 * jnp.prod(1 / jnp.sqrt(eigenvalues))
  58. return jnp.cbrt(particle_volume) ** 2 * eigenvalues
  59. def eigenvalues_to_semiaxes(eigenvalues: Array) -> Array:
  60. """Calculate ellipsoid semiaxes from eigenvalues."""
  61. return jnp.sort(1 / jnp.sqrt(eigenvalues))
  62. def canonicalize_eigvals(interaction_params: T) -> T:
  63. """
  64. Create a new InteractionParams instance with transformed eigenvalues
  65. so that they correspond to unit volume ellipsoidal particles.
  66. """
  67. params_dict = vars(interaction_params)
  68. new_dict = params_dict.copy() # shallow copy is enough as values (interaction_params elements) are jax arrays
  69. new_dict['eigvals'] = eigenvalues_at_unit_volume(params_dict['eigvals'])
  70. return type(interaction_params)(**new_dict)
  71. def box_size_at_number_density(particle_count: int,
  72. number_density: float,
  73. spatial_dimension: int = 3):
  74. return quantity.box_size_at_number_density(particle_count,
  75. number_density,
  76. spatial_dimension=spatial_dimension)
  77. def ellipsoid_volume(eigvals: Array):
  78. return 4 / 3 * jnp.pi / jnp.prod(jnp.sqrt(eigvals), axis=-1)
  79. def box_size_at_ellipsoid_density(particle_count: int,
  80. density: float,
  81. eigvals: Array):
  82. if eigvals.ndim > 2:
  83. raise ValueError("Eigenvalue matrix should have at most 2 dimensions.")
  84. spatial_dimension = eigvals.shape[-1]
  85. particle_volume = ellipsoid_volume(eigvals)
  86. if particle_volume.ndim == 0:
  87. particle_volume = jnp.full((particle_count,), particle_volume)
  88. total_particle_volume = jnp.sum(particle_volume)
  89. return jnp.power(total_particle_volume / density, 1 / spatial_dimension)
  90. @jax.jit
  91. def update_interaction_params(grad: data_protocols.InteractionParams,
  92. interaction_params: data_protocols.InteractionParams,
  93. learning_rate: float) -> data_protocols.InteractionParams:
  94. """
  95. Update interaction parameters with gradient descent step. Rescales the new ellipsoid eigenvalues
  96. so that they correspond to unit volume particles.
  97. """
  98. new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, interaction_params, grad)
  99. return canonicalize_eigvals(new_params)
  100. class OrientedParticleEnergy(Protocol):
  101. """Protocol specifying the signature for energy functions between oriented particles."""
  102. def __call__(self, dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
  103. ...
  104. def get_ellipsoid_contact_function(contact_function: Callable[..., Array], eigvals: Array, **cf_kwargs):
  105. """
  106. Return a function that calculates square root of the Perram-Wertheim contact function between a pair of ellipsoids.
  107. """
  108. def fun(dr: Array, eigsys1: Array, eigsys2: Array) -> Array:
  109. qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
  110. qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
  111. return contact_function(dr, qf1, qf2, **cf_kwargs)
  112. return fun
  113. def get_ellipsoid_contact_function_param(contact_function: Callable[..., Array], **cf_kwargs):
  114. """
  115. Return a function that calculates the contact function between a pair of ellipsoids with a standardized call
  116. signature. It also does the transform from the standard quadratic form eigenvalues for ellipsoids (where
  117. eigenvalues are invere squares of semiaxis lenghts) to the weight matrix used in the Perram-Wertheim contact
  118. function (eigenvalues are just semiaxes squared, without the inverse).
  119. """
  120. def fun(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:
  121. qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
  122. qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
  123. return contact_function(dr, qf1, qf2, **cf_kwargs)
  124. return fun
  125. def isotropic_to_ellipsoid_energy(energy_fn: Callable[..., Array],
  126. contact_function: Callable[..., Array],
  127. eigvals: Array,
  128. **cf_kwargs) -> OrientedParticleEnergy:
  129. """Promotes an isotropic energy function to one acting between ellipsoids,
  130. with a given contact function as a measure of distance."""
  131. cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
  132. def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
  133. return energy_fn(cf(dr, eigsys1, eigsys2), **kwargs)
  134. return ellipsoid_energy_fn
  135. def isotropic_to_cf_energy(energy_fn: Callable[..., Array],
  136. contact_function: Callable[..., Array],
  137. **cf_kwargs) -> OrientedParticleEnergy:
  138. """Promotes an isotropic energy function to one acting between ellipsoids,
  139. with a given contact function as a measure of distance."""
  140. cf = get_ellipsoid_contact_function_param(contact_function, **cf_kwargs)
  141. def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
  142. return energy_fn(cf(dr, eigsys1, eigsys2, **cf_kwargs), **kwargs)
  143. return ellipsoid_energy_fn
  144. def isotropic_to_ellipsoid_energy_with_cutoff(energy_fn: Callable[..., Array],
  145. contact_function: Callable[..., Array],
  146. eigvals: Array,
  147. cf_onset: float,
  148. cf_cutoff: float,
  149. **cf_kwargs) -> OrientedParticleEnergy:
  150. """
  151. Promotes an isotropic energy function to one acting between ellipsoids,
  152. with a given contact function as a measure of distance.
  153. Adds the multiplicative isotropic cutoff to get a truncated function.
  154. """
  155. cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
  156. def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
  157. return energy.multiplicative_isotropic_cutoff(
  158. energy_fn, cf_onset, cf_cutoff)(cf(dr, eigsys1, eigsys2), **kwargs)
  159. return ellipsoid_energy_fn