123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- from typing import Callable
- import jax
- import jax.numpy as jnp
- from jax_md import dataclasses, space, rigid_body
- from functools import partial
- from curvature_assembly import oriented_particle
- Array = jnp.ndarray
- @dataclasses.dataclass
- class Clusters:
- """Class to store cluster data."""
- n_clusters: int
- clusters: Array
- n_part_per_cluster: Array
- masks: Array
- @dataclasses.dataclass
- class Neighbors:
- """Class to store connectivity data."""
- node_next: Array
- n_contacts_per_particle: Array
- n_links: int
- def nonzero_1d(array: Array, size: int = None, fill_value: float = None) -> Array:
- """Vectorization of the nonzero function over the second array axis."""
- return jnp.vectorize(partial(jnp.nonzero, size=size, fill_value=fill_value), signature='(k)->(d)')(array)[0]
- def get_neighboring_fn(displacement: space.DisplacementFn, max_contacts: int = 16) -> Callable:
- """Return distance based neighboring function used to get particle contact information."""
- @jax.jit
- def neighboring(coord: Array, neigh_dist: float) -> Neighbors:
- metric = space.canonicalize_displacement_or_metric(displacement)
- dr = space.map_product(metric)(coord, coord)
- mask = jnp.float32(1.0) - jnp.eye(coord.shape[0])
- contacts = mask * jnp.array(dr < neigh_dist, dtype=jnp.int32)
- n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
- n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
- node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=len(coord)), dtype=jnp.int32)
- return Neighbors(node_next, n_contacts_per_particle, n_links)
- return neighboring
- def get_ellipsoid_neighboring_fn(displacement: space.DisplacementFn,
- contact_fn: Callable[..., Array],
- max_contacts: int = 16,
- **cf_kwargs) -> Callable:
- """Return contact function based neighboring function used to get particle contact information."""
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
- @jax.jit
- def neighboring(body: rigid_body.RigidBody, eigvals: Array, neigh_contact_fn: float) -> Neighbors:
- num_particles = body.center.shape[0]
- dr = space.map_product(displacement)(body.center, body.center)
- eigsys = oriented_particle.eigensystem(body.orientation)
- mapped_cf = jax.vmap(jax.vmap(partial(contact_function, eigvals=eigvals), (0, 0, None), 0), (0, None, 0), 0)
- cf = mapped_cf(dr, eigsys, eigsys)
- mask = jnp.float32(1.0) - jnp.eye(num_particles)
- contacts = mask * jnp.array(cf < neigh_contact_fn, dtype=jnp.int32)
- n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
- n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
- node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=num_particles), dtype=jnp.int32)
- return Neighbors(node_next, n_contacts_per_particle, n_links)
- return neighboring
- @partial(jax.jit, static_argnums=1)
- def clustering(neighbors: Neighbors, num_iter: int = 20) -> Clusters:
- """
- Clustering algorithm from de Oliveira et.al.,
- https://www.tandfonline.com/doi/full/10.1080/08927022.2020.1839661
- Args:
- neighbors: instance of Neighbors object where data about each particle neighbors is stored
- num_iter: number of iterations to fix cluster labels of neighboring particles
- Returns:
- Clustering data stored in Clusters object with the following attributes:
- n_clusters - number of clusters,
- n_part_per_cluster - array of length N storing the lengths of all clusters. Elements with index i > n_clusters
- are set to the total number of particles in the system.
- clusters - NxN array where each row i stores indices of particles in one cluster. Clusters[i, j] with
- i > n_clusters and/or j > n_part_per_cluster[i] are set to the total number of particles in the system.
- """
- num_particles = neighbors.node_next.shape[0]
- particle_indices = jnp.arange(num_particles)
- nodeLpure = jnp.zeros(num_particles, dtype=jnp.int32)
- nodeL = jnp.concatenate((nodeLpure, jnp.array([num_particles]))) # the last index is used as a dump for unused data
- def new_label(nodeL, n_clusters, idx, *args):
- nodeL = nodeL.at[idx].set(n_clusters)
- return nodeL, n_clusters + 1
- def neighbor_labels(nodeL, n_clusters, idx, neigh_indices, label):
- nodeL = nodeL.at[idx].set(label)
- nodeL = nodeL.at[neigh_indices].set(label)
- return nodeL, n_clusters
- def assign_labels(nodeL_ncl, idx, no_labeled_neighbors_fn):
- nodeL, n_clusters = nodeL_ncl
- labels = nodeL[neighbors.node_next[idx]]
- max_label = jnp.max(labels, where=(labels < num_particles), initial=0)
- min_label = jnp.min(labels, where=(labels > 0), initial=num_particles)
- nodeL, n_clusters = jax.lax.cond(max_label == 0, no_labeled_neighbors_fn, neighbor_labels,
- nodeL, n_clusters, idx, neighbors.node_next[idx], min_label)
- nodeL = nodeL.at[num_particles].set(num_particles)
- return (nodeL, n_clusters), 0.
- # assign labels to all clustered particles
- n_clusters = 0
- nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=new_label),
- (nodeL, n_clusters), xs=particle_indices)
- nodeL, _ = nodeL_ncl
- def keep_label(nodeL, n_clusters, *args):
- return nodeL, n_clusters
- def check_labels(condition, idx, nodeL):
- # add 1 to condition value if all neighbors of particle idx share its label
- labels = nodeL[neighbors.node_next[idx]]
- unique_labels = jnp.unique(labels, size=3, fill_value=num_particles)
- # check if all neighbor labels are the same:
- all_same = (unique_labels[1] == num_particles) * (unique_labels[0] == nodeL[idx])
- condition = jax.lax.cond(all_same, lambda x: condition + 1, lambda x: condition, 0.)
- return condition, 0.
- def fix_labels(nodeL):
- nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=keep_label),
- (nodeL, 0), xs=particle_indices)
- nodeL, _ = nodeL_ncl
- return nodeL
- def relabel_iteration(nodeL, iteration):
- # condition is the number of particles with all neighbors sharing its label
- condition, _ = jax.lax.scan(partial(check_labels, nodeL=nodeL), 0., xs=particle_indices)
- # if condition == num_particles, algorithm converged, and we return the same values for nodeL
- nodeL = jax.lax.cond(condition == num_particles, lambda x: x, fix_labels, nodeL)
- return nodeL, 0.
- # fixing the cluster labels of particles (max_iter iterations)
- nodeL, _ = jax.lax.scan(relabel_iteration, nodeL, xs=jnp.arange(num_iter))
- # determine cluster id values for all clusters
- nodeLpure = nodeL[:-1]
- id = jnp.unique(nodeLpure, size=num_particles, fill_value=num_particles)
- @partial(jnp.vectorize, signature='()->(d)')
- def set_cluster_vmap(cluster_id):
- return jnp.where(nodeLpure == cluster_id, size=num_particles, fill_value=num_particles)[0]
- @partial(jnp.vectorize, signature='()->(d)')
- def set_mask_vmap(cluster_id):
- return jnp.where(nodeLpure == cluster_id,
- jnp.ones(num_particles, dtype=jnp.int32),
- jnp.zeros(num_particles, dtype=jnp.int32))
- clusters = set_cluster_vmap(id)
- masks = set_mask_vmap(id)
- n_part_per_cluster = jnp.sum(clusters < num_particles, axis=1)
- n_clusters = jnp.sum(jnp.min(clusters, axis=1) < num_particles)
- return Clusters(n_clusters, clusters, n_part_per_cluster, masks)
- def get_cluster_particles(clusters: Clusters, idx: int) -> Array:
- """Clip fill values from cluster data and return only indices of particles in cluster idx."""
- return clusters.clusters[idx, :clusters.n_part_per_cluster[idx]]
- @jax.jit
- def get_cluster_mask(cluster: Array) -> Array:
- """Create a mask for particles belonging to cluster given as array of indices."""
- mask_extended = jnp.zeros(cluster.shape[0] + 1, dtype=jnp.int32)
- mask_extended = mask_extended.at[cluster].set(1.)
- return mask_extended[:-1]
- def get_all_cluster_masks(clusters: Clusters) -> Array:
- """
- Create a NxN array with each row representing a mask for one particle cluster. For rows with
- i > clusters.n_clusters, the elements of all masks are False.
- """
- f = jax.vmap(get_cluster_mask)
- return f(clusters.clusters)
|