from __future__ import annotations import time import sys import json from pathlib import Path with open(Path(sys.argv[1])) as config_file: config_data = json.load(config_file) with open(Path(sys.argv[2])) as run_data_file: run_params = json.load(run_data_file) with open(Path(sys.argv[3])) as int_param_file: int_params = json.load(int_param_file) from jax import config import os if config_data['device'] == 'cpu': os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={run_params["num_simulations"]}' elif config_data['device'] == 'gpu': # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45" pass else: raise ValueError('Unknown device type, only "cpu" and "gpu" are supported.') config.update("jax_enable_x64", True) # needs to be set before importing jax # config.update("jax_debug_nans", True) import jax from jax_md import space, quantity, rigid_body, dataclasses from curvature_assembly import oriented_particle, ellipsoid_contact, io_functions, simulation, data_protocols from curvature_assembly import cost_functions, energy, parallelization, fit, pytree_transf, patchy_interaction import optax from functools import partial import jax.numpy as jnp if config_data['save'] == 1: SAVE = True elif config_data['save'] == 0: SAVE = False else: raise ValueError('save" configuration parameter must be 0 or 1.') if config_data['autodif_opt'] == 1: only_forward_calculation = False elif config_data['autodif_opt'] == 0: only_forward_calculation = True run_params['num_iterations'] = 1 # overwrite number of iterations if we only do forward simulation else: raise ValueError('"autodif_opt" configuration parameter must be 0 or 1.') def get_interaction_params(int_params_dict: dict) -> data_protocols.InteractionParams: match config_data["type"]: case "ferro": interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict)) case "patchy": interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict)) interaction_params = dataclasses.replace(interaction_params, q0=0.0) case "quad": interaction_params = energy.QuadWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict)) case _: raise ValueError('Unknown simulation type, currently only "ferro", "quad", and "patchy" are supported.') return interaction_params def quad_attraction_invariant(params, d0_start, q0_start): params_dict = vars(params) new_dict = params_dict.copy() invariant = d0_start + q0_start ** 2 current = params_dict["d0"] + params_dict["q0"] ** 2 rescaling = invariant / current new_dict["d0"] = params_dict["d0"] * rescaling new_dict["q0"] = params_dict["q0"] * jnp.sqrt(rescaling) return type(params)(**new_dict) def params_to_optimize_to_bounds_kwargs(): bounds_kwargs = {'lm_magnitudes': None} try: # backwards compatibility if optimize_params is not given for param in run_params["optimize_params"]: bounds_kwargs[param] = None except KeyError: pass return bounds_kwargs def main(): results_folder = Path(config_data['results_base_folder']) init_folder = Path(config_data['init_folder']) simulation_params = simulation.NVTSimulationParams(num=run_params["num_particles"], density=run_params["density"], simulation_steps=run_params["simulation_steps"], dt=run_params["dt"], kT=run_params["kT"], config_every=run_params["config_every"], bptt_truncation=run_params["bptt_truncation"] ) # initialize interaction params and related optimization bounds interaction_params = get_interaction_params(int_params) lm_list = patchy_interaction.generate_lm_list(6, only_even_l=False, only_non_neg_m=False) interaction_params = interaction_params.init_lm_magnitudes(patchy_interaction.init_lm_coefs(lm_list, [(2,0)])) # interaction_params = interaction_params.init_unit_volume_particle() lower_bounds, upper_bounds = fit.bounds_params(interaction_params, **params_to_optimize_to_bounds_kwargs()) # set displacement and shift functions box_size_old = quantity.box_size_at_number_density(simulation_params.num, simulation_params.density, spatial_dimension=3) box_size = oriented_particle.box_size_at_ellipsoid_density(simulation_params.num, simulation_params.density, interaction_params.eigvals) displacement, shift = space.periodic(box_size) # load initial config(s) body = io_functions.load_multiple_initial_configs_single_object(simulation_params.num, simulation_params.density, [i for i in range(run_params["num_simulations"])], init_folder, coord_rescale_factor=box_size / box_size_old) # define energy function contact_fn = ellipsoid_contact.bp_contact_function # contact_fn = ellipsoid_contact.pw_contact_function match config_data["type"]: case "ferro" | "patchy": energy_fn = energy.ferro_wca_sphere_pair(displacement=displacement, lm=lm_list) case "quad": energy_fn = energy.quadrupolar_wca_sphere_pair(displacement=displacement, lm=lm_list) energy_fn = jax.checkpoint(energy_fn) # select cost function cost_fn = cost_functions.CurvedClustersResidualsCost(displacement, box_size, contact_fn, target_radius=run_params["target_radius"], residuals_avg_type=run_params["residuals_average"], residuals_cost_factor=run_params["residuals_cost_factor"]) # initialize optimization saver io_manager = io_functions.OptimizationSaver(results_folder.joinpath(Path(config_data['optimization_folder_name'])), simulation_params) # save metadata if SAVE: io_manager.export_cost_function_info(cost_fn) # a lot of run parameters are already saved as simulation_params, but it is easier to just save everything io_manager.export_run_params(run_params) io_manager.export_additional_simulation_data({'thermostat': config_data['thermostat'], 'lm_list': lm_list}) ########################################################################### # SIMULATION FUNCTION CONSTRUCTION ########################################################################### # prepare functions and auxiliary data container for the simulation if config_data["thermostat"] == "langevin": gamma = run_params["langevin_gamma"] init_fn, step_fn, aux = simulation.setup_langevin(energy_fn, shift, simulation_params, gamma=rigid_body.RigidBody(gamma, gamma)) elif config_data["thermostat"] == "nose-hoover": init_fn, step_fn, aux = simulation.setup_nose_hoover(energy_fn, shift, simulation_params, tau=simulation_params.dt * run_params["nose-hoover-tau"]) else: raise NotImplementedError('Thermostat in config file should be "langevin" or "nose-hoover".') # we must add additional dimension to aux, compatible with body leading dimension, used for parallelization # and consistency over multiple evaluations of parallelized function aux = pytree_transf.repeat_fields(aux, pytree_transf.data_length(body)) # create bptt simulation function bptt_simulation = simulation.truncated_bptt_nvt_simulation(step_fn, energy_fn, cost_fn, simulation_params, only_forward_calculation=only_forward_calculation) bptt_simulation = parallelization.pmap_segment_dispatch(jax.jit(bptt_simulation), map_argnums=(1, 2)) # create optimization configuration optimizer = optax.adam(learning_rate=run_params["learning_rate"]) opt_state = optimizer.init(interaction_params) param_rescalings = [partial(fit.normalize_param, param_name='lm_magnitudes'),] if config_data["type"] == "quad": param_rescalings.append(partial(quad_attraction_invariant, d0_start=interaction_params.d0, q0_start=interaction_params.q0)) fit_step = fit.fit_bptt(bptt_simulation, optimizer.update, clipping=50, grad_time_weights=run_params["grad_time_weights"], param_rescalings=param_rescalings, lower_bounds=lower_bounds, upper_bounds=upper_bounds) ############################################################################## # RUN SIMULATION ############################################################################## init_keys = jax.random.split(jax.random.PRNGKey(0), pytree_transf.data_length(aux, axis=0)) md_state = pytree_transf.map_over_leading_leaf_dimension(partial(init_fn, mass=simulation.ellipsoid_unit_mass(interaction_params.eigvals), **vars(interaction_params)), init_keys, body) # run optimization for i in range(run_params["num_iterations"]): if SAVE: io_manager.export_interaction_params(interaction_params) print(f'Params for iter {i}: ', interaction_params) t0 = time.perf_counter() interaction_params, opt_state, bptt_results, aux, grad_clipped = jax.block_until_ready(fit_step(interaction_params, opt_state, md_state, aux)) t1 = time.perf_counter() print(f'Simulation time: {t1 - t0}') print(f'End cost: {bptt_results.cost[:, -1]}') if SAVE: io_manager.export_multiple_results(bptt_results, aux) io_manager.export_clipped_gradients(grad_clipped) if __name__ == '__main__': main()