from typing import Callable import jax.numpy as jnp from functools import partial import jax from jax_md import space from curvature_assembly import util import numpy as np Array = jnp.ndarray def construct_eigensystem(normal: Array) -> Array: """Construct eigensystem where z-axis is given by the surface normal at the given point.""" psi = jnp.arccos(normal[2]) phi = jnp.arctan2(normal[1], normal[0]) return jnp.array([[-jnp.sin(phi), jnp.cos(phi), 0], [jnp.cos(psi) * jnp.cos(phi), jnp.cos(psi) * jnp.sin(phi), -jnp.sin(psi)], [normal[0], normal[1], normal[2]]]).T def position_to_local_coordinates(eigensystem: Array, center_position: Array, neighbors_coord_global: Array) -> Array: """ Transform coordinates of neighbor particles to the coordinate system centered at center_position and axes given in eigensystem matrix. """ return jnp.dot(neighbors_coord_global - center_position[None, :], eigensystem) def normal_to_local_coordinates(eigensystem: Array, neighbors_normal_global: Array) -> Array: """ Transform normals at neighbor particle locations to the coordinate system centered at center_position and axes given in eigensystem matrix. """ return jnp.dot(neighbors_normal_global, eigensystem) @partial(jnp.vectorize, signature='(d),(d)->()') def point_main_curvature(neighbor_coord_local: Array, neighbor_normal_local: Array) -> Array: """Normal curvature estimate given a point and its normal in the coordinate system of the center particle.""" # 1e-8 makes it safe on diagonal where dist2d=0 dist2d = jnp.sqrt(neighbor_coord_local[0] ** 2 + neighbor_coord_local[1] ** 2) + 1e-8 n_xy = (neighbor_coord_local[0] * neighbor_normal_local[0] + neighbor_coord_local[1] * neighbor_normal_local[1]) / dist2d return -n_xy / (jnp.sqrt(n_xy ** 2 + neighbor_normal_local[2] ** 2) * dist2d) def construct_coefficient_matrix(neighbors_local: Array) -> Array: """Construct coefficient matrix for least square fitting of curvature parameters.""" thetas = jnp.arctan2(neighbors_local[:, 1], neighbors_local[:, 0]) sin_thetas = jnp.sin(thetas) cos_thetas = jnp.cos(thetas) coefficient_matrix = jnp.zeros(neighbors_local.shape) coefficient_matrix = coefficient_matrix.at[:, 0].set(cos_thetas ** 2) coefficient_matrix = coefficient_matrix.at[:, 1].set(cos_thetas * sin_thetas) coefficient_matrix = coefficient_matrix.at[:, 2].set(sin_thetas ** 2) return coefficient_matrix def principal_curvature(idx: int, coord: Array, normal: Array, neighbors: Array) -> (Array, Array): """ Calculate the two principal curvatures at the `idx˙ particle. Algorithm from Zhang et al., "Curvature Estimation of 3D Point Cloud Surfaces Through the Fitting of Normal Section Curvatures", http://www.nlpr.ia.ac.cn/2008papers/gjhy/gh129.pdf """ eigensystem = construct_eigensystem(normal[idx]) coord_local = position_to_local_coordinates(eigensystem, coord[idx], coord) normal_local = normal_to_local_coordinates(eigensystem, normal) # jax.debug.print("{}", jnp.sum(neighbors[idx])) coefficient_matrix = construct_coefficient_matrix(coord_local) * neighbors[idx][:, None] # we add neighbors mask neighbor_curvatures = point_main_curvature(coord_local, normal_local) * neighbors[idx] # jax.debug.print("{}", neighbor_curvatures) # we get curvature parameters by least square fitting curve_params, residuals, rank, s = jnp.linalg.lstsq(coefficient_matrix, neighbor_curvatures) discriminant_sqrt = jnp.sqrt((curve_params[0] - curve_params[2]) ** 2 + 4 * curve_params[1] ** 2) curvature1 = 0.5 * (curve_params[0] + curve_params[2] - discriminant_sqrt) curvature2 = 0.5 * (curve_params[0] + curve_params[2] + discriminant_sqrt) # jax.debug.print("{}, {}", curvature1, curvature2) return curvature1, curvature2 @jax.jit def minimum_spanning_tree(normals: Array, neighbors_full: Array) -> Array: """ Determine minimum spanning tree for a graph with links based on neighbors and weights (energies) 1 - |ni.nj| using the Prim's algorithm. Args: normals: normal vectors at each node, shape (N, 3) neighbors_full: boolean matrix of shape (N, N) where True elements describe neighbor particles Returns: minimum spanning tree of the graph """ num_nodes = normals.shape[0] # calculate weights, w[i, j] between 0 and 1 weights = 1. - jnp.abs(jnp.einsum('nmk, nmk -> nm', normals[None, :, :], normals[:, None, :])) weights = weights + 10. * (1 - neighbors_full) # effectively erases connections between nodes selected_nodes = jnp.zeros(num_nodes, dtype=bool) selected_nodes = selected_nodes.at[0].set(True) node_order = jnp.zeros(num_nodes, dtype=jnp.int32) def add_node(carry, i): node_order, selected_nodes, weights = carry # find minimum energy link among all links connected to the already selected nodes of the spanning tree # argmin returns first occurrence in flattened array which we then unravel min_idx = jnp.unravel_index(jnp.argmin(weights[node_order]), shape=(num_nodes, num_nodes)) # the next idx in the spanning tree will be the second element in min_idx selected_nodes = selected_nodes.at[min_idx[1]].set(True) node_order = node_order.at[i].set(min_idx[1]) # erase possible remaining connections between all selected nodes to prevent formation of loops mask = selected_nodes[:, None] * selected_nodes[None, :] weights += 1. * mask return (node_order, selected_nodes, weights), jnp.array([node_order[min_idx[0]], min_idx[1]]) _, link_list = jax.lax.scan(add_node, init=(node_order, selected_nodes, weights), xs=jnp.arange(1, num_nodes)) return link_list @jax.jit def determine_normals(coordinates: Array, neighbors: Array) -> Array: """ Determine local surface normals at each point with PCA and ensuring consistent surface orientation. Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points", https://dl.acm.org/doi/pdf/10.1145/133994.134011 """ def get_normal(neighbor_mask): n_part = jnp.sum(neighbor_mask) neighbor_coord = coordinates * neighbor_mask[:, None] centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None] matrix = centered.T @ centered values, vecs = jnp.linalg.eigh(matrix) return vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue normals = jax.vmap(get_normal)(neighbors) link_list = minimum_spanning_tree(normals, neighbors) def update_normals(normals, idx): normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0, lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals) return normals, 0. consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list))) return consistent_normals @jax.jit def determine_tangent_planes(coordinates: Array, neighbors: Array) -> (Array, Array): """ Determine local surface normals at each point with PCA and ensuring consistent surface orientation. Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points", https://dl.acm.org/doi/pdf/10.1145/133994.134011 """ def get_normal(neighbor_mask): n_part = jnp.sum(neighbor_mask) neighbor_coord = coordinates * neighbor_mask[:, None] centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None] matrix = centered.T @ centered values, vecs = jnp.linalg.eigh(matrix) return jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8), vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue centers, normals = jax.vmap(get_normal)(neighbors) link_list = minimum_spanning_tree(normals, neighbors) def update_normals(normals, idx): normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0, lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals) return normals, 0. consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list))) return centers, consistent_normals def nearest_neighbors_fn(displacement_or_metric: space.DisplacementOrMetricFn, dist_cutoff: float) -> Callable[[Array], Array]: """ Construct a function that determines boolean nearest neighbor matrix with dimensions (N, N) where element neighbors[i,j] is True only if particle j is among num_neighbors nearest neighbors of particle i. The returned array is NOT necessarily symmetric. """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) def condition(d): return jax.lax.cond(0. < d < dist_cutoff, lambda x: 1., lambda x: 0., 0) # @jax.jit def nearest_neighbors_matrix(coord: Array) -> Array: dist = space.map_product(metric)(coord, coord) # jax.debug.print("Num neighbors: {}", jnp.sum(dist < dist_cutoff, axis=1)) # jax.debug.print("Determinant: {}", jnp.linalg.det(util.diagonal_mask(dist < dist_cutoff))) # return util.diagonal_mask(dist < dist_cutoff) return dist < dist_cutoff # return util.diagonal_mask(jnp.asarray(np.random.randint(2, size=(coord.shape[0], coord.shape[0])))) # neighbors = jax.vmap(jax.vmap(condition))(dist) # return neighbors return nearest_neighbors_matrix def edge_detection_fn(displacement_or_metric: space.DisplacementOrMetricFn, classification_threshold: float, num_neighbors: int = 12) -> Callable[[Array, Array], Array]: """ Construct edge detection function based on the distance between the center of mass of all the neighbors and the particle of interest. Adapted from: https://arxiv.org/pdf/1809.10468.pdf Args: displacement_or_metric: displacement or metric function classification_threshold: factor of resolution (the smallest distance between a particle and one of its neighbors) that determines if the center of mass for all neighboring particles is too far which classifies the particle as an edge particle num_neighbors: number of neighbor particles taken into account Returns: Edge detection function that takes the coordinates of all particles and boolean neighboring matrix """ metric = space.canonicalize_displacement_or_metric(displacement_or_metric) @jax.jit def detect_edge_particles(coordinates: Array, neighbors: Array) -> Array: num_particles = coordinates.shape[0] dist_fn = jnp.vectorize(metric, signature='(d),(d)->()') def vmap_fn(idx): # fill value in jnp.where doesn't matter as we specify the exact number of neighbors neighbor_coord = coordinates[jnp.where(neighbors[idx], size=num_neighbors, fill_value=num_particles)] neigh_cm = jnp.mean(neighbor_coord, axis=0) dist = dist_fn(neighbor_coord, coordinates[idx]) resolution = jnp.min(dist) edge_particle = jax.lax.cond(dist_fn(neigh_cm, coordinates[idx]) > classification_threshold * resolution, lambda x: True, lambda x: False, 0.) return edge_particle edge_particles = jax.vmap(vmap_fn)(jnp.arange(num_particles)) return edge_particles return detect_edge_particles def gaussian_curvature_fn(displacement_or_metric: space.DisplacementOrMetricFn, dist_cutoff: float) -> Callable[[Array], Array]: """Construct a function that returns the gaussian curvature at each given point in a point cloud.""" nearest_neighbors = nearest_neighbors_fn(displacement_or_metric, dist_cutoff=dist_cutoff) def gaussian_curvature(coord: Array) -> Array: num_particles = coord.shape[0] neighbors = nearest_neighbors(coord) # normals = determine_normals(coord, neighbors) centers, normals = determine_tangent_planes(coord, neighbors) # print(coord) curvature1, curvature2 = jax.vmap(partial(principal_curvature, coord=centers, # or coord=centers normal=normals, neighbors=util.diagonal_mask(neighbors)))(jnp.arange(num_particles)) return curvature1 * curvature2 return gaussian_curvature