io_functions.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. import numpy as np
  2. from pathlib import Path
  3. from curvature_assembly import oriented_particle, pytree_transf, file_management
  4. from curvature_assembly.data_protocols import *
  5. import json
  6. from enum import Enum
  7. from typing import TypeVar
  8. from functools import partial
  9. import jax
  10. from dataclasses import dataclass
  11. import copy
  12. from abc import ABC, abstractmethod
  13. T = TypeVar('T')
  14. SIMULATION_PARAMS_FILENAME = 'simulation_params.json'
  15. INTERACTION_PARAMS_FILENAME = 'interaction_params.json'
  16. PARAMS_GRAD_FILENAME = 'param_grad.json'
  17. PARAMS_GRAD_CLIPPED_FILENAME = 'param_grad_clipped.json'
  18. NEIGHBOR_LIST_PARAMS_FILENAME = 'neighbor_list_params.json'
  19. COST_FILENAME = 'cost.dat'
  20. SIMULATION_LOG_FILENAME = 'simulation_log.npz'
  21. COORD_HISTORY_FILENAME = 'coord_history.npy'
  22. ORIENT_HISTORY_FILENAME = 'orient_history.npy'
  23. BOX_FILENAME = 'box_size.dat'
  24. def coord_filename(idx: int):
  25. if idx is None:
  26. return 'coord.dat'
  27. return f'coord_{idx}.dat'
  28. def orient_filename(idx: int):
  29. if idx is None:
  30. return 'orient.dat'
  31. return f'orient_{idx}.dat'
  32. def weight_matrix_filename(idx: int):
  33. if idx is None:
  34. return 'weight_matrix.dat'
  35. return f'weight_matrix_{idx}.dat'
  36. def export_simulation_params(params: SimulationParams, path: Path) -> None:
  37. """Export simulation parameters as a dictionary."""
  38. with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'w') as f:
  39. json.dump(vars(params), f)
  40. def load_simulation_params(path: Path) -> dict:
  41. """Load simulation parameters as a dictionary."""
  42. with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'r') as f:
  43. params_dict = json.load(f)
  44. return params_dict
  45. def convert_arrays_to_lists(params: InteractionParams) -> dict:
  46. """
  47. Converts jax arrays in InteractionParams instance to lists and returns params as dict.
  48. Used for saving params in .json files.
  49. """
  50. no_array_dict = {}
  51. for key, val in vars(params).items():
  52. no_array_dict[key] = np.asarray(val).tolist() if isinstance(val, jnp.ndarray) else val
  53. return no_array_dict
  54. def convert_lists_to_arrays(params_dict: dict, force_float: bool = True) -> dict:
  55. """Converts list in a dictionary to jax arrays."""
  56. array_dict = {}
  57. for key, val in params_dict.items():
  58. if force_float and isinstance(val, int):
  59. val = float(val)
  60. array_dict[key] = jnp.array(val) if isinstance(val, list) else val
  61. return array_dict
  62. def export_interaction_params(params: InteractionParams, path: Path, filename: str = None) -> None:
  63. """Export interaction parameters as a dictionary. Jax arrays are converted to lists."""
  64. if filename is None:
  65. filename = INTERACTION_PARAMS_FILENAME
  66. filename = path.joinpath(filename)
  67. file_management.overwrite_protection(filename)
  68. with open(filename, 'w') as f:
  69. json.dump(convert_arrays_to_lists(params), f)
  70. def convert_enum_to_int(params: NeighborListParams) -> dict:
  71. no_array_dict = {}
  72. for key, val in vars(params).items():
  73. no_array_dict[key] = val.value if isinstance(val, Enum) else val
  74. return no_array_dict
  75. def export_neighbor_list_params(params: NeighborListParams, path: Path) -> None:
  76. """Export neighbor list parameters as a dictionary."""
  77. filename = path.joinpath(NEIGHBOR_LIST_PARAMS_FILENAME)
  78. file_management.overwrite_protection(filename)
  79. with open(filename, 'w') as f:
  80. json.dump(convert_enum_to_int(params), f)
  81. def export_cost(cost: Array, path: Path) -> None:
  82. """Export cost function array."""
  83. filename = path.joinpath(COST_FILENAME)
  84. file_management.overwrite_protection(filename)
  85. np.savetxt(filename, cost)
  86. export_param_grad = partial(export_interaction_params, filename=PARAMS_GRAD_FILENAME)
  87. export_param_grad_clipped = partial(export_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
  88. def load_interaction_params(path: Path, filename: str = None, convert_arrays=True) -> dict:
  89. """Load interaction parameters as a dictionary."""
  90. if filename is None:
  91. filename = INTERACTION_PARAMS_FILENAME
  92. with open(path.joinpath(filename), 'r') as f:
  93. params_dict = json.load(f)
  94. if convert_arrays:
  95. return convert_lists_to_arrays(params_dict)
  96. return params_dict
  97. load_param_grad = partial(load_interaction_params, filename=PARAMS_GRAD_FILENAME)
  98. load_param_grad_clipped = partial(load_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
  99. def load_cost(path: Path) -> np.ndarray:
  100. """Load cost function array."""
  101. filename = path.joinpath(COST_FILENAME)
  102. return np.loadtxt(filename)
  103. def save_single_config(body: rigid_body.RigidBody, folder: Path, save_idx: int = None) -> None:
  104. """General function for saving single config data."""
  105. np.savetxt(folder.joinpath(coord_filename(save_idx)).resolve(), body.center)
  106. np.savetxt(folder.joinpath(orient_filename(save_idx)).resolve(), body.orientation.vec)
  107. def load_single_config(folder: Path, save_idx: int = None) -> rigid_body.RigidBody:
  108. """General function for loading single config data."""
  109. coord = jnp.asarray(np.loadtxt(folder.joinpath(coord_filename(save_idx)).resolve()))
  110. orient = jnp.asarray(np.loadtxt(folder.joinpath(orient_filename(save_idx)).resolve()))
  111. return rigid_body.RigidBody(coord, rigid_body.Quaternion(orient))
  112. def init_config_folder_name(num: int, density: float) -> str:
  113. return f'n{num}rho{int(1000 * density)}'
  114. def save_initial_config(body: rigid_body.RigidBody, density: float, idx: int, init_folder: Path) -> None:
  115. """Save the initial RigidBody configuration with a given density and index."""
  116. save_folder = init_folder.joinpath(init_config_folder_name(body.center.shape[0], density))
  117. save_folder.mkdir(exist_ok=True, parents=True)
  118. save_single_config(body, save_folder, idx)
  119. def load_initial_config(n: int, density: float, idx: int, init_folder: Path) -> rigid_body.RigidBody:
  120. """Load the initial RigidBody configuration with a given density and index."""
  121. save_folder = init_folder.joinpath(init_config_folder_name(n, density))
  122. return load_single_config(save_folder, idx)
  123. def load_multiple_initial_configs(n: int, density: float, indices: list[int], init_folder: Path) \
  124. -> list[rigid_body.RigidBody]:
  125. """Load multiple initial RigidBody configurations with a given density and a list of indices."""
  126. save_folder = init_folder.joinpath(init_config_folder_name(n, density))
  127. return [load_single_config(save_folder, idx) for idx in indices]
  128. def load_multiple_initial_configs_single_object(n: int, density: float, indices: list[int], init_folder: Path,
  129. coord_rescale_factor: float = None) -> rigid_body.RigidBody:
  130. """Load multiple initial RigidBody configurations with a given density and a list of indices as a single object."""
  131. save_folder = init_folder.joinpath(init_config_folder_name(n, density))
  132. coord = []
  133. orient = []
  134. for i in indices:
  135. coord_i = jnp.asarray(np.loadtxt(save_folder.joinpath(coord_filename(i)).resolve()))
  136. if coord_rescale_factor is not None:
  137. coord_i *= coord_rescale_factor
  138. coord.append(coord_i)
  139. orient.append(jnp.asarray(np.loadtxt(save_folder.joinpath(orient_filename(i)).resolve())))
  140. return rigid_body.RigidBody(jnp.stack(coord, axis=0), rigid_body.Quaternion(jnp.stack(orient, axis=0)))
  141. def simulation_log_data_fields(simulation_log: SimulationLog) -> dict:
  142. """Return a dictionary of data fields in a simulation log object, ie ignoring other internal attributes."""
  143. data_dict = {}
  144. for key, val in vars(simulation_log).items():
  145. try:
  146. if val.shape[0] == pytree_transf.data_length(simulation_log, ignore_non_array_leaves=True):
  147. data_dict[key] = val[jnp.nonzero(val)] # exclude zero entries that may not have been populated
  148. except (AttributeError, IndexError):
  149. pass
  150. return data_dict
  151. def export_simulation_log(simulation_log: SimulationLog,
  152. folder: Path) -> None:
  153. """Export simulation log data in a single file."""
  154. file_management.overwrite_protection(folder.joinpath(SIMULATION_LOG_FILENAME))
  155. data_dict = simulation_log_data_fields(simulation_log)
  156. np.savez(folder.joinpath(SIMULATION_LOG_FILENAME), **data_dict)
  157. def load_simulation_log(folder: Path) -> dict:
  158. """Load simulation log data from file."""
  159. npz_file = np.load(folder.joinpath(SIMULATION_LOG_FILENAME))
  160. return dict(npz_file)
  161. def export_state_history(state_history: SimulationStateHistory, folder: Path) -> None:
  162. """Save simulation state history data."""
  163. # we exclude array indices that were not populated
  164. relevant_indices = jnp.nonzero(jnp.linalg.norm(state_history.coord, axis=(-2, -1)))
  165. np.save(folder.joinpath(COORD_HISTORY_FILENAME), state_history.coord[relevant_indices])
  166. np.save(folder.joinpath(ORIENT_HISTORY_FILENAME), state_history.orient[relevant_indices])
  167. def load_state_history(folder: Path) -> tuple[np.ndarray, np.ndarray]:
  168. """Save simulation state history data."""
  169. coord = np.load(folder.joinpath(COORD_HISTORY_FILENAME))
  170. orient = np.load(folder.joinpath(ORIENT_HISTORY_FILENAME))
  171. return coord, orient
  172. def export_simulation_state(body: rigid_body.RigidBody,
  173. simulation_params: SimulationParams,
  174. interaction_params: InteractionParams,
  175. folder: Path,
  176. idx: int) -> None:
  177. """Export config data along with simulation and interaction parameters used in simulation."""
  178. export_simulation_params(simulation_params, folder)
  179. export_interaction_params(interaction_params, folder, idx=idx)
  180. np.savetxt(folder.joinpath(f'coord_frame{idx}.dat').resolve(), body.center)
  181. np.savetxt(folder.joinpath(f'orient_frame{idx}.dat').resolve(), body.orientation.vec)
  182. np.savetxt(folder.joinpath(f'weight_matrix_frame{idx}.dat').resolve(),
  183. oriented_particle.get_weight_matrices(body.orientation, interaction_params.eigvals).reshape(-1, 9))
  184. def direct_visualization_export(body: rigid_body.RigidBody, eigvals: Array, export_folder: Path):
  185. coord = body.center
  186. weight_matrix = oriented_particle.get_weight_matrices(body.orientation, eigvals).reshape(-1, 9)
  187. np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
  188. np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
  189. def get_config_for_visualization(results_folder: Path,
  190. export_folder: Path,
  191. idx: int = -1,
  192. zero_cm: bool = True) -> None:
  193. interaction_params = load_interaction_params(results_folder)
  194. coord, orient = load_state_history(results_folder)
  195. if zero_cm:
  196. coord = cannonicalize_cm(coord)
  197. weight_matrix = oriented_particle.get_weight_matrices(
  198. rigid_body.Quaternion(orient[idx]), interaction_params['eigvals']).reshape(-1, 9)
  199. np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
  200. np.savetxt(export_folder.joinpath(f'orient.dat').resolve(), orient)
  201. np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
  202. def cannonicalize_cm(coord: Array) -> Array:
  203. cm = jnp.mean(coord, axis=-2)
  204. return coord - cm[..., None, :]
  205. def prepare_animation_data(results_folder: Path, interaction_params: dict,
  206. get_every: int = 1, zero_cm=True) -> dict[str, np.ndarray]:
  207. coord_history, orient_history = load_state_history(results_folder)
  208. if zero_cm:
  209. coord_history = cannonicalize_cm(coord_history)
  210. coord_history = coord_history[::get_every]
  211. orient_history = orient_history[::get_every]
  212. def weight_matrix_frame(quaternion_vec):
  213. return oriented_particle.get_weight_matrices(
  214. rigid_body.Quaternion(quaternion_vec), interaction_params['eigvals']).reshape(-1, 9)
  215. weight_matrix_hist = jax.vmap(weight_matrix_frame)(orient_history)
  216. eigensystem = oriented_particle.eigensystem(rigid_body.Quaternion(orient_history))
  217. return {'coord': coord_history, 'weight_matrix': weight_matrix_hist, 'eigensystem': eigensystem}
  218. def export_animation_data(anim_data: dict[str, np.ndarray], box: float, export_folder: Path) -> None:
  219. """Saves animation data matrices."""
  220. np.save(export_folder.joinpath('anim_coord'), anim_data['coord'])
  221. np.save(export_folder.joinpath('anim_weight_matrix'), anim_data['weight_matrix'])
  222. np.save(export_folder.joinpath('anim_eigensystem'), anim_data['eigensystem'])
  223. np.savetxt(export_folder.joinpath(BOX_FILENAME), np.asarray(box).reshape(1,))
  224. def export_cost_and_grad(cost: float,
  225. grad: InteractionParams | None,
  226. folder: Path,
  227. idx: int) -> None:
  228. """
  229. Export gradient data into a new file and append cost function value to an existing one. If gradient data
  230. was not calculated, pass None to the function.
  231. """
  232. with open(folder.joinpath('cost_function.dat'), 'a') as f:
  233. f.writelines(f'{cost: .4f}\n')
  234. if grad is not None:
  235. export_param_grad(grad, folder, idx=idx)
  236. class OptimizationSaver:
  237. def __init__(self, folder: Path, simulation_params: SimulationParams,
  238. overwrite_folder_with_no_results=False,
  239. folder_num: int = None):
  240. if folder_num is not None:
  241. self.base_folder = file_management.new_folder_with_number(folder, folder_num)
  242. else:
  243. self.base_folder = file_management.new_folder(folder)
  244. export_simulation_params(simulation_params, self.base_folder)
  245. self._export_results_happened = False
  246. self._export_inter_params_happened = False
  247. self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter_0'))
  248. def _get_iter_folder(self, check_happened: bool) -> Path:
  249. if check_happened:
  250. self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter'))
  251. self._export_results_happened = False
  252. self._export_inter_params_happened = False
  253. return self._iter_folder
  254. def _get_config_folder(self, config_idx: int) -> Path:
  255. folder = self._iter_folder.joinpath(f'config_{config_idx}')
  256. folder.mkdir(exist_ok=True)
  257. return folder
  258. def export_interaction_params(self, interaction_params: InteractionParams) -> None:
  259. folder = self._get_iter_folder(self._export_results_happened)
  260. export_interaction_params(interaction_params, folder)
  261. self._export_inter_params_happened = True
  262. def export_param_updates(self, updates: InteractionParams) -> None:
  263. folder = self._get_iter_folder(self._export_results_happened)
  264. export_interaction_params(updates, folder, filename='interaction_param_updates.json')
  265. self._export_inter_params_happened = True
  266. def export_run_params(self, run_params: dict):
  267. with open(self.base_folder.joinpath(f'run_params.json'), 'w') as f:
  268. json.dump(run_params, f)
  269. def export_additional_simulation_data(self, data: dict):
  270. with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'w') as f:
  271. json.dump(data, f)
  272. def export_cost_function_info(self, cost_fn):
  273. with open(self.base_folder.joinpath(f'cost_function_info.dat'), 'w') as f:
  274. f.write(str(cost_fn))
  275. def export_results(self, bptt_results: BpttResults, aux: SimulationAux) -> None:
  276. folder = self._get_iter_folder(self._export_results_happened)
  277. try:
  278. export_cost(bptt_results.cost, folder)
  279. export_param_grad(bptt_results.grad, folder)
  280. export_simulation_log(aux.log, folder)
  281. export_state_history(aux.state_history, folder)
  282. self._export_results_happened = True
  283. except ValueError:
  284. raise ValueError('For exporting multiple results, use method "export_multiple_results".')
  285. def export_multiple_results(self,
  286. bptt_results: BpttResults,
  287. aux: SimulationAux) -> None:
  288. self._get_iter_folder(self._export_results_happened)
  289. bptt_results_list = pytree_transf.split_to_list(bptt_results)
  290. aux_list = pytree_transf.split_to_list(aux)
  291. for config_idx, (result, a) in enumerate(zip(bptt_results_list, aux_list)):
  292. folder = self._get_config_folder(config_idx)
  293. export_cost(result.cost, folder)
  294. export_param_grad(result.grad, folder)
  295. export_simulation_log(a.log, folder)
  296. export_state_history(a.state_history, folder)
  297. self._export_results_happened = True
  298. def export_clipped_gradients(self, grad_clipped: InteractionParams):
  299. grad_clipped_list = pytree_transf.split_to_list(grad_clipped)
  300. for config_idx, grad in enumerate(grad_clipped_list):
  301. folder = self._get_config_folder(config_idx)
  302. export_param_grad_clipped(grad, folder)
  303. class NoResultsError(Exception):
  304. pass
  305. class OptimizationLoader:
  306. """Convenience class to load results of an optimization simulation."""
  307. def __init__(self, folder: Path):
  308. self.base_folder = folder.resolve()
  309. if not self.base_folder.exists():
  310. raise NoResultsError(f"Results folder {self.base_folder} does not exist.")
  311. def get_results_folder(self, iter_idx: int, config_idx: int = None):
  312. if iter_idx < 0:
  313. iter_idx = self.all_iter_indices()[iter_idx]
  314. if config_idx is None:
  315. return self.base_folder.joinpath(f'iter_{iter_idx}')
  316. return self.base_folder.joinpath(f'iter_{iter_idx}').joinpath(f'config_{config_idx}')
  317. def last_iter_idx(self):
  318. try:
  319. return self.all_iter_indices()[-1]
  320. except IndexError:
  321. return 0
  322. def all_config_indices(self, iter_idx: int = None) -> list:
  323. if iter_idx is None:
  324. iter_idx = self.last_iter_idx()
  325. iteration_folders = [folder for folder in self.get_results_folder(iter_idx).glob(f'config_*')]
  326. all_directory_nums = []
  327. for folder in iteration_folders:
  328. _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
  329. all_directory_nums.append(dir_num)
  330. return sorted(all_directory_nums)
  331. def num_replicas(self):
  332. return len(self.all_config_indices())
  333. def all_iter_indices(self) -> list:
  334. iteration_folders = [folder for folder in self.base_folder.glob(f'iter_*')]
  335. all_directory_nums = []
  336. for folder in iteration_folders:
  337. _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
  338. if not file_management.recursive_dir_empty(folder, ignore_top_level_files=True):
  339. all_directory_nums.append(dir_num)
  340. return sorted(all_directory_nums)
  341. def load_simulation_params(self, iter_idx: int = None, config_idx: int = None) -> dict:
  342. return load_simulation_params(self.base_folder)
  343. def box_size_at_number_density(self):
  344. simulation_params = self.load_simulation_params()
  345. return oriented_particle.box_size_at_number_density(simulation_params["num"],
  346. simulation_params["density"],
  347. spatial_dimension=3)
  348. def box_size_at_ellipsoid_density(self, iter_idx: int = None):
  349. simulation_params = self.load_simulation_params()
  350. if iter_idx is None:
  351. iter_idx = self.last_iter_idx()
  352. interaction_params = self.load_interaction_params(iter_idx)
  353. return oriented_particle.box_size_at_ellipsoid_density(simulation_params["num"],
  354. simulation_params["density"],
  355. interaction_params["eigvals"])
  356. def box_size(self, iter_idx: int = None):
  357. if iter_idx is None:
  358. iter_idx = self.last_iter_idx()
  359. interaction_params = self.load_interaction_params(iter_idx)
  360. particle_volume = oriented_particle.ellipsoid_volume(interaction_params["eigvals"])
  361. if jnp.all(jnp.isclose(particle_volume, 1., atol=1e-4)):
  362. return self.box_size_at_number_density()
  363. return self.box_size_at_ellipsoid_density(iter_idx=iter_idx)
  364. def load_additional_simulation_data(self) -> dict:
  365. with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'r') as f:
  366. data = json.load(f)
  367. return data
  368. def load_run_params(self) -> dict:
  369. with open(self.base_folder.joinpath(f'run_params.json'), 'r') as f:
  370. run_params = json.load(f)
  371. return run_params
  372. def load_interaction_params(self, iter_idx: int, config_idx: int = None, convert_arrays=True) -> dict:
  373. return load_interaction_params(self.get_results_folder(iter_idx), convert_arrays=convert_arrays)
  374. def load_multiple_interaction_params(self, iter_indices: list = None) -> dict:
  375. if iter_indices is None:
  376. iter_indices = self.all_iter_indices()
  377. return pytree_transf.stack([self.load_interaction_params(iter_idx) for iter_idx in iter_indices])
  378. def load_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
  379. return load_param_grad(self.get_results_folder(iter_idx, config_idx))
  380. def load_clipped_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
  381. return load_param_grad_clipped(self.get_results_folder(iter_idx, config_idx))
  382. def load_multiple_gradients(self, iter_idx: int, config_indices: list = None) -> list[dict, ...]:
  383. if config_indices is None:
  384. config_indices = self.all_config_indices()
  385. return [self.load_gradient(iter_idx, config_idx) for config_idx in config_indices]
  386. def load_cost(self, iter_idx: int, config_idx: int = None) -> np.ndarray:
  387. return load_cost(self.get_results_folder(iter_idx, config_idx))
  388. def load_cost_all_config(self, iter_idx: int):
  389. config_indices = self.all_config_indices()
  390. config_costs = []
  391. for config_idx in config_indices:
  392. config_costs.append(load_cost(self.get_results_folder(iter_idx, config_idx)))
  393. return np.stack(config_costs)
  394. def load_cost_all(self) -> np.ndarray:
  395. inter_indices = self.all_iter_indices()
  396. all_costs = []
  397. for iter_idx in inter_indices:
  398. all_costs.append(self.load_cost_all_config(iter_idx))
  399. return np.stack(all_costs)
  400. def load_simulation_log(self, iter_idx: int, config_idx: int = None) -> dict:
  401. return load_simulation_log(self.get_results_folder(iter_idx, config_idx))
  402. def load_body(self, iter_idx: int, config_idx: int, time_idx: int) -> rigid_body.RigidBody:
  403. trajectory = self.load_trajectory(iter_idx=iter_idx, config_idx=config_idx)
  404. return trajectory[time_idx]
  405. def load_multiple_bodies(self, iter_idx: int, time_idx: int = -1, config_indices: list[int] = None):
  406. if config_indices is None:
  407. config_indices = self.all_config_indices(iter_idx)
  408. return pytree_transf.stack([self.load_body(iter_idx, config_idx, time_idx) for config_idx in config_indices])
  409. def load_trajectory(self, iter_idx: int, config_idx: int) -> rigid_body.RigidBody:
  410. coord, orient = load_state_history(self.get_results_folder(iter_idx, config_idx))
  411. return rigid_body.RigidBody(jnp.asarray(coord), rigid_body.Quaternion(jnp.asarray(orient)))
  412. def load_multiple_config_trajectories(self, iter_idx: int, config_indices: list[int]) -> rigid_body.RigidBody:
  413. return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for config_idx in config_indices])
  414. def load_all_iter_trajectories(self, config_idx: int) -> rigid_body.RigidBody:
  415. iter_indices = self.all_iter_indices()
  416. return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for iter_idx in iter_indices])
  417. def get_config_for_visualization(self, iter_idx: int, config_idx: int, export_folder: Path,
  418. frame_idx: int = -1, zero_cm: bool = True):
  419. return get_config_for_visualization(self.get_results_folder(iter_idx, config_idx), export_folder,
  420. idx=frame_idx, zero_cm=zero_cm)
  421. def export_animation_data(self, iter_idx: int, config_idx: int, export_folder: Path,
  422. get_every: int = 1, zero_cm: bool = False) -> None:
  423. interaction_params = self.load_interaction_params(iter_idx)
  424. anim_data = prepare_animation_data(self.get_results_folder(iter_idx, config_idx), interaction_params,
  425. get_every=get_every, zero_cm=zero_cm)
  426. box_size = self.box_size(iter_idx=iter_idx)
  427. export_animation_data(anim_data, box_size, export_folder)
  428. @dataclass(frozen=True)
  429. class PlotData(ABC):
  430. """Abstract class for saving and loading plot data. Children should add attributes for this data."""
  431. run_params: dict
  432. results_path: Path
  433. @property
  434. @abstractmethod
  435. def cluster_data_folder(self) -> str:
  436. """Subclasses must define a folder name for saving/loading data."""
  437. pass
  438. @classmethod
  439. @abstractmethod
  440. def calculate_data(cls, results_folder: Path, **kwargs) -> "PlotData":
  441. """A method that fills the class with data."""
  442. pass
  443. @classmethod
  444. def get_save_path(cls, results_path) -> Path:
  445. results_base = results_path.parent
  446. results_filename = f'{results_path.name}.json'
  447. save_path_base = results_base / cls.cluster_data_folder
  448. save_path_base.mkdir(exist_ok=True)
  449. return save_path_base / results_filename
  450. def save(self) -> None:
  451. self_as_dict = vars(copy.deepcopy(self))
  452. for key in self_as_dict:
  453. if isinstance(self_as_dict[key], jnp.ndarray):
  454. self_as_dict[key] = np.asarray(self_as_dict[key]).tolist()
  455. self_as_dict['results_path'] = str(self_as_dict['results_path'])
  456. with open(self.get_save_path(self.results_path), 'w') as f:
  457. json.dump(self_as_dict, f)
  458. @classmethod
  459. def load(cls, results_path: Path) -> "PlotData":
  460. cluster_data_path = cls.get_save_path(results_path)
  461. if not cluster_data_path.exists():
  462. raise FileNotFoundError("Clustering results for the given path not yet exported.")
  463. with open(cluster_data_path, 'r') as f:
  464. cls_as_dict = json.load(f)
  465. for key in cls_as_dict:
  466. if isinstance(cls_as_dict[key], list):
  467. cls_as_dict[key] = jnp.asarray(cls_as_dict[key])
  468. cls_as_dict['results_path'] = Path(cls_as_dict['results_path'])
  469. return cls(**cls_as_dict)
  470. @classmethod
  471. def get_data(cls,
  472. results_folder: Path,
  473. recalculate: bool = False,
  474. **calculate_data_kwargs) -> "PlotData":
  475. if recalculate:
  476. print(f'Recalculating results for {results_folder}...')
  477. csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
  478. csd.save()
  479. return csd
  480. try:
  481. csd = cls.load(results_folder)
  482. except FileNotFoundError:
  483. print(f'Results for folder {results_folder} not yet exported, calculating...')
  484. csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
  485. csd.save()
  486. return csd
  487. def figure_file_name(base_file_name: str, results_folder: Path, *, iter_idx: int = None, config_idx: int = None,
  488. figure_folder: Path = Path("/home/andraz/CurvatureAssemblyFigures")):
  489. figure_folder.mkdir(exist_ok=True)
  490. results_folder_base_name, results_idx = file_management.split_base_and_num(results_folder.name, sep='_', no_num_return='')
  491. base = Path(base_file_name).stem
  492. suffix = Path(base_file_name).suffix
  493. fig_folder = figure_folder.joinpath(f'{results_folder_base_name}_{results_idx}')
  494. fig_folder.mkdir(exist_ok=True)
  495. if iter_idx is None and config_idx is None:
  496. return fig_folder.joinpath(f'{base}{suffix}')
  497. fig_folder = fig_folder.joinpath(f'{base}')
  498. fig_folder.mkdir(exist_ok=True)
  499. if config_idx is None:
  500. return fig_folder.joinpath(f'{base}_iter{iter_idx}{suffix}')
  501. if iter_idx is None:
  502. return fig_folder.joinpath(f'{base}_config{config_idx}{suffix}')
  503. fig_folder = fig_folder.joinpath(f'iter{iter_idx}')
  504. fig_folder.mkdir(exist_ok=True)
  505. return fig_folder.joinpath(f'{base}_iter{iter_idx}_config{config_idx}{suffix}')