from __future__ import annotations import time from functools import partial import jax from jax import lax, jit, random from curvature_assembly import oriented_particle, data_protocols, cost_functions, util from jax_md import simulate, rigid_body, dataclasses, space, partition, quantity import jax.numpy as jnp from typing import Callable, TypeVar, Optional, Any import warnings import copy # import equinox Array = jnp.ndarray NeighborFn = partition.NeighborFn NeighborListFormat = partition.NeighborListFormat T = TypeVar('T') InitFn = Callable[..., T] ApplyFn = Callable[[T], T] RigidBody = rigid_body.RigidBody InteractionParams = data_protocols.InteractionParams P = TypeVar('P', bound=InteractionParams) @dataclasses.dataclass class NVTSimulationParams: """ Container for NVT simulation parameters. """ num: int density: float simulation_steps: int dt: float kT: float config_every: int = 100 bptt_truncation: int = 500 def get_higher_temp_equilibration_params(sim_params: NVTSimulationParams, new_kT: float) -> NVTSimulationParams: """Get a new NVT simulation parameters for equilibration simulation at a given higher temperature.""" params_dict = copy.deepcopy(vars(sim_params)) params_dict['bptt_truncation'] = sim_params.simulation_steps params_dict['kT'] = new_kT return NVTSimulationParams(**params_dict) @dataclasses.dataclass class SimulationLogNoseHoover: """Dataclass for storing observables, invariants etc. during a simulation.""" T: Array E: Array K: Array H: Array current_len: Array @staticmethod def create_empty(num_steps: int, save_every: int) -> SimulationLogNoseHoover: E = jnp.zeros(num_steps // save_every) T = jnp.zeros(num_steps // save_every) K = jnp.zeros(num_steps // save_every) H = jnp.zeros(num_steps // save_every) return SimulationLogNoseHoover(T, E, K, H, 0) def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float, float): T = rigid_body.temperature(position=state.position, momentum=state.momentum, mass=ellipsoid_mass) E = energy_fn(state.position, **params) K = rigid_body.kinetic_energy(position=state.position, momentum=state.momentum, mass=ellipsoid_mass) H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, **params) return T, E, K, H def update(self, T: Array, E: Array, K: Array, H: Array) -> SimulationLogNoseHoover: idx = self.current_len log = dataclasses.replace(self, E=self.E.at[idx].set(E)) log = dataclasses.replace(log, T=log.T.at[idx].set(T)) log = dataclasses.replace(log, K=log.K.at[idx].set(K)) log = dataclasses.replace(log, H=log.H.at[idx].set(H)) log = dataclasses.replace(log, current_len=idx + 1) return log def revert_last_nsteps(self, nsteps) -> SimulationLogNoseHoover: log = dataclasses.replace(self, current_len=self.current_len - nsteps) return log @dataclasses.dataclass class SimulationLogLangevin: """Dataclass for storing observables, invariants etc. during a simulation.""" T: Array E: Array K: Array current_len: Array @staticmethod def create_empty(num_steps: int, save_every: int) -> SimulationLogLangevin: E = jnp.zeros(num_steps // save_every) T = jnp.zeros(num_steps // save_every) K = jnp.zeros(num_steps // save_every) return SimulationLogLangevin(T, E, K, 0) def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float): T = rigid_body.temperature(position=state.position, momentum=state.momentum, mass=ellipsoid_mass) E = energy_fn(state.position, **params) K = rigid_body.kinetic_energy(position=state.position, momentum=state.momentum, mass=ellipsoid_mass) return T, E, K def update(self, T: Array, E: Array, K: Array) -> SimulationLogLangevin: idx = self.current_len log = dataclasses.replace(self, E=self.E.at[idx].set(E)) log = dataclasses.replace(log, T=log.T.at[idx].set(T)) log = dataclasses.replace(log, K=log.K.at[idx].set(K)) log = dataclasses.replace(log, current_len=idx + 1) return log def revert_last_nsteps(self, nsteps) -> SimulationLogLangevin: log = dataclasses.replace(self, current_len=self.current_len - nsteps) return log @dataclasses.dataclass class SimulationStateHistory: """Dataclass for storing particle configurations during a simulation.""" coord: Array orient: Array current_len: Array @staticmethod def create_empty(num_steps: int, n_particles: int, config_every: int) -> SimulationStateHistory: coord = jnp.zeros((num_steps // config_every, n_particles, 3)) orient = jnp.zeros((num_steps // config_every, n_particles, 4)) return SimulationStateHistory(coord, orient, 0) def update(self, coord: Array, orient: Array) -> SimulationStateHistory: idx = self.current_len state_history = dataclasses.replace(self, coord=self.coord.at[idx].set(coord)) state_history = dataclasses.replace(state_history, orient=state_history.orient.at[idx].set(orient)) state_history = dataclasses.replace(state_history, current_len=idx + 1) return state_history def revert_last_nsteps(self, nsteps) -> SimulationStateHistory: log = dataclasses.replace(self, current_len=self.current_len - nsteps) return log @dataclasses.dataclass class SimulationAux: """Dataclass for simulation auxiliary data.""" log: data_protocols.SimulationLog state_history: data_protocols.SimulationStateHistory def revert_last_nsteps(self, nsteps, config_every): log = self.log.revert_last_nsteps(nsteps) state_history = self.state_history.revert_last_nsteps(nsteps // config_every) aux = dataclasses.replace(self, log=log) aux = dataclasses.replace(aux, state_history=state_history) return aux def reset_empty(self) -> SimulationAux: """ Set current_len attribute of SimulationLog and SimulationStateHistory classes to 0 which effectively resets their empty state (current data will be overwritten in the next bptt simulation run). """ # we use zeros_like() because of possible parallelization that adds an axis to current_len attribute empty_log = dataclasses.replace(self.log, current_len=jnp.zeros_like(self.log.current_len)) empty_history = dataclasses.replace(self.state_history, current_len=jnp.zeros_like(self.state_history.current_len)) aux = dataclasses.replace(self, log=empty_log) aux = dataclasses.replace(aux, state_history=empty_history) return aux def setup_nose_hoover(energy: Callable, shift: space.ShiftFn, simulation_params: NVTSimulationParams, **nose_hoover_kwargs) -> (InitFn, ApplyFn, SimulationAux): """ Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat. """ log = SimulationLogNoseHoover.create_empty(simulation_params.simulation_steps, simulation_params.config_every) state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps, simulation_params.num, simulation_params.config_every) aux = SimulationAux(log=log, state_history=state_history) init_fn, step_fn = simulate.nvt_nose_hoover(energy, shift, simulation_params.dt, simulation_params.kT, **nose_hoover_kwargs) return init_fn, step_fn, aux def setup_langevin(energy: Callable, shift: space.ShiftFn, simulation_params: NVTSimulationParams, **langevin_kwargs) -> (InitFn, ApplyFn, SimulationAux): """ Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat. """ log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every) state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps, simulation_params.num, simulation_params.config_every) aux = SimulationAux(log=log, state_history=state_history) init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT, **langevin_kwargs) return init_fn, step_fn, aux def rescale_momenta_new_temperature(state: simulate.NVTNoseHooverState, new_kT: float, old_kT: float) -> simulate.NVTNoseHooverState: new_momentum_center = jnp.sqrt(new_kT / old_kT) * state.momentum.center new_momentum_orientation = jnp.sqrt(new_kT / old_kT) * state.momentum.orientation.vec return state.set(momentum=RigidBody(new_momentum_center, rigid_body.Quaternion(new_momentum_orientation))) def init_nose_hoover_new_temperature(state: simulate.NVTNoseHooverState, new_kT: float, old_kT: float, dt: float, chain_length: int = 5, chain_steps: int = 2, sy_steps: int = 3, tau: Optional[float] = None) -> simulate.NVTNoseHooverState: dt = simulate.f32(dt) if tau is None: tau = dt * 100 tau = simulate.f32(tau) thermostat = simulate.nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau) dof = quantity.count_dof(state.position) state = rescale_momenta_new_temperature(state, new_kT, old_kT) KE = simulate.kinetic_energy(state) return state.set(chain=thermostat.initialize(dof, KE, new_kT)) def setup_langevin(energy: Callable, shift: space.ShiftFn, simulation_params: NVTSimulationParams, gamma: RigidBody = RigidBody(0.1, 0.1)) -> (InitFn, ApplyFn, SimulationAux): """ Prepare functions and auxiliary data container for a molecular dynamics simulation using the Langevin thermostat. """ log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every) state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps, simulation_params.num, simulation_params.config_every) aux = SimulationAux(log=log, state_history=state_history) init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT, gamma=gamma) return init_fn, step_fn, aux def ellipsoid_unit_mass(eigvals: Array): return oriented_particle.ellipsoid_mass(jnp.array([1.]), eigvals) def simulation_step(state_aux_params: tuple[T, data_protocols.SimulationAux, InteractionParams], iteration_idx: int, step_fn: Callable, energy_fn: Callable, kT: float, config_every: int) -> (tuple[T, data_protocols.SimulationAux, InteractionParams], float): """Perform one simulation step and log the progress.""" state, aux, params = state_aux_params log = aux.log state_history = aux.state_history # take a simulation step # params must be passed as a dictionary state = step_fn(state, **vars(params)) def update_aux(l, h): new_log = l.update(*log.calculate_values(state, energy_fn, ellipsoid_unit_mass(params.eigvals), kT, **vars(params))) new_history = h.update(state.position.center, state.position.orientation.vec) return new_log, new_history # log information about simulation as well as the state history log, state_history = lax.cond((iteration_idx + 1) % config_every == 0, update_aux, lambda l, h: (l, h), log, state_history) aux = dataclasses.replace(aux, log=log) aux = dataclasses.replace(aux, state_history=state_history) return (state, aux, params), 0. def nvt_simulation_pair(init_fn: InitFn, step_fn: ApplyFn, aux: SimulationAux, energy: Callable[[...], Array], interaction_params: InteractionParams, simulation_params: NVTSimulationParams, body: RigidBody) -> (RigidBody, SimulationAux): # set all particle masses to 1 ellipsoid_mass = oriented_particle.ellipsoid_mass(jnp.array([1.]), interaction_params.eigvals) # setup simulation scan_step = partial(simulation_step, step_fn=step_fn, energy_fn=energy, kT=simulation_params.kT, ellipsoid_mass=ellipsoid_mass, config_every=simulation_params.config_every) # initialize state key = random.PRNGKey(0) state = init_fn(key, body, mass=ellipsoid_mass) @jit def scan_to_jit(state_aux, num_steps): new_state_aux, _ = lax.scan(scan_step, state_aux, num_steps) return new_state_aux # run simulation print('Simulation start') t0 = time.perf_counter() state_and_aux = scan_to_jit((state, aux), jnp.arange(simulation_params.simulation_steps)) state, aux = state_and_aux t1 = time.perf_counter() print(f'Simulation time: {t1 - t0}') return state.position, aux @dataclasses.dataclass class BPTTResults: grad: InteractionParams cost: Array current_len: int @staticmethod def create_empty(interaction_params: InteractionParams, n_steps: int) -> BPTTResults: grad = empty_grad_results(interaction_params, n_steps) cost = jnp.zeros((n_steps,)) return BPTTResults(grad, cost, 0) def update(self, grad: InteractionParams, cost: float) -> BPTTResults: idx = self.current_len new_grad = jax.tree_util.tree_map(partial(update_gradient_history, idx=idx), self.grad, grad) r = dataclasses.replace(self, grad=new_grad) r = dataclasses.replace(r, cost=self.cost.at[idx].set(cost)) r = dataclasses.replace(r, current_len=idx + 1) return r def empty_grad_results(interaction_params: P, num_rep: int) -> P: """ Initializes an interactions parameters class to store gradients after each section of a truncated BPTT run. """ history_dict = {} for key, value in vars(interaction_params).items(): try: history_dict[key] = jnp.zeros((num_rep,) + value.shape) except AttributeError: history_dict[key] = jnp.zeros((num_rep,)) return type(interaction_params)(**history_dict) def update_gradient_history(history_array: Array, grad_value: Array, idx: int) -> Array: """Update a gradient history array at idx with a new gradient value.""" return history_array.at[idx].set(grad_value) def simple_forward_simulation(init_fn: InitFn, step_fn: ApplyFn, num_steps: int ) -> Callable[[InteractionParams, RigidBody, int, jax.random.PRNGKey], RigidBody]: """Elementary forward MD simulation, without any logging and configuration saving.""" def simulation(interaction_params: InteractionParams, body: RigidBody, key: jax.random.PRNGKey): # initialize state state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params)) def scan_step(state, i): return step_fn(state, **vars(interaction_params)), 0. state, _ = lax.scan(scan_step, state, jnp.arange(num_steps)) return state return simulation def truncated_bptt_nvt_simulation(step_fn: ApplyFn, energy: Callable[[...], Array], cost_fn: cost_functions.CostFn, simulation_params: NVTSimulationParams, only_forward_calculation: bool = False) -> data_protocols.BpttSimulation: # simulation setup scan_step = partial(simulation_step, step_fn=step_fn, energy_fn=energy, kT=simulation_params.kT, config_every=simulation_params.config_every) # loop_fn = partial(equinox.internal.scan, kind='checkpointed', checkpoints=10) if simulation_params.simulation_steps < simulation_params.config_every: raise ValueError(f'Number of simulation steps must be higher or equal to the config_every value, ' f'got {simulation_params.simulation_steps} and {simulation_params.config_every}, ' f'respectively') n_iterations = simulation_params.simulation_steps // simulation_params.bptt_truncation if n_iterations == 0: raise ValueError(f'Number of simulation steps must be equal to or grater than BPTT truncation, ' f'got {simulation_params.simulation_steps} and {simulation_params.bptt_truncation}, ' f'respectively.') if n_iterations * simulation_params.bptt_truncation < simulation_params.simulation_steps: warnings.warn(f'Only {n_iterations * simulation_params.bptt_truncation} time steps will be calculated ' f'as bptt truncation length does not divide the desired number of steps exactly.') def forward_function(params: InteractionParams, state, aux): (state, aux, _), _ = lax.scan(scan_step, (state, aux, params), jnp.arange(simulation_params.bptt_truncation)) cost = cost_fn(state.position, **vars(params)) return cost, (state, aux) grad_fn = jax.value_and_grad(forward_function, has_aux=True, argnums=(0,)) def bptt_section(state_aux_params_results, i): state, aux, params, results = state_aux_params_results value, grad = grad_fn(params, state, aux) cost, (state, aux) = value results = results.update(grad[0], cost) return (state, aux, params, results), 0. def bptt_section_test(state_aux_params_results, i): state, aux, params, results = state_aux_params_results value = forward_function(params, state, aux) cost, (state, aux) = value results = results.update(jax.tree_util.tree_map(jnp.zeros_like, params), cost) return (state, aux, params, results), 0. if only_forward_calculation: bptt_section = bptt_section_test def simulation(interaction_params: InteractionParams, init_state: Any, aux: data_protocols.SimulationAux ): # initialize state # state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params)) # initialize results object bptt_results = BPTTResults.create_empty(interaction_params, n_iterations) # run simulation state_aux_results, _ = jax.lax.scan(bptt_section, (init_state, aux, interaction_params, bptt_results), xs=jnp.arange(n_iterations)) state, aux, params, bptt_results = state_aux_results return bptt_results, aux return simulation