parallelization.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from __future__ import annotations
  2. from typing import Callable, TypeVar
  3. import copy
  4. import jax
  5. import multiprocess as mp
  6. from functools import partial
  7. import math
  8. from curvature_assembly import pytree_transf
  9. import warnings
  10. T = TypeVar('T')
  11. def all_equal(iterator) -> bool:
  12. """Check if all elements in an iterator are equal."""
  13. iterator = iter(iterator)
  14. try:
  15. first = next(iterator)
  16. except StopIteration:
  17. return True
  18. return all(first == x for x in iterator)
  19. def get_argument_length(arg) -> int:
  20. """
  21. Get length of arg if arg is a sequence or assign a length to arg if arg is a dataclass of arrays.
  22. Raises TypeError if length cannot be assigned to this argument.
  23. """
  24. if pytree_transf.check_if_pytree_of_arrays(arg):
  25. return pytree_transf.data_length(arg)
  26. try:
  27. return len(arg)
  28. except TypeError:
  29. raise TypeError(f'Cannot assign lenght to argument of type {type(arg)}')
  30. def full_sequence_length(map_argnums: int | tuple, *args):
  31. """
  32. Return the length of the sequence(s) over which a function will be mapped.
  33. Args:
  34. map_argnums: The positional index(es) of the argument(s) over which the function will be mapped.
  35. args: The arguments passed to the function.
  36. Returns:
  37. The length of the sequence(s) over which the function will be mapped.
  38. Raises:
  39. ValueError: If any of the arguments over which the function is mapped is not a sequence, or if the sequences
  40. have different lengths.
  41. """
  42. if isinstance(map_argnums, int):
  43. map_argnums = (map_argnums,)
  44. lengths = []
  45. for argnum in map_argnums:
  46. try:
  47. lengths.append(get_argument_length(args[argnum]))
  48. except TypeError:
  49. raise ValueError(f'Each argument over which a function is mapped should be a sequence '
  50. f'or a pytree of arrays, got {type(args[argnum])} for argument {argnum}.')
  51. if all_equal(lengths):
  52. return lengths[0]
  53. else:
  54. raise ValueError(f'All arguments over which we map should be of the same length,'
  55. f'got lengths {lengths} for args {map_argnums}, respectively.')
  56. def canonicalize_args(map_argnums: int | tuple, *args) -> list:
  57. """
  58. Create a canonical list of arguments consisting of sequences with equal
  59. numbers of elements.
  60. Args:
  61. map_argnums: Argument indices that are already sequences of length num.
  62. *args: A variable number of arguments to be canonicalized. Each argument
  63. should either be a sequence (list, tuple, etc.) with length num,
  64. or a non-sequence type that can be repeated num times to create
  65. a sequence.
  66. Returns:
  67. A canonical list of arguments. Each element of the list is a sequence
  68. with `num` elements, either copied from the input argument or created by
  69. repeating a non-sequence argument num times.
  70. """
  71. if isinstance(map_argnums, int):
  72. map_argnums = (map_argnums,)
  73. num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths
  74. canonical_args = []
  75. for i, arg in enumerate(args):
  76. if not i in map_argnums:
  77. canonical_args.append([copy.deepcopy(arg) for _ in range(num)])
  78. else:
  79. canonical_args.append(arg)
  80. return canonical_args
  81. def canonicalize_args_pmap(map_argnums: int | tuple, *args) -> list:
  82. """
  83. Create a canonical list of arguments consisting of dataclasses with all Array
  84. fields having the same leading dimension length.
  85. Args:
  86. map_argnums: Argument indices that already store arrays of length num to be mapped over.
  87. *args: A variable number of arguments to be canonicalized.
  88. Returns:
  89. A canonical list of arguments. Each element of the list is a sequence
  90. with `num` elements, either copied from the input argument or created by
  91. repeating a non-sequence argument num times.
  92. """
  93. if isinstance(map_argnums, int):
  94. map_argnums = (map_argnums,)
  95. num = full_sequence_length(map_argnums, *args) # ValueError if mapped arguments have different lengths
  96. canonical_args = []
  97. for i, arg in enumerate(args):
  98. if not i in map_argnums:
  99. canonical_args.append(pytree_transf.repeat_fields(arg, num))
  100. try:
  101. if pytree_transf.data_length(arg) == num:
  102. warnings.warn(f"Added a new leading dimension to argument {i} with existing leading dimension "
  103. f"length that is the same as the length of the mapped argument(s). Make sure that "
  104. f"this is the desired behavior and this argument should not also be mapped over.")
  105. except pytree_transf.NoLengthError:
  106. pass
  107. else:
  108. canonical_args.append(arg)
  109. return canonical_args
  110. def fill_to_length_num(num, *args):
  111. """
  112. Extends each argument in `args` with its last element until its length is a multiple of `num`.
  113. Args:
  114. num: The multiple to which the length of each argument should be extended.
  115. args: A variable number of arguments to be extended.
  116. Returns:
  117. A list of the extended arguments.
  118. """
  119. filled_args = []
  120. for arg in args:
  121. filled_args.append(pytree_transf.extend_with_last_element(arg, num))
  122. return filled_args
  123. def get_slice(start_idx: int, slice_length: int, *args) -> list:
  124. """
  125. Return a slice of a specified length from each argument in a variable-length
  126. list of sequences.
  127. Args:
  128. start_idx: The starting index of the slice to be extracted from each sequence.
  129. slice_length: The length of the slice to be extracted from each sequence.
  130. *args: A variable-length list of sequences.
  131. Returns:
  132. A list of slices where each slice is extracted from the corresponding
  133. sequence in `args` starting at index `start_idx` and extending for `slice_length`
  134. elements.
  135. """
  136. if start_idx < 0 or slice_length < 0:
  137. raise ValueError("Start index and slice length must be non-negative.")
  138. return [arg[start_idx:start_idx+slice_length] for arg in args]
  139. def list_flatten(lst: list) -> list:
  140. """
  141. Flatten a list of nested lists.
  142. """
  143. flat_list = []
  144. for sublist in lst:
  145. for item in sublist:
  146. flat_list.append(item)
  147. return flat_list
  148. def segment_args_pool(num: int, num_cores: int, *args) -> list:
  149. """
  150. Segment the input arguments into a list of segments, with each segment containing
  151. a fraction of the arguments. This function can be used to split up a large computation
  152. across multiple processor cores using the multiprocess.Pool to speed up processing.
  153. Args:
  154. num: The total number of items to be segmented across cores.
  155. num_cores: The number of cores to be used for processing.
  156. *args: A variable-length list of sequences. Each sequence should be indexable
  157. and have a length equal to `num`.
  158. Returns:
  159. A list of segments, where each segment is a list of argument values
  160. extracted from the corresponding index range in the input sequences. The output
  161. list will have length `num_cores`, and each segment will have the
  162. same number of items, except for the last one that gets the remaining number of items.
  163. """
  164. segment_length = int(math.ceil(num / num_cores))
  165. args_list = []
  166. for i in range(num_cores):
  167. args_list.append(get_slice(segment_length*i, segment_length, *args))
  168. return args_list
  169. def segment_args_pmap(num: int, num_devices: int, *args) -> list:
  170. """
  171. Segment the input arguments into a list of segments, with each segment containing
  172. a fraction of the arguments. This function can be used to split up a large computation
  173. across multiple computational units using jax.pmap to speed up processing.
  174. Args:
  175. num: The total number of items to be segmented across cores.
  176. num_devices: The number of devices to be used for processing.
  177. *args: A variable-length list of sequences. Each sequence should be indexable
  178. and have a length equal to `num`.
  179. Returns:
  180. A list of segments, where each segment is a list of argument values
  181. extracted from the corresponding index range in the input sequences. The output
  182. list will have length num_pmap_calculations, and each segment will have the
  183. same number of items, except for the last one that gets the remaining number of items.
  184. """
  185. num_pmap_calculations = int(math.ceil(num / num_devices))
  186. args_list = []
  187. for i in range(num_pmap_calculations):
  188. args_list.append(pytree_transf.get_slice(args, num_devices * i, num_devices))
  189. return args_list
  190. def cpu_segment_dispatch(f: Callable[..., T], num_cores: int, map_argnums: int | tuple = 0) -> Callable[..., list[T]]:
  191. """
  192. Embarrassingly-parallel function evaluation over multiple cores. Divides the input arguments into
  193. segments and dispatches each segment to a different processor core. The idea of such implementation is
  194. that jax functions that the compilation of jax functions only happens once at each core.
  195. Args:
  196. f: A function to be executed on the different input arguments in parallel.
  197. Parallelization over keyword arguments is not supported.
  198. num_cores: The number of processor cores to be used for parallel processing.
  199. map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
  200. Returns:
  201. A new function that takes the same arguments as `f` and dispatches
  202. the input arguments across multiple processor cores for parallel processing.
  203. The returned function will return a list of the results from each
  204. parallel processing segment.
  205. """
  206. if num_cores <= 1:
  207. raise ValueError("The number of cores must be a positive integer.")
  208. def sequential_f(args: list, **kwargs):
  209. seq_results = []
  210. for i, a in enumerate(zip(*args)):
  211. seq_results.append(f(*a, **kwargs))
  212. return seq_results
  213. def parallel_f(*args, **kwargs) -> list:
  214. canonical_args = canonicalize_args(map_argnums, *args)
  215. num = full_sequence_length(map_argnums, *args)
  216. threads = mp.Pool(num_cores)
  217. results = threads.map(partial(sequential_f, **kwargs), segment_args_pool(num, num_cores, *canonical_args))
  218. return list_flatten(results)
  219. return parallel_f
  220. def pmap_segment_dispatch(f: Callable[..., T],
  221. map_argnums: int | tuple[int, ...] = 0,
  222. backend: str = 'cpu',
  223. pmap_jit: bool = False) -> Callable[..., T]:
  224. """
  225. Embarrassingly-parallel function evaluation over multiple jax devices. Divides the input arguments into
  226. segments and dispatches each segment to a different processor core.
  227. Args:
  228. f: A function to be mapped over the leading axis of `map_argnums` arguments in parallel.
  229. Parallelization over keyword arguments is not supported.
  230. map_argnums: index or a tuple of indices of function `f` arguments to map over. Default is 0.
  231. backend: jax backend, 'cpu' or 'gpu'. For parallelization over multiple cpu cores,
  232. os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=XX' should be set
  233. at the beginning of the main script, with XX the number of cores.
  234. pmap_jit: bool, whether to jit the pmap-ed function. This will raise a warning but can speed up
  235. parallel calculations when num > device_count, at least on cpu.
  236. Returns:
  237. A new function that takes the same arguments as `f` and dispatches `map_argnums`
  238. input arguments over the leading axis across multiple devices for parallel processing.
  239. All return values of the mapped function will have a leading axis with a length corresponding
  240. to the length of the `map_argnums` input arguments.
  241. """
  242. device_count = jax.local_device_count(backend=backend)
  243. if backend == 'cpu' and device_count == 1:
  244. raise ValueError('Got cpu backend for parallelization but only 1 cpu device is available. '
  245. 'Try setting os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=XX\" '
  246. 'at the beginning of the script.')
  247. def parallel_f(*args, **kwargs):
  248. canonical_args = canonicalize_args_pmap(map_argnums, *args)
  249. num = full_sequence_length(map_argnums, *args)
  250. def pmap_f(*x):
  251. return jax.pmap(partial(f, **kwargs))(*x)
  252. if pmap_jit:
  253. # jit(pmap) raises UserWarning (https://github.com/google/jax/issues/2926) but using jit here prevents
  254. # pmap seemingly tracing the code in every iteration of the following for loop, which results in
  255. # faster computation when num > device_count
  256. pmap_f = jax.jit(pmap_f)
  257. # when jit-ing pmap, merging of results doesn't work if segments have different lengths, so we
  258. # expand the arguments to a multiple of device_count
  259. canonical_args = fill_to_length_num(math.ceil(num / device_count) * device_count, *canonical_args)
  260. results = []
  261. for arguments in segment_args_pmap(num, device_count, *canonical_args):
  262. r = pmap_f(*arguments)
  263. results.append(r)
  264. return pytree_transf.get_slice(pytree_transf.merge(results), 0, num)
  265. return parallel_f