from functools import partial
from typing import List, Union, Callable
import jax.numpy as jnp
import jax
from curvature_assembly.spherical_harmonics import sph_harm_fn, real_sph_harm, sph_harm_not_fast, sph_harm_fn_custom, real_sph_harm_fn_custom_rev

Array = jnp.ndarray


def vec_in_eigensystem(eigsys: Array, vec: Array):
    """Get vector components in the eigensystem."""
    return jnp.dot(jnp.transpose(eigsys), vec)


def safe_arctan2(x, y):
    """
    Version of arctan2 that works for zero-valued inputs. Look at https://github.com/google/jax/issues/1052
    and https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
    """
    safe_y = jnp.where(y > 0., y, 1.)
    return jnp.where(y > 0, jnp.arctan2(x, safe_y), 1.)


def cart_to_sph(vec: Array) -> (Array, Array):
    """Transformation to spherical coordinates theta and phi."""
    sph_coord = jnp.zeros(2, )
    sph_coord = sph_coord.at[0].set(safe_arctan2(jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), vec[2]))
    sph_coord = sph_coord.at[1].set(safe_arctan2(vec[1], vec[0]))
    return sph_coord


def patchy_interaction_general(lm_list: Union[tuple, List[tuple]]) -> Callable:
    """
    Orientational part for a general patchy particle interaction where patches are described by a linear combination
    of spherical harmonics. The form of the potential is inspired by the Kern-Frenkel patchy particle model.
    """

    if isinstance(lm_list, tuple):
        lm_list = [lm_list]

    l_list, m_list = zip(*lm_list)
    l_array = jnp.array(l_list)
    m_array = jnp.array(m_list)

    # sph_harm = real_sph_harm_fn_custom_rev(6)

    if not jnp.all(jnp.abs(m_array) <= l_array):
        raise ValueError(f'Spherical harmonics are only defined for |m|<=l.')

    def fn(dr: Array, eigsys1: Array, eigsys2: Array, lm_magnitudes: Array) -> Array:

        if lm_magnitudes.shape == ():
            lm_magnitudes = jnp.full(len(lm_list), lm_magnitudes)

        if len(lm_magnitudes) != len(lm_list):
            raise ValueError(f'Length of lm_magnitudes array does not match the number of (l, m) expansion terms, '
                             f'got {len(lm_magnitudes)} and {len(lm_list)}, respectively.')

        # dr points from 2nd to 1st particle (dr = r1 - r2)
        # we need relative direction from one particle to another, so in the case of the first, we need to take -dr
        normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
        vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
        vec2 = vec_in_eigensystem(eigsys2, normalized_dr)

        # patches_particle1 = jnp.real(sph_harm(vec1)) @ lm_magnitudes
        # patches_particle2 = jnp.real(sph_harm(vec2)) @ lm_magnitudes

        patches_particle1 = real_sph_harm(vec1, l_list, m_list) @ lm_magnitudes
        patches_particle2 = real_sph_harm(vec2, l_list, m_list) @ lm_magnitudes

        # energy contribution from patches is defined in such a way that negative patches attract each other,
        # positive patches repulse and differently-signed patches have 0 energy
        return -(jnp.sign(patches_particle1) + jnp.sign(patches_particle2)) * patches_particle1 * patches_particle2

    return fn


def generate_lm_list(l_max: int,
                     only_non_neg_m: bool = False,
                     only_even_l: bool = False,
                     only_odd_l: bool = False) -> list:
    """Return list of all possible (l, m) for a given maximal l."""
    if only_odd_l and only_even_l:
        raise ValueError('Parameters only_even_l and only_odd_l cannot both be True at the same time.')
    lm_list = []
    if only_even_l:
        l_list = list(range(0, l_max + 1, 2))
    elif only_odd_l:
        l_list = list(range(1, l_max + 1, 2))
    else:
        l_list = list(range(0, l_max + 1))
    for l in l_list:
        min_m = 0 if only_non_neg_m else -l
        for m in range(min_m, l + 1):
            lm_list.append((l, m))
    return lm_list


def init_lm_coefs(lm_list: list[tuple], nonzero_list: list[tuple], init_values: list = None) -> jnp.ndarray:
    """
    Initialize lm coefficients for a given lm_list with desired values. Default is 0. if init_values is not provided.
    """
    if init_values is None:
        init_values = [1 for _ in nonzero_list]
    coef_list = []
    for lm in lm_list:
        try:
            idx = nonzero_list.index(lm)
            coef_list.append(init_values[idx])
        except ValueError:
            coef_list.append(0.)
    return jnp.array(coef_list) / jnp.linalg.norm(jnp.array(coef_list))


def patchy_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)

    limit_z_plus = jnp.cos(theta + sigma)
    limit_z_minus = jnp.cos(theta - sigma)

    # return value should be positive for attractive patches
    # as this potential is usually combined with attractive isotropic term
    return jnp.heaviside(limit_z_minus - vec1[2], 0.5) * jnp.heaviside(vec1[2] - limit_z_plus, 0.5) * \
           jnp.heaviside(limit_z_minus - vec2[2], 0.5) * jnp.heaviside(vec2[2] - limit_z_plus, 0.5)


@jax.custom_jvp
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))


@sigmoid.defjvp
def sigmoid_jvp(x, x_dot):
    primal_out = sigmoid(x)
    tangent_out = primal_out * (1 - primal_out) * x_dot
    return primal_out, tangent_out


def gaussian_belt(x, theta, sigma) -> jnp.ndarray:
    return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)


def gaussian_belt_fixed_height(x, theta, sigma) -> jnp.ndarray:
    return jnp.exp(-0.5 * ((x - theta) / sigma) ** 2)


def gaussian_interaction_band(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)

    belt = partial(gaussian_belt, theta=theta, sigma=sigma)

    # return value should be positive for attractive patches
    # as this potential is usually combined with attractive isotropic term
    return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))


def gaussian_interaction_band_fixed_height(dr: Array, eigsys1: Array, eigsys2: Array, theta: Array, sigma: Array):
    normalized_dr = dr / jnp.sqrt(jnp.sum(dr ** 2))
    vec1 = vec_in_eigensystem(eigsys1, -normalized_dr)
    vec2 = vec_in_eigensystem(eigsys2, normalized_dr)

    belt = partial(gaussian_belt_fixed_height, theta=theta, sigma=sigma)

    # return value should be positive for attractive patches
    # as this potential is usually combined with attractive isotropic term
    return belt(jnp.arccos(vec1[2])) * belt(jnp.arccos(vec2[2]))