123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536 |
- from typing import Callable, Literal, Protocol, Any
- import jax
- import jax.numpy as jnp
- from curvature_assembly import oriented_particle, clustering, smap, surface_fit, surface_fit_general
- from jax_md import rigid_body, space
- from functools import partial
- Array = jnp.ndarray
- class CostFn(Protocol):
- def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
- ...
- def __str__(self):
- ...
- @jax.jit
- def jit_flat_plane_cost_function(body: rigid_body.RigidBody, **interaction_params) -> float:
- n_part = body.center.shape[0]
- cm = jnp.mean(body.center, axis=0)
- matrix = 1 / n_part * jnp.sum(jnp.einsum('ni, nj -> nij', body.center - cm, body.center - cm), axis=0)
- values, vecs = jnp.linalg.eigh(matrix)
- return values[0]
- def single_cluster_cost(mask, body):
- n_part = jnp.sum(mask)
- cluster_particles = body.center * mask[:, None]
- centered = (cluster_particles - jnp.sum(cluster_particles, axis=0) / (n_part + 1e-8)) * mask[:, None]
- matrix = 1 / (n_part + 1e-8) * centered.T @ centered
- values, vecs = jnp.linalg.eigh(matrix)
- return values[0]
- class FlatPlaneClustersCost:
- def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
- self.num_clusters_penalty = num_clusters_penalty
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- def __str__(self):
- return f"Flat plane clusters cost, penalty={self.num_clusters_penalty}"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
- cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
- return jnp.sum(cluster_costs) + self.num_clusters_penalty * clusters.n_clusters
- class DistanceCost:
- def __init__(self, displacement, contact_fn):
- self.displacement = displacement
- self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, num_steps=25)
- def __str__(self):
- return "Distance cost"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
- num_particles = body.center.shape[0]
- dr = space.map_product(self.displacement)(body.center, body.center)
- eigsys = oriented_particle.eigensystem(body.orientation)
- cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
- (0, 0, None), 0), (0, None, 0), 0)(dr, eigsys, eigsys)
- mask = jnp.float32(1.0) - jnp.eye(num_particles)
- return 0.5 * jnp.sum(mask * cf)
- class NumClustersCost:
- def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
- self.num_clusters_penalty = num_clusters_penalty
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- def __str__(self):
- return f"Num clusters cost, penalty={self.num_clusters_penalty}"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- return self.num_clusters_penalty * clusters.n_clusters
- class FlatPlaneCost:
- def __init__(self, displacement, contact_fn):
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- def __str__(self):
- return "Flat plane cost"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
- cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
- return jnp.sum(cluster_costs)
- class SquaredClusterSizeCost:
- def __init__(self, displacement, contact_fn):
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- def __str__(self):
- return "Squared cluster size cost"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- num_particles = jnp.sum(clusters.n_part_per_cluster)
- return num_particles ** 2 / jnp.sum(clusters.n_part_per_cluster ** 2) - 1
- class FlatNeighbors:
- def __init__(self, displacement, contact_fn, num_neighbors: int = 6, **cf_kwargs):
- self.num_neighbors = num_neighbors
- self.displacement = displacement
- self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
- def __str__(self):
- return f"Flat neighbors cost, num_neighbors={self.num_neighbors}"
- def local_flatness(self, neighbor_indices: Array, body: rigid_body.RigidBody):
- neighbor_coord = body.center[neighbor_indices]
- mean = jnp.mean(neighbor_coord, axis=0, keepdims=True)
- centered_r = neighbor_coord - mean
- matrix = 1 / self.num_neighbors * centered_r.T @ centered_r
- values, vecs = jnp.linalg.eigh(matrix)
- return values[0]
- def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
- dr = space.map_product(self.displacement)(body.center, body.center)
- eigsys = oriented_particle.eigensystem(body.orientation)
- mapped_cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
- (0, 0, None), 0), (0, None, 0), 0)
- cf = mapped_cf(dr, eigsys, eigsys)
- indices = jnp.argsort(cf, axis=1)[:, 1:self.num_neighbors+1] # if we start with idx 0, also the central particle will count
- neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body))(indices)
- return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others?
- class WeightedFlatNeighbors:
- def __init__(self, displacement, weight_fn: Callable[[Array], Array]):
- self.weight_fn = weight_fn
- self.distance = jax.vmap(displacement, in_axes=(0, None))
- def __str__(self):
- return "Weighted flat neighbors cost"
- def local_flatness(self, particle_coord: Array, body: rigid_body.RigidBody):
- dr = self.distance(body.center, particle_coord) + 1e-8 # adding 1e-8 prevents Nan gradients
- neighbor_weights = self.weight_fn(jnp.linalg.norm(dr, axis=1))
- mean = 1 / jnp.sum(neighbor_weights) * jnp.sum(dr * neighbor_weights[:, None], axis=0, keepdims=True)
- centered_r = dr - mean
- matrix = 1 / jnp.sum(neighbor_weights) * centered_r.T @ jnp.diag(neighbor_weights) @ centered_r
- values, vecs = jnp.linalg.eigh(matrix)
- return values[0] * jnp.abs(values[1] - values[2])
- def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
- neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body), in_axes=(0,))(body.center)
- return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others?
- class WeightedDistanceCost:
- def _init__(self, displacement,
- contact_fn,
- weight_fn: Callable[[Array], Array],
- **cf_kwargs):
- contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
- self.cf_cost_fn = smap.oriented_pair(
- lambda r, e1, e2, **params: weight_fn(contact_function(r, e1, e2, params['eigvals'])),
- displacement)
- def __str__(self):
- return "Weighted distance cost"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
- num_particles = body.center.shape[0]
- return -self.cf_cost_fn(body, **interaction_params) / num_particles
- def normal_weight(x: Array, sigma: float, displacement: float = 1.):
- return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - displacement) / sigma) ** 2)
- def center_of_mass_pbc(coord, box_size, mask):
- """Calculate the center of mass of a cluster given by mask, taking into account periodic boundary conditions."""
- num_cluster_particles = jnp.sum(mask)
- angle = coord * 2 * jnp.pi / box_size
- avg_sin = jnp.sum(jnp.sin(angle) * mask[:, None], axis=0) / num_cluster_particles
- avg_cos = jnp.sum(jnp.cos(angle) * mask[:, None], axis=0) / num_cluster_particles
- avg_angle = jnp.arctan2(-avg_sin, -avg_cos) + jnp.pi
- return box_size * avg_angle / (2 * jnp.pi)
- def displace_with_periodic_cm(coord, displacement, box_size, mask):
- """Displace all coordinates with a center of mass for a cluster given by mask."""
- cm = center_of_mass_pbc(coord, box_size, mask)
- mapped_displacement = jax.vmap(displacement, in_axes=(0, None))
- displaced_coord = mapped_displacement(coord, cm)
- return displaced_coord
- # @partial(jnp.vectorize, signature='(m),(m,m),(m,n),()->(m,n)')
- def contiguous_clusters(mask_pbc: Array, mask_box: Array, coord: Array, box_size: float):
- """
- Map clusters in a PBC box to a contiguous cluster with the center of mass in the coordinate origin.
- The algorithm is simple but can fail in certain cases. It is based on calculating distances between the center of
- mass for the whole cluster and all the centers of mass for the subclusters that we get by not taking into account
- periodic boundary conditions. The subclusters are displaced for the PBC period (box size) if this calculated
- distance is more than a half of box size in any component.
- :param mask_pbc: mask for a SINGLE cluster, taking into account PBC
- :param mask_box: N x N array containing all the masks for subclusters without PBC
- :param coord: coordinated of all particles in the box
- :param box_size: size of box side, assumes all sides are equal
- :return: displaced coordinates of all particles, with the center of mass of the given cluster in coordinate origin.
- """
- cm = center_of_mass_pbc(coord, box_size, mask_pbc)
- cm_subcl = jnp.sum(mask_box[..., None] * coord[None, ...], axis=1) / (jnp.sum(mask_box, axis=1) + 1e-6)[:,None]
- dist_between_cm = cm_subcl - cm[None, :]
- cluster_displacements = jnp.where(jnp.abs(dist_between_cm) > box_size / 2,
- jnp.sign(dist_between_cm) * jnp.full((coord.shape[0], 3), box_size,
- dtype=jnp.float64),
- jnp.zeros((coord.shape[0], 3)))
- # jnp.isclose must be used as we add 1e-12 to mask_pbc in some places to enable differentiation through clustering
- relevant_subclusters = jnp.all(jnp.isclose(mask_pbc[None, :] * mask_box, mask_box), axis=1)
- cluster_displacements = cluster_displacements * relevant_subclusters[:, None]
- # sum will just collapse the first dimension with no overlap as each particle can be a part of only one cluster
- particle_displacements = jnp.sum(cluster_displacements[:, None, :] * mask_box[:, :, None], axis=0)
- displaced_coord = coord - particle_displacements - cm
- return displaced_coord - jnp.sum(displaced_coord * mask_pbc[:, None], axis=0) / jnp.sum(mask_pbc)
- def box_displacement(Ra, Rb):
- """Calculate displacement vector in the simulation box without PBCs."""
- return space.pairwise_displacement(Ra, Rb)
- ResidualsAvgType = Literal['linear', 'quadratic']
- ResidualsAvgFn = Callable[[Array, Array], Array]
- def residuals_avg_fn_factory(which: ResidualsAvgType) -> ResidualsAvgFn:
- if which == 'linear':
- return lambda residuals, mask: jnp.sum(jnp.abs(residuals)) / (jnp.sum(mask) + 1e-6)
- if which == 'quadratic':
- return lambda residuals, mask: jnp.sum(residuals ** 2) / (jnp.sum(mask) + 1e-6)
- raise ValueError('Unknown type of residuals cost function.')
- @partial(jax.jit, static_argnums=(4,))
- def single_cluster_curv_radius(mask_pbc: Array,
- mask_box: Array,
- body: rigid_body.RigidBody,
- box_size: float,
- residuals_avg_fn: ResidualsAvgFn = residuals_avg_fn_factory('linear')):
- """
- Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
- filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
- the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
- whole system rigid_body.RigidBody.
- Returns:
- - fitted cluster radius
- - mean residuals for the fit
- """
- displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC
- cluster_coord = displaced_coord * mask_pbc[:, None]
- matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
- values, vecs = jnp.linalg.eigh(matrix)
- p0_1 = surface_fit.QuadraticSurfaceParams(center=2 * vecs[:, 0],
- radius=1.)
- p0_2 = surface_fit.QuadraticSurfaceParams(center=-2 * vecs[:, 0],
- radius=1.)
- opt1 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_1)
- opt2 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_2)
- mean_residuals1 = residuals_avg_fn(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc), mask_pbc)
- mean_residuals2 = residuals_avg_fn(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc), mask_pbc)
- # first we order solutions based on fit residuals
- (opt1, opt2), (mr1, mr2) = jax.lax.cond(mean_residuals1 < mean_residuals2,
- lambda: ((opt1, opt2), (mean_residuals1, mean_residuals2)),
- lambda: ((opt2, opt1), (mean_residuals2, mean_residuals1)))
- # the second criterion makes sure to select the fit that didn't fail to converge
- # (in most cases, at least one converges)
- return jax.lax.cond(jnp.abs(opt1.radius) < 1e4,
- lambda: (jnp.abs(opt1.radius), mr1),
- lambda: (jnp.abs(opt2.radius), mr2))
- class CurvedClustersCost:
- def __init__(self,
- displacement,
- box_size,
- contact_fn,
- target_radius,
- radius_cutoff_mul=100):
- self.box_size = box_size
- self.target_radius = target_radius
- self.displacement = displacement
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
- self.cutoff = radius_cutoff_mul
- def __str__(self):
- return f"Curved clusters cost, target_radius={self.target_radius}"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- clusters_box = clustering.clustering(neighbors_box)
- # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
- # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
- all_radii, _ = jax.vmap(partial(single_cluster_curv_radius,
- mask_box=clusters_box.masks,
- body=body,
- box_size=self.box_size))(clusters.masks + 1e-12)
- # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
- all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
- cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
- relevant_clusters = clusters.n_part_per_cluster > 3
- cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
- num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
- return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
- class CurvedClustersResidualsCost:
- def __init__(self,
- displacement,
- box_size,
- contact_fn,
- target_radius,
- radius_cutoff_mul=100,
- residuals_cost_factor=1,
- residuals_avg_type: ResidualsAvgType = 'linear'):
- self.box_size = box_size
- self.target_radius = target_radius
- self.displacement = displacement
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
- self.cutoff = radius_cutoff_mul
- self.res_cost_fac = residuals_cost_factor
- self.residuals_avg_type = residuals_avg_type
- self.residuals_avg_fn = residuals_avg_fn_factory(residuals_avg_type)
- def __str__(self):
- return (f"Curved clusters residuals cost, target_radius={self.target_radius}, "
- f"{self.residuals_avg_type} residuals with factor {self.res_cost_fac}.")
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- clusters_box = clustering.clustering(neighbors_box)
- # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
- # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
- all_radii, all_residuals = jax.vmap(partial(single_cluster_curv_radius,
- mask_box=clusters_box.masks,
- body=body,
- box_size=self.box_size,
- residuals_avg_fn=self.residuals_avg_fn))(clusters.masks + 1e-12)
- cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
- relevant_clusters = clusters.n_part_per_cluster > 3
- cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
- num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
- # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
- all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
- all_residuals = jnp.where(all_residuals > 1e2, 0., all_residuals)
- curvature_cost = jnp.log(all_radii / self.target_radius) ** 2
- residuals_cost = self.res_cost_fac * all_residuals
- return jnp.sum((curvature_cost + residuals_cost) * cluster_weights) # / num_of_weighted_clusters
- @partial(jax.jit, static_argnums=(4,))
- def single_cluster_quadratic_surface(mask_pbc: Array,
- mask_box: Array,
- body: rigid_body.RigidBody,
- box_size: float,
- surface_constant: int = -1,
- ) -> Array:
- """
- Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
- filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
- the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
- whole system rigid_body.RigidBody.
- Returns:
- - fitted quadratic surface eigenvalues
- """
- displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size)
- p0 = surface_fit_general.GeneralQuadraticSurfaceParams()
- opt = surface_fit_general.surface_fit_gn(partial(surface_fit_general.quadratic_surface, constant=surface_constant),
- displaced_coord, mask_pbc, p0=p0)
- return jnp.linalg.eigvalsh(opt.quadratic_form)
- class QuadraticSurfaceClustersCost:
- def __init__(self,
- displacement,
- box_size,
- contact_fn,
- target_eigvals,
- surface_constat=-1):
- self.box_size = box_size
- self.target_eigvals = jnp.sort(target_eigvals)
- self.displacement = displacement
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
- self.surface_constat = surface_constat
- def __str__(self):
- return f"Curved clusters cost, target_eigvals={self.target_eigvals}"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- clusters_box = clustering.clustering(neighbors_box)
- # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
- # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
- all_eigvals = jax.vmap(partial(single_cluster_quadratic_surface,
- mask_box=clusters_box.masks,
- body=body,
- box_size=self.box_size,
- surface_constant=self.surface_constat))(clusters.masks + 1e-12)
- # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
- cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
- relevant_clusters = clusters.n_part_per_cluster > 3
- cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
- num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
- return jnp.sum((all_eigvals - self.target_eigvals) ** 2 * cluster_weights[:, None]) / num_of_weighted_clusters
- def single_cluster_cylinder_radius(mask_pbc: Array,
- mask_box: Array,
- body: rigid_body.RigidBody,
- box_size: float):
- """
- Fit a cylinder to cluster particles. Cluster parameter should be an array of length N of cluster indices,
- filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
- the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
- whole system rigid_body.RigidBody.
- """
- displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC
- cluster_coord = displaced_coord * mask_pbc[:, None]
- matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
- values, vecs = jnp.linalg.eigh(matrix)
- euler_best_direction = jnp.array([jnp.arctan2(vecs[0, 2], vecs[1, 2]),
- jnp.arctan2(vecs[2, 2], jnp.sqrt(vecs[0, 2] ** 2 + vecs[1, 2] ** 2)),
- 0])
- p0_1 = surface_fit.QuadraticSurfaceParams(center=3 * vecs[:, 0],
- euler=euler_best_direction,
- radius=5.)
- p0_2 = surface_fit.QuadraticSurfaceParams(center=-3 * vecs[:, 0],
- euler=euler_best_direction,
- radius=5.)
- opt1 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_1)
- opt2 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_2)
- mean_residuals1 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc))) / (
- jnp.sum(mask_pbc) + 1e-6)
- mean_residuals2 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc))) / (
- jnp.sum(mask_pbc) + 1e-6)
- return jax.lax.cond(mean_residuals1 < mean_residuals2,
- lambda: jnp.abs(opt1.radius),
- lambda: jnp.abs(opt2.radius))
- # return jnp.minimum(jnp.abs(opt1.radius), jnp.abs(opt2.radius))
- class CylindricalClustersCost:
- def __init__(self, displacement,
- box_size,
- contact_fn,
- target_radius):
- self.box_size = box_size
- self.target_radius = target_radius
- self.displacement = displacement
- self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
- def __str__(self):
- return f"Cylindrical clusters cost, target_radius={self.target_radius}"
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
- clusters = clustering.clustering(neighbors)
- # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
- # this can change the fit values for small clusters, however, we later only take into account clusters with N > 3
- all_radii = jax.vmap(partial(single_cluster_cylinder_radius,
- body=body,
- displacement=self.displacement,
- box_size=self.box_size))(clusters.masks + 1e-12)
- # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at 1000 * target radius
- all_radii = jnp.where(all_radii > 1000 * self.target_radius, 1000 * self.target_radius, all_radii)
- cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
- relevant_clusters = clusters.n_part_per_cluster > 3
- cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
- num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
- return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
- class CostCombinator:
- def __init__(self, cost_fns: list[CostFn], coefficients: list[float]):
- if len(cost_fns) != len(coefficients):
- raise ValueError(f'Lengths of cost_fn list and coefficients list should be equal, '
- f'got {len(cost_fns)} and {len(coefficients)}, respectively')
- self.cost_fns = cost_fns
- self.coefficients = coefficients
- def __str__(self):
- return f"".join([f"{coef} x {cf} \n" for cf, coef in zip(self.cost_fns, self.coefficients)])
- def __call__(self, body: rigid_body.RigidBody, **interaction_params):
- return sum(coef * cf(body, **interaction_params) for cf, coef in zip(self.cost_fns, self.coefficients))
|