1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import jax.numpy as jnp
- import jax
- Array = jnp.ndarray
- f32 = jnp.float32
- def diagonal_mask(X: Array) -> Array:
- """Sets the diagonal of a matrix to zero. A direct copy of jax_md.smap._diagonal_matrix()"""
- if X.shape[0] != X.shape[1]:
- raise ValueError(
- 'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
- X.shape[0], X.shape[1]))
- if len(X.shape) > 3:
- raise ValueError(
- ('Diagonal mask can only mask rank-2 or rank-3 tensors. '
- 'Found {}.'.format(len(X.shape))))
- N = X.shape[0]
- X = jnp.nan_to_num(X)
- mask = f32(1.0) - jnp.eye(N, dtype=X.dtype)
- if len(X.shape) == 3:
- mask = jnp.reshape(mask, (N, N, 1))
- return mask * X
- @jax.custom_vjp
- def print_grad(x):
- return x
- def _print_grad_fwd(x):
- return x, None
- def _print_grad_bwd(_, grad):
- jax.debug.print("grad: {}", grad)
- return (grad,)
- print_grad.defvjp(_print_grad_fwd, _print_grad_bwd)
|