import jax.numpy as jnp import jax from jax_md import dataclasses, space, rigid_body from typing import Callable, TypeVar, Any import functools Array = jnp.ndarray T = TypeVar('T') InitFn = Callable[..., T] ApplyFn = Callable[[T], T] def random_unit_vector(key): key, split = jax.random.split(key) x1, x2 = jax.random.uniform(split, (2,), dtype=jnp.float64) phi = 2 * jnp.pi * x1 cos_theta = 2 * x1 - 1 sin_theta = jnp.sqrt(1 - cos_theta ** 2) sin_phi = jnp.sin(phi) cos_phi = jnp.cos(phi) return jnp.array([cos_phi * sin_theta, sin_phi * sin_theta, cos_theta]) def random_quaternion(key, max_rotation): key, axis_key, angle_key = jax.random.split(key, 3) axis = random_unit_vector(axis_key) angle = max_rotation * jax.random.uniform(angle_key, ()) sin_angle_2 = jnp.sin(angle / 2) cos_angle_2 = jnp.cos(angle / 2) q = jnp.array([cos_angle_2, sin_angle_2 * axis[0], sin_angle_2 * axis[1], sin_angle_2 * axis[2]]) return rigid_body.Quaternion(q) @functools.singledispatch def mc_move(position: Array, idx: int, key: jax.random.KeyArray, moving_distance: Array, shift: space.ShiftFn) -> Array: move = moving_distance * random_unit_vector(key) return position.at[idx].set(shift(position[idx], move)) @mc_move.register(rigid_body.RigidBody) def _(position: rigid_body.RigidBody, idx: int, key: jax.random.KeyArray, moving_distance: rigid_body.RigidBody, shift: space.ShiftFn) -> rigid_body.RigidBody: key, position_key, orientation_key = jax.random.split(key, 3) position_move = moving_distance.center * jax.random.normal(key, (3,)) orientation_move = random_quaternion(orientation_key, moving_distance.orientation) new_position = position.center.at[idx].set(shift(position.center[idx], position_move)) new_orientation_vec = position.orientation.vec.at[idx].set((orientation_move * position.orientation[idx]).vec) return rigid_body.RigidBody(new_position, rigid_body.Quaternion(new_orientation_vec)) @functools.singledispatch def num_particles(position: Array): return position.shape[0] @num_particles.register(rigid_body.RigidBody) def _(position: rigid_body.RigidBody): return position.center.shape[0] @dataclasses.dataclass class MCMCState: position: Any key: Array accept: bool def mc_mc(shift: space.ShiftFn, energy_fn: Callable[..., Array], kT: float, moving_distance: Array ) -> (InitFn, ApplyFn): def init_fn(key, position) -> MCMCState: return MCMCState(position, key, False) def apply_fn(state: MCMCState, **kwargs) -> MCMCState: position = state.position N = num_particles(position) # Move random particle for a random amount key, particle_key, move_key, accept_key = jax.random.split(state.key, 4) idx = jax.random.randint(particle_key, (2,), jnp.array(0), jnp.array(N)) new_position = mc_move(position, idx, move_key, moving_distance, shift) # Compute the energy before the swap. energy = energy_fn(position, **kwargs) # Compute the energy after the swap. new_energy = energy_fn(new_position, **kwargs) # Accept or reject with a metropolis probability. p = jax.random.uniform(accept_key, ()) accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT)) position = jax.lax.cond(p < accept_prob, lambda x: x[0], lambda x: x[1], [new_position, position]) return MCMCState(position, key, p < accept_prob) return init_fn, apply_fn