123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- import jax.numpy as jnp
- import jax.random
- import jax
- from jax_md import quantity, rigid_body, space
- from typing import Callable
- from curvature_assembly import monte_carlo, oriented_particle, smap, energy
- Array = jnp.ndarray
- def grid_init(num: int,
- box_size: float,
- initial_orient=None
- ) -> rigid_body.RigidBody:
- """
- Initialize a 3D grid of particles within a box of given size.
- Args:
- num: Number of particles in the grid.
- box_size: The length of the box in which the grid is placed.
- initial_orient: Initial orientation of the particles. Default is None,
- which corresponds to an initial orientation quaternion (1., 0., 0., 0.).
- Returns:
- A RigidBody object containing the initial positions and orientations of the particles.
- """
- Nmax = jnp.ceil(jnp.cbrt(num))
- gridpoints_1d = jnp.arange(Nmax) * box_size / Nmax
- x = jnp.meshgrid(*(3 * (gridpoints_1d,)))
- y = jnp.vstack(list(map(jnp.ravel, x))).T
- position = y[:num]
- if initial_orient is None:
- initial_orient = jnp.array([1., 0., 0., 0.])
- orientation = rigid_body.Quaternion(jnp.tile(initial_orient, (num, 1)))
- return rigid_body.RigidBody(position, orientation)
- def randomize_init_mc(num: int,
- density: float,
- contact_fn: Callable,
- mc_steps: int,
- kT: float,
- moving_distance: rigid_body.RigidBody = None,
- **cf_kwargs
- ) -> Callable[[jax.random.KeyArray], monte_carlo.MCMCState]:
- """
- Create an MC simulation function that generates random positions and orientations of particles in a simulation box
- with periodic boundary conditions starting from a grid of particles.
- Args:
- num: the number of particles in the system
- density: the density of the system
- contact_fn: a function that calculates the contact distance between particles
- mc_steps: the number of Monte Carlo steps to take
- kT: the temperature parameter for Metropolis criterion
- moving_distance: a RigidBody object that holds the maximum distance by which a particle can move and
- reorientate. If not provided, a default scale is set based on the density of the simulation.
- **cf_kwargs: any additional keyword arguments that should be passed to the contact function
- Returns:
- A callable function that takes a jax.random.KeyArray and returns a monte_carlo.MCMCState object.
- """
- box_size = quantity.box_size_at_number_density(num, density, spatial_dimension=3)
- displacement, shift = space.periodic(box_size)
- if moving_distance is None:
- # default scale for particle movement is approx 1 / 4 interparticle distance (taking into account particle size)
- # and default reorientation scale is pi/4
- moving_distance = rigid_body.RigidBody(0.25 * (jnp.cbrt(1 / density) - jnp.cbrt(2)), jnp.pi / 4)
- energy_fn = oriented_particle.isotropic_to_cf_energy(energy.weeks_chandler_andersen, contact_fn, **cf_kwargs)
- energy_pair = smap.oriented_pair(energy_fn, displacement)
- energy_kwargs = {'sigma': 1, 'epsilon': 10}
- init_fn, apply_fn = monte_carlo.mc_mc(shift, energy_pair, kT, moving_distance)
- grid_state = grid_init(num, box_size)
- @jax.jit
- def scan_fn(state, i):
- state = apply_fn(state, **energy_kwargs)
- return state, state.accept
- def mc_simulation(key):
- init_state = init_fn(key, grid_state)
- state, accept_array = jax.lax.scan(scan_fn, init=init_state, xs=jnp.arange(mc_steps))
- # print(jnp.mean(jnp.array(accept_array, dtype=jnp.float32)))
- return state
- return mc_simulation
- def rdf(displacement_or_metric: space.DisplacementOrMetricFn,
- positions: Array,
- density: float,
- r_min: float,
- r_max: float,
- num_bins: int) -> tuple[Array, Array]:
- """
- Calculate the radial distribution function (RDF) of a set of particles in a simulation box.
- Args:
- displacement_or_metric: Displacement or metric function
- positions: An array of shape (num_particles, 3) containing the positions of the particles.
- density: number density of particles in the system
- r_min: The minimum radial distance to consider in the RDF calculation.
- r_max: The maximum radial distance to consider in the RDF calculation.
- num_bins: The number of bins to use in the RDF calculation.
- Returns:
- An array of shape (num_bins,) containing the midpoints of the radial distance bins and an array
- of shape (num_bins,) containing the values of the RDF for each bin.
- """
- # Define the bin edges for the RDF
- bin_edges = jnp.linspace(r_min, r_max, num_bins + 1)
- # Create a histogram of the pairwise distances between particles
- metric = space.canonicalize_displacement_or_metric(displacement_or_metric)
- pairwise_distances = space.map_product(metric)(positions, positions)
- i, j = jnp.triu_indices(pairwise_distances.shape[0], 1)
- histogram, _ = jnp.histogram(pairwise_distances[i, j].flatten(), bins=bin_edges)
- # Calculate the RDF
- bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
- bin_volumes = 4 / 3 * jnp.pi * (bin_edges[1:] ** 3 - bin_edges[:-1] ** 3)
- rdf = histogram / (density * bin_volumes * positions.shape[0] / 2)
- return bin_centers, rdf
|