Andraz Gnidovec 1 miesiąc temu
commit
72558893ad

+ 7 - 0
.gitignore

@@ -0,0 +1,7 @@
+# Ignore compiled Python files
+curvature_assembly/__pycache__/
+curvature_assembly.egg-info/
+jax-md/
+*.pyc
+*.pyo
+config.json

+ 0 - 0
README.md


+ 0 - 0
curvature_assembly/__init__.py


+ 196 - 0
curvature_assembly/clustering.py

@@ -0,0 +1,196 @@
+from typing import Callable
+import jax
+import jax.numpy as jnp
+from jax_md import dataclasses, space, rigid_body
+from functools import partial
+from curvature_assembly import oriented_particle
+
+Array = jnp.ndarray
+
+
+@dataclasses.dataclass
+class Clusters:
+    """Class to store cluster data."""
+    n_clusters: int
+    clusters: Array
+    n_part_per_cluster: Array
+    masks: Array
+
+
+@dataclasses.dataclass
+class Neighbors:
+    """Class to store connectivity data."""
+    node_next: Array
+    n_contacts_per_particle: Array
+    n_links: int
+
+
+def nonzero_1d(array: Array, size: int = None, fill_value: float = None) -> Array:
+    """Vectorization of the nonzero function over the second array axis."""
+    return jnp.vectorize(partial(jnp.nonzero, size=size, fill_value=fill_value), signature='(k)->(d)')(array)[0]
+
+
+def get_neighboring_fn(displacement: space.DisplacementFn, max_contacts: int = 16) -> Callable:
+    """Return distance based neighboring function used to get particle contact information."""
+
+    @jax.jit
+    def neighboring(coord: Array, neigh_dist: float) -> Neighbors:
+        metric = space.canonicalize_displacement_or_metric(displacement)
+        dr = space.map_product(metric)(coord, coord)
+        mask = jnp.float32(1.0) - jnp.eye(coord.shape[0])
+        contacts = mask * jnp.array(dr < neigh_dist, dtype=jnp.int32)
+        n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
+        n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
+        node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=len(coord)), dtype=jnp.int32)
+        return Neighbors(node_next, n_contacts_per_particle, n_links)
+
+    return neighboring
+
+
+def get_ellipsoid_neighboring_fn(displacement: space.DisplacementFn,
+                                 contact_fn: Callable[..., Array],
+                                 max_contacts: int = 16,
+                                 **cf_kwargs) -> Callable:
+    """Return contact function based neighboring function used to get particle contact information."""
+
+    contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
+
+    @jax.jit
+    def neighboring(body: rigid_body.RigidBody, eigvals: Array, neigh_contact_fn: float) -> Neighbors:
+        num_particles = body.center.shape[0]
+        dr = space.map_product(displacement)(body.center, body.center)
+        eigsys = oriented_particle.eigensystem(body.orientation)
+        mapped_cf = jax.vmap(jax.vmap(partial(contact_function, eigvals=eigvals), (0, 0, None), 0), (0, None, 0), 0)
+        cf = mapped_cf(dr, eigsys, eigsys)
+        mask = jnp.float32(1.0) - jnp.eye(num_particles)
+        contacts = mask * jnp.array(cf < neigh_contact_fn, dtype=jnp.int32)
+        n_links = 0.5 * jnp.sum(contacts, dtype=jnp.int32)
+        n_contacts_per_particle = jnp.sum(contacts, axis=0, dtype=jnp.int32)
+        node_next = jnp.asarray(nonzero_1d(contacts, size=max_contacts, fill_value=num_particles), dtype=jnp.int32)
+        return Neighbors(node_next, n_contacts_per_particle, n_links)
+
+    return neighboring
+
+
+@partial(jax.jit, static_argnums=1)
+def clustering(neighbors: Neighbors, num_iter: int = 20) -> Clusters:
+    """
+    Clustering algorithm from de Oliveira et.al.,
+    https://www.tandfonline.com/doi/full/10.1080/08927022.2020.1839661
+
+    Args:
+        neighbors: instance of Neighbors object where data about each particle neighbors is stored
+        num_iter: number of iterations to fix cluster labels of neighboring particles
+
+    Returns:
+        Clustering data stored in Clusters object with the following attributes:
+        n_clusters - number of clusters,
+        n_part_per_cluster - array of length N storing the lengths of all clusters. Elements with index i > n_clusters
+        are set to the total number of particles in the system.
+        clusters - NxN array where each row i stores indices of particles in one cluster. Clusters[i, j] with
+        i > n_clusters and/or j > n_part_per_cluster[i] are set to the total number of particles in the system.
+    """
+
+    num_particles = neighbors.node_next.shape[0]
+    particle_indices = jnp.arange(num_particles)
+    nodeLpure = jnp.zeros(num_particles, dtype=jnp.int32)
+    nodeL = jnp.concatenate((nodeLpure, jnp.array([num_particles])))  # the last index is used as a dump for unused data
+
+    def new_label(nodeL, n_clusters, idx, *args):
+        nodeL = nodeL.at[idx].set(n_clusters)
+        return nodeL, n_clusters + 1
+
+    def neighbor_labels(nodeL, n_clusters, idx, neigh_indices, label):
+        nodeL = nodeL.at[idx].set(label)
+        nodeL = nodeL.at[neigh_indices].set(label)
+        return nodeL, n_clusters
+
+    def assign_labels(nodeL_ncl, idx, no_labeled_neighbors_fn):
+        nodeL, n_clusters = nodeL_ncl
+        labels = nodeL[neighbors.node_next[idx]]
+        max_label = jnp.max(labels, where=(labels < num_particles), initial=0)
+        min_label = jnp.min(labels, where=(labels > 0), initial=num_particles)
+        nodeL, n_clusters = jax.lax.cond(max_label == 0, no_labeled_neighbors_fn, neighbor_labels,
+                                         nodeL, n_clusters, idx, neighbors.node_next[idx], min_label)
+        nodeL = nodeL.at[num_particles].set(num_particles)
+        return (nodeL, n_clusters), 0.
+
+    # assign labels to all clustered particles
+    n_clusters = 0
+    nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=new_label),
+                                (nodeL, n_clusters), xs=particle_indices)
+    nodeL, _ = nodeL_ncl
+
+    def keep_label(nodeL, n_clusters, *args):
+        return nodeL, n_clusters
+
+    def check_labels(condition, idx, nodeL):
+        # add 1 to condition value if all neighbors of particle idx share its label
+        labels = nodeL[neighbors.node_next[idx]]
+        unique_labels = jnp.unique(labels, size=3, fill_value=num_particles)
+        # check if all neighbor labels are the same:
+        all_same = (unique_labels[1] == num_particles) * (unique_labels[0] == nodeL[idx])
+        condition = jax.lax.cond(all_same, lambda x: condition + 1, lambda x: condition, 0.)
+        return condition, 0.
+
+    def fix_labels(nodeL):
+        nodeL_ncl, _ = jax.lax.scan(partial(assign_labels, no_labeled_neighbors_fn=keep_label),
+                                    (nodeL, 0), xs=particle_indices)
+        nodeL, _ = nodeL_ncl
+        return nodeL
+
+    def relabel_iteration(nodeL, iteration):
+        # condition is the number of particles with all neighbors sharing its label
+        condition, _ = jax.lax.scan(partial(check_labels, nodeL=nodeL), 0., xs=particle_indices)
+        # if condition == num_particles, algorithm converged, and we return the same values for nodeL
+        nodeL = jax.lax.cond(condition == num_particles, lambda x: x, fix_labels, nodeL)
+        return nodeL, 0.
+
+    # fixing the cluster labels of particles (max_iter iterations)
+    nodeL, _ = jax.lax.scan(relabel_iteration, nodeL, xs=jnp.arange(num_iter))
+
+    # determine cluster id values for all clusters
+    nodeLpure = nodeL[:-1]
+    id = jnp.unique(nodeLpure, size=num_particles, fill_value=num_particles)
+
+    @partial(jnp.vectorize, signature='()->(d)')
+    def set_cluster_vmap(cluster_id):
+        return jnp.where(nodeLpure == cluster_id, size=num_particles, fill_value=num_particles)[0]
+
+    @partial(jnp.vectorize, signature='()->(d)')
+    def set_mask_vmap(cluster_id):
+        return jnp.where(nodeLpure == cluster_id,
+                         jnp.ones(num_particles, dtype=jnp.int32),
+                         jnp.zeros(num_particles, dtype=jnp.int32))
+
+    clusters = set_cluster_vmap(id)
+    masks = set_mask_vmap(id)
+    n_part_per_cluster = jnp.sum(clusters < num_particles, axis=1)
+    n_clusters = jnp.sum(jnp.min(clusters, axis=1) < num_particles)
+
+    return Clusters(n_clusters, clusters, n_part_per_cluster, masks)
+
+
+def get_cluster_particles(clusters: Clusters, idx: int) -> Array:
+    """Clip fill values from cluster data and return only indices of particles in cluster idx."""
+    return clusters.clusters[idx, :clusters.n_part_per_cluster[idx]]
+
+
+@jax.jit
+def get_cluster_mask(cluster: Array) -> Array:
+    """Create a mask for particles belonging to cluster given as array of indices."""
+    mask_extended = jnp.zeros(cluster.shape[0] + 1, dtype=jnp.int32)
+    mask_extended = mask_extended.at[cluster].set(1.)
+    return mask_extended[:-1]
+
+
+def get_all_cluster_masks(clusters: Clusters) -> Array:
+    """
+    Create a NxN array with each row representing a mask for one particle cluster. For rows with
+    i > clusters.n_clusters, the elements of all masks are False.
+    """
+    f = jax.vmap(get_cluster_mask)
+    return f(clusters.clusters)
+
+
+

+ 536 - 0
curvature_assembly/cost_functions.py

@@ -0,0 +1,536 @@
+from typing import Callable, Literal, Protocol, Any
+import jax
+import jax.numpy as jnp
+from curvature_assembly import oriented_particle, clustering, smap, surface_fit, surface_fit_general
+from jax_md import rigid_body, space
+from functools import partial
+
+
+Array = jnp.ndarray
+
+class CostFn(Protocol):
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
+        ...
+
+    def __str__(self):
+        ...
+
+
+@jax.jit
+def jit_flat_plane_cost_function(body: rigid_body.RigidBody, **interaction_params) -> float:
+    n_part = body.center.shape[0]
+    cm = jnp.mean(body.center, axis=0)
+    matrix = 1 / n_part * jnp.sum(jnp.einsum('ni, nj -> nij', body.center - cm, body.center - cm), axis=0)
+    values, vecs = jnp.linalg.eigh(matrix)
+    return values[0]
+
+
+def single_cluster_cost(mask, body):
+    n_part = jnp.sum(mask)
+    cluster_particles = body.center * mask[:, None]
+    centered = (cluster_particles - jnp.sum(cluster_particles, axis=0) / (n_part + 1e-8)) * mask[:, None]
+    matrix = 1 / (n_part + 1e-8) * centered.T @ centered
+    values, vecs = jnp.linalg.eigh(matrix)
+    return values[0]
+
+
+class FlatPlaneClustersCost:
+
+    def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
+        self.num_clusters_penalty = num_clusters_penalty
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+
+    def __str__(self):
+        return f"Flat plane clusters cost, penalty={self.num_clusters_penalty}"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
+        cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
+        return jnp.sum(cluster_costs) + self.num_clusters_penalty * clusters.n_clusters
+
+
+class DistanceCost:
+
+    def __init__(self, displacement, contact_fn):
+        self.displacement = displacement
+        self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, num_steps=25)
+
+    def __str__(self):
+        return "Distance cost"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
+        num_particles = body.center.shape[0]
+        dr = space.map_product(self.displacement)(body.center, body.center)
+        eigsys = oriented_particle.eigensystem(body.orientation)
+        cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
+                               (0, 0, None), 0), (0, None, 0), 0)(dr, eigsys, eigsys)
+        mask = jnp.float32(1.0) - jnp.eye(num_particles)
+        return 0.5 * jnp.sum(mask * cf)
+
+
+class NumClustersCost:
+
+    def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
+        self.num_clusters_penalty = num_clusters_penalty
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+
+    def __str__(self):
+        return f"Num clusters cost, penalty={self.num_clusters_penalty}"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        return self.num_clusters_penalty * clusters.n_clusters
+
+
+class FlatPlaneCost:
+
+    def __init__(self, displacement, contact_fn):
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+
+    def __str__(self):
+        return "Flat plane cost"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
+        cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
+        return jnp.sum(cluster_costs)
+
+
+class SquaredClusterSizeCost:
+
+    def __init__(self, displacement, contact_fn):
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+
+    def __str__(self):
+        return "Squared cluster size cost"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        num_particles = jnp.sum(clusters.n_part_per_cluster)
+        return num_particles ** 2 / jnp.sum(clusters.n_part_per_cluster ** 2) - 1
+
+
+class FlatNeighbors:
+
+    def __init__(self, displacement, contact_fn, num_neighbors: int = 6, **cf_kwargs):
+        self.num_neighbors = num_neighbors
+        self.displacement = displacement
+        self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
+
+    def __str__(self):
+        return f"Flat neighbors cost, num_neighbors={self.num_neighbors}"
+
+    def local_flatness(self, neighbor_indices: Array, body: rigid_body.RigidBody):
+        neighbor_coord = body.center[neighbor_indices]
+        mean = jnp.mean(neighbor_coord, axis=0, keepdims=True)
+        centered_r = neighbor_coord - mean
+        matrix = 1 / self.num_neighbors * centered_r.T @ centered_r
+        values, vecs = jnp.linalg.eigh(matrix)
+        return values[0]
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
+        dr = space.map_product(self.displacement)(body.center, body.center)
+        eigsys = oriented_particle.eigensystem(body.orientation)
+        mapped_cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
+                                      (0, 0, None), 0), (0, None, 0), 0)
+        cf = mapped_cf(dr, eigsys, eigsys)
+        indices = jnp.argsort(cf, axis=1)[:, 1:self.num_neighbors+1]  # if we start with idx 0, also the central particle will count
+        neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body))(indices)
+        return jnp.mean(neighbor_flatness)  # should this be weighted somehow? Based on particle distance to others?
+
+
+class WeightedFlatNeighbors:
+
+    def __init__(self, displacement, weight_fn: Callable[[Array], Array]):
+        self.weight_fn = weight_fn
+        self.distance = jax.vmap(displacement, in_axes=(0, None))
+
+    def __str__(self):
+        return "Weighted flat neighbors cost"
+
+    def local_flatness(self, particle_coord: Array, body: rigid_body.RigidBody):
+        dr = self.distance(body.center, particle_coord) + 1e-8  # adding 1e-8 prevents Nan gradients
+        neighbor_weights = self.weight_fn(jnp.linalg.norm(dr, axis=1))
+        mean = 1 / jnp.sum(neighbor_weights) * jnp.sum(dr * neighbor_weights[:, None], axis=0, keepdims=True)
+        centered_r = dr - mean
+        matrix = 1 / jnp.sum(neighbor_weights) * centered_r.T @ jnp.diag(neighbor_weights) @ centered_r
+        values, vecs = jnp.linalg.eigh(matrix)
+        return values[0] * jnp.abs(values[1] - values[2])
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
+        neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body), in_axes=(0,))(body.center)
+        return jnp.mean(neighbor_flatness)  # should this be weighted somehow? Based on particle distance to others?
+
+
+class WeightedDistanceCost:
+
+    def _init__(self, displacement,
+                         contact_fn,
+                         weight_fn: Callable[[Array], Array],
+                         **cf_kwargs):
+
+        contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
+        self.cf_cost_fn = smap.oriented_pair(
+            lambda r, e1, e2, **params: weight_fn(contact_function(r, e1, e2, params['eigvals'])),
+            displacement)
+
+    def __str__(self):
+        return "Weighted distance cost"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
+        num_particles = body.center.shape[0]
+        return -self.cf_cost_fn(body, **interaction_params) / num_particles
+
+
+def normal_weight(x: Array, sigma: float, displacement: float = 1.):
+    return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - displacement) / sigma) ** 2)
+
+
+def center_of_mass_pbc(coord, box_size, mask):
+    """Calculate the center of mass of a cluster given by mask, taking into account periodic boundary conditions."""
+    num_cluster_particles = jnp.sum(mask)
+    angle = coord * 2 * jnp.pi / box_size
+    avg_sin = jnp.sum(jnp.sin(angle) * mask[:, None], axis=0) / num_cluster_particles
+    avg_cos = jnp.sum(jnp.cos(angle) * mask[:, None], axis=0) / num_cluster_particles
+    avg_angle = jnp.arctan2(-avg_sin, -avg_cos) + jnp.pi
+    return box_size * avg_angle / (2 * jnp.pi)
+
+
+def displace_with_periodic_cm(coord, displacement, box_size, mask):
+    """Displace all coordinates with a center of mass for a cluster given by mask."""
+    cm = center_of_mass_pbc(coord, box_size, mask)
+    mapped_displacement = jax.vmap(displacement, in_axes=(0, None))
+    displaced_coord = mapped_displacement(coord, cm)
+    return displaced_coord
+
+
+# @partial(jnp.vectorize, signature='(m),(m,m),(m,n),()->(m,n)')
+def contiguous_clusters(mask_pbc: Array, mask_box: Array, coord: Array, box_size: float):
+    """
+    Map clusters in a PBC box to a contiguous cluster with the center of mass in the coordinate origin.
+    The algorithm is simple but can fail in certain cases. It is based on calculating distances between the center of
+    mass for the whole cluster and all the centers of mass for the subclusters that we get by not taking into account
+    periodic boundary conditions. The subclusters are displaced for the PBC period (box size) if this calculated
+    distance is more than a half of box size in any component.
+    :param mask_pbc: mask for a SINGLE cluster, taking into account PBC
+    :param mask_box: N x N array containing all the masks for subclusters without PBC
+    :param coord: coordinated of all particles in the box
+    :param box_size: size of box side, assumes all sides are equal
+    :return: displaced coordinates of all particles, with the center of mass of the given cluster in coordinate origin.
+    """
+    cm = center_of_mass_pbc(coord, box_size, mask_pbc)
+    cm_subcl = jnp.sum(mask_box[..., None] * coord[None, ...], axis=1) / (jnp.sum(mask_box, axis=1) + 1e-6)[:,None]
+
+    dist_between_cm = cm_subcl - cm[None, :]
+    cluster_displacements = jnp.where(jnp.abs(dist_between_cm) > box_size / 2,
+                                      jnp.sign(dist_between_cm) * jnp.full((coord.shape[0], 3), box_size,
+                                                                       dtype=jnp.float64),
+                                      jnp.zeros((coord.shape[0], 3)))
+
+    # jnp.isclose must be used as we add 1e-12 to mask_pbc in some places to enable differentiation through clustering
+    relevant_subclusters = jnp.all(jnp.isclose(mask_pbc[None, :] * mask_box, mask_box), axis=1)
+    cluster_displacements = cluster_displacements * relevant_subclusters[:, None]
+
+    # sum will just collapse the first dimension with no overlap as each particle can be a part of only one cluster
+    particle_displacements = jnp.sum(cluster_displacements[:, None, :] * mask_box[:, :, None], axis=0)
+
+    displaced_coord = coord - particle_displacements - cm
+    return displaced_coord - jnp.sum(displaced_coord * mask_pbc[:, None], axis=0) / jnp.sum(mask_pbc)
+
+
+def box_displacement(Ra, Rb):
+    """Calculate displacement vector in the simulation box without PBCs."""
+    return space.pairwise_displacement(Ra, Rb)
+
+
+ResidualsAvgType = Literal['linear', 'quadratic']
+ResidualsAvgFn = Callable[[Array, Array], Array]
+
+def residuals_avg_fn_factory(which: ResidualsAvgType) -> ResidualsAvgFn:
+    if which == 'linear':
+        return lambda residuals, mask: jnp.sum(jnp.abs(residuals)) / (jnp.sum(mask) + 1e-6)
+    if which == 'quadratic':
+        return lambda residuals, mask: jnp.sum(residuals ** 2) / (jnp.sum(mask) + 1e-6)
+    raise ValueError('Unknown type of residuals cost function.')
+
+
+@partial(jax.jit, static_argnums=(4,))
+def single_cluster_curv_radius(mask_pbc: Array,
+                               mask_box: Array,
+                               body: rigid_body.RigidBody,
+                               box_size: float,
+                               residuals_avg_fn: ResidualsAvgFn = residuals_avg_fn_factory('linear')):
+    """
+    Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
+    filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
+    the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
+    whole system rigid_body.RigidBody.
+    Returns:
+        - fitted cluster radius
+        - mean residuals for the fit
+    """
+    displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size)  # frame of reference without PBC
+    cluster_coord = displaced_coord * mask_pbc[:, None]
+    matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
+    values, vecs = jnp.linalg.eigh(matrix)
+    p0_1 = surface_fit.QuadraticSurfaceParams(center=2 * vecs[:, 0],
+                                              radius=1.)
+    p0_2 = surface_fit.QuadraticSurfaceParams(center=-2 * vecs[:, 0],
+                                              radius=1.)
+    opt1 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_1)
+    opt2 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_2)
+
+    mean_residuals1 = residuals_avg_fn(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc), mask_pbc)
+    mean_residuals2 = residuals_avg_fn(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc), mask_pbc)
+
+    # first we order solutions based on fit residuals
+    (opt1, opt2), (mr1, mr2) = jax.lax.cond(mean_residuals1 < mean_residuals2,
+                                            lambda: ((opt1, opt2), (mean_residuals1, mean_residuals2)),
+                                            lambda: ((opt2, opt1), (mean_residuals2, mean_residuals1)))
+
+    # the second criterion makes sure to select the fit that didn't fail to converge
+    # (in most cases, at least one converges)
+    return jax.lax.cond(jnp.abs(opt1.radius) < 1e4,
+                        lambda: (jnp.abs(opt1.radius), mr1),
+                        lambda: (jnp.abs(opt2.radius), mr2))
+
+
+class CurvedClustersCost:
+
+    def __init__(self,
+                 displacement,
+                 box_size,
+                 contact_fn,
+                 target_radius,
+                 radius_cutoff_mul=100):
+        self.box_size = box_size
+        self.target_radius = target_radius
+        self.displacement = displacement
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+        self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
+        self.cutoff = radius_cutoff_mul
+
+    def __str__(self):
+        return f"Curved clusters cost, target_radius={self.target_radius}"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        clusters_box = clustering.clustering(neighbors_box)
+        # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
+        # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
+        all_radii, _ = jax.vmap(partial(single_cluster_curv_radius,
+                                        mask_box=clusters_box.masks,
+                                        body=body,
+                                        box_size=self.box_size))(clusters.masks + 1e-12)
+        # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
+        all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
+        cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
+        relevant_clusters = clusters.n_part_per_cluster > 3
+        cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.)  # clusters with N < 3 are excluded
+        num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6  # we add a small number to avoid dividing by 0
+        return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
+
+
+class CurvedClustersResidualsCost:
+
+    def __init__(self,
+                 displacement,
+                 box_size,
+                 contact_fn,
+                 target_radius,
+                 radius_cutoff_mul=100,
+                 residuals_cost_factor=1,
+                 residuals_avg_type: ResidualsAvgType = 'linear'):
+        self.box_size = box_size
+        self.target_radius = target_radius
+        self.displacement = displacement
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+        self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
+        self.cutoff = radius_cutoff_mul
+        self.res_cost_fac = residuals_cost_factor
+        self.residuals_avg_type = residuals_avg_type
+        self.residuals_avg_fn = residuals_avg_fn_factory(residuals_avg_type)
+
+    def __str__(self):
+        return (f"Curved clusters residuals cost, target_radius={self.target_radius}, "
+                f"{self.residuals_avg_type} residuals with factor {self.res_cost_fac}.")
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        clusters_box = clustering.clustering(neighbors_box)
+        # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
+        # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
+        all_radii, all_residuals = jax.vmap(partial(single_cluster_curv_radius,
+                                                    mask_box=clusters_box.masks,
+                                                    body=body,
+                                                    box_size=self.box_size,
+                                                    residuals_avg_fn=self.residuals_avg_fn))(clusters.masks + 1e-12)
+
+        cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
+        relevant_clusters = clusters.n_part_per_cluster > 3
+        cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.)  # clusters with N < 3 are excluded
+        num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6  # we add a small number to avoid dividing by 0
+
+        # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
+        all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
+        all_residuals = jnp.where(all_residuals > 1e2, 0., all_residuals)
+        curvature_cost = jnp.log(all_radii / self.target_radius) ** 2
+        residuals_cost = self.res_cost_fac * all_residuals
+
+        return jnp.sum((curvature_cost + residuals_cost) * cluster_weights) # / num_of_weighted_clusters
+
+
+@partial(jax.jit, static_argnums=(4,))
+def single_cluster_quadratic_surface(mask_pbc: Array,
+                                     mask_box: Array,
+                                     body: rigid_body.RigidBody,
+                                     box_size: float,
+                                     surface_constant: int = -1,
+                                     ) -> Array:
+    """
+    Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
+    filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
+    the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
+    whole system rigid_body.RigidBody.
+    Returns:
+        - fitted quadratic surface eigenvalues
+    """
+    displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size)
+    p0 = surface_fit_general.GeneralQuadraticSurfaceParams()
+    opt = surface_fit_general.surface_fit_gn(partial(surface_fit_general.quadratic_surface, constant=surface_constant),
+                                             displaced_coord, mask_pbc, p0=p0)
+
+    return jnp.linalg.eigvalsh(opt.quadratic_form)
+
+
+class QuadraticSurfaceClustersCost:
+
+    def __init__(self,
+                 displacement,
+                 box_size,
+                 contact_fn,
+                 target_eigvals,
+                 surface_constat=-1):
+        self.box_size = box_size
+        self.target_eigvals = jnp.sort(target_eigvals)
+        self.displacement = displacement
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+        self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
+        self.surface_constat = surface_constat
+
+    def __str__(self):
+        return f"Curved clusters cost, target_eigvals={self.target_eigvals}"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        clusters_box = clustering.clustering(neighbors_box)
+        # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
+        # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
+        all_eigvals = jax.vmap(partial(single_cluster_quadratic_surface,
+                                       mask_box=clusters_box.masks,
+                                       body=body,
+                                       box_size=self.box_size,
+                                       surface_constant=self.surface_constat))(clusters.masks + 1e-12)
+        # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
+        cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
+        relevant_clusters = clusters.n_part_per_cluster > 3
+        cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.)  # clusters with N < 3 are excluded
+        num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6  # we add a small number to avoid dividing by 0
+        return jnp.sum((all_eigvals - self.target_eigvals) ** 2 * cluster_weights[:, None]) / num_of_weighted_clusters
+
+
+def single_cluster_cylinder_radius(mask_pbc: Array,
+                                   mask_box: Array,
+                                   body: rigid_body.RigidBody,
+                                   box_size: float):
+    """
+    Fit a cylinder to cluster particles. Cluster parameter should be an array of length N of cluster indices,
+    filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
+    the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
+    whole system rigid_body.RigidBody.
+    """
+    displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size)  # frame of reference without PBC
+    cluster_coord = displaced_coord * mask_pbc[:, None]
+    matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
+    values, vecs = jnp.linalg.eigh(matrix)
+    euler_best_direction = jnp.array([jnp.arctan2(vecs[0, 2], vecs[1, 2]),
+                                      jnp.arctan2(vecs[2, 2], jnp.sqrt(vecs[0, 2] ** 2 + vecs[1, 2] ** 2)),
+                                      0])
+    p0_1 = surface_fit.QuadraticSurfaceParams(center=3 * vecs[:, 0],
+                                              euler=euler_best_direction,
+                                              radius=5.)
+    p0_2 = surface_fit.QuadraticSurfaceParams(center=-3 * vecs[:, 0],
+                                              euler=euler_best_direction,
+                                              radius=5.)
+    opt1 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_1)
+    opt2 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_2)
+    mean_residuals1 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc))) / (
+                jnp.sum(mask_pbc) + 1e-6)
+    mean_residuals2 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc))) / (
+                jnp.sum(mask_pbc) + 1e-6)
+    return jax.lax.cond(mean_residuals1 < mean_residuals2,
+                        lambda: jnp.abs(opt1.radius),
+                        lambda: jnp.abs(opt2.radius))
+    # return jnp.minimum(jnp.abs(opt1.radius), jnp.abs(opt2.radius))
+
+
+class CylindricalClustersCost:
+
+    def __init__(self, displacement,
+                                box_size,
+                                contact_fn,
+                                target_radius):
+        self.box_size = box_size
+        self.target_radius = target_radius
+        self.displacement = displacement
+        self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
+
+    def __str__(self):
+        return f"Cylindrical clusters cost, target_radius={self.target_radius}"
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
+        clusters = clustering.clustering(neighbors)
+        # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
+        # this can change the fit values for small clusters, however, we later only take into account clusters with N > 3
+        all_radii = jax.vmap(partial(single_cluster_cylinder_radius,
+                                     body=body,
+                                     displacement=self.displacement,
+                                     box_size=self.box_size))(clusters.masks + 1e-12)
+        # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at 1000 * target radius
+        all_radii = jnp.where(all_radii > 1000 * self.target_radius, 1000 * self.target_radius, all_radii)
+        cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
+        relevant_clusters = clusters.n_part_per_cluster > 3
+        cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.)  # clusters with N < 3 are excluded
+        num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6  # we add a small number to avoid dividing by 0
+        return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
+
+
+class CostCombinator:
+
+    def __init__(self, cost_fns: list[CostFn], coefficients: list[float]):
+        if len(cost_fns) != len(coefficients):
+            raise ValueError(f'Lengths of cost_fn list and coefficients list should be equal, '
+                             f'got {len(cost_fns)} and {len(coefficients)}, respectively')
+        self.cost_fns = cost_fns
+        self.coefficients = coefficients
+
+    def __str__(self):
+        return f"".join([f"{coef} x {cf} \n" for cf, coef in zip(self.cost_fns, self.coefficients)])
+
+    def __call__(self, body: rigid_body.RigidBody, **interaction_params):
+        return sum(coef * cf(body, **interaction_params) for cf, coef in zip(self.cost_fns, self.coefficients))
+

