123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- import jax.numpy as jnp
- import jax
- from jax_md import dataclasses, space, rigid_body
- from typing import Callable, TypeVar, Any
- import functools
- Array = jnp.ndarray
- T = TypeVar('T')
- InitFn = Callable[..., T]
- ApplyFn = Callable[[T], T]
- def random_unit_vector(key):
- key, split = jax.random.split(key)
- x1, x2 = jax.random.uniform(split, (2,), dtype=jnp.float64)
- phi = 2 * jnp.pi * x1
- cos_theta = 2 * x1 - 1
- sin_theta = jnp.sqrt(1 - cos_theta ** 2)
- sin_phi = jnp.sin(phi)
- cos_phi = jnp.cos(phi)
- return jnp.array([cos_phi * sin_theta, sin_phi * sin_theta, cos_theta])
- def random_quaternion(key, max_rotation):
- key, axis_key, angle_key = jax.random.split(key, 3)
- axis = random_unit_vector(axis_key)
- angle = max_rotation * jax.random.uniform(angle_key, ())
- sin_angle_2 = jnp.sin(angle / 2)
- cos_angle_2 = jnp.cos(angle / 2)
- q = jnp.array([cos_angle_2, sin_angle_2 * axis[0], sin_angle_2 * axis[1], sin_angle_2 * axis[2]])
- return rigid_body.Quaternion(q)
- @functools.singledispatch
- def mc_move(position: Array, idx: int, key: jax.random.KeyArray, moving_distance: Array, shift: space.ShiftFn) -> Array:
- move = moving_distance * random_unit_vector(key)
- return position.at[idx].set(shift(position[idx], move))
- @mc_move.register(rigid_body.RigidBody)
- def _(position: rigid_body.RigidBody,
- idx: int,
- key: jax.random.KeyArray,
- moving_distance: rigid_body.RigidBody,
- shift: space.ShiftFn) -> rigid_body.RigidBody:
- key, position_key, orientation_key = jax.random.split(key, 3)
- position_move = moving_distance.center * jax.random.normal(key, (3,))
- orientation_move = random_quaternion(orientation_key, moving_distance.orientation)
- new_position = position.center.at[idx].set(shift(position.center[idx], position_move))
- new_orientation_vec = position.orientation.vec.at[idx].set((orientation_move * position.orientation[idx]).vec)
- return rigid_body.RigidBody(new_position, rigid_body.Quaternion(new_orientation_vec))
- @functools.singledispatch
- def num_particles(position: Array):
- return position.shape[0]
- @num_particles.register(rigid_body.RigidBody)
- def _(position: rigid_body.RigidBody):
- return position.center.shape[0]
- @dataclasses.dataclass
- class MCMCState:
- position: Any
- key: Array
- accept: bool
- def mc_mc(shift: space.ShiftFn,
- energy_fn: Callable[..., Array],
- kT: float,
- moving_distance: Array
- ) -> (InitFn, ApplyFn):
- def init_fn(key, position) -> MCMCState:
- return MCMCState(position, key, False)
- def apply_fn(state: MCMCState, **kwargs) -> MCMCState:
- position = state.position
- N = num_particles(position)
- # Move random particle for a random amount
- key, particle_key, move_key, accept_key = jax.random.split(state.key, 4)
- idx = jax.random.randint(particle_key, (2,), jnp.array(0), jnp.array(N))
- new_position = mc_move(position, idx, move_key, moving_distance, shift)
- # Compute the energy before the swap.
- energy = energy_fn(position, **kwargs)
- # Compute the energy after the swap.
- new_energy = energy_fn(new_position, **kwargs)
- # Accept or reject with a metropolis probability.
- p = jax.random.uniform(accept_key, ())
- accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT))
- position = jax.lax.cond(p < accept_prob, lambda x: x[0], lambda x: x[1], [new_position, position])
- return MCMCState(position, key, p < accept_prob)
- return init_fn, apply_fn
|