from __future__ import annotations from typing import Protocol, Callable, Any import jax.numpy as jnp from jax_md import rigid_body Array = jnp.ndarray class SimulationParams(Protocol): """ Interface for a container of simulation parameters, i.e. simulation parameters over which we do not intend to differentiate. """ num: int density: float simulation_steps: int dt: float config_every: int class InteractionParams(Protocol): """ Protocol for a container of interaction parameters. Gradient of the simulation should be taken over these values. """ eigvals: jnp.ndarray class NeighborListParams(Protocol): """Protocol for a container of neighbor list parameters.""" class SimulationLog(Protocol): """Protocol class for storing data during the simulation.""" current_len: Array def calculate_values(self, state, energy_fn: Callable, ellipsoid_mass: rigid_body.RigidBody, kT: float, **params): ... def update(self, *args): ... def revert_last_nsteps(self, nsteps: int): ... class SimulationStateHistory(Protocol): coord: Array orient: Array current_len: Array def revert_last_nsteps(self, nsteps: int): ... class SimulationAux(Protocol): """Dataclass for simulation auxiliary data.""" log: SimulationLog state_history: SimulationStateHistory def reset_empty(self) -> SimulationAux: ... class BpttResults(Protocol): cost: Array grad: InteractionParams BpttSimulation = Callable[[InteractionParams, Any, SimulationAux], tuple[BpttResults, SimulationAux]]