monte_carlo.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import jax.numpy as jnp
  2. import jax
  3. from jax_md import dataclasses, space, rigid_body
  4. from typing import Callable, TypeVar, Any
  5. import functools
  6. Array = jnp.ndarray
  7. T = TypeVar('T')
  8. InitFn = Callable[..., T]
  9. ApplyFn = Callable[[T], T]
  10. def random_unit_vector(key):
  11. key, split = jax.random.split(key)
  12. x1, x2 = jax.random.uniform(split, (2,), dtype=jnp.float64)
  13. phi = 2 * jnp.pi * x1
  14. cos_theta = 2 * x1 - 1
  15. sin_theta = jnp.sqrt(1 - cos_theta ** 2)
  16. sin_phi = jnp.sin(phi)
  17. cos_phi = jnp.cos(phi)
  18. return jnp.array([cos_phi * sin_theta, sin_phi * sin_theta, cos_theta])
  19. def random_quaternion(key, max_rotation):
  20. key, axis_key, angle_key = jax.random.split(key, 3)
  21. axis = random_unit_vector(axis_key)
  22. angle = max_rotation * jax.random.uniform(angle_key, ())
  23. sin_angle_2 = jnp.sin(angle / 2)
  24. cos_angle_2 = jnp.cos(angle / 2)
  25. q = jnp.array([cos_angle_2, sin_angle_2 * axis[0], sin_angle_2 * axis[1], sin_angle_2 * axis[2]])
  26. return rigid_body.Quaternion(q)
  27. @functools.singledispatch
  28. def mc_move(position: Array, idx: int, key: jax.random.KeyArray, moving_distance: Array, shift: space.ShiftFn) -> Array:
  29. move = moving_distance * random_unit_vector(key)
  30. return position.at[idx].set(shift(position[idx], move))
  31. @mc_move.register(rigid_body.RigidBody)
  32. def _(position: rigid_body.RigidBody,
  33. idx: int,
  34. key: jax.random.KeyArray,
  35. moving_distance: rigid_body.RigidBody,
  36. shift: space.ShiftFn) -> rigid_body.RigidBody:
  37. key, position_key, orientation_key = jax.random.split(key, 3)
  38. position_move = moving_distance.center * jax.random.normal(key, (3,))
  39. orientation_move = random_quaternion(orientation_key, moving_distance.orientation)
  40. new_position = position.center.at[idx].set(shift(position.center[idx], position_move))
  41. new_orientation_vec = position.orientation.vec.at[idx].set((orientation_move * position.orientation[idx]).vec)
  42. return rigid_body.RigidBody(new_position, rigid_body.Quaternion(new_orientation_vec))
  43. @functools.singledispatch
  44. def num_particles(position: Array):
  45. return position.shape[0]
  46. @num_particles.register(rigid_body.RigidBody)
  47. def _(position: rigid_body.RigidBody):
  48. return position.center.shape[0]
  49. @dataclasses.dataclass
  50. class MCMCState:
  51. position: Any
  52. key: Array
  53. accept: bool
  54. def mc_mc(shift: space.ShiftFn,
  55. energy_fn: Callable[..., Array],
  56. kT: float,
  57. moving_distance: Array
  58. ) -> (InitFn, ApplyFn):
  59. def init_fn(key, position) -> MCMCState:
  60. return MCMCState(position, key, False)
  61. def apply_fn(state: MCMCState, **kwargs) -> MCMCState:
  62. position = state.position
  63. N = num_particles(position)
  64. # Move random particle for a random amount
  65. key, particle_key, move_key, accept_key = jax.random.split(state.key, 4)
  66. idx = jax.random.randint(particle_key, (2,), jnp.array(0), jnp.array(N))
  67. new_position = mc_move(position, idx, move_key, moving_distance, shift)
  68. # Compute the energy before the swap.
  69. energy = energy_fn(position, **kwargs)
  70. # Compute the energy after the swap.
  71. new_energy = energy_fn(new_position, **kwargs)
  72. # Accept or reject with a metropolis probability.
  73. p = jax.random.uniform(accept_key, ())
  74. accept_prob = jnp.minimum(1, jnp.exp(-(new_energy - energy) / kT))
  75. position = jax.lax.cond(p < accept_prob, lambda x: x[0], lambda x: x[1], [new_position, position])
  76. return MCMCState(position, key, p < accept_prob)
  77. return init_fn, apply_fn