+ 288 - 0
curvature_assembly/curvature_estimation.py

@@ -0,0 +1,288 @@
+from typing import Callable
+import jax.numpy as jnp
+from functools import partial
+import jax
+from jax_md import space
+from curvature_assembly import util
+import numpy as np
+
+Array = jnp.ndarray
+
+
+def construct_eigensystem(normal: Array) -> Array:
+    """Construct eigensystem where z-axis is given by the surface normal at the given point."""
+    psi = jnp.arccos(normal[2])
+    phi = jnp.arctan2(normal[1], normal[0])
+    return jnp.array([[-jnp.sin(phi), jnp.cos(phi), 0],
+                      [jnp.cos(psi) * jnp.cos(phi), jnp.cos(psi) * jnp.sin(phi), -jnp.sin(psi)],
+                      [normal[0], normal[1], normal[2]]]).T
+
+
+def position_to_local_coordinates(eigensystem: Array, center_position: Array, neighbors_coord_global: Array) -> Array:
+    """
+    Transform coordinates of neighbor particles to the coordinate system centered at center_position
+    and axes given in eigensystem matrix.
+    """
+    return jnp.dot(neighbors_coord_global - center_position[None, :], eigensystem)
+
+
+def normal_to_local_coordinates(eigensystem: Array, neighbors_normal_global: Array) -> Array:
+    """
+    Transform normals at neighbor particle locations to the coordinate system centered at center_position
+    and axes given in eigensystem matrix.
+    """
+    return jnp.dot(neighbors_normal_global, eigensystem)
+
+
+@partial(jnp.vectorize, signature='(d),(d)->()')
+def point_main_curvature(neighbor_coord_local: Array, neighbor_normal_local: Array) -> Array:
+    """Normal curvature estimate given a point and its normal in the coordinate system of the center particle."""
+    # 1e-8 makes it safe on diagonal where dist2d=0
+    dist2d = jnp.sqrt(neighbor_coord_local[0] ** 2 + neighbor_coord_local[1] ** 2) + 1e-8
+    n_xy = (neighbor_coord_local[0] * neighbor_normal_local[0] +
+            neighbor_coord_local[1] * neighbor_normal_local[1]) / dist2d
+    return -n_xy / (jnp.sqrt(n_xy ** 2 + neighbor_normal_local[2] ** 2) * dist2d)
+
+
+def construct_coefficient_matrix(neighbors_local: Array) -> Array:
+    """Construct coefficient matrix for least square fitting of curvature parameters."""
+    thetas = jnp.arctan2(neighbors_local[:, 1], neighbors_local[:, 0])
+    sin_thetas = jnp.sin(thetas)
+    cos_thetas = jnp.cos(thetas)
+
+    coefficient_matrix = jnp.zeros(neighbors_local.shape)
+    coefficient_matrix = coefficient_matrix.at[:, 0].set(cos_thetas ** 2)
+    coefficient_matrix = coefficient_matrix.at[:, 1].set(cos_thetas * sin_thetas)
+    coefficient_matrix = coefficient_matrix.at[:, 2].set(sin_thetas ** 2)
+
+    return coefficient_matrix
+
+
+def principal_curvature(idx: int, coord: Array, normal: Array, neighbors: Array) -> (Array, Array):
+    """
+    Calculate the two principal curvatures at the `idx˙ particle.
+    Algorithm from Zhang et al., "Curvature Estimation of 3D Point Cloud Surfaces Through the Fitting of Normal
+    Section Curvatures", http://www.nlpr.ia.ac.cn/2008papers/gjhy/gh129.pdf
+    """
+
+    eigensystem = construct_eigensystem(normal[idx])
+
+    coord_local = position_to_local_coordinates(eigensystem, coord[idx], coord)
+    normal_local = normal_to_local_coordinates(eigensystem, normal)
+
+    # jax.debug.print("{}", jnp.sum(neighbors[idx]))
+
+    coefficient_matrix = construct_coefficient_matrix(coord_local) * neighbors[idx][:, None]  # we add neighbors mask
+    neighbor_curvatures = point_main_curvature(coord_local, normal_local) * neighbors[idx]
+
+    # jax.debug.print("{}", neighbor_curvatures)
+
+    # we get curvature parameters by least square fitting
+    curve_params, residuals, rank, s = jnp.linalg.lstsq(coefficient_matrix, neighbor_curvatures)
+
+    discriminant_sqrt = jnp.sqrt((curve_params[0] - curve_params[2]) ** 2 + 4 * curve_params[1] ** 2)
+
+    curvature1 = 0.5 * (curve_params[0] + curve_params[2] - discriminant_sqrt)
+    curvature2 = 0.5 * (curve_params[0] + curve_params[2] + discriminant_sqrt)
+
+    # jax.debug.print("{}, {}", curvature1, curvature2)
+
+    return curvature1, curvature2
+
+
+@jax.jit
+def minimum_spanning_tree(normals: Array, neighbors_full: Array) -> Array:
+    """
+    Determine minimum spanning tree for a graph with links based on neighbors and weights (energies) 1 - |ni.nj| using
+    the Prim's algorithm.
+
+    Args:
+        normals: normal vectors at each node, shape (N, 3)
+        neighbors_full: boolean matrix of shape (N, N) where True elements describe neighbor particles
+
+    Returns:
+        minimum spanning tree of the graph
+    """
+
+    num_nodes = normals.shape[0]
+
+    # calculate weights, w[i, j] between 0 and 1
+    weights = 1. - jnp.abs(jnp.einsum('nmk, nmk -> nm', normals[None, :, :], normals[:, None, :]))
+    weights = weights + 10. * (1 - neighbors_full)  # effectively erases connections between nodes
+
+    selected_nodes = jnp.zeros(num_nodes, dtype=bool)
+    selected_nodes = selected_nodes.at[0].set(True)
+
+    node_order = jnp.zeros(num_nodes, dtype=jnp.int32)
+
+    def add_node(carry, i):
+        node_order, selected_nodes, weights = carry
+
+        # find minimum energy link among all links connected to the already selected nodes of the spanning tree
+        # argmin returns first occurrence in flattened array which we then unravel
+        min_idx = jnp.unravel_index(jnp.argmin(weights[node_order]), shape=(num_nodes, num_nodes))
+
+        # the next idx in the spanning tree will be the second element in min_idx
+        selected_nodes = selected_nodes.at[min_idx[1]].set(True)
+        node_order = node_order.at[i].set(min_idx[1])
+
+        # erase possible remaining connections between all selected nodes to prevent formation of loops
+        mask = selected_nodes[:, None] * selected_nodes[None, :]
+        weights += 1. * mask
+
+        return (node_order, selected_nodes, weights), jnp.array([node_order[min_idx[0]], min_idx[1]])
+
+    _, link_list = jax.lax.scan(add_node, init=(node_order, selected_nodes, weights), xs=jnp.arange(1, num_nodes))
+
+    return link_list
+
+
+@jax.jit
+def determine_normals(coordinates: Array, neighbors: Array) -> Array:
+    """
+    Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
+    Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
+    https://dl.acm.org/doi/pdf/10.1145/133994.134011
+    """
+
+    def get_normal(neighbor_mask):
+        n_part = jnp.sum(neighbor_mask)
+        neighbor_coord = coordinates * neighbor_mask[:, None]
+        centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
+        matrix = centered.T @ centered
+        values, vecs = jnp.linalg.eigh(matrix)
+        return vecs[:, 0]  # eigenvector corresponding to the smallest eigenvalue
+
+    normals = jax.vmap(get_normal)(neighbors)
+
+    link_list = minimum_spanning_tree(normals, neighbors)
+
+    def update_normals(normals, idx):
+        normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
+                               lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
+        return normals, 0.
+
+    consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
+
+    return consistent_normals
+
+
+@jax.jit
+def determine_tangent_planes(coordinates: Array, neighbors: Array) -> (Array, Array):
+    """
+    Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
+    Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
+    https://dl.acm.org/doi/pdf/10.1145/133994.134011
+    """
+
+    def get_normal(neighbor_mask):
+        n_part = jnp.sum(neighbor_mask)
+        neighbor_coord = coordinates * neighbor_mask[:, None]
+        centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
+        matrix = centered.T @ centered
+        values, vecs = jnp.linalg.eigh(matrix)
+        return jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8), vecs[:, 0]  # eigenvector corresponding to the smallest eigenvalue
+
+    centers, normals = jax.vmap(get_normal)(neighbors)
+
+    link_list = minimum_spanning_tree(normals, neighbors)
+
+    def update_normals(normals, idx):
+        normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
+                               lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
+        return normals, 0.
+
+    consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
+
+    return centers, consistent_normals
+
+
+def nearest_neighbors_fn(displacement_or_metric: space.DisplacementOrMetricFn,
+                         dist_cutoff: float) -> Callable[[Array], Array]:
+    """
+    Construct a function that determines boolean nearest neighbor matrix with dimensions (N, N)
+    where element neighbors[i,j] is True only if particle j is among num_neighbors nearest neighbors of particle i.
+    The returned array is NOT necessarily symmetric.
+    """
+
+    metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
+    def condition(d):
+        return jax.lax.cond(0. < d < dist_cutoff, lambda x: 1., lambda x: 0., 0)
+
+    # @jax.jit
+    def nearest_neighbors_matrix(coord: Array) -> Array:
+        dist = space.map_product(metric)(coord, coord)
+        # jax.debug.print("Num neighbors: {}", jnp.sum(dist < dist_cutoff, axis=1))
+        # jax.debug.print("Determinant: {}", jnp.linalg.det(util.diagonal_mask(dist < dist_cutoff)))
+        # return util.diagonal_mask(dist < dist_cutoff)
+        return dist < dist_cutoff
+        # return util.diagonal_mask(jnp.asarray(np.random.randint(2, size=(coord.shape[0], coord.shape[0]))))
+
+        # neighbors = jax.vmap(jax.vmap(condition))(dist)
+        # return neighbors
+
+    return nearest_neighbors_matrix
+
+
+def edge_detection_fn(displacement_or_metric: space.DisplacementOrMetricFn,
+                      classification_threshold: float,
+                      num_neighbors: int = 12) -> Callable[[Array, Array], Array]:
+    """
+    Construct edge detection function based on the distance between the center of mass of all the neighbors
+    and the particle of interest. Adapted from: https://arxiv.org/pdf/1809.10468.pdf
+
+    Args:
+        displacement_or_metric: displacement or metric function
+        classification_threshold: factor of resolution (the smallest distance between a particle and one of its
+            neighbors) that determines if the center of mass for all neighboring particles is too far which classifies
+            the particle as an edge particle
+        num_neighbors: number of neighbor particles taken into account
+
+    Returns:
+        Edge detection function that takes the coordinates of all particles and boolean neighboring matrix
+    """
+
+    metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
+
+    @jax.jit
+    def detect_edge_particles(coordinates: Array, neighbors: Array) -> Array:
+        num_particles = coordinates.shape[0]
+        dist_fn = jnp.vectorize(metric, signature='(d),(d)->()')
+
+        def vmap_fn(idx):
+            # fill value in jnp.where doesn't matter as we specify the exact number of neighbors
+            neighbor_coord = coordinates[jnp.where(neighbors[idx], size=num_neighbors, fill_value=num_particles)]
+            neigh_cm = jnp.mean(neighbor_coord, axis=0)
+            dist = dist_fn(neighbor_coord, coordinates[idx])
+            resolution = jnp.min(dist)
+            edge_particle = jax.lax.cond(dist_fn(neigh_cm, coordinates[idx]) > classification_threshold * resolution,
+                                         lambda x: True, lambda x: False, 0.)
+            return edge_particle
+
+        edge_particles = jax.vmap(vmap_fn)(jnp.arange(num_particles))
+
+        return edge_particles
+
+    return detect_edge_particles
+
+
+def gaussian_curvature_fn(displacement_or_metric: space.DisplacementOrMetricFn,
+                          dist_cutoff: float) -> Callable[[Array], Array]:
+    """Construct a function that returns the gaussian curvature at each given point in a point cloud."""
+    nearest_neighbors = nearest_neighbors_fn(displacement_or_metric, dist_cutoff=dist_cutoff)
+
+    def gaussian_curvature(coord: Array) -> Array:
+        num_particles = coord.shape[0]
+        neighbors = nearest_neighbors(coord)
+        # normals = determine_normals(coord, neighbors)
+        centers, normals = determine_tangent_planes(coord, neighbors)
+        # print(coord)
+
+        curvature1, curvature2 = jax.vmap(partial(principal_curvature,
+                                                  coord=centers,  # or coord=centers
+                                                  normal=normals,
+                                                  neighbors=util.diagonal_mask(neighbors)))(jnp.arange(num_particles))
+
+        return curvature1 * curvature2
+
+    return gaussian_curvature

+ 75 - 0
curvature_assembly/data_protocols.py

@@ -0,0 +1,75 @@
+from __future__ import annotations
+from typing import Protocol, Callable, Any
+import jax.numpy as jnp
+from jax_md import rigid_body
+
+Array = jnp.ndarray
+
+
+class SimulationParams(Protocol):
+    """
+    Interface for a container of simulation parameters, i.e. simulation parameters over which we do not
+    intend to differentiate.
+    """
+    num: int
+    density: float
+    simulation_steps: int
+    dt: float
+    config_every: int
+
+
+class InteractionParams(Protocol):
+    """
+    Protocol for a container of interaction parameters. Gradient of the simulation should be taken over these values.
+    """
+
+    eigvals: jnp.ndarray
+
+
+class NeighborListParams(Protocol):
+    """Protocol for a container of neighbor list parameters."""
+
+
+class SimulationLog(Protocol):
+    """Protocol class for storing data during the simulation."""
+
+    current_len: Array
+
+    def calculate_values(self, state, energy_fn: Callable, ellipsoid_mass: rigid_body.RigidBody, kT: float, **params):
+        ...
+
+    def update(self, *args):
+        ...
+
+    def revert_last_nsteps(self, nsteps: int):
+        ...
+
+
+class SimulationStateHistory(Protocol):
+
+    coord: Array
+    orient: Array
+    current_len: Array
+
+    def revert_last_nsteps(self, nsteps: int):
+        ...
+
+
+class SimulationAux(Protocol):
+    """Dataclass for simulation auxiliary data."""
+    log: SimulationLog
+    state_history: SimulationStateHistory
+
+    def reset_empty(self) -> SimulationAux:
+        ...
+
+
+class BpttResults(Protocol):
+    cost: Array
+    grad: InteractionParams
+
+
+BpttSimulation = Callable[[InteractionParams, Any, SimulationAux],
+                           tuple[BpttResults, SimulationAux]]
+
+

+ 87 - 0
curvature_assembly/ellipsoid_contact.py

@@ -0,0 +1,87 @@
+import jax.numpy as jnp
+import jax
+from functools import partial
+
+Array = jnp.ndarray
+
+
+@partial(jnp.vectorize, signature='(d,d)->()')
+def determinant(a):
+    """Determinant of a symmetric 3x3 matrix."""
+    return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) \
+           - a[1, 0] * (a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) \
+           + a[2, 0] * (a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2])
+
+
+@partial(jnp.vectorize, signature='(d,d)->(d,d)')
+def inverse(a):
+    """Inverse of a symmetric 3x3 matrix. Much faster than jnp.linalg.inv."""
+    det = determinant(a)
+    inv = jnp.array([[a[2, 2] * a[1, 1] - a[1, 2] ** 2,
+                      a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
+                      a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1]],
+                     [a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
+                      a[2, 2] * a[0, 0] - a[0, 2] ** 2,
+                      a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2]],
+                     [a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1],
+                      a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2],
+                      a[0, 0] * a[1, 1] - a[0, 1] ** 2]])
+
+    return inv / det
+
+
+def matrix_c(lbd: float, mat1: Array, mat2: Array) -> Array:
+    """Matrix C from the Perram and Wertheim article on ellipsoid contact function."""
+    return inverse(lbd * mat2 + (1 - lbd) * mat1)
+
+
+def perram_wertheim_objective(lbd: float, r12: Array, mat1: Array, mat2: Array) -> Array:
+    c = matrix_c(lbd, mat1, mat2)
+    return lbd * (1 - lbd) * jnp.dot(r12, jnp.dot(c, r12))
+
+
+objective_grad = jax.grad(perram_wertheim_objective, argnums=0)
+
+
+def evaluate_grad_step(carry: float, x: float, r12: Array, mat1: Array, mat2: Array) -> (Array, Array):
+    grad = objective_grad(carry, r12, mat1, mat2)
+    return carry + x * jnp.sign(grad), 0.
+
+
+@partial(jax.jit, static_argnums=(3,))
+def pw_contact_function(r12: Array, mat1: Array, mat2: Array, num_steps: int = 25, **unused_kwargs) -> Array:
+    """
+    Calculate Perram-Wertheim contact function. To ensure jax.gradient compatibility,
+    a dumb gradient-based method is used where a fixed number of steps is taken to calculate the maximum
+    of the objective function. Square root is taken to get linear distance dependence.
+
+    Args:
+        r12: distance vector between ellipsoid centers.
+        mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
+        mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
+        num_steps: number of step in objective maximization. Accuracy improves as 1 / 2^num_steps.
+
+    Returns:
+        Perram-Wertheim contact function
+    """
+    powers = 2 ** (jnp.arange(num_steps) + 2)  # powers of two
+    t_change = 1 / powers
+    t_opt, _ = jax.lax.scan(partial(evaluate_grad_step, r12=r12, mat1=mat1, mat2=mat2), init=0.5, xs=t_change)
+    return jnp.sqrt(perram_wertheim_objective(t_opt, r12, mat1, mat2))
+
+
+def bp_contact_function(r12: Array, mat1: Array, mat2: Array, **unused_kwargs) -> Array:
+    """
+    Calculates Berne-Pechukas contact function which is an approximation for the true Perram-Wertheim contact function
+    at the value of interpolation parameter t = 0.5. Square root is taken to get linear distance dependence.
+
+    Args:
+        r12: distance vector between ellipsoid centers.
+        mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
+        mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
+
+    Returns:
+        Berne-Pechukas contact function
+    """
+
+    return jnp.sqrt(perram_wertheim_objective(0.5, r12, mat1, mat2))

+ 328 - 0
curvature_assembly/energy.py

@@ -0,0 +1,328 @@
+from __future__ import annotations
+from typing import Callable
+import jax.numpy as jnp
+from curvature_assembly import (
+    oriented_particle,
+    data_protocols,
+    patchy_interaction,
+    multipole_interaction,
+)
+from jax_md import energy as jaxmd_energy
+from curvature_assembly.smap import oriented_pair
+from jax_md import partition, space, dataclasses
+
+f32 = jnp.float32
+f64 = jnp.float64
+Array = jnp.ndarray
+
+DisplacementFn = space.DisplacementFn
+ContactFunction = Callable[..., Array]
+NeighborListFormat = partition.NeighborListFormat
+InteractionParams = data_protocols.InteractionParams
+
+
+def weeks_chandler_andersen(
+    dr: Array, sigma: Array = 1.0, epsilon: Array = 1.0, **unused_kwargs
+) -> Array:
+    """Repulsive part of the Lennard-Jones potential."""
+    return jnp.where(
+        dr < jnp.power(2, 1 / 6) * sigma,
+        jaxmd_energy.lennard_jones(dr, sigma=sigma, epsilon=epsilon) + epsilon,
+        0.0,
+    )
+
+
+@dataclasses.dataclass
+class GbWcaParams:
+    eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
+    epsilon: Array = 5.0
+    d0: Array = 10.0
+    sigma: Array = 1.0
+    alpha: Array = 1.0
+    band_theta: Array = jnp.pi / 2
+    band_sigma: Array = 0.5
+
+
+def gaussian_band_wca_ellipsoid_pair(
+    displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
+) -> Callable[..., Array]:
+    contact_function = oriented_particle.get_ellipsoid_contact_function_param(
+        contact_fn, **cf_kwargs
+    )
+
+    def patchy_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
+        epsilon: Array = 5.0,
+        d0: Array = 10,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        band_theta: Array = jnp.pi / 2,
+        band_sigma: Array = 0.5,
+    ) -> Array:
+        cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
+        wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
+        ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
+        patchy_value = patchy_interaction.gaussian_interaction_band(
+            dr, eigsys1, eigsys2, band_theta, band_sigma
+        )
+        # patchy_value = 0.
+        return wca_repulsion + ellipsod_morse * patchy_value
+
+    energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
+
+    return energy_fn
+
+
+def gaussian_band_fh_wca_ellipsoid_pair(
+    displacement: DisplacementFn, contact_fn: ContactFunction, **cf_kwargs
+) -> Callable[..., Array]:
+    contact_function = oriented_particle.get_ellipsoid_contact_function_param(
+        contact_fn, **cf_kwargs
+    )
+
+    def patchy_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
+        epsilon: Array = 5.0,
+        d0: Array = 10,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        band_theta: Array = jnp.pi / 2,
+        band_sigma: Array = 0.5,
+    ) -> Array:
+        cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
+        wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
+        ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
+        patchy_value = patchy_interaction.gaussian_interaction_band_fixed_height(
+            dr, eigsys1, eigsys2, band_theta, band_sigma
+        )
+        # patchy_value = 0.
+        return wca_repulsion + ellipsod_morse * patchy_value
+
+    energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
+
+    return energy_fn
+
+
+@dataclasses.dataclass
+class PatchyWcaParams:
+    eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
+    epsilon: Array = 5.0
+    d0: Array = 10.0
+    sigma: Array = 1.0
+    alpha: Array = 1.0
+    lm_magnitudes: Array = 1
+
+
+def patchy_wca_ellipsoid_pair(
+    displacement: DisplacementFn,
+    contact_fn: ContactFunction,
+    lm: tuple | list[tuple],
+    **cf_kwargs,
+) -> Callable[..., Array]:
+    contact_function = oriented_particle.get_ellipsoid_contact_function_param(
+        contact_fn, **cf_kwargs
+    )
+    patchy_function = patchy_interaction.patchy_interaction_general(lm)
+
+    def patchy_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
+        epsilon: Array = 5.0,
+        d0: Array = 10,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        lm_magnitudes: Array = 1.0,
+    ):
+        cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
+        wca_repulsion = weeks_chandler_andersen(cf, sigma=1.0, epsilon=epsilon)
+        ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
+        patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
+        # patchy_value = 0.
+        return wca_repulsion + ellipsod_morse * patchy_value
+
+    energy_fn = oriented_pair(patchy_wca_ellipsoid, displacement)
+
+    return energy_fn
+
+
+@dataclasses.dataclass
+class QuadWcaParams:
+    eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
+    epsilon: Array = 2.0
+    d0: Array = 10.0
+    q0: Array = 10.0
+    sigma: Array = 1.0
+    alpha: Array = 1.0
+    lm_magnitudes: Array = 1
+
+    def init_unit_volume_particle(self) -> FerroWcaParams:
+        params_dict = vars(self)
+        params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
+            jnp.array([1.0, 1.0, 1.0])
+        )
+        return QuadWcaParams(**params_dict)
+
+    def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
+        params_dict = vars(self)
+        params_dict["lm_magnitudes"] = lm_magnitudes
+        return QuadWcaParams(**params_dict)
+
+
+def quadrupolar_wca_sphere_pair(
+    displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
+) -> Callable[..., Array]:
+    patchy_function = patchy_interaction.patchy_interaction_general(lm)
+
+    def quadrupolar_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        epsilon: Array,
+        # eigvals: Array = jnp.array([1., 1., 1.]),
+        d0: Array = 1,
+        q0: Array = 1,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        lm_magnitudes: Array = 1.0,
+        **unused_kwargs,
+    ):
+        # NOTE: we take unit volume particles
+        # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
+        sigma_particle = sigma
+
+        wca = weeks_chandler_andersen(
+            space.distance(dr), sigma=sigma_particle, epsilon=epsilon
+        )
+        # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
+        quadrupolar = multipole_interaction.lin_quad_energy(
+            dr,
+            eigsys1,
+            eigsys2,
+            multipole_interaction.quadrupolar_eigenvalues(
+                q0 * sigma_particle ** (5 / 2) * jnp.sqrt(epsilon), jnp.pi / 2
+            ),
+        )
+        # NOTE: in quadrupolar eigenvalues calculation, exponent was corrected from 5 to 5/2
+        ellipsod_morse = jaxmd_energy.morse(
+            space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
+        )
+        patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
+        # patchy_value = 0.
+        return wca + quadrupolar + ellipsod_morse * patchy_value
+
+    energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
+
+    return energy_fn
+
+
+@dataclasses.dataclass
+class FerroWcaParams:
+    eigvals: Array = dataclasses.field(default_factory=lambda: jnp.ones((3,)))
+    epsilon: Array = 5.0
+    d0: Array = 1.5
+    q0: Array = 2.0
+    sigma: Array = 1.0
+    alpha: Array = 1.0
+    lm_magnitudes: Array = 1
+    softness: Array = 1.5
+
+    def init_unit_volume_particle(self) -> FerroWcaParams:
+        params_dict = vars(self)
+        params_dict["eigvals"] = oriented_particle.eigenvalues_at_unit_volume(
+            jnp.array([1.0, 1.0, 1.0])
+        )
+        return FerroWcaParams(**params_dict)
+
+    def init_lm_magnitudes(self, lm_magnitudes: Array) -> FerroWcaParams:
+        params_dict = vars(self)
+        params_dict["lm_magnitudes"] = lm_magnitudes
+        return FerroWcaParams(**params_dict)
+
+
+def ferro_wca_sphere_pair(
+    displacement: DisplacementFn, lm: tuple | list[tuple], **cf_kwargs
+) -> Callable[..., Array]:
+    patchy_function = patchy_interaction.patchy_interaction_general(lm)
+
+    def ferro_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        # eigvals: Array = jnp.array([1., 1., 1.]),
+        epsilon: Array = 5.0,
+        d0: Array = 1,
+        q0: Array = 2,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        lm_magnitudes: Array = 1.0,
+        softness: Array = 1.5,
+        **unused_kwargs,
+    ):
+        # NOTE: we take unit volume particles
+        # sigma_particle = 2 * jnp.cbrt(3 / (4 * jnp.pi))
+        sigma_particle = sigma
+
+        wca = weeks_chandler_andersen(
+            space.distance(dr), sigma=sigma_particle, epsilon=epsilon
+        )
+        # vdw = jaxmd_energy.lennard_jones(space.distance(dr), sigma=sigma, epsilon=1.)
+        ferro = multipole_interaction.ferro_orientational_energy(
+            dr, eigsys1, eigsys2, softness=softness
+        )
+        morse = jaxmd_energy.morse(
+            space.distance(dr), epsilon=d0 * epsilon, alpha=alpha, sigma=sigma_particle
+        )
+        patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
+        # patchy_value = 0.
+        return wca + morse * (patchy_value + q0**2 * ferro)
+
+    energy_fn = oriented_pair(ferro_wca_ellipsoid, displacement)
+
+    return energy_fn
+
+
+def quadrupolar_wca_ellipsoid_pair(
+    displacement: DisplacementFn,
+    contact_fn: ContactFunction,
+    lm: tuple | list[tuple],
+    **cf_kwargs,
+) -> Callable[..., Array]:
+    contact_function = oriented_particle.get_ellipsoid_contact_function_param(
+        contact_fn, **cf_kwargs
+    )
+    patchy_function = patchy_interaction.patchy_interaction_general(lm)
+
+    def quadrupolar_wca_ellipsoid(
+        dr: Array,
+        eigsys1: Array,
+        eigsys2: Array,
+        eigvals: Array = jnp.array([1.0, 1.0, 1.0]),
+        epsilon: Array = 5.0,
+        d0: Array = 10,
+        d1: Array = 10,
+        alpha: Array = 1.0,
+        sigma: Array = 1.0,
+        lm_magnitudes: Array = 1.0,
+    ):
+        cf = contact_function(dr, eigsys1, eigsys2, eigvals=eigvals)
+        # wca_repulsion = weeks_chandler_andersen(cf, sigma=1., epsilon=epsilon)
+        vdw = jaxmd_energy.lennard_jones(cf, sigma=1.0, epsilon=epsilon)
+        quadrupolar = multipole_interaction.quadrupolar_interaction(
+            dr, eigsys1, eigsys2, multipole_interaction.quadrupolar_eigenvalues(1.0)
+        )
+        # ellipsod_morse = jaxmd_energy.morse(cf, epsilon=d0, alpha=alpha, sigma=sigma)
+        # patchy_value = patchy_function(dr, eigsys1, eigsys2, lm_magnitudes)
+        # patchy_value = 0.
+        return vdw + d1 * quadrupolar
+
+    energy_fn = oriented_pair(quadrupolar_wca_ellipsoid, displacement)
+
+    return energy_fn

