123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 |
- from __future__ import annotations
- import optax
- from curvature_assembly import data_protocols, pytree_transf, oriented_particle
- import jax
- from typing import Callable, Any
- import jax.numpy as jnp
- InteractionParams = data_protocols.InteractionParams
- SimulationAux = data_protocols.SimulationAux
- BpttResults = data_protocols.BpttResults
- BpttSimulation = data_protocols.BpttSimulation
- Array = jnp.ndarray
- def unitwise_clip(g_norm: Array,
- max_norm: Array,
- grad: Array,
- div_eps: float = 1e-6) -> Array:
- """Applies gradient clipping unit-wise."""
- # This little max(., div_eps) is distinct from the normal eps and just
- # prevents division by zero. It technically should be impossible to engage.
- clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps))
- return jnp.where(g_norm < max_norm, grad, clipped_grad)
- def adaptive_grad_clip(grad, params, clipping: float, eps: float = 1e-3):
- num_ed = pytree_transf.num_extra_dimensions(grad, params)
- g_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(grad, keepdims=True, num_ld=num_ed), grad)
- p_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(params, keepdims=True), grad)
- # Maximum allowable leaf_norm
- max_norm = jax.tree_util.tree_map(
- lambda x: clipping * jnp.maximum(x, eps), p_norm)
- # If grad leaf_norm > clipping * param_norm, rescale
- return jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, grad)
- def get_grad_time_weights(grad: InteractionParams, time_weight_fn: Callable[[Array], Array], time_axis: int = 1):
- """
- Apply time-based weights to the gradients of interaction parameters.
- Args:
- grad: Gradients of interaction parameters, represented as a JAX PyTree.
- time_weight_fn: A function that computes the time-based weights on a rescaled time interval [0, 1].
- time_axis: The axis along which the time steps are represented in the `grad` PyTree. Default is 1.
- Returns:
- Gradients of interaction parameters with time-based weights applied.
- """
- num_timesteps = pytree_transf.data_length(grad, axis=time_axis)
- weights = time_weight_fn(jnp.linspace(0, 1, num_timesteps, endpoint=True))
- mean = jax.lax.cond(jnp.mean(weights) == 0, lambda x: 1., lambda x: jnp.mean(x), weights)
- normalized_weights = weights / mean
- def apply_weights(x):
- expand_dims = tuple(i for i in range(len(x.shape)) if i != time_axis)
- expanded_weights = jnp.expand_dims(normalized_weights, axis=expand_dims)
- return x * expanded_weights
- return jax.tree_util.tree_map(apply_weights, grad)
- def canonicalize_grad_results(grad: InteractionParams, params: InteractionParams) -> InteractionParams:
- """
- Make gradient leaf shapes compatible with interaction params, i.e. we take the average over all extra axes
- compared to the original params shape.
- """
- results_num_ed = pytree_transf.num_extra_dimensions(grad, params)
- return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=tuple(range(results_num_ed))), grad)
- def bounds_params(params: InteractionParams, **opt_param_dict: dict[tuple]) -> (InteractionParams, InteractionParams):
- params_dict = vars(params)
- lower_bounds = {}
- upper_bounds = {}
- for key in params_dict.keys():
- if key in opt_param_dict:
- try:
- lb, ub = opt_param_dict[key]
- if not jnp.all(lb < ub):
- raise ValueError(f"Lower bounds should all be smaller than upper bounds, "
- f"problem with parameter {key}")
- except TypeError:
- lb = None
- ub = None
- lower_bounds[key] = lb
- upper_bounds[key] = ub
- else:
- lower_bounds[key] = params_dict[key]
- upper_bounds[key] = params_dict[key]
- return type(params)(**lower_bounds), type(params)(**upper_bounds)
- def map_into_bounds(params: InteractionParams,
- lower_bounds: InteractionParams | None,
- upper_bounds: InteractionParams | None):
- """Map interaction parameters back into the interval between lower and upper bounds."""
- leaves, treedef = jax.tree_util.tree_flatten(params)
- # if lower and/or upper bounds are not provided, we must construct pytrees with the same structure as
- # parameters structure and filled with None. Note that this will fail if parameters pytree contains any
- # None values (will raise a ValueError).
- if lower_bounds is None:
- lower_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
- if upper_bounds is None:
- upper_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves))
- def map_to_bounds(x, xmin, xmax):
- try:
- if jnp.any(xmin > xmax):
- raise ValueError(f'Min bound cannot be larger than max bound, got {xmin} and {xmax}, respectively.')
- except TypeError:
- pass
- if xmin is not None:
- x = jnp.maximum(x, xmin)
- if xmax is not None:
- x = jnp.minimum(x, xmax)
- return x
- return jax.tree_util.tree_map(map_to_bounds, params, lower_bounds, upper_bounds)
- def normalize_param(params: InteractionParams, param_name: str, ord=None) -> InteractionParams:
- params_dict = vars(params)
- new_dict = params_dict.copy() # shallow copy is enough as values (interaction_params elements) are jax arrays
- new_dict[param_name] = params_dict[param_name] / jnp.linalg.norm(params_dict[param_name], keepdims=True, ord=ord)
- return type(params)(**new_dict)
- TIME_WEIGHT_FN = {'constant': lambda x: jnp.ones_like(x),
- 'linear': lambda x: x,
- 'quadratic': lambda x: x ** 2,
- 'exponential': lambda x: jnp.exp(x),
- 'step_25': lambda x: jnp.heaviside(x - 0.249, 1),
- 'step_50': lambda x: jnp.heaviside(x - 0.50, 1),
- 'step_75': lambda x: jnp.heaviside(x - 0.749, 1),
- 'step_100': lambda x: jnp.heaviside(x - 1., 1),
- 'neg_linear': lambda x: 1 - x}
- def fit_bptt(simulation_fn: BpttSimulation,
- optimizer_update: optax.TransformUpdateFn,
- clipping: float,
- grad_time_weights: str = None,
- param_rescalings: list[Callable[[InteractionParams], InteractionParams]] = None,
- lower_bounds: InteractionParams = None,
- upper_bounds: InteractionParams = None,
- time_axis: int = 1) -> Callable:
- """
- Construct the step function for meta optimization of parameters in a BPTT simulation.
- Args:
- simulation_fn: A function that performs the simulation and computes the gradients of interaction parameters.
- optimizer_update: A function that updates the parameters using the computed gradients.
- clipping: The maximum value to clip the gradients during training.
- grad_time_weights: String that then maps into a function that computes time-based weights for the gradients.
- Default is a function that assigns equal weights (ones) to all time steps.
- param_rescalings: A list of functions that apply rescalings
- or transformations to the interaction parameters during training. Default is an empty list.
- lower_bounds: The lower bounds for the interaction parameters. Default is None.
- upper_bounds: The upper bounds for the interaction parameters. Default is None.
- time_axis: The axis along which the time steps are represented in the gradient PyTree. Default is 1.
- Returns:
- Callable: A step function that performs one training step.
- """
- if grad_time_weights is None:
- grad_time_weights = 'constant'
- try:
- grad_time_weight_fn = TIME_WEIGHT_FN[grad_time_weights]
- except KeyError:
- raise ValueError(f'Invalid time weight parameter, {grad_time_weights} is not among the implemented weights.')
- if param_rescalings is None:
- param_rescalings = []
- param_rescalings.insert(0, oriented_particle.canonicalize_eigvals)
- def step(params: InteractionParams,
- opt_state: optax.OptState,
- md_state: Any,
- aux: SimulationAux) -> (InteractionParams, optax.OptState,
- BpttResults, SimulationAux, InteractionParams):
- aux = aux.reset_empty()
- bptt_results, aux = simulation_fn(params, md_state, aux)
- grad_clipped = adaptive_grad_clip(bptt_results.grad, params, clipping)
- grad_weighted = get_grad_time_weights(grad_clipped, grad_time_weight_fn, time_axis=time_axis)
- grad_mean = canonicalize_grad_results(grad_weighted, params)
- updates, opt_state = optimizer_update(grad_mean, opt_state)
- params = optax.apply_updates(params, updates)
- for fn in param_rescalings:
- params = fn(params)
- params = map_into_bounds(params, lower_bounds, upper_bounds)
- return params, opt_state, bptt_results, aux, grad_clipped
- return step
|