from __future__ import annotations
from typing import TypeVar, Callable, Any
import jax.numpy as jnp
import jax


T = TypeVar('T')
Array = jnp.ndarray


def all_equal(iterator) -> bool:
    """Check if all elements in an iterator are equal."""
    iterator = iter(iterator)
    try:
        first = next(iterator)
    except StopIteration:
        return True
    return all(first == x for x in iterator)


def repeat_fields(pytree: T, num: int) -> T:
    """Repeat each leaf node of a PyTree `num` times along a new leading axis."""
    def repeat(x):
        try:
            return jnp.repeat(x[None, ...], num, axis=0)
        except TypeError:
            return jnp.full((num,), x)
    return jax.tree_util.tree_map(repeat, pytree)


def merge(pytree_list: list[T]) -> T:
    """Merge the leaf nodes of multiple PyTrees by concatenating them along the leading axis."""
    # Watch out for the differences between this function and stack function.
    # This concatenates over existing axes and stack creates a new leading axis.
    return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *pytree_list)

def stack(pytree_list: list[T]) -> T:
    """Merge the leaf nodes of multiple PyTrees by stacking them along a new leading axis."""
    return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *pytree_list)


def axis_length(pytree: T, axis: int = 0) -> T:
    """
    Calculate axis lengths of pytree leaves. Non-array values are assigned None. If `axis` argument is larger
    than array shape, we return -1.
    """
    def length(x):
        try:
            return x.shape[axis]
        except AttributeError:
            return 0
        except IndexError:
            return -1
    return jax.tree_util.tree_map(length, pytree)


class NoLengthError(Exception):
    pass


def data_length(pytree: T, axis: int = 0, ignore_non_array_leaves: bool = False) -> int:
    """Assign a length to a pytree from shapes of arrays stored within it."""
    leading_dim = axis_length(pytree, axis=axis)
    lengths, structure = jax.tree_util.tree_flatten(leading_dim)
    # we want to exclude leaves that are not arrays (length None) as this might be some number auxiliary data
    if ignore_non_array_leaves:
        lengths = [x for x in lengths if x > 0]
    if all_equal(lengths) and len(lengths) > 0:
        return lengths[0]
    raise NoLengthError(f'Pytree of type {type(pytree)} with structure {structure} cannot have a length assigned to it '
                        f'over axis {axis}.')


def check_if_pytree_of_arrays(pytree: T, allow_numbers: bool = True) -> bool:
    """Check if a pytree consists only of ndarray leaves, with possibly allowed number leaves for auxiliary data."""
    leaves, _ = jax.tree_util.tree_flatten(pytree)
    # print('Leaves type: ', [type(l) for l in leaves])
    for leaf in leaves:
        if not isinstance(leaf, jnp.ndarray):
            if allow_numbers and not hasattr(leaf, "__len__"):
                continue
            return False
    return True


def get_slice(pytree: T, start_idx: int, slice_length: int) -> T:
    """
    Return a new PyTree with the same structure as the input PyTree, but each leaf array sliced along the leading axis.

    Args:
        pytree: A PyTree of ndarrays, representing the input data.
        start_idx: An integer, the starting index of the slice.
        slice_length: An integer, the length of the slice.

    Returns:
        A PyTree of ndarrays, with each element sliced along the first axis.
    """
    return jax.tree_util.tree_map(lambda x: x[start_idx:start_idx+slice_length], pytree)


def split_to_list(pytree: T) -> list[T]:
    if not check_if_pytree_of_arrays(pytree, allow_numbers=True):
        raise ValueError('Should get a pytree of arrays.')
    length = data_length(pytree)
    return [jax.tree_util.tree_map(lambda x: x[idx], pytree) for idx in range(length)]

def map_over_leading_leaf_dimension(f: Callable[[T, Any], T], *pytrees: T, **kwargs):
    """
    Maps a function that takes a pytree over the leading leaf dimensions by splitting the pytree over this leading
    dimension and stacking the results into a single object. All function *args must be splittable pytrees.
    """
    split_tree = split_to_list(pytrees)
    results = [f(*tree, **kwargs) for tree in split_tree]
    return stack(results)


def num_dimensions(pytree: T) -> T:
    """Determine the number of dimensions for each array in a pytree of arrays."""
    def num_dim(x):
        try:
            return len(x.shape)
        except AttributeError:
            return 0
    return jax.tree_util.tree_map(num_dim, pytree)


def num_extra_dimensions(pytree: T, og_pytree: T) -> int:
    """Determine the number of extra leading dimensions compared to some original pytree of the same kind."""
    pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
    og_pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(og_pytree))
    dim_differences = [d1 - d2 for d1, d2 in zip(pytree_dims, og_pytree_dims)]
    if all_equal(dim_differences):
        return dim_differences[0]
    raise ValueError('No consistent extra leading dimensions found.')


def leaf_norm(pytree: T, num_ld: int = 0, keepdims: bool = True) -> T:

    def unitwise_norm(x: Array) -> Array:
        squared_norm = jnp.sum(x ** 2, keepdims=keepdims)
        return jnp.sqrt(squared_norm)

    if num_ld == 0:
        return jax.tree_util.tree_map(unitwise_norm, pytree)
    if num_ld == 1:
        return jax.tree_util.tree_map(jax.vmap(unitwise_norm), pytree)
    if num_ld == 2:
        return jax.tree_util.tree_map(jax.vmap(jax.vmap(unitwise_norm)), pytree)
    raise NotImplementedError('Cannot calculate the leaf_norm of leaves with 3 or more common leading dimensions.')


def broadcast_to(pytree1: T, pytree2: T) -> T:
    """Broadcast all leaf arrays from one pytree to the shape of arrays in another pytree of the same type."""
    if not check_if_pytree_of_arrays(pytree1, allow_numbers=False) \
            and not check_if_pytree_of_arrays(pytree2, allow_numbers=False):
        raise ValueError('Should get pytrees of arrays.')
    return jax.tree_util.tree_map(lambda x, y: jnp.broadcast_to(x, y.shape), pytree1, pytree2)


def all_data_to_single_array(pytree: T) -> jnp.ndarray:
    if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
        raise ValueError('Should get a pytree of arrays.')
    arrays, _ = jax.tree_util.tree_flatten(pytree)
    array_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
    max_tree_dim = max(array_dims)
    reshaped_arrays = []
    for array, dim in zip(arrays, array_dims):
        reshaped_arrays.append(array.reshape(array.shape + (1,) * (max_tree_dim - dim)))
    return jnp.hstack(reshaped_arrays)


def extend_with_last_element(pytree: T, desired_length: int) -> T:
    if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
        raise ValueError('Should get a pytree of arrays.')
    length = data_length(pytree, axis=0)
    if length > desired_length:
        raise ValueError('Data length is larger than desired length so it cannot be extended.')
    append_length = desired_length - length
    def leaf_append(x):
        append_shape = list(x.shape)
        append_shape[0] = append_length
        return jnp.append(x, jnp.full(append_shape, x[-1]), axis=0)
    return jax.tree_util.tree_map(leaf_append, pytree)