+ 123 - 0
curvature_assembly/file_management.py

@@ -0,0 +1,123 @@
+from pathlib import Path
+import re
+
+class OverwriteError(Exception):
+    pass
+
+
+def overwrite_protection(path: Path) -> None:
+    if path.is_file():
+        raise OverwriteError(f'File {path} already exists and should not be overwritten.')
+    if path.is_dir():
+        raise OverwriteError(f'Directory {path} already exists and should not be overwritten.')
+
+
+def rm_directory(path: Path) -> None:
+    """Removes a directory with its entire content."""
+    for child in path.iterdir():
+        if child.is_file():
+            child.unlink()
+        else:
+            rm_directory(child)
+    path.rmdir()
+
+
+def recursive_dir_empty(path: Path, ignore_top_level_files=False) -> bool:
+    """Check if the directory is empty."""
+    if not path.exists():
+        return True
+
+    # Check if path is empty
+    has_next = next(path.iterdir(), None)
+    if has_next is None:
+        return True
+
+    # Iterate over items in dir_path
+    for item in path.iterdir():
+        if item.is_dir():
+            # Recursively check if subdirectory is empty or contains only empty subdirectories
+            if not recursive_dir_empty(item, ignore_top_level_files=False):  # files in subdirectories are never allowed
+                return False
+        if item.is_file() and not ignore_top_level_files:
+            return False
+
+    return True
+
+
+def split_base_and_num(name: str, sep: str, no_num_return=0):
+    separated = str(name).split(sep)
+    try:
+        num = int(separated[-1])
+        return sep.join(separated[:-1]), num
+    except ValueError:
+        return sep.join(separated), no_num_return
+
+
+def get_unique_filename(file_path: Path) -> Path:
+    # Get the base name and the extension
+    ext = file_path.suffix  # Get the file extension (e.g., .dat)
+    base_name, current_suffix = split_base_and_num(file_path.stem, '_', no_num_return=0)
+    directory = file_path.parent
+    existing_files = list(directory.glob(f"{base_name}_*{ext}"))
+
+    # Extract all existing suffixes (numbers) and find the highest one
+    suffixes = []
+    for file in existing_files:
+        suffix_part = file.stem[len(base_name) + 1:]  # Skip base name and underscore
+        if suffix_part.isdigit():
+            suffixes.append(int(suffix_part))
+
+    # Find the next available suffix
+    if suffixes:
+        new_suffix = max(suffixes) + 1
+    else:
+        new_suffix = 1  # Start with _1 if no suffixes are found
+
+    new_file_name = f"{base_name}_{new_suffix}{ext}"
+
+    return directory / new_file_name
+
+
+def new_folder(folder_name: Path, sep='_', mkdir: bool = True) -> Path:
+    """Create a new folder based on the given folder_name with unique (sequential) number added in the end."""
+
+    folder_name = folder_name.resolve()
+    if folder_name.is_file():
+        raise NameError(f'{folder_name} is not a directory.')
+
+    parent = folder_name.parent
+    base_str, num = split_base_and_num(folder_name.name, sep=sep, no_num_return=0)
+    pattern = re.compile(rf"^{base_str}{sep}\d+$")
+
+    all_directory_nums = []
+    for folder in parent.glob(f'{base_str}*'):
+        if folder.is_dir() and pattern.match(folder.name):
+            _, dir_num = split_base_and_num(folder.name, sep=sep, no_num_return=0)
+            all_directory_nums.append(dir_num)
+
+    try:
+        max_num = max(all_directory_nums)
+    except ValueError:  # if all_directory_nums is empty (i.e. no file with such name exists)
+        max_num = num
+
+    new_folder_name = parent.joinpath(base_str + sep + str(max_num))
+    if new_folder_name.exists(): # and not recursive_dir_empty(new_folder_name, ignore_top_level_files=ignore_files):
+        new_folder_name = parent.joinpath(base_str + sep + str(max_num + 1))
+
+    if mkdir:
+        new_folder_name.mkdir(parents=True, exist_ok=True)
+    return new_folder_name
+
+
+def new_folder_with_number(folder_name: Path, number: int, sep='_') -> Path:
+
+    folder_name = folder_name.resolve()
+    if folder_name.is_file():
+        raise NameError(f'{folder_name} is not a directory.')
+
+    parent = folder_name.parent
+    base_str, num = split_base_and_num(folder_name.name, sep=sep, no_num_return=0)
+
+    new_folder_name = parent.joinpath(base_str + sep + str(number))
+    new_folder_name.mkdir(parents=True, exist_ok=True)
+    return new_folder_name

+ 201 - 0
curvature_assembly/fit.py

@@ -0,0 +1,201 @@
+from __future__ import annotations
+import optax
+from curvature_assembly import data_protocols, pytree_transf, oriented_particle
+import jax
+from typing import Callable, Any
+import jax.numpy as jnp
+
+
+InteractionParams = data_protocols.InteractionParams
+SimulationAux = data_protocols.SimulationAux
+BpttResults = data_protocols.BpttResults
+BpttSimulation = data_protocols.BpttSimulation
+Array = jnp.ndarray
+
+
+def unitwise_clip(g_norm: Array,
+                  max_norm: Array,
+                  grad: Array,
+                  div_eps: float = 1e-6) -> Array:
+  """Applies gradient clipping unit-wise."""
+  # This little max(., div_eps) is distinct from the normal eps and just
+  # prevents division by zero. It technically should be impossible to engage.
+  clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
+  return jnp.where(g_norm < max_norm, grad, clipped_grad)
+
+
+def adaptive_grad_clip(grad, params, clipping: float, eps: float = 1e-3):
+    num_ed = pytree_transf.num_extra_dimensions(grad, params)
+    g_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(grad, keepdims=True, num_ld=num_ed), grad)
+    p_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(params, keepdims=True), grad)
+    # Maximum allowable leaf_norm
+    max_norm = jax.tree_util.tree_map(
+        lambda x: clipping * jnp.maximum(x, eps), p_norm)
+    # If grad leaf_norm > clipping * param_norm, rescale
+    return jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, grad)
+
+
+def get_grad_time_weights(grad: InteractionParams, time_weight_fn: Callable[[Array], Array], time_axis: int = 1):
+    """
+    Apply time-based weights to the gradients of interaction parameters.
+
+    Args:
+        grad: Gradients of interaction parameters, represented as a JAX PyTree.
+        time_weight_fn: A function that computes the time-based weights on a rescaled time interval [0, 1].
+        time_axis: The axis along which the time steps are represented in the `grad` PyTree. Default is 1.
+
+    Returns:
+        Gradients of interaction parameters with time-based weights applied.
+    """
+
+    num_timesteps = pytree_transf.data_length(grad, axis=time_axis)
+    weights = time_weight_fn(jnp.linspace(0, 1, num_timesteps, endpoint=True))
+    mean = jax.lax.cond(jnp.mean(weights) == 0, lambda x: 1., lambda x: jnp.mean(x), weights)
+    normalized_weights = weights / mean
+
+    def apply_weights(x):
+        expand_dims = tuple(i for i in range(len(x.shape)) if i != time_axis)
+        expanded_weights = jnp.expand_dims(normalized_weights, axis=expand_dims)
+        return x * expanded_weights
+
+    return jax.tree_util.tree_map(apply_weights, grad)
+
+
+
+def canonicalize_grad_results(grad: InteractionParams, params: InteractionParams) -> InteractionParams:
+    """
+    Make gradient leaf shapes compatible with interaction params, i.e. we take the average over all extra axes
+    compared to the original params shape.
+    """
+    results_num_ed = pytree_transf.num_extra_dimensions(grad, params)
+    return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=tuple(range(results_num_ed))), grad)
+
+
+def bounds_params(params: InteractionParams, **opt_param_dict: dict[tuple]) -> (InteractionParams, InteractionParams):
+    params_dict = vars(params)
+    lower_bounds = {}
+    upper_bounds = {}
+    for key in params_dict.keys():
+        if key in opt_param_dict:
+            try:
+                lb, ub = opt_param_dict[key]
+                if not jnp.all(lb < ub):
+                    raise ValueError(f"Lower bounds should all be smaller than upper bounds, "
+                                     f"problem with parameter {key}")
+            except TypeError:
+                lb = None
+                ub = None
+            lower_bounds[key] = lb
+            upper_bounds[key] = ub
+        else:
+            lower_bounds[key] = params_dict[key]
+            upper_bounds[key] = params_dict[key]
+    return type(params)(**lower_bounds), type(params)(**upper_bounds)
+
+
+def map_into_bounds(params: InteractionParams,
+                    lower_bounds: InteractionParams | None,
+                    upper_bounds: InteractionParams | None):
+    """Map interaction parameters back into the interval between lower and upper bounds."""
+
+    leaves, treedef = jax.tree_util.tree_flatten(params)
+    # if lower and/or upper bounds are not provided, we must construct pytrees with the same structure as
+    # parameters structure and filled with None. Note that this will fail if parameters pytree contains any
+    # None values (will raise a ValueError).
+    if lower_bounds is None:
+        lower_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
+    if upper_bounds is None:
+        upper_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
+
+    def map_to_bounds(x, xmin, xmax):
+        try:
+            if jnp.any(xmin > xmax):
+                raise ValueError(f'Min bound cannot be larger than max bound, got {xmin} and {xmax}, respectively.')
+        except TypeError:
+            pass
+        if xmin is not None:
+            x = jnp.maximum(x, xmin)
+        if xmax is not None:
+            x = jnp.minimum(x, xmax)
+        return x
+    return jax.tree_util.tree_map(map_to_bounds, params, lower_bounds, upper_bounds)
+
+
+def normalize_param(params: InteractionParams, param_name: str, ord=None) -> InteractionParams:
+    params_dict = vars(params)
+    new_dict = params_dict.copy()  # shallow copy is enough as values (interaction_params elements) are jax arrays
+    new_dict[param_name] = params_dict[param_name] / jnp.linalg.norm(params_dict[param_name], keepdims=True, ord=ord)
+    return type(params)(**new_dict)
+
+
+TIME_WEIGHT_FN = {'constant': lambda x: jnp.ones_like(x),
+                  'linear': lambda x: x,
+                  'quadratic': lambda x: x ** 2,
+                  'exponential': lambda x: jnp.exp(x),
+                  'step_25': lambda x: jnp.heaviside(x - 0.249, 1),
+                  'step_50': lambda x: jnp.heaviside(x - 0.50, 1),
+                  'step_75': lambda x: jnp.heaviside(x - 0.749, 1),
+                  'step_100': lambda x: jnp.heaviside(x - 1., 1),
+                  'neg_linear': lambda x: 1 - x}
+
+
+def fit_bptt(simulation_fn: BpttSimulation,
+             optimizer_update: optax.TransformUpdateFn,
+             clipping: float,
+             grad_time_weights: str = None,
+             param_rescalings: list[Callable[[InteractionParams], InteractionParams]] = None,
+             lower_bounds: InteractionParams = None,
+             upper_bounds: InteractionParams = None,
+             time_axis: int = 1) -> Callable:
+
+    """
+    Construct the step function for meta optimization of parameters in a BPTT simulation.
+
+        Args:
+            simulation_fn: A function that performs the simulation and computes the gradients of interaction parameters.
+            optimizer_update: A function that updates the parameters using the computed gradients.
+            clipping: The maximum value to clip the gradients during training.
+            grad_time_weights: String that then maps into a function that computes time-based weights for the gradients.
+                Default is a function that assigns equal weights (ones) to all time steps.
+            param_rescalings: A list of functions that apply rescalings
+                or transformations to the interaction parameters during training. Default is an empty list.
+            lower_bounds: The lower bounds for the interaction parameters. Default is None.
+            upper_bounds: The upper bounds for the interaction parameters. Default is None.
+            time_axis: The axis along which the time steps are represented in the gradient PyTree. Default is 1.
+
+        Returns:
+            Callable: A step function that performs one training step.
+    """
+
+    if grad_time_weights is None:
+        grad_time_weights = 'constant'
+
+    try:
+        grad_time_weight_fn = TIME_WEIGHT_FN[grad_time_weights]
+    except KeyError:
+        raise ValueError(f'Invalid time weight parameter, {grad_time_weights} is not among the implemented weights.')
+
+    if param_rescalings is None:
+        param_rescalings = []
+
+    param_rescalings.insert(0, oriented_particle.canonicalize_eigvals)
+
+    def step(params: InteractionParams,
+             opt_state: optax.OptState,
+             md_state: Any,
+             aux: SimulationAux) -> (InteractionParams, optax.OptState,
+                                           BpttResults, SimulationAux, InteractionParams):
+
+        aux = aux.reset_empty()
+        bptt_results, aux = simulation_fn(params, md_state, aux)
+        grad_clipped = adaptive_grad_clip(bptt_results.grad, params, clipping)
+        grad_weighted = get_grad_time_weights(grad_clipped, grad_time_weight_fn, time_axis=time_axis)
+        grad_mean = canonicalize_grad_results(grad_weighted, params)
+        updates, opt_state = optimizer_update(grad_mean, opt_state)
+        params = optax.apply_updates(params, updates)
+        for fn in param_rescalings:
+            params = fn(params)
+        params = map_into_bounds(params, lower_bounds, upper_bounds)
+        return params, opt_state, bptt_results, aux, grad_clipped
+
+    return step

+ 135 - 0
curvature_assembly/initial_conditions.py

@@ -0,0 +1,135 @@
+import jax.numpy as jnp
+import jax.random
+import jax
+from jax_md import quantity, rigid_body, space
+from typing import Callable
+from curvature_assembly import monte_carlo, oriented_particle, smap, energy
+
+Array = jnp.ndarray
+
+
+def grid_init(num: int,
+              box_size: float,
+              initial_orient=None
+              ) -> rigid_body.RigidBody:
+    """
+    Initialize a 3D grid of particles within a box of given size.
+
+    Args:
+        num: Number of particles in the grid.
+        box_size: The length of the box in which the grid is placed.
+        initial_orient: Initial orientation of the particles. Default is None,
+            which corresponds to an initial orientation quaternion (1., 0., 0., 0.).
+
+    Returns:
+        A RigidBody object containing the initial positions and orientations of the particles.
+    """
+
+    Nmax = jnp.ceil(jnp.cbrt(num))
+    gridpoints_1d = jnp.arange(Nmax) * box_size / Nmax
+    x = jnp.meshgrid(*(3 * (gridpoints_1d,)))
+    y = jnp.vstack(list(map(jnp.ravel, x))).T
+    position = y[:num]
+
+    if initial_orient is None:
+        initial_orient = jnp.array([1., 0., 0., 0.])
+    orientation = rigid_body.Quaternion(jnp.tile(initial_orient, (num, 1)))
+
+    return rigid_body.RigidBody(position, orientation)
+
+
+def randomize_init_mc(num: int,
+                      density: float,
+                      contact_fn: Callable,
+                      mc_steps: int,
+                      kT: float,
+                      moving_distance: rigid_body.RigidBody = None,
+                      **cf_kwargs
+                      ) -> Callable[[jax.random.KeyArray], monte_carlo.MCMCState]:
+    """
+    Create an MC simulation function that generates random positions and orientations of particles in a simulation box
+    with periodic boundary conditions starting from a grid of particles.
+
+    Args:
+        num: the number of particles in the system
+        density: the density of the system
+        contact_fn: a function that calculates the contact distance between particles
+        mc_steps: the number of Monte Carlo steps to take
+        kT: the temperature parameter for Metropolis criterion
+        moving_distance: a RigidBody object that holds the maximum distance by which a particle can move and
+            reorientate. If not provided, a default scale is set based on the density of the simulation.
+        **cf_kwargs: any additional keyword arguments that should be passed to the contact function
+
+    Returns:
+        A callable function that takes a jax.random.KeyArray and returns a monte_carlo.MCMCState object.
+    """
+
+    box_size = quantity.box_size_at_number_density(num, density, spatial_dimension=3)
+    displacement, shift = space.periodic(box_size)
+
+    if moving_distance is None:
+        # default scale for particle movement is approx 1 / 4 interparticle distance (taking into account particle size)
+        # and default reorientation scale is pi/4
+        moving_distance = rigid_body.RigidBody(0.25 * (jnp.cbrt(1 / density) - jnp.cbrt(2)), jnp.pi / 4)
+
+    energy_fn = oriented_particle.isotropic_to_cf_energy(energy.weeks_chandler_andersen, contact_fn, **cf_kwargs)
+    energy_pair = smap.oriented_pair(energy_fn, displacement)
+    energy_kwargs = {'sigma': 1, 'epsilon': 10}
+
+    init_fn, apply_fn = monte_carlo.mc_mc(shift, energy_pair, kT, moving_distance)
+
+    grid_state = grid_init(num, box_size)
+
+    @jax.jit
+    def scan_fn(state, i):
+        state = apply_fn(state, **energy_kwargs)
+        return state, state.accept
+
+    def mc_simulation(key):
+        init_state = init_fn(key, grid_state)
+        state, accept_array = jax.lax.scan(scan_fn, init=init_state, xs=jnp.arange(mc_steps))
+        # print(jnp.mean(jnp.array(accept_array, dtype=jnp.float32)))
+        return state
+
+    return mc_simulation
+
+
+def rdf(displacement_or_metric: space.DisplacementOrMetricFn,
+        positions: Array,
+        density: float,
+        r_min: float,
+        r_max: float,
+        num_bins: int) -> tuple[Array, Array]:
+    """
+    Calculate the radial distribution function (RDF) of a set of particles in a simulation box.
+
+    Args:
+        displacement_or_metric: Displacement or metric function
+        positions: An array of shape (num_particles, 3) containing the positions of the particles.
+        density: number density of particles in the system
+        r_min: The minimum radial distance to consider in the RDF calculation.
+        r_max: The maximum radial distance to consider in the RDF calculation.
+        num_bins: The number of bins to use in the RDF calculation.
+
+    Returns:
+        An array of shape (num_bins,) containing the midpoints of the radial distance bins and an array
+        of shape (num_bins,) containing the values of the RDF for each bin.
+    """
+
+    # Define the bin edges for the RDF
+    bin_edges = jnp.linspace(r_min, r_max, num_bins + 1)
+
+    # Create a histogram of the pairwise distances between particles
+
+    metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
+    pairwise_distances = space.map_product(metric)(positions, positions)
+    i, j = jnp.triu_indices(pairwise_distances.shape[0], 1)
+    histogram, _ = jnp.histogram(pairwise_distances[i, j].flatten(), bins=bin_edges)
+
+    # Calculate the RDF
+    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
+    bin_volumes = 4 / 3 * jnp.pi * (bin_edges[1:] ** 3 - bin_edges[:-1] ** 3)
+
+    rdf = histogram / (density * bin_volumes * positions.shape[0] / 2)
+
+    return bin_centers, rdf

+ 638 - 0
curvature_assembly/io_functions.py

