from __future__ import annotations from jaxopt import GaussNewton, LevenbergMarquardt import jax.numpy as jnp from typing import Callable, TypeVar import jax from functools import partial from jax_md import dataclasses Array = jnp.ndarray T = TypeVar('T') @dataclasses.dataclass class QuadraticSurfaceParams: center: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,))) euler: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,))) radius: float = 1. def to_array(self) -> Array: return jnp.hstack((self.center, self.euler, jnp.array([self.radius]))) @staticmethod def from_array(x: Array) -> QuadraticSurfaceParams: return QuadraticSurfaceParams(x[:3], x[3:6], x[6]) def rotation_matrix(euler_angles: Array) -> Array: alpha, beta, gamma = euler_angles Rz1 = jnp.array([[jnp.cos(alpha), -jnp.sin(alpha), 0], [jnp.sin(alpha), jnp.cos(alpha), 0], [0, 0, 1]]) Ry = jnp.array([[jnp.cos(beta), 0, -jnp.sin(beta)], [0, 1, 0], [jnp.sin(beta), 0, jnp.cos(beta)]]) Rz2 = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0], [jnp.sin(gamma), jnp.cos(gamma), 0], [0, 0, 1]]) return Rz2 @ Ry @ Rz1 @partial(jnp.vectorize, signature='(d,d),(d)->(d,d)') def quadratic_form(rot_mat, eigvals: Array): a, b, c = eigvals eig_mat = jnp.array([[a, 0, 0], [0, b, 0], [0, 0, c]]) return rot_mat @ eig_mat @ jnp.transpose(rot_mat) SurfaceFn = Callable[[QuadraticSurfaceParams, Array, Array], Array] def spherical_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array) -> Array: """ Residual function for fitting a spherical surface to a group of particles defined by coord array and a mask over coord. """ return (jnp.linalg.norm(coord - params.center, axis=1) - jnp.abs(params.radius)) * mask def quadratic_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array, qf_eigvals: Array) -> Array: """ Residual function for fitting a cylinder to a group of particles defined by coord array and a mask over coord. """ relative_coord = coord - params.center qf = quadratic_form(rotation_matrix(params.euler), qf_eigvals) # return (jnp.sum(relative_coord * (relative_coord @ qf), axis=1) ** 2 - jnp.abs(params.radius) ** 2) * mask return (jnp.sqrt(jnp.sum(relative_coord * (relative_coord @ qf), axis=1)) - jnp.abs(params.radius)) * mask cylindrical_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., 0.])) hyperbolic_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., -2.])) def surface_fit_gn(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: T, verbose: bool = False) -> T: """ Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method. """ gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose) # we want to avoid "bytes-like object" TypeError if initial params are given as integers: p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0) opt = gn.run(p0, coord, mask) return opt.params def surface_fit_lm(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: QuadraticSurfaceParams, verbose: bool = False) -> QuadraticSurfaceParams: """ Fit a surface to a group of particles defined by coord array and a mask over coord using the Levenberg-Marquardt method. Doesn't seem to work with gradient calculation over hyperparameters. """ def unraveled_fn(x): params = QuadraticSurfaceParams.from_array(x) return surface_fn(params, coord, mask) p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0) p0_array = p0.to_array() lm = LevenbergMarquardt(residual_fun=unraveled_fn, maxiter=20, verbose=verbose) opt = lm.run(p0_array) return QuadraticSurfaceParams.from_array(opt.params)