from typing import Callable import jax import jax.numpy as jnp from jax_md import dataclasses, space, rigid_body from functools import partial from curvature_assembly import oriented_particle Array = jnp.ndarray @dataclasses.dataclass class Clusters: """Class to store cluster data.""" n_clusters: int clusters: Array n_part_per_cluster: Array masks: Array @dataclasses.dataclass class Neighbors: """Class to store connectivity data.""" node_next: Array n_contacts_per_particle: Array n_links: int def nonzero_1d(array: Array, size: int = None, fill_value: float = None) -> Array: """Vectorization of the nonzero function over the second array axis.""" return jnp.vectorize(partial(jnp.nonzero, size=size, fill_value=fill_value), signature='(k)->(d)')(array)[0] def get_neighboring_fn(displacement: space.DisplacementFn, max_contacts: int = 16) -> Callable: """Return distance based neighboring function used to get particle contact information.""" @jax.jit def neighboring(coord: Array, neigh_dist: float) -> Neighbors: metric = space.canonicalize_displacement_or_metric(displacement) dr = space.map_product(metric)(coord, coord) mask = jnp.float32(1.0) - jnp.eye(coord.shape[0]) contacts = mask * jnp.array(dr < neigh_dist, dtype=jnp.int32) n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32) n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32) node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=len(coord)), dtype=jnp.int32) return Neighbors(node_next, n_contacts_per_particle, n_links) return neighboring def get_ellipsoid_neighboring_fn(displacement: space.DisplacementFn, contact_fn: Callable[..., Array], max_contacts: int = 16, **cf_kwargs) -> Callable: """Return contact function based neighboring function used to get particle contact information.""" contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs) @jax.jit def neighboring(body: rigid_body.RigidBody, eigvals: Array, neigh_contact_fn: float) -> Neighbors: num_particles = body.center.shape[0] dr = space.map_product(displacement)(body.center, body.center) eigsys = oriented_particle.eigensystem(body.orientation) mapped_cf = jax.vmap(jax.vmap(partial(contact_function, eigvals=eigvals), (0, 0, None), 0), (0, None, 0), 0) cf = mapped_cf(dr, eigsys, eigsys) mask = jnp.float32(1.0) - jnp.eye(num_particles) contacts = mask * jnp.array(cf < neigh_contact_fn, dtype=jnp.int32) n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32) n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32) node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=num_particles), dtype=jnp.int32) return Neighbors(node_next, n_contacts_per_particle, n_links) return neighboring @partial(jax.jit, static_argnums=1) def clustering(neighbors: Neighbors, num_iter: int = 20) -> Clusters: """ Clustering algorithm from de Oliveira et.al., https://www.tandfonline.com/doi/full/10.1080/08927022.2020.1839661 Args: neighbors: instance of Neighbors object where data about each particle neighbors is stored num_iter: number of iterations to fix cluster labels of neighboring particles Returns: Clustering data stored in Clusters object with the following attributes: n_clusters - number of clusters, n_part_per_cluster - array of length N storing the lengths of all clusters. Elements with index i > n_clusters are set to the total number of particles in the system. clusters - NxN array where each row i stores indices of particles in one cluster. Clusters[i, j] with i > n_clusters and/or j > n_part_per_cluster[i] are set to the total number of particles in the system. """ num_particles = neighbors.node_next.shape[0] particle_indices = jnp.arange(num_particles) nodeLpure = jnp.zeros(num_particles, dtype=jnp.int32) nodeL = jnp.concatenate((nodeLpure, jnp.array([num_particles]))) # the last index is used as a dump for unused data def new_label(nodeL, n_clusters, idx, *args): nodeL = nodeL.at[idx].set(n_clusters) return nodeL, n_clusters + 1 def neighbor_labels(nodeL, n_clusters, idx, neigh_indices, label): nodeL = nodeL.at[idx].set(label) nodeL = nodeL.at[neigh_indices].set(label) return nodeL, n_clusters def assign_labels(nodeL_ncl, idx, no_labeled_neighbors_fn): nodeL, n_clusters = nodeL_ncl labels = nodeL[neighbors.node_next[idx]] max_label = jnp.max(labels, where=(labels < num_particles), initial=0) min_label = jnp.min(labels, where=(labels > 0), initial=num_particles) nodeL, n_clusters = jax.lax.cond(max_label == 0, no_labeled_neighbors_fn, neighbor_labels, nodeL, n_clusters, idx, neighbors.node_next[idx], min_label) nodeL = nodeL.at[num_particles].set(num_particles) return (nodeL, n_clusters), 0. # assign labels to all clustered particles n_clusters = 0 nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=new_label), (nodeL, n_clusters), xs=particle_indices) nodeL, _ = nodeL_ncl def keep_label(nodeL, n_clusters, *args): return nodeL, n_clusters def check_labels(condition, idx, nodeL): # add 1 to condition value if all neighbors of particle idx share its label labels = nodeL[neighbors.node_next[idx]] unique_labels = jnp.unique(labels, size=3, fill_value=num_particles) # check if all neighbor labels are the same: all_same = (unique_labels[1] == num_particles) * (unique_labels[0] == nodeL[idx]) condition = jax.lax.cond(all_same, lambda x: condition + 1, lambda x: condition, 0.) return condition, 0. def fix_labels(nodeL): nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=keep_label), (nodeL, 0), xs=particle_indices) nodeL, _ = nodeL_ncl return nodeL def relabel_iteration(nodeL, iteration): # condition is the number of particles with all neighbors sharing its label condition, _ = jax.lax.scan(partial(check_labels, nodeL=nodeL), 0., xs=particle_indices) # if condition == num_particles, algorithm converged, and we return the same values for nodeL nodeL = jax.lax.cond(condition == num_particles, lambda x: x, fix_labels, nodeL) return nodeL, 0. # fixing the cluster labels of particles (max_iter iterations) nodeL, _ = jax.lax.scan(relabel_iteration, nodeL, xs=jnp.arange(num_iter)) # determine cluster id values for all clusters nodeLpure = nodeL[:-1] id = jnp.unique(nodeLpure, size=num_particles, fill_value=num_particles) @partial(jnp.vectorize, signature='()->(d)') def set_cluster_vmap(cluster_id): return jnp.where(nodeLpure == cluster_id, size=num_particles, fill_value=num_particles)[0] @partial(jnp.vectorize, signature='()->(d)') def set_mask_vmap(cluster_id): return jnp.where(nodeLpure == cluster_id, jnp.ones(num_particles, dtype=jnp.int32), jnp.zeros(num_particles, dtype=jnp.int32)) clusters = set_cluster_vmap(id) masks = set_mask_vmap(id) n_part_per_cluster = jnp.sum(clusters < num_particles, axis=1) n_clusters = jnp.sum(jnp.min(clusters, axis=1) < num_particles) return Clusters(n_clusters, clusters, n_part_per_cluster, masks) def get_cluster_particles(clusters: Clusters, idx: int) -> Array: """Clip fill values from cluster data and return only indices of particles in cluster idx.""" return clusters.clusters[idx, :clusters.n_part_per_cluster[idx]] @jax.jit def get_cluster_mask(cluster: Array) -> Array: """Create a mask for particles belonging to cluster given as array of indices.""" mask_extended = jnp.zeros(cluster.shape[0] + 1, dtype=jnp.int32) mask_extended = mask_extended.at[cluster].set(1.) return mask_extended[:-1] def get_all_cluster_masks(clusters: Clusters) -> Array: """ Create a NxN array with each row representing a mask for one particle cluster. For rows with i > clusters.n_clusters, the elements of all masks are False. """ f = jax.vmap(get_cluster_mask) return f(clusters.clusters)