@@ -0,0 +1,638 @@
+import numpy as np
+from pathlib import Path
+from curvature_assembly import oriented_particle, pytree_transf, file_management
+from curvature_assembly.data_protocols import *
+import json
+from enum import Enum
+from typing import TypeVar
+from functools import partial
+import jax
+from dataclasses import dataclass
+import copy
+from abc import ABC, abstractmethod
+
+
+T = TypeVar('T')
+
+
+SIMULATION_PARAMS_FILENAME = 'simulation_params.json'
+INTERACTION_PARAMS_FILENAME = 'interaction_params.json'
+PARAMS_GRAD_FILENAME = 'param_grad.json'
+PARAMS_GRAD_CLIPPED_FILENAME = 'param_grad_clipped.json'
+NEIGHBOR_LIST_PARAMS_FILENAME = 'neighbor_list_params.json'
+COST_FILENAME = 'cost.dat'
+SIMULATION_LOG_FILENAME = 'simulation_log.npz'
+COORD_HISTORY_FILENAME = 'coord_history.npy'
+ORIENT_HISTORY_FILENAME = 'orient_history.npy'
+BOX_FILENAME = 'box_size.dat'
+
+
+def coord_filename(idx: int):
+    if idx is None:
+        return 'coord.dat'
+    return f'coord_{idx}.dat'
+
+
+def orient_filename(idx: int):
+    if idx is None:
+        return 'orient.dat'
+    return f'orient_{idx}.dat'
+
+
+def weight_matrix_filename(idx: int):
+    if idx is None:
+        return 'weight_matrix.dat'
+    return f'weight_matrix_{idx}.dat'
+
+
+def export_simulation_params(params: SimulationParams, path: Path) -> None:
+    """Export simulation parameters as a dictionary."""
+    with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'w') as f:
+        json.dump(vars(params), f)
+
+
+def load_simulation_params(path: Path) -> dict:
+    """Load simulation parameters as a dictionary."""
+    with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'r') as f:
+        params_dict = json.load(f)
+    return params_dict
+
+
+def convert_arrays_to_lists(params: InteractionParams) -> dict:
+    """
+    Converts jax arrays in InteractionParams instance to lists and returns params as dict.
+    Used for saving params in .json files.
+    """
+    no_array_dict = {}
+    for key, val in vars(params).items():
+        no_array_dict[key] = np.asarray(val).tolist() if isinstance(val, jnp.ndarray) else val
+    return no_array_dict
+
+
+def convert_lists_to_arrays(params_dict: dict, force_float: bool = True) -> dict:
+    """Converts list in a dictionary to jax arrays."""
+    array_dict = {}
+    for key, val in params_dict.items():
+        if force_float and isinstance(val, int):
+            val = float(val)
+        array_dict[key] = jnp.array(val) if isinstance(val, list) else val
+    return array_dict
+
+
+def export_interaction_params(params: InteractionParams, path: Path, filename: str = None) -> None:
+    """Export interaction parameters as a dictionary. Jax arrays are converted to lists."""
+    if filename is None:
+        filename = INTERACTION_PARAMS_FILENAME
+    filename = path.joinpath(filename)
+    file_management.overwrite_protection(filename)
+    with open(filename, 'w') as f:
+        json.dump(convert_arrays_to_lists(params), f)
+
+
+def convert_enum_to_int(params: NeighborListParams) -> dict:
+    no_array_dict = {}
+    for key, val in vars(params).items():
+        no_array_dict[key] = val.value if isinstance(val, Enum) else val
+    return no_array_dict
+
+
+def export_neighbor_list_params(params: NeighborListParams, path: Path) -> None:
+    """Export neighbor list parameters as a dictionary."""
+    filename = path.joinpath(NEIGHBOR_LIST_PARAMS_FILENAME)
+    file_management.overwrite_protection(filename)
+    with open(filename, 'w') as f:
+        json.dump(convert_enum_to_int(params), f)
+
+
+def export_cost(cost: Array, path: Path) -> None:
+    """Export cost function array."""
+    filename = path.joinpath(COST_FILENAME)
+    file_management.overwrite_protection(filename)
+    np.savetxt(filename, cost)
+
+
+export_param_grad = partial(export_interaction_params, filename=PARAMS_GRAD_FILENAME)
+export_param_grad_clipped = partial(export_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
+
+
+def load_interaction_params(path: Path, filename: str = None, convert_arrays=True) -> dict:
+    """Load interaction parameters as a dictionary."""
+    if filename is None:
+        filename = INTERACTION_PARAMS_FILENAME
+    with open(path.joinpath(filename), 'r') as f:
+        params_dict = json.load(f)
+    if convert_arrays:
+        return convert_lists_to_arrays(params_dict)
+    return params_dict
+
+load_param_grad = partial(load_interaction_params, filename=PARAMS_GRAD_FILENAME)
+load_param_grad_clipped = partial(load_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
+
+
+def load_cost(path: Path) -> np.ndarray:
+    """Load cost function array."""
+    filename = path.joinpath(COST_FILENAME)
+    return np.loadtxt(filename)
+
+
+def save_single_config(body: rigid_body.RigidBody, folder: Path, save_idx: int = None) -> None:
+    """General function for saving single config data."""
+    np.savetxt(folder.joinpath(coord_filename(save_idx)).resolve(), body.center)
+    np.savetxt(folder.joinpath(orient_filename(save_idx)).resolve(), body.orientation.vec)
+
+
+def load_single_config(folder: Path, save_idx: int = None) -> rigid_body.RigidBody:
+    """General function for loading single config data."""
+    coord = jnp.asarray(np.loadtxt(folder.joinpath(coord_filename(save_idx)).resolve()))
+    orient = jnp.asarray(np.loadtxt(folder.joinpath(orient_filename(save_idx)).resolve()))
+    return rigid_body.RigidBody(coord, rigid_body.Quaternion(orient))
+
+
+def init_config_folder_name(num: int, density: float) -> str:
+    return f'n{num}rho{int(1000 * density)}'
+
+
+def save_initial_config(body: rigid_body.RigidBody, density: float, idx: int, init_folder: Path) -> None:
+    """Save the initial RigidBody configuration with a given density and index."""
+    save_folder = init_folder.joinpath(init_config_folder_name(body.center.shape[0], density))
+    save_folder.mkdir(exist_ok=True, parents=True)
+    save_single_config(body, save_folder, idx)
+
+
+def load_initial_config(n: int, density: float, idx: int, init_folder: Path) -> rigid_body.RigidBody:
+    """Load the initial RigidBody configuration with a given density and index."""
+    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
+    return load_single_config(save_folder, idx)
+
+
+def load_multiple_initial_configs(n: int, density: float, indices: list[int], init_folder: Path) \
+        -> list[rigid_body.RigidBody]:
+    """Load multiple initial RigidBody configurations with a given density and a list of indices."""
+    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
+    return [load_single_config(save_folder, idx) for idx in indices]
+
+
+def load_multiple_initial_configs_single_object(n: int, density: float, indices: list[int], init_folder: Path,
+                                                coord_rescale_factor: float = None) -> rigid_body.RigidBody:
+    """Load multiple initial RigidBody configurations with a given density and a list of indices as a single object."""
+    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
+    coord = []
+    orient = []
+    for i in indices:
+        coord_i = jnp.asarray(np.loadtxt(save_folder.joinpath(coord_filename(i)).resolve()))
+        if coord_rescale_factor is not None:
+            coord_i *= coord_rescale_factor
+        coord.append(coord_i)
+        orient.append(jnp.asarray(np.loadtxt(save_folder.joinpath(orient_filename(i)).resolve())))
+    return rigid_body.RigidBody(jnp.stack(coord, axis=0), rigid_body.Quaternion(jnp.stack(orient, axis=0)))
+
+
+def simulation_log_data_fields(simulation_log: SimulationLog) -> dict:
+    """Return a dictionary of data fields in a simulation log object, ie ignoring other internal attributes."""
+    data_dict = {}
+    for key, val in vars(simulation_log).items():
+        try:
+            if val.shape[0] == pytree_transf.data_length(simulation_log, ignore_non_array_leaves=True):
+                data_dict[key] = val[jnp.nonzero(val)]  # exclude zero entries that may not have been populated
+        except (AttributeError, IndexError):
+            pass
+    return data_dict
+
+
+def export_simulation_log(simulation_log: SimulationLog,
+                          folder: Path) -> None:
+    """Export simulation log data in a single file."""
+    file_management.overwrite_protection(folder.joinpath(SIMULATION_LOG_FILENAME))
+    data_dict = simulation_log_data_fields(simulation_log)
+    np.savez(folder.joinpath(SIMULATION_LOG_FILENAME), **data_dict)
+
+
+def load_simulation_log(folder: Path) -> dict:
+    """Load simulation log data from file."""
+    npz_file = np.load(folder.joinpath(SIMULATION_LOG_FILENAME))
+    return dict(npz_file)
+
+
+def export_state_history(state_history: SimulationStateHistory, folder: Path) -> None:
+    """Save simulation state history data."""
+    # we exclude array indices that were not populated
+    relevant_indices = jnp.nonzero(jnp.linalg.norm(state_history.coord, axis=(-2, -1)))
+    np.save(folder.joinpath(COORD_HISTORY_FILENAME), state_history.coord[relevant_indices])
+    np.save(folder.joinpath(ORIENT_HISTORY_FILENAME), state_history.orient[relevant_indices])
+
+
+def load_state_history(folder: Path) -> tuple[np.ndarray, np.ndarray]:
+    """Save simulation state history data."""
+    coord = np.load(folder.joinpath(COORD_HISTORY_FILENAME))
+    orient = np.load(folder.joinpath(ORIENT_HISTORY_FILENAME))
+    return coord, orient
+
+
+def export_simulation_state(body: rigid_body.RigidBody,
+                            simulation_params: SimulationParams,
+                            interaction_params: InteractionParams,
+                            folder: Path,
+                            idx: int) -> None:
+    """Export config data along with simulation and interaction parameters used in simulation."""
+    export_simulation_params(simulation_params, folder)
+    export_interaction_params(interaction_params, folder, idx=idx)
+    np.savetxt(folder.joinpath(f'coord_frame{idx}.dat').resolve(), body.center)
+    np.savetxt(folder.joinpath(f'orient_frame{idx}.dat').resolve(), body.orientation.vec)
+    np.savetxt(folder.joinpath(f'weight_matrix_frame{idx}.dat').resolve(),
+               oriented_particle.get_weight_matrices(body.orientation, interaction_params.eigvals).reshape(-1, 9))
+
+
+def direct_visualization_export(body: rigid_body.RigidBody, eigvals: Array, export_folder: Path):
+    coord = body.center
+    weight_matrix = oriented_particle.get_weight_matrices(body.orientation, eigvals).reshape(-1, 9)
+    np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
+    np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
+
+
+def get_config_for_visualization(results_folder: Path,
+                                  export_folder: Path,
+                                  idx: int = -1,
+                                  zero_cm: bool = True) -> None:
+    interaction_params = load_interaction_params(results_folder)
+    coord, orient = load_state_history(results_folder)
+    if zero_cm:
+        coord = cannonicalize_cm(coord)
+    weight_matrix = oriented_particle.get_weight_matrices(
+        rigid_body.Quaternion(orient[idx]), interaction_params['eigvals']).reshape(-1, 9)
+    np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
+    np.savetxt(export_folder.joinpath(f'orient.dat').resolve(), orient)
+    np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
+
+
+def cannonicalize_cm(coord: Array) -> Array:
+    cm = jnp.mean(coord, axis=-2)
+    return coord - cm[..., None, :]
+
+
+def prepare_animation_data(results_folder: Path, interaction_params: dict,
+                           get_every: int = 1, zero_cm=True) -> dict[str, np.ndarray]:
+
+    coord_history, orient_history = load_state_history(results_folder)
+    if zero_cm:
+        coord_history = cannonicalize_cm(coord_history)
+    coord_history = coord_history[::get_every]
+    orient_history = orient_history[::get_every]
+
+    def weight_matrix_frame(quaternion_vec):
+        return oriented_particle.get_weight_matrices(
+            rigid_body.Quaternion(quaternion_vec), interaction_params['eigvals']).reshape(-1, 9)
+
+    weight_matrix_hist = jax.vmap(weight_matrix_frame)(orient_history)
+    eigensystem = oriented_particle.eigensystem(rigid_body.Quaternion(orient_history))
+
+    return {'coord': coord_history, 'weight_matrix': weight_matrix_hist, 'eigensystem': eigensystem}
+
+
+def export_animation_data(anim_data: dict[str, np.ndarray], box: float, export_folder: Path) -> None:
+    """Saves animation data matrices."""
+    np.save(export_folder.joinpath('anim_coord'), anim_data['coord'])
+    np.save(export_folder.joinpath('anim_weight_matrix'), anim_data['weight_matrix'])
+    np.save(export_folder.joinpath('anim_eigensystem'), anim_data['eigensystem'])
+    np.savetxt(export_folder.joinpath(BOX_FILENAME), np.asarray(box).reshape(1,))
+
+
+def export_cost_and_grad(cost: float,
+                         grad: InteractionParams | None,
+                         folder: Path,
+                         idx: int) -> None:
+    """
+    Export gradient data into a new file and append cost function value to an existing one. If gradient data
+    was not calculated, pass None to the function.
+    """
+    with open(folder.joinpath('cost_function.dat'), 'a') as f:
+        f.writelines(f'{cost: .4f}\n')
+    if grad is not None:
+        export_param_grad(grad, folder, idx=idx)
+
+
+class OptimizationSaver:
+
+    def __init__(self, folder: Path, simulation_params: SimulationParams,
+                 overwrite_folder_with_no_results=False,
+                 folder_num: int = None):
+        if folder_num is not None:
+            self.base_folder = file_management.new_folder_with_number(folder, folder_num)
+        else:
+            self.base_folder = file_management.new_folder(folder)
+        export_simulation_params(simulation_params, self.base_folder)
+        self._export_results_happened = False
+        self._export_inter_params_happened = False
+        self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter_0'))
+
+    def _get_iter_folder(self, check_happened: bool) -> Path:
+        if check_happened:
+            self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter'))
+            self._export_results_happened = False
+            self._export_inter_params_happened = False
+        return self._iter_folder
+
+    def _get_config_folder(self, config_idx: int) -> Path:
+        folder = self._iter_folder.joinpath(f'config_{config_idx}')
+        folder.mkdir(exist_ok=True)
+        return folder
+
+    def export_interaction_params(self, interaction_params: InteractionParams) -> None:
+        folder = self._get_iter_folder(self._export_results_happened)
+        export_interaction_params(interaction_params, folder)
+        self._export_inter_params_happened = True
+
+    def export_param_updates(self, updates: InteractionParams) -> None:
+        folder = self._get_iter_folder(self._export_results_happened)
+        export_interaction_params(updates, folder, filename='interaction_param_updates.json')
+        self._export_inter_params_happened = True
+
+    def export_run_params(self, run_params: dict):
+        with open(self.base_folder.joinpath(f'run_params.json'), 'w') as f:
+            json.dump(run_params, f)
+
+    def export_additional_simulation_data(self, data: dict):
+        with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'w') as f:
+            json.dump(data, f)
+
+    def export_cost_function_info(self, cost_fn):
+        with open(self.base_folder.joinpath(f'cost_function_info.dat'), 'w') as f:
+            f.write(str(cost_fn))
+
+    def export_results(self, bptt_results: BpttResults, aux: SimulationAux) -> None:
+        folder = self._get_iter_folder(self._export_results_happened)
+        try:
+            export_cost(bptt_results.cost, folder)
+            export_param_grad(bptt_results.grad, folder)
+            export_simulation_log(aux.log, folder)
+            export_state_history(aux.state_history, folder)
+            self._export_results_happened = True
+        except ValueError:
+            raise ValueError('For exporting multiple results, use method "export_multiple_results".')
+
+    def export_multiple_results(self,
+                                bptt_results: BpttResults,
+                                aux: SimulationAux) -> None:
+        self._get_iter_folder(self._export_results_happened)
+        bptt_results_list = pytree_transf.split_to_list(bptt_results)
+        aux_list = pytree_transf.split_to_list(aux)
+        for config_idx, (result, a) in enumerate(zip(bptt_results_list, aux_list)):
+            folder = self._get_config_folder(config_idx)
+            export_cost(result.cost, folder)
+            export_param_grad(result.grad, folder)
+            export_simulation_log(a.log, folder)
+            export_state_history(a.state_history, folder)
+        self._export_results_happened = True
+
+    def export_clipped_gradients(self, grad_clipped: InteractionParams):
+        grad_clipped_list = pytree_transf.split_to_list(grad_clipped)
+        for config_idx, grad in enumerate(grad_clipped_list):
+            folder = self._get_config_folder(config_idx)
+            export_param_grad_clipped(grad, folder)
+
+
+class NoResultsError(Exception):
+    pass
+
+
+class OptimizationLoader:
+    """Convenience class to load results of an optimization simulation."""
+
+    def __init__(self, folder: Path):
+        self.base_folder = folder.resolve()
+        if not self.base_folder.exists():
+            raise NoResultsError(f"Results folder {self.base_folder} does not exist.")
+
+    def get_results_folder(self, iter_idx: int, config_idx: int = None):
+        if iter_idx < 0:
+            iter_idx = self.all_iter_indices()[iter_idx]
+        if config_idx is None:
+            return self.base_folder.joinpath(f'iter_{iter_idx}')
+        return self.base_folder.joinpath(f'iter_{iter_idx}').joinpath(f'config_{config_idx}')
+
+    def last_iter_idx(self):
+        try:
+            return self.all_iter_indices()[-1]
+        except IndexError:
+            return 0
+
+    def all_config_indices(self, iter_idx: int = None) -> list:
+        if iter_idx is None:
+            iter_idx = self.last_iter_idx()
+        iteration_folders = [folder for folder in self.get_results_folder(iter_idx).glob(f'config_*')]
+        all_directory_nums = []
+        for folder in iteration_folders:
+            _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
+            all_directory_nums.append(dir_num)
+        return sorted(all_directory_nums)
+
+    def num_replicas(self):
+        return len(self.all_config_indices())
+
+    def all_iter_indices(self) -> list:
+        iteration_folders = [folder for folder in self.base_folder.glob(f'iter_*')]
+        all_directory_nums = []
+        for folder in iteration_folders:
+            _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
+            if not file_management.recursive_dir_empty(folder, ignore_top_level_files=True):
+                all_directory_nums.append(dir_num)
+        return sorted(all_directory_nums)
+
+    def load_simulation_params(self, iter_idx: int = None, config_idx: int = None) -> dict:
+        return load_simulation_params(self.base_folder)
+
+    def box_size_at_number_density(self):
+        simulation_params = self.load_simulation_params()
+        return oriented_particle.box_size_at_number_density(simulation_params["num"],
+                                                   simulation_params["density"],
+                                                   spatial_dimension=3)
+
+    def box_size_at_ellipsoid_density(self, iter_idx: int = None):
+        simulation_params = self.load_simulation_params()
+        if iter_idx is None:
+            iter_idx = self.last_iter_idx()
+        interaction_params = self.load_interaction_params(iter_idx)
+        return oriented_particle.box_size_at_ellipsoid_density(simulation_params["num"],
+                                                               simulation_params["density"],
+                                                               interaction_params["eigvals"])
+
+    def box_size(self, iter_idx: int = None):
+        if iter_idx is None:
+            iter_idx = self.last_iter_idx()
+        interaction_params = self.load_interaction_params(iter_idx)
+        particle_volume = oriented_particle.ellipsoid_volume(interaction_params["eigvals"])
+        if jnp.all(jnp.isclose(particle_volume, 1., atol=1e-4)):
+            return self.box_size_at_number_density()
+        return self.box_size_at_ellipsoid_density(iter_idx=iter_idx)
+
+    def load_additional_simulation_data(self) ->  dict:
+        with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'r') as f:
+            data = json.load(f)
+        return data
+
+    def load_run_params(self) -> dict:
+        with open(self.base_folder.joinpath(f'run_params.json'), 'r') as f:
+            run_params = json.load(f)
+        return run_params
+
+    def load_interaction_params(self, iter_idx: int, config_idx: int = None, convert_arrays=True) -> dict:
+        return load_interaction_params(self.get_results_folder(iter_idx), convert_arrays=convert_arrays)
+
+    def load_multiple_interaction_params(self, iter_indices: list = None) -> dict:
+        if iter_indices is None:
+            iter_indices = self.all_iter_indices()
+        return pytree_transf.stack([self.load_interaction_params(iter_idx) for iter_idx in iter_indices])
+
+    def load_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
+        return load_param_grad(self.get_results_folder(iter_idx, config_idx))
+
+    def load_clipped_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
+        return load_param_grad_clipped(self.get_results_folder(iter_idx, config_idx))
+
+    def load_multiple_gradients(self, iter_idx: int, config_indices: list = None) -> list[dict, ...]:
+        if config_indices is None:
+            config_indices = self.all_config_indices()
+        return [self.load_gradient(iter_idx, config_idx) for config_idx in config_indices]
+
+    def load_cost(self, iter_idx: int, config_idx: int = None) -> np.ndarray:
+        return load_cost(self.get_results_folder(iter_idx, config_idx))
+
+    def load_cost_all_config(self, iter_idx: int):
+        config_indices = self.all_config_indices()
+        config_costs = []
+        for config_idx in config_indices:
+            config_costs.append(load_cost(self.get_results_folder(iter_idx, config_idx)))
+        return np.stack(config_costs)
+
+    def load_cost_all(self) -> np.ndarray:
+        inter_indices = self.all_iter_indices()
+        all_costs = []
+        for iter_idx in inter_indices:
+            all_costs.append(self.load_cost_all_config(iter_idx))
+        return np.stack(all_costs)
+
+    def load_simulation_log(self, iter_idx: int, config_idx: int = None) -> dict:
+        return load_simulation_log(self.get_results_folder(iter_idx, config_idx))
+
+    def load_body(self, iter_idx: int, config_idx: int, time_idx: int) -> rigid_body.RigidBody:
+        trajectory = self.load_trajectory(iter_idx=iter_idx, config_idx=config_idx)
+        return trajectory[time_idx]
+
+    def load_multiple_bodies(self, iter_idx: int, time_idx: int = -1, config_indices: list[int] = None):
+        if config_indices is None:
+            config_indices = self.all_config_indices(iter_idx)
+        return pytree_transf.stack([self.load_body(iter_idx, config_idx, time_idx) for config_idx in config_indices])
+
+    def load_trajectory(self, iter_idx: int, config_idx: int) -> rigid_body.RigidBody:
+        coord, orient = load_state_history(self.get_results_folder(iter_idx, config_idx))
+        return rigid_body.RigidBody(jnp.asarray(coord), rigid_body.Quaternion(jnp.asarray(orient)))
+
+    def load_multiple_config_trajectories(self, iter_idx: int, config_indices: list[int]) -> rigid_body.RigidBody:
+        return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for config_idx in config_indices])
+
+    def load_all_iter_trajectories(self, config_idx: int) -> rigid_body.RigidBody:
+        iter_indices = self.all_iter_indices()
+        return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for iter_idx in iter_indices])
+
+    def get_config_for_visualization(self, iter_idx: int, config_idx: int, export_folder: Path,
+                                     frame_idx: int = -1, zero_cm: bool = True):
+        return get_config_for_visualization(self.get_results_folder(iter_idx, config_idx), export_folder,
+                                            idx=frame_idx, zero_cm=zero_cm)
+
+    def export_animation_data(self, iter_idx: int, config_idx: int, export_folder: Path,
+                              get_every: int = 1, zero_cm: bool = False) -> None:
+        interaction_params = self.load_interaction_params(iter_idx)
+        anim_data = prepare_animation_data(self.get_results_folder(iter_idx, config_idx), interaction_params,
+                                           get_every=get_every, zero_cm=zero_cm)
+        box_size = self.box_size(iter_idx=iter_idx)
+        export_animation_data(anim_data, box_size, export_folder)
+
+
+@dataclass(frozen=True)
+class PlotData(ABC):
+    """Abstract class for saving and loading plot data. Children should add attributes for this data."""
+
+    run_params: dict
+    results_path: Path
+
+    @property
+    @abstractmethod
+    def cluster_data_folder(self) -> str:
+        """Subclasses must define a folder name for saving/loading data."""
+        pass
+
+    @classmethod
+    @abstractmethod
+    def calculate_data(cls, results_folder: Path, **kwargs) -> "PlotData":
+        """A method that fills the class with data."""
+        pass
+
+    @classmethod
+    def get_save_path(cls, results_path) -> Path:
+        results_base = results_path.parent
+        results_filename = f'{results_path.name}.json'
+        save_path_base = results_base / cls.cluster_data_folder
+        save_path_base.mkdir(exist_ok=True)
+        return save_path_base / results_filename
+
+    def save(self) -> None:
+        self_as_dict = vars(copy.deepcopy(self))
+        for key in self_as_dict:
+            if isinstance(self_as_dict[key], jnp.ndarray):
+                self_as_dict[key] = np.asarray(self_as_dict[key]).tolist()
+        self_as_dict['results_path'] = str(self_as_dict['results_path'])
+        with open(self.get_save_path(self.results_path), 'w') as f:
+            json.dump(self_as_dict, f)
+
+    @classmethod
+    def load(cls, results_path: Path) -> "PlotData":
+        cluster_data_path = cls.get_save_path(results_path)
+        if not cluster_data_path.exists():
+            raise FileNotFoundError("Clustering results for the given path not yet exported.")
+        with open(cluster_data_path, 'r') as f:
+            cls_as_dict = json.load(f)
+        for key in cls_as_dict:
+            if isinstance(cls_as_dict[key], list):
+                cls_as_dict[key] = jnp.asarray(cls_as_dict[key])
+        cls_as_dict['results_path'] = Path(cls_as_dict['results_path'])
+        return cls(**cls_as_dict)
+
+    @classmethod
+    def get_data(cls,
+                 results_folder: Path,
+                 recalculate: bool = False,
+                 **calculate_data_kwargs) -> "PlotData":
+        if recalculate:
+            print(f'Recalculating results for {results_folder}...')
+            csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
+            csd.save()
+            return csd
+        try:
+            csd = cls.load(results_folder)
+        except FileNotFoundError:
+            print(f'Results for folder {results_folder} not yet exported, calculating...')
+            csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
+            csd.save()
+        return csd
+
+
+def figure_file_name(base_file_name: str, results_folder: Path, *, iter_idx: int = None, config_idx: int = None,
+                     figure_folder: Path = Path("/home/andraz/CurvatureAssemblyFigures")):
+    figure_folder.mkdir(exist_ok=True)
+    results_folder_base_name, results_idx = file_management.split_base_and_num(results_folder.name, sep='_', no_num_return='')
+    base = Path(base_file_name).stem
+    suffix = Path(base_file_name).suffix
+    fig_folder = figure_folder.joinpath(f'{results_folder_base_name}_{results_idx}')
+    fig_folder.mkdir(exist_ok=True)
+    if iter_idx is None and config_idx is None:
+        return fig_folder.joinpath(f'{base}{suffix}')
+
+    fig_folder = fig_folder.joinpath(f'{base}')
+    fig_folder.mkdir(exist_ok=True)
+    if config_idx is None:
+        return fig_folder.joinpath(f'{base}_iter{iter_idx}{suffix}')
+    if iter_idx is None:
+        return fig_folder.joinpath(f'{base}_config{config_idx}{suffix}')
+
+    fig_folder = fig_folder.joinpath(f'iter{iter_idx}')
+    fig_folder.mkdir(exist_ok=True)
+    return fig_folder.joinpath(f'{base}_iter{iter_idx}_config{config_idx}{suffix}')

+ 107 - 0
curvature_assembly/monte_carlo.py

@@ -0,0 +1,107 @@
+import jax.numpy as jnp
+import jax
+from jax_md import dataclasses, space, rigid_body
+from typing import Callable, TypeVar, Any
+import functools
+
+
+Array = jnp.ndarray
+T = TypeVar('T')
+InitFn = Callable[..., T]
+ApplyFn = Callable[[T], T]
+
+
+def random_unit_vector(key):
+    key, split = jax.random.split(key)
+    x1, x2 = jax.random.uniform(split, (2,), dtype=jnp.float64)
+    phi = 2 * jnp.pi * x1
+    cos_theta = 2 * x1 - 1
+    sin_theta = jnp.sqrt(1 - cos_theta ** 2)
+    sin_phi = jnp.sin(phi)
+    cos_phi = jnp.cos(phi)
+    return jnp.array([cos_phi * sin_theta, sin_phi * sin_theta, cos_theta])
+
+
+def random_quaternion(key, max_rotation):
+    key, axis_key, angle_key = jax.random.split(key, 3)
+    axis = random_unit_vector(axis_key)
+    angle = max_rotation * jax.random.uniform(angle_key, ())
+    sin_angle_2 = jnp.sin(angle / 2)
+    cos_angle_2 = jnp.cos(angle / 2)
+    q = jnp.array([cos_angle_2, sin_angle_2 * axis[0], sin_angle_2 * axis[1], sin_angle_2 * axis[2]])
+    return rigid_body.Quaternion(q)
+
+
+@functools.singledispatch
+def mc_move(position: Array, idx: int, key: jax.random.KeyArray, moving_distance: Array, shift: space.ShiftFn) -> Array:
+    move = moving_distance * random_unit_vector(key)
+    return position.at[idx].set(shift(position[idx], move))
+
+
+@mc_move.register(rigid_body.RigidBody)
+def _(position: rigid_body.RigidBody,
+      idx: int,
+      key: jax.random.KeyArray,
+      moving_distance: rigid_body.RigidBody,
+      shift: space.ShiftFn) -> rigid_body.RigidBody:
+
+    key, position_key, orientation_key = jax.random.split(key, 3)
+    position_move = moving_distance.center * jax.random.normal(key, (3,))
+    orientation_move = random_quaternion(orientation_key, moving_distance.orientation)
+
+    new_position = position.center.at[idx].set(shift(position.center[idx], position_move))
+    new_orientation_vec = position.orientation.vec.at[idx].set((orientation_move * position.orientation[idx]).vec)
+
+    return rigid_body.RigidBody(new_position, rigid_body.Quaternion(new_orientation_vec))
+
+
+@functools.singledispatch
+def num_particles(position: Array):
+    return position.shape[0]
+
+
+@num_particles.register(rigid_body.RigidBody)
+def _(position: rigid_body.RigidBody):
+    return position.center.shape[0]
+
+
+@dataclasses.dataclass
+class MCMCState:
+    position: Any
+    key: Array
+    accept: bool
+
+
+def mc_mc(shift: space.ShiftFn,
+          energy_fn: Callable[..., Array],
+          kT: float,
+          moving_distance: Array
+          ) -> (InitFn, ApplyFn):
+
+    def init_fn(key, position) -> MCMCState:
+        return MCMCState(position, key, False)
+
+    def apply_fn(state: MCMCState, **kwargs) -> MCMCState:
+
+        position = state.position
+        N = num_particles(position)
+
+        # Move random particle for a random amount
+        key, particle_key, move_key, accept_key = jax.random.split(state.key, 4)
+        idx = jax.random.randint(particle_key, (2,), jnp.array(0), jnp.array(N))
+        new_position = mc_move(position, idx, move_key, moving_distance, shift)
+
+        # Compute the energy before the swap.
+        energy = energy_fn(position, **kwargs)
+
+        # Compute the energy after the swap.
+        new_energy = energy_fn(new_position, **kwargs)
+
+        # Accept or reject with a metropolis probability.
+        p = jax.random.uniform(accept_key, ())
+        accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT))
+        position = jax.lax.cond(p < accept_prob, lambda x: x[0], lambda x: x[1], [new_position, position])
+
+        return MCMCState(position, key, p < accept_prob)
+
+    return init_fn, apply_fn

+ 78 - 0
curvature_assembly/multipole_interaction.py

