123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- from __future__ import annotations
- from typing import Protocol, Callable, Any
- import jax.numpy as jnp
- from jax_md import rigid_body
- Array = jnp.ndarray
- class SimulationParams(Protocol):
- """
- Interface for a container of simulation parameters, i.e. simulation parameters over which we do not
- intend to differentiate.
- """
- num: int
- density: float
- simulation_steps: int
- dt: float
- config_every: int
- class InteractionParams(Protocol):
- """
- Protocol for a container of interaction parameters. Gradient of the simulation should be taken over these values.
- """
- eigvals: jnp.ndarray
- class NeighborListParams(Protocol):
- """Protocol for a container of neighbor list parameters."""
- class SimulationLog(Protocol):
- """Protocol class for storing data during the simulation."""
- current_len: Array
- def calculate_values(self, state, energy_fn: Callable, ellipsoid_mass: rigid_body.RigidBody, kT: float, **params):
- ...
- def update(self, *args):
- ...
- def revert_last_nsteps(self, nsteps: int):
- ...
- class SimulationStateHistory(Protocol):
- coord: Array
- orient: Array
- current_len: Array
- def revert_last_nsteps(self, nsteps: int):
- ...
- class SimulationAux(Protocol):
- """Dataclass for simulation auxiliary data."""
- log: SimulationLog
- state_history: SimulationStateHistory
- def reset_empty(self) -> SimulationAux:
- ...
- class BpttResults(Protocol):
- cost: Array
- grad: InteractionParams
- BpttSimulation = Callable[[InteractionParams, Any, SimulationAux],
- tuple[BpttResults, SimulationAux]]
|