ellipsoid_contact.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import jax.numpy as jnp
  2. import jax
  3. from functools import partial
  4. Array = jnp.ndarray
  5. @partial(jnp.vectorize, signature='(d,d)->()')
  6. def determinant(a):
  7. """Determinant of a symmetric 3x3 matrix."""
  8. return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) \
  9. - a[1, 0] * (a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) \
  10. + a[2, 0] * (a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2])
  11. @partial(jnp.vectorize, signature='(d,d)->(d,d)')
  12. def inverse(a):
  13. """Inverse of a symmetric 3x3 matrix. Much faster than jnp.linalg.inv."""
  14. det = determinant(a)
  15. inv = jnp.array([[a[2, 2] * a[1, 1] - a[1, 2] ** 2,
  16. a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
  17. a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1]],
  18. [a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
  19. a[2, 2] * a[0, 0] - a[0, 2] ** 2,
  20. a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2]],
  21. [a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1],
  22. a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2],
  23. a[0, 0] * a[1, 1] - a[0, 1] ** 2]])
  24. return inv / det
  25. def matrix_c(lbd: float, mat1: Array, mat2: Array) -> Array:
  26. """Matrix C from the Perram and Wertheim article on ellipsoid contact function."""
  27. return inverse(lbd * mat2 + (1 - lbd) * mat1)
  28. def perram_wertheim_objective(lbd: float, r12: Array, mat1: Array, mat2: Array) -> Array:
  29. c = matrix_c(lbd, mat1, mat2)
  30. return lbd * (1 - lbd) * jnp.dot(r12, jnp.dot(c, r12))
  31. objective_grad = jax.grad(perram_wertheim_objective, argnums=0)
  32. def evaluate_grad_step(carry: float, x: float, r12: Array, mat1: Array, mat2: Array) -> (Array, Array):
  33. grad = objective_grad(carry, r12, mat1, mat2)
  34. return carry + x * jnp.sign(grad), 0.
  35. @partial(jax.jit, static_argnums=(3,))
  36. def pw_contact_function(r12: Array, mat1: Array, mat2: Array, num_steps: int = 25, **unused_kwargs) -> Array:
  37. """
  38. Calculate Perram-Wertheim contact function. To ensure jax.gradient compatibility,
  39. a dumb gradient-based method is used where a fixed number of steps is taken to calculate the maximum
  40. of the objective function. Square root is taken to get linear distance dependence.
  41. Args:
  42. r12: distance vector between ellipsoid centers.
  43. mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
  44. mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
  45. num_steps: number of step in objective maximization. Accuracy improves as 1 / 2^num_steps.
  46. Returns:
  47. Perram-Wertheim contact function
  48. """
  49. powers = 2 ** (jnp.arange(num_steps) + 2) # powers of two
  50. t_change = 1 / powers
  51. t_opt, _ = jax.lax.scan(partial(evaluate_grad_step, r12=r12, mat1=mat1, mat2=mat2), init=0.5, xs=t_change)
  52. return jnp.sqrt(perram_wertheim_objective(t_opt, r12, mat1, mat2))
  53. def bp_contact_function(r12: Array, mat1: Array, mat2: Array, **unused_kwargs) -> Array:
  54. """
  55. Calculates Berne-Pechukas contact function which is an approximation for the true Perram-Wertheim contact function
  56. at the value of interpolation parameter t = 0.5. Square root is taken to get linear distance dependence.
  57. Args:
  58. r12: distance vector between ellipsoid centers.
  59. mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
  60. mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
  61. Returns:
  62. Berne-Pechukas contact function
  63. """
  64. return jnp.sqrt(perram_wertheim_objective(0.5, r12, mat1, mat2))