@@ -0,0 +1,78 @@
+import jax.numpy as jnp
+import jax
+from curvature_assembly import oriented_particle
+
+Array = jnp.ndarray
+
+
+def quadrupolar_eigenvalues(q0: Array, theta: Array) -> Array:
+    return q0 * jnp.array([(jnp.cos(theta) + 3) / 4, (jnp.cos(theta) - 3) / 4, -jnp.cos(theta) / 2])
+
+
+def quadrupolar_interaction(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:
+    """General quadrupolar interaction. However, it is really slow to evaluate in gradient-based simulations."""
+
+    distance2 = jnp.sum(dr ** 2)
+    distance4 = distance2 ** 2
+    distance = jnp.sqrt(distance2)
+
+    qf1 = oriented_particle.qf_from_rotation(eigsys1, oriented_particle.make_diagonal(eigvals))
+    qf2 = oriented_particle.qf_from_rotation(eigsys2, oriented_particle.make_diagonal(eigvals))
+
+    dr2 = jax.lax.dot_general(dr, dr, dimension_numbers=(((), ()), ((), ())))
+    dr4 = jax.lax.dot_general(dr2, dr2, dimension_numbers=(((), ()), ((), ())))
+
+    term1 = jnp.einsum('ijkl, ij, kl', dr4, qf1, qf2)
+    term2 = jnp.einsum('jk, ij, ik', dr2, qf1, qf2)
+    term3 = jnp.einsum('ij, ij', qf1, qf2)
+
+    return 1 / (3 * distance ** 5) * (35 * term1 / distance4 - 20 * term2 / distance2 + 2 * term3)
+
+
+def lin_quad_energy(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array):
+    """Interaction between two linear quadrupoles with eigenvalues [1. -1, 0] in this exact order."""
+
+    q0 = eigvals[0]
+
+    mi = eigsys1[:, 0]
+    ni = eigsys1[:, 1]
+
+    mj = eigsys2[:, 0]
+    nj = eigsys2[:, 1]
+
+    dist = jnp.sqrt(jnp.sum(dr * dr))
+    rij_hat = dr / dist
+
+    mi_rij = jnp.sum(mi * rij_hat)
+    mj_rij = jnp.sum(mj * rij_hat)
+    ni_rij = jnp.sum(ni * rij_hat)
+    nj_rij = jnp.sum(nj * rij_hat)
+    mi_mj = jnp.sum(mi * mj)
+    ni_nj = jnp.sum(ni * nj)
+    mi_nj = jnp.sum(mi * nj)
+    ni_mj = jnp.sum(ni * mj)
+
+    Aij = mi_rij ** 2 * mj_rij ** 2 - mi_rij ** 2 * nj_rij ** 2 - ni_rij ** 2 * mj_rij ** 2 + ni_rij ** 2 * nj_rij ** 2
+    Bij = mi_mj * mi_rij * mj_rij - mi_nj * mi_rij * nj_rij - ni_mj * ni_rij * mj_rij + ni_nj * ni_rij * nj_rij
+    Cij = mi_mj ** 2 - mi_nj ** 2 - ni_mj ** 2 + ni_nj ** 2
+
+    return q0 ** 2 / (3 * dist ** 5) * (35 * Aij - 20 * Bij + 2 * Cij)
+
+
+def ferro_orientational_energy(dr: Array, eigsys1: Array, eigsys2: Array, softness: float = 1.5):
+    """
+    Ferromagnetic-like interaction between a pair of particles. Must be combined with some distance based term.
+    Softness parameter is a factor that scales the second term of the expansion and relates to the energy sensitivity
+    on deviations from the parallel configuration for side by side particles. Lower values mean more stiff potential.
+    Increasing it too much can lead to the preference for dipolar-like ordering (at softness = 3, effects notable at
+    softness >= 2).
+    """
+    pi = eigsys1[:, 2]
+    pj = eigsys2[:, 2]
+
+    dist = jnp.sqrt(jnp.sum(dr * dr))
+    rij_hat = dr / dist
+
+    # positive values for attraction as added distance based term should make the entire energy negative
+    return jnp.sum(pi * pj) - softness * jnp.sum(pi * rij_hat) * jnp.sum(pj * rij_hat)
+

+ 214 - 0
curvature_assembly/oriented_particle.py

@@ -0,0 +1,214 @@
+from typing import Protocol, Callable, TypeVar
+import jax
+import jax.numpy as jnp
+from curvature_assembly import data_protocols
+from jax_md import rigid_body, energy, quantity
+from functools import partial
+
+Array = jnp.ndarray
+T = TypeVar('T')
+
+
+@partial(jnp.vectorize, signature='(d,d),(d,d)->(d,d)')
+def qf_from_rotation(rotation: Array, eigen_qf: Array) -> Array:
+    """Get particle quadratic form in world frame given the rotation matrix that describes eigensystem orientation."""
+    return jnp.linalg.multi_dot((rotation, eigen_qf, jnp.transpose(rotation)))
+
+
+@partial(jnp.vectorize, signature='(d)->(d,d)')
+def make_diagonal(eigvals: Array) -> Array:
+    """Create diagonal matrix from an 1D array of length 3."""
+    a, b, c = eigvals
+    return jnp.array([[a, 0, 0],
+                      [0, b, 0],
+                      [0, 0, c]])
+
+
+def eigensystem(orientation: rigid_body.Quaternion) -> Array:
+    """Get eigensystem matrix with eigenvectors as columns."""
+    return jnp.moveaxis(rigid_body.space_to_body_rotation(orientation), -1, -2)
+
+
+def matrix_repr(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
+    """Quadratic form of the oriented particle given the matrix eigenvalues and quaternion orientation."""
+    return qf_from_rotation(eigensystem(orientation), make_diagonal(eigvals))
+
+
+def get_weight_matrices(orientation: rigid_body.Quaternion, eigvals: Array) -> Array:
+    """Weight matrices of the rigid body with squared semi-axes lengths as matrix eigenvalues."""
+    return matrix_repr(orientation, 1 / eigvals)
+
+
+@partial(jnp.vectorize, signature='(),(d)->(d)')
+def ellipsoid_moment_of_inertia(m, eigvals):
+    eig1, eig2, eig3 = eigvals
+    a2 = 1 / eig1
+    b2 = 1 / eig2
+    c2 = 1 / eig3
+    return m / 5 * jnp.array([b2 + c2, a2 + c2, a2 + b2])
+
+
+def ellipsoid_mass(masses, eigvals) -> rigid_body.RigidBody:
+    """Get an Ellipsoid with the mass and moment of inertia for each particle."""
+    return rigid_body.RigidBody(masses, ellipsoid_moment_of_inertia(masses, eigvals))
+
+
+def contact_to_distance_cutoff(cf_cut: float, eigvals: Array) -> float:
+    """
+    Calculate a sufficient distance cutoff from the contact function cutoff.
+    Contact function should be the square root of the Perram-Wertheim contact function.
+    """
+    return 2 / jnp.sqrt(jnp.min(eigvals)) * cf_cut
+
+
+def contact_to_distance_threshold(cf_cut: float, cf_theshold: float, eigvals: Array) -> float:
+    """Map from threshold in contact function to the distance threshold. We take the minimal distance
+    that comes from the particle move for cf_threshold at the very edge of the function range."""
+    return contact_to_distance_cutoff(cf_cut, eigvals) - contact_to_distance_cutoff(cf_cut - cf_theshold, eigvals)
+
+
+def distance_to_contact_cutoff(r_cut: float, eigvals: Array) -> float:
+    """
+    Calculate a sufficient contact function cutoff from the distance cutoff.
+    Contact function value returned corresponds to the square root of the Perram-Wertheim contact function.
+    """
+    return jnp.min(eigvals) * r_cut / 2
+
+
+def eigenvalues_at_unit_volume(eigenvalues: Array) -> Array:
+    """Rescales the eigenvalues to get unit volume ellipsoids."""
+    particle_volume = 4 * jnp.pi / 3 * jnp.prod(1 / jnp.sqrt(eigenvalues))
+    return jnp.cbrt(particle_volume) ** 2 * eigenvalues
+
+
+def eigenvalues_to_semiaxes(eigenvalues: Array) -> Array:
+    """Calculate ellipsoid semiaxes from eigenvalues."""
+    return jnp.sort(1 / jnp.sqrt(eigenvalues))
+
+
+def canonicalize_eigvals(interaction_params: T) -> T:
+    """
+    Create a new InteractionParams instance with transformed eigenvalues
+    so that they correspond to unit volume ellipsoidal particles.
+    """
+    params_dict = vars(interaction_params)
+    new_dict = params_dict.copy()  # shallow copy is enough as values (interaction_params elements) are jax arrays
+    new_dict['eigvals'] = eigenvalues_at_unit_volume(params_dict['eigvals'])
+    return type(interaction_params)(**new_dict)
+
+
+def box_size_at_number_density(particle_count: int,
+                               number_density: float,
+                               spatial_dimension: int = 3):
+    return quantity.box_size_at_number_density(particle_count,
+                                               number_density,
+                                               spatial_dimension=spatial_dimension)
+
+def ellipsoid_volume(eigvals: Array):
+    return 4 / 3 * jnp.pi / jnp.prod(jnp.sqrt(eigvals), axis=-1)
+
+def box_size_at_ellipsoid_density(particle_count: int,
+                                  density: float,
+                                  eigvals: Array):
+    if eigvals.ndim > 2:
+        raise ValueError("Eigenvalue matrix should have at most 2 dimensions.")
+    spatial_dimension = eigvals.shape[-1]
+    particle_volume = ellipsoid_volume(eigvals)
+    if particle_volume.ndim == 0:
+        particle_volume = jnp.full((particle_count,), particle_volume)
+    total_particle_volume = jnp.sum(particle_volume)
+    return jnp.power(total_particle_volume / density, 1 / spatial_dimension)
+
+
+@jax.jit
+def update_interaction_params(grad: data_protocols.InteractionParams,
+                              interaction_params: data_protocols.InteractionParams,
+                              learning_rate: float) -> data_protocols.InteractionParams:
+    """
+    Update interaction parameters with gradient descent step. Rescales the new ellipsoid eigenvalues
+    so that they correspond to unit volume particles.
+    """
+    new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, interaction_params, grad)
+    return canonicalize_eigvals(new_params)
+
+
+class OrientedParticleEnergy(Protocol):
+    """Protocol specifying the signature for energy functions between oriented particles."""
+
+    def __call__(self, dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
+        ...
+
+
+def get_ellipsoid_contact_function(contact_function: Callable[..., Array], eigvals: Array, **cf_kwargs):
+    """
+    Return a function that calculates square root of the Perram-Wertheim contact function between a pair of ellipsoids.
+    """
+    def fun(dr: Array, eigsys1: Array, eigsys2: Array) -> Array:
+        qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
+        qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
+        return contact_function(dr, qf1, qf2, **cf_kwargs)
+    return fun
+
+
+def get_ellipsoid_contact_function_param(contact_function: Callable[..., Array], **cf_kwargs):
+    """
+    Return a function that calculates the contact function between a pair of ellipsoids with a standardized call
+    signature. It also does the transform from the standard quadratic form eigenvalues for ellipsoids (where
+    eigenvalues are invere squares of semiaxis lenghts) to the weight matrix used in the Perram-Wertheim contact
+    function (eigenvalues are just semiaxes squared, without the inverse).
+    """
+    def fun(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:
+        qf1 = qf_from_rotation(eigsys1, make_diagonal(1 / eigvals))
+        qf2 = qf_from_rotation(eigsys2, make_diagonal(1 / eigvals))
+        return contact_function(dr, qf1, qf2, **cf_kwargs)
+    return fun
+
+
+def isotropic_to_ellipsoid_energy(energy_fn: Callable[..., Array],
+                                  contact_function: Callable[..., Array],
+                                  eigvals: Array,
+                                  **cf_kwargs) -> OrientedParticleEnergy:
+    """Promotes an isotropic energy function to one acting between ellipsoids,
+    with a given contact function as a measure of distance."""
+
+    cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
+
+    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
+        return energy_fn(cf(dr, eigsys1, eigsys2), **kwargs)
+
+    return ellipsoid_energy_fn
+
+
+def isotropic_to_cf_energy(energy_fn: Callable[..., Array],
+                           contact_function: Callable[..., Array],
+                           **cf_kwargs) -> OrientedParticleEnergy:
+    """Promotes an isotropic energy function to one acting between ellipsoids,
+    with a given contact function as a measure of distance."""
+
+    cf = get_ellipsoid_contact_function_param(contact_function, **cf_kwargs)
+
+    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
+        return energy_fn(cf(dr, eigsys1, eigsys2, **cf_kwargs), **kwargs)
+
+    return ellipsoid_energy_fn
+
+
+def isotropic_to_ellipsoid_energy_with_cutoff(energy_fn: Callable[..., Array],
+                                              contact_function: Callable[..., Array],
+                                              eigvals: Array,
+                                              cf_onset: float,
+                                              cf_cutoff: float,
+                                              **cf_kwargs) -> OrientedParticleEnergy:
+    """
+    Promotes an isotropic energy function to one acting between ellipsoids,
+    with a given contact function as a measure of distance.
+    Adds the multiplicative isotropic cutoff to get a truncated function.
+    """
+
+    cf = get_ellipsoid_contact_function(contact_function, eigvals, **cf_kwargs)
+
+    def ellipsoid_energy_fn(dr: Array, eigsys1: Array, eigsys2: Array, **kwargs) -> Array:
+        return energy.multiplicative_isotropic_cutoff(
+            energy_fn, cf_onset, cf_cutoff)(cf(dr, eigsys1, eigsys2), **kwargs)
+
+    return ellipsoid_energy_fn

+ 324 - 0
curvature_assembly/parallelization.py

@@ -0,0 +1,324 @@
+from __future__ import annotations
+from typing import Callable, TypeVar
+import copy
+import jax
+import multiprocess as mp
+from functools import partial
+import math
+from curvature_assembly import pytree_transf
+import warnings
+
+T = TypeVar('T')
+
+
+def all_equal(iterator) -> bool:
+    """Check if all elements in an iterator are equal."""
+    iterator = iter(iterator)
+    try:
+        first = next(iterator)
+    except StopIteration:
+        return True
+    return all(first == x for x in iterator)
+
+
+def get_argument_length(arg) -> int:
+    """
+    Get length of arg if arg is a sequence or assign a length to arg if arg is a dataclass of arrays.
+    Raises TypeError if length cannot be assigned to this argument.
+    """
+    if pytree_transf.check_if_pytree_of_arrays(arg):
+        return pytree_transf.data_length(arg)
+    try:
+        return len(arg)
+    except TypeError:
+        raise TypeError(f'Cannot assign lenght to argument of type {type(arg)}')
+
+
+def full_sequence_length(map_argnums: int | tuple, *args):
+    """
+    Return the length of the sequence(s) over which a function will be mapped.
+
+    Args:
+    map_argnums: The positional index(es) of the argument(s) over which the function will be mapped.
+    args: The arguments passed to the function.
+
+    Returns:
+    The length of the sequence(s) over which the function will be mapped.
+
+    Raises:
+    ValueError: If any of the arguments over which the function is mapped is not a sequence, or if the sequences
+        have different lengths.
+    """
+
+    if isinstance(map_argnums, int):
+        map_argnums = (map_argnums,)
+
+    lengths = []
+    for argnum in map_argnums:
+        try:
+            lengths.append(get_argument_length(args[argnum]))
+        except TypeError:
+            raise ValueError(f'Each argument over which a function is mapped should be a sequence '
+                             f'or a pytree of arrays, got {type(args[argnum])} for argument {argnum}.')
+
+    if all_equal(lengths):
+        return lengths[0]
+    else:
+        raise ValueError(f'All arguments over which we map should be of the same length,'
+                         f'got lengths {lengths} for args {map_argnums}, respectively.')
+
+
+def canonicalize_args(map_argnums: int | tuple, *args) -> list:
+    """
+    Create a canonical list of arguments consisting of sequences with equal
+    numbers of elements.
+
+    Args:
+        map_argnums: Argument indices that are already sequences of length num.
+        *args: A variable number of arguments to be canonicalized. Each argument
+            should either be a sequence (list, tuple, etc.) with length num,
+            or a non-sequence type that can be repeated num times to create
+            a sequence.
+
+    Returns:
+        A canonical list of arguments. Each element of the list is a sequence
+        with `num` elements, either copied from the input argument or created by
+        repeating a non-sequence argument num times.
+        """
+    if isinstance(map_argnums, int):
+        map_argnums = (map_argnums,)
+    num = full_sequence_length(map_argnums, *args)  # ValueError if mapped arguments have different lengths
+    canonical_args = []
+    for i, arg in enumerate(args):
+        if not i in map_argnums:
+            canonical_args.append([copy.deepcopy(arg) for _ in range(num)])
+        else:
+            canonical_args.append(arg)
+    return canonical_args
+
+
+def canonicalize_args_pmap(map_argnums: int | tuple, *args) -> list:
+    """
+    Create a canonical list of arguments consisting of dataclasses with all Array
+    fields having the same leading dimension length.
+
+    Args:
+        map_argnums: Argument indices that already store arrays of length num to be mapped over.
+        *args: A variable number of arguments to be canonicalized.
+
+    Returns:
+        A canonical list of arguments. Each element of the list is a sequence
+        with `num` elements, either copied from the input argument or created by
+        repeating a non-sequence argument num times.
+        """
+    if isinstance(map_argnums, int):
+        map_argnums = (map_argnums,)
+    num = full_sequence_length(map_argnums, *args)  # ValueError if mapped arguments have different lengths
+    canonical_args = []
+    for i, arg in enumerate(args):
+        if not i in map_argnums:
+            canonical_args.append(pytree_transf.repeat_fields(arg, num))
+            try:
+                if pytree_transf.data_length(arg) == num:
+                    warnings.warn(f"Added a new leading dimension to argument {i} with existing leading dimension "
+                                  f"length that is the same as the length of the mapped argument(s). Make sure that "
+                                  f"this is the desired behavior and this argument should not also be mapped over.")
+            except pytree_transf.NoLengthError:
+                pass
+        else:
+            canonical_args.append(arg)
+    return canonical_args
+
+
+def fill_to_length_num(num, *args):
+    """
+    Extends each argument in `args` with its last element until its length is a multiple of `num`.
+    
+    Args:
+        num: The multiple to which the length of each argument should be extended.
+        args: A variable number of arguments to be extended.
+        
+    Returns:
+        A list of the extended arguments.
+    """
+    filled_args = []
+    for arg in args:
+        filled_args.append(pytree_transf.extend_with_last_element(arg, num))
+    return filled_args
+
+
+def get_slice(start_idx: int, slice_length: int, *args) -> list:
+    """
+    Return a slice of a specified length from each argument in a variable-length
+    list of sequences.
+
+    Args:
+        start_idx: The starting index of the slice to be extracted from each sequence.
+        slice_length: The length of the slice to be extracted from each sequence.
+        *args: A variable-length list of sequences.
+
+    Returns:
+        A list of slices where each slice is extracted from the corresponding
+        sequence in `args` starting at index `start_idx` and extending for `slice_length`
+        elements.
+        """
+    if start_idx < 0 or slice_length < 0:
+        raise ValueError("Start index and slice length must be non-negative.")
+    return [arg[start_idx:start_idx+slice_length] for arg in args]
+
+
+def list_flatten(lst: list) -> list:
+    """
+    Flatten a list of nested lists.
+    """
+    flat_list = []
+    for sublist in lst:
+        for item in sublist:
+            flat_list.append(item)
+    return flat_list
+
+
+def segment_args_pool(num: int, num_cores: int, *args) -> list:
+    """
+    Segment the input arguments into a list of segments, with each segment containing
+    a fraction of the arguments. This function can be used to split up a large computation
+    across multiple processor cores using the multiprocess.Pool to speed up processing.
+
+    Args:
+        num: The total number of items to be segmented across cores.
+        num_cores: The number of cores to be used for processing.
+        *args: A variable-length list of sequences. Each sequence should be indexable
+            and have a length equal to `num`.
+
+    Returns:
+        A list of segments, where each segment is a list of argument values
+        extracted from the corresponding index range in the input sequences. The output
+        list will have length `num_cores`, and each segment will have the
+        same number of items, except for the last one that gets the remaining number of items.
+        """
+    segment_length = int(math.ceil(num / num_cores))
+    args_list = []
+    for i in range(num_cores):
+        args_list.append(get_slice(segment_length*i, segment_length, *args))
+    return args_list
+
+
+def segment_args_pmap(num: int, num_devices: int, *args) -> list:
+    """
+    Segment the input arguments into a list of segments, with each segment containing
+    a fraction of the arguments. This function can be used to split up a large computation
+    across multiple computational units using jax.pmap to speed up processing.
+
+    Args:
+        num: The total number of items to be segmented across cores.
+        num_devices: The number of devices to be used for processing.
+        *args: A variable-length list of sequences. Each sequence should be indexable
+            and have a length equal to `num`.
+
+    Returns:
+        A list of segments, where each segment is a list of argument values
+        extracted from the corresponding index range in the input sequences. The output
+        list will have length num_pmap_calculations, and each segment will have the
+        same number of items, except for the last one that gets the remaining number of items.
+        """
+
+    num_pmap_calculations = int(math.ceil(num / num_devices))
+    args_list = []
+    for i in range(num_pmap_calculations):
+        args_list.append(pytree_transf.get_slice(args, num_devices * i, num_devices))
+    return args_list
+
+
+def cpu_segment_dispatch(f: Callable[..., T], num_cores: int, map_argnums: int | tuple = 0) -> Callable[..., list[T]]:
+    """
+    Embarrassingly-parallel function evaluation over multiple cores. Divides the input arguments into
+    segments and dispatches each segment to a different processor core. The idea of such implementation is
+    that jax functions that the compilation of jax functions only happens once at each core.
+
+    Args:
+        f: A function to be executed on the different input arguments in parallel.
+            Parallelization over keyword arguments is not supported.
+        num_cores: The number of processor cores to be used for parallel processing.
+        map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
+
+    Returns:
+        A new function that takes the same arguments as `f` and dispatches
+        the input arguments across multiple processor cores for parallel processing.
+        The returned function will return a list of the results from each
+        parallel processing segment.
+    """
+    if num_cores <= 1:
+        raise ValueError("The number of cores must be a positive integer.")
+    def sequential_f(args: list, **kwargs):
+        seq_results = []
+        for i, a in enumerate(zip(*args)):
+            seq_results.append(f(*a, **kwargs))
+        return seq_results
+
+    def parallel_f(*args, **kwargs) -> list:
+        canonical_args = canonicalize_args(map_argnums, *args)
+        num = full_sequence_length(map_argnums, *args)
+        threads = mp.Pool(num_cores)
+        results = threads.map(partial(sequential_f, **kwargs), segment_args_pool(num, num_cores, *canonical_args))
+        return list_flatten(results)
+    return parallel_f
+
+
+def pmap_segment_dispatch(f: Callable[..., T],
+                          map_argnums:  int | tuple[int, ...] = 0,
+                          backend: str = 'cpu',
+                          pmap_jit: bool = False) -> Callable[..., T]:
+    """
+    Embarrassingly-parallel function evaluation over multiple jax devices. Divides the input arguments into
+    segments and dispatches each segment to a different processor core.
+
+    Args:
+        f: A function to be mapped over the leading axis of `map_argnums` arguments in parallel.
+            Parallelization over keyword arguments is not supported.
+        map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
+        backend: jax backend, 'cpu' or 'gpu'. For parallelization over multiple cpu cores,
+            os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=XX' should be set
+            at the beginning of the main script, with XX the number of cores.
+        pmap_jit: bool, whether to jit the pmap-ed function. This will raise a warning but can speed up
+            parallel calculations when num > device_count, at least on cpu.
+
+    Returns:
+        A new function that takes the same arguments as `f` and dispatches `map_argnums`
+        input arguments over the leading axis across multiple devices for parallel processing.
+        All return values of the mapped function will have a leading axis with a length corresponding
+        to the length of the `map_argnums` input arguments.
+    """
+
+    device_count = jax.local_device_count(backend=backend)
+    if backend == 'cpu' and device_count == 1:
+        raise ValueError('Got cpu backend for parallelization but only 1 cpu device is available. '
+                         'Try setting os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=XX\" '
+                         'at the beginning of the script.')
+
+    def parallel_f(*args, **kwargs):
+        canonical_args = canonicalize_args_pmap(map_argnums, *args)
+        num = full_sequence_length(map_argnums, *args)
+
+        def pmap_f(*x):
+            return jax.pmap(partial(f, **kwargs))(*x)
+
+        if pmap_jit:
+            # jit(pmap) raises UserWarning (https://github.com/google/jax/issues/2926) but using jit here prevents
+            # pmap seemingly tracing the code in every iteration of the following for loop, which results in
+            # faster computation when num > device_count
+            pmap_f = jax.jit(pmap_f)
+
+            # when jit-ing pmap, merging of results doesn't work if segments have different lengths, so we
+            # expand the arguments to a multiple of device_count
+            canonical_args = fill_to_length_num(math.ceil(num / device_count) * device_count, *canonical_args)
+
+        results = []
+        for arguments in segment_args_pmap(num, device_count, *canonical_args):
+            r = pmap_f(*arguments)
+            results.append(r)
+        return pytree_transf.get_slice(pytree_transf.merge(results), 0, num)
+
+    return parallel_f
+
+
+

+ 170 - 0
curvature_assembly/patchy_interaction.py

@@ -0,0 +1,170 @@
+from functools import partial
+from typing import List, Union, Callable
+import jax.numpy as jnp
+import jax
+from curvature_assembly.spherical_harmonics import sph_harm_fn, real_sph_harm, sph_harm_not_fast, sph_harm_fn_custom, real_sph_harm_fn_custom_rev
+
+Array = jnp.ndarray
+
+
+def vec_in_eigensystem(eigsys: Array, vec: Array):
+    """Get vector components in the eigensystem."""
+    return jnp.dot(jnp.transpose(eigsys), vec)
+
+
+def safe_arctan2(x, y):
+    """
+    Version of arctan2 that works for zero-valued inputs. Look at https://github.com/google/jax/issues/1052
+    and https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
+    """
+    safe_y = jnp.where(y > 0., y, 1.)
+    return jnp.where(y > 0, jnp.arctan2(x, safe_y), 1.)
+
+
+def cart_to_sph(vec: Array) -> (Array, Array):
+    """Transformation to spherical coordinates theta and phi."""
+    sph_coord = jnp.zeros(2, )
+    sph_coord = sph_coord.at[0].set(safe_arctan2(jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), vec[2]))
+    sph_coord = sph_coord.at[1].set(safe_arctan2(vec[1], vec[0]))
+    return sph_coord
+
+
+def patchy_interaction_general(lm_list: Union[tuple, List[tuple]]) -> Callable:
+    """
+    Orientational part for a general patchy particle interaction where patches are described by a linear combination
+    of spherical harmonics. The form of the potential is inspired by the Kern-Frenkel patchy particle model.
+    """
+
+    if isinstance(lm_list, tuple):
+        lm_list = [lm_list]
+
+    l_list, m_list = zip(*lm_list)
+    l_array = jnp.array(l_list)
+    m_array = jnp.array(m_list)
+
+    # sph_harm = real_sph_harm_fn_custom_rev(6)
+
+    if not jnp.all(jnp.abs(m_array) <= l_array):
+        raise ValueError(f'Spherical harmonics are only defined for |m|<=l.')
+
+    def fn(dr: Array, eigsys1: Array, eigsys2: Array, lm_magnitudes: Array) -> Array:
+
+        if lm_magnitudes.shape == ():
+            lm_magnitudes = jnp.full(len(lm_list), lm_magnitudes)
+
+        if len(lm_magnitudes) != len(lm_list):
+            raise ValueError(f'Length of lm_magnitudes array does not match the number of (l, m) expansion terms, '
+                             f'got {len(lm_magnitudes)} and {len(lm_list)}, respectively.')
+
+        # dr points from 2nd to 1st particle (dr = r1 - r2)
+        # we need relative direction from one particle to another, so in the case of the first, we need to take -dr
+        normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
+        vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
+        vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
+
+        # patches_particle1 = jnp.real(sph_harm(vec1)) @ lm_magnitudes
+        # patches_particle2 = jnp.real(sph_harm(vec2)) @ lm_magnitudes
+
+        patches_particle1 = real_sph_harm(vec1, l_list, m_list) @ lm_magnitudes
+        patches_particle2 = real_sph_harm(vec2, l_list, m_list) @ lm_magnitudes
+
+        # energy contribution from patches is defined in such a way that negative patches attract each other,
+        # positive patches repulse and differently-signed patches have 0 energy
+        return -(jnp.sign(patches_particle1) + jnp.sign(patches_particle2)) * patches_particle1 * patches_particle2
+
+    return fn
+
+
+def generate_lm_list(l_max: int,
+                     only_non_neg_m: bool = False,
+                     only_even_l: bool = False,
+                     only_odd_l: bool = False) -> list:
+    """Return list of all possible (l, m) for a given maximal l."""
+    if only_odd_l and only_even_l:
+        raise ValueError('Parameters only_even_l and only_odd_l cannot both be True at the same time.')
+    lm_list = []
+    if only_even_l:
+        l_list = list(range(0, l_max + 1, 2))
+    elif only_odd_l:
+        l_list = list(range(1, l_max + 1, 2))
+    else:
+        l_list = list(range(0, l_max + 1))
+    for l in l_list:
+        min_m = 0 if only_non_neg_m else -l
+        for m in range(min_m, l + 1):
+            lm_list.append((l, m))
+    return lm_list
+
+
+def init_lm_coefs(lm_list: list[tuple], nonzero_list: list[tuple], init_values: list = None) -> jnp.ndarray:
+    """
+    Initialize lm coefficients for a given lm_list with desired values. Default is 0. if init_values is not provided.
+    """
+    if init_values is None:
+        init_values = [1 for _ in nonzero_list]
+    coef_list = []
+    for lm in lm_list:
+        try:
+            idx = nonzero_list.index(lm)
+            coef_list.append(init_values[idx])
+        except ValueError:
+            coef_list.append(0.)
+    return jnp.array(coef_list) / jnp.linalg.norm(jnp.array(coef_list))
+
+
+def patchy_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
+    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
+    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
+    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
+
+    limit_z_plus = jnp.cos(theta + sigma)
+    limit_z_minus = jnp.cos(theta - sigma)
+
+    # return value should be positive for attractive patches
+    # as this potential is usually combined with attractive isotropic term
+    return jnp.heaviside(limit_z_minus - vec1[2], 0.5) * jnp.heaviside(vec1[2] - limit_z_plus, 0.5) * \
+           jnp.heaviside(limit_z_minus - vec2[2], 0.5) * jnp.heaviside(vec2[2] - limit_z_plus, 0.5)
+
+
+@jax.custom_jvp
+def sigmoid(x):
+    return 1 / (1 + jnp.exp(-x))
+
+
+@sigmoid.defjvp
+def sigmoid_jvp(x, x_dot):
+    primal_out = sigmoid(x)
+    tangent_out = primal_out * (1 - primal_out) * x_dot
+    return primal_out, tangent_out
+
+
+def gaussian_belt(x, theta, sigma) -> jnp.ndarray:
+    return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
+
+
+def gaussian_belt_fixed_height(x, theta, sigma) -> jnp.ndarray:
+    return jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)
+
+
+def gaussian_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
+    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
+    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
+    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
+
+    belt = partial(gaussian_belt, theta=theta, sigma=sigma)
+
+    # return value should be positive for attractive patches
+    # as this potential is usually combined with attractive isotropic term
+    return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))
+
+
+def gaussian_interaction_band_fixed_height(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
+    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
+    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
+    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)
+
+    belt = partial(gaussian_belt_fixed_height, theta=theta, sigma=sigma)
+
+    # return value should be positive for attractive patches
+    # as this potential is usually combined with attractive isotropic term
+    return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))

