from __future__ import annotations from typing import TypeVar, Callable, Any import jax.numpy as jnp import jax T = TypeVar('T') Array = jnp.ndarray def all_equal(iterator) -> bool: """Check if all elements in an iterator are equal.""" iterator = iter(iterator) try: first = next(iterator) except StopIteration: return True return all(first == x for x in iterator) def repeat_fields(pytree: T, num: int) -> T: """Repeat each leaf node of a PyTree `num` times along a new leading axis.""" def repeat(x): try: return jnp.repeat(x[None, ...], num, axis=0) except TypeError: return jnp.full((num,), x) return jax.tree_util.tree_map(repeat, pytree) def merge(pytree_list: list[T]) -> T: """Merge the leaf nodes of multiple PyTrees by concatenating them along the leading axis.""" # Watch out for the differences between this function and stack function. # This concatenates over existing axes and stack creates a new leading axis. return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *pytree_list) def stack(pytree_list: list[T]) -> T: """Merge the leaf nodes of multiple PyTrees by stacking them along a new leading axis.""" return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *pytree_list) def axis_length(pytree: T, axis: int = 0) -> T: """ Calculate axis lengths of pytree leaves. Non-array values are assigned None. If `axis` argument is larger than array shape, we return -1. """ def length(x): try: return x.shape[axis] except AttributeError: return 0 except IndexError: return -1 return jax.tree_util.tree_map(length, pytree) class NoLengthError(Exception): pass def data_length(pytree: T, axis: int = 0, ignore_non_array_leaves: bool = False) -> int: """Assign a length to a pytree from shapes of arrays stored within it.""" leading_dim = axis_length(pytree, axis=axis) lengths, structure = jax.tree_util.tree_flatten(leading_dim) # we want to exclude leaves that are not arrays (length None) as this might be some number auxiliary data if ignore_non_array_leaves: lengths = [x for x in lengths if x > 0] if all_equal(lengths) and len(lengths) > 0: return lengths[0] raise NoLengthError(f'Pytree of type {type(pytree)} with structure {structure} cannot have a length assigned to it ' f'over axis {axis}.') def check_if_pytree_of_arrays(pytree: T, allow_numbers: bool = True) -> bool: """Check if a pytree consists only of ndarray leaves, with possibly allowed number leaves for auxiliary data.""" leaves, _ = jax.tree_util.tree_flatten(pytree) # print('Leaves type: ', [type(l) for l in leaves]) for leaf in leaves: if not isinstance(leaf, jnp.ndarray): if allow_numbers and not hasattr(leaf, "__len__"): continue return False return True def get_slice(pytree: T, start_idx: int, slice_length: int) -> T: """ Return a new PyTree with the same structure as the input PyTree, but each leaf array sliced along the leading axis. Args: pytree: A PyTree of ndarrays, representing the input data. start_idx: An integer, the starting index of the slice. slice_length: An integer, the length of the slice. Returns: A PyTree of ndarrays, with each element sliced along the first axis. """ return jax.tree_util.tree_map(lambda x: x[start_idx:start_idx+slice_length], pytree) def split_to_list(pytree: T) -> list[T]: if not check_if_pytree_of_arrays(pytree, allow_numbers=True): raise ValueError('Should get a pytree of arrays.') length = data_length(pytree) return [jax.tree_util.tree_map(lambda x: x[idx], pytree) for idx in range(length)] def map_over_leading_leaf_dimension(f: Callable[[T, Any], T], *pytrees: T, **kwargs): """ Maps a function that takes a pytree over the leading leaf dimensions by splitting the pytree over this leading dimension and stacking the results into a single object. All function *args must be splittable pytrees. """ split_tree = split_to_list(pytrees) results = [f(*tree, **kwargs) for tree in split_tree] return stack(results) def num_dimensions(pytree: T) -> T: """Determine the number of dimensions for each array in a pytree of arrays.""" def num_dim(x): try: return len(x.shape) except AttributeError: return 0 return jax.tree_util.tree_map(num_dim, pytree) def num_extra_dimensions(pytree: T, og_pytree: T) -> int: """Determine the number of extra leading dimensions compared to some original pytree of the same kind.""" pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree)) og_pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(og_pytree)) dim_differences = [d1 - d2 for d1, d2 in zip(pytree_dims, og_pytree_dims)] if all_equal(dim_differences): return dim_differences[0] raise ValueError('No consistent extra leading dimensions found.') def leaf_norm(pytree: T, num_ld: int = 0, keepdims: bool = True) -> T: def unitwise_norm(x: Array) -> Array: squared_norm = jnp.sum(x ** 2, keepdims=keepdims) return jnp.sqrt(squared_norm) if num_ld == 0: return jax.tree_util.tree_map(unitwise_norm, pytree) if num_ld == 1: return jax.tree_util.tree_map(jax.vmap(unitwise_norm), pytree) if num_ld == 2: return jax.tree_util.tree_map(jax.vmap(jax.vmap(unitwise_norm)), pytree) raise NotImplementedError('Cannot calculate the leaf_norm of leaves with 3 or more common leading dimensions.') def broadcast_to(pytree1: T, pytree2: T) -> T: """Broadcast all leaf arrays from one pytree to the shape of arrays in another pytree of the same type.""" if not check_if_pytree_of_arrays(pytree1, allow_numbers=False) \ and not check_if_pytree_of_arrays(pytree2, allow_numbers=False): raise ValueError('Should get pytrees of arrays.') return jax.tree_util.tree_map(lambda x, y: jnp.broadcast_to(x, y.shape), pytree1, pytree2) def all_data_to_single_array(pytree: T) -> jnp.ndarray: if not check_if_pytree_of_arrays(pytree, allow_numbers=False): raise ValueError('Should get a pytree of arrays.') arrays, _ = jax.tree_util.tree_flatten(pytree) array_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree)) max_tree_dim = max(array_dims) reshaped_arrays = [] for array, dim in zip(arrays, array_dims): reshaped_arrays.append(array.reshape(array.shape + (1,) * (max_tree_dim - dim))) return jnp.hstack(reshaped_arrays) def extend_with_last_element(pytree: T, desired_length: int) -> T: if not check_if_pytree_of_arrays(pytree, allow_numbers=False): raise ValueError('Should get a pytree of arrays.') length = data_length(pytree, axis=0) if length > desired_length: raise ValueError('Data length is larger than desired length so it cannot be extended.') append_length = desired_length - length def leaf_append(x): append_shape = list(x.shape) append_shape[0] = append_length return jnp.append(x, jnp.full(append_shape, x[-1]), axis=0) return jax.tree_util.tree_map(leaf_append, pytree)