12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- from __future__ import annotations
- from jaxopt import GaussNewton, LevenbergMarquardt
- import jax.numpy as jnp
- from typing import Callable, TypeVar
- import jax
- from functools import partial
- from jax_md import dataclasses
- Array = jnp.ndarray
- T = TypeVar('T')
- @dataclasses.dataclass
- class GeneralQuadraticSurfaceParams:
- quadratic_form_flat: Array = dataclasses.field(default_factory=lambda: jnp.array([1, 0, 0, 1, 0, 1]))
- linear: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
- @property
- def quadratic_form(self):
- a, b, c, d, e, f = self.quadratic_form_flat
- return jnp.array([[a, b, c],
- [b, d, e],
- [c, e, f]])
- # def to_array(self) -> Array:
- # return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
- # @staticmethod
- # def from_array(x: Array) -> QuadraticSurfaceParams:
- # return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
- @partial(jax.jit, static_argnums=(3,))
- def quadratic_surface(params: GeneralQuadraticSurfaceParams, coord: Array, mask: Array, constant: int = -1) -> Array:
- """
- Residual function for fitting a general quadric surface to a group of particles defined by coord array and a mask over coord.
- """
- if constant not in (-1, 0, 1):
- raise ValueError(f"Quadratic surface constant should be -1, 0, or 1, got {constant}.")
- quadratic_term = jnp.sum(coord * (coord @ params.quadratic_form), axis=1)
- linear_term = jnp.sum(coord * params.linear, axis=1)
- return (quadratic_term + linear_term + constant) * mask
- GeneralSurfaceFn = Callable[[GeneralQuadraticSurfaceParams, Array, Array], Array]
- def surface_fit_gn(surface_fn: GeneralSurfaceFn,
- coord: Array,
- mask: Array,
- p0: GeneralQuadraticSurfaceParams,
- verbose: bool = False) -> GeneralQuadraticSurfaceParams:
- """
- Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
- """
- gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
- # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
- p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
- opt = gn.run(p0, coord, mask)
- return opt.params
|