|
- from __future__ import annotations
- import time
- from functools import partial
- import jax
- from jax import lax, jit, random
- from curvature_assembly import oriented_particle, data_protocols, cost_functions, util
- from jax_md import simulate, rigid_body, dataclasses, space, partition, quantity
- import jax.numpy as jnp
- from typing import Callable, TypeVar, Optional, Any
- import warnings
- import copy
- # import equinox
- Array = jnp.ndarray
- NeighborFn = partition.NeighborFn
- NeighborListFormat = partition.NeighborListFormat
- T = TypeVar('T')
- InitFn = Callable[..., T]
- ApplyFn = Callable[[T], T]
- RigidBody = rigid_body.RigidBody
- InteractionParams = data_protocols.InteractionParams
- P = TypeVar('P', bound=InteractionParams)
- @dataclasses.dataclass
- class NVTSimulationParams:
- """
- Container for NVT simulation parameters.
- """
- num: int
- density: float
- simulation_steps: int
- dt: float
- kT: float
- config_every: int = 100
- bptt_truncation: int = 500
- def get_higher_temp_equilibration_params(sim_params: NVTSimulationParams, new_kT: float) -> NVTSimulationParams:
- """Get a new NVT simulation parameters for equilibration simulation at a given higher temperature."""
- params_dict = copy.deepcopy(vars(sim_params))
- params_dict['bptt_truncation'] = sim_params.simulation_steps
- params_dict['kT'] = new_kT
- return NVTSimulationParams(**params_dict)
- @dataclasses.dataclass
- class SimulationLogNoseHoover:
- """Dataclass for storing observables, invariants etc. during a simulation."""
- T: Array
- E: Array
- K: Array
- H: Array
- current_len: Array
- @staticmethod
- def create_empty(num_steps: int, save_every: int) -> SimulationLogNoseHoover:
- E = jnp.zeros(num_steps // save_every)
- T = jnp.zeros(num_steps // save_every)
- K = jnp.zeros(num_steps // save_every)
- H = jnp.zeros(num_steps // save_every)
- return SimulationLogNoseHoover(T, E, K, H, 0)
- def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float, float):
- T = rigid_body.temperature(position=state.position,
- momentum=state.momentum,
- mass=ellipsoid_mass)
- E = energy_fn(state.position, **params)
- K = rigid_body.kinetic_energy(position=state.position,
- momentum=state.momentum,
- mass=ellipsoid_mass)
- H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, **params)
- return T, E, K, H
- def update(self, T: Array, E: Array, K: Array, H: Array) -> SimulationLogNoseHoover:
- idx = self.current_len
- log = dataclasses.replace(self, E=self.E.at[idx].set(E))
- log = dataclasses.replace(log, T=log.T.at[idx].set(T))
- log = dataclasses.replace(log, K=log.K.at[idx].set(K))
- log = dataclasses.replace(log, H=log.H.at[idx].set(H))
- log = dataclasses.replace(log, current_len=idx + 1)
- return log
- def revert_last_nsteps(self, nsteps) -> SimulationLogNoseHoover:
- log = dataclasses.replace(self, current_len=self.current_len - nsteps)
- return log
- @dataclasses.dataclass
- class SimulationLogLangevin:
- """Dataclass for storing observables, invariants etc. during a simulation."""
- T: Array
- E: Array
- K: Array
- current_len: Array
- @staticmethod
- def create_empty(num_steps: int, save_every: int) -> SimulationLogLangevin:
- E = jnp.zeros(num_steps // save_every)
- T = jnp.zeros(num_steps // save_every)
- K = jnp.zeros(num_steps // save_every)
- return SimulationLogLangevin(T, E, K, 0)
- def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float):
- T = rigid_body.temperature(position=state.position,
- momentum=state.momentum,
- mass=ellipsoid_mass)
- E = energy_fn(state.position, **params)
- K = rigid_body.kinetic_energy(position=state.position,
- momentum=state.momentum,
- mass=ellipsoid_mass)
- return T, E, K
- def update(self, T: Array, E: Array, K: Array) -> SimulationLogLangevin:
- idx = self.current_len
- log = dataclasses.replace(self, E=self.E.at[idx].set(E))
- log = dataclasses.replace(log, T=log.T.at[idx].set(T))
- log = dataclasses.replace(log, K=log.K.at[idx].set(K))
- log = dataclasses.replace(log, current_len=idx + 1)
- return log
- def revert_last_nsteps(self, nsteps) -> SimulationLogLangevin:
- log = dataclasses.replace(self, current_len=self.current_len - nsteps)
- return log
- @dataclasses.dataclass
- class SimulationStateHistory:
- """Dataclass for storing particle configurations during a simulation."""
- coord: Array
- orient: Array
- current_len: Array
- @staticmethod
- def create_empty(num_steps: int, n_particles: int, config_every: int) -> SimulationStateHistory:
- coord = jnp.zeros((num_steps // config_every, n_particles, 3))
- orient = jnp.zeros((num_steps // config_every, n_particles, 4))
- return SimulationStateHistory(coord, orient, 0)
- def update(self, coord: Array, orient: Array) -> SimulationStateHistory:
- idx = self.current_len
- state_history = dataclasses.replace(self, coord=self.coord.at[idx].set(coord))
- state_history = dataclasses.replace(state_history, orient=state_history.orient.at[idx].set(orient))
- state_history = dataclasses.replace(state_history, current_len=idx + 1)
- return state_history
- def revert_last_nsteps(self, nsteps) -> SimulationStateHistory:
- log = dataclasses.replace(self, current_len=self.current_len - nsteps)
- return log
- @dataclasses.dataclass
- class SimulationAux:
- """Dataclass for simulation auxiliary data."""
- log: data_protocols.SimulationLog
- state_history: data_protocols.SimulationStateHistory
- def revert_last_nsteps(self, nsteps, config_every):
- log = self.log.revert_last_nsteps(nsteps)
- state_history = self.state_history.revert_last_nsteps(nsteps // config_every)
- aux = dataclasses.replace(self, log=log)
- aux = dataclasses.replace(aux, state_history=state_history)
- return aux
- def reset_empty(self) -> SimulationAux:
- """
- Set current_len attribute of SimulationLog and SimulationStateHistory classes to 0 which effectively resets
- their empty state (current data will be overwritten in the next bptt simulation run).
- """
- # we use zeros_like() because of possible parallelization that adds an axis to current_len attribute
- empty_log = dataclasses.replace(self.log, current_len=jnp.zeros_like(self.log.current_len))
- empty_history = dataclasses.replace(self.state_history, current_len=jnp.zeros_like(self.state_history.current_len))
- aux = dataclasses.replace(self, log=empty_log)
- aux = dataclasses.replace(aux, state_history=empty_history)
- return aux
- def setup_nose_hoover(energy: Callable,
- shift: space.ShiftFn,
- simulation_params: NVTSimulationParams,
- **nose_hoover_kwargs) -> (InitFn, ApplyFn, SimulationAux):
- """
- Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
- """
- log = SimulationLogNoseHoover.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
- state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
- simulation_params.num,
- simulation_params.config_every)
- aux = SimulationAux(log=log, state_history=state_history)
- init_fn, step_fn = simulate.nvt_nose_hoover(energy, shift, simulation_params.dt, simulation_params.kT,
- **nose_hoover_kwargs)
- return init_fn, step_fn, aux
- def setup_langevin(energy: Callable,
- shift: space.ShiftFn,
- simulation_params: NVTSimulationParams,
- **langevin_kwargs) -> (InitFn, ApplyFn, SimulationAux):
- """
- Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
- """
- log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
- state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
- simulation_params.num,
- simulation_params.config_every)
- aux = SimulationAux(log=log, state_history=state_history)
- init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
- **langevin_kwargs)
- return init_fn, step_fn, aux
- def rescale_momenta_new_temperature(state: simulate.NVTNoseHooverState,
- new_kT: float,
- old_kT: float) -> simulate.NVTNoseHooverState:
- new_momentum_center = jnp.sqrt(new_kT / old_kT) * state.momentum.center
- new_momentum_orientation = jnp.sqrt(new_kT / old_kT) * state.momentum.orientation.vec
- return state.set(momentum=RigidBody(new_momentum_center, rigid_body.Quaternion(new_momentum_orientation)))
- def init_nose_hoover_new_temperature(state: simulate.NVTNoseHooverState,
- new_kT: float,
- old_kT: float,
- dt: float,
- chain_length: int = 5,
- chain_steps: int = 2,
- sy_steps: int = 3,
- tau: Optional[float] = None) -> simulate.NVTNoseHooverState:
- dt = simulate.f32(dt)
- if tau is None:
- tau = dt * 100
- tau = simulate.f32(tau)
- thermostat = simulate.nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)
- dof = quantity.count_dof(state.position)
- state = rescale_momenta_new_temperature(state, new_kT, old_kT)
- KE = simulate.kinetic_energy(state)
- return state.set(chain=thermostat.initialize(dof, KE, new_kT))
- def setup_langevin(energy: Callable,
- shift: space.ShiftFn,
- simulation_params: NVTSimulationParams,
- gamma: RigidBody = RigidBody(0.1, 0.1)) -> (InitFn, ApplyFn, SimulationAux):
- """
- Prepare functions and auxiliary data container for a molecular dynamics simulation using the Langevin thermostat.
- """
- log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
- state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
- simulation_params.num,
- simulation_params.config_every)
- aux = SimulationAux(log=log, state_history=state_history)
- init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
- gamma=gamma)
- return init_fn, step_fn, aux
- def ellipsoid_unit_mass(eigvals: Array):
- return oriented_particle.ellipsoid_mass(jnp.array([1.]), eigvals)
- def simulation_step(state_aux_params: tuple[T, data_protocols.SimulationAux, InteractionParams],
- iteration_idx: int,
- step_fn: Callable,
- energy_fn: Callable,
- kT: float,
- config_every: int) -> (tuple[T, data_protocols.SimulationAux, InteractionParams], float):
- """Perform one simulation step and log the progress."""
- state, aux, params = state_aux_params
- log = aux.log
- state_history = aux.state_history
- # take a simulation step
- # params must be passed as a dictionary
- state = step_fn(state, **vars(params))
- def update_aux(l, h):
- new_log = l.update(*log.calculate_values(state,
- energy_fn,
- ellipsoid_unit_mass(params.eigvals),
- kT,
- **vars(params)))
- new_history = h.update(state.position.center,
- state.position.orientation.vec)
- return new_log, new_history
- # log information about simulation as well as the state history
- log, state_history = lax.cond((iteration_idx + 1) % config_every == 0,
- update_aux,
- lambda l, h: (l, h),
- log, state_history)
- aux = dataclasses.replace(aux, log=log)
- aux = dataclasses.replace(aux, state_history=state_history)
- return (state, aux, params), 0.
- def nvt_simulation_pair(init_fn: InitFn,
- step_fn: ApplyFn,
- aux: SimulationAux,
- energy: Callable[[...], Array],
- interaction_params: InteractionParams,
- simulation_params: NVTSimulationParams,
- body: RigidBody) -> (RigidBody, SimulationAux):
- # set all particle masses to 1
- ellipsoid_mass = oriented_particle.ellipsoid_mass(jnp.array([1.]), interaction_params.eigvals)
- # setup simulation
- scan_step = partial(simulation_step, step_fn=step_fn,
- energy_fn=energy, kT=simulation_params.kT, ellipsoid_mass=ellipsoid_mass,
- config_every=simulation_params.config_every)
- # initialize state
- key = random.PRNGKey(0)
- state = init_fn(key, body, mass=ellipsoid_mass)
- @jit
- def scan_to_jit(state_aux, num_steps):
- new_state_aux, _ = lax.scan(scan_step, state_aux, num_steps)
- return new_state_aux
- # run simulation
- print('Simulation start')
- t0 = time.perf_counter()
- state_and_aux = scan_to_jit((state, aux), jnp.arange(simulation_params.simulation_steps))
- state, aux = state_and_aux
- t1 = time.perf_counter()
- print(f'Simulation time: {t1 - t0}')
- return state.position, aux
- @dataclasses.dataclass
- class BPTTResults:
- grad: InteractionParams
- cost: Array
- current_len: int
- @staticmethod
- def create_empty(interaction_params: InteractionParams, n_steps: int) -> BPTTResults:
- grad = empty_grad_results(interaction_params, n_steps)
- cost = jnp.zeros((n_steps,))
- return BPTTResults(grad, cost, 0)
- def update(self, grad: InteractionParams, cost: float) -> BPTTResults:
- idx = self.current_len
- new_grad = jax.tree_util.tree_map(partial(update_gradient_history, idx=idx), self.grad, grad)
- r = dataclasses.replace(self, grad=new_grad)
- r = dataclasses.replace(r, cost=self.cost.at[idx].set(cost))
- r = dataclasses.replace(r, current_len=idx + 1)
- return r
- def empty_grad_results(interaction_params: P, num_rep: int) -> P:
- """
- Initializes an interactions parameters class to store gradients after each section of a truncated BPTT run.
- """
- history_dict = {}
- for key, value in vars(interaction_params).items():
- try:
- history_dict[key] = jnp.zeros((num_rep,) + value.shape)
- except AttributeError:
- history_dict[key] = jnp.zeros((num_rep,))
- return type(interaction_params)(**history_dict)
- def update_gradient_history(history_array: Array, grad_value: Array, idx: int) -> Array:
- """Update a gradient history array at idx with a new gradient value."""
- return history_array.at[idx].set(grad_value)
- def simple_forward_simulation(init_fn: InitFn,
- step_fn: ApplyFn,
- num_steps: int
- ) -> Callable[[InteractionParams, RigidBody, int, jax.random.PRNGKey], RigidBody]:
- """Elementary forward MD simulation, without any logging and configuration saving."""
- def simulation(interaction_params: InteractionParams,
- body: RigidBody,
- key: jax.random.PRNGKey):
- # initialize state
- state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
- def scan_step(state, i):
- return step_fn(state, **vars(interaction_params)), 0.
- state, _ = lax.scan(scan_step, state, jnp.arange(num_steps))
- return state
- return simulation
- def truncated_bptt_nvt_simulation(step_fn: ApplyFn,
- energy: Callable[[...], Array],
- cost_fn: cost_functions.CostFn,
- simulation_params: NVTSimulationParams,
- only_forward_calculation: bool = False) -> data_protocols.BpttSimulation:
- # simulation setup
- scan_step = partial(simulation_step, step_fn=step_fn,
- energy_fn=energy, kT=simulation_params.kT,
- config_every=simulation_params.config_every)
- # loop_fn = partial(equinox.internal.scan, kind='checkpointed', checkpoints=10)
- if simulation_params.simulation_steps < simulation_params.config_every:
- raise ValueError(f'Number of simulation steps must be higher or equal to the config_every value, '
- f'got {simulation_params.simulation_steps} and {simulation_params.config_every}, '
- f'respectively')
- n_iterations = simulation_params.simulation_steps // simulation_params.bptt_truncation
- if n_iterations == 0:
- raise ValueError(f'Number of simulation steps must be equal to or grater than BPTT truncation, '
- f'got {simulation_params.simulation_steps} and {simulation_params.bptt_truncation}, '
- f'respectively.')
- if n_iterations * simulation_params.bptt_truncation < simulation_params.simulation_steps:
- warnings.warn(f'Only {n_iterations * simulation_params.bptt_truncation} time steps will be calculated '
- f'as bptt truncation length does not divide the desired number of steps exactly.')
- def forward_function(params: InteractionParams, state, aux):
- (state, aux, _), _ = lax.scan(scan_step, (state, aux, params), jnp.arange(simulation_params.bptt_truncation))
- cost = cost_fn(state.position, **vars(params))
- return cost, (state, aux)
- grad_fn = jax.value_and_grad(forward_function, has_aux=True, argnums=(0,))
- def bptt_section(state_aux_params_results, i):
- state, aux, params, results = state_aux_params_results
- value, grad = grad_fn(params, state, aux)
- cost, (state, aux) = value
- results = results.update(grad[0], cost)
- return (state, aux, params, results), 0.
- def bptt_section_test(state_aux_params_results, i):
- state, aux, params, results = state_aux_params_results
- value = forward_function(params, state, aux)
- cost, (state, aux) = value
- results = results.update(jax.tree_util.tree_map(jnp.zeros_like, params), cost)
- return (state, aux, params, results), 0.
- if only_forward_calculation:
- bptt_section = bptt_section_test
- def simulation(interaction_params: InteractionParams,
- init_state: Any,
- aux: data_protocols.SimulationAux
- ):
- # initialize state
- # state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
- # initialize results object
- bptt_results = BPTTResults.create_empty(interaction_params, n_iterations)
- # run simulation
- state_aux_results, _ = jax.lax.scan(bptt_section,
- (init_state, aux, interaction_params, bptt_results),
- xs=jnp.arange(n_iterations))
- state, aux, params, bptt_results = state_aux_results
- return bptt_results, aux
- return simulation
|