import numpy as np
from pathlib import Path
from curvature_assembly import oriented_particle, pytree_transf, file_management
from curvature_assembly.data_protocols import *
import json
from enum import Enum
from typing import TypeVar
from functools import partial
import jax
from dataclasses import dataclass
import copy
from abc import ABC, abstractmethod


T = TypeVar('T')


SIMULATION_PARAMS_FILENAME = 'simulation_params.json'
INTERACTION_PARAMS_FILENAME = 'interaction_params.json'
PARAMS_GRAD_FILENAME = 'param_grad.json'
PARAMS_GRAD_CLIPPED_FILENAME = 'param_grad_clipped.json'
NEIGHBOR_LIST_PARAMS_FILENAME = 'neighbor_list_params.json'
COST_FILENAME = 'cost.dat'
SIMULATION_LOG_FILENAME = 'simulation_log.npz'
COORD_HISTORY_FILENAME = 'coord_history.npy'
ORIENT_HISTORY_FILENAME = 'orient_history.npy'
BOX_FILENAME = 'box_size.dat'


def coord_filename(idx: int):
    if idx is None:
        return 'coord.dat'
    return f'coord_{idx}.dat'


def orient_filename(idx: int):
    if idx is None:
        return 'orient.dat'
    return f'orient_{idx}.dat'


def weight_matrix_filename(idx: int):
    if idx is None:
        return 'weight_matrix.dat'
    return f'weight_matrix_{idx}.dat'


def export_simulation_params(params: SimulationParams, path: Path) -> None:
    """Export simulation parameters as a dictionary."""
    with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'w') as f:
        json.dump(vars(params), f)


def load_simulation_params(path: Path) -> dict:
    """Load simulation parameters as a dictionary."""
    with open(path.joinpath(SIMULATION_PARAMS_FILENAME), 'r') as f:
        params_dict = json.load(f)
    return params_dict


def convert_arrays_to_lists(params: InteractionParams) -> dict:
    """
    Converts jax arrays in InteractionParams instance to lists and returns params as dict.
    Used for saving params in .json files.
    """
    no_array_dict = {}
    for key, val in vars(params).items():
        no_array_dict[key] = np.asarray(val).tolist() if isinstance(val, jnp.ndarray) else val
    return no_array_dict


def convert_lists_to_arrays(params_dict: dict, force_float: bool = True) -> dict:
    """Converts list in a dictionary to jax arrays."""
    array_dict = {}
    for key, val in params_dict.items():
        if force_float and isinstance(val, int):
            val = float(val)
        array_dict[key] = jnp.array(val) if isinstance(val, list) else val
    return array_dict


def export_interaction_params(params: InteractionParams, path: Path, filename: str = None) -> None:
    """Export interaction parameters as a dictionary. Jax arrays are converted to lists."""
    if filename is None:
        filename = INTERACTION_PARAMS_FILENAME
    filename = path.joinpath(filename)
    file_management.overwrite_protection(filename)
    with open(filename, 'w') as f:
        json.dump(convert_arrays_to_lists(params), f)


def convert_enum_to_int(params: NeighborListParams) -> dict:
    no_array_dict = {}
    for key, val in vars(params).items():
        no_array_dict[key] = val.value if isinstance(val, Enum) else val
    return no_array_dict


def export_neighbor_list_params(params: NeighborListParams, path: Path) -> None:
    """Export neighbor list parameters as a dictionary."""
    filename = path.joinpath(NEIGHBOR_LIST_PARAMS_FILENAME)
    file_management.overwrite_protection(filename)
    with open(filename, 'w') as f:
        json.dump(convert_enum_to_int(params), f)


def export_cost(cost: Array, path: Path) -> None:
    """Export cost function array."""
    filename = path.joinpath(COST_FILENAME)
    file_management.overwrite_protection(filename)
    np.savetxt(filename, cost)


export_param_grad = partial(export_interaction_params, filename=PARAMS_GRAD_FILENAME)
export_param_grad_clipped = partial(export_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)


