from __future__ import annotations
from jaxopt import GaussNewton, LevenbergMarquardt
import jax.numpy as jnp
from typing import Callable, TypeVar
import jax
from functools import partial
from jax_md import dataclasses


Array = jnp.ndarray
T = TypeVar('T')


@dataclasses.dataclass
class QuadraticSurfaceParams:
    center: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
    euler: Array = dataclasses.field(default_factory=lambda: jnp.zeros((3,)))
    radius: float = 1.

    def to_array(self) -> Array:
        return jnp.hstack((self.center, self.euler, jnp.array([self.radius])))

    @staticmethod
    def from_array(x: Array) -> QuadraticSurfaceParams:
        return QuadraticSurfaceParams(x[:3], x[3:6], x[6])


def rotation_matrix(euler_angles: Array) -> Array:
    alpha, beta, gamma = euler_angles
    Rz1 = jnp.array([[jnp.cos(alpha), -jnp.sin(alpha), 0],
                     [jnp.sin(alpha), jnp.cos(alpha), 0],
                     [0, 0, 1]])

    Ry = jnp.array([[jnp.cos(beta), 0, -jnp.sin(beta)],
                    [0, 1, 0],
                    [jnp.sin(beta), 0, jnp.cos(beta)]])

    Rz2 = jnp.array([[jnp.cos(gamma), -jnp.sin(gamma), 0],
                     [jnp.sin(gamma), jnp.cos(gamma), 0],
                     [0, 0, 1]])
    return Rz2 @ Ry @ Rz1


@partial(jnp.vectorize, signature='(d,d),(d)->(d,d)')
def quadratic_form(rot_mat, eigvals: Array):
    a, b, c = eigvals
    eig_mat = jnp.array([[a, 0, 0],
                         [0, b, 0],
                         [0, 0, c]])
    return rot_mat @ eig_mat @ jnp.transpose(rot_mat)


SurfaceFn = Callable[[QuadraticSurfaceParams, Array, Array], Array]


def spherical_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array) -> Array:
    """
    Residual function for fitting a spherical surface to a group of particles defined by coord array and
    a mask over coord.
    """
    return (jnp.linalg.norm(coord - params.center, axis=1) - jnp.abs(params.radius)) * mask


def quadratic_surface(params: QuadraticSurfaceParams, coord: Array, mask: Array, qf_eigvals: Array) -> Array:
    """
    Residual function for fitting a cylinder to a group of particles defined by coord array and a mask over coord.
    """
    relative_coord = coord - params.center
    qf = quadratic_form(rotation_matrix(params.euler), qf_eigvals)
    # return (jnp.sum(relative_coord * (relative_coord @ qf), axis=1) ** 2 - jnp.abs(params.radius) ** 2) * mask
    return (jnp.sqrt(jnp.sum(relative_coord * (relative_coord @ qf), axis=1)) - jnp.abs(params.radius)) * mask


cylindrical_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., 0.]))
hyperbolic_surface = partial(quadratic_surface, qf_eigvals=jnp.array([1., 1., -2.]))


def surface_fit_gn(surface_fn: SurfaceFn, coord: Array, mask: Array, p0: T, verbose: bool = False) -> T:
    """
    Fit a surface to a group of particles defined by coord array and a mask over coord using the Gauss-Newton method.
    """
    gn = GaussNewton(residual_fun=surface_fn, maxiter=20, verbose=verbose)
    # we want to avoid "bytes-like object" TypeError if initial params are given as integers:
    p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
    opt = gn.run(p0, coord, mask)
    return opt.params


def surface_fit_lm(surface_fn: SurfaceFn,
                   coord: Array,
                   mask: Array,
                   p0: QuadraticSurfaceParams,
                   verbose: bool = False) -> QuadraticSurfaceParams:
    """
    Fit a surface to a group of particles defined by coord array and a mask over coord
    using the Levenberg-Marquardt method. Doesn't seem to work with gradient calculation over hyperparameters.
    """

    def unraveled_fn(x):
        params = QuadraticSurfaceParams.from_array(x)
        return surface_fn(params, coord, mask)

    p0 = jax.tree_util.tree_map(partial(jnp.asarray, dtype=jnp.float64), p0)
    p0_array = p0.to_array()

    lm = LevenbergMarquardt(residual_fun=unraveled_fn, maxiter=20, verbose=verbose)
    opt = lm.run(p0_array)
    return QuadraticSurfaceParams.from_array(opt.params)