from typing import Callable, Literal, Protocol, Any import jax import jax.numpy as jnp from curvature_assembly import oriented_particle, clustering, smap, surface_fit, surface_fit_general from jax_md import rigid_body, space from functools import partial Array = jnp.ndarray class CostFn(Protocol): def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array: ... def __str__(self): ... @jax.jit def jit_flat_plane_cost_function(body: rigid_body.RigidBody, **interaction_params) -> float: n_part = body.center.shape[0] cm = jnp.mean(body.center, axis=0) matrix = 1 / n_part * jnp.sum(jnp.einsum('ni, nj -> nij', body.center - cm, body.center - cm), axis=0) values, vecs = jnp.linalg.eigh(matrix) return values[0] def single_cluster_cost(mask, body): n_part = jnp.sum(mask) cluster_particles = body.center * mask[:, None] centered = (cluster_particles - jnp.sum(cluster_particles, axis=0) / (n_part + 1e-8)) * mask[:, None] matrix = 1 / (n_part + 1e-8) * centered.T @ centered values, vecs = jnp.linalg.eigh(matrix) return values[0] class FlatPlaneClustersCost: def __init__(self, displacement, contact_fn, num_clusters_penalty: float): self.num_clusters_penalty = num_clusters_penalty self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) def __str__(self): return f"Flat plane clusters cost, penalty={self.num_clusters_penalty}" def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks) return jnp.sum(cluster_costs) + self.num_clusters_penalty * clusters.n_clusters class DistanceCost: def __init__(self, displacement, contact_fn): self.displacement = displacement self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, num_steps=25) def __str__(self): return "Distance cost" def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array: num_particles = body.center.shape[0] dr = space.map_product(self.displacement)(body.center, body.center) eigsys = oriented_particle.eigensystem(body.orientation) cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']), (0, 0, None), 0), (0, None, 0), 0)(dr, eigsys, eigsys) mask = jnp.float32(1.0) - jnp.eye(num_particles) return 0.5 * jnp.sum(mask * cf) class NumClustersCost: def __init__(self, displacement, contact_fn, num_clusters_penalty: float): self.num_clusters_penalty = num_clusters_penalty self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) def __str__(self): return f"Num clusters cost, penalty={self.num_clusters_penalty}" def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array: neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) return self.num_clusters_penalty * clusters.n_clusters class FlatPlaneCost: def __init__(self, displacement, contact_fn): self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) def __str__(self): return "Flat plane cost" def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks) return jnp.sum(cluster_costs) class SquaredClusterSizeCost: def __init__(self, displacement, contact_fn): self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) def __str__(self): return "Squared cluster size cost" def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) num_particles = jnp.sum(clusters.n_part_per_cluster) return num_particles ** 2 / jnp.sum(clusters.n_part_per_cluster ** 2) - 1 class FlatNeighbors: def __init__(self, displacement, contact_fn, num_neighbors: int = 6, **cf_kwargs): self.num_neighbors = num_neighbors self.displacement = displacement self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs) def __str__(self): return f"Flat neighbors cost, num_neighbors={self.num_neighbors}" def local_flatness(self, neighbor_indices: Array, body: rigid_body.RigidBody): neighbor_coord = body.center[neighbor_indices] mean = jnp.mean(neighbor_coord, axis=0, keepdims=True) centered_r = neighbor_coord - mean matrix = 1 / self.num_neighbors * centered_r.T @ centered_r values, vecs = jnp.linalg.eigh(matrix) return values[0] def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array: dr = space.map_product(self.displacement)(body.center, body.center) eigsys = oriented_particle.eigensystem(body.orientation) mapped_cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']), (0, 0, None), 0), (0, None, 0), 0) cf = mapped_cf(dr, eigsys, eigsys) indices = jnp.argsort(cf, axis=1)[:, 1:self.num_neighbors+1] # if we start with idx 0, also the central particle will count neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body))(indices) return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others? class WeightedFlatNeighbors: def __init__(self, displacement, weight_fn: Callable[[Array], Array]): self.weight_fn = weight_fn self.distance = jax.vmap(displacement, in_axes=(0, None)) def __str__(self): return "Weighted flat neighbors cost" def local_flatness(self, particle_coord: Array, body: rigid_body.RigidBody): dr = self.distance(body.center, particle_coord) + 1e-8 # adding 1e-8 prevents Nan gradients neighbor_weights = self.weight_fn(jnp.linalg.norm(dr, axis=1)) mean = 1 / jnp.sum(neighbor_weights) * jnp.sum(dr * neighbor_weights[:, None], axis=0, keepdims=True) centered_r = dr - mean matrix = 1 / jnp.sum(neighbor_weights) * centered_r.T @ jnp.diag(neighbor_weights) @ centered_r values, vecs = jnp.linalg.eigh(matrix) return values[0] * jnp.abs(values[1] - values[2]) def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array: neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body), in_axes=(0,))(body.center) return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others? class WeightedDistanceCost: def _init__(self, displacement, contact_fn, weight_fn: Callable[[Array], Array], **cf_kwargs): contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs) self.cf_cost_fn = smap.oriented_pair( lambda r, e1, e2, **params: weight_fn(contact_function(r, e1, e2, params['eigvals'])), displacement) def __str__(self): return "Weighted distance cost" def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array: num_particles = body.center.shape[0] return -self.cf_cost_fn(body, **interaction_params) / num_particles def normal_weight(x: Array, sigma: float, displacement: float = 1.): return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - displacement) / sigma) ** 2) def center_of_mass_pbc(coord, box_size, mask): """Calculate the center of mass of a cluster given by mask, taking into account periodic boundary conditions.""" num_cluster_particles = jnp.sum(mask) angle = coord * 2 * jnp.pi / box_size avg_sin = jnp.sum(jnp.sin(angle) * mask[:, None], axis=0) / num_cluster_particles avg_cos = jnp.sum(jnp.cos(angle) * mask[:, None], axis=0) / num_cluster_particles avg_angle = jnp.arctan2(-avg_sin, -avg_cos) + jnp.pi return box_size * avg_angle / (2 * jnp.pi) def displace_with_periodic_cm(coord, displacement, box_size, mask): """Displace all coordinates with a center of mass for a cluster given by mask.""" cm = center_of_mass_pbc(coord, box_size, mask) mapped_displacement = jax.vmap(displacement, in_axes=(0, None)) displaced_coord = mapped_displacement(coord, cm) return displaced_coord # @partial(jnp.vectorize, signature='(m),(m,m),(m,n),()->(m,n)') def contiguous_clusters(mask_pbc: Array, mask_box: Array, coord: Array, box_size: float): """ Map clusters in a PBC box to a contiguous cluster with the center of mass in the coordinate origin. The algorithm is simple but can fail in certain cases. It is based on calculating distances between the center of mass for the whole cluster and all the centers of mass for the subclusters that we get by not taking into account periodic boundary conditions. The subclusters are displaced for the PBC period (box size) if this calculated distance is more than a half of box size in any component. :param mask_pbc: mask for a SINGLE cluster, taking into account PBC :param mask_box: N x N array containing all the masks for subclusters without PBC :param coord: coordinated of all particles in the box :param box_size: size of box side, assumes all sides are equal :return: displaced coordinates of all particles, with the center of mass of the given cluster in coordinate origin. """ cm = center_of_mass_pbc(coord, box_size, mask_pbc) cm_subcl = jnp.sum(mask_box[..., None] * coord[None, ...], axis=1) / (jnp.sum(mask_box, axis=1) + 1e-6)[:,None] dist_between_cm = cm_subcl - cm[None, :] cluster_displacements = jnp.where(jnp.abs(dist_between_cm) > box_size / 2, jnp.sign(dist_between_cm) * jnp.full((coord.shape[0], 3), box_size, dtype=jnp.float64), jnp.zeros((coord.shape[0], 3))) # jnp.isclose must be used as we add 1e-12 to mask_pbc in some places to enable differentiation through clustering relevant_subclusters = jnp.all(jnp.isclose(mask_pbc[None, :] * mask_box, mask_box), axis=1) cluster_displacements = cluster_displacements * relevant_subclusters[:, None] # sum will just collapse the first dimension with no overlap as each particle can be a part of only one cluster particle_displacements = jnp.sum(cluster_displacements[:, None, :] * mask_box[:, :, None], axis=0) displaced_coord = coord - particle_displacements - cm return displaced_coord - jnp.sum(displaced_coord * mask_pbc[:, None], axis=0) / jnp.sum(mask_pbc) def box_displacement(Ra, Rb): """Calculate displacement vector in the simulation box without PBCs.""" return space.pairwise_displacement(Ra, Rb) ResidualsAvgType = Literal['linear', 'quadratic'] ResidualsAvgFn = Callable[[Array, Array], Array] def residuals_avg_fn_factory(which: ResidualsAvgType) -> ResidualsAvgFn: if which == 'linear': return lambda residuals, mask: jnp.sum(jnp.abs(residuals)) / (jnp.sum(mask) + 1e-6) if which == 'quadratic': return lambda residuals, mask: jnp.sum(residuals ** 2) / (jnp.sum(mask) + 1e-6) raise ValueError('Unknown type of residuals cost function.') @partial(jax.jit, static_argnums=(4,)) def single_cluster_curv_radius(mask_pbc: Array, mask_box: Array, body: rigid_body.RigidBody, box_size: float, residuals_avg_fn: ResidualsAvgFn = residuals_avg_fn_factory('linear')): """ Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices, filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the whole system rigid_body.RigidBody. Returns: - fitted cluster radius - mean residuals for the fit """ displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC cluster_coord = displaced_coord * mask_pbc[:, None] matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord values, vecs = jnp.linalg.eigh(matrix) p0_1 = surface_fit.QuadraticSurfaceParams(center=2 * vecs[:, 0], radius=1.) p0_2 = surface_fit.QuadraticSurfaceParams(center=-2 * vecs[:, 0], radius=1.) opt1 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_1) opt2 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_2) mean_residuals1 = residuals_avg_fn(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc), mask_pbc) mean_residuals2 = residuals_avg_fn(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc), mask_pbc) # first we order solutions based on fit residuals (opt1, opt2), (mr1, mr2) = jax.lax.cond(mean_residuals1 < mean_residuals2, lambda: ((opt1, opt2), (mean_residuals1, mean_residuals2)), lambda: ((opt2, opt1), (mean_residuals2, mean_residuals1))) # the second criterion makes sure to select the fit that didn't fail to converge # (in most cases, at least one converges) return jax.lax.cond(jnp.abs(opt1.radius) < 1e4, lambda: (jnp.abs(opt1.radius), mr1), lambda: (jnp.abs(opt2.radius), mr2)) class CurvedClustersCost: def __init__(self, displacement, box_size, contact_fn, target_radius, radius_cutoff_mul=100): self.box_size = box_size self.target_radius = target_radius self.displacement = displacement self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn) self.cutoff = radius_cutoff_mul def __str__(self): return f"Curved clusters cost, target_radius={self.target_radius}" def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) clusters_box = clustering.clustering(neighbors_box) # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0 # this can change the fit values for small clusters, however, we later only take into account clusters with N>3 all_radii, _ = jax.vmap(partial(single_cluster_curv_radius, mask_box=clusters_box.masks, body=body, box_size=self.box_size))(clusters.masks + 1e-12) # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii) cluster_weights = clusters.n_part_per_cluster / body.center.shape[0] relevant_clusters = clusters.n_part_per_cluster > 3 cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0 return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters class CurvedClustersResidualsCost: def __init__(self, displacement, box_size, contact_fn, target_radius, radius_cutoff_mul=100, residuals_cost_factor=1, residuals_avg_type: ResidualsAvgType = 'linear'): self.box_size = box_size self.target_radius = target_radius self.displacement = displacement self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn) self.cutoff = radius_cutoff_mul self.res_cost_fac = residuals_cost_factor self.residuals_avg_type = residuals_avg_type self.residuals_avg_fn = residuals_avg_fn_factory(residuals_avg_type) def __str__(self): return (f"Curved clusters residuals cost, target_radius={self.target_radius}, " f"{self.residuals_avg_type} residuals with factor {self.res_cost_fac}.") def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) clusters_box = clustering.clustering(neighbors_box) # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0 # this can change the fit values for small clusters, however, we later only take into account clusters with N>3 all_radii, all_residuals = jax.vmap(partial(single_cluster_curv_radius, mask_box=clusters_box.masks, body=body, box_size=self.box_size, residuals_avg_fn=self.residuals_avg_fn))(clusters.masks + 1e-12) cluster_weights = clusters.n_part_per_cluster / body.center.shape[0] relevant_clusters = clusters.n_part_per_cluster > 3 cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0 # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii) all_residuals = jnp.where(all_residuals > 1e2, 0., all_residuals) curvature_cost = jnp.log(all_radii / self.target_radius) ** 2 residuals_cost = self.res_cost_fac * all_residuals return jnp.sum((curvature_cost + residuals_cost) * cluster_weights) # / num_of_weighted_clusters @partial(jax.jit, static_argnums=(4,)) def single_cluster_quadratic_surface(mask_pbc: Array, mask_box: Array, body: rigid_body.RigidBody, box_size: float, surface_constant: int = -1, ) -> Array: """ Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices, filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the whole system rigid_body.RigidBody. Returns: - fitted quadratic surface eigenvalues """ displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) p0 = surface_fit_general.GeneralQuadraticSurfaceParams() opt = surface_fit_general.surface_fit_gn(partial(surface_fit_general.quadratic_surface, constant=surface_constant), displaced_coord, mask_pbc, p0=p0) return jnp.linalg.eigvalsh(opt.quadratic_form) class QuadraticSurfaceClustersCost: def __init__(self, displacement, box_size, contact_fn, target_eigvals, surface_constat=-1): self.box_size = box_size self.target_eigvals = jnp.sort(target_eigvals) self.displacement = displacement self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn) self.surface_constat = surface_constat def __str__(self): return f"Curved clusters cost, target_eigvals={self.target_eigvals}" def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) clusters_box = clustering.clustering(neighbors_box) # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0 # this can change the fit values for small clusters, however, we later only take into account clusters with N>3 all_eigvals = jax.vmap(partial(single_cluster_quadratic_surface, mask_box=clusters_box.masks, body=body, box_size=self.box_size, surface_constant=self.surface_constat))(clusters.masks + 1e-12) # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value cluster_weights = clusters.n_part_per_cluster / body.center.shape[0] relevant_clusters = clusters.n_part_per_cluster > 3 cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0 return jnp.sum((all_eigvals - self.target_eigvals) ** 2 * cluster_weights[:, None]) / num_of_weighted_clusters def single_cluster_cylinder_radius(mask_pbc: Array, mask_box: Array, body: rigid_body.RigidBody, box_size: float): """ Fit a cylinder to cluster particles. Cluster parameter should be an array of length N of cluster indices, filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the whole system rigid_body.RigidBody. """ displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC cluster_coord = displaced_coord * mask_pbc[:, None] matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord values, vecs = jnp.linalg.eigh(matrix) euler_best_direction = jnp.array([jnp.arctan2(vecs[0, 2], vecs[1, 2]), jnp.arctan2(vecs[2, 2], jnp.sqrt(vecs[0, 2] ** 2 + vecs[1, 2] ** 2)), 0]) p0_1 = surface_fit.QuadraticSurfaceParams(center=3 * vecs[:, 0], euler=euler_best_direction, radius=5.) p0_2 = surface_fit.QuadraticSurfaceParams(center=-3 * vecs[:, 0], euler=euler_best_direction, radius=5.) opt1 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_1) opt2 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_2) mean_residuals1 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc))) / ( jnp.sum(mask_pbc) + 1e-6) mean_residuals2 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc))) / ( jnp.sum(mask_pbc) + 1e-6) return jax.lax.cond(mean_residuals1 < mean_residuals2, lambda: jnp.abs(opt1.radius), lambda: jnp.abs(opt2.radius)) # return jnp.minimum(jnp.abs(opt1.radius), jnp.abs(opt2.radius)) class CylindricalClustersCost: def __init__(self, displacement, box_size, contact_fn, target_radius): self.box_size = box_size self.target_radius = target_radius self.displacement = displacement self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn) def __str__(self): return f"Cylindrical clusters cost, target_radius={self.target_radius}" def __call__(self, body: rigid_body.RigidBody, **interaction_params): neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4) clusters = clustering.clustering(neighbors) # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0 # this can change the fit values for small clusters, however, we later only take into account clusters with N > 3 all_radii = jax.vmap(partial(single_cluster_cylinder_radius, body=body, displacement=self.displacement, box_size=self.box_size))(clusters.masks + 1e-12) # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at 1000 * target radius all_radii = jnp.where(all_radii > 1000 * self.target_radius, 1000 * self.target_radius, all_radii) cluster_weights = clusters.n_part_per_cluster / body.center.shape[0] relevant_clusters = clusters.n_part_per_cluster > 3 cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0 return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters class CostCombinator: def __init__(self, cost_fns: list[CostFn], coefficients: list[float]): if len(cost_fns) != len(coefficients): raise ValueError(f'Lengths of cost_fn list and coefficients list should be equal, ' f'got {len(cost_fns)} and {len(coefficients)}, respectively') self.cost_fns = cost_fns self.coefficients = coefficients def __str__(self): return f"".join([f"{coef} x {cf} \n" for cf, coef in zip(self.cost_fns, self.coefficients)]) def __call__(self, body: rigid_body.RigidBody, **interaction_params): return sum(coef * cf(body, **interaction_params) for cf, coef in zip(self.cost_fns, self.coefficients))