def load_interaction_params(path: Path, filename: str = None, convert_arrays=True) -> dict:
    """Load interaction parameters as a dictionary."""
    if filename is None:
        filename = INTERACTION_PARAMS_FILENAME
    with open(path.joinpath(filename), 'r') as f:
        params_dict = json.load(f)
    if convert_arrays:
        return convert_lists_to_arrays(params_dict)
    return params_dict

load_param_grad = partial(load_interaction_params, filename=PARAMS_GRAD_FILENAME)
load_param_grad_clipped = partial(load_interaction_params, filename=PARAMS_GRAD_CLIPPED_FILENAME)


def load_cost(path: Path) -> np.ndarray:
    """Load cost function array."""
    filename = path.joinpath(COST_FILENAME)
    return np.loadtxt(filename)


def save_single_config(body: rigid_body.RigidBody, folder: Path, save_idx: int = None) -> None:
    """General function for saving single config data."""
    np.savetxt(folder.joinpath(coord_filename(save_idx)).resolve(), body.center)
    np.savetxt(folder.joinpath(orient_filename(save_idx)).resolve(), body.orientation.vec)


def load_single_config(folder: Path, save_idx: int = None) -> rigid_body.RigidBody:
    """General function for loading single config data."""
    coord = jnp.asarray(np.loadtxt(folder.joinpath(coord_filename(save_idx)).resolve()))
    orient = jnp.asarray(np.loadtxt(folder.joinpath(orient_filename(save_idx)).resolve()))
    return rigid_body.RigidBody(coord, rigid_body.Quaternion(orient))


def init_config_folder_name(num: int, density: float) -> str:
    return f'n{num}rho{int(1000 * density)}'


def save_initial_config(body: rigid_body.RigidBody, density: float, idx: int, init_folder: Path) -> None:
    """Save the initial RigidBody configuration with a given density and index."""
    save_folder = init_folder.joinpath(init_config_folder_name(body.center.shape[0], density))
    save_folder.mkdir(exist_ok=True, parents=True)
    save_single_config(body, save_folder, idx)


def load_initial_config(n: int, density: float, idx: int, init_folder: Path) -> rigid_body.RigidBody:
    """Load the initial RigidBody configuration with a given density and index."""
    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
    return load_single_config(save_folder, idx)


def load_multiple_initial_configs(n: int, density: float, indices: list[int], init_folder: Path) \
        -> list[rigid_body.RigidBody]:
    """Load multiple initial RigidBody configurations with a given density and a list of indices."""
    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
    return [load_single_config(save_folder, idx) for idx in indices]


def load_multiple_initial_configs_single_object(n: int, density: float, indices: list[int], init_folder: Path,
                                                coord_rescale_factor: float = None) -> rigid_body.RigidBody:
    """Load multiple initial RigidBody configurations with a given density and a list of indices as a single object."""
    save_folder = init_folder.joinpath(init_config_folder_name(n, density))
    coord = []
    orient = []
    for i in indices:
        coord_i = jnp.asarray(np.loadtxt(save_folder.joinpath(coord_filename(i)).resolve()))
        if coord_rescale_factor is not None:
            coord_i *= coord_rescale_factor
        coord.append(coord_i)
        orient.append(jnp.asarray(np.loadtxt(save_folder.joinpath(orient_filename(i)).resolve())))
    return rigid_body.RigidBody(jnp.stack(coord, axis=0), rigid_body.Quaternion(jnp.stack(orient, axis=0)))


def simulation_log_data_fields(simulation_log: SimulationLog) -> dict:
    """Return a dictionary of data fields in a simulation log object, ie ignoring other internal attributes."""
    data_dict = {}
    for key, val in vars(simulation_log).items():
        try:
            if val.shape[0] == pytree_transf.data_length(simulation_log, ignore_non_array_leaves=True):
                data_dict[key] = val[jnp.nonzero(val)]  # exclude zero entries that may not have been populated
        except (AttributeError, IndexError):
            pass
    return data_dict


