fit.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from __future__ import annotations
  2. import optax
  3. from curvature_assembly import data_protocols, pytree_transf, oriented_particle
  4. import jax
  5. from typing import Callable, Any
  6. import jax.numpy as jnp
  7. InteractionParams = data_protocols.InteractionParams
  8. SimulationAux = data_protocols.SimulationAux
  9. BpttResults = data_protocols.BpttResults
  10. BpttSimulation = data_protocols.BpttSimulation
  11. Array = jnp.ndarray
  12. def unitwise_clip(g_norm: Array,
  13. max_norm: Array,
  14. grad: Array,
  15. div_eps: float = 1e-6) -> Array:
  16. """Applies gradient clipping unit-wise."""
  17. # This little max(., div_eps) is distinct from the normal eps and just
  18. # prevents division by zero. It technically should be impossible to engage.
  19. clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
  20. return jnp.where(g_norm < max_norm, grad, clipped_grad)
  21. def adaptive_grad_clip(grad, params, clipping: float, eps: float = 1e-3):
  22. num_ed = pytree_transf.num_extra_dimensions(grad, params)
  23. g_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(grad, keepdims=True, num_ld=num_ed), grad)
  24. p_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(params, keepdims=True), grad)
  25. # Maximum allowable leaf_norm
  26. max_norm = jax.tree_util.tree_map(
  27. lambda x: clipping * jnp.maximum(x, eps), p_norm)
  28. # If grad leaf_norm > clipping * param_norm, rescale
  29. return jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, grad)
  30. def get_grad_time_weights(grad: InteractionParams, time_weight_fn: Callable[[Array], Array], time_axis: int = 1):
  31. """
  32. Apply time-based weights to the gradients of interaction parameters.
  33. Args:
  34. grad: Gradients of interaction parameters, represented as a JAX PyTree.
  35. time_weight_fn: A function that computes the time-based weights on a rescaled time interval [0, 1].
  36. time_axis: The axis along which the time steps are represented in the `grad` PyTree. Default is 1.
  37. Returns:
  38. Gradients of interaction parameters with time-based weights applied.
  39. """
  40. num_timesteps = pytree_transf.data_length(grad, axis=time_axis)
  41. weights = time_weight_fn(jnp.linspace(0, 1, num_timesteps, endpoint=True))
  42. mean = jax.lax.cond(jnp.mean(weights) == 0, lambda x: 1., lambda x: jnp.mean(x), weights)
  43. normalized_weights = weights / mean
  44. def apply_weights(x):
  45. expand_dims = tuple(i for i in range(len(x.shape)) if i != time_axis)
  46. expanded_weights = jnp.expand_dims(normalized_weights, axis=expand_dims)
  47. return x * expanded_weights
  48. return jax.tree_util.tree_map(apply_weights, grad)
  49. def canonicalize_grad_results(grad: InteractionParams, params: InteractionParams) -> InteractionParams:
  50. """
  51. Make gradient leaf shapes compatible with interaction params, i.e. we take the average over all extra axes
  52. compared to the original params shape.
  53. """
  54. results_num_ed = pytree_transf.num_extra_dimensions(grad, params)
  55. return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=tuple(range(results_num_ed))), grad)
  56. def bounds_params(params: InteractionParams, **opt_param_dict: dict[tuple]) -> (InteractionParams, InteractionParams):
  57. params_dict = vars(params)
  58. lower_bounds = {}
  59. upper_bounds = {}
  60. for key in params_dict.keys():
  61. if key in opt_param_dict:
  62. try:
  63. lb, ub = opt_param_dict[key]
  64. if not jnp.all(lb < ub):
  65. raise ValueError(f"Lower bounds should all be smaller than upper bounds, "
  66. f"problem with parameter {key}")
  67. except TypeError:
  68. lb = None
  69. ub = None
  70. lower_bounds[key] = lb
  71. upper_bounds[key] = ub
  72. else:
  73. lower_bounds[key] = params_dict[key]
  74. upper_bounds[key] = params_dict[key]
  75. return type(params)(**lower_bounds), type(params)(**upper_bounds)
  76. def map_into_bounds(params: InteractionParams,
  77. lower_bounds: InteractionParams | None,
  78. upper_bounds: InteractionParams | None):
  79. """Map interaction parameters back into the interval between lower and upper bounds."""
  80. leaves, treedef = jax.tree_util.tree_flatten(params)
  81. # if lower and/or upper bounds are not provided, we must construct pytrees with the same structure as
  82. # parameters structure and filled with None. Note that this will fail if parameters pytree contains any
  83. # None values (will raise a ValueError).
  84. if lower_bounds is None:
  85. lower_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
  86. if upper_bounds is None:
  87. upper_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
  88. def map_to_bounds(x, xmin, xmax):
  89. try:
  90. if jnp.any(xmin > xmax):
  91. raise ValueError(f'Min bound cannot be larger than max bound, got {xmin} and {xmax}, respectively.')
  92. except TypeError:
  93. pass
  94. if xmin is not None:
  95. x = jnp.maximum(x, xmin)
  96. if xmax is not None:
  97. x = jnp.minimum(x, xmax)
  98. return x
  99. return jax.tree_util.tree_map(map_to_bounds, params, lower_bounds, upper_bounds)
  100. def normalize_param(params: InteractionParams, param_name: str, ord=None) -> InteractionParams:
  101. params_dict = vars(params)
  102. new_dict = params_dict.copy() # shallow copy is enough as values (interaction_params elements) are jax arrays
  103. new_dict[param_name] = params_dict[param_name] / jnp.linalg.norm(params_dict[param_name], keepdims=True, ord=ord)
  104. return type(params)(**new_dict)
  105. TIME_WEIGHT_FN = {'constant': lambda x: jnp.ones_like(x),
  106. 'linear': lambda x: x,
  107. 'quadratic': lambda x: x ** 2,
  108. 'exponential': lambda x: jnp.exp(x),
  109. 'step_25': lambda x: jnp.heaviside(x - 0.249, 1),
  110. 'step_50': lambda x: jnp.heaviside(x - 0.50, 1),
  111. 'step_75': lambda x: jnp.heaviside(x - 0.749, 1),
  112. 'step_100': lambda x: jnp.heaviside(x - 1., 1),
  113. 'neg_linear': lambda x: 1 - x}
  114. def fit_bptt(simulation_fn: BpttSimulation,
  115. optimizer_update: optax.TransformUpdateFn,
  116. clipping: float,
  117. grad_time_weights: str = None,
  118. param_rescalings: list[Callable[[InteractionParams], InteractionParams]] = None,
  119. lower_bounds: InteractionParams = None,
  120. upper_bounds: InteractionParams = None,
  121. time_axis: int = 1) -> Callable:
  122. """
  123. Construct the step function for meta optimization of parameters in a BPTT simulation.
  124. Args:
  125. simulation_fn: A function that performs the simulation and computes the gradients of interaction parameters.
  126. optimizer_update: A function that updates the parameters using the computed gradients.
  127. clipping: The maximum value to clip the gradients during training.
  128. grad_time_weights: String that then maps into a function that computes time-based weights for the gradients.
  129. Default is a function that assigns equal weights (ones) to all time steps.
  130. param_rescalings: A list of functions that apply rescalings
  131. or transformations to the interaction parameters during training. Default is an empty list.
  132. lower_bounds: The lower bounds for the interaction parameters. Default is None.
  133. upper_bounds: The upper bounds for the interaction parameters. Default is None.
  134. time_axis: The axis along which the time steps are represented in the gradient PyTree. Default is 1.
  135. Returns:
  136. Callable: A step function that performs one training step.
  137. """
  138. if grad_time_weights is None:
  139. grad_time_weights = 'constant'
  140. try:
  141. grad_time_weight_fn = TIME_WEIGHT_FN[grad_time_weights]
  142. except KeyError:
  143. raise ValueError(f'Invalid time weight parameter, {grad_time_weights} is not among the implemented weights.')
  144. if param_rescalings is None:
  145. param_rescalings = []
  146. param_rescalings.insert(0, oriented_particle.canonicalize_eigvals)
  147. def step(params: InteractionParams,
  148. opt_state: optax.OptState,
  149. md_state: Any,
  150. aux: SimulationAux) -> (InteractionParams, optax.OptState,
  151. BpttResults, SimulationAux, InteractionParams):
  152. aux = aux.reset_empty()
  153. bptt_results, aux = simulation_fn(params, md_state, aux)
  154. grad_clipped = adaptive_grad_clip(bptt_results.grad, params, clipping)
  155. grad_weighted = get_grad_time_weights(grad_clipped, grad_time_weight_fn, time_axis=time_axis)
  156. grad_mean = canonicalize_grad_results(grad_weighted, params)
  157. updates, opt_state = optimizer_update(grad_mean, opt_state)
  158. params = optax.apply_updates(params, updates)
  159. for fn in param_rescalings:
  160. params = fn(params)
  161. params = map_into_bounds(params, lower_bounds, upper_bounds)
  162. return params, opt_state, bptt_results, aux, grad_clipped
  163. return step