+ 183 - 0
curvature_assembly/pytree_transf.py

@@ -0,0 +1,183 @@
+from __future__ import annotations
+from typing import TypeVar, Callable, Any
+import jax.numpy as jnp
+import jax
+
+
+T = TypeVar('T')
+Array = jnp.ndarray
+
+
+def all_equal(iterator) -> bool:
+    """Check if all elements in an iterator are equal."""
+    iterator = iter(iterator)
+    try:
+        first = next(iterator)
+    except StopIteration:
+        return True
+    return all(first == x for x in iterator)
+
+
+def repeat_fields(pytree: T, num: int) -> T:
+    """Repeat each leaf node of a PyTree `num` times along a new leading axis."""
+    def repeat(x):
+        try:
+            return jnp.repeat(x[None, ...], num, axis=0)
+        except TypeError:
+            return jnp.full((num,), x)
+    return jax.tree_util.tree_map(repeat, pytree)
+
+
+def merge(pytree_list: list[T]) -> T:
+    """Merge the leaf nodes of multiple PyTrees by concatenating them along the leading axis."""
+    # Watch out for the differences between this function and stack function.
+    # This concatenates over existing axes and stack creates a new leading axis.
+    return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *pytree_list)
+
+def stack(pytree_list: list[T]) -> T:
+    """Merge the leaf nodes of multiple PyTrees by stacking them along a new leading axis."""
+    return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *pytree_list)
+
+
+def axis_length(pytree: T, axis: int = 0) -> T:
+    """
+    Calculate axis lengths of pytree leaves. Non-array values are assigned None. If `axis` argument is larger
+    than array shape, we return -1.
+    """
+    def length(x):
+        try:
+            return x.shape[axis]
+        except AttributeError:
+            return 0
+        except IndexError:
+            return -1
+    return jax.tree_util.tree_map(length, pytree)
+
+
+class NoLengthError(Exception):
+    pass
+
+
+def data_length(pytree: T, axis: int = 0, ignore_non_array_leaves: bool = False) -> int:
+    """Assign a length to a pytree from shapes of arrays stored within it."""
+    leading_dim = axis_length(pytree, axis=axis)
+    lengths, structure = jax.tree_util.tree_flatten(leading_dim)
+    # we want to exclude leaves that are not arrays (length None) as this might be some number auxiliary data
+    if ignore_non_array_leaves:
+        lengths = [x for x in lengths if x > 0]
+    if all_equal(lengths) and len(lengths) > 0:
+        return lengths[0]
+    raise NoLengthError(f'Pytree of type {type(pytree)} with structure {structure} cannot have a length assigned to it '
+                        f'over axis {axis}.')
+
+
+def check_if_pytree_of_arrays(pytree: T, allow_numbers: bool = True) -> bool:
+    """Check if a pytree consists only of ndarray leaves, with possibly allowed number leaves for auxiliary data."""
+    leaves, _ = jax.tree_util.tree_flatten(pytree)
+    # print('Leaves type: ', [type(l) for l in leaves])
+    for leaf in leaves:
+        if not isinstance(leaf, jnp.ndarray):
+            if allow_numbers and not hasattr(leaf, "__len__"):
+                continue
+            return False
+    return True
+
+
+def get_slice(pytree: T, start_idx: int, slice_length: int) -> T:
+    """
+    Return a new PyTree with the same structure as the input PyTree, but each leaf array sliced along the leading axis.
+
+    Args:
+        pytree: A PyTree of ndarrays, representing the input data.
+        start_idx: An integer, the starting index of the slice.
+        slice_length: An integer, the length of the slice.
+
+    Returns:
+        A PyTree of ndarrays, with each element sliced along the first axis.
+    """
+    return jax.tree_util.tree_map(lambda x: x[start_idx:start_idx+slice_length], pytree)
+
+
+def split_to_list(pytree: T) -> list[T]:
+    if not check_if_pytree_of_arrays(pytree, allow_numbers=True):
+        raise ValueError('Should get a pytree of arrays.')
+    length = data_length(pytree)
+    return [jax.tree_util.tree_map(lambda x: x[idx], pytree) for idx in range(length)]
+
+def map_over_leading_leaf_dimension(f: Callable[[T, Any], T], *pytrees: T, **kwargs):
+    """
+    Maps a function that takes a pytree over the leading leaf dimensions by splitting the pytree over this leading
+    dimension and stacking the results into a single object. All function *args must be splittable pytrees.
+    """
+    split_tree = split_to_list(pytrees)
+    results = [f(*tree, **kwargs) for tree in split_tree]
+    return stack(results)
+
+
+def num_dimensions(pytree: T) -> T:
+    """Determine the number of dimensions for each array in a pytree of arrays."""
+    def num_dim(x):
+        try:
+            return len(x.shape)
+        except AttributeError:
+            return 0
+    return jax.tree_util.tree_map(num_dim, pytree)
+
+
+def num_extra_dimensions(pytree: T, og_pytree: T) -> int:
+    """Determine the number of extra leading dimensions compared to some original pytree of the same kind."""
+    pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
+    og_pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(og_pytree))
+    dim_differences = [d1 - d2 for d1, d2 in zip(pytree_dims, og_pytree_dims)]
+    if all_equal(dim_differences):
+        return dim_differences[0]
+    raise ValueError('No consistent extra leading dimensions found.')
+
+
+def leaf_norm(pytree: T, num_ld: int = 0, keepdims: bool = True) -> T:
+
+    def unitwise_norm(x: Array) -> Array:
+        squared_norm = jnp.sum(x ** 2, keepdims=keepdims)
+        return jnp.sqrt(squared_norm)
+
+    if num_ld == 0:
+        return jax.tree_util.tree_map(unitwise_norm, pytree)
+    if num_ld == 1:
+        return jax.tree_util.tree_map(jax.vmap(unitwise_norm), pytree)
+    if num_ld == 2:
+        return jax.tree_util.tree_map(jax.vmap(jax.vmap(unitwise_norm)), pytree)
+    raise NotImplementedError('Cannot calculate the leaf_norm of leaves with 3 or more common leading dimensions.')
+
+
+def broadcast_to(pytree1: T, pytree2: T) -> T:
+    """Broadcast all leaf arrays from one pytree to the shape of arrays in another pytree of the same type."""
+    if not check_if_pytree_of_arrays(pytree1, allow_numbers=False) \
+            and not check_if_pytree_of_arrays(pytree2, allow_numbers=False):
+        raise ValueError('Should get pytrees of arrays.')
+    return jax.tree_util.tree_map(lambda x, y: jnp.broadcast_to(x, y.shape), pytree1, pytree2)
+
+
+def all_data_to_single_array(pytree: T) -> jnp.ndarray:
+    if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
+        raise ValueError('Should get a pytree of arrays.')
+    arrays, _ = jax.tree_util.tree_flatten(pytree)
+    array_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
+    max_tree_dim = max(array_dims)
+    reshaped_arrays = []
+    for array, dim in zip(arrays, array_dims):
+        reshaped_arrays.append(array.reshape(array.shape + (1,) * (max_tree_dim - dim)))
+    return jnp.hstack(reshaped_arrays)
+
+
+def extend_with_last_element(pytree: T, desired_length: int) -> T:
+    if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
+        raise ValueError('Should get a pytree of arrays.')
+    length = data_length(pytree, axis=0)
+    if length > desired_length:
+        raise ValueError('Data length is larger than desired length so it cannot be extended.')
+    append_length = desired_length - length
+    def leaf_append(x):
+        append_shape = list(x.shape)
+        append_shape[0] = append_length
+        return jnp.append(x, jnp.full(append_shape, x[-1]), axis=0)
+    return jax.tree_util.tree_map(leaf_append, pytree)

+ 478 - 0
curvature_assembly/simulation.py