def export_simulation_log(simulation_log: SimulationLog,
                          folder: Path) -> None:
    """Export simulation log data in a single file."""
    file_management.overwrite_protection(folder.joinpath(SIMULATION_LOG_FILENAME))
    data_dict = simulation_log_data_fields(simulation_log)
    np.savez(folder.joinpath(SIMULATION_LOG_FILENAME), **data_dict)


def load_simulation_log(folder: Path) -> dict:
    """Load simulation log data from file."""
    npz_file = np.load(folder.joinpath(SIMULATION_LOG_FILENAME))
    return dict(npz_file)


def export_state_history(state_history: SimulationStateHistory, folder: Path) -> None:
    """Save simulation state history data."""
    # we exclude array indices that were not populated
    relevant_indices = jnp.nonzero(jnp.linalg.norm(state_history.coord, axis=(-2, -1)))
    np.save(folder.joinpath(COORD_HISTORY_FILENAME), state_history.coord[relevant_indices])
    np.save(folder.joinpath(ORIENT_HISTORY_FILENAME), state_history.orient[relevant_indices])


def load_state_history(folder: Path) -> tuple[np.ndarray, np.ndarray]:
    """Save simulation state history data."""
    coord = np.load(folder.joinpath(COORD_HISTORY_FILENAME))
    orient = np.load(folder.joinpath(ORIENT_HISTORY_FILENAME))
    return coord, orient


def export_simulation_state(body: rigid_body.RigidBody,
                            simulation_params: SimulationParams,
                            interaction_params: InteractionParams,
                            folder: Path,
                            idx: int) -> None:
    """Export config data along with simulation and interaction parameters used in simulation."""
    export_simulation_params(simulation_params, folder)
    export_interaction_params(interaction_params, folder, idx=idx)
    np.savetxt(folder.joinpath(f'coord_frame{idx}.dat').resolve(), body.center)
    np.savetxt(folder.joinpath(f'orient_frame{idx}.dat').resolve(), body.orientation.vec)
    np.savetxt(folder.joinpath(f'weight_matrix_frame{idx}.dat').resolve(),
               oriented_particle.get_weight_matrices(body.orientation, interaction_params.eigvals).reshape(-1, 9))


def direct_visualization_export(body: rigid_body.RigidBody, eigvals: Array, export_folder: Path):
    coord = body.center
    weight_matrix = oriented_particle.get_weight_matrices(body.orientation, eigvals).reshape(-1, 9)
    np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
    np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)


def get_config_for_visualization(results_folder: Path,
                                  export_folder: Path,
                                  idx: int = -1,
                                  zero_cm: bool = True) -> None:
    interaction_params = load_interaction_params(results_folder)
    coord, orient = load_state_history(results_folder)
    if zero_cm:
        coord = cannonicalize_cm(coord)
    weight_matrix = oriented_particle.get_weight_matrices(
        rigid_body.Quaternion(orient[idx]), interaction_params['eigvals']).reshape(-1, 9)
    np.savetxt(export_folder.joinpath(f'coord.dat').resolve(), coord)
    np.savetxt(export_folder.joinpath(f'orient.dat').resolve(), orient)
    np.savetxt(export_folder.joinpath(f'weight_matrix.dat').resolve(), weight_matrix)


def cannonicalize_cm(coord: Array) -> Array:
    cm = jnp.mean(coord, axis=-2)
    return coord - cm[..., None, :]


