123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- from __future__ import annotations
- from typing import Callable, TypeVar
- import copy
- import jax
- import multiprocess as mp
- from functools import partial
- import math
- from curvature_assembly import pytree_transf
- import warnings
- T = TypeVar('T')
- def all_equal(iterator) -> bool:
- """Check if all elements in an iterator are equal."""
- iterator = iter(iterator)
- try:
- first = next(iterator)
- except StopIteration:
- return True
- return all(first == x for x in iterator)
- def get_argument_length(arg) -> int:
- """
- Get length of arg if arg is a sequence or assign a length to arg if arg is a dataclass of arrays.
- Raises TypeError if length cannot be assigned to this argument.
- """
- if pytree_transf.check_if_pytree_of_arrays(arg):
- return pytree_transf.data_length(arg)
- try:
- return len(arg)
- except TypeError:
- raise TypeError(f'Cannot assign lenght to argument of type {type(arg)}')
- def full_sequence_length(map_argnums: int | tuple, *args):
- """
- Return the length of the sequence(s) over which a function will be mapped.
- Args:
- map_argnums: The positional index(es) of the argument(s) over which the function will be mapped.
- args: The arguments passed to the function.
- Returns:
- The length of the sequence(s) over which the function will be mapped.
- Raises:
- ValueError: If any of the arguments over which the function is mapped is not a sequence, or if the sequences
- have different lengths.
- """
- if isinstance(map_argnums, int):
- map_argnums = (map_argnums,)
- lengths = []
- for argnum in map_argnums:
- try:
- lengths.append(get_argument_length(args[argnum]))
- except TypeError:
- raise ValueError(f'Each argument over which a function is mapped should be a sequence '
- f'or a pytree of arrays, got {type(args[argnum])} for argument {argnum}.')
- if all_equal(lengths):
- return lengths[0]
- else:
- raise ValueError(f'All arguments over which we map should be of the same length,'
- f'got lengths {lengths} for args {map_argnums}, respectively.')
- def canonicalize_args(map_argnums: int | tuple, *args) -> list:
- """
- Create a canonical list of arguments consisting of sequences with equal
- numbers of elements.
- Args:
- map_argnums: Argument indices that are already sequences of length num.
- *args: A variable number of arguments to be canonicalized. Each argument
- should either be a sequence (list, tuple, etc.) with length num,
- or a non-sequence type that can be repeated num times to create
- a sequence.
- Returns:
- A canonical list of arguments. Each element of the list is a sequence
- with `num` elements, either copied from the input argument or created by
- repeating a non-sequence argument num times.
- """
- if isinstance(map_argnums, int):
- map_argnums = (map_argnums,)
- num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths
- canonical_args = []
- for i, arg in enumerate(args):
- if not i in map_argnums:
- canonical_args.append([copy.deepcopy(arg) for _ in range(num)])
- else:
- canonical_args.append(arg)
- return canonical_args
- def canonicalize_args_pmap(map_argnums: int | tuple, *args) -> list:
- """
- Create a canonical list of arguments consisting of dataclasses with all Array
- fields having the same leading dimension length.
- Args:
- map_argnums: Argument indices that already store arrays of length num to be mapped over.
- *args: A variable number of arguments to be canonicalized.
- Returns:
- A canonical list of arguments. Each element of the list is a sequence
- with `num` elements, either copied from the input argument or created by
- repeating a non-sequence argument num times.
- """
- if isinstance(map_argnums, int):
- map_argnums = (map_argnums,)
- num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths
- canonical_args = []
- for i, arg in enumerate(args):
- if not i in map_argnums:
- canonical_args.append(pytree_transf.repeat_fields(arg, num))
- try:
- if pytree_transf.data_length(arg) == num:
- warnings.warn(f"Added a new leading dimension to argument {i} with existing leading dimension "
- f"length that is the same as the length of the mapped argument(s). Make sure that "
- f"this is the desired behavior and this argument should not also be mapped over.")
- except pytree_transf.NoLengthError:
- pass
- else:
- canonical_args.append(arg)
- return canonical_args
- def fill_to_length_num(num, *args):
- """
- Extends each argument in `args` with its last element until its length is a multiple of `num`.
-
- Args:
- num: The multiple to which the length of each argument should be extended.
- args: A variable number of arguments to be extended.
-
- Returns:
- A list of the extended arguments.
- """
- filled_args = []
- for arg in args:
- filled_args.append(pytree_transf.extend_with_last_element(arg, num))
- return filled_args
- def get_slice(start_idx: int, slice_length: int, *args) -> list:
- """
- Return a slice of a specified length from each argument in a variable-length
- list of sequences.
- Args:
- start_idx: The starting index of the slice to be extracted from each sequence.
- slice_length: The length of the slice to be extracted from each sequence.
- *args: A variable-length list of sequences.
- Returns:
- A list of slices where each slice is extracted from the corresponding
- sequence in `args` starting at index `start_idx` and extending for `slice_length`
- elements.
- """
- if start_idx < 0 or slice_length < 0:
- raise ValueError("Start index and slice length must be non-negative.")
- return [arg[start_idx:start_idx+slice_length] for arg in args]
- def list_flatten(lst: list) -> list:
- """
- Flatten a list of nested lists.
- """
- flat_list = []
- for sublist in lst:
- for item in sublist:
- flat_list.append(item)
- return flat_list
- def segment_args_pool(num: int, num_cores: int, *args) -> list:
- """
- Segment the input arguments into a list of segments, with each segment containing
- a fraction of the arguments. This function can be used to split up a large computation
- across multiple processor cores using the multiprocess.Pool to speed up processing.
- Args:
- num: The total number of items to be segmented across cores.
- num_cores: The number of cores to be used for processing.
- *args: A variable-length list of sequences. Each sequence should be indexable
- and have a length equal to `num`.
- Returns:
- A list of segments, where each segment is a list of argument values
- extracted from the corresponding index range in the input sequences. The output
- list will have length `num_cores`, and each segment will have the
- same number of items, except for the last one that gets the remaining number of items.
- """
- segment_length = int(math.ceil(num / num_cores))
- args_list = []
- for i in range(num_cores):
- args_list.append(get_slice(segment_length*i, segment_length, *args))
- return args_list
- def segment_args_pmap(num: int, num_devices: int, *args) -> list:
- """
- Segment the input arguments into a list of segments, with each segment containing
- a fraction of the arguments. This function can be used to split up a large computation
- across multiple computational units using jax.pmap to speed up processing.
- Args:
- num: The total number of items to be segmented across cores.
- num_devices: The number of devices to be used for processing.
- *args: A variable-length list of sequences. Each sequence should be indexable
- and have a length equal to `num`.
- Returns:
- A list of segments, where each segment is a list of argument values
- extracted from the corresponding index range in the input sequences. The output
- list will have length num_pmap_calculations, and each segment will have the
- same number of items, except for the last one that gets the remaining number of items.
- """
- num_pmap_calculations = int(math.ceil(num / num_devices))
- args_list = []
- for i in range(num_pmap_calculations):
- args_list.append(pytree_transf.get_slice(args, num_devices * i, num_devices))
- return args_list
- def cpu_segment_dispatch(f: Callable[..., T], num_cores: int, map_argnums: int | tuple = 0) -> Callable[..., list[T]]:
- """
- Embarrassingly-parallel function evaluation over multiple cores. Divides the input arguments into
- segments and dispatches each segment to a different processor core. The idea of such implementation is
- that jax functions that the compilation of jax functions only happens once at each core.
- Args:
- f: A function to be executed on the different input arguments in parallel.
- Parallelization over keyword arguments is not supported.
- num_cores: The number of processor cores to be used for parallel processing.
- map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
- Returns:
- A new function that takes the same arguments as `f` and dispatches
- the input arguments across multiple processor cores for parallel processing.
- The returned function will return a list of the results from each
- parallel processing segment.
- """
- if num_cores <= 1:
- raise ValueError("The number of cores must be a positive integer.")
- def sequential_f(args: list, **kwargs):
- seq_results = []
- for i, a in enumerate(zip(*args)):
- seq_results.append(f(*a, **kwargs))
- return seq_results
- def parallel_f(*args, **kwargs) -> list:
- canonical_args = canonicalize_args(map_argnums, *args)
- num = full_sequence_length(map_argnums, *args)
- threads = mp.Pool(num_cores)
- results = threads.map(partial(sequential_f, **kwargs), segment_args_pool(num, num_cores, *canonical_args))
- return list_flatten(results)
- return parallel_f
- def pmap_segment_dispatch(f: Callable[..., T],
- map_argnums: int | tuple[int, ...] = 0,
- backend: str = 'cpu',
- pmap_jit: bool = False) -> Callable[..., T]:
- """
- Embarrassingly-parallel function evaluation over multiple jax devices. Divides the input arguments into
- segments and dispatches each segment to a different processor core.
- Args:
- f: A function to be mapped over the leading axis of `map_argnums` arguments in parallel.
- Parallelization over keyword arguments is not supported.
- map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
- backend: jax backend, 'cpu' or 'gpu'. For parallelization over multiple cpu cores,
- os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=XX' should be set
- at the beginning of the main script, with XX the number of cores.
- pmap_jit: bool, whether to jit the pmap-ed function. This will raise a warning but can speed up
- parallel calculations when num > device_count, at least on cpu.
- Returns:
- A new function that takes the same arguments as `f` and dispatches `map_argnums`
- input arguments over the leading axis across multiple devices for parallel processing.
- All return values of the mapped function will have a leading axis with a length corresponding
- to the length of the `map_argnums` input arguments.
- """
- device_count = jax.local_device_count(backend=backend)
- if backend == 'cpu' and device_count == 1:
- raise ValueError('Got cpu backend for parallelization but only 1 cpu device is available. '
- 'Try setting os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=XX\" '
- 'at the beginning of the script.')
- def parallel_f(*args, **kwargs):
- canonical_args = canonicalize_args_pmap(map_argnums, *args)
- num = full_sequence_length(map_argnums, *args)
- def pmap_f(*x):
- return jax.pmap(partial(f, **kwargs))(*x)
- if pmap_jit:
- # jit(pmap) raises UserWarning (https://github.com/google/jax/issues/2926) but using jit here prevents
- # pmap seemingly tracing the code in every iteration of the following for loop, which results in
- # faster computation when num > device_count
- pmap_f = jax.jit(pmap_f)
- # when jit-ing pmap, merging of results doesn't work if segments have different lengths, so we
- # expand the arguments to a multiple of device_count
- canonical_args = fill_to_length_num(math.ceil(num / device_count) * device_count, *canonical_args)
- results = []
- for arguments in segment_args_pmap(num, device_count, *canonical_args):
- r = pmap_f(*arguments)
- results.append(r)
- return pytree_transf.get_slice(pytree_transf.merge(results), 0, num)
- return parallel_f
|