import jax.numpy as jnp import jax Array = jnp.ndarray f32 = jnp.float32 def diagonal_mask(X: Array) -> Array: """Sets the diagonal of a matrix to zero. A direct copy of jax_md.smap._diagonal_matrix()""" if X.shape[0] != X.shape[1]: raise ValueError( 'Diagonal mask can only mask square matrices. Found {}x{}.'.format( X.shape[0], X.shape[1])) if len(X.shape) > 3: raise ValueError( ('Diagonal mask can only mask rank-2 or rank-3 tensors. ' 'Found {}.'.format(len(X.shape)))) N = X.shape[0] X = jnp.nan_to_num(X) mask = f32(1.0) - jnp.eye(N, dtype=X.dtype) if len(X.shape) == 3: mask = jnp.reshape(mask, (N, N, 1)) return mask * X @jax.custom_vjp def print_grad(x): return x def _print_grad_fwd(x): return x, None def _print_grad_bwd(_, grad): jax.debug.print("grad: {}", grad) return (grad,) print_grad.defvjp(_print_grad_fwd, _print_grad_bwd)