def prepare_animation_data(results_folder: Path, interaction_params: dict,
                           get_every: int = 1, zero_cm=True) -> dict[str, np.ndarray]:

    coord_history, orient_history = load_state_history(results_folder)
    if zero_cm:
        coord_history = cannonicalize_cm(coord_history)
    coord_history = coord_history[::get_every]
    orient_history = orient_history[::get_every]

    def weight_matrix_frame(quaternion_vec):
        return oriented_particle.get_weight_matrices(
            rigid_body.Quaternion(quaternion_vec), interaction_params['eigvals']).reshape(-1, 9)

    weight_matrix_hist = jax.vmap(weight_matrix_frame)(orient_history)
    eigensystem = oriented_particle.eigensystem(rigid_body.Quaternion(orient_history))

    return {'coord': coord_history, 'weight_matrix': weight_matrix_hist, 'eigensystem': eigensystem}


def export_animation_data(anim_data: dict[str, np.ndarray], box: float, export_folder: Path) -> None:
    """Saves animation data matrices."""
    np.save(export_folder.joinpath('anim_coord'), anim_data['coord'])
    np.save(export_folder.joinpath('anim_weight_matrix'), anim_data['weight_matrix'])
    np.save(export_folder.joinpath('anim_eigensystem'), anim_data['eigensystem'])
    np.savetxt(export_folder.joinpath(BOX_FILENAME), np.asarray(box).reshape(1,))


def export_cost_and_grad(cost: float,
                         grad: InteractionParams | None,
                         folder: Path,
                         idx: int) -> None:
    """
    Export gradient data into a new file and append cost function value to an existing one. If gradient data
    was not calculated, pass None to the function.
    """
    with open(folder.joinpath('cost_function.dat'), 'a') as f:
        f.writelines(f'{cost: .4f}\n')
    if grad is not None:
        export_param_grad(grad, folder, idx=idx)


class OptimizationSaver:

    def __init__(self, folder: Path, simulation_params: SimulationParams,
                 overwrite_folder_with_no_results=False,
                 folder_num: int = None):
        if folder_num is not None:
            self.base_folder = file_management.new_folder_with_number(folder, folder_num)
        else:
            self.base_folder = file_management.new_folder(folder)
        export_simulation_params(simulation_params, self.base_folder)
        self._export_results_happened = False
        self._export_inter_params_happened = False
        self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter_0'))

    def _get_iter_folder(self, check_happened: bool) -> Path:
        if check_happened:
            self._iter_folder = file_management.new_folder(self.base_folder.joinpath(f'iter'))
            self._export_results_happened = False
            self._export_inter_params_happened = False
        return self._iter_folder

    def _get_config_folder(self, config_idx: int) -> Path:
        folder = self._iter_folder.joinpath(f'config_{config_idx}')
        folder.mkdir(exist_ok=True)
        return folder

    def export_interaction_params(self, interaction_params: InteractionParams) -> None:
        folder = self._get_iter_folder(self._export_results_happened)
        export_interaction_params(interaction_params, folder)
        self._export_inter_params_happened = True

    def export_param_updates(self, updates: InteractionParams) -> None:
        folder = self._get_iter_folder(self._export_results_happened)
        export_interaction_params(updates, folder, filename='interaction_param_updates.json')
        self._export_inter_params_happened = True

    def export_run_params(self, run_params: dict):
        with open(self.base_folder.joinpath(f'run_params.json'), 'w') as f:
            json.dump(run_params, f)

    def export_additional_simulation_data(self, data: dict):
        with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'w') as f:
            json.dump(data, f)

    def export_cost_function_info(self, cost_fn):
        with open(self.base_folder.joinpath(f'cost_function_info.dat'), 'w') as f:
            f.write(str(cost_fn))

    def export_results(self, bptt_results: BpttResults, aux: SimulationAux) -> None:
        folder = self._get_iter_folder(self._export_results_happened)
        try:
            export_cost(bptt_results.cost, folder)
            export_param_grad(bptt_results.grad, folder)
            export_simulation_log(aux.log, folder)
            export_state_history(aux.state_history, folder)
            self._export_results_happened = True
        except ValueError:
            raise ValueError('For exporting multiple results, use method "export_multiple_results".')

    def export_multiple_results(self,
                                bptt_results: BpttResults,
                                aux: SimulationAux) -> None:
        self._get_iter_folder(self._export_results_happened)
        bptt_results_list = pytree_transf.split_to_list(bptt_results)
        aux_list = pytree_transf.split_to_list(aux)
        for config_idx, (result, a) in enumerate(zip(bptt_results_list, aux_list)):
            folder = self._get_config_folder(config_idx)
            export_cost(result.cost, folder)
            export_param_grad(result.grad, folder)
            export_simulation_log(a.log, folder)
            export_state_history(a.state_history, folder)
        self._export_results_happened = True

    def export_clipped_gradients(self, grad_clipped: InteractionParams):
        grad_clipped_list = pytree_transf.split_to_list(grad_clipped)
        for config_idx, grad in enumerate(grad_clipped_list):
            folder = self._get_config_folder(config_idx)
            export_param_grad_clipped(grad, folder)


