util.py 973 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import jax.numpy as jnp
  2. import jax
  3. Array = jnp.ndarray
  4. f32 = jnp.float32
  5. def diagonal_mask(X: Array) -> Array:
  6. """Sets the diagonal of a matrix to zero. A direct copy of jax_md.smap._diagonal_matrix()"""
  7. if X.shape[0] != X.shape[1]:
  8. raise ValueError(
  9. 'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
  10. X.shape[0], X.shape[1]))
  11. if len(X.shape) > 3:
  12. raise ValueError(
  13. ('Diagonal mask can only mask rank-2 or rank-3 tensors. '
  14. 'Found {}.'.format(len(X.shape))))
  15. N = X.shape[0]
  16. X = jnp.nan_to_num(X)
  17. mask = f32(1.0) - jnp.eye(N, dtype=X.dtype)
  18. if len(X.shape) == 3:
  19. mask = jnp.reshape(mask, (N, N, 1))
  20. return mask * X
  21. @jax.custom_vjp
  22. def print_grad(x):
  23. return x
  24. def _print_grad_fwd(x):
  25. return x, None
  26. def _print_grad_bwd(_, grad):
  27. jax.debug.print("grad: {}", grad)
  28. return (grad,)
  29. print_grad.defvjp(_print_grad_fwd, _print_grad_bwd)