from __future__ import annotations from typing import Callable, TypeVar import copy import jax import multiprocess as mp from functools import partial import math from curvature_assembly import pytree_transf import warnings T = TypeVar('T') def all_equal(iterator) -> bool: """Check if all elements in an iterator are equal.""" iterator = iter(iterator) try: first = next(iterator) except StopIteration: return True return all(first == x for x in iterator) def get_argument_length(arg) -> int: """ Get length of arg if arg is a sequence or assign a length to arg if arg is a dataclass of arrays. Raises TypeError if length cannot be assigned to this argument. """ if pytree_transf.check_if_pytree_of_arrays(arg): return pytree_transf.data_length(arg) try: return len(arg) except TypeError: raise TypeError(f'Cannot assign lenght to argument of type {type(arg)}') def full_sequence_length(map_argnums: int | tuple, *args): """ Return the length of the sequence(s) over which a function will be mapped. Args: map_argnums: The positional index(es) of the argument(s) over which the function will be mapped. args: The arguments passed to the function. Returns: The length of the sequence(s) over which the function will be mapped. Raises: ValueError: If any of the arguments over which the function is mapped is not a sequence, or if the sequences have different lengths. """ if isinstance(map_argnums, int): map_argnums = (map_argnums,) lengths = [] for argnum in map_argnums: try: lengths.append(get_argument_length(args[argnum])) except TypeError: raise ValueError(f'Each argument over which a function is mapped should be a sequence ' f'or a pytree of arrays, got {type(args[argnum])} for argument {argnum}.') if all_equal(lengths): return lengths[0] else: raise ValueError(f'All arguments over which we map should be of the same length,' f'got lengths {lengths} for args {map_argnums}, respectively.') def canonicalize_args(map_argnums: int | tuple, *args) -> list: """ Create a canonical list of arguments consisting of sequences with equal numbers of elements. Args: map_argnums: Argument indices that are already sequences of length num. *args: A variable number of arguments to be canonicalized. Each argument should either be a sequence (list, tuple, etc.) with length num, or a non-sequence type that can be repeated num times to create a sequence. Returns: A canonical list of arguments. Each element of the list is a sequence with `num` elements, either copied from the input argument or created by repeating a non-sequence argument num times. """ if isinstance(map_argnums, int): map_argnums = (map_argnums,) num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths canonical_args = [] for i, arg in enumerate(args): if not i in map_argnums: canonical_args.append([copy.deepcopy(arg) for _ in range(num)]) else: canonical_args.append(arg) return canonical_args def canonicalize_args_pmap(map_argnums: int | tuple, *args) -> list: """ Create a canonical list of arguments consisting of dataclasses with all Array fields having the same leading dimension length. Args: map_argnums: Argument indices that already store arrays of length num to be mapped over. *args: A variable number of arguments to be canonicalized. Returns: A canonical list of arguments. Each element of the list is a sequence with `num` elements, either copied from the input argument or created by repeating a non-sequence argument num times. """ if isinstance(map_argnums, int): map_argnums = (map_argnums,) num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths canonical_args = [] for i, arg in enumerate(args): if not i in map_argnums: canonical_args.append(pytree_transf.repeat_fields(arg, num)) try: if pytree_transf.data_length(arg) == num: warnings.warn(f"Added a new leading dimension to argument {i} with existing leading dimension " f"length that is the same as the length of the mapped argument(s). Make sure that " f"this is the desired behavior and this argument should not also be mapped over.") except pytree_transf.NoLengthError: pass else: canonical_args.append(arg) return canonical_args def fill_to_length_num(num, *args): """ Extends each argument in `args` with its last element until its length is a multiple of `num`. Args: num: The multiple to which the length of each argument should be extended. args: A variable number of arguments to be extended. Returns: A list of the extended arguments. """ filled_args = [] for arg in args: filled_args.append(pytree_transf.extend_with_last_element(arg, num)) return filled_args def get_slice(start_idx: int, slice_length: int, *args) -> list: """ Return a slice of a specified length from each argument in a variable-length list of sequences. Args: start_idx: The starting index of the slice to be extracted from each sequence. slice_length: The length of the slice to be extracted from each sequence. *args: A variable-length list of sequences. Returns: A list of slices where each slice is extracted from the corresponding sequence in `args` starting at index `start_idx` and extending for `slice_length` elements. """ if start_idx < 0 or slice_length < 0: raise ValueError("Start index and slice length must be non-negative.") return [arg[start_idx:start_idx+slice_length] for arg in args] def list_flatten(lst: list) -> list: """ Flatten a list of nested lists. """ flat_list = [] for sublist in lst: for item in sublist: flat_list.append(item) return flat_list def segment_args_pool(num: int, num_cores: int, *args) -> list: """ Segment the input arguments into a list of segments, with each segment containing a fraction of the arguments. This function can be used to split up a large computation across multiple processor cores using the multiprocess.Pool to speed up processing. Args: num: The total number of items to be segmented across cores. num_cores: The number of cores to be used for processing. *args: A variable-length list of sequences. Each sequence should be indexable and have a length equal to `num`. Returns: A list of segments, where each segment is a list of argument values extracted from the corresponding index range in the input sequences. The output list will have length `num_cores`, and each segment will have the same number of items, except for the last one that gets the remaining number of items. """ segment_length = int(math.ceil(num / num_cores)) args_list = [] for i in range(num_cores): args_list.append(get_slice(segment_length*i, segment_length, *args)) return args_list def segment_args_pmap(num: int, num_devices: int, *args) -> list: """ Segment the input arguments into a list of segments, with each segment containing a fraction of the arguments. This function can be used to split up a large computation across multiple computational units using jax.pmap to speed up processing. Args: num: The total number of items to be segmented across cores. num_devices: The number of devices to be used for processing. *args: A variable-length list of sequences. Each sequence should be indexable and have a length equal to `num`. Returns: A list of segments, where each segment is a list of argument values extracted from the corresponding index range in the input sequences. The output list will have length num_pmap_calculations, and each segment will have the same number of items, except for the last one that gets the remaining number of items. """ num_pmap_calculations = int(math.ceil(num / num_devices)) args_list = [] for i in range(num_pmap_calculations): args_list.append(pytree_transf.get_slice(args, num_devices * i, num_devices)) return args_list def cpu_segment_dispatch(f: Callable[..., T], num_cores: int, map_argnums: int | tuple = 0) -> Callable[..., list[T]]: """ Embarrassingly-parallel function evaluation over multiple cores. Divides the input arguments into segments and dispatches each segment to a different processor core. The idea of such implementation is that jax functions that the compilation of jax functions only happens once at each core. Args: f: A function to be executed on the different input arguments in parallel. Parallelization over keyword arguments is not supported. num_cores: The number of processor cores to be used for parallel processing. map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0. Returns: A new function that takes the same arguments as `f` and dispatches the input arguments across multiple processor cores for parallel processing. The returned function will return a list of the results from each parallel processing segment. """ if num_cores <= 1: raise ValueError("The number of cores must be a positive integer.") def sequential_f(args: list, **kwargs): seq_results = [] for i, a in enumerate(zip(*args)): seq_results.append(f(*a, **kwargs)) return seq_results def parallel_f(*args, **kwargs) -> list: canonical_args = canonicalize_args(map_argnums, *args) num = full_sequence_length(map_argnums, *args) threads = mp.Pool(num_cores) results = threads.map(partial(sequential_f, **kwargs), segment_args_pool(num, num_cores, *canonical_args)) return list_flatten(results) return parallel_f def pmap_segment_dispatch(f: Callable[..., T], map_argnums: int | tuple[int, ...] = 0, backend: str = 'cpu', pmap_jit: bool = False) -> Callable[..., T]: """ Embarrassingly-parallel function evaluation over multiple jax devices. Divides the input arguments into segments and dispatches each segment to a different processor core. Args: f: A function to be mapped over the leading axis of `map_argnums` arguments in parallel. Parallelization over keyword arguments is not supported. map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0. backend: jax backend, 'cpu' or 'gpu'. For parallelization over multiple cpu cores, os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=XX' should be set at the beginning of the main script, with XX the number of cores. pmap_jit: bool, whether to jit the pmap-ed function. This will raise a warning but can speed up parallel calculations when num > device_count, at least on cpu. Returns: A new function that takes the same arguments as `f` and dispatches `map_argnums` input arguments over the leading axis across multiple devices for parallel processing. All return values of the mapped function will have a leading axis with a length corresponding to the length of the `map_argnums` input arguments. """ device_count = jax.local_device_count(backend=backend) if backend == 'cpu' and device_count == 1: raise ValueError('Got cpu backend for parallelization but only 1 cpu device is available. ' 'Try setting os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=XX\" ' 'at the beginning of the script.') def parallel_f(*args, **kwargs): canonical_args = canonicalize_args_pmap(map_argnums, *args) num = full_sequence_length(map_argnums, *args) def pmap_f(*x): return jax.pmap(partial(f, **kwargs))(*x) if pmap_jit: # jit(pmap) raises UserWarning (https://github.com/google/jax/issues/2926) but using jit here prevents # pmap seemingly tracing the code in every iteration of the following for loop, which results in # faster computation when num > device_count pmap_f = jax.jit(pmap_f) # when jit-ing pmap, merging of results doesn't work if segments have different lengths, so we # expand the arguments to a multiple of device_count canonical_args = fill_to_length_num(math.ceil(num / device_count) * device_count, *canonical_args) results = [] for arguments in segment_args_pmap(num, device_count, *canonical_args): r = pmap_f(*arguments) results.append(r) return pytree_transf.get_slice(pytree_transf.merge(results), 0, num) return parallel_f