import jax.numpy as jnp import jax.random import jax from jax_md import quantity, rigid_body, space from typing import Callable from curvature_assembly import monte_carlo, oriented_particle, smap, energy Array = jnp.ndarray def grid_init(num: int, box_size: float, initial_orient=None ) -> rigid_body.RigidBody: """ Initialize a 3D grid of particles within a box of given size. Args: num: Number of particles in the grid. box_size: The length of the box in which the grid is placed. initial_orient: Initial orientation of the particles. Default is None, which corresponds to an initial orientation quaternion (1., 0., 0., 0.). Returns: A RigidBody object containing the initial positions and orientations of the particles. """ Nmax = jnp.ceil(jnp.cbrt(num)) gridpoints_1d = jnp.arange(Nmax) * box_size / Nmax x = jnp.meshgrid(*(3 * (gridpoints_1d,))) y = jnp.vstack(list(map(jnp.ravel, x))).T position = y[:num] if initial_orient is None: initial_orient = jnp.array([1., 0., 0., 0.]) orientation = rigid_body.Quaternion(jnp.tile(initial_orient, (num, 1))) return rigid_body.RigidBody(position, orientation) def randomize_init_mc(num: int, density: float, contact_fn: Callable, mc_steps: int, kT: float, moving_distance: rigid_body.RigidBody = None, **cf_kwargs ) -> Callable[[jax.random.KeyArray], monte_carlo.MCMCState]: """ Create an MC simulation function that generates random positions and orientations of particles in a simulation box with periodic boundary conditions starting from a grid of particles. Args: num: the number of particles in the system density: the density of the system contact_fn: a function that calculates the contact distance between particles mc_steps: the number of Monte Carlo steps to take kT: the temperature parameter for Metropolis criterion moving_distance: a RigidBody object that holds the maximum distance by which a particle can move and reorientate. If not provided, a default scale is set based on the density of the simulation. **cf_kwargs: any additional keyword arguments that should be passed to the contact function Returns: A callable function that takes a jax.random.KeyArray and returns a monte_carlo.MCMCState object. """ box_size = quantity.box_size_at_number_density(num, density, spatial_dimension=3) displacement, shift = space.periodic(box_size) if moving_distance is None: # default scale for particle movement is approx 1 / 4 interparticle distance (taking into account particle size) # and default reorientation scale is pi/4 moving_distance = rigid_body.RigidBody(0.25 * (jnp.cbrt(1 / density) - jnp.cbrt(2)), jnp.pi / 4) energy_fn = oriented_particle.isotropic_to_cf_energy(energy.weeks_chandler_andersen, contact_fn, **cf_kwargs) energy_pair = smap.oriented_pair(energy_fn, displacement) energy_kwargs = {'sigma': 1, 'epsilon': 10} init_fn, apply_fn = monte_carlo.mc_mc(shift, energy_pair, kT, moving_distance) grid_state = grid_init(num, box_size) @jax.jit def scan_fn(state, i): state = apply_fn(state, **energy_kwargs) return state, state.accept def mc_simulation(key): init_state = init_fn(key, grid_state) state, accept_array = jax.lax.scan(scan_fn, init=init_state, xs=jnp.arange(mc_steps)) # print(jnp.mean(jnp.array(accept_array, dtype=jnp.float32))) return state return mc_simulation def rdf(displacement_or_metric: space.DisplacementOrMetricFn, positions: Array, density: float, r_min: float, r_max: float, num_bins: int) -> tuple[Array, Array]: """ Calculate the radial distribution function (RDF) of a set of particles in a simulation box. Args: displacement_or_metric: Displacement or metric function positions: An array of shape (num_particles, 3) containing the positions of the particles. density: number density of particles in the system r_min: The minimum radial distance to consider in the RDF calculation. r_max: The maximum radial distance to consider in the RDF calculation. num_bins: The number of bins to use in the RDF calculation. Returns: An array of shape (num_bins,) containing the midpoints of the radial distance bins and an array of shape (num_bins,) containing the values of the RDF for each bin. """ # Define the bin edges for the RDF bin_edges = jnp.linspace(r_min, r_max, num_bins + 1) # Create a histogram of the pairwise distances between particles metric = space.canonicalize_displacement_or_metric(displacement_or_metric) pairwise_distances = space.map_product(metric)(positions, positions) i, j = jnp.triu_indices(pairwise_distances.shape[0], 1) histogram, _ = jnp.histogram(pairwise_distances[i, j].flatten(), bins=bin_edges) # Calculate the RDF bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 bin_volumes = 4 / 3 * jnp.pi * (bin_edges[1:] ** 3 - bin_edges[:-1] ** 3) rdf = histogram / (density * bin_volumes * positions.shape[0] / 2) return bin_centers, rdf