class NoResultsError(Exception):
    pass


class OptimizationLoader:
    """Convenience class to load results of an optimization simulation."""

    def __init__(self, folder: Path):
        self.base_folder = folder.resolve()
        if not self.base_folder.exists():
            raise NoResultsError(f"Results folder {self.base_folder} does not exist.")

    def get_results_folder(self, iter_idx: int, config_idx: int = None):
        if iter_idx < 0:
            iter_idx = self.all_iter_indices()[iter_idx]
        if config_idx is None:
            return self.base_folder.joinpath(f'iter_{iter_idx}')
        return self.base_folder.joinpath(f'iter_{iter_idx}').joinpath(f'config_{config_idx}')

    def last_iter_idx(self):
        try:
            return self.all_iter_indices()[-1]
        except IndexError:
            return 0

    def all_config_indices(self, iter_idx: int = None) -> list:
        if iter_idx is None:
            iter_idx = self.last_iter_idx()
        iteration_folders = [folder for folder in self.get_results_folder(iter_idx).glob(f'config_*')]
        all_directory_nums = []
        for folder in iteration_folders:
            _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
            all_directory_nums.append(dir_num)
        return sorted(all_directory_nums)

    def num_replicas(self):
        return len(self.all_config_indices())

    def all_iter_indices(self) -> list:
        iteration_folders = [folder for folder in self.base_folder.glob(f'iter_*')]
        all_directory_nums = []
        for folder in iteration_folders:
            _, dir_num = file_management.split_base_and_num(folder.name, sep='_', no_num_return=0)
            if not file_management.recursive_dir_empty(folder, ignore_top_level_files=True):
                all_directory_nums.append(dir_num)
        return sorted(all_directory_nums)

    def load_simulation_params(self, iter_idx: int = None, config_idx: int = None) -> dict:
        return load_simulation_params(self.base_folder)

    def box_size_at_number_density(self):
        simulation_params = self.load_simulation_params()
        return oriented_particle.box_size_at_number_density(simulation_params["num"],
                                                   simulation_params["density"],
                                                   spatial_dimension=3)

    def box_size_at_ellipsoid_density(self, iter_idx: int = None):
        simulation_params = self.load_simulation_params()
        if iter_idx is None:
            iter_idx = self.last_iter_idx()
        interaction_params = self.load_interaction_params(iter_idx)
        return oriented_particle.box_size_at_ellipsoid_density(simulation_params["num"],
                                                               simulation_params["density"],
                                                               interaction_params["eigvals"])

    def box_size(self, iter_idx: int = None):
        if iter_idx is None:
            iter_idx = self.last_iter_idx()
        interaction_params = self.load_interaction_params(iter_idx)
        particle_volume = oriented_particle.ellipsoid_volume(interaction_params["eigvals"])
        if jnp.all(jnp.isclose(particle_volume, 1., atol=1e-4)):
            return self.box_size_at_number_density()
        return self.box_size_at_ellipsoid_density(iter_idx=iter_idx)

    def load_additional_simulation_data(self) ->  dict:
        with open(self.base_folder.joinpath(f'aux_simulation_data.json'), 'r') as f:
            data = json.load(f)
        return data

    def load_run_params(self) -> dict:
        with open(self.base_folder.joinpath(f'run_params.json'), 'r') as f:
            run_params = json.load(f)
        return run_params

    def load_interaction_params(self, iter_idx: int, config_idx: int = None, convert_arrays=True) -> dict:
        return load_interaction_params(self.get_results_folder(iter_idx), convert_arrays=convert_arrays)

    def load_multiple_interaction_params(self, iter_indices: list = None) -> dict:
        if iter_indices is None:
            iter_indices = self.all_iter_indices()
        return pytree_transf.stack([self.load_interaction_params(iter_idx) for iter_idx in iter_indices])

    def load_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
        return load_param_grad(self.get_results_folder(iter_idx, config_idx))

    def load_clipped_gradient(self, iter_idx: int, config_idx: int = None) -> dict:
        return load_param_grad_clipped(self.get_results_folder(iter_idx, config_idx))

    def load_multiple_gradients(self, iter_idx: int, config_indices: list = None) -> list[dict, ...]:
        if config_indices is None:
            config_indices = self.all_config_indices()
        return [self.load_gradient(iter_idx, config_idx) for config_idx in config_indices]

    def load_cost(self, iter_idx: int, config_idx: int = None) -> np.ndarray:
        return load_cost(self.get_results_folder(iter_idx, config_idx))

    def load_cost_all_config(self, iter_idx: int):
        config_indices = self.all_config_indices()
        config_costs = []
        for config_idx in config_indices:
            config_costs.append(load_cost(self.get_results_folder(iter_idx, config_idx)))
        return np.stack(config_costs)

    def load_cost_all(self) -> np.ndarray:
        inter_indices = self.all_iter_indices()
        all_costs = []
        for iter_idx in inter_indices:
            all_costs.append(self.load_cost_all_config(iter_idx))
        return np.stack(all_costs)

    def load_simulation_log(self, iter_idx: int, config_idx: int = None) -> dict:
        return load_simulation_log(self.get_results_folder(iter_idx, config_idx))

    def load_body(self, iter_idx: int, config_idx: int, time_idx: int) -> rigid_body.RigidBody:
        trajectory = self.load_trajectory(iter_idx=iter_idx, config_idx=config_idx)
        return trajectory[time_idx]

    def load_multiple_bodies(self, iter_idx: int, time_idx: int = -1, config_indices: list[int] = None):
        if config_indices is None:
            config_indices = self.all_config_indices(iter_idx)
        return pytree_transf.stack([self.load_body(iter_idx, config_idx, time_idx) for config_idx in config_indices])

    def load_trajectory(self, iter_idx: int, config_idx: int) -> rigid_body.RigidBody:
        coord, orient = load_state_history(self.get_results_folder(iter_idx, config_idx))
        return rigid_body.RigidBody(jnp.asarray(coord), rigid_body.Quaternion(jnp.asarray(orient)))

    def load_multiple_config_trajectories(self, iter_idx: int, config_indices: list[int]) -> rigid_body.RigidBody:
        return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for config_idx in config_indices])

    def load_all_iter_trajectories(self, config_idx: int) -> rigid_body.RigidBody:
        iter_indices = self.all_iter_indices()
        return pytree_transf.stack([self.load_trajectory(iter_idx, config_idx) for iter_idx in iter_indices])

    def get_config_for_visualization(self, iter_idx: int, config_idx: int, export_folder: Path,
                                     frame_idx: int = -1, zero_cm: bool = True):
        return get_config_for_visualization(self.get_results_folder(iter_idx, config_idx), export_folder,
                                            idx=frame_idx, zero_cm=zero_cm)

    def export_animation_data(self, iter_idx: int, config_idx: int, export_folder: Path,
                              get_every: int = 1, zero_cm: bool = False) -> None:
        interaction_params = self.load_interaction_params(iter_idx)
        anim_data = prepare_animation_data(self.get_results_folder(iter_idx, config_idx), interaction_params,
                                           get_every=get_every, zero_cm=zero_cm)
        box_size = self.box_size(iter_idx=iter_idx)
        export_animation_data(anim_data, box_size, export_folder)