@@ -0,0 +1,478 @@
+from __future__ import annotations
+import time
+from functools import partial
+import jax
+from jax import lax, jit, random
+from curvature_assembly import oriented_particle, data_protocols, cost_functions, util
+from jax_md import simulate, rigid_body, dataclasses, space, partition, quantity
+import jax.numpy as jnp
+from typing import Callable, TypeVar, Optional, Any
+import warnings
+import copy
+# import equinox
+
+
+Array = jnp.ndarray
+NeighborFn = partition.NeighborFn
+NeighborListFormat = partition.NeighborListFormat
+T = TypeVar('T')
+InitFn = Callable[..., T]
+ApplyFn = Callable[[T], T]
+RigidBody = rigid_body.RigidBody
+InteractionParams = data_protocols.InteractionParams
+P = TypeVar('P', bound=InteractionParams)
+
+
+@dataclasses.dataclass
+class NVTSimulationParams:
+    """
+    Container for NVT simulation parameters.
+    """
+
+    num: int
+    density: float
+    simulation_steps: int
+    dt: float
+    kT: float
+    config_every: int = 100
+    bptt_truncation: int = 500
+
+
+def get_higher_temp_equilibration_params(sim_params: NVTSimulationParams, new_kT: float) -> NVTSimulationParams:
+    """Get a new NVT simulation parameters for equilibration simulation at a given higher temperature."""
+    params_dict = copy.deepcopy(vars(sim_params))
+    params_dict['bptt_truncation'] = sim_params.simulation_steps
+    params_dict['kT'] = new_kT
+    return NVTSimulationParams(**params_dict)
+
+
+@dataclasses.dataclass
+class SimulationLogNoseHoover:
+    """Dataclass for storing observables, invariants etc. during a simulation."""
+    T: Array
+    E: Array
+    K: Array
+    H: Array
+    current_len: Array
+
+    @staticmethod
+    def create_empty(num_steps: int, save_every: int) -> SimulationLogNoseHoover:
+        E = jnp.zeros(num_steps // save_every)
+        T = jnp.zeros(num_steps // save_every)
+        K = jnp.zeros(num_steps // save_every)
+        H = jnp.zeros(num_steps // save_every)
+        return SimulationLogNoseHoover(T, E, K, H, 0)
+
+    def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float, float):
+        T = rigid_body.temperature(position=state.position,
+                                    momentum=state.momentum,
+                                    mass=ellipsoid_mass)
+        E = energy_fn(state.position, **params)
+        K = rigid_body.kinetic_energy(position=state.position,
+                                      momentum=state.momentum,
+                                      mass=ellipsoid_mass)
+        H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, **params)
+        return T, E, K, H
+
+    def update(self, T: Array, E: Array, K: Array, H: Array) -> SimulationLogNoseHoover:
+        idx = self.current_len
+        log = dataclasses.replace(self, E=self.E.at[idx].set(E))
+        log = dataclasses.replace(log, T=log.T.at[idx].set(T))
+        log = dataclasses.replace(log, K=log.K.at[idx].set(K))
+        log = dataclasses.replace(log, H=log.H.at[idx].set(H))
+        log = dataclasses.replace(log, current_len=idx + 1)
+        return log
+
+    def revert_last_nsteps(self, nsteps) -> SimulationLogNoseHoover:
+        log = dataclasses.replace(self, current_len=self.current_len - nsteps)
+        return log
+
+
+@dataclasses.dataclass
+class SimulationLogLangevin:
+    """Dataclass for storing observables, invariants etc. during a simulation."""
+    T: Array
+    E: Array
+    K: Array
+    current_len: Array
+
+    @staticmethod
+    def create_empty(num_steps: int, save_every: int) -> SimulationLogLangevin:
+        E = jnp.zeros(num_steps // save_every)
+        T = jnp.zeros(num_steps // save_every)
+        K = jnp.zeros(num_steps // save_every)
+        return SimulationLogLangevin(T, E, K, 0)
+
+    def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float):
+        T = rigid_body.temperature(position=state.position,
+                                    momentum=state.momentum,
+                                    mass=ellipsoid_mass)
+        E = energy_fn(state.position, **params)
+        K = rigid_body.kinetic_energy(position=state.position,
+                                      momentum=state.momentum,
+                                      mass=ellipsoid_mass)
+        return T, E, K
+
+    def update(self, T: Array, E: Array, K: Array) -> SimulationLogLangevin:
+        idx = self.current_len
+        log = dataclasses.replace(self, E=self.E.at[idx].set(E))
+        log = dataclasses.replace(log, T=log.T.at[idx].set(T))
+        log = dataclasses.replace(log, K=log.K.at[idx].set(K))
+        log = dataclasses.replace(log, current_len=idx + 1)
+        return log
+
+    def revert_last_nsteps(self, nsteps) -> SimulationLogLangevin:
+        log = dataclasses.replace(self, current_len=self.current_len - nsteps)
+        return log
+
+
+@dataclasses.dataclass
+class SimulationStateHistory:
+    """Dataclass for storing particle configurations during a simulation."""
+    coord: Array
+    orient: Array
+    current_len: Array
+
+    @staticmethod
+    def create_empty(num_steps: int, n_particles: int, config_every: int) -> SimulationStateHistory:
+        coord = jnp.zeros((num_steps // config_every, n_particles, 3))
+        orient = jnp.zeros((num_steps // config_every, n_particles, 4))
+        return SimulationStateHistory(coord, orient, 0)
+
+    def update(self, coord: Array, orient: Array) -> SimulationStateHistory:
+        idx = self.current_len
+        state_history = dataclasses.replace(self, coord=self.coord.at[idx].set(coord))
+        state_history = dataclasses.replace(state_history, orient=state_history.orient.at[idx].set(orient))
+        state_history = dataclasses.replace(state_history, current_len=idx + 1)
+        return state_history
+
+    def revert_last_nsteps(self, nsteps) -> SimulationStateHistory:
+        log = dataclasses.replace(self, current_len=self.current_len - nsteps)
+        return log
+
+
+@dataclasses.dataclass
+class SimulationAux:
+    """Dataclass for simulation auxiliary data."""
+    log: data_protocols.SimulationLog
+    state_history: data_protocols.SimulationStateHistory
+
+    def revert_last_nsteps(self, nsteps, config_every):
+        log = self.log.revert_last_nsteps(nsteps)
+        state_history = self.state_history.revert_last_nsteps(nsteps // config_every)
+        aux = dataclasses.replace(self, log=log)
+        aux = dataclasses.replace(aux, state_history=state_history)
+        return aux
+
+    def reset_empty(self) -> SimulationAux:
+        """
+        Set current_len attribute of SimulationLog and SimulationStateHistory classes to 0 which effectively resets
+        their empty state (current data will be overwritten in the next bptt simulation run).
+        """
+        # we use zeros_like() because of possible parallelization that adds an axis to current_len attribute
+        empty_log = dataclasses.replace(self.log, current_len=jnp.zeros_like(self.log.current_len))
+        empty_history = dataclasses.replace(self.state_history, current_len=jnp.zeros_like(self.state_history.current_len))
+        aux = dataclasses.replace(self, log=empty_log)
+        aux = dataclasses.replace(aux, state_history=empty_history)
+        return aux
+
+
+def setup_nose_hoover(energy: Callable,
+                      shift: space.ShiftFn,
+                      simulation_params: NVTSimulationParams,
+                      **nose_hoover_kwargs) -> (InitFn, ApplyFn, SimulationAux):
+    """
+    Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
+    """
+
+    log = SimulationLogNoseHoover.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
+    state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
+                                                        simulation_params.num,
+                                                        simulation_params.config_every)
+    aux = SimulationAux(log=log, state_history=state_history)
+    init_fn, step_fn = simulate.nvt_nose_hoover(energy, shift, simulation_params.dt, simulation_params.kT,
+                                                **nose_hoover_kwargs)
+
+    return init_fn, step_fn, aux
+
+
+def setup_langevin(energy: Callable,
+                      shift: space.ShiftFn,
+                      simulation_params: NVTSimulationParams,
+                      **langevin_kwargs) -> (InitFn, ApplyFn, SimulationAux):
+    """
+    Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
+    """
+
+    log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
+    state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
+                                                        simulation_params.num,
+                                                        simulation_params.config_every)
+    aux = SimulationAux(log=log, state_history=state_history)
+    init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
+                                                **langevin_kwargs)
+
+    return init_fn, step_fn, aux
+
+
+def rescale_momenta_new_temperature(state: simulate.NVTNoseHooverState,
+                                    new_kT: float,
+                                    old_kT: float) -> simulate.NVTNoseHooverState:
+
+    new_momentum_center = jnp.sqrt(new_kT / old_kT) * state.momentum.center
+    new_momentum_orientation = jnp.sqrt(new_kT / old_kT) * state.momentum.orientation.vec
+    return state.set(momentum=RigidBody(new_momentum_center, rigid_body.Quaternion(new_momentum_orientation)))
+
+
+def init_nose_hoover_new_temperature(state: simulate.NVTNoseHooverState,
+                                     new_kT: float,
+                                     old_kT: float,
+                                     dt: float,
+                                     chain_length: int = 5,
+                                     chain_steps: int = 2,
+                                     sy_steps: int = 3,
+                                     tau: Optional[float] = None) -> simulate.NVTNoseHooverState:
+
+    dt = simulate.f32(dt)
+    if tau is None:
+        tau = dt * 100
+    tau = simulate.f32(tau)
+
+    thermostat = simulate.nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)
+
+    dof = quantity.count_dof(state.position)
+
+    state = rescale_momenta_new_temperature(state, new_kT, old_kT)
+    KE = simulate.kinetic_energy(state)
+    return state.set(chain=thermostat.initialize(dof, KE, new_kT))
+
+
+def setup_langevin(energy: Callable,
+                   shift: space.ShiftFn,
+                   simulation_params: NVTSimulationParams,
+                   gamma: RigidBody = RigidBody(0.1, 0.1)) -> (InitFn, ApplyFn, SimulationAux):
+    """
+    Prepare functions and auxiliary data container for a molecular dynamics simulation using the Langevin thermostat.
+    """
+
+    log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
+    state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
+                                                        simulation_params.num,
+                                                        simulation_params.config_every)
+    aux = SimulationAux(log=log, state_history=state_history)
+    init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
+                                             gamma=gamma)
+
+    return init_fn, step_fn, aux
+
+
+def ellipsoid_unit_mass(eigvals: Array):
+    return oriented_particle.ellipsoid_mass(jnp.array([1.]), eigvals)
+
+
+def simulation_step(state_aux_params: tuple[T, data_protocols.SimulationAux, InteractionParams],
+                    iteration_idx: int,
+                    step_fn: Callable,
+                    energy_fn: Callable,
+                    kT: float,
+                    config_every: int) -> (tuple[T, data_protocols.SimulationAux, InteractionParams], float):
+    """Perform one simulation step and log the progress."""
+    state, aux, params = state_aux_params
+    log = aux.log
+    state_history = aux.state_history
+
+    # take a simulation step
+    # params must be passed as a dictionary
+    state = step_fn(state, **vars(params))
+
+    def update_aux(l, h):
+        new_log = l.update(*log.calculate_values(state,
+                                                 energy_fn,
+                                                 ellipsoid_unit_mass(params.eigvals),
+                                                 kT,
+                                                 **vars(params)))
+        new_history = h.update(state.position.center,
+                               state.position.orientation.vec)
+        return new_log, new_history
+
+    # log information about simulation as well as the state history
+    log, state_history = lax.cond((iteration_idx + 1) % config_every == 0,
+                                  update_aux,
+                                  lambda l, h: (l, h),
+                                  log, state_history)
+
+    aux = dataclasses.replace(aux, log=log)
+    aux = dataclasses.replace(aux, state_history=state_history)
+
+    return (state, aux, params), 0.
+
+
+def nvt_simulation_pair(init_fn: InitFn,
+                        step_fn: ApplyFn,
+                        aux: SimulationAux,
+                        energy: Callable[[...], Array],
+                        interaction_params: InteractionParams,
+                        simulation_params: NVTSimulationParams,
+                        body: RigidBody) -> (RigidBody, SimulationAux):
+
+    # set all particle masses to 1
+    ellipsoid_mass = oriented_particle.ellipsoid_mass(jnp.array([1.]), interaction_params.eigvals)
+
+    # setup simulation
+    scan_step = partial(simulation_step, step_fn=step_fn,
+                                      energy_fn=energy, kT=simulation_params.kT, ellipsoid_mass=ellipsoid_mass,
+                                      config_every=simulation_params.config_every)
+
+    # initialize state
+    key = random.PRNGKey(0)
+    state = init_fn(key, body, mass=ellipsoid_mass)
+
+    @jit
+    def scan_to_jit(state_aux, num_steps):
+        new_state_aux, _ = lax.scan(scan_step, state_aux, num_steps)
+        return new_state_aux
+
+    # run simulation
+    print('Simulation start')
+    t0 = time.perf_counter()
+    state_and_aux = scan_to_jit((state, aux), jnp.arange(simulation_params.simulation_steps))
+    state, aux = state_and_aux
+    t1 = time.perf_counter()
+    print(f'Simulation time: {t1 - t0}')
+
+    return state.position, aux
+
+
+@dataclasses.dataclass
+class BPTTResults:
+    grad: InteractionParams
+    cost: Array
+    current_len: int
+
+    @staticmethod
+    def create_empty(interaction_params: InteractionParams, n_steps: int) -> BPTTResults:
+        grad = empty_grad_results(interaction_params, n_steps)
+        cost = jnp.zeros((n_steps,))
+        return BPTTResults(grad, cost, 0)
+
+    def update(self, grad: InteractionParams, cost: float) -> BPTTResults:
+        idx = self.current_len
+        new_grad = jax.tree_util.tree_map(partial(update_gradient_history, idx=idx), self.grad, grad)
+        r = dataclasses.replace(self, grad=new_grad)
+        r = dataclasses.replace(r, cost=self.cost.at[idx].set(cost))
+        r = dataclasses.replace(r, current_len=idx + 1)
+        return r
+
+
+def empty_grad_results(interaction_params: P, num_rep: int) -> P:
+    """
+    Initializes an interactions parameters class to store gradients after each section of a truncated BPTT run.
+    """
+    history_dict = {}
+    for key, value in vars(interaction_params).items():
+        try:
+            history_dict[key] = jnp.zeros((num_rep,) + value.shape)
+        except AttributeError:
+            history_dict[key] = jnp.zeros((num_rep,))
+    return type(interaction_params)(**history_dict)
+
+
+def update_gradient_history(history_array: Array, grad_value: Array, idx: int) -> Array:
+    """Update a gradient history array at idx with a new gradient value."""
+    return history_array.at[idx].set(grad_value)
+
+
+def simple_forward_simulation(init_fn: InitFn,
+                              step_fn: ApplyFn,
+                              num_steps: int
+                              ) -> Callable[[InteractionParams, RigidBody, int, jax.random.PRNGKey], RigidBody]:
+    """Elementary forward MD simulation, without any logging and configuration saving."""
+
+    def simulation(interaction_params: InteractionParams,
+                   body: RigidBody,
+                   key: jax.random.PRNGKey):
+
+        # initialize state
+        state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
+
+        def scan_step(state, i):
+            return step_fn(state, **vars(interaction_params)), 0.
+
+        state, _ = lax.scan(scan_step, state, jnp.arange(num_steps))
+
+        return state
+
+    return simulation
+
+
+def truncated_bptt_nvt_simulation(step_fn: ApplyFn,
+                                  energy: Callable[[...], Array],
+                                  cost_fn: cost_functions.CostFn,
+                                  simulation_params: NVTSimulationParams,
+                                  only_forward_calculation: bool = False) -> data_protocols.BpttSimulation:
+
+    # simulation setup
+    scan_step = partial(simulation_step, step_fn=step_fn,
+                                      energy_fn=energy, kT=simulation_params.kT,
+                                      config_every=simulation_params.config_every)
+
+    # loop_fn = partial(equinox.internal.scan, kind='checkpointed', checkpoints=10)
+
+    if simulation_params.simulation_steps < simulation_params.config_every:
+        raise ValueError(f'Number of simulation steps must be higher or equal to the config_every value, '
+                         f'got {simulation_params.simulation_steps} and {simulation_params.config_every}, '
+                         f'respectively')
+
+    n_iterations = simulation_params.simulation_steps // simulation_params.bptt_truncation
+    if n_iterations == 0:
+        raise ValueError(f'Number of simulation steps must be equal to or grater than BPTT truncation, '
+                         f'got {simulation_params.simulation_steps} and {simulation_params.bptt_truncation}, '
+                         f'respectively.')
+    if n_iterations * simulation_params.bptt_truncation < simulation_params.simulation_steps:
+        warnings.warn(f'Only {n_iterations * simulation_params.bptt_truncation} time steps will be calculated '
+                      f'as bptt truncation length does not divide the desired number of steps exactly.')
+
+    def forward_function(params: InteractionParams, state, aux):
+        (state, aux, _), _ = lax.scan(scan_step, (state, aux, params), jnp.arange(simulation_params.bptt_truncation))
+        cost = cost_fn(state.position, **vars(params))
+        return cost, (state, aux)
+
+    grad_fn = jax.value_and_grad(forward_function, has_aux=True, argnums=(0,))
+
+    def bptt_section(state_aux_params_results, i):
+        state, aux, params, results = state_aux_params_results
+        value, grad = grad_fn(params, state, aux)
+        cost, (state, aux) = value
+        results = results.update(grad[0], cost)
+        return (state, aux, params, results), 0.
+
+    def bptt_section_test(state_aux_params_results, i):
+        state, aux, params, results = state_aux_params_results
+        value = forward_function(params, state, aux)
+        cost, (state, aux) = value
+        results = results.update(jax.tree_util.tree_map(jnp.zeros_like, params), cost)
+        return (state, aux, params, results), 0.
+
+    if only_forward_calculation:
+        bptt_section = bptt_section_test
+
+    def simulation(interaction_params: InteractionParams,
+                   init_state: Any,
+                   aux: data_protocols.SimulationAux
+                   ):
+
+        # initialize state
+        # state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
+
+        # initialize results object
+        bptt_results = BPTTResults.create_empty(interaction_params, n_iterations)
+
+        # run simulation
+        state_aux_results, _  = jax.lax.scan(bptt_section,
+                                             (init_state, aux, interaction_params, bptt_results),
+                                             xs=jnp.arange(n_iterations))
+        state, aux, params, bptt_results = state_aux_results
+
+        return bptt_results, aux
+
+    return simulation

+ 118 - 0
curvature_assembly/smap.py

@@ -0,0 +1,118 @@
+import jax
+from curvature_assembly.oriented_particle import OrientedParticleEnergy, eigensystem
+import jax.numpy as jnp
+from functools import partial
+from typing import Callable
+from jax_md import space, smap, util, partition, rigid_body
+
+Array = jnp.ndarray
+
+
+def oriented_pair(fn: OrientedParticleEnergy,
+                  displacement: space.DisplacementFn,
+                  ignore_unused_parameters: bool = False,
+                  **kwargs) -> Callable[..., Array]:
+    """
+    Promotes a function that acts on a pair of ellipses to one on a system.
+
+    Args:
+        fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
+        displacement: displacement function that calculates distances between particles.
+        ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
+            passed to the mapped function get ignored if they were not first specified as keyword arguments
+            when calling `oriented_pair(...)`.
+        kwargs: arguments providing parameters to the mapped function.
+
+    Return:
+        A function fn_mapped that takes a RigidBody object.
+    """
+
+    kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
+    merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
+
+    def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
+        rows, columns = jnp.triu_indices(body.center.shape[0], 1)
+        particle1 = body[rows]
+        particle2 = body[columns]
+        dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
+        eigsys1 = eigensystem(particle1.orientation)
+        eigsys2 = eigensystem(particle2.orientation)
+
+        _kwargs = merge_dicts(kwargs, dynamic_kwargs)
+        # _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
+
+        all_pair_interctions = jax.vmap(partial(fn, **_kwargs))(dr, eigsys1, eigsys2)
+        return util.high_precision_sum(all_pair_interctions)
+
+    # def fn_mapped(body: rigid_body.RigidBody, **dynamic_kwargs) -> Array:
+    #     # this does not give the same results as the above fn_mapped, but it should?
+    #     d = space.map_product(partial(displacement, **dynamic_kwargs))
+    #     eigsys = eigensystem(body.orientation)
+    #     _kwargs = merge_dicts(kwargs, dynamic_kwargs)
+    #     _kwargs = smap._kwargs_to_parameters(None, _kwargs, param_combinators)
+    #     # print(_kwargs)
+    #     dr = d(body.center, body.center)
+    #     meshx, meshy = jnp.meshgrid(jnp.arange(body.center.shape[0]), jnp.arange(body.center.shape[0]))
+    #     eigsys1 = eigsys[meshx]
+    #     eigsys2 = eigsys[meshy]
+    #     # print(dr.shape, eigsys1, eigsys2)
+    #     return util.high_precision_sum(smap._diagonal_mask(fn(dr, eigsys1, eigsys2, **_kwargs)),
+    #                               axis=None, keepdims=False) * util.f32(0.5)
+
+    return fn_mapped
+
+
+def oriented_pair_neighbor_list(fn: OrientedParticleEnergy,
+                                displacement: space.DisplacementFn,
+                                ignore_unused_parameters: bool = False,
+                                **kwargs) -> Callable[..., Array]:
+    """
+    Promotes a function acting on pairs of particles to use neighbor lists.
+
+    Args:
+        fn: energy function that takes distance, eigensystem1, eigensystem2 as first three arguments.
+        displacement: displacement function that calculates distances between particles.
+        ignore_unused_parameters: A boolean that denotes whether dynamically specified keyword arguments
+            passed to the mapped function get ignored if they were not first specified as keyword arguments
+            when calling `oriented_pair(...)`.
+        kwargs: arguments providing parameters to the mapped function.
+
+    Return:
+        A function `fn_mapped` that takes a RigidBody object and a NeighborList object specifying neighbors.
+    """
+
+    kwargs, param_combinators = smap._split_params_and_combinators(kwargs)
+    merge_dicts = partial(util.merge_dicts, ignore_unused_parameters=ignore_unused_parameters)
+
+    def fn_mapped(body: rigid_body.RigidBody, neighbor: partition.NeighborList, **dynamic_kwargs) -> Array:
+
+        normalization = 2.0
+
+        if partition.is_sparse(neighbor.format):
+            particle1 = body[neighbor.idx[0]]
+            particle2 = body[neighbor.idx[1]]
+            dr = jax.vmap(partial(displacement, **dynamic_kwargs))(particle1.center, particle2.center)
+            eigsys1 = eigensystem(particle1.orientation)
+            eigsys2 = eigensystem(particle2.orientation)
+
+            mask = neighbor.idx[0] < body.center.shape[0]  # takes care of fill values in neighbor lists
+            if neighbor.format is partition.OrderedSparse:
+                normalization = 1.0
+        else:
+            raise NotImplementedError('Only sparse neighbor lists are currently supported.')
+
+        merged_kwargs = merge_dicts(kwargs, dynamic_kwargs)
+        merged_kwargs = smap._neighborhood_kwargs_to_params(neighbor.format,
+                                                            neighbor.idx,
+                                                            None,
+                                                            merged_kwargs,
+                                                            param_combinators)
+        out = jax.vmap(partial(fn, **merged_kwargs))(dr, eigsys1, eigsys2)
+        if out.ndim > mask.ndim:
+            ddim = out.ndim - mask.ndim
+            mask = jnp.reshape(mask, mask.shape + (1,) * ddim)
+        out *= mask
+
+        return util.high_precision_sum(out) / normalization
+
+    return fn_mapped

+ 965 - 0
curvature_assembly/spherical_harmonics.py

@@ -0,0 +1,965 @@
+import jax.numpy as jnp
+from typing import Callable
+import jax
+from functools import partial
+
+
+Array = jnp.ndarray
+
+
+def neg_m(sph_harm: Callable, m: int) -> Callable:
+    def wrapped(x):
+        return -1 ** m * jnp.conj(sph_harm(x))
+    return wrapped
+
+
+def Y00(x):
+    return jax.lax.convert_element_type(0.5 * jnp.sqrt(1 / jnp.pi), new_dtype=jnp.complex128)
+
+
+def Y10(x):
+    return jax.lax.convert_element_type(0.5 * jnp.sqrt(3 / jnp.pi) * x[2], new_dtype=jnp.complex128)
+
+
+def Y11(x):
+    return -0.5 * jnp.sqrt(1.5 / jnp.pi) * (x[0] + 1j * x[1])
+
+
+def Y20(x):
+    return jax.lax.convert_element_type(0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1),
+                                        new_dtype=jnp.complex128)
+    # return jax.lax.convert_element_type(0.3153915652525201 * (3 * x[2] ** 2 - 1),
+    #                                     new_dtype=jnp.complex128)
+
+
+def Y21(x):
+    return -0.5 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) * x[2]
+    # return -0.7725484040463791 * (x[0] + 1j * x[1]) * x[2]
+
+
+def Y22(x):
+    return 0.25 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2
+    # return 0.3862742020231896 * (x[0] + 1j * x[1]) ** 2
+
+
+
+def Y30(x):
+    return jax.lax.convert_element_type(0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2]),
+                                        new_dtype=jnp.complex128)
+
+
+def Y31(x):
+    return -0.125 * jnp.sqrt(21 / jnp.pi) * (x[0] + 1j * x[1]) * (5 * x[2] ** 2 - 1)
+
+
+def Y32(x):
+    return 0.25 * jnp.sqrt(52.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * x[2]
+
+
+def Y33(x):
+    return -0.125 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3
+
+
+def Y40(x):
+    return jax.lax.convert_element_type(3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
+                                        new_dtype=jnp.complex128)
+    # return jax.lax.convert_element_type(0.1057855469152043 * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
+    #                                     new_dtype=jnp.complex128)
+
+
+def Y41(x):
+    return -3 / 8 * jnp.sqrt(5 / jnp.pi) * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
+    # return -0.47308734787878 * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
+
+
+def Y42(x):
+    return 3 / 8 * jnp.sqrt(2.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 -1)
+    # return 0.3345232717786446 * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 - 1)
+
+
+def Y43(x):
+    return -3 / 8 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * x[2]
+    # return 1.251671470898352 * (x[0] + 1j * x[1]) ** 3 * x[2]
+
+
+def Y44(x):
+    return 3 / 16 * jnp.sqrt(17.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4
+    # return 0.4425326924449826 * (x[0] + 1j * x[1]) ** 4
+
+
+def Y50(x):
+    return jax.lax.convert_element_type(1 / 16 * jnp.sqrt(11 / jnp.pi) *
+                                        (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2]),
+                                        new_dtype=jnp.complex128)
+
+
+def Y51(x):
+    return -1 / 16 * jnp.sqrt(82.5 / jnp.pi) *\
+        (x[0] + 1j * x[1]) * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
+
+
+def Y52(x):
+    return 1 / 8 * jnp.sqrt(577.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (3 * x[2] ** 3 - x[2])
+
+
+def Y53(x):
+    return -1 / 32 * jnp.sqrt(385 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (9 * x[2] ** 2 - 1)
+
+
+def Y54(x):
+    return 3 / 16 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * x[2]
+
+
+def Y55(x):
+    return -3 / 32 * jnp.sqrt(77 / jnp.pi) * (x[0] + 1j * x[1]) ** 5
+
+
+def Y60(x):
+    return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(13 / jnp.pi) *
+                                        (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5),
+                                        new_dtype=jnp.complex128)
+
+
+def Y61(x):
+    return -1 / 16 * jnp.sqrt(136.5 / jnp.pi) *\
+        (x[0] + 1j * x[1]) * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
+
+
+def Y62(x):
+    return 1 / 64 * jnp.sqrt(1365 / jnp.pi) *\
+        (x[0] + 1j * x[1]) ** 2 * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
+
+
+def Y63(x):
+    return -1 / 32 * jnp.sqrt(1365 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (11 * x[2] ** 3 - 3 * x[2])
+
+
+def Y64(x):
+    return 3 / 32 * jnp.sqrt(45.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (11 * x[2] ** 2 - 1)
+
+
+def Y65(x):
+    return -3 / 32 * jnp.sqrt(1001 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * x[2]
+
+
+def Y66(x):
+    return 1 / 64 * jnp.sqrt(3003 / jnp.pi) * (x[0] + 1j * x[1]) ** 6
+
+
+def Y70(x):
+    return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(15 / jnp.pi) *
+                                        (429 * x[2] ** 7 - 693 * x[2] ** 5 + 315 * x[2] ** 3 - 35 * x[2]),
+                                        new_dtype=jnp.complex128)
+
+
+def Y71(x):
+    return -1 / 64 * jnp.sqrt(52.5 / jnp.pi) *\
+        (x[0] + 1j * x[1]) * (429 * x[2] ** 6 - 495 * x[2] ** 4 + 135 * x[2] ** 2 - 5)
+
+
+def Y72(x):
+    return 3 / 64 * jnp.sqrt(35 / jnp.pi) * \
+        (x[0] + 1j * x[1]) ** 2 * (143 * x[2] ** 5 - 110 * x[2] ** 3 + 15 * x[2])
+
+
+def Y73(x):
+    return -3 / 64 * jnp.sqrt(17.5 / jnp.pi) * \
+        (x[0] + 1j * x[1]) ** 3 * (143 * x[2] ** 4 - 66 * x[2] ** 2 + 3)
+
+
+def Y74(x):
+    return 3 / 32 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (13 * x[2] ** 3 - 3 * x[2])
+
+
+def Y75(x):
+    return -3 / 64 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * (13 * x[2] ** 2 - 1)
+
+
+def Y76(x):
+    return 3 / 64 * jnp.sqrt(5005 / jnp.pi) * (x[0] + 1j * x[1]) ** 6 * x[2]
+
+
+def Y77(x):
+    return -3 / 64 * jnp.sqrt(357.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 7
+
+
+def get_sph_function(l, m):
+    if abs(m) <= l < len(sph_harm_list):
+        return jax.jit(sph_harm_list[l * (l + 1) + m])
+    return 0
+
+
+# @partial(jax.jit, static_argnums=(1, 2))
+# def sph_harm(x: Array, l: Array, m: Array) -> Array:
+#     return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
+
+
+def sph_harm_fn(l: tuple, m: tuple) -> Callable[[Array,], Array]:
+
+    def f(x: Array):
+        return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
+
+    return f
+
+
+def sph_harm_fn_custom(l: tuple, m: tuple) -> Callable[[Array,], Array]:
+
+    l_array = jnp.array(l)
+    m_array = jnp.array(m)
+
+    # @jax.custom_jvp
+    @jax.custom_vjp
+    def f(x: Array):
+        return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
+
+    # @f.defjvp
+    def sph_harm_jvp(primals, tangents):
+        x, = primals
+        dx, = tangents
+
+        primal_out = f(x)
+
+        extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        extanded_primal = extanded_primal.at[1:-1].set(primal_out)
+        rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
+        rho = jnp.sqrt(rho2)
+
+        coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
+        coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
+        theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
+        phi_derivatives = 1j * m_array * primal_out
+
+        x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
+        y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
+        z_derivatives = -theta_derivatives * rho
+
+        jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
+        tangent_out = jacobian.T @ dx
+        return primal_out, tangent_out
+
+    def sph_harm_fwd(x):
+        primal_out = f(x)
+
+        extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        extanded_primal = extanded_primal.at[1:-1].set(primal_out)
+        rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
+        rho = jnp.sqrt(rho2)
+
+        coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
+        coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
+        theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
+        phi_derivatives = 1j * m_array * primal_out
+
+        x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
+        y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
+        z_derivatives = -theta_derivatives * rho
+
+        jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
+        return primal_out, jacobian
+
+    def sph_harm_rev(jacobian, y_bar):
+        return (y_bar @ jacobian,)
+
+    f.defvjp(sph_harm_fwd, sph_harm_rev)
+
+    return f
+
+sph_harm_list = [Y00,
+                 neg_m(Y11, 1), Y10, Y11,
+                 neg_m(Y22, 2), neg_m(Y21, 1), Y20, Y21, Y22,
+                 neg_m(Y33, 3), neg_m(Y32, 2), neg_m(Y31, 1), Y30, Y31, Y32, Y33,
+                 neg_m(Y44, 4), neg_m(Y43, 3), neg_m(Y42, 2), neg_m(Y41, 1), Y40, Y41, Y42, Y43, Y44,
+                 neg_m(Y55, 5), neg_m(Y54, 4), neg_m(Y53, 3), neg_m(Y52, 2), neg_m(Y51, 1), Y50, Y51, Y52, Y53, Y54, Y55,
+                 neg_m(Y66, 6), neg_m(Y65, 5), neg_m(Y64, 4), neg_m(Y63, 3), neg_m(Y62, 2), neg_m(Y61, 1), Y60, Y61, Y62, Y63, Y64, Y65, Y66,
+                 neg_m(Y77, 7), neg_m(Y76, 6), neg_m(Y75, 5), neg_m(Y74, 4), neg_m(Y73, 3), neg_m(Y72, 2), neg_m(Y71, 1), Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77]
+
+
+def Y00real(x):
+    return 0.5 * jnp.sqrt(1 / jnp.pi)
+
+
+def Y1m1real(x):
+    return 0.5 * jnp.sqrt(3 / jnp.pi) * x[1]
+
+
+def Y10real(x):
+    return 0.5 * jnp.sqrt(3 / jnp.pi) * x[2]
+
+
+def Y11real(x):
+    return 0.5 * jnp.sqrt(3 / jnp.pi) * x[0]
+
+
+def Y2m2real(x):
+    return 0.5 * jnp.sqrt(15 / jnp.pi) * x[0] * x[1]
+
+
+def Y2m1real(x):
+    return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[1]
+
+
+def Y20real(x):
+    return 0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1)
+
+
+def Y21real(x):
+    return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[0]
+
+
+def Y22real(x):
+    return 0.25 * jnp.sqrt(15 / jnp.pi) * (x[0] ** 2 - x[1] ** 2)
+
+
+def Y3m3real(x):
+    return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2)
+
+
+def Y3m2real(x):
+    return 0.5 * jnp.sqrt(105 / jnp.pi) * x[0] * x[1] * x[2]
+
+
+def Y3m1real(x):
+    return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[1] * (5 * x[2] ** 2 - 1)
+
+
+def Y30real(x):
+    return 0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2])
+
+
+def Y31real(x):
+    return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[0] * (5 * x[2] ** 2 - 1)
+
+
+def Y32real(x):
+    return 0.25 * jnp.sqrt(105 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * x[2]
+
+
+def Y33real(x):
+    return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2)
+
+
+def Y4m4real(x):
+    return 0.75 * jnp.sqrt(35 / jnp.pi) * x[0] * x[1] * (x[0] ** 2 - x[1] ** 2)
+
+
+def Y4m3real(x):
+    return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * x[2]
+
+
+def Y4m2real(x):
+    return 0.75 * jnp.sqrt(5 / jnp.pi) * x[0] * x[1] * (7 * x[2] ** 2 - 1)
+
+
+def Y4m1real(x):
+    return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[1] * (7 * x[2] ** 3 - 3 * x[2])
+
+
+def Y40real(x):
+    return 3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3)
+
+
+def Y41real(x):
+    return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[0] * (7 * x[2] ** 3 - 3 * x[2])
+
+
+def Y42real(x):
+    return 0.375 * jnp.sqrt(5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (7 * x[2] ** 2 - 1)
+
+
+def Y43real(x):
+    return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * x[2]
+
+
+def Y44real(x):
+    return 0.1875 * jnp.sqrt(35 / jnp.pi) * (x[0] ** 2 * (x[0] ** 2 - 3 * x[1] ** 2) - x[1] ** 2 * (3 * x[0] ** 2 - x[1] ** 2))
+
+
+def Y5m5real(x):
+    return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5)
+
+
+def Y5m4real(x):
+    return 3 / 4 * jnp.sqrt(385 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * x[2]
+
+
+def Y5m3real(x):
+    return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (9 * x[2] ** 2 - 1)
+
+
+def Y5m2real(x):
+    return 1 / 4 * jnp.sqrt(1155 / jnp.pi) * x[0] * x[1] * (3 * x[2] ** 3 - x[2])
+
+
+def Y5m1real(x):
+    return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[1] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
+
+
+def Y50real(x):
+    return 1 / 16 * jnp.sqrt(11 / jnp.pi) * (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2])
+
+
+def Y51real(x):
+    return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[0] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
+
+
+def Y52real(x):
+    return 1 / 8 * jnp.sqrt(1155 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (3 * x[2] ** 3 - x[2])
+
+
+def Y53real(x):
+    return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (9 * x[2] ** 2 - 1)
+
+
+def Y54real(x):
+    return 3 / 16 * jnp.sqrt(385 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * x[2]
+
+
+def Y55real(x):
+    return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4)
+
+
+def Y6m6real(x):
+    return 1 / 16 * jnp.sqrt(1501.5 / jnp.pi) * x[0] * x[1] * (-3 * x[0] ** 4 + 10 * x[0] ** 2 * x[1] ** 2 - 3 * x[1] ** 4)
+
+
+def Y6m5real(x):
+    return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5) * x[2]
+
+
+def Y6m4real(x):
+    return 3 / 8 * jnp.sqrt(91 / jnp.pi) *  x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * (11 * x[2] ** 2 - 1)
+
+
+def Y6m3real(x):
+    return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
+
+
+def Y6m2real(x):
+    return 1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * x[1] * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
+
+
+def Y6m1real(x):
+    return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[1] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
+
+
+def Y60real(x):
+    return 1 / 32 * jnp.sqrt(13 / jnp.pi) * (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5)
+
+
+def Y61real(x):
+    return -1 / 16 * jnp.sqrt(273 / jnp.pi) *  x[0] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
+
+
+def Y62real(x):
+    return 1 / 32 * jnp.sqrt(682.5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
+
+
+def Y63real(x):
+    return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
+
+
+def Y64real(x):
+    return 3 / 32 * jnp.sqrt(91 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * (11 * x[2] ** 2 - 1)
+
+
+def Y65real(x):
+    return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4) * x[2]
+
+
+def Y66real(x):
+    return 1 / 32 * jnp.sqrt(1501.5 / jnp.pi) * (x[0] ** 6 - 15 * x[0] ** 4 * x[1] ** 2 + 15 * x[0] ** 2 * x[1] ** 4 - x[1] ** 6)
+
+
+def get_real_sph_function(l, m):
+    if abs(m) <= l < len(sph_harm_list):
+        return jax.jit(real_sph_harm_list[l * (l + 1) + m])
+    return 0
+
+
+@partial(jax.jit, static_argnums=(1, 2))
+def real_sph_harm(x: Array, l: Array, m: Array) -> Array:
+    return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
+
+
+def real_sph_harm_fn_custom_fwd(l_max: int) -> Callable[[Array,], Array]:
+
+    l_list = list(range(0, l_max + 1))
+    lm_list = []
+    for l in l_list:
+        for m in range(-l, l + 1):
+            lm_list.append((l, m))
+
+    l_list, m_list = zip(*lm_list)
+
+    l_array = jnp.array(l_list)
+    m_array = jnp.array(m_list)
+
+    # indices where derivative rules differ from the general case
+    m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
+    m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
+    m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
+
+    m_plus_one_factors = jnp.sign(m_array)
+    m_minus_one_factors = -jnp.sign(m_array)
+    minus_m_plus_one_factors = jnp.sign(m_array)
+    minus_m_minus_one_factors = jnp.sign(m_array)
+
+    # m = -1, 0, 1 special cases:
+    m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
+    m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
+
+    m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
+    minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
+
+    m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
+    minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
+
+    @jax.custom_jvp
+    def f(x: Array):
+        return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
+
+    @f.defjvp
+    def sph_harm_jvp(primals, tangents):
+        x, = primals
+        dx, = tangents
+
+        primal_out = f(x)
+
+        extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        extanded_primal = extanded_primal.at[1:-1].set(primal_out)
+        mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
+
+        rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
+        rho = jnp.sqrt(rho2)
+
+        coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
+        coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
+
+        theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
+                                              x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
+                                   coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
+                                              x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
+        phi_derivatives = m_array * primal_out
+
+        x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
+        y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
+        z_derivatives = -theta_derivatives * rho
+
+        jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
+        tangent_out = jacobian.T @ dx
+        return primal_out, tangent_out
+
+    return f
+
+
+def real_sph_harm_fn_custom_rev(l_max: int) -> Callable[[Array, ], Array]:
+    l_list = list(range(0, l_max + 1))
+    lm_list = []
+    for l in l_list:
+        for m in range(-l, l + 1):
+            lm_list.append((l, m))
+
+    l_list, m_list = zip(*lm_list)
+
+    l_array = jnp.array(l_list)
+    m_array = jnp.array(m_list)
+
+    # indices where derivative rules differ from the general case
+    m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
+    m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
+    m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
+
+    m_plus_one_factors = jnp.sign(m_array)
+    m_minus_one_factors = -jnp.sign(m_array)
+    minus_m_plus_one_factors = jnp.sign(m_array)
+    minus_m_minus_one_factors = jnp.sign(m_array)
+
+    # m = -1, 0, 1 special cases:
+    m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
+    m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
+
+    m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
+    minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
+
+    m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
+    minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
+
+    @jax.custom_vjp
+    def f(x: Array):
+        return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
+
+    def sph_harm_fwd(x):
+        primal_out = f(x)
+
+        extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
+        extanded_primal = extanded_primal.at[1:-1].set(primal_out)
+        mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
+
+        rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
+        rho = jnp.sqrt(rho2)
+
+        coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
+        coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
+
+        theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
+                                              x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
+                                   coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
+                                              x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
+        phi_derivatives = m_array * primal_out
+
+        x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
+        y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
+        z_derivatives = -theta_derivatives * rho
+
+        jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
+
+        return primal_out, jacobian
+
+    def sph_harm_rev(jacobian, y_bar):
+        return (y_bar @ jacobian,)
+
+    f.defvjp(sph_harm_fwd, sph_harm_rev)
+
+    return f
+
+
+real_sph_harm_list = [Y00real,
+                      Y1m1real, Y10real, Y11real,
+                      Y2m2real, Y2m1real, Y20real, Y21real, Y22real,
+                      Y3m3real, Y3m2real, Y3m1real, Y30real, Y31real, Y32real, Y33real,
+                      Y4m4real, Y4m3real, Y4m2real, Y4m1real, Y40real, Y41real, Y42real, Y43real, Y44real,
+                      Y5m5real, Y5m4real, Y5m3real, Y5m2real, Y5m1real, Y50real, Y51real, Y52real, Y53real, Y54real, Y55real,
+                      Y6m6real, Y6m5real, Y6m4real, Y6m3real, Y6m2real, Y6m1real, Y60real, Y61real, Y62real, Y63real, Y64real, Y65real, Y66real]
+
+
+# spherical harmonic prefactors up to l=7
+ylm_prefactors = jnp.array([0.5 * jnp.sqrt(1 / jnp.pi),
+
+                            0.5 * jnp.sqrt(1.5 / jnp.pi),
+                            0.5 * jnp.sqrt(3 / jnp.pi),
+                            -0.5 * jnp.sqrt(1.5 / jnp.pi),
+
+                            0.25 * jnp.sqrt(7.5 / jnp.pi),
+                            0.5 * jnp.sqrt(7.5 / jnp.pi),
+                            0.25 * jnp.sqrt(5 / jnp.pi),
+                            -0.5 * jnp.sqrt(7.5 / jnp.pi),
+                            0.25 * jnp.sqrt(7.5 / jnp.pi),
+
+                            0.125 * jnp.sqrt(35 / jnp.pi),
+                            0.25 * jnp.sqrt(52.5 / jnp.pi),
+                            0.125 * jnp.sqrt(21 / jnp.pi),
+                            0.25 * jnp.sqrt(7 / jnp.pi),
+                            -0.125 * jnp.sqrt(21 / jnp.pi),
+                            0.25 * jnp.sqrt(52.5 / jnp.pi),
+                            -0.125 * jnp.sqrt(35 / jnp.pi),
+
+                            3 / 16 * jnp.sqrt(17.5 / jnp.pi),
+                            3 / 8 * jnp.sqrt(35 / jnp.pi),
+                            3 / 8 * jnp.sqrt(2.5 / jnp.pi),
+                            3 / 8 * jnp.sqrt(5 / jnp.pi),
+                            3 / 16 * jnp.sqrt(1 / jnp.pi),
+                            -3 / 8 * jnp.sqrt(5 / jnp.pi),
+                            3 / 8 * jnp.sqrt(2.5 / jnp.pi),
+                            -3 / 8 * jnp.sqrt(35 / jnp.pi),
+                            3 / 16 * jnp.sqrt(17.5 / jnp.pi),
+
+                            3 / 32 * jnp.sqrt(77 / jnp.pi),
+                            3 / 16 * jnp.sqrt(192.5 / jnp.pi),
+                            1 / 32 * jnp.sqrt(385 / jnp.pi),
+                            1 / 8 * jnp.sqrt(577.5 / jnp.pi),
+                            1 / 16 * jnp.sqrt(82.5 / jnp.pi),
+                            1 / 16 * jnp.sqrt(11 / jnp.pi),
+                            -1 / 16 * jnp.sqrt(82.5 / jnp.pi),
+                            1 / 8 * jnp.sqrt(577.5 / jnp.pi),
+                            -1 / 32 * jnp.sqrt(385 / jnp.pi),
+                            3 / 16 * jnp.sqrt(192.5 / jnp.pi),
+                            -3 / 32 * jnp.sqrt(77 / jnp.pi),
+
+                            1 / 64 * jnp.sqrt(3003 / jnp.pi),
+                            3 / 32 * jnp.sqrt(1001 / jnp.pi),
+                            3 / 32 * jnp.sqrt(45.5 / jnp.pi),
+                            1 / 32 * jnp.sqrt(1365 / jnp.pi),
+                            1 / 64 * jnp.sqrt(1365 / jnp.pi),
+                            1 / 16 * jnp.sqrt(136.5 / jnp.pi),
+                            1 / 32 * jnp.sqrt(13 / jnp.pi),
+                            -1 / 16 * jnp.sqrt(136.5 / jnp.pi),
+                            1 / 64 * jnp.sqrt(1365 / jnp.pi),
+                            -1 / 32 * jnp.sqrt(1365 / jnp.pi),
+                            3 / 32 * jnp.sqrt(45.5 / jnp.pi),
+                            -3 / 32 * jnp.sqrt(1001 / jnp.pi),
+                            1 / 64 * jnp.sqrt(3003 / jnp.pi),
+
+                            3 / 64 * jnp.sqrt(357.5 / jnp.pi),
+                            3 / 64 * jnp.sqrt(5005 / jnp.pi),
+                            3 / 64 * jnp.sqrt(192.5 / jnp.pi),
+                            3 / 32 * jnp.sqrt(192.5 / jnp.pi),
+                            3 / 64 * jnp.sqrt(17.5 / jnp.pi),
+                            3 / 64 * jnp.sqrt(35 / jnp.pi),
+                            1 / 64 * jnp.sqrt(52.5 / jnp.pi),
+                            1 / 32 * jnp.sqrt(15 / jnp.pi),
+                            -1 / 64 * jnp.sqrt(52.5 / jnp.pi),
+                            3 / 64 * jnp.sqrt(35 / jnp.pi),
+                            -3 / 64 * jnp.sqrt(17.5 / jnp.pi),
+                            3 / 32 * jnp.sqrt(192.5 / jnp.pi),
+                            -3 / 64 * jnp.sqrt(192.5 / jnp.pi),
+                            3 / 64 * jnp.sqrt(5005 / jnp.pi),
+                            3 / 64 * jnp.sqrt(357.5 / jnp.pi),
+                            ])
+
+# coefficient array up to l=7
+z_coef_array = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 3, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 5, 0, 0, 0, 0, 0],
+                          [0, -3, 0, 5, 0, 0, 0, 0],
+                          [-1, 0, 5, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 7, 0, 0, 0, 0, 0],
+                          [0, -3, 0, 7, 0, 0, 0, 0],
+                          [3, 0, -30, 0, 35, 0, 0, 0],
+                          [0, -3, 0, 7, 0, 0, 0, 0],
+                          [-1, 0, 7, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 9, 0, 0, 0, 0, 0],
+                          [0, -1, 0, 3, 0, 0, 0, 0],
+                          [1, 0, -14, 0, 21, 0, 0, 0],
+                          [0, 15, 0, -70, 0, 63, 0, 0],
+                          [1, 0, -14, 0, 21, 0, 0, 0],
+                          [0, -1, 0, 3, 0, 0, 0, 0],
+                          [-1, 0, 9, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 11, 0, 0, 0, 0, 0],
+                          [0, -3, 0, 11, 0, 0, 0, 0],
+                          [1, 0, -18, 0, 33, 0, 0, 0],
+                          [0, 5, 0, -30, 0, 33, 0, 0],
+                          [-5, 0, 105, 0, -315, 0, 231, 0],
+                          [0, 5, 0, -30, 0, 33, 0, 0],
+                          [1, 0, -18, 0, 33, 0, 0, 0],
+                          [0, -3, 0, 11, 0, 0, 0, 0],
+                          [-1, 0, 11, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [-1, 0, 13, 0, 0, 0, 0, 0],
+                          [0, -3, 0, 13, 0, 0, 0, 0],
+                          [3, 0, -66, 0, 143, 0, 0, 0],
+                          [0, 15, 0, -110, 0, 143, 0, 0],
+                          [-5, 0, 135, 0, -495, 0, 429, 0],
+                          [0, -35, 0, 315, 0, -693, 0, 429],
+                          [-5, 0, 135, 0, -495, 0, 429, 0],
+                          [0, 15, 0, -110, 0, 143, 0, 0],
+                          [3, 0, -66, 0, 143, 0, 0, 0],
+                          [0, -3, 0, 13, 0, 0, 0, 0],
+                          [-1, 0, 13, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          ])
+
+
+x_p_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+                          [0, 0, 0, 0, 0, 0, 1, 0],
+
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+                          [0, 0, 0, 0, 0, 0, 1, 0],
+                          [0, 0, 0, 0, 0, 0, 0, 1],
+                          ])
+
+
+x_m_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 0, 0, 0, 0, 1, 0],
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+
+                          [0, 0, 0, 0, 0, 0, 0, 1],
+                          [0, 0, 0, 0, 0, 0, 1, 0],
+                          [0, 0, 0, 0, 0, 1, 0, 0],
+                          [0, 0, 0, 0, 1, 0, 0, 0],
+                          [0, 0, 0, 1, 0, 0, 0, 0],
+                          [0, 0, 1, 0, 0, 0, 0, 0],
+                          [0, 1, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          [1, 0, 0, 0, 0, 0, 0, 0],
+                          ])
+
+
+def sph_harm_not_fast(l: Array, m: Array) -> Callable:
+    max_l = jnp.max(l)
+    idx = l * (l + 1) + m  # l ** 2 + l + m
+
+    prefactors = ylm_prefactors[idx]
+    z_coefs = z_coef_array[idx, :max_l+1]
+    xy = x_p_iy_coefs[idx, :max_l+1]
+    xy_conj = x_m_iy_coefs[idx, :max_l+1]
+
+    powers = jnp.arange(max_l + 1, dtype=jnp.int32)
+
+    @partial(jnp.vectorize, signature='(d)->(k)')
+    def f(x):
+        z_powers = x[2] ** powers
+        xy_powers = (x[0] + 1j * x[1]) ** powers
+        xy_conj_powers = jnp.conj(xy_powers)
+        return prefactors * (z_coefs @ z_powers) * (xy @ xy_powers) * (xy_conj @ xy_conj_powers)
+
+    @partial(jnp.vectorize, signature='(d)->(k)')
+    def f_positive_m(x):
+        z_powers = x[2] ** powers
+        xy_powers = (x[0] + 1j * x[1]) ** m
+        return prefactors * (z_coefs @ z_powers) * xy_powers
+
+    if jnp.all(m >= 0):
+        return f_positive_m
+
+    return jax.jit(f)
+
+
+
+

