simulation.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. from __future__ import annotations
  2. import time
  3. from functools import partial
  4. import jax
  5. from jax import lax, jit, random
  6. from curvature_assembly import oriented_particle, data_protocols, cost_functions, util
  7. from jax_md import simulate, rigid_body, dataclasses, space, partition, quantity
  8. import jax.numpy as jnp
  9. from typing import Callable, TypeVar, Optional, Any
  10. import warnings
  11. import copy
  12. # import equinox
  13. Array = jnp.ndarray
  14. NeighborFn = partition.NeighborFn
  15. NeighborListFormat = partition.NeighborListFormat
  16. T = TypeVar('T')
  17. InitFn = Callable[..., T]
  18. ApplyFn = Callable[[T], T]
  19. RigidBody = rigid_body.RigidBody
  20. InteractionParams = data_protocols.InteractionParams
  21. P = TypeVar('P', bound=InteractionParams)
  22. @dataclasses.dataclass
  23. class NVTSimulationParams:
  24. """
  25. Container for NVT simulation parameters.
  26. """
  27. num: int
  28. density: float
  29. simulation_steps: int
  30. dt: float
  31. kT: float
  32. config_every: int = 100
  33. bptt_truncation: int = 500
  34. def get_higher_temp_equilibration_params(sim_params: NVTSimulationParams, new_kT: float) -> NVTSimulationParams:
  35. """Get a new NVT simulation parameters for equilibration simulation at a given higher temperature."""
  36. params_dict = copy.deepcopy(vars(sim_params))
  37. params_dict['bptt_truncation'] = sim_params.simulation_steps
  38. params_dict['kT'] = new_kT
  39. return NVTSimulationParams(**params_dict)
  40. @dataclasses.dataclass
  41. class SimulationLogNoseHoover:
  42. """Dataclass for storing observables, invariants etc. during a simulation."""
  43. T: Array
  44. E: Array
  45. K: Array
  46. H: Array
  47. current_len: Array
  48. @staticmethod
  49. def create_empty(num_steps: int, save_every: int) -> SimulationLogNoseHoover:
  50. E = jnp.zeros(num_steps // save_every)
  51. T = jnp.zeros(num_steps // save_every)
  52. K = jnp.zeros(num_steps // save_every)
  53. H = jnp.zeros(num_steps // save_every)
  54. return SimulationLogNoseHoover(T, E, K, H, 0)
  55. def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float, float):
  56. T = rigid_body.temperature(position=state.position,
  57. momentum=state.momentum,
  58. mass=ellipsoid_mass)
  59. E = energy_fn(state.position, **params)
  60. K = rigid_body.kinetic_energy(position=state.position,
  61. momentum=state.momentum,
  62. mass=ellipsoid_mass)
  63. H = simulate.nvt_nose_hoover_invariant(energy_fn, state, kT, **params)
  64. return T, E, K, H
  65. def update(self, T: Array, E: Array, K: Array, H: Array) -> SimulationLogNoseHoover:
  66. idx = self.current_len
  67. log = dataclasses.replace(self, E=self.E.at[idx].set(E))
  68. log = dataclasses.replace(log, T=log.T.at[idx].set(T))
  69. log = dataclasses.replace(log, K=log.K.at[idx].set(K))
  70. log = dataclasses.replace(log, H=log.H.at[idx].set(H))
  71. log = dataclasses.replace(log, current_len=idx + 1)
  72. return log
  73. def revert_last_nsteps(self, nsteps) -> SimulationLogNoseHoover:
  74. log = dataclasses.replace(self, current_len=self.current_len - nsteps)
  75. return log
  76. @dataclasses.dataclass
  77. class SimulationLogLangevin:
  78. """Dataclass for storing observables, invariants etc. during a simulation."""
  79. T: Array
  80. E: Array
  81. K: Array
  82. current_len: Array
  83. @staticmethod
  84. def create_empty(num_steps: int, save_every: int) -> SimulationLogLangevin:
  85. E = jnp.zeros(num_steps // save_every)
  86. T = jnp.zeros(num_steps // save_every)
  87. K = jnp.zeros(num_steps // save_every)
  88. return SimulationLogLangevin(T, E, K, 0)
  89. def calculate_values(self, state, energy_fn, ellipsoid_mass, kT, **params) -> (float, float, float):
  90. T = rigid_body.temperature(position=state.position,
  91. momentum=state.momentum,
  92. mass=ellipsoid_mass)
  93. E = energy_fn(state.position, **params)
  94. K = rigid_body.kinetic_energy(position=state.position,
  95. momentum=state.momentum,
  96. mass=ellipsoid_mass)
  97. return T, E, K
  98. def update(self, T: Array, E: Array, K: Array) -> SimulationLogLangevin:
  99. idx = self.current_len
  100. log = dataclasses.replace(self, E=self.E.at[idx].set(E))
  101. log = dataclasses.replace(log, T=log.T.at[idx].set(T))
  102. log = dataclasses.replace(log, K=log.K.at[idx].set(K))
  103. log = dataclasses.replace(log, current_len=idx + 1)
  104. return log
  105. def revert_last_nsteps(self, nsteps) -> SimulationLogLangevin:
  106. log = dataclasses.replace(self, current_len=self.current_len - nsteps)
  107. return log
  108. @dataclasses.dataclass
  109. class SimulationStateHistory:
  110. """Dataclass for storing particle configurations during a simulation."""
  111. coord: Array
  112. orient: Array
  113. current_len: Array
  114. @staticmethod
  115. def create_empty(num_steps: int, n_particles: int, config_every: int) -> SimulationStateHistory:
  116. coord = jnp.zeros((num_steps // config_every, n_particles, 3))
  117. orient = jnp.zeros((num_steps // config_every, n_particles, 4))
  118. return SimulationStateHistory(coord, orient, 0)
  119. def update(self, coord: Array, orient: Array) -> SimulationStateHistory:
  120. idx = self.current_len
  121. state_history = dataclasses.replace(self, coord=self.coord.at[idx].set(coord))
  122. state_history = dataclasses.replace(state_history, orient=state_history.orient.at[idx].set(orient))
  123. state_history = dataclasses.replace(state_history, current_len=idx + 1)
  124. return state_history
  125. def revert_last_nsteps(self, nsteps) -> SimulationStateHistory:
  126. log = dataclasses.replace(self, current_len=self.current_len - nsteps)
  127. return log
  128. @dataclasses.dataclass
  129. class SimulationAux:
  130. """Dataclass for simulation auxiliary data."""
  131. log: data_protocols.SimulationLog
  132. state_history: data_protocols.SimulationStateHistory
  133. def revert_last_nsteps(self, nsteps, config_every):
  134. log = self.log.revert_last_nsteps(nsteps)
  135. state_history = self.state_history.revert_last_nsteps(nsteps // config_every)
  136. aux = dataclasses.replace(self, log=log)
  137. aux = dataclasses.replace(aux, state_history=state_history)
  138. return aux
  139. def reset_empty(self) -> SimulationAux:
  140. """
  141. Set current_len attribute of SimulationLog and SimulationStateHistory classes to 0 which effectively resets
  142. their empty state (current data will be overwritten in the next bptt simulation run).
  143. """
  144. # we use zeros_like() because of possible parallelization that adds an axis to current_len attribute
  145. empty_log = dataclasses.replace(self.log, current_len=jnp.zeros_like(self.log.current_len))
  146. empty_history = dataclasses.replace(self.state_history, current_len=jnp.zeros_like(self.state_history.current_len))
  147. aux = dataclasses.replace(self, log=empty_log)
  148. aux = dataclasses.replace(aux, state_history=empty_history)
  149. return aux
  150. def setup_nose_hoover(energy: Callable,
  151. shift: space.ShiftFn,
  152. simulation_params: NVTSimulationParams,
  153. **nose_hoover_kwargs) -> (InitFn, ApplyFn, SimulationAux):
  154. """
  155. Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
  156. """
  157. log = SimulationLogNoseHoover.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
  158. state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
  159. simulation_params.num,
  160. simulation_params.config_every)
  161. aux = SimulationAux(log=log, state_history=state_history)
  162. init_fn, step_fn = simulate.nvt_nose_hoover(energy, shift, simulation_params.dt, simulation_params.kT,
  163. **nose_hoover_kwargs)
  164. return init_fn, step_fn, aux
  165. def setup_langevin(energy: Callable,
  166. shift: space.ShiftFn,
  167. simulation_params: NVTSimulationParams,
  168. **langevin_kwargs) -> (InitFn, ApplyFn, SimulationAux):
  169. """
  170. Prepare functions and auxiliary data container for a molecular dynamics simulation using the Nose-Hoover thermostat.
  171. """
  172. log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
  173. state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
  174. simulation_params.num,
  175. simulation_params.config_every)
  176. aux = SimulationAux(log=log, state_history=state_history)
  177. init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
  178. **langevin_kwargs)
  179. return init_fn, step_fn, aux
  180. def rescale_momenta_new_temperature(state: simulate.NVTNoseHooverState,
  181. new_kT: float,
  182. old_kT: float) -> simulate.NVTNoseHooverState:
  183. new_momentum_center = jnp.sqrt(new_kT / old_kT) * state.momentum.center
  184. new_momentum_orientation = jnp.sqrt(new_kT / old_kT) * state.momentum.orientation.vec
  185. return state.set(momentum=RigidBody(new_momentum_center, rigid_body.Quaternion(new_momentum_orientation)))
  186. def init_nose_hoover_new_temperature(state: simulate.NVTNoseHooverState,
  187. new_kT: float,
  188. old_kT: float,
  189. dt: float,
  190. chain_length: int = 5,
  191. chain_steps: int = 2,
  192. sy_steps: int = 3,
  193. tau: Optional[float] = None) -> simulate.NVTNoseHooverState:
  194. dt = simulate.f32(dt)
  195. if tau is None:
  196. tau = dt * 100
  197. tau = simulate.f32(tau)
  198. thermostat = simulate.nose_hoover_chain(dt, chain_length, chain_steps, sy_steps, tau)
  199. dof = quantity.count_dof(state.position)
  200. state = rescale_momenta_new_temperature(state, new_kT, old_kT)
  201. KE = simulate.kinetic_energy(state)
  202. return state.set(chain=thermostat.initialize(dof, KE, new_kT))
  203. def setup_langevin(energy: Callable,
  204. shift: space.ShiftFn,
  205. simulation_params: NVTSimulationParams,
  206. gamma: RigidBody = RigidBody(0.1, 0.1)) -> (InitFn, ApplyFn, SimulationAux):
  207. """
  208. Prepare functions and auxiliary data container for a molecular dynamics simulation using the Langevin thermostat.
  209. """
  210. log = SimulationLogLangevin.create_empty(simulation_params.simulation_steps, simulation_params.config_every)
  211. state_history = SimulationStateHistory.create_empty(simulation_params.simulation_steps,
  212. simulation_params.num,
  213. simulation_params.config_every)
  214. aux = SimulationAux(log=log, state_history=state_history)
  215. init_fn, step_fn = simulate.nvt_langevin(energy, shift, simulation_params.dt, simulation_params.kT,
  216. gamma=gamma)
  217. return init_fn, step_fn, aux
  218. def ellipsoid_unit_mass(eigvals: Array):
  219. return oriented_particle.ellipsoid_mass(jnp.array([1.]), eigvals)
  220. def simulation_step(state_aux_params: tuple[T, data_protocols.SimulationAux, InteractionParams],
  221. iteration_idx: int,
  222. step_fn: Callable,
  223. energy_fn: Callable,
  224. kT: float,
  225. config_every: int) -> (tuple[T, data_protocols.SimulationAux, InteractionParams], float):
  226. """Perform one simulation step and log the progress."""
  227. state, aux, params = state_aux_params
  228. log = aux.log
  229. state_history = aux.state_history
  230. # take a simulation step
  231. # params must be passed as a dictionary
  232. state = step_fn(state, **vars(params))
  233. def update_aux(l, h):
  234. new_log = l.update(*log.calculate_values(state,
  235. energy_fn,
  236. ellipsoid_unit_mass(params.eigvals),
  237. kT,
  238. **vars(params)))
  239. new_history = h.update(state.position.center,
  240. state.position.orientation.vec)
  241. return new_log, new_history
  242. # log information about simulation as well as the state history
  243. log, state_history = lax.cond((iteration_idx + 1) % config_every == 0,
  244. update_aux,
  245. lambda l, h: (l, h),
  246. log, state_history)
  247. aux = dataclasses.replace(aux, log=log)
  248. aux = dataclasses.replace(aux, state_history=state_history)
  249. return (state, aux, params), 0.
  250. def nvt_simulation_pair(init_fn: InitFn,
  251. step_fn: ApplyFn,
  252. aux: SimulationAux,
  253. energy: Callable[[...], Array],
  254. interaction_params: InteractionParams,
  255. simulation_params: NVTSimulationParams,
  256. body: RigidBody) -> (RigidBody, SimulationAux):
  257. # set all particle masses to 1
  258. ellipsoid_mass = oriented_particle.ellipsoid_mass(jnp.array([1.]), interaction_params.eigvals)
  259. # setup simulation
  260. scan_step = partial(simulation_step, step_fn=step_fn,
  261. energy_fn=energy, kT=simulation_params.kT, ellipsoid_mass=ellipsoid_mass,
  262. config_every=simulation_params.config_every)
  263. # initialize state
  264. key = random.PRNGKey(0)
  265. state = init_fn(key, body, mass=ellipsoid_mass)
  266. @jit
  267. def scan_to_jit(state_aux, num_steps):
  268. new_state_aux, _ = lax.scan(scan_step, state_aux, num_steps)
  269. return new_state_aux
  270. # run simulation
  271. print('Simulation start')
  272. t0 = time.perf_counter()
  273. state_and_aux = scan_to_jit((state, aux), jnp.arange(simulation_params.simulation_steps))
  274. state, aux = state_and_aux
  275. t1 = time.perf_counter()
  276. print(f'Simulation time: {t1 - t0}')
  277. return state.position, aux
  278. @dataclasses.dataclass
  279. class BPTTResults:
  280. grad: InteractionParams
  281. cost: Array
  282. current_len: int
  283. @staticmethod
  284. def create_empty(interaction_params: InteractionParams, n_steps: int) -> BPTTResults:
  285. grad = empty_grad_results(interaction_params, n_steps)
  286. cost = jnp.zeros((n_steps,))
  287. return BPTTResults(grad, cost, 0)
  288. def update(self, grad: InteractionParams, cost: float) -> BPTTResults:
  289. idx = self.current_len
  290. new_grad = jax.tree_util.tree_map(partial(update_gradient_history, idx=idx), self.grad, grad)
  291. r = dataclasses.replace(self, grad=new_grad)
  292. r = dataclasses.replace(r, cost=self.cost.at[idx].set(cost))
  293. r = dataclasses.replace(r, current_len=idx + 1)
  294. return r
  295. def empty_grad_results(interaction_params: P, num_rep: int) -> P:
  296. """
  297. Initializes an interactions parameters class to store gradients after each section of a truncated BPTT run.
  298. """
  299. history_dict = {}
  300. for key, value in vars(interaction_params).items():
  301. try:
  302. history_dict[key] = jnp.zeros((num_rep,) + value.shape)
  303. except AttributeError:
  304. history_dict[key] = jnp.zeros((num_rep,))
  305. return type(interaction_params)(**history_dict)
  306. def update_gradient_history(history_array: Array, grad_value: Array, idx: int) -> Array:
  307. """Update a gradient history array at idx with a new gradient value."""
  308. return history_array.at[idx].set(grad_value)
  309. def simple_forward_simulation(init_fn: InitFn,
  310. step_fn: ApplyFn,
  311. num_steps: int
  312. ) -> Callable[[InteractionParams, RigidBody, int, jax.random.PRNGKey], RigidBody]:
  313. """Elementary forward MD simulation, without any logging and configuration saving."""
  314. def simulation(interaction_params: InteractionParams,
  315. body: RigidBody,
  316. key: jax.random.PRNGKey):
  317. # initialize state
  318. state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
  319. def scan_step(state, i):
  320. return step_fn(state, **vars(interaction_params)), 0.
  321. state, _ = lax.scan(scan_step, state, jnp.arange(num_steps))
  322. return state
  323. return simulation
  324. def truncated_bptt_nvt_simulation(step_fn: ApplyFn,
  325. energy: Callable[[...], Array],
  326. cost_fn: cost_functions.CostFn,
  327. simulation_params: NVTSimulationParams,
  328. only_forward_calculation: bool = False) -> data_protocols.BpttSimulation:
  329. # simulation setup
  330. scan_step = partial(simulation_step, step_fn=step_fn,
  331. energy_fn=energy, kT=simulation_params.kT,
  332. config_every=simulation_params.config_every)
  333. # loop_fn = partial(equinox.internal.scan, kind='checkpointed', checkpoints=10)
  334. if simulation_params.simulation_steps < simulation_params.config_every:
  335. raise ValueError(f'Number of simulation steps must be higher or equal to the config_every value, '
  336. f'got {simulation_params.simulation_steps} and {simulation_params.config_every}, '
  337. f'respectively')
  338. n_iterations = simulation_params.simulation_steps // simulation_params.bptt_truncation
  339. if n_iterations == 0:
  340. raise ValueError(f'Number of simulation steps must be equal to or grater than BPTT truncation, '
  341. f'got {simulation_params.simulation_steps} and {simulation_params.bptt_truncation}, '
  342. f'respectively.')
  343. if n_iterations * simulation_params.bptt_truncation < simulation_params.simulation_steps:
  344. warnings.warn(f'Only {n_iterations * simulation_params.bptt_truncation} time steps will be calculated '
  345. f'as bptt truncation length does not divide the desired number of steps exactly.')
  346. def forward_function(params: InteractionParams, state, aux):
  347. (state, aux, _), _ = lax.scan(scan_step, (state, aux, params), jnp.arange(simulation_params.bptt_truncation))
  348. cost = cost_fn(state.position, **vars(params))
  349. return cost, (state, aux)
  350. grad_fn = jax.value_and_grad(forward_function, has_aux=True, argnums=(0,))
  351. def bptt_section(state_aux_params_results, i):
  352. state, aux, params, results = state_aux_params_results
  353. value, grad = grad_fn(params, state, aux)
  354. cost, (state, aux) = value
  355. results = results.update(grad[0], cost)
  356. return (state, aux, params, results), 0.
  357. def bptt_section_test(state_aux_params_results, i):
  358. state, aux, params, results = state_aux_params_results
  359. value = forward_function(params, state, aux)
  360. cost, (state, aux) = value
  361. results = results.update(jax.tree_util.tree_map(jnp.zeros_like, params), cost)
  362. return (state, aux, params, results), 0.
  363. if only_forward_calculation:
  364. bptt_section = bptt_section_test
  365. def simulation(interaction_params: InteractionParams,
  366. init_state: Any,
  367. aux: data_protocols.SimulationAux
  368. ):
  369. # initialize state
  370. # state = init_fn(key, body, mass=ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params))
  371. # initialize results object
  372. bptt_results = BPTTResults.create_empty(interaction_params, n_iterations)
  373. # run simulation
  374. state_aux_results, _ = jax.lax.scan(bptt_section,
  375. (init_state, aux, interaction_params, bptt_results),
  376. xs=jnp.arange(n_iterations))
  377. state, aux, params, bptt_results = state_aux_results
  378. return bptt_results, aux
  379. return simulation