@dataclass(frozen=True)
class PlotData(ABC):
    """Abstract class for saving and loading plot data. Children should add attributes for this data."""

    run_params: dict
    results_path: Path

    @property
    @abstractmethod
    def cluster_data_folder(self) -> str:
        """Subclasses must define a folder name for saving/loading data."""
        pass

    @classmethod
    @abstractmethod
    def calculate_data(cls, results_folder: Path, **kwargs) -> "PlotData":
        """A method that fills the class with data."""
        pass

    @classmethod
    def get_save_path(cls, results_path) -> Path:
        results_base = results_path.parent
        results_filename = f'{results_path.name}.json'
        save_path_base = results_base / cls.cluster_data_folder
        save_path_base.mkdir(exist_ok=True)
        return save_path_base / results_filename

    def save(self) -> None:
        self_as_dict = vars(copy.deepcopy(self))
        for key in self_as_dict:
            if isinstance(self_as_dict[key], jnp.ndarray):
                self_as_dict[key] = np.asarray(self_as_dict[key]).tolist()
        self_as_dict['results_path'] = str(self_as_dict['results_path'])
        with open(self.get_save_path(self.results_path), 'w') as f:
            json.dump(self_as_dict, f)

    @classmethod
    def load(cls, results_path: Path) -> "PlotData":
        cluster_data_path = cls.get_save_path(results_path)
        if not cluster_data_path.exists():
            raise FileNotFoundError("Clustering results for the given path not yet exported.")
        with open(cluster_data_path, 'r') as f:
            cls_as_dict = json.load(f)
        for key in cls_as_dict:
            if isinstance(cls_as_dict[key], list):
                cls_as_dict[key] = jnp.asarray(cls_as_dict[key])
        cls_as_dict['results_path'] = Path(cls_as_dict['results_path'])
        return cls(**cls_as_dict)

    @classmethod
    def get_data(cls,
                 results_folder: Path,
                 recalculate: bool = False,
                 **calculate_data_kwargs) -> "PlotData":
        if recalculate:
            print(f'Recalculating results for {results_folder}...')
            csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
            csd.save()
            return csd
        try:
            csd = cls.load(results_folder)
        except FileNotFoundError:
            print(f'Results for folder {results_folder} not yet exported, calculating...')
            csd = cls.calculate_data(results_folder, **calculate_data_kwargs)
            csd.save()
        return csd