+ 109 - 0
curvature_assembly/surface_fit.py

@@ -0,0 +1,109 @@
+from __future__ import annotations
+from jaxopt import GaussNewton, LevenbergMarquardt
+import jax.numpy as jnp
+from typing import Callable, TypeVar
+import jax
+from functools import partial
+from jax_md import dataclasses
+
+
+Array = jnp.ndarray
+T = TypeVar('T')
+
+
+@dataclasses.dataclass
+class QuadraticSurfaceParams:
+    center: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
+    euler: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
+    radius: float = 1.
+
+    def to_array(self) -> Array:
+        return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
+
+    @staticmethod
+    def from_array(x: Array) -> QuadraticSurfaceParams:
+        return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
+
+
+def rotation_matrix(euler_angles: Array) -> Array:
+    alpha, beta, gamma = euler_angles
+    Rz1 = jnp.array([[jnp.cos(alpha), -jnp.sin(alpha), 0],
+                     [jnp.sin(alpha), jnp.cos(alpha), 0],
+                     [0, 0, 1]])
+
+    Ry = jnp.array([[jnp.cos(beta), 0, -jnp.sin(beta)],
+                    [0, 1, 0],
+                    [jnp.sin(beta), 0, jnp.cos(beta)]])
+
+    Rz2 = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0],
+                     [jnp.sin(gamma), jnp.cos(gamma), 0],
+                     [0, 0, 1]])
+    return Rz2 @ Ry @ Rz1
+
+
+@partial(jnp.vectorize, signature='(d,d),(d)->(d,d)')
+def quadratic_form(rot_mat, eigvals: Array):
+    a, b, c = eigvals
+    eig_mat = jnp.array([[a, 0, 0],
+                         [0, b, 0],
+                         [0, 0, c]])
+    return rot_mat @ eig_mat @ jnp.transpose(rot_mat)
+
+
+SurfaceFn = Callable[[QuadraticSurfaceParams, Array, Array], Array]
+
+
+def spherical_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array) -> Array:
+    """
+    Residual function for fitting a spherical surface to a group of particles defined by coord array and
+    a mask over coord.
+    """
+    return (jnp.linalg.norm(coord - params.center, axis=1) - jnp.abs(params.radius)) * mask
+
+
+def quadratic_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array, qf_eigvals: Array) -> Array:
+    """
+    Residual function for fitting a cylinder to a group of particles defined by coord array and a mask over coord.
+    """
+    relative_coord = coord - params.center
+    qf = quadratic_form(rotation_matrix(params.euler), qf_eigvals)
+    # return (jnp.sum(relative_coord * (relative_coord @ qf), axis=1) ** 2 - jnp.abs(params.radius) ** 2) * mask
+    return (jnp.sqrt(jnp.sum(relative_coord * (relative_coord @ qf), axis=1)) - jnp.abs(params.radius)) * mask
+
+
+cylindrical_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., 0.]))
+hyperbolic_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., -2.]))
+
+
+def surface_fit_gn(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: T, verbose: bool = False) -> T:
+    """
+    Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
+    """
+    gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
+    # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
+    p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
+    opt = gn.run(p0, coord, mask)
+    return opt.params
+
+
+def surface_fit_lm(surface_fn: SurfaceFn,
+                   coord: Array,
+                   mask: Array,
+                   p0: QuadraticSurfaceParams,
+                   verbose: bool = False) -> QuadraticSurfaceParams:
+    """
+    Fit a surface to a group of particles defined by coord array and a mask over coord
+    using the Levenberg-Marquardt method. Doesn't seem to work with gradient calculation over hyperparameters.
+    """
+
+    def unraveled_fn(x):
+        params = QuadraticSurfaceParams.from_array(x)
+        return surface_fn(params, coord, mask)
+
+    p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
+    p0_array = p0.to_array()
+
+    lm = LevenbergMarquardt(residual_fun=unraveled_fn, maxiter=20, verbose=verbose)
+    opt = lm.run(p0_array)
+    return QuadraticSurfaceParams.from_array(opt.params)
+

+ 61 - 0
curvature_assembly/surface_fit_general.py

@@ -0,0 +1,61 @@
+from __future__ import annotations
+from jaxopt import GaussNewton, LevenbergMarquardt
+import jax.numpy as jnp
+from typing import Callable, TypeVar
+import jax
+from functools import partial
+from jax_md import dataclasses
+
+
+Array = jnp.ndarray
+T = TypeVar('T')
+
+
+@dataclasses.dataclass
+class GeneralQuadraticSurfaceParams:
+    quadratic_form_flat: Array = dataclasses.field(default_factory=lambda: jnp.array([1, 0, 0, 1, 0, 1]))
+    linear: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
+
+    @property
+    def quadratic_form(self):
+        a, b, c, d, e, f = self.quadratic_form_flat
+        return jnp.array([[a, b, c],
+                          [b, d, e],
+                          [c, e, f]])
+
+    # def to_array(self) -> Array:
+    #     return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
+
+    # @staticmethod
+    # def from_array(x: Array) -> QuadraticSurfaceParams:
+    #     return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
+
+
+@partial(jax.jit, static_argnums=(3,))
+def quadratic_surface(params: GeneralQuadraticSurfaceParams, coord: Array, mask: Array, constant: int = -1) -> Array:
+    """
+    Residual function for fitting a general quadric surface to a group of particles defined by coord array and a mask over coord.
+    """
+    if constant not in (-1, 0, 1):
+        raise ValueError(f"Quadratic surface constant should be -1, 0, or 1, got {constant}.")
+    quadratic_term = jnp.sum(coord * (coord @ params.quadratic_form), axis=1)
+    linear_term = jnp.sum(coord * params.linear, axis=1)
+    return (quadratic_term + linear_term + constant) * mask
+
+
+GeneralSurfaceFn = Callable[[GeneralQuadraticSurfaceParams, Array, Array], Array]
+
+
+def surface_fit_gn(surface_fn: GeneralSurfaceFn,
+                   coord: Array,
+                   mask: Array,
+                   p0: GeneralQuadraticSurfaceParams,
+                   verbose: bool = False) -> GeneralQuadraticSurfaceParams:
+    """
+    Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
+    """
+    gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
+    # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
+    p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
+    opt = gn.run(p0, coord, mask)
+    return opt.params

+ 41 - 0
curvature_assembly/util.py

@@ -0,0 +1,41 @@
+import jax.numpy as jnp
+import jax
+
+
+Array = jnp.ndarray
+f32 = jnp.float32
+
+
+def diagonal_mask(X: Array) -> Array:
+    """Sets the diagonal of a matrix to zero. A direct copy of jax_md.smap._diagonal_matrix()"""
+    if X.shape[0] != X.shape[1]:
+        raise ValueError(
+            'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
+                X.shape[0], X.shape[1]))
+    if len(X.shape) > 3:
+        raise ValueError(
+        ('Diagonal mask can only mask rank-2 or rank-3 tensors. '
+         'Found {}.'.format(len(X.shape))))
+    N = X.shape[0]
+    X = jnp.nan_to_num(X)
+    mask = f32(1.0) - jnp.eye(N, dtype=X.dtype)
+    if len(X.shape) == 3:
+        mask = jnp.reshape(mask, (N, N, 1))
+    return mask * X
+
+
+@jax.custom_vjp
+def print_grad(x):
+    return x
+
+
+def _print_grad_fwd(x):
+    return x, None
+
+
+def _print_grad_bwd(_, grad):
+    jax.debug.print("grad: {}", grad)
+    return (grad,)
+
+
+print_grad.defvjp(_print_grad_fwd, _print_grad_bwd)

+ 10 - 0
interaction_params.json

@@ -0,0 +1,10 @@
+{
+  "eigvals": [4.0, 4.0, 4.0],
+  "epsilon": 5.0,
+  "d0": 5.0,
+  "q0": 0.5,
+  "sigma": 1.0,
+  "alpha": 1.0,
+  "lm_magnitudes": 1.0,
+  "softness": 1.9
+}

+ 236 - 0
main.py

@@ -0,0 +1,236 @@
+from __future__ import annotations
+import time
+import sys
+import json
+from pathlib import Path
+
+with open(Path(sys.argv[1])) as config_file:
+    config_data = json.load(config_file)
+with open(Path(sys.argv[2])) as run_data_file:
+    run_params = json.load(run_data_file)
+with open(Path(sys.argv[3])) as int_param_file:
+    int_params = json.load(int_param_file)
+
+from jax import config
+import os
+
+if config_data['device'] == 'cpu':
+    os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={run_params["num_simulations"]}'
+elif config_data['device'] == 'gpu':
+    # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
+    # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
+    # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45"
+    pass
+else:
+    raise ValueError('Unknown device type, only "cpu" and "gpu" are supported.')
+
+config.update("jax_enable_x64", True)  # needs to be set before importing jax
+# config.update("jax_debug_nans", True)
+
+import jax
+from jax_md import space, quantity, rigid_body, dataclasses
+from curvature_assembly import oriented_particle, ellipsoid_contact, io_functions, simulation, data_protocols
+from curvature_assembly import cost_functions, energy, parallelization, fit, pytree_transf, patchy_interaction
+import optax
+from functools import partial
+import jax.numpy as jnp
+
+
+if config_data['save'] == 1:
+    SAVE = True
+elif config_data['save'] == 0:
+    SAVE = False
+else:
+    raise ValueError('save" configuration parameter must be 0 or 1.')
+
+if config_data['autodif_opt'] == 1:
+    only_forward_calculation = False
+elif config_data['autodif_opt'] == 0:
+    only_forward_calculation = True
+    run_params['num_iterations'] = 1  # overwrite number of iterations if we only do forward simulation
+else:
+    raise ValueError('"autodif_opt" configuration parameter must be 0 or 1.')
+
+
+def get_interaction_params(int_params_dict: dict) -> data_protocols.InteractionParams:
+    match config_data["type"]:
+        case "ferro":
+            interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
+        case "patchy":
+            interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
+            interaction_params = dataclasses.replace(interaction_params, q0=0.0)
+        case "quad":
+            interaction_params = energy.QuadWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
+        case _:
+            raise ValueError('Unknown simulation type, currently only "ferro", "quad", and "patchy" are supported.')
+    return interaction_params
+
+
+def quad_attraction_invariant(params, d0_start, q0_start):
+    params_dict = vars(params)
+    new_dict = params_dict.copy()
+
+    invariant = d0_start + q0_start ** 2
+    current = params_dict["d0"] + params_dict["q0"] ** 2
+
+    rescaling = invariant / current
+    new_dict["d0"] = params_dict["d0"] * rescaling
+    new_dict["q0"] = params_dict["q0"] * jnp.sqrt(rescaling)
+    return type(params)(**new_dict)
+
+
+def params_to_optimize_to_bounds_kwargs():
+    bounds_kwargs = {'lm_magnitudes': None}
+    try:  # backwards compatibility if optimize_params is not given
+        for param in run_params["optimize_params"]:
+            bounds_kwargs[param] = None
+    except KeyError:
+        pass
+    return bounds_kwargs
+
+
+def main():
+
+    results_folder = Path(config_data['results_base_folder'])
+    init_folder = Path(config_data['init_folder'])
+
+    simulation_params = simulation.NVTSimulationParams(num=run_params["num_particles"],
+                                                       density=run_params["density"],
+                                                       simulation_steps=run_params["simulation_steps"],
+                                                       dt=run_params["dt"],
+                                                       kT=run_params["kT"],
+                                                       config_every=run_params["config_every"],
+                                                       bptt_truncation=run_params["bptt_truncation"]
+                                                       )
+
+    # initialize interaction params and related optimization bounds
+    interaction_params = get_interaction_params(int_params)
+    lm_list = patchy_interaction.generate_lm_list(6, only_even_l=False, only_non_neg_m=False)
+    interaction_params = interaction_params.init_lm_magnitudes(patchy_interaction.init_lm_coefs(lm_list,
+                                                                                                [(2,0)]))
+    # interaction_params = interaction_params.init_unit_volume_particle()
+    lower_bounds, upper_bounds = fit.bounds_params(interaction_params, **params_to_optimize_to_bounds_kwargs())
+
+    # set displacement and shift functions
+    box_size_old = quantity.box_size_at_number_density(simulation_params.num,
+                                                       simulation_params.density,
+                                                       spatial_dimension=3)
+    box_size = oriented_particle.box_size_at_ellipsoid_density(simulation_params.num,
+                                                                simulation_params.density,
+                                                                interaction_params.eigvals)
+    displacement, shift = space.periodic(box_size)
+
+    # load initial config(s)
+    body = io_functions.load_multiple_initial_configs_single_object(simulation_params.num,
+                                                                    simulation_params.density,
+                                                                    [i for i in range(run_params["num_simulations"])],
+                                                                    init_folder,
+                                                                    coord_rescale_factor=box_size / box_size_old)
+
+    # define energy function
+    contact_fn = ellipsoid_contact.bp_contact_function
+    # contact_fn = ellipsoid_contact.pw_contact_function
+
+    match config_data["type"]:
+        case "ferro" | "patchy":
+            energy_fn = energy.ferro_wca_sphere_pair(displacement=displacement, lm=lm_list)
+        case "quad":
+            energy_fn = energy.quadrupolar_wca_sphere_pair(displacement=displacement, lm=lm_list)
+    energy_fn = jax.checkpoint(energy_fn)
+
+    # select cost function
+    cost_fn = cost_functions.CurvedClustersResidualsCost(displacement,
+                                                         box_size,
+                                                         contact_fn,
+                                                         target_radius=run_params["target_radius"],
+                                                         residuals_avg_type=run_params["residuals_average"],
+                                                         residuals_cost_factor=run_params["residuals_cost_factor"])
+
+    # initialize optimization saver
+    io_manager = io_functions.OptimizationSaver(results_folder.joinpath(Path(config_data['optimization_folder_name'])),
+                                                    simulation_params)
+    # save metadata
+    if SAVE:
+        io_manager.export_cost_function_info(cost_fn)
+        # a lot of run parameters are already saved as simulation_params, but it is easier to just save everything
+        io_manager.export_run_params(run_params)
+        io_manager.export_additional_simulation_data({'thermostat': config_data['thermostat'],
+                                                      'lm_list': lm_list})
+
+    ###########################################################################
+    # SIMULATION FUNCTION CONSTRUCTION
+    ###########################################################################
+
+    # prepare functions and auxiliary data container for the simulation
+    if config_data["thermostat"] == "langevin":
+        gamma = run_params["langevin_gamma"]
+        init_fn, step_fn, aux = simulation.setup_langevin(energy_fn, shift, simulation_params,
+                                                          gamma=rigid_body.RigidBody(gamma, gamma))
+    elif config_data["thermostat"] == "nose-hoover":
+        init_fn, step_fn, aux = simulation.setup_nose_hoover(energy_fn, shift, simulation_params,
+                                                             tau=simulation_params.dt * run_params["nose-hoover-tau"])
+    else:
+        raise NotImplementedError('Thermostat in config file should be "langevin" or "nose-hoover".')
+
+    # we must add additional dimension to aux, compatible with body leading dimension, used for parallelization
+    # and consistency over multiple evaluations of parallelized function
+    aux = pytree_transf.repeat_fields(aux, pytree_transf.data_length(body))
+
+    # create bptt simulation function
+    bptt_simulation = simulation.truncated_bptt_nvt_simulation(step_fn,
+                                                               energy_fn,
+                                                               cost_fn,
+                                                               simulation_params,
+                                                               only_forward_calculation=only_forward_calculation)
+    bptt_simulation = parallelization.pmap_segment_dispatch(jax.jit(bptt_simulation), map_argnums=(1, 2))
+
+    # create optimization configuration
+    optimizer = optax.adam(learning_rate=run_params["learning_rate"])
+    opt_state = optimizer.init(interaction_params)
+    param_rescalings = [partial(fit.normalize_param, param_name='lm_magnitudes'),]
+    if config_data["type"] == "quad":
+        param_rescalings.append(partial(quad_attraction_invariant,
+                                        d0_start=interaction_params.d0, q0_start=interaction_params.q0))
+
+    fit_step = fit.fit_bptt(bptt_simulation,
+                            optimizer.update,
+                            clipping=50,
+                            grad_time_weights=run_params["grad_time_weights"],
+                            param_rescalings=param_rescalings,
+                            lower_bounds=lower_bounds,
+                            upper_bounds=upper_bounds)
+
+    ##############################################################################
+    # RUN SIMULATION
+    ##############################################################################
+
+    init_keys = jax.random.split(jax.random.PRNGKey(0), pytree_transf.data_length(aux, axis=0))
+    md_state = pytree_transf.map_over_leading_leaf_dimension(partial(init_fn,
+                                                                     mass=simulation.ellipsoid_unit_mass(interaction_params.eigvals),
+                                                                     **vars(interaction_params)),
+                                                             init_keys, body)
+
+    # run optimization
+    for i in range(run_params["num_iterations"]):
+        if SAVE:
+            io_manager.export_interaction_params(interaction_params)
+
+        print(f'Params for iter {i}: ', interaction_params)
+
+        t0 = time.perf_counter()
+        interaction_params, opt_state, bptt_results, aux, grad_clipped = jax.block_until_ready(fit_step(interaction_params,
+                                                                                                        opt_state,
+                                                                                                        md_state,
+                                                                                                        aux))
+        t1 = time.perf_counter()
+        print(f'Simulation time: {t1 - t0}')
+
+        print(f'End cost: {bptt_results.cost[:, -1]}')
+
+        if SAVE:
+            io_manager.export_multiple_results(bptt_results, aux)
+            io_manager.export_clipped_gradients(grad_clipped)
+
+
+if __name__ == '__main__':
+    main()

+ 19 - 0
run_params.json

@@ -0,0 +1,19 @@
+{
+  "num_particles": 50,
+  "density": 0.03,
+  "simulation_steps": 100000,
+  "dt": 0.0001,
+  "kT": 1.0,
+  "config_every": 100,
+  "target_radius": 4.0,
+  "num_simulations": 4,
+  "num_iterations": 3,
+  "learning_rate": 0.01,
+  "bptt_truncation": 500,
+  "residuals_average": "linear",
+  "residuals_cost_factor": 1.0,
+  "grad_time_weights": "linear",
+  "langevin_gamma": 5.0,
+  "nose-hoover-tau": 50.0,
+  "optimize_params": ["softness", "q0"]
+}

+ 3 - 0
setup.py

@@ -0,0 +1,3 @@
+from setuptools import setup
+
+setup(name='curvature-assembly', version='1.0')

+ 13 - 0
template_config.json

@@ -0,0 +1,13 @@
+{
+    "type": "ferro",
+    "results_base_folder": "/path/to/my/results/directory/",
+    "optimization_folder_name": "optimization",
+    "validation_folder_name": "validation",
+    "forward_folder_name": "forward",
+    "init_folder": "/path/to/my/init_configs",
+    "save": 1,
+    "autodif_opt": 1,
+    "device": "cpu",
+    "thermostat": "nose-hoover"
+  }
+