pytree_transf.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from __future__ import annotations
  2. from typing import TypeVar, Callable, Any
  3. import jax.numpy as jnp
  4. import jax
  5. T = TypeVar('T')
  6. Array = jnp.ndarray
  7. def all_equal(iterator) -> bool:
  8. """Check if all elements in an iterator are equal."""
  9. iterator = iter(iterator)
  10. try:
  11. first = next(iterator)
  12. except StopIteration:
  13. return True
  14. return all(first == x for x in iterator)
  15. def repeat_fields(pytree: T, num: int) -> T:
  16. """Repeat each leaf node of a PyTree `num` times along a new leading axis."""
  17. def repeat(x):
  18. try:
  19. return jnp.repeat(x[None, ...], num, axis=0)
  20. except TypeError:
  21. return jnp.full((num,), x)
  22. return jax.tree_util.tree_map(repeat, pytree)
  23. def merge(pytree_list: list[T]) -> T:
  24. """Merge the leaf nodes of multiple PyTrees by concatenating them along the leading axis."""
  25. # Watch out for the differences between this function and stack function.
  26. # This concatenates over existing axes and stack creates a new leading axis.
  27. return jax.tree_util.tree_map(lambda *args: jnp.concatenate(args, axis=0), *pytree_list)
  28. def stack(pytree_list: list[T]) -> T:
  29. """Merge the leaf nodes of multiple PyTrees by stacking them along a new leading axis."""
  30. return jax.tree_util.tree_map(lambda *args: jnp.stack(args), *pytree_list)
  31. def axis_length(pytree: T, axis: int = 0) -> T:
  32. """
  33. Calculate axis lengths of pytree leaves. Non-array values are assigned None. If `axis` argument is larger
  34. than array shape, we return -1.
  35. """
  36. def length(x):
  37. try:
  38. return x.shape[axis]
  39. except AttributeError:
  40. return 0
  41. except IndexError:
  42. return -1
  43. return jax.tree_util.tree_map(length, pytree)
  44. class NoLengthError(Exception):
  45. pass
  46. def data_length(pytree: T, axis: int = 0, ignore_non_array_leaves: bool = False) -> int:
  47. """Assign a length to a pytree from shapes of arrays stored within it."""
  48. leading_dim = axis_length(pytree, axis=axis)
  49. lengths, structure = jax.tree_util.tree_flatten(leading_dim)
  50. # we want to exclude leaves that are not arrays (length None) as this might be some number auxiliary data
  51. if ignore_non_array_leaves:
  52. lengths = [x for x in lengths if x > 0]
  53. if all_equal(lengths) and len(lengths) > 0:
  54. return lengths[0]
  55. raise NoLengthError(f'Pytree of type {type(pytree)} with structure {structure} cannot have a length assigned to it '
  56. f'over axis {axis}.')
  57. def check_if_pytree_of_arrays(pytree: T, allow_numbers: bool = True) -> bool:
  58. """Check if a pytree consists only of ndarray leaves, with possibly allowed number leaves for auxiliary data."""
  59. leaves, _ = jax.tree_util.tree_flatten(pytree)
  60. # print('Leaves type: ', [type(l) for l in leaves])
  61. for leaf in leaves:
  62. if not isinstance(leaf, jnp.ndarray):
  63. if allow_numbers and not hasattr(leaf, "__len__"):
  64. continue
  65. return False
  66. return True
  67. def get_slice(pytree: T, start_idx: int, slice_length: int) -> T:
  68. """
  69. Return a new PyTree with the same structure as the input PyTree, but each leaf array sliced along the leading axis.
  70. Args:
  71. pytree: A PyTree of ndarrays, representing the input data.
  72. start_idx: An integer, the starting index of the slice.
  73. slice_length: An integer, the length of the slice.
  74. Returns:
  75. A PyTree of ndarrays, with each element sliced along the first axis.
  76. """
  77. return jax.tree_util.tree_map(lambda x: x[start_idx:start_idx+slice_length], pytree)
  78. def split_to_list(pytree: T) -> list[T]:
  79. if not check_if_pytree_of_arrays(pytree, allow_numbers=True):
  80. raise ValueError('Should get a pytree of arrays.')
  81. length = data_length(pytree)
  82. return [jax.tree_util.tree_map(lambda x: x[idx], pytree) for idx in range(length)]
  83. def map_over_leading_leaf_dimension(f: Callable[[T, Any], T], *pytrees: T, **kwargs):
  84. """
  85. Maps a function that takes a pytree over the leading leaf dimensions by splitting the pytree over this leading
  86. dimension and stacking the results into a single object. All function *args must be splittable pytrees.
  87. """
  88. split_tree = split_to_list(pytrees)
  89. results = [f(*tree, **kwargs) for tree in split_tree]
  90. return stack(results)
  91. def num_dimensions(pytree: T) -> T:
  92. """Determine the number of dimensions for each array in a pytree of arrays."""
  93. def num_dim(x):
  94. try:
  95. return len(x.shape)
  96. except AttributeError:
  97. return 0
  98. return jax.tree_util.tree_map(num_dim, pytree)
  99. def num_extra_dimensions(pytree: T, og_pytree: T) -> int:
  100. """Determine the number of extra leading dimensions compared to some original pytree of the same kind."""
  101. pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
  102. og_pytree_dims, _ = jax.tree_util.tree_flatten(num_dimensions(og_pytree))
  103. dim_differences = [d1 - d2 for d1, d2 in zip(pytree_dims, og_pytree_dims)]
  104. if all_equal(dim_differences):
  105. return dim_differences[0]
  106. raise ValueError('No consistent extra leading dimensions found.')
  107. def leaf_norm(pytree: T, num_ld: int = 0, keepdims: bool = True) -> T:
  108. def unitwise_norm(x: Array) -> Array:
  109. squared_norm = jnp.sum(x ** 2, keepdims=keepdims)
  110. return jnp.sqrt(squared_norm)
  111. if num_ld == 0:
  112. return jax.tree_util.tree_map(unitwise_norm, pytree)
  113. if num_ld == 1:
  114. return jax.tree_util.tree_map(jax.vmap(unitwise_norm), pytree)
  115. if num_ld == 2:
  116. return jax.tree_util.tree_map(jax.vmap(jax.vmap(unitwise_norm)), pytree)
  117. raise NotImplementedError('Cannot calculate the leaf_norm of leaves with 3 or more common leading dimensions.')
  118. def broadcast_to(pytree1: T, pytree2: T) -> T:
  119. """Broadcast all leaf arrays from one pytree to the shape of arrays in another pytree of the same type."""
  120. if not check_if_pytree_of_arrays(pytree1, allow_numbers=False) \
  121. and not check_if_pytree_of_arrays(pytree2, allow_numbers=False):
  122. raise ValueError('Should get pytrees of arrays.')
  123. return jax.tree_util.tree_map(lambda x, y: jnp.broadcast_to(x, y.shape), pytree1, pytree2)
  124. def all_data_to_single_array(pytree: T) -> jnp.ndarray:
  125. if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
  126. raise ValueError('Should get a pytree of arrays.')
  127. arrays, _ = jax.tree_util.tree_flatten(pytree)
  128. array_dims, _ = jax.tree_util.tree_flatten(num_dimensions(pytree))
  129. max_tree_dim = max(array_dims)
  130. reshaped_arrays = []
  131. for array, dim in zip(arrays, array_dims):
  132. reshaped_arrays.append(array.reshape(array.shape + (1,) * (max_tree_dim - dim)))
  133. return jnp.hstack(reshaped_arrays)
  134. def extend_with_last_element(pytree: T, desired_length: int) -> T:
  135. if not check_if_pytree_of_arrays(pytree, allow_numbers=False):
  136. raise ValueError('Should get a pytree of arrays.')
  137. length = data_length(pytree, axis=0)
  138. if length > desired_length:
  139. raise ValueError('Data length is larger than desired length so it cannot be extended.')
  140. append_length = desired_length - length
  141. def leaf_append(x):
  142. append_shape = list(x.shape)
  143. append_shape[0] = append_length
  144. return jnp.append(x, jnp.full(append_shape, x[-1]), axis=0)
  145. return jax.tree_util.tree_map(leaf_append, pytree)