123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638 |
- import numpy as np
- from pathlib import Path
- from curvature_assembly import oriented_particle, pytree_transf, file_management
- from curvature_assembly.data_protocols import *
- import json
- from enum import Enum
- from typing import TypeVar
- from functools import partial
- import jax
- from dataclasses import dataclass
- import copy
- from abc import ABC, abstractmethod
- T = TypeVar('T')
- SIMULATION_PARAMS_FILENAME = 'simulation_params.json'
- INTERACTION_PARAMS_FILENAME = 'interaction_params.json'
- PARAMS_GRAD_FILENAME = 'param_grad.json'
- PARAMS_GRAD_CLIPPED_FILENAME = 'param_grad_clipped.json'
- NEIGHBOR_LIST_PARAMS_FILENAME = 'neighbor_list_params.json'
- COST_FILENAME = 'cost.dat'
- SIMULATION_LOG_FILENAME = 'simulation_log.npz'
- COORD_HISTORY_FILENAME = 'coord_history.npy'
- ORIENT_HISTORY_FILENAME = 'orient_history.npy'
- BOX_FILENAME = 'box_size.dat'
- def coord_filename(idx: int):
- if idx is None:
- return 'coord.dat'
- return f'coord_{idx}.dat'
- def orient_filename(idx: int):
- if idx is None:
- return 'orient.dat'
- return f'orient_{idx}.dat'
- def weight_matrix_filename(idx: int):
- if idx is None:
- return 'weight_matrix.dat'
- return f'weight_matrix_{idx}.dat'
- def export_simulation_params(params: SimulationParams, path: Path) -> None:
- """Export simulation parameters as a dictionary."""
- with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'w') as f:
- json.dump(vars(params), f)
- def load_simulation_params(path: Path) -> dict:
- """Load simulation parameters as a dictionary."""
- with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'r') as f:
- params_dict = json.load(f)
- return params_dict
- def convert_arrays_to_lists(params: InteractionParams) -> dict:
- """
- Converts jax arrays in InteractionParams instance to lists and returns params as dict.
- Used for saving params in .json files.
- """
- no_array_dict = {}
- for key, val in vars(params).items():
- no_array_dict[key] = np.asarray(val).tolist() if isinstance(val, jnp.ndarray) else val
- return no_array_dict
- def convert_lists_to_arrays(params_dict: dict, force_float: bool = True) -> dict:
- """Converts list in a dictionary to jax arrays."""
- array_dict = {}
- for key, val in params_dict.items():
- if force_float and isinstance(val, int):
- val = float(val)
- array_dict[key] = jnp.array(val) if isinstance(val, list) else val
- return array_dict
- def export_interaction_params(params: InteractionParams, path: Path, filename: str = None) -> None:
- """Export interaction parameters as a dictionary. Jax arrays are converted to lists."""
- if filename is None:
- filename = INTERACTION_PARAMS_FILENAME
- filename = path.joinpath(filename)
- file_management.overwrite_protection(filename)
- with open(filename, 'w') as f:
- json.dump(convert_arrays_to_lists(params), f)
- def convert_enum_to_int(params: NeighborListParams) -> dict:
- no_array_dict = {}
- for key, val in vars(params).items():
- no_array_dict[key] = val.value if isinstance(val, Enum) else val
- return no_array_dict
- def export_neighbor_list_params(params: NeighborListParams, path: Path) -> None:
- """Export neighbor list parameters as a dictionary."""
- filename = path.joinpath(NEIGHBOR_LIST_PARAMS_FILENAME)
- file_management.overwrite_protection(filename)
- with open(filename, 'w') as f:
- json.dump(convert_enum_to_int(params), f)
- def export_cost(cost: Array, path: Path) -> None:
- """Export cost function array."""
- filename = path.joinpath(COST_FILENAME)
- file_management.overwrite_protection(filename)
- np.savetxt(filename, cost)
- export_param_grad = partial(export_interaction_params, filename=PARAMS_GRAD_FILENAME)
- export_param_grad_clipped = partial(export_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
- def load_interaction_params(path: Path, filename: str = None, convert_arrays=True) -> dict:
- """Load interaction parameters as a dictionary."""
- if filename is None:
- filename = INTERACTION_PARAMS_FILENAME
- with open(path.joinpath(filename), 'r') as f:
- params_dict = json.load(f)
- if convert_arrays:
- return convert_lists_to_arrays(params_dict)
- return params_dict
- load_param_grad = partial(load_interaction_params, filename=PARAMS_GRAD_FILENAME)
- load_param_grad_clipped = partial(load_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)
- def load_cost(path: Path) -> np.ndarray:
- """Load cost function array."""
- filename = path.joinpath(COST_FILENAME)
- return np.loadtxt(filename)
- def save_single_config(body: rigid_body.RigidBody, folder: Path, save_idx: int = None) -> None:
- """General function for saving single config data."""
- np.savetxt(folder.joinpath(coord_filename(save_idx)).resolve(), body.center)
- np.savetxt(folder.joinpath(orient_filename(save_idx)).resolve(), body.orientation.vec)
- def load_single_config(folder: Path, save_idx: int = None) -> rigid_body.RigidBody:
- """General function for loading single config data."""
- coord = jnp.asarray(np.loadtxt(folder.joinpath(coord_filename(save_idx)).resolve()))
- orient = jnp.asarray(np.loadtxt(folder.joinpath(orient_filename(save_idx)).resolve()))
- return rigid_body.RigidBody(coord, rigid_body.Quaternion(orient))
- def init_config_folder_name(num: int, density: float) -> str:
- return f'n{num}rho{int(1000 * density)}'
- def save_initial_config(body: rigid_body.RigidBody, density: float, idx: int, init_folder: Path) -> None:
- """Save the initial RigidBody configuration with a given density and index."""
- save_folder = init_folder.joinpath(init_config_folder_name(body.center.shape[0], density))
- save_folder.mkdir(exist_ok=True, parents=True)
- save_single_config(body, save_folder, idx)
- def load_initial_config(n: int, density: float, idx: int, init_folder: Path) -> rigid_body.RigidBody:
- """Load the initial RigidBody configuration with a given density and index."""
- save_folder = init_folder.joinpath(init_config_folder_name(n, density))
- return load_single_config(save_folder, idx)
- def load_multiple_initial_configs(n: int, density: float, indices: list[int], init_folder: Path) \
- -> list[rigid_body.RigidBody]:
- """Load multiple initial RigidBody configurations with a given density and a list of indices."""
- save_folder = init_folder.joinpath(init_config_folder_name(n, density))
- return [load_single_config(save_folder, idx) for idx in indices]
- def load_multiple_initial_configs_single_object(n: int, density: float, indices: list[int], init_folder: Path,
- coord_rescale_factor: float = None) -> rigid_body.RigidBody:
- """Load multiple initial RigidBody configurations with a given density and a list of indices as a single object."""
- save_folder = init_folder.joinpath(init_config_folder_name(n, density))
- coord = []
- orient = []
- for i in indices:
- coord_i = jnp.asarray(np.loadtxt(save_folder.joinpath(coord_filename(i)).resolve()))
- if coord_rescale_factor is not None:
- coord_i *= coord_rescale_factor
- coord.append(coord_i)
- orient.append(jnp.asarray(np.loadtxt(save_folder.joinpath(orient_filename(i)).resolve())))
- return rigid_body.RigidBody(jnp.stack(coord, axis=0), rigid_body.Quaternion(jnp.stack(orient, axis=0)))
- def simulation_log_data_fields(simulation_log: SimulationLog) -> dict:
- """Return a dictionary of data fields in a simulation log object, ie ignoring other internal attributes."""
- data_dict = {}
- for key, val in vars(simulation_log).items():
- try:
- if val.shape[0] == pytree_transf.data_length(simulation_log, ignore_non_array_leaves=True):
- data_dict[key] = val[jnp.nonzero(val)] # exclude zero entries that may not have been populated
- except (AttributeError, IndexError):
- pass
- return data_dict
- def export_simulation_log(simulation_log: SimulationLog,
- folder: Path) -> None:
- """Export simulation log data in a single file."""
- file_management.overwrite_protection(folder.joinpath(SIMULATION_LOG_FILENAME))
- data_dict = simulation_log_data_fields(simulation_log)
- np.savez(folder.joinpath(SIMULATION_LOG_FILENAME), **data_dict)
- def load_simulation_log(folder: Path) -> dict:
- """Load simulation log data from file."""
- npz_file = np.load(folder.joinpath(SIMULATION_LOG_FILENAME))
- return dict(npz_file)
- def export_state_history(state_history: SimulationStateHistory, folder: Path) -> None:
- """Save simulation state history data."""
- # we exclude array indices that were not populated
- relevant_indices = jnp.nonzero(jnp.linalg.norm(state_history.coord, axis=(-2, -1)))
- np.save(folder.joinpath(COORD_HISTORY_FILENAME), state_history.coord[relevant_indices])
- np.save(folder.joinpath(ORIENT_HISTORY_FILENAME), state_history.orient[relevant_indices])
- def load_state_history(folder: Path) -> tuple[np.ndarray, np.ndarray]:
- """Save simulation state history data."""
- coord = np.load(folder.joinpath(COORD_HISTORY_FILENAME))
- orient = np.load(folder.joinpath(ORIENT_HISTORY_FILENAME))
- return coord, orient
- def export_simulation_state(body: rigid_body.RigidBody,
- simulation_params: SimulationParams,
- interaction_params: InteractionParams,
- folder: Path,
- idx: int) -> None:
- """Export config data along with simulation and interaction parameters used in simulation."""
- export_simulation_params(simulation_params, folder)
- export_interaction_params(interaction_params, folder, idx=idx)
- np.savetxt(folder.joinpath(f'coord_frame{idx}.dat').resolve(), body.center)
- np.savetxt(folder.joinpath(f'orient_frame{idx}.dat').resolve(), body.orientation.vec)
- np.savetxt(folder.joinpath(f'weight_matrix_frame{idx}.dat').resolve(),
- oriented_particle.get_weight_matrices(body.orientation, interaction_params.eigvals).reshape(-1, 9))
- def direct_visualization_export(body: rigid_body.RigidBody, eigvals: Array, export_folder: Path):
- coord = body.center
- weight_matrix = oriented_particle.get_weight_matrices(body.orientation, eigvals).reshape(-1, 9)
- np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
- np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
- def get_config_for_visualization(results_folder: Path,
- export_folder: Path,
- idx: int = -1,
- zero_cm: bool = True) -> None:
- interaction_params = load_interaction_params(results_folder)
- coord, orient = load_state_history(results_folder)
- if zero_cm:
- coord = cannonicalize_cm(coord)
- weight_matrix = oriented_particle.get_weight_matrices(
- rigid_body.Quaternion(orient[idx]), interaction_params['eigvals']).reshape(-1, 9)
- np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
- np.savetxt(export_folder.joinpath(f'orient.dat').resolve(), orient)
- np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)
- def cannonicalize_cm(coord: Array) -> Array:
- cm = jnp.mean(coord, axis=-2)
- return coord - cm[..., None, :]
- def prepare_animation_data(results_folder: Path, interaction_params: dict,
- get_every: int = 1, zero_cm=True) -> dict[str, np.ndarray]:
- coord_history, orient_history = load_state_history(results_folder)
- if zero_cm:
- coord_history = cannonicalize_cm(coord_history)
- coord_history = coord_history[::get_every]
- orient_history = orient_history[::get_every]
- def weight_matrix_frame(quaternion_vec):
- return oriented_particle.get_weight_matrices(
- rigid_body.Quaternion(quaternion_vec), interaction_params['eigvals']).reshape(-1, 9)
- weight_matrix_hist = jax.vmap(weight_matrix_frame)(orient_history)
- eigensystem = oriented_particle.eigensystem(rigid_body.Quaternion(orient_history))
- return {'coord': coord_history, 'weight_matrix': weight_matrix_hist, 'eigensystem': eigensystem}
- def export_animation_data(anim_data: dict[str, np.ndarray], box: float, export_folder: Path) -> None:
- """Saves animation data matrices."""
- np.save(export_folder.joinpath('anim_coord'), anim_data['coord'])
- np.save(export_folder.joinpath('anim_weight_matrix'), anim_data['weight_matrix'])
- np.save(export_folder.joinpath('anim_eigensystem'), anim_data['eigensystem'])
- np.savetxt(export_folder.joinpath(BOX_FILENAME), np.asarray(box).reshape(1,))
- def export_cost_and_grad(cost: float,
- grad: InteractionParams | None,
- folder: Path,
- idx: int) -> None:
- """
- Export gradient data into a new file and append cost function value to an existing one. If gradient data
- was not calculated, pass None to the function.
- """
- with open(folder.joinpath('cost_function.dat'), 'a') as f:
- f.writelines(f'{cost: .4f}\n')
- if grad is not None:
- export_param_grad(grad, folder, idx=idx)
- class OptimizationSaver:
- def __init__(self, folder: Path, simulation_params: SimulationParams,
- overwrite_folder_with_no_results=False,
- folder_num: int = None):
- if folder_num is not None:
- self.base_folder = file_management.new_folder_with_number(folder, folder_num)
- else:
- self.base_folder = file_management.new_folder(folder)
- export_simulation_params(simulation_params, self.base_folder)
- self._export_results_happened = False
- self._export_inter_params_happened = False
- self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter_0'))
- def _get_iter_folder(self, check_happened: bool) -> Path:
- if check_happened:
- self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter'))
- self._export_results_happened = False
- self._export_inter_params_happened = False
- return self._iter_folder
- def _get_config_folder(self, config_idx: int) -> Path:
- folder = self._iter_folder.joinpath(f'config_{config_idx}')
- folder.mkdir(exist_ok=True)
- return folder
- def export_interaction_params(self, interaction_params: InteractionParams) -> None:
- folder = self._get_iter_folder(self._export_results_happened)
- export_interaction_params(interaction_params, folder)
- self._export_inter_params_happened = True
- def export_param_updates(self, updates: InteractionParams) -> None:
- folder = self._get_iter_folder(self._export_results_happened)
- export_interaction_params(updates, folder, filename='interaction_param_updates.json')
- self._export_inter_params_happened = True
- def export_run_params(self, run_params: dict):
- with open(self.base_folder.joinpath(f'run_params.json'), 'w') as f:
- json.dump(run_params, f)
- def export_additional_simulation_data(self, data: dict):
- with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'w') as f:
- json.dump(data, f)
- def export_cost_function_info(self, cost_fn):
- with open(self.base_folder.joinpath(f'cost_function_info.dat'), 'w') as f:
- f.write(str(cost_fn))
- def export_results(self, bptt_results: BpttResults, aux: SimulationAux) -> None:
- folder = self._get_iter_folder(self._export_results_happened)
- try:
- export_cost(bptt_results.cost, folder)
- export_param_grad(bptt_results.grad, folder)
- export_simulation_log(aux.log, folder)
- export_state_history(aux.state_history, folder)
- self._export_results_happened = True
- except ValueError:
- raise ValueError('For exporting multiple results, use method "export_multiple_results".')
- def export_multiple_results(self,
- bptt_results: BpttResults,
- aux: SimulationAux) -> None:
- self._get_iter_folder(self._export_results_happened)
- bptt_results_list = pytree_transf.split_to_list(bptt_results)
- aux_list = pytree_transf.split_to_list(aux)
- for config_idx, (result, a) in enumerate(zip(bptt_results_list, aux_list)):
- folder = self._get_config_folder(config_idx)
- export_cost(result.cost, folder)
- export_param_grad(result.grad, folder)
- export_simulation_log(a.log, folder)
- export_state_history(a.state_history, folder)
- self._export_results_happened = True
- def export_clipped_gradients(self, grad_clipped: InteractionParams):
- grad_clipped_list = pytree_transf.split_to_list(grad_clipped)
- for config_idx, grad in enumerate(grad_clipped_list):
- folder = self._get_config_folder(config_idx)
- export_param_grad_clipped(grad, folder)
- class NoResultsError(Exception):
- pass
- class OptimizationLoader:
- """Convenience class to load results of an optimization simulation."""
- def __init__(self, folder: Path):
- self.base_folder = folder.resolve()
- if not self.base_folder.exists():
- raise NoResultsError(f"Results folder {self.base_folder} does not exist.")
- def get_results_folder(self, iter_idx: int, config_idx: int = None):
- if iter_idx < 0:
- iter_idx = self.all_iter_indices()[iter_idx]
- if config_idx is None:
- return self.base_folder.joinpath(f'iter_{iter_idx}')
- return self.base_folder.joinpath(f'iter_{iter_idx}').joinpath(f'config_{config_idx}')
- def last_iter_idx(self):
- try:
- return self.all_iter_indices()[-1]
- except IndexError:
- return 0
- def all_config_indices(self, iter_idx: int = None) -> list:
- if iter_idx is None:
- iter_idx = self.last_iter_idx()
- iteration_folders = [folder for folder in self.get_results_folder(iter_idx).glob(f'config_*')]
- all_directory_nums = []
- for folder in iteration_folders:
- _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
- all_directory_nums.append(dir_num)
- return sorted(all_directory_nums)
- def num_replicas(self):
- return len(self.all_config_indices())
- def all_iter_indices(self) -> list:
- iteration_folders = [folder for folder in self.base_folder.glob(f'iter_*')]
- all_directory_nums = []
- for folder in iteration_folders:
- _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
- if not file_management.recursive_dir_empty(folder, ignore_top_level_files=True):
- all_directory_nums.append(dir_num)
- return sorted(all_directory_nums)
- def load_simulation_params(self, iter_idx: int = None, config_idx: int = None) -> dict:
- return load_simulation_params(self.base_folder)
- def box_size_at_number_density(self):
- simulation_params = self.load_simulation_params()
- return oriented_particle.box_size_at_number_density(simulation_params["num"],
- simulation_params["density"],
- spatial_dimension=3)
- def box_size_at_ellipsoid_density(self, iter_idx: int = None):
- simulation_params = self.load_simulation_params()
- if iter_idx is None:
- iter_idx = self.last_iter_idx()
- interaction_params = self.load_interaction_params(iter_idx)
- return oriented_particle.box_size_at_ellipsoid_density(simulation_params["num"],
- simulation_params["density"],
- interaction_params["eigvals"])
- def box_size(self, iter_idx: int = None):
- if iter_idx is None:
- iter_idx = self.last_iter_idx()
- interaction_params = self.load_interaction_params(iter_idx)
- particle_volume = oriented_particle.ellipsoid_volume(interaction_params["eigvals"])
- if jnp.all(jnp.isclose(particle_volume, 1., atol=1e-4)):
- return self.box_size_at_number_density()
- return self.box_size_at_ellipsoid_density(iter_idx=iter_idx)
- def load_additional_simulation_data(self) -> dict:
- with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'r') as f:
- data = json.load(f)
- return data
- def load_run_params(self) -> dict:
- with open(self.base_folder.joinpath(f'run_params.json'), 'r') as f:
- run_params = json.load(f)
- return run_params
- def load_interaction_params(self, iter_idx: int, config_idx: int = None, convert_arrays=True) -> dict:
- return load_interaction_params(self.get_results_folder(iter_idx), convert_arrays=convert_arrays)
- def load_multiple_interaction_params(self, iter_indices: list = None) -> dict:
- if iter_indices is None:
- iter_indices = self.all_iter_indices()
- return pytree_transf.stack([self.load_interaction_params(iter_idx) for iter_idx in iter_indices])
- def load_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
- return load_param_grad(self.get_results_folder(iter_idx, config_idx))
- def load_clipped_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
- return load_param_grad_clipped(self.get_results_folder(iter_idx, config_idx))
- def load_multiple_gradients(self, iter_idx: int, config_indices: list = None) -> list[dict, ...]:
- if config_indices is None:
- config_indices = self.all_config_indices()
- return [self.load_gradient(iter_idx, config_idx) for config_idx in config_indices]
- def load_cost(self, iter_idx: int, config_idx: int = None) -> np.ndarray:
- return load_cost(self.get_results_folder(iter_idx, config_idx))
- def load_cost_all_config(self, iter_idx: int):
- config_indices = self.all_config_indices()
- config_costs = []
- for config_idx in config_indices:
- config_costs.append(load_cost(self.get_results_folder(iter_idx, config_idx)))
- return np.stack(config_costs)
- def load_cost_all(self) -> np.ndarray:
- inter_indices = self.all_iter_indices()
- all_costs = []
- for iter_idx in inter_indices:
- all_costs.append(self.load_cost_all_config(iter_idx))
- return np.stack(all_costs)
- def load_simulation_log(self, iter_idx: int, config_idx: int = None) -> dict:
- return load_simulation_log(self.get_results_folder(iter_idx, config_idx))
- def load_body(self, iter_idx: int, config_idx: int, time_idx: int) -> rigid_body.RigidBody:
- trajectory = self.load_trajectory(iter_idx=iter_idx, config_idx=config_idx)
- return trajectory[time_idx]
- def load_multiple_bodies(self, iter_idx: int, time_idx: int = -1, config_indices: list[int] = None):
- if config_indices is None:
- config_indices = self.all_config_indices(iter_idx)
- return pytree_transf.stack([self.load_body(iter_idx, config_idx, time_idx) for config_idx in config_indices])
- def load_trajectory(self, iter_idx: int, config_idx: int) -> rigid_body.RigidBody:
- coord, orient = load_state_history(self.get_results_folder(iter_idx, config_idx))
- return rigid_body.RigidBody(jnp.asarray(coord), rigid_body.Quaternion(jnp.asarray(orient)))
- def load_multiple_config_trajectories(self, iter_idx: int, config_indices: list[int]) -> rigid_body.RigidBody:
- return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for config_idx in config_indices])
- def load_all_iter_trajectories(self, config_idx: int) -> rigid_body.RigidBody:
- iter_indices = self.all_iter_indices()
- return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for iter_idx in iter_indices])
- def get_config_for_visualization(self, iter_idx: int, config_idx: int, export_folder: Path,
- frame_idx: int = -1, zero_cm: bool = True):
- return get_config_for_visualization(self.get_results_folder(iter_idx, config_idx), export_folder,
- idx=frame_idx, zero_cm=zero_cm)
- def export_animation_data(self, iter_idx: int, config_idx: int, export_folder: Path,
- get_every: int = 1, zero_cm: bool = False) -> None:
- interaction_params = self.load_interaction_params(iter_idx)
- anim_data = prepare_animation_data(self.get_results_folder(iter_idx, config_idx), interaction_params,
- get_every=get_every, zero_cm=zero_cm)
- box_size = self.box_size(iter_idx=iter_idx)
- export_animation_data(anim_data, box_size, export_folder)
- @dataclass(frozen=True)
- class PlotData(ABC):
- """Abstract class for saving and loading plot data. Children should add attributes for this data."""
- run_params: dict
- results_path: Path
- @property
- @abstractmethod
- def cluster_data_folder(self) -> str:
- """Subclasses must define a folder name for saving/loading data."""
- pass
- @classmethod
- @abstractmethod
- def calculate_data(cls, results_folder: Path, **kwargs) -> "PlotData":
- """A method that fills the class with data."""
- pass
- @classmethod
- def get_save_path(cls, results_path) -> Path:
- results_base = results_path.parent
- results_filename = f'{results_path.name}.json'
- save_path_base = results_base / cls.cluster_data_folder
- save_path_base.mkdir(exist_ok=True)
- return save_path_base / results_filename
- def save(self) -> None:
- self_as_dict = vars(copy.deepcopy(self))
- for key in self_as_dict:
- if isinstance(self_as_dict[key], jnp.ndarray):
- self_as_dict[key] = np.asarray(self_as_dict[key]).tolist()
- self_as_dict['results_path'] = str(self_as_dict['results_path'])
- with open(self.get_save_path(self.results_path), 'w') as f:
- json.dump(self_as_dict, f)
- @classmethod
- def load(cls, results_path: Path) -> "PlotData":
- cluster_data_path = cls.get_save_path(results_path)
- if not cluster_data_path.exists():
- raise FileNotFoundError("Clustering results for the given path not yet exported.")
- with open(cluster_data_path, 'r') as f:
- cls_as_dict = json.load(f)
- for key in cls_as_dict:
- if isinstance(cls_as_dict[key], list):
- cls_as_dict[key] = jnp.asarray(cls_as_dict[key])
- cls_as_dict['results_path'] = Path(cls_as_dict['results_path'])
- return cls(**cls_as_dict)
- @classmethod
- def get_data(cls,
- results_folder: Path,
- recalculate: bool = False,
- **calculate_data_kwargs) -> "PlotData":
- if recalculate:
- print(f'Recalculating results for {results_folder}...')
- csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
- csd.save()
- return csd
- try:
- csd = cls.load(results_folder)
- except FileNotFoundError:
- print(f'Results for folder {results_folder} not yet exported, calculating...')
- csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
- csd.save()
- return csd
- def figure_file_name(base_file_name: str, results_folder: Path, *, iter_idx: int = None, config_idx: int = None,
- figure_folder: Path = Path("/home/andraz/CurvatureAssemblyFigures")):
- figure_folder.mkdir(exist_ok=True)
- results_folder_base_name, results_idx = file_management.split_base_and_num(results_folder.name, sep='_', no_num_return='')
- base = Path(base_file_name).stem
- suffix = Path(base_file_name).suffix
- fig_folder = figure_folder.joinpath(f'{results_folder_base_name}_{results_idx}')
- fig_folder.mkdir(exist_ok=True)
- if iter_idx is None and config_idx is None:
- return fig_folder.joinpath(f'{base}{suffix}')
- fig_folder = fig_folder.joinpath(f'{base}')
- fig_folder.mkdir(exist_ok=True)
- if config_idx is None:
- return fig_folder.joinpath(f'{base}_iter{iter_idx}{suffix}')
- if iter_idx is None:
- return fig_folder.joinpath(f'{base}_config{config_idx}{suffix}')
- fig_folder = fig_folder.joinpath(f'iter{iter_idx}')
- fig_folder.mkdir(exist_ok=True)
- return fig_folder.joinpath(f'{base}_iter{iter_idx}_config{config_idx}{suffix}')
|