import jax.numpy as jnp import jax from functools import partial Array = jnp.ndarray @partial(jnp.vectorize, signature='(d,d)->()') def determinant(a): """Determinant of a symmetric 3x3 matrix.""" return a[0, 0] * (a[1, 1] * a[2, 2] - a[2, 1] * a[1, 2]) \ - a[1, 0] * (a[0, 1] * a[2, 2] - a[2, 1] * a[0, 2]) \ + a[2, 0] * (a[0, 1] * a[1, 2] - a[1, 1] * a[0, 2]) @partial(jnp.vectorize, signature='(d,d)->(d,d)') def inverse(a): """Inverse of a symmetric 3x3 matrix. Much faster than jnp.linalg.inv.""" det = determinant(a) inv = jnp.array([[a[2, 2] * a[1, 1] - a[1, 2] ** 2, a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1], a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1]], [a[0, 2] * a[1, 2] - a[2, 2] * a[0, 1], a[2, 2] * a[0, 0] - a[0, 2] ** 2, a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2]], [a[0, 1] * a[1, 2] - a[0, 2] * a[1, 1], a[0, 1] * a[0, 2] - a[0, 0] * a[1, 2], a[0, 0] * a[1, 1] - a[0, 1] ** 2]]) return inv / det def matrix_c(lbd: float, mat1: Array, mat2: Array) -> Array: """Matrix C from the Perram and Wertheim article on ellipsoid contact function.""" return inverse(lbd * mat2 + (1 - lbd) * mat1) def perram_wertheim_objective(lbd: float, r12: Array, mat1: Array, mat2: Array) -> Array: c = matrix_c(lbd, mat1, mat2) return lbd * (1 - lbd) * jnp.dot(r12, jnp.dot(c, r12)) objective_grad = jax.grad(perram_wertheim_objective, argnums=0) def evaluate_grad_step(carry: float, x: float, r12: Array, mat1: Array, mat2: Array) -> (Array, Array): grad = objective_grad(carry, r12, mat1, mat2) return carry + x * jnp.sign(grad), 0. @partial(jax.jit, static_argnums=(3,)) def pw_contact_function(r12: Array, mat1: Array, mat2: Array, num_steps: int = 25, **unused_kwargs) -> Array: """ Calculate Perram-Wertheim contact function. To ensure jax.gradient compatibility, a dumb gradient-based method is used where a fixed number of steps is taken to calculate the maximum of the objective function. Square root is taken to get linear distance dependence. Args: r12: distance vector between ellipsoid centers. mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths. mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths. num_steps: number of step in objective maximization. Accuracy improves as 1 / 2^num_steps. Returns: Perram-Wertheim contact function """ powers = 2 ** (jnp.arange(num_steps) + 2) # powers of two t_change = 1 / powers t_opt, _ = jax.lax.scan(partial(evaluate_grad_step, r12=r12, mat1=mat1, mat2=mat2), init=0.5, xs=t_change) return jnp.sqrt(perram_wertheim_objective(t_opt, r12, mat1, mat2)) def bp_contact_function(r12: Array, mat1: Array, mat2: Array, **unused_kwargs) -> Array: """ Calculates Berne-Pechukas contact function which is an approximation for the true Perram-Wertheim contact function at the value of interpolation parameter t = 0.5. Square root is taken to get linear distance dependence. Args: r12: distance vector between ellipsoid centers. mat1: weight matrix of the first ellipsoid, with eigenvalues equal to squared semiaxis lengths. mat2: weight matrix of the second ellipsoid, with eigenvalues equal to squared semiaxis lengths. Returns: Berne-Pechukas contact function """ return jnp.sqrt(perram_wertheim_objective(0.5, r12, mat1, mat2))