initial_conditions.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import jax.numpy as jnp
  2. import jax.random
  3. import jax
  4. from jax_md import quantity, rigid_body, space
  5. from typing import Callable
  6. from curvature_assembly import monte_carlo, oriented_particle, smap, energy
  7. Array = jnp.ndarray
  8. def grid_init(num: int,
  9. box_size: float,
  10. initial_orient=None
  11. ) -> rigid_body.RigidBody:
  12. """
  13. Initialize a 3D grid of particles within a box of given size.
  14. Args:
  15. num: Number of particles in the grid.
  16. box_size: The length of the box in which the grid is placed.
  17. initial_orient: Initial orientation of the particles. Default is None,
  18. which corresponds to an initial orientation quaternion (1., 0., 0., 0.).
  19. Returns:
  20. A RigidBody object containing the initial positions and orientations of the particles.
  21. """
  22. Nmax = jnp.ceil(jnp.cbrt(num))
  23. gridpoints_1d = jnp.arange(Nmax) * box_size / Nmax
  24. x = jnp.meshgrid(*(3 * (gridpoints_1d,)))
  25. y = jnp.vstack(list(map(jnp.ravel, x))).T
  26. position = y[:num]
  27. if initial_orient is None:
  28. initial_orient = jnp.array([1., 0., 0., 0.])
  29. orientation = rigid_body.Quaternion(jnp.tile(initial_orient, (num, 1)))
  30. return rigid_body.RigidBody(position, orientation)
  31. def randomize_init_mc(num: int,
  32. density: float,
  33. contact_fn: Callable,
  34. mc_steps: int,
  35. kT: float,
  36. moving_distance: rigid_body.RigidBody = None,
  37. **cf_kwargs
  38. ) -> Callable[[jax.random.KeyArray], monte_carlo.MCMCState]:
  39. """
  40. Create an MC simulation function that generates random positions and orientations of particles in a simulation box
  41. with periodic boundary conditions starting from a grid of particles.
  42. Args:
  43. num: the number of particles in the system
  44. density: the density of the system
  45. contact_fn: a function that calculates the contact distance between particles
  46. mc_steps: the number of Monte Carlo steps to take
  47. kT: the temperature parameter for Metropolis criterion
  48. moving_distance: a RigidBody object that holds the maximum distance by which a particle can move and
  49. reorientate. If not provided, a default scale is set based on the density of the simulation.
  50. **cf_kwargs: any additional keyword arguments that should be passed to the contact function
  51. Returns:
  52. A callable function that takes a jax.random.KeyArray and returns a monte_carlo.MCMCState object.
  53. """
  54. box_size = quantity.box_size_at_number_density(num, density, spatial_dimension=3)
  55. displacement, shift = space.periodic(box_size)
  56. if moving_distance is None:
  57. # default scale for particle movement is approx 1 / 4 interparticle distance (taking into account particle size)
  58. # and default reorientation scale is pi/4
  59. moving_distance = rigid_body.RigidBody(0.25 * (jnp.cbrt(1 / density) - jnp.cbrt(2)), jnp.pi / 4)
  60. energy_fn = oriented_particle.isotropic_to_cf_energy(energy.weeks_chandler_andersen, contact_fn, **cf_kwargs)
  61. energy_pair = smap.oriented_pair(energy_fn, displacement)
  62. energy_kwargs = {'sigma': 1, 'epsilon': 10}
  63. init_fn, apply_fn = monte_carlo.mc_mc(shift, energy_pair, kT, moving_distance)
  64. grid_state = grid_init(num, box_size)
  65. @jax.jit
  66. def scan_fn(state, i):
  67. state = apply_fn(state, **energy_kwargs)
  68. return state, state.accept
  69. def mc_simulation(key):
  70. init_state = init_fn(key, grid_state)
  71. state, accept_array = jax.lax.scan(scan_fn, init=init_state, xs=jnp.arange(mc_steps))
  72. # print(jnp.mean(jnp.array(accept_array, dtype=jnp.float32)))
  73. return state
  74. return mc_simulation
  75. def rdf(displacement_or_metric: space.DisplacementOrMetricFn,
  76. positions: Array,
  77. density: float,
  78. r_min: float,
  79. r_max: float,
  80. num_bins: int) -> tuple[Array, Array]:
  81. """
  82. Calculate the radial distribution function (RDF) of a set of particles in a simulation box.
  83. Args:
  84. displacement_or_metric: Displacement or metric function
  85. positions: An array of shape (num_particles, 3) containing the positions of the particles.
  86. density: number density of particles in the system
  87. r_min: The minimum radial distance to consider in the RDF calculation.
  88. r_max: The maximum radial distance to consider in the RDF calculation.
  89. num_bins: The number of bins to use in the RDF calculation.
  90. Returns:
  91. An array of shape (num_bins,) containing the midpoints of the radial distance bins and an array
  92. of shape (num_bins,) containing the values of the RDF for each bin.
  93. """
  94. # Define the bin edges for the RDF
  95. bin_edges = jnp.linspace(r_min, r_max, num_bins + 1)
  96. # Create a histogram of the pairwise distances between particles
  97. metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
  98. pairwise_distances = space.map_product(metric)(positions, positions)
  99. i, j = jnp.triu_indices(pairwise_distances.shape[0], 1)
  100. histogram, _ = jnp.histogram(pairwise_distances[i, j].flatten(), bins=bin_edges)
  101. # Calculate the RDF
  102. bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
  103. bin_volumes = 4 / 3 * jnp.pi * (bin_edges[1:] ** 3 - bin_edges[:-1] ** 3)
  104. rdf = histogram / (density * bin_volumes * positions.shape[0] / 2)
  105. return bin_centers, rdf