def figure_file_name(base_file_name: str, results_folder: Path, *, iter_idx: int = None, config_idx: int = None,
                     figure_folder: Path = Path("/home/andraz/CurvatureAssemblyFigures")):
    figure_folder.mkdir(exist_ok=True)
    results_folder_base_name, results_idx = file_management.split_base_and_num(results_folder.name, sep='_', no_num_return='')
    base = Path(base_file_name).stem
    suffix = Path(base_file_name).suffix
    fig_folder = figure_folder.joinpath(f'{results_folder_base_name}_{results_idx}')
    fig_folder.mkdir(exist_ok=True)
    if iter_idx is None and config_idx is None:
        return fig_folder.joinpath(f'{base}{suffix}')

    fig_folder = fig_folder.joinpath(f'{base}')
    fig_folder.mkdir(exist_ok=True)
    if config_idx is None:
        return fig_folder.joinpath(f'{base}_iter{iter_idx}{suffix}')
    if iter_idx is None:
        return fig_folder.joinpath(f'{base}_config{config_idx}{suffix}')

    fig_folder = fig_folder.joinpath(f'iter{iter_idx}')
    fig_folder.mkdir(exist_ok=True)
    return fig_folder.joinpath(f'{base}_iter{iter_idx}_config{config_idx}{suffix}')