surface_fit.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from __future__ import annotations
  2. from jaxopt import GaussNewton, LevenbergMarquardt
  3. import jax.numpy as jnp
  4. from typing import Callable, TypeVar
  5. import jax
  6. from functools import partial
  7. from jax_md import dataclasses
  8. Array = jnp.ndarray
  9. T = TypeVar('T')
  10. @dataclasses.dataclass
  11. class QuadraticSurfaceParams:
  12. center: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
  13. euler: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
  14. radius: float = 1.
  15. def to_array(self) -> Array:
  16. return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
  17. @staticmethod
  18. def from_array(x: Array) -> QuadraticSurfaceParams:
  19. return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
  20. def rotation_matrix(euler_angles: Array) -> Array:
  21. alpha, beta, gamma = euler_angles
  22. Rz1 = jnp.array([[jnp.cos(alpha), -jnp.sin(alpha), 0],
  23. [jnp.sin(alpha), jnp.cos(alpha), 0],
  24. [0, 0, 1]])
  25. Ry = jnp.array([[jnp.cos(beta), 0, -jnp.sin(beta)],
  26. [0, 1, 0],
  27. [jnp.sin(beta), 0, jnp.cos(beta)]])
  28. Rz2 = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0],
  29. [jnp.sin(gamma), jnp.cos(gamma), 0],
  30. [0, 0, 1]])
  31. return Rz2 @ Ry @ Rz1
  32. @partial(jnp.vectorize, signature='(d,d),(d)->(d,d)')
  33. def quadratic_form(rot_mat, eigvals: Array):
  34. a, b, c = eigvals
  35. eig_mat = jnp.array([[a, 0, 0],
  36. [0, b, 0],
  37. [0, 0, c]])
  38. return rot_mat @ eig_mat @ jnp.transpose(rot_mat)
  39. SurfaceFn = Callable[[QuadraticSurfaceParams, Array, Array], Array]
  40. def spherical_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array) -> Array:
  41. """
  42. Residual function for fitting a spherical surface to a group of particles defined by coord array and
  43. a mask over coord.
  44. """
  45. return (jnp.linalg.norm(coord - params.center, axis=1) - jnp.abs(params.radius)) * mask
  46. def quadratic_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array, qf_eigvals: Array) -> Array:
  47. """
  48. Residual function for fitting a cylinder to a group of particles defined by coord array and a mask over coord.
  49. """
  50. relative_coord = coord - params.center
  51. qf = quadratic_form(rotation_matrix(params.euler), qf_eigvals)
  52. # return (jnp.sum(relative_coord * (relative_coord @ qf), axis=1) ** 2 - jnp.abs(params.radius) ** 2) * mask
  53. return (jnp.sqrt(jnp.sum(relative_coord * (relative_coord @ qf), axis=1)) - jnp.abs(params.radius)) * mask
  54. cylindrical_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., 0.]))
  55. hyperbolic_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., -2.]))
  56. def surface_fit_gn(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: T, verbose: bool = False) -> T:
  57. """
  58. Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
  59. """
  60. gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
  61. # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
  62. p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
  63. opt = gn.run(p0, coord, mask)
  64. return opt.params
  65. def surface_fit_lm(surface_fn: SurfaceFn,
  66. coord: Array,
  67. mask: Array,
  68. p0: QuadraticSurfaceParams,
  69. verbose: bool = False) -> QuadraticSurfaceParams:
  70. """
  71. Fit a surface to a group of particles defined by coord array and a mask over coord
  72. using the Levenberg-Marquardt method. Doesn't seem to work with gradient calculation over hyperparameters.
  73. """
  74. def unraveled_fn(x):
  75. params = QuadraticSurfaceParams.from_array(x)
  76. return surface_fn(params, coord, mask)
  77. p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
  78. p0_array = p0.to_array()
  79. lm = LevenbergMarquardt(residual_fun=unraveled_fn, maxiter=20, verbose=verbose)
  80. opt = lm.run(p0_array)
  81. return QuadraticSurfaceParams.from_array(opt.params)