123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- from __future__ import annotations
- import time
- import sys
- import json
- from pathlib import Path
- with open(Path(sys.argv[1])) as config_file:
- config_data = json.load(config_file)
- with open(Path(sys.argv[2])) as run_data_file:
- run_params = json.load(run_data_file)
- with open(Path(sys.argv[3])) as int_param_file:
- int_params = json.load(int_param_file)
- from jax import config
- import os
- if config_data['device'] == 'cpu':
- os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={run_params["num_simulations"]}'
- elif config_data['device'] == 'gpu':
- # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
- # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
- # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45"
- pass
- else:
- raise ValueError('Unknown device type, only "cpu" and "gpu" are supported.')
- config.update("jax_enable_x64", True) # needs to be set before importing jax
- # config.update("jax_debug_nans", True)
- import jax
- from jax_md import space, quantity, rigid_body, dataclasses
- from curvature_assembly import oriented_particle, ellipsoid_contact, io_functions, simulation, data_protocols
- from curvature_assembly import cost_functions, energy, parallelization, fit, pytree_transf, patchy_interaction
- import optax
- from functools import partial
- import jax.numpy as jnp
- if config_data['save'] == 1:
- SAVE = True
- elif config_data['save'] == 0:
- SAVE = False
- else:
- raise ValueError('save" configuration parameter must be 0 or 1.')
- if config_data['autodif_opt'] == 1:
- only_forward_calculation = False
- elif config_data['autodif_opt'] == 0:
- only_forward_calculation = True
- run_params['num_iterations'] = 1 # overwrite number of iterations if we only do forward simulation
- else:
- raise ValueError('"autodif_opt" configuration parameter must be 0 or 1.')
- def get_interaction_params(int_params_dict: dict) -> data_protocols.InteractionParams:
- match config_data["type"]:
- case "ferro":
- interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
- case "patchy":
- interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
- interaction_params = dataclasses.replace(interaction_params, q0=0.0)
- case "quad":
- interaction_params = energy.QuadWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
- case _:
- raise ValueError('Unknown simulation type, currently only "ferro", "quad", and "patchy" are supported.')
- return interaction_params
- def quad_attraction_invariant(params, d0_start, q0_start):
- params_dict = vars(params)
- new_dict = params_dict.copy()
- invariant = d0_start + q0_start ** 2
- current = params_dict["d0"] + params_dict["q0"] ** 2
- rescaling = invariant / current
- new_dict["d0"] = params_dict["d0"] * rescaling
- new_dict["q0"] = params_dict["q0"] * jnp.sqrt(rescaling)
- return type(params)(**new_dict)
- def params_to_optimize_to_bounds_kwargs():
- bounds_kwargs = {'lm_magnitudes': None}
- try: # backwards compatibility if optimize_params is not given
- for param in run_params["optimize_params"]:
- bounds_kwargs[param] = None
- except KeyError:
- pass
- return bounds_kwargs
- def main():
- results_folder = Path(config_data['results_base_folder'])
- init_folder = Path(config_data['init_folder'])
- simulation_params = simulation.NVTSimulationParams(num=run_params["num_particles"],
- density=run_params["density"],
- simulation_steps=run_params["simulation_steps"],
- dt=run_params["dt"],
- kT=run_params["kT"],
- config_every=run_params["config_every"],
- bptt_truncation=run_params["bptt_truncation"]
- )
- # initialize interaction params and related optimization bounds
- interaction_params = get_interaction_params(int_params)
- lm_list = patchy_interaction.generate_lm_list(6, only_even_l=False, only_non_neg_m=False)
- interaction_params = interaction_params.init_lm_magnitudes(patchy_interaction.init_lm_coefs(lm_list,
- [(2,0)]))
- # interaction_params = interaction_params.init_unit_volume_particle()
- lower_bounds, upper_bounds = fit.bounds_params(interaction_params, **params_to_optimize_to_bounds_kwargs())
- # set displacement and shift functions
- box_size_old = quantity.box_size_at_number_density(simulation_params.num,
- simulation_params.density,
- spatial_dimension=3)
- box_size = oriented_particle.box_size_at_ellipsoid_density(simulation_params.num,
- simulation_params.density,
- interaction_params.eigvals)
- displacement, shift = space.periodic(box_size)
- # load initial config(s)
- body = io_functions.load_multiple_initial_configs_single_object(simulation_params.num,
- simulation_params.density,
- [i for i in range(run_params["num_simulations"])],
- init_folder,
- coord_rescale_factor=box_size / box_size_old)
- # define energy function
- contact_fn = ellipsoid_contact.bp_contact_function
- # contact_fn = ellipsoid_contact.pw_contact_function
- match config_data["type"]:
- case "ferro" | "patchy":
- energy_fn = energy.ferro_wca_sphere_pair(displacement=displacement, lm=lm_list)
- case "quad":
- energy_fn = energy.quadrupolar_wca_sphere_pair(displacement=displacement, lm=lm_list)
- energy_fn = jax.checkpoint(energy_fn)
- # select cost function
- cost_fn = cost_functions.CurvedClustersResidualsCost(displacement,
- box_size,
- contact_fn,
- target_radius=run_params["target_radius"],
- residuals_avg_type=run_params["residuals_average"],
- residuals_cost_factor=run_params["residuals_cost_factor"])
- # initialize optimization saver
- io_manager = io_functions.OptimizationSaver(results_folder.joinpath(Path(config_data['optimization_folder_name'])),
- simulation_params)
- # save metadata
- if SAVE:
- io_manager.export_cost_function_info(cost_fn)
- # a lot of run parameters are already saved as simulation_params, but it is easier to just save everything
- io_manager.export_run_params(run_params)
- io_manager.export_additional_simulation_data({'thermostat': config_data['thermostat'],
- 'lm_list': lm_list})
- ###########################################################################
- # SIMULATION FUNCTION CONSTRUCTION
- ###########################################################################
- # prepare functions and auxiliary data container for the simulation
- if config_data["thermostat"] == "langevin":
- gamma = run_params["langevin_gamma"]
- init_fn, step_fn, aux = simulation.setup_langevin(energy_fn, shift, simulation_params,
- gamma=rigid_body.RigidBody(gamma, gamma))
- elif config_data["thermostat"] == "nose-hoover":
- init_fn, step_fn, aux = simulation.setup_nose_hoover(energy_fn, shift, simulation_params,
- tau=simulation_params.dt * run_params["nose-hoover-tau"])
- else:
- raise NotImplementedError('Thermostat in config file should be "langevin" or "nose-hoover".')
- # we must add additional dimension to aux, compatible with body leading dimension, used for parallelization
- # and consistency over multiple evaluations of parallelized function
- aux = pytree_transf.repeat_fields(aux, pytree_transf.data_length(body))
- # create bptt simulation function
- bptt_simulation = simulation.truncated_bptt_nvt_simulation(step_fn,
- energy_fn,
- cost_fn,
- simulation_params,
- only_forward_calculation=only_forward_calculation)
- bptt_simulation = parallelization.pmap_segment_dispatch(jax.jit(bptt_simulation), map_argnums=(1, 2))
- # create optimization configuration
- optimizer = optax.adam(learning_rate=run_params["learning_rate"])
- opt_state = optimizer.init(interaction_params)
- param_rescalings = [partial(fit.normalize_param, param_name='lm_magnitudes'),]
- if config_data["type"] == "quad":
- param_rescalings.append(partial(quad_attraction_invariant,
- d0_start=interaction_params.d0, q0_start=interaction_params.q0))
- fit_step = fit.fit_bptt(bptt_simulation,
- optimizer.update,
- clipping=50,
- grad_time_weights=run_params["grad_time_weights"],
- param_rescalings=param_rescalings,
- lower_bounds=lower_bounds,
- upper_bounds=upper_bounds)
- ##############################################################################
- # RUN SIMULATION
- ##############################################################################
- init_keys = jax.random.split(jax.random.PRNGKey(0), pytree_transf.data_length(aux, axis=0))
- md_state = pytree_transf.map_over_leading_leaf_dimension(partial(init_fn,
- mass=simulation.ellipsoid_unit_mass(interaction_params.eigvals),
- **vars(interaction_params)),
- init_keys, body)
- # run optimization
- for i in range(run_params["num_iterations"]):
- if SAVE:
- io_manager.export_interaction_params(interaction_params)
- print(f'Params for iter {i}: ', interaction_params)
- t0 = time.perf_counter()
- interaction_params, opt_state, bptt_results, aux, grad_clipped = jax.block_until_ready(fit_step(interaction_params,
- opt_state,
- md_state,
- aux))
- t1 = time.perf_counter()
- print(f'Simulation time: {t1 - t0}')
- print(f'End cost: {bptt_results.cost[:, -1]}')
- if SAVE:
- io_manager.export_multiple_results(bptt_results, aux)
- io_manager.export_clipped_gradients(grad_clipped)
- if __name__ == '__main__':
- main()
|