123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- from typing import Callable
- import jax.numpy as jnp
- from functools import partial
- import jax
- from jax_md import space
- from curvature_assembly import util
- import numpy as np
- Array = jnp.ndarray
- def construct_eigensystem(normal: Array) -> Array:
- """Construct eigensystem where z-axis is given by the surface normal at the given point."""
- psi = jnp.arccos(normal[2])
- phi = jnp.arctan2(normal[1], normal[0])
- return jnp.array([[-jnp.sin(phi), jnp.cos(phi), 0],
- [jnp.cos(psi) * jnp.cos(phi), jnp.cos(psi) * jnp.sin(phi), -jnp.sin(psi)],
- [normal[0], normal[1], normal[2]]]).T
- def position_to_local_coordinates(eigensystem: Array, center_position: Array, neighbors_coord_global: Array) -> Array:
- """
- Transform coordinates of neighbor particles to the coordinate system centered at center_position
- and axes given in eigensystem matrix.
- """
- return jnp.dot(neighbors_coord_global - center_position[None, :], eigensystem)
- def normal_to_local_coordinates(eigensystem: Array, neighbors_normal_global: Array) -> Array:
- """
- Transform normals at neighbor particle locations to the coordinate system centered at center_position
- and axes given in eigensystem matrix.
- """
- return jnp.dot(neighbors_normal_global, eigensystem)
- @partial(jnp.vectorize, signature='(d),(d)->()')
- def point_main_curvature(neighbor_coord_local: Array, neighbor_normal_local: Array) -> Array:
- """Normal curvature estimate given a point and its normal in the coordinate system of the center particle."""
- # 1e-8 makes it safe on diagonal where dist2d=0
- dist2d = jnp.sqrt(neighbor_coord_local[0] ** 2 + neighbor_coord_local[1] ** 2) + 1e-8
- n_xy = (neighbor_coord_local[0] * neighbor_normal_local[0] +
- neighbor_coord_local[1] * neighbor_normal_local[1]) / dist2d
- return -n_xy / (jnp.sqrt(n_xy ** 2 + neighbor_normal_local[2] ** 2) * dist2d)
- def construct_coefficient_matrix(neighbors_local: Array) -> Array:
- """Construct coefficient matrix for least square fitting of curvature parameters."""
- thetas = jnp.arctan2(neighbors_local[:, 1], neighbors_local[:, 0])
- sin_thetas = jnp.sin(thetas)
- cos_thetas = jnp.cos(thetas)
- coefficient_matrix = jnp.zeros(neighbors_local.shape)
- coefficient_matrix = coefficient_matrix.at[:, 0].set(cos_thetas ** 2)
- coefficient_matrix = coefficient_matrix.at[:, 1].set(cos_thetas * sin_thetas)
- coefficient_matrix = coefficient_matrix.at[:, 2].set(sin_thetas ** 2)
- return coefficient_matrix
- def principal_curvature(idx: int, coord: Array, normal: Array, neighbors: Array) -> (Array, Array):
- """
- Calculate the two principal curvatures at the `idx˙ particle.
- Algorithm from Zhang et al., "Curvature Estimation of 3D Point Cloud Surfaces Through the Fitting of Normal
- Section Curvatures", http://www.nlpr.ia.ac.cn/2008papers/gjhy/gh129.pdf
- """
- eigensystem = construct_eigensystem(normal[idx])
- coord_local = position_to_local_coordinates(eigensystem, coord[idx], coord)
- normal_local = normal_to_local_coordinates(eigensystem, normal)
- # jax.debug.print("{}", jnp.sum(neighbors[idx]))
- coefficient_matrix = construct_coefficient_matrix(coord_local) * neighbors[idx][:, None] # we add neighbors mask
- neighbor_curvatures = point_main_curvature(coord_local, normal_local) * neighbors[idx]
- # jax.debug.print("{}", neighbor_curvatures)
- # we get curvature parameters by least square fitting
- curve_params, residuals, rank, s = jnp.linalg.lstsq(coefficient_matrix, neighbor_curvatures)
- discriminant_sqrt = jnp.sqrt((curve_params[0] - curve_params[2]) ** 2 + 4 * curve_params[1] ** 2)
- curvature1 = 0.5 * (curve_params[0] + curve_params[2] - discriminant_sqrt)
- curvature2 = 0.5 * (curve_params[0] + curve_params[2] + discriminant_sqrt)
- # jax.debug.print("{}, {}", curvature1, curvature2)
- return curvature1, curvature2
- @jax.jit
- def minimum_spanning_tree(normals: Array, neighbors_full: Array) -> Array:
- """
- Determine minimum spanning tree for a graph with links based on neighbors and weights (energies) 1 - |ni.nj| using
- the Prim's algorithm.
- Args:
- normals: normal vectors at each node, shape (N, 3)
- neighbors_full: boolean matrix of shape (N, N) where True elements describe neighbor particles
- Returns:
- minimum spanning tree of the graph
- """
- num_nodes = normals.shape[0]
- # calculate weights, w[i, j] between 0 and 1
- weights = 1. - jnp.abs(jnp.einsum('nmk, nmk -> nm', normals[None, :, :], normals[:, None, :]))
- weights = weights + 10. * (1 - neighbors_full) # effectively erases connections between nodes
- selected_nodes = jnp.zeros(num_nodes, dtype=bool)
- selected_nodes = selected_nodes.at[0].set(True)
- node_order = jnp.zeros(num_nodes, dtype=jnp.int32)
- def add_node(carry, i):
- node_order, selected_nodes, weights = carry
- # find minimum energy link among all links connected to the already selected nodes of the spanning tree
- # argmin returns first occurrence in flattened array which we then unravel
- min_idx = jnp.unravel_index(jnp.argmin(weights[node_order]), shape=(num_nodes, num_nodes))
- # the next idx in the spanning tree will be the second element in min_idx
- selected_nodes = selected_nodes.at[min_idx[1]].set(True)
- node_order = node_order.at[i].set(min_idx[1])
- # erase possible remaining connections between all selected nodes to prevent formation of loops
- mask = selected_nodes[:, None] * selected_nodes[None, :]
- weights += 1. * mask
- return (node_order, selected_nodes, weights), jnp.array([node_order[min_idx[0]], min_idx[1]])
- _, link_list = jax.lax.scan(add_node, init=(node_order, selected_nodes, weights), xs=jnp.arange(1, num_nodes))
- return link_list
- @jax.jit
- def determine_normals(coordinates: Array, neighbors: Array) -> Array:
- """
- Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
- Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
- https://dl.acm.org/doi/pdf/10.1145/133994.134011
- """
- def get_normal(neighbor_mask):
- n_part = jnp.sum(neighbor_mask)
- neighbor_coord = coordinates * neighbor_mask[:, None]
- centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
- matrix = centered.T @ centered
- values, vecs = jnp.linalg.eigh(matrix)
- return vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue
- normals = jax.vmap(get_normal)(neighbors)
- link_list = minimum_spanning_tree(normals, neighbors)
- def update_normals(normals, idx):
- normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
- lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
- return normals, 0.
- consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
- return consistent_normals
- @jax.jit
- def determine_tangent_planes(coordinates: Array, neighbors: Array) -> (Array, Array):
- """
- Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
- Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
- https://dl.acm.org/doi/pdf/10.1145/133994.134011
- """
- def get_normal(neighbor_mask):
- n_part = jnp.sum(neighbor_mask)
- neighbor_coord = coordinates * neighbor_mask[:, None]
- centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
- matrix = centered.T @ centered
- values, vecs = jnp.linalg.eigh(matrix)
- return jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8), vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue
- centers, normals = jax.vmap(get_normal)(neighbors)
- link_list = minimum_spanning_tree(normals, neighbors)
- def update_normals(normals, idx):
- normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
- lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
- return normals, 0.
- consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
- return centers, consistent_normals
- def nearest_neighbors_fn(displacement_or_metric: space.DisplacementOrMetricFn,
- dist_cutoff: float) -> Callable[[Array], Array]:
- """
- Construct a function that determines boolean nearest neighbor matrix with dimensions (N, N)
- where element neighbors[i,j] is True only if particle j is among num_neighbors nearest neighbors of particle i.
- The returned array is NOT necessarily symmetric.
- """
- metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
- def condition(d):
- return jax.lax.cond(0. < d < dist_cutoff, lambda x: 1., lambda x: 0., 0)
- # @jax.jit
- def nearest_neighbors_matrix(coord: Array) -> Array:
- dist = space.map_product(metric)(coord, coord)
- # jax.debug.print("Num neighbors: {}", jnp.sum(dist < dist_cutoff, axis=1))
- # jax.debug.print("Determinant: {}", jnp.linalg.det(util.diagonal_mask(dist < dist_cutoff)))
- # return util.diagonal_mask(dist < dist_cutoff)
- return dist < dist_cutoff
- # return util.diagonal_mask(jnp.asarray(np.random.randint(2, size=(coord.shape[0], coord.shape[0]))))
- # neighbors = jax.vmap(jax.vmap(condition))(dist)
- # return neighbors
- return nearest_neighbors_matrix
- def edge_detection_fn(displacement_or_metric: space.DisplacementOrMetricFn,
- classification_threshold: float,
- num_neighbors: int = 12) -> Callable[[Array, Array], Array]:
- """
- Construct edge detection function based on the distance between the center of mass of all the neighbors
- and the particle of interest. Adapted from: https://arxiv.org/pdf/1809.10468.pdf
- Args:
- displacement_or_metric: displacement or metric function
- classification_threshold: factor of resolution (the smallest distance between a particle and one of its
- neighbors) that determines if the center of mass for all neighboring particles is too far which classifies
- the particle as an edge particle
- num_neighbors: number of neighbor particles taken into account
- Returns:
- Edge detection function that takes the coordinates of all particles and boolean neighboring matrix
- """
- metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
- @jax.jit
- def detect_edge_particles(coordinates: Array, neighbors: Array) -> Array:
- num_particles = coordinates.shape[0]
- dist_fn = jnp.vectorize(metric, signature='(d),(d)->()')
- def vmap_fn(idx):
- # fill value in jnp.where doesn't matter as we specify the exact number of neighbors
- neighbor_coord = coordinates[jnp.where(neighbors[idx], size=num_neighbors, fill_value=num_particles)]
- neigh_cm = jnp.mean(neighbor_coord, axis=0)
- dist = dist_fn(neighbor_coord, coordinates[idx])
- resolution = jnp.min(dist)
- edge_particle = jax.lax.cond(dist_fn(neigh_cm, coordinates[idx]) > classification_threshold * resolution,
- lambda x: True, lambda x: False, 0.)
- return edge_particle
- edge_particles = jax.vmap(vmap_fn)(jnp.arange(num_particles))
- return edge_particles
- return detect_edge_particles
- def gaussian_curvature_fn(displacement_or_metric: space.DisplacementOrMetricFn,
- dist_cutoff: float) -> Callable[[Array], Array]:
- """Construct a function that returns the gaussian curvature at each given point in a point cloud."""
- nearest_neighbors = nearest_neighbors_fn(displacement_or_metric, dist_cutoff=dist_cutoff)
- def gaussian_curvature(coord: Array) -> Array:
- num_particles = coord.shape[0]
- neighbors = nearest_neighbors(coord)
- # normals = determine_normals(coord, neighbors)
- centers, normals = determine_tangent_planes(coord, neighbors)
- # print(coord)
- curvature1, curvature2 = jax.vmap(partial(principal_curvature,
- coord=centers, # or coord=centers
- normal=normals,
- neighbors=util.diagonal_mask(neighbors)))(jnp.arange(num_particles))
- return curvature1 * curvature2
- return gaussian_curvature
|