curvature_estimation.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from typing import Callable
  2. import jax.numpy as jnp
  3. from functools import partial
  4. import jax
  5. from jax_md import space
  6. from curvature_assembly import util
  7. import numpy as np
  8. Array = jnp.ndarray
  9. def construct_eigensystem(normal: Array) -> Array:
  10. """Construct eigensystem where z-axis is given by the surface normal at the given point."""
  11. psi = jnp.arccos(normal[2])
  12. phi = jnp.arctan2(normal[1], normal[0])
  13. return jnp.array([[-jnp.sin(phi), jnp.cos(phi), 0],
  14. [jnp.cos(psi) * jnp.cos(phi), jnp.cos(psi) * jnp.sin(phi), -jnp.sin(psi)],
  15. [normal[0], normal[1], normal[2]]]).T
  16. def position_to_local_coordinates(eigensystem: Array, center_position: Array, neighbors_coord_global: Array) -> Array:
  17. """
  18. Transform coordinates of neighbor particles to the coordinate system centered at center_position
  19. and axes given in eigensystem matrix.
  20. """
  21. return jnp.dot(neighbors_coord_global - center_position[None, :], eigensystem)
  22. def normal_to_local_coordinates(eigensystem: Array, neighbors_normal_global: Array) -> Array:
  23. """
  24. Transform normals at neighbor particle locations to the coordinate system centered at center_position
  25. and axes given in eigensystem matrix.
  26. """
  27. return jnp.dot(neighbors_normal_global, eigensystem)
  28. @partial(jnp.vectorize, signature='(d),(d)->()')
  29. def point_main_curvature(neighbor_coord_local: Array, neighbor_normal_local: Array) -> Array:
  30. """Normal curvature estimate given a point and its normal in the coordinate system of the center particle."""
  31. # 1e-8 makes it safe on diagonal where dist2d=0
  32. dist2d = jnp.sqrt(neighbor_coord_local[0] ** 2 + neighbor_coord_local[1] ** 2) + 1e-8
  33. n_xy = (neighbor_coord_local[0] * neighbor_normal_local[0] +
  34. neighbor_coord_local[1] * neighbor_normal_local[1]) / dist2d
  35. return -n_xy / (jnp.sqrt(n_xy ** 2 + neighbor_normal_local[2] ** 2) * dist2d)
  36. def construct_coefficient_matrix(neighbors_local: Array) -> Array:
  37. """Construct coefficient matrix for least square fitting of curvature parameters."""
  38. thetas = jnp.arctan2(neighbors_local[:, 1], neighbors_local[:, 0])
  39. sin_thetas = jnp.sin(thetas)
  40. cos_thetas = jnp.cos(thetas)
  41. coefficient_matrix = jnp.zeros(neighbors_local.shape)
  42. coefficient_matrix = coefficient_matrix.at[:, 0].set(cos_thetas ** 2)
  43. coefficient_matrix = coefficient_matrix.at[:, 1].set(cos_thetas * sin_thetas)
  44. coefficient_matrix = coefficient_matrix.at[:, 2].set(sin_thetas ** 2)
  45. return coefficient_matrix
  46. def principal_curvature(idx: int, coord: Array, normal: Array, neighbors: Array) -> (Array, Array):
  47. """
  48. Calculate the two principal curvatures at the `idx˙ particle.
  49. Algorithm from Zhang et al., "Curvature Estimation of 3D Point Cloud Surfaces Through the Fitting of Normal
  50. Section Curvatures", http://www.nlpr.ia.ac.cn/2008papers/gjhy/gh129.pdf
  51. """
  52. eigensystem = construct_eigensystem(normal[idx])
  53. coord_local = position_to_local_coordinates(eigensystem, coord[idx], coord)
  54. normal_local = normal_to_local_coordinates(eigensystem, normal)
  55. # jax.debug.print("{}", jnp.sum(neighbors[idx]))
  56. coefficient_matrix = construct_coefficient_matrix(coord_local) * neighbors[idx][:, None] # we add neighbors mask
  57. neighbor_curvatures = point_main_curvature(coord_local, normal_local) * neighbors[idx]
  58. # jax.debug.print("{}", neighbor_curvatures)
  59. # we get curvature parameters by least square fitting
  60. curve_params, residuals, rank, s = jnp.linalg.lstsq(coefficient_matrix, neighbor_curvatures)
  61. discriminant_sqrt = jnp.sqrt((curve_params[0] - curve_params[2]) ** 2 + 4 * curve_params[1] ** 2)
  62. curvature1 = 0.5 * (curve_params[0] + curve_params[2] - discriminant_sqrt)
  63. curvature2 = 0.5 * (curve_params[0] + curve_params[2] + discriminant_sqrt)
  64. # jax.debug.print("{}, {}", curvature1, curvature2)
  65. return curvature1, curvature2
  66. @jax.jit
  67. def minimum_spanning_tree(normals: Array, neighbors_full: Array) -> Array:
  68. """
  69. Determine minimum spanning tree for a graph with links based on neighbors and weights (energies) 1 - |ni.nj| using
  70. the Prim's algorithm.
  71. Args:
  72. normals: normal vectors at each node, shape (N, 3)
  73. neighbors_full: boolean matrix of shape (N, N) where True elements describe neighbor particles
  74. Returns:
  75. minimum spanning tree of the graph
  76. """
  77. num_nodes = normals.shape[0]
  78. # calculate weights, w[i, j] between 0 and 1
  79. weights = 1. - jnp.abs(jnp.einsum('nmk, nmk -> nm', normals[None, :, :], normals[:, None, :]))
  80. weights = weights + 10. * (1 - neighbors_full) # effectively erases connections between nodes
  81. selected_nodes = jnp.zeros(num_nodes, dtype=bool)
  82. selected_nodes = selected_nodes.at[0].set(True)
  83. node_order = jnp.zeros(num_nodes, dtype=jnp.int32)
  84. def add_node(carry, i):
  85. node_order, selected_nodes, weights = carry
  86. # find minimum energy link among all links connected to the already selected nodes of the spanning tree
  87. # argmin returns first occurrence in flattened array which we then unravel
  88. min_idx = jnp.unravel_index(jnp.argmin(weights[node_order]), shape=(num_nodes, num_nodes))
  89. # the next idx in the spanning tree will be the second element in min_idx
  90. selected_nodes = selected_nodes.at[min_idx[1]].set(True)
  91. node_order = node_order.at[i].set(min_idx[1])
  92. # erase possible remaining connections between all selected nodes to prevent formation of loops
  93. mask = selected_nodes[:, None] * selected_nodes[None, :]
  94. weights += 1. * mask
  95. return (node_order, selected_nodes, weights), jnp.array([node_order[min_idx[0]], min_idx[1]])
  96. _, link_list = jax.lax.scan(add_node, init=(node_order, selected_nodes, weights), xs=jnp.arange(1, num_nodes))
  97. return link_list
  98. @jax.jit
  99. def determine_normals(coordinates: Array, neighbors: Array) -> Array:
  100. """
  101. Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
  102. Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
  103. https://dl.acm.org/doi/pdf/10.1145/133994.134011
  104. """
  105. def get_normal(neighbor_mask):
  106. n_part = jnp.sum(neighbor_mask)
  107. neighbor_coord = coordinates * neighbor_mask[:, None]
  108. centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
  109. matrix = centered.T @ centered
  110. values, vecs = jnp.linalg.eigh(matrix)
  111. return vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue
  112. normals = jax.vmap(get_normal)(neighbors)
  113. link_list = minimum_spanning_tree(normals, neighbors)
  114. def update_normals(normals, idx):
  115. normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
  116. lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
  117. return normals, 0.
  118. consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
  119. return consistent_normals
  120. @jax.jit
  121. def determine_tangent_planes(coordinates: Array, neighbors: Array) -> (Array, Array):
  122. """
  123. Determine local surface normals at each point with PCA and ensuring consistent surface orientation.
  124. Algorithm from Hoppe et al., "Surface Reconstruction from Unorganized Points",
  125. https://dl.acm.org/doi/pdf/10.1145/133994.134011
  126. """
  127. def get_normal(neighbor_mask):
  128. n_part = jnp.sum(neighbor_mask)
  129. neighbor_coord = coordinates * neighbor_mask[:, None]
  130. centered = (neighbor_coord - jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8)) * neighbor_mask[:, None]
  131. matrix = centered.T @ centered
  132. values, vecs = jnp.linalg.eigh(matrix)
  133. return jnp.sum(neighbor_coord, axis=0) / (n_part + 1e-8), vecs[:, 0] # eigenvector corresponding to the smallest eigenvalue
  134. centers, normals = jax.vmap(get_normal)(neighbors)
  135. link_list = minimum_spanning_tree(normals, neighbors)
  136. def update_normals(normals, idx):
  137. normals = jax.lax.cond(jnp.sum(normals[link_list[idx, 0]] * normals[link_list[idx, 1]]) < 0,
  138. lambda x: x.at[link_list[idx, 1]].set(-x[link_list[idx, 1]]), lambda x: x, normals)
  139. return normals, 0.
  140. consistent_normals, _ = jax.lax.scan(update_normals, init=normals, xs=jnp.arange(len(link_list)))
  141. return centers, consistent_normals
  142. def nearest_neighbors_fn(displacement_or_metric: space.DisplacementOrMetricFn,
  143. dist_cutoff: float) -> Callable[[Array], Array]:
  144. """
  145. Construct a function that determines boolean nearest neighbor matrix with dimensions (N, N)
  146. where element neighbors[i,j] is True only if particle j is among num_neighbors nearest neighbors of particle i.
  147. The returned array is NOT necessarily symmetric.
  148. """
  149. metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
  150. def condition(d):
  151. return jax.lax.cond(0. < d < dist_cutoff, lambda x: 1., lambda x: 0., 0)
  152. # @jax.jit
  153. def nearest_neighbors_matrix(coord: Array) -> Array:
  154. dist = space.map_product(metric)(coord, coord)
  155. # jax.debug.print("Num neighbors: {}", jnp.sum(dist < dist_cutoff, axis=1))
  156. # jax.debug.print("Determinant: {}", jnp.linalg.det(util.diagonal_mask(dist < dist_cutoff)))
  157. # return util.diagonal_mask(dist < dist_cutoff)
  158. return dist < dist_cutoff
  159. # return util.diagonal_mask(jnp.asarray(np.random.randint(2, size=(coord.shape[0], coord.shape[0]))))
  160. # neighbors = jax.vmap(jax.vmap(condition))(dist)
  161. # return neighbors
  162. return nearest_neighbors_matrix
  163. def edge_detection_fn(displacement_or_metric: space.DisplacementOrMetricFn,
  164. classification_threshold: float,
  165. num_neighbors: int = 12) -> Callable[[Array, Array], Array]:
  166. """
  167. Construct edge detection function based on the distance between the center of mass of all the neighbors
  168. and the particle of interest. Adapted from: https://arxiv.org/pdf/1809.10468.pdf
  169. Args:
  170. displacement_or_metric: displacement or metric function
  171. classification_threshold: factor of resolution (the smallest distance between a particle and one of its
  172. neighbors) that determines if the center of mass for all neighboring particles is too far which classifies
  173. the particle as an edge particle
  174. num_neighbors: number of neighbor particles taken into account
  175. Returns:
  176. Edge detection function that takes the coordinates of all particles and boolean neighboring matrix
  177. """
  178. metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
  179. @jax.jit
  180. def detect_edge_particles(coordinates: Array, neighbors: Array) -> Array:
  181. num_particles = coordinates.shape[0]
  182. dist_fn = jnp.vectorize(metric, signature='(d),(d)->()')
  183. def vmap_fn(idx):
  184. # fill value in jnp.where doesn't matter as we specify the exact number of neighbors
  185. neighbor_coord = coordinates[jnp.where(neighbors[idx], size=num_neighbors, fill_value=num_particles)]
  186. neigh_cm = jnp.mean(neighbor_coord, axis=0)
  187. dist = dist_fn(neighbor_coord, coordinates[idx])
  188. resolution = jnp.min(dist)
  189. edge_particle = jax.lax.cond(dist_fn(neigh_cm, coordinates[idx]) > classification_threshold * resolution,
  190. lambda x: True, lambda x: False, 0.)
  191. return edge_particle
  192. edge_particles = jax.vmap(vmap_fn)(jnp.arange(num_particles))
  193. return edge_particles
  194. return detect_edge_particles
  195. def gaussian_curvature_fn(displacement_or_metric: space.DisplacementOrMetricFn,
  196. dist_cutoff: float) -> Callable[[Array], Array]:
  197. """Construct a function that returns the gaussian curvature at each given point in a point cloud."""
  198. nearest_neighbors = nearest_neighbors_fn(displacement_or_metric, dist_cutoff=dist_cutoff)
  199. def gaussian_curvature(coord: Array) -> Array:
  200. num_particles = coord.shape[0]
  201. neighbors = nearest_neighbors(coord)
  202. # normals = determine_normals(coord, neighbors)
  203. centers, normals = determine_tangent_planes(coord, neighbors)
  204. # print(coord)
  205. curvature1, curvature2 = jax.vmap(partial(principal_curvature,
  206. coord=centers, # or coord=centers
  207. normal=normals,
  208. neighbors=util.diagonal_mask(neighbors)))(jnp.arange(num_particles))
  209. return curvature1 * curvature2
  210. return gaussian_curvature