surface_fit_general.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from __future__ import annotations
  2. from jaxopt import GaussNewton, LevenbergMarquardt
  3. import jax.numpy as jnp
  4. from typing import Callable, TypeVar
  5. import jax
  6. from functools import partial
  7. from jax_md import dataclasses
  8. Array = jnp.ndarray
  9. T = TypeVar('T')
  10. @dataclasses.dataclass
  11. class GeneralQuadraticSurfaceParams:
  12. quadratic_form_flat: Array = dataclasses.field(default_factory=lambda: jnp.array([1, 0, 0, 1, 0, 1]))
  13. linear: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
  14. @property
  15. def quadratic_form(self):
  16. a, b, c, d, e, f = self.quadratic_form_flat
  17. return jnp.array([[a, b, c],
  18. [b, d, e],
  19. [c, e, f]])
  20. # def to_array(self) -> Array:
  21. # return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
  22. # @staticmethod
  23. # def from_array(x: Array) -> QuadraticSurfaceParams:
  24. # return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
  25. @partial(jax.jit, static_argnums=(3,))
  26. def quadratic_surface(params: GeneralQuadraticSurfaceParams, coord: Array, mask: Array, constant: int = -1) -> Array:
  27. """
  28. Residual function for fitting a general quadric surface to a group of particles defined by coord array and a mask over coord.
  29. """
  30. if constant not in (-1, 0, 1):
  31. raise ValueError(f"Quadratic surface constant should be -1, 0, or 1, got {constant}.")
  32. quadratic_term = jnp.sum(coord * (coord @ params.quadratic_form), axis=1)
  33. linear_term = jnp.sum(coord * params.linear, axis=1)
  34. return (quadratic_term + linear_term + constant) * mask
  35. GeneralSurfaceFn = Callable[[GeneralQuadraticSurfaceParams, Array, Array], Array]
  36. def surface_fit_gn(surface_fn: GeneralSurfaceFn,
  37. coord: Array,
  38. mask: Array,
  39. p0: GeneralQuadraticSurfaceParams,
  40. verbose: bool = False) -> GeneralQuadraticSurfaceParams:
  41. """
  42. Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
  43. """
  44. gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
  45. # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
  46. p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
  47. opt = gn.run(p0, coord, mask)
  48. return opt.params