123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- from __future__ import annotations
- from jaxopt import GaussNewton, LevenbergMarquardt
- import jax.numpy as jnp
- from typing import Callable, TypeVar
- import jax
- from functools import partial
- from jax_md import dataclasses
- Array = jnp.ndarray
- T = TypeVar('T')
- @dataclasses.dataclass
- class QuadraticSurfaceParams:
- center: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
- euler: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
- radius: float = 1.
- def to_array(self) -> Array:
- return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))
- @staticmethod
- def from_array(x: Array) -> QuadraticSurfaceParams:
- return QuadraticSurfaceParams(x[:3], x[3:6], x[6])
- def rotation_matrix(euler_angles: Array) -> Array:
- alpha, beta, gamma = euler_angles
- Rz1 = jnp.array([[jnp.cos(alpha), -jnp.sin(alpha), 0],
- [jnp.sin(alpha), jnp.cos(alpha), 0],
- [0, 0, 1]])
- Ry = jnp.array([[jnp.cos(beta), 0, -jnp.sin(beta)],
- [0, 1, 0],
- [jnp.sin(beta), 0, jnp.cos(beta)]])
- Rz2 = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0],
- [jnp.sin(gamma), jnp.cos(gamma), 0],
- [0, 0, 1]])
- return Rz2 @ Ry @ Rz1
- @partial(jnp.vectorize, signature='(d,d),(d)->(d,d)')
- def quadratic_form(rot_mat, eigvals: Array):
- a, b, c = eigvals
- eig_mat = jnp.array([[a, 0, 0],
- [0, b, 0],
- [0, 0, c]])
- return rot_mat @ eig_mat @ jnp.transpose(rot_mat)
- SurfaceFn = Callable[[QuadraticSurfaceParams, Array, Array], Array]
- def spherical_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array) -> Array:
- """
- Residual function for fitting a spherical surface to a group of particles defined by coord array and
- a mask over coord.
- """
- return (jnp.linalg.norm(coord - params.center, axis=1) - jnp.abs(params.radius)) * mask
- def quadratic_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array, qf_eigvals: Array) -> Array:
- """
- Residual function for fitting a cylinder to a group of particles defined by coord array and a mask over coord.
- """
- relative_coord = coord - params.center
- qf = quadratic_form(rotation_matrix(params.euler), qf_eigvals)
- # return (jnp.sum(relative_coord * (relative_coord @ qf), axis=1) ** 2 - jnp.abs(params.radius) ** 2) * mask
- return (jnp.sqrt(jnp.sum(relative_coord * (relative_coord @ qf), axis=1)) - jnp.abs(params.radius)) * mask
- cylindrical_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., 0.]))
- hyperbolic_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., -2.]))
- def surface_fit_gn(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: T, verbose: bool = False) -> T:
- """
- Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
- """
- gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
- # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
- p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
- opt = gn.run(p0, coord, mask)
- return opt.params
- def surface_fit_lm(surface_fn: SurfaceFn,
- coord: Array,
- mask: Array,
- p0: QuadraticSurfaceParams,
- verbose: bool = False) -> QuadraticSurfaceParams:
- """
- Fit a surface to a group of particles defined by coord array and a mask over coord
- using the Levenberg-Marquardt method. Doesn't seem to work with gradient calculation over hyperparameters.
- """
- def unraveled_fn(x):
- params = QuadraticSurfaceParams.from_array(x)
- return surface_fn(params, coord, mask)
- p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
- p0_array = p0.to_array()
- lm = LevenbergMarquardt(residual_fun=unraveled_fn, maxiter=20, verbose=verbose)
- opt = lm.run(p0_array)
- return QuadraticSurfaceParams.from_array(opt.params)
|