123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import jax.numpy as jnp
- import jax
- from functools import partial
- Array = jnp.ndarray
- @partial(jnp.vectorize, signature='(d,d)->()')
- def determinant(a):
- """Determinant of a symmetric 3x3 matrix."""
- return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) \
- - a[1, 0] * (a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) \
- + a[2, 0] * (a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2])
- @partial(jnp.vectorize, signature='(d,d)->(d,d)')
- def inverse(a):
- """Inverse of a symmetric 3x3 matrix. Much faster than jnp.linalg.inv."""
- det = determinant(a)
- inv = jnp.array([[a[2, 2] * a[1, 1] - a[1, 2] ** 2,
- a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
- a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1]],
- [a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1],
- a[2, 2] * a[0, 0] - a[0, 2] ** 2,
- a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2]],
- [a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1],
- a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2],
- a[0, 0] * a[1, 1] - a[0, 1] ** 2]])
- return inv / det
- def matrix_c(lbd: float, mat1: Array, mat2: Array) -> Array:
- """Matrix C from the Perram and Wertheim article on ellipsoid contact function."""
- return inverse(lbd * mat2 + (1 - lbd) * mat1)
- def perram_wertheim_objective(lbd: float, r12: Array, mat1: Array, mat2: Array) -> Array:
- c = matrix_c(lbd, mat1, mat2)
- return lbd * (1 - lbd) * jnp.dot(r12, jnp.dot(c, r12))
- objective_grad = jax.grad(perram_wertheim_objective, argnums=0)
- def evaluate_grad_step(carry: float, x: float, r12: Array, mat1: Array, mat2: Array) -> (Array, Array):
- grad = objective_grad(carry, r12, mat1, mat2)
- return carry + x * jnp.sign(grad), 0.
- @partial(jax.jit, static_argnums=(3,))
- def pw_contact_function(r12: Array, mat1: Array, mat2: Array, num_steps: int = 25, **unused_kwargs) -> Array:
- """
- Calculate Perram-Wertheim contact function. To ensure jax.gradient compatibility,
- a dumb gradient-based method is used where a fixed number of steps is taken to calculate the maximum
- of the objective function. Square root is taken to get linear distance dependence.
- Args:
- r12: distance vector between ellipsoid centers.
- mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
- mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
- num_steps: number of step in objective maximization. Accuracy improves as 1 / 2^num_steps.
- Returns:
- Perram-Wertheim contact function
- """
- powers = 2 ** (jnp.arange(num_steps) + 2) # powers of two
- t_change = 1 / powers
- t_opt, _ = jax.lax.scan(partial(evaluate_grad_step, r12=r12, mat1=mat1, mat2=mat2), init=0.5, xs=t_change)
- return jnp.sqrt(perram_wertheim_objective(t_opt, r12, mat1, mat2))
- def bp_contact_function(r12: Array, mat1: Array, mat2: Array, **unused_kwargs) -> Array:
- """
- Calculates Berne-Pechukas contact function which is an approximation for the true Perram-Wertheim contact function
- at the value of interpolation parameter t = 0.5. Square root is taken to get linear distance dependence.
- Args:
- r12: distance vector between ellipsoid centers.
- mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths.
- mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths.
- Returns:
- Berne-Pechukas contact function
- """
- return jnp.sqrt(perram_wertheim_objective(0.5, r12, mat1, mat2))
|