123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- from __future__ import annotations
- from typing import TypeVar, Callable, Any
- import jax.numpy as jnp
- import jax
- T = TypeVar('T')
- Array = jnp.ndarray
- def all_equal(iterator) -> bool:
- """Check if all elements in an iterator are equal."""
- iterator = iter(iterator)
- try:
- first = next(iterator)
- except StopIteration:
- return True
- return all(first == x for x in iterator)
- def repeat_fields(pytree: T, num: int) -> T:
- """Repeat each leaf node of a PyTree `num` times along a new leading axis."""
- def repeat(x):
- try:
- return jnp.repeat(x[None, ...], num, axis=0)
- except TypeError:
- return jnp.full((num,), x)
- return jax.tree_util.tree_map(repeat, pytree)
- def merge(pytree_list: list[T]) -> T:
- """Merge the leaf nodes of multiple PyTrees by concatenating them along the leading axis."""
- # Watch out for the differences between this function and stack function.
- # This concatenates over existing axes and stack creates a new leading axis.
- return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *pytree_list)
- def stack(pytree_list: list[T]) -> T:
- """Merge the leaf nodes of multiple PyTrees by stacking them along a new leading axis."""
- return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *pytree_list)
- def axis_length(pytree: T, axis: int = 0) -> T:
- """
- Calculate axis lengths of pytree leaves. Non-array values are assigned None. If `axis` argument is larger
- than array shape, we return -1.
- """
- def length(x):
- try:
- return x.shape[axis]
- except AttributeError:
- return 0
- except IndexError:
- return -1
- return jax.tree_util.tree_map(length, pytree)
- class NoLengthError(Exception):
- pass
- def data_length(pytree: T, axis: int = 0, ignore_non_array_leaves: bool = False) -> int:
- """Assign a length to a pytree from shapes of arrays stored within it."""
- leading_dim = axis_length(pytree, axis=axis)
- lengths, structure = jax.tree_util.tree_flatten(leading_dim)
- # we want to exclude leaves that are not arrays (length None) as this might be some number auxiliary data
- if ignore_non_array_leaves:
- lengths = [x for x in lengths if x > 0]
- if all_equal(lengths) and len(lengths) > 0:
- return lengths[0]
- raise NoLengthError(f'Pytree of type {type(pytree)} with structure {structure} cannot have a length assigned to it '
- f'over axis {axis}.')
- def check_if_pytree_of_arrays(pytree: T, allow_numbers: bool = True) -> bool:
- """Check if a pytree consists only of ndarray leaves, with possibly allowed number leaves for auxiliary data."""
- leaves, _ = jax.tree_util.tree_flatten(pytree)
- # print('Leaves type: ', [type(l) for l in leaves])
- for leaf in leaves:
- if not isinstance(leaf, jnp.ndarray):
- if allow_numbers and not hasattr(leaf, "__len__"):
- continue
- return False
- return True
- def get_slice(pytree: T, start_idx: int, slice_length: int) -> T:
- """
- Return a new PyTree with the same structure as the input PyTree, but each leaf array sliced along the leading axis.
- Args:
- pytree: A PyTree of ndarrays, representing the input data.
- start_idx: An integer, the starting index of the slice.
- slice_length: An integer, the length of the slice.
- Returns:
- A PyTree of ndarrays, with each element sliced along the first axis.
- """
- return jax.tree_util.tree_map(lambda x: x[start_idx:start_idx+slice_length], pytree)
- def split_to_list(pytree: T) -> list[T]:
- if not check_if_pytree_of_arrays(pytree, allow_numbers=True):
- raise ValueError('Should get a pytree of arrays.')
- length = data_length(pytree)
- return [jax.tree_util.tree_map(lambda x: x[idx], pytree) for idx in range(length)]
- def map_over_leading_leaf_dimension(f: Callable[[T, Any], T], *pytrees: T, **kwargs):
- """
- Maps a function that takes a pytree over the leading leaf dimensions by splitting the pytree over this leading
- dimension and stacking the results into a single object. All function *args must be splittable pytrees.
- """
- split_tree = split_to_list(pytrees)
- results = [f(*tree, **kwargs) for tree in split_tree]
- return stack(results)
- def num_dimensions(pytree: T) -> T:
- """Determine the number of dimensions for each array in a pytree of arrays."""
- def num_dim(x):
- try:
- return len(x.shape)
- except AttributeError:
- return 0
- return jax.tree_util.tree_map(num_dim, pytree)
- def num_extra_dimensions(pytree: T, og_pytree: T) -> int:
- """Determine the number of extra leading dimensions compared to some original pytree of the same kind."""
- pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
- og_pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(og_pytree))
- dim_differences = [d1 - d2 for d1, d2 in zip(pytree_dims, og_pytree_dims)]
- if all_equal(dim_differences):
- return dim_differences[0]
- raise ValueError('No consistent extra leading dimensions found.')
- def leaf_norm(pytree: T, num_ld: int = 0, keepdims: bool = True) -> T:
- def unitwise_norm(x: Array) -> Array:
- squared_norm = jnp.sum(x ** 2, keepdims=keepdims)
- return jnp.sqrt(squared_norm)
- if num_ld == 0:
- return jax.tree_util.tree_map(unitwise_norm, pytree)
- if num_ld == 1:
- return jax.tree_util.tree_map(jax.vmap(unitwise_norm), pytree)
- if num_ld == 2:
- return jax.tree_util.tree_map(jax.vmap(jax.vmap(unitwise_norm)), pytree)
- raise NotImplementedError('Cannot calculate the leaf_norm of leaves with 3 or more common leading dimensions.')
- def broadcast_to(pytree1: T, pytree2: T) -> T:
- """Broadcast all leaf arrays from one pytree to the shape of arrays in another pytree of the same type."""
- if not check_if_pytree_of_arrays(pytree1, allow_numbers=False) \
- and not check_if_pytree_of_arrays(pytree2, allow_numbers=False):
- raise ValueError('Should get pytrees of arrays.')
- return jax.tree_util.tree_map(lambda x, y: jnp.broadcast_to(x, y.shape), pytree1, pytree2)
- def all_data_to_single_array(pytree: T) -> jnp.ndarray:
- if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
- raise ValueError('Should get a pytree of arrays.')
- arrays, _ = jax.tree_util.tree_flatten(pytree)
- array_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
- max_tree_dim = max(array_dims)
- reshaped_arrays = []
- for array, dim in zip(arrays, array_dims):
- reshaped_arrays.append(array.reshape(array.shape + (1,) * (max_tree_dim - dim)))
- return jnp.hstack(reshaped_arrays)
- def extend_with_last_element(pytree: T, desired_length: int) -> T:
- if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
- raise ValueError('Should get a pytree of arrays.')
- length = data_length(pytree, axis=0)
- if length > desired_length:
- raise ValueError('Data length is larger than desired length so it cannot be extended.')
- append_length = desired_length - length
- def leaf_append(x):
- append_shape = list(x.shape)
- append_shape[0] = append_length
- return jnp.append(x, jnp.full(append_shape, x[-1]), axis=0)
- return jax.tree_util.tree_map(leaf_append, pytree)
|