from __future__ import annotations import optax from curvature_assembly import data_protocols, pytree_transf, oriented_particle import jax from typing import Callable, Any import jax.numpy as jnp InteractionParams = data_protocols.InteractionParams SimulationAux = data_protocols.SimulationAux BpttResults = data_protocols.BpttResults BpttSimulation = data_protocols.BpttSimulation Array = jnp.ndarray def unitwise_clip(g_norm: Array, max_norm: Array, grad: Array, div_eps: float = 1e-6) -> Array: """Applies gradient clipping unit-wise.""" # This little max(., div_eps) is distinct from the normal eps and just # prevents division by zero. It technically should be impossible to engage. clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) return jnp.where(g_norm < max_norm, grad, clipped_grad) def adaptive_grad_clip(grad, params, clipping: float, eps: float = 1e-3): num_ed = pytree_transf.num_extra_dimensions(grad, params) g_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(grad, keepdims=True, num_ld=num_ed), grad) p_norm = pytree_transf.broadcast_to(pytree_transf.leaf_norm(params, keepdims=True), grad) # Maximum allowable leaf_norm max_norm = jax.tree_util.tree_map( lambda x: clipping * jnp.maximum(x, eps), p_norm) # If grad leaf_norm > clipping * param_norm, rescale return jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, grad) def get_grad_time_weights(grad: InteractionParams, time_weight_fn: Callable[[Array], Array], time_axis: int = 1): """ Apply time-based weights to the gradients of interaction parameters. Args: grad: Gradients of interaction parameters, represented as a JAX PyTree. time_weight_fn: A function that computes the time-based weights on a rescaled time interval [0, 1]. time_axis: The axis along which the time steps are represented in the `grad` PyTree. Default is 1. Returns: Gradients of interaction parameters with time-based weights applied. """ num_timesteps = pytree_transf.data_length(grad, axis=time_axis) weights = time_weight_fn(jnp.linspace(0, 1, num_timesteps, endpoint=True)) mean = jax.lax.cond(jnp.mean(weights) == 0, lambda x: 1., lambda x: jnp.mean(x), weights) normalized_weights = weights / mean def apply_weights(x): expand_dims = tuple(i for i in range(len(x.shape)) if i != time_axis) expanded_weights = jnp.expand_dims(normalized_weights, axis=expand_dims) return x * expanded_weights return jax.tree_util.tree_map(apply_weights, grad) def canonicalize_grad_results(grad: InteractionParams, params: InteractionParams) -> InteractionParams: """ Make gradient leaf shapes compatible with interaction params, i.e. we take the average over all extra axes compared to the original params shape. """ results_num_ed = pytree_transf.num_extra_dimensions(grad, params) return jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=tuple(range(results_num_ed))), grad) def bounds_params(params: InteractionParams, **opt_param_dict: dict[tuple]) -> (InteractionParams, InteractionParams): params_dict = vars(params) lower_bounds = {} upper_bounds = {} for key in params_dict.keys(): if key in opt_param_dict: try: lb, ub = opt_param_dict[key] if not jnp.all(lb < ub): raise ValueError(f"Lower bounds should all be smaller than upper bounds, " f"problem with parameter {key}") except TypeError: lb = None ub = None lower_bounds[key] = lb upper_bounds[key] = ub else: lower_bounds[key] = params_dict[key] upper_bounds[key] = params_dict[key] return type(params)(**lower_bounds), type(params)(**upper_bounds) def map_into_bounds(params: InteractionParams, lower_bounds: InteractionParams | None, upper_bounds: InteractionParams | None): """Map interaction parameters back into the interval between lower and upper bounds.""" leaves, treedef = jax.tree_util.tree_flatten(params) # if lower and/or upper bounds are not provided, we must construct pytrees with the same structure as # parameters structure and filled with None. Note that this will fail if parameters pytree contains any # None values (will raise a ValueError). if lower_bounds is None: lower_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves)) if upper_bounds is None: upper_bounds = jax.tree_util.tree_unflatten(treedef, [None] * len(leaves)) def map_to_bounds(x, xmin, xmax): try: if jnp.any(xmin > xmax): raise ValueError(f'Min bound cannot be larger than max bound, got {xmin} and {xmax}, respectively.') except TypeError: pass if xmin is not None: x = jnp.maximum(x, xmin) if xmax is not None: x = jnp.minimum(x, xmax) return x return jax.tree_util.tree_map(map_to_bounds, params, lower_bounds, upper_bounds) def normalize_param(params: InteractionParams, param_name: str, ord=None) -> InteractionParams: params_dict = vars(params) new_dict = params_dict.copy() # shallow copy is enough as values (interaction_params elements) are jax arrays new_dict[param_name] = params_dict[param_name] / jnp.linalg.norm(params_dict[param_name], keepdims=True, ord=ord) return type(params)(**new_dict) TIME_WEIGHT_FN = {'constant': lambda x: jnp.ones_like(x), 'linear': lambda x: x, 'quadratic': lambda x: x ** 2, 'exponential': lambda x: jnp.exp(x), 'step_25': lambda x: jnp.heaviside(x - 0.249, 1), 'step_50': lambda x: jnp.heaviside(x - 0.50, 1), 'step_75': lambda x: jnp.heaviside(x - 0.749, 1), 'step_100': lambda x: jnp.heaviside(x - 1., 1), 'neg_linear': lambda x: 1 - x} def fit_bptt(simulation_fn: BpttSimulation, optimizer_update: optax.TransformUpdateFn, clipping: float, grad_time_weights: str = None, param_rescalings: list[Callable[[InteractionParams], InteractionParams]] = None, lower_bounds: InteractionParams = None, upper_bounds: InteractionParams = None, time_axis: int = 1) -> Callable: """ Construct the step function for meta optimization of parameters in a BPTT simulation. Args: simulation_fn: A function that performs the simulation and computes the gradients of interaction parameters. optimizer_update: A function that updates the parameters using the computed gradients. clipping: The maximum value to clip the gradients during training. grad_time_weights: String that then maps into a function that computes time-based weights for the gradients. Default is a function that assigns equal weights (ones) to all time steps. param_rescalings: A list of functions that apply rescalings or transformations to the interaction parameters during training. Default is an empty list. lower_bounds: The lower bounds for the interaction parameters. Default is None. upper_bounds: The upper bounds for the interaction parameters. Default is None. time_axis: The axis along which the time steps are represented in the gradient PyTree. Default is 1. Returns: Callable: A step function that performs one training step. """ if grad_time_weights is None: grad_time_weights = 'constant' try: grad_time_weight_fn = TIME_WEIGHT_FN[grad_time_weights] except KeyError: raise ValueError(f'Invalid time weight parameter, {grad_time_weights} is not among the implemented weights.') if param_rescalings is None: param_rescalings = [] param_rescalings.insert(0, oriented_particle.canonicalize_eigvals) def step(params: InteractionParams, opt_state: optax.OptState, md_state: Any, aux: SimulationAux) -> (InteractionParams, optax.OptState, BpttResults, SimulationAux, InteractionParams): aux = aux.reset_empty() bptt_results, aux = simulation_fn(params, md_state, aux) grad_clipped = adaptive_grad_clip(bptt_results.grad, params, clipping) grad_weighted = get_grad_time_weights(grad_clipped, grad_time_weight_fn, time_axis=time_axis) grad_mean = canonicalize_grad_results(grad_weighted, params) updates, opt_state = optimizer_update(grad_mean, opt_state) params = optax.apply_updates(params, updates) for fn in param_rescalings: params = fn(params) params = map_into_bounds(params, lower_bounds, upper_bounds) return params, opt_state, bptt_results, aux, grad_clipped return step