from typing import Callable, TypeVar, Protocol
import numpy as np
from charged_shells.expansion import Expansion
from charged_shells.parameters import ModelParams

T = TypeVar('T')
V = TypeVar('V')
Array = np.ndarray


def map_over_expansion(f: Callable[[Expansion, T], V]) -> Callable[[Expansion, T], V]:
    """Map a function f over all leading axes of an expansion. Uses for loops, so it is kinda slow."""

    def mapped_f(ex: Expansion, *args, **kwargs):
        og_shape = ex.shape
        flat_ex = ex.flatten()
        results = []
        for i in range(int(np.prod(og_shape))):
            results.append(f(flat_ex[i], *args, **kwargs))
        try:
            return np.array(results).reshape(og_shape + results[0].shape)
        except AttributeError:
            return np.array(results).reshape(og_shape)

    return mapped_f


def unravel_params(params: ModelParams) -> list[ModelParams]:
    if isinstance(params.R, Array) and isinstance(params.kappa, Array):
        # if this is to be implemented, watch out for implementations of mapping expansions that depend
        # on one of the parameters in ModelParams over other functions that also take the same ModelParameters
        raise NotImplementedError("Currently only unravel over a single parameter is supported. ")
    if isinstance(params.R, Array):
        return [ModelParams(R=r, kappa=params.kappa) for r in params.R]
    if isinstance(params.kappa, Array):
        return [ModelParams(R=params.R, kappa=kappa) for kappa in params.kappa]
    if not (isinstance(params.R, Array) or isinstance(params.kappa, Array)):
        return [params]
    raise NotImplementedError


def unravel_expansion_over_axis(ex: Expansion, axis: int | None, param_list_len: int) -> list[Expansion]:
    if axis is None:
        return [ex for _ in range(param_list_len)]
    axis_len = ex.shape[axis]
    if axis_len != param_list_len:
        raise ValueError(f'Parameter list has different length than the provided expansion axis, '
                         f'got param_list_len={param_list_len} and axis_len={axis_len}.')
    return [Expansion(ex.l_array, np.take(ex.coefs, i, axis=axis)) for i in range(axis_len)]


SingleExpansionFn = Callable[[Expansion, ModelParams], T]
# TwoExpansionsFn = Callable[[Expansion, Expansion, ModelParams], T]

class TwoExpansionsFn(Protocol):
    def __call__(self, ex1: Expansion, ex2: Expansion, params: ModelParams, **kwargs) -> T:
        ...


def parameter_map_single_expansion(f: SingleExpansionFn,
                                   match_expansion_axis_to_params: int = None) -> SingleExpansionFn:

    def mapped_f(ex: Expansion, params: ModelParams):
        params_list = unravel_params(params)
        expansion_list = unravel_expansion_over_axis(ex, match_expansion_axis_to_params, len(params_list))

        results = []
        for exp, prms in zip(expansion_list, params_list):
            results.append(f(exp, prms))

        if match_expansion_axis_to_params is not None:
            # return the params-matched axis to where it belongs
            return np.moveaxis(np.array(results), 0, match_expansion_axis_to_params)
        return np.squeeze(np.array(results))

    return mapped_f


def parameter_map_two_expansions(f: TwoExpansionsFn,
                                 match_expansion_axis_to_params: int = None,
                                 ) -> TwoExpansionsFn:

    def mapped_f(ex1: Expansion, ex2: Expansion, params: ModelParams, **kwargs):
        params_list = unravel_params(params)
        expansion_list1 = unravel_expansion_over_axis(ex1, match_expansion_axis_to_params, len(params_list))
        expansion_list2 = unravel_expansion_over_axis(ex2, match_expansion_axis_to_params, len(params_list))

        results = []
        for exp1, exp2, prms in zip(expansion_list1, expansion_list2, params_list):
            results.append(f(exp1, exp2, prms, **kwargs))

        if match_expansion_axis_to_params is not None:
            # return the params-matched axis to where it belongs
            return np.moveaxis(np.array(results), 0, match_expansion_axis_to_params)
        return np.squeeze(np.array(results))

    return mapped_f