data_protocols.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from __future__ import annotations
  2. from typing import Protocol, Callable, Any
  3. import jax.numpy as jnp
  4. from jax_md import rigid_body
  5. Array = jnp.ndarray
  6. class SimulationParams(Protocol):
  7. """
  8. Interface for a container of simulation parameters, i.e. simulation parameters over which we do not
  9. intend to differentiate.
  10. """
  11. num: int
  12. density: float
  13. simulation_steps: int
  14. dt: float
  15. config_every: int
  16. class InteractionParams(Protocol):
  17. """
  18. Protocol for a container of interaction parameters. Gradient of the simulation should be taken over these values.
  19. """
  20. eigvals: jnp.ndarray
  21. class NeighborListParams(Protocol):
  22. """Protocol for a container of neighbor list parameters."""
  23. class SimulationLog(Protocol):
  24. """Protocol class for storing data during the simulation."""
  25. current_len: Array
  26. def calculate_values(self, state, energy_fn: Callable, ellipsoid_mass: rigid_body.RigidBody, kT: float, **params):
  27. ...
  28. def update(self, *args):
  29. ...
  30. def revert_last_nsteps(self, nsteps: int):
  31. ...
  32. class SimulationStateHistory(Protocol):
  33. coord: Array
  34. orient: Array
  35. current_len: Array
  36. def revert_last_nsteps(self, nsteps: int):
  37. ...
  38. class SimulationAux(Protocol):
  39. """Dataclass for simulation auxiliary data."""
  40. log: SimulationLog
  41. state_history: SimulationStateHistory
  42. def reset_empty(self) -> SimulationAux:
  43. ...
  44. class BpttResults(Protocol):
  45. cost: Array
  46. grad: InteractionParams
  47. BpttSimulation = Callable[[InteractionParams, Any, SimulationAux],
  48. tuple[BpttResults, SimulationAux]]