from __future__ import annotations from jaxopt import GaussNewton, LevenbergMarquardt import jax.numpy as jnp from typing import Callable, TypeVar import jax from functools import partial from jax_md import dataclasses Array = jnp.ndarray T = TypeVar('T') @dataclasses.dataclass class GeneralQuadraticSurfaceParams: quadratic_form_flat: Array = dataclasses.field(default_factory=lambda: jnp.array([1, 0, 0, 1, 0, 1])) linear: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,))) @property def quadratic_form(self): a, b, c, d, e, f = self.quadratic_form_flat return jnp.array([[a, b, c], [b, d, e], [c, e, f]]) # def to_array(self) -> Array: # return jnp.hstack((self.center, self.euler, jnp.array([self.radius]))) # @staticmethod # def from_array(x: Array) -> QuadraticSurfaceParams: # return QuadraticSurfaceParams(x[:3], x[3:6], x[6]) @partial(jax.jit, static_argnums=(3,)) def quadratic_surface(params: GeneralQuadraticSurfaceParams, coord: Array, mask: Array, constant: int = -1) -> Array: """ Residual function for fitting a general quadric surface to a group of particles defined by coord array and a mask over coord. """ if constant not in (-1, 0, 1): raise ValueError(f"Quadratic surface constant should be -1, 0, or 1, got {constant}.") quadratic_term = jnp.sum(coord * (coord @ params.quadratic_form), axis=1) linear_term = jnp.sum(coord * params.linear, axis=1) return (quadratic_term + linear_term + constant) * mask GeneralSurfaceFn = Callable[[GeneralQuadraticSurfaceParams, Array, Array], Array] def surface_fit_gn(surface_fn: GeneralSurfaceFn, coord: Array, mask: Array, p0: GeneralQuadraticSurfaceParams, verbose: bool = False) -> GeneralQuadraticSurfaceParams: """ Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method. """ gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose) # we want to avoid "bytes-like object" TypeError if initial params are given as integers: p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0) opt = gn.run(p0, coord, mask) return opt.params