clustering.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from typing import Callable
  2. import jax
  3. import jax.numpy as jnp
  4. from jax_md import dataclasses, space, rigid_body
  5. from functools import partial
  6. from curvature_assembly import oriented_particle
  7. Array = jnp.ndarray
  8. @dataclasses.dataclass
  9. class Clusters:
  10. """Class to store cluster data."""
  11. n_clusters: int
  12. clusters: Array
  13. n_part_per_cluster: Array
  14. masks: Array
  15. @dataclasses.dataclass
  16. class Neighbors:
  17. """Class to store connectivity data."""
  18. node_next: Array
  19. n_contacts_per_particle: Array
  20. n_links: int
  21. def nonzero_1d(array: Array, size: int = None, fill_value: float = None) -> Array:
  22. """Vectorization of the nonzero function over the second array axis."""
  23. return jnp.vectorize(partial(jnp.nonzero, size=size, fill_value=fill_value), signature='(k)->(d)')(array)[0]
  24. def get_neighboring_fn(displacement: space.DisplacementFn, max_contacts: int = 16) -> Callable:
  25. """Return distance based neighboring function used to get particle contact information."""
  26. @jax.jit
  27. def neighboring(coord: Array, neigh_dist: float) -> Neighbors:
  28. metric = space.canonicalize_displacement_or_metric(displacement)
  29. dr = space.map_product(metric)(coord, coord)
  30. mask = jnp.float32(1.0) - jnp.eye(coord.shape[0])
  31. contacts = mask * jnp.array(dr < neigh_dist, dtype=jnp.int32)
  32. n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
  33. n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
  34. node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=len(coord)), dtype=jnp.int32)
  35. return Neighbors(node_next, n_contacts_per_particle, n_links)
  36. return neighboring
  37. def get_ellipsoid_neighboring_fn(displacement: space.DisplacementFn,
  38. contact_fn: Callable[..., Array],
  39. max_contacts: int = 16,
  40. **cf_kwargs) -> Callable:
  41. """Return contact function based neighboring function used to get particle contact information."""
  42. contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
  43. @jax.jit
  44. def neighboring(body: rigid_body.RigidBody, eigvals: Array, neigh_contact_fn: float) -> Neighbors:
  45. num_particles = body.center.shape[0]
  46. dr = space.map_product(displacement)(body.center, body.center)
  47. eigsys = oriented_particle.eigensystem(body.orientation)
  48. mapped_cf = jax.vmap(jax.vmap(partial(contact_function, eigvals=eigvals), (0, 0, None), 0), (0, None, 0), 0)
  49. cf = mapped_cf(dr, eigsys, eigsys)
  50. mask = jnp.float32(1.0) - jnp.eye(num_particles)
  51. contacts = mask * jnp.array(cf < neigh_contact_fn, dtype=jnp.int32)
  52. n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
  53. n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
  54. node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=num_particles), dtype=jnp.int32)
  55. return Neighbors(node_next, n_contacts_per_particle, n_links)
  56. return neighboring
  57. @partial(jax.jit, static_argnums=1)
  58. def clustering(neighbors: Neighbors, num_iter: int = 20) -> Clusters:
  59. """
  60. Clustering algorithm from de Oliveira et.al.,
  61. https://www.tandfonline.com/doi/full/10.1080/08927022.2020.1839661
  62. Args:
  63. neighbors: instance of Neighbors object where data about each particle neighbors is stored
  64. num_iter: number of iterations to fix cluster labels of neighboring particles
  65. Returns:
  66. Clustering data stored in Clusters object with the following attributes:
  67. n_clusters - number of clusters,
  68. n_part_per_cluster - array of length N storing the lengths of all clusters. Elements with index i > n_clusters
  69. are set to the total number of particles in the system.
  70. clusters - NxN array where each row i stores indices of particles in one cluster. Clusters[i, j] with
  71. i > n_clusters and/or j > n_part_per_cluster[i] are set to the total number of particles in the system.
  72. """
  73. num_particles = neighbors.node_next.shape[0]
  74. particle_indices = jnp.arange(num_particles)
  75. nodeLpure = jnp.zeros(num_particles, dtype=jnp.int32)
  76. nodeL = jnp.concatenate((nodeLpure, jnp.array([num_particles]))) # the last index is used as a dump for unused data
  77. def new_label(nodeL, n_clusters, idx, *args):
  78. nodeL = nodeL.at[idx].set(n_clusters)
  79. return nodeL, n_clusters + 1
  80. def neighbor_labels(nodeL, n_clusters, idx, neigh_indices, label):
  81. nodeL = nodeL.at[idx].set(label)
  82. nodeL = nodeL.at[neigh_indices].set(label)
  83. return nodeL, n_clusters
  84. def assign_labels(nodeL_ncl, idx, no_labeled_neighbors_fn):
  85. nodeL, n_clusters = nodeL_ncl
  86. labels = nodeL[neighbors.node_next[idx]]
  87. max_label = jnp.max(labels, where=(labels < num_particles), initial=0)
  88. min_label = jnp.min(labels, where=(labels > 0), initial=num_particles)
  89. nodeL, n_clusters = jax.lax.cond(max_label == 0, no_labeled_neighbors_fn, neighbor_labels,
  90. nodeL, n_clusters, idx, neighbors.node_next[idx], min_label)
  91. nodeL = nodeL.at[num_particles].set(num_particles)
  92. return (nodeL, n_clusters), 0.
  93. # assign labels to all clustered particles
  94. n_clusters = 0
  95. nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=new_label),
  96. (nodeL, n_clusters), xs=particle_indices)
  97. nodeL, _ = nodeL_ncl
  98. def keep_label(nodeL, n_clusters, *args):
  99. return nodeL, n_clusters
  100. def check_labels(condition, idx, nodeL):
  101. # add 1 to condition value if all neighbors of particle idx share its label
  102. labels = nodeL[neighbors.node_next[idx]]
  103. unique_labels = jnp.unique(labels, size=3, fill_value=num_particles)
  104. # check if all neighbor labels are the same:
  105. all_same = (unique_labels[1] == num_particles) * (unique_labels[0] == nodeL[idx])
  106. condition = jax.lax.cond(all_same, lambda x: condition + 1, lambda x: condition, 0.)
  107. return condition, 0.
  108. def fix_labels(nodeL):
  109. nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=keep_label),
  110. (nodeL, 0), xs=particle_indices)
  111. nodeL, _ = nodeL_ncl
  112. return nodeL
  113. def relabel_iteration(nodeL, iteration):
  114. # condition is the number of particles with all neighbors sharing its label
  115. condition, _ = jax.lax.scan(partial(check_labels, nodeL=nodeL), 0., xs=particle_indices)
  116. # if condition == num_particles, algorithm converged, and we return the same values for nodeL
  117. nodeL = jax.lax.cond(condition == num_particles, lambda x: x, fix_labels, nodeL)
  118. return nodeL, 0.
  119. # fixing the cluster labels of particles (max_iter iterations)
  120. nodeL, _ = jax.lax.scan(relabel_iteration, nodeL, xs=jnp.arange(num_iter))
  121. # determine cluster id values for all clusters
  122. nodeLpure = nodeL[:-1]
  123. id = jnp.unique(nodeLpure, size=num_particles, fill_value=num_particles)
  124. @partial(jnp.vectorize, signature='()->(d)')
  125. def set_cluster_vmap(cluster_id):
  126. return jnp.where(nodeLpure == cluster_id, size=num_particles, fill_value=num_particles)[0]
  127. @partial(jnp.vectorize, signature='()->(d)')
  128. def set_mask_vmap(cluster_id):
  129. return jnp.where(nodeLpure == cluster_id,
  130. jnp.ones(num_particles, dtype=jnp.int32),
  131. jnp.zeros(num_particles, dtype=jnp.int32))
  132. clusters = set_cluster_vmap(id)
  133. masks = set_mask_vmap(id)
  134. n_part_per_cluster = jnp.sum(clusters < num_particles, axis=1)
  135. n_clusters = jnp.sum(jnp.min(clusters, axis=1) < num_particles)
  136. return Clusters(n_clusters, clusters, n_part_per_cluster, masks)
  137. def get_cluster_particles(clusters: Clusters, idx: int) -> Array:
  138. """Clip fill values from cluster data and return only indices of particles in cluster idx."""
  139. return clusters.clusters[idx, :clusters.n_part_per_cluster[idx]]
  140. @jax.jit
  141. def get_cluster_mask(cluster: Array) -> Array:
  142. """Create a mask for particles belonging to cluster given as array of indices."""
  143. mask_extended = jnp.zeros(cluster.shape[0] + 1, dtype=jnp.int32)
  144. mask_extended = mask_extended.at[cluster].set(1.)
  145. return mask_extended[:-1]
  146. def get_all_cluster_masks(clusters: Clusters) -> Array:
  147. """
  148. Create a NxN array with each row representing a mask for one particle cluster. For rows with
  149. i > clusters.n_clusters, the elements of all masks are False.
  150. """
  151. f = jax.vmap(get_cluster_mask)
  152. return f(clusters.clusters)