main.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. from __future__ import annotations
  2. import time
  3. import sys
  4. import json
  5. from pathlib import Path
  6. with open(Path(sys.argv[1])) as config_file:
  7. config_data = json.load(config_file)
  8. with open(Path(sys.argv[2])) as run_data_file:
  9. run_params = json.load(run_data_file)
  10. with open(Path(sys.argv[3])) as int_param_file:
  11. int_params = json.load(int_param_file)
  12. from jax import config
  13. import os
  14. if config_data['device'] == 'cpu':
  15. os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={run_params["num_simulations"]}'
  16. elif config_data['device'] == 'gpu':
  17. # os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  18. # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
  19. # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".45"
  20. pass
  21. else:
  22. raise ValueError('Unknown device type, only "cpu" and "gpu" are supported.')
  23. config.update("jax_enable_x64", True) # needs to be set before importing jax
  24. # config.update("jax_debug_nans", True)
  25. import jax
  26. from jax_md import space, quantity, rigid_body, dataclasses
  27. from curvature_assembly import oriented_particle, ellipsoid_contact, io_functions, simulation, data_protocols
  28. from curvature_assembly import cost_functions, energy, parallelization, fit, pytree_transf, patchy_interaction
  29. import optax
  30. from functools import partial
  31. import jax.numpy as jnp
  32. if config_data['save'] == 1:
  33. SAVE = True
  34. elif config_data['save'] == 0:
  35. SAVE = False
  36. else:
  37. raise ValueError('save" configuration parameter must be 0 or 1.')
  38. if config_data['autodif_opt'] == 1:
  39. only_forward_calculation = False
  40. elif config_data['autodif_opt'] == 0:
  41. only_forward_calculation = True
  42. run_params['num_iterations'] = 1 # overwrite number of iterations if we only do forward simulation
  43. else:
  44. raise ValueError('"autodif_opt" configuration parameter must be 0 or 1.')
  45. def get_interaction_params(int_params_dict: dict) -> data_protocols.InteractionParams:
  46. match config_data["type"]:
  47. case "ferro":
  48. interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
  49. case "patchy":
  50. interaction_params = energy.FerroWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
  51. interaction_params = dataclasses.replace(interaction_params, q0=0.0)
  52. case "quad":
  53. interaction_params = energy.QuadWcaParams(**io_functions.convert_lists_to_arrays(int_params_dict))
  54. case _:
  55. raise ValueError('Unknown simulation type, currently only "ferro", "quad", and "patchy" are supported.')
  56. return interaction_params
  57. def quad_attraction_invariant(params, d0_start, q0_start):
  58. params_dict = vars(params)
  59. new_dict = params_dict.copy()
  60. invariant = d0_start + q0_start ** 2
  61. current = params_dict["d0"] + params_dict["q0"] ** 2
  62. rescaling = invariant / current
  63. new_dict["d0"] = params_dict["d0"] * rescaling
  64. new_dict["q0"] = params_dict["q0"] * jnp.sqrt(rescaling)
  65. return type(params)(**new_dict)
  66. def params_to_optimize_to_bounds_kwargs():
  67. bounds_kwargs = {'lm_magnitudes': None}
  68. try: # backwards compatibility if optimize_params is not given
  69. for param in run_params["optimize_params"]:
  70. bounds_kwargs[param] = None
  71. except KeyError:
  72. pass
  73. return bounds_kwargs
  74. def main():
  75. results_folder = Path(config_data['results_base_folder'])
  76. init_folder = Path(config_data['init_folder'])
  77. simulation_params = simulation.NVTSimulationParams(num=run_params["num_particles"],
  78. density=run_params["density"],
  79. simulation_steps=run_params["simulation_steps"],
  80. dt=run_params["dt"],
  81. kT=run_params["kT"],
  82. config_every=run_params["config_every"],
  83. bptt_truncation=run_params["bptt_truncation"]
  84. )
  85. # initialize interaction params and related optimization bounds
  86. interaction_params = get_interaction_params(int_params)
  87. lm_list = patchy_interaction.generate_lm_list(6, only_even_l=False, only_non_neg_m=False)
  88. interaction_params = interaction_params.init_lm_magnitudes(patchy_interaction.init_lm_coefs(lm_list,
  89. [(2,0)]))
  90. # interaction_params = interaction_params.init_unit_volume_particle()
  91. lower_bounds, upper_bounds = fit.bounds_params(interaction_params, **params_to_optimize_to_bounds_kwargs())
  92. # set displacement and shift functions
  93. box_size_old = quantity.box_size_at_number_density(simulation_params.num,
  94. simulation_params.density,
  95. spatial_dimension=3)
  96. box_size = oriented_particle.box_size_at_ellipsoid_density(simulation_params.num,
  97. simulation_params.density,
  98. interaction_params.eigvals)
  99. displacement, shift = space.periodic(box_size)
  100. # load initial config(s)
  101. body = io_functions.load_multiple_initial_configs_single_object(simulation_params.num,
  102. simulation_params.density,
  103. [i for i in range(run_params["num_simulations"])],
  104. init_folder,
  105. coord_rescale_factor=box_size / box_size_old)
  106. # define energy function
  107. contact_fn = ellipsoid_contact.bp_contact_function
  108. # contact_fn = ellipsoid_contact.pw_contact_function
  109. match config_data["type"]:
  110. case "ferro" | "patchy":
  111. energy_fn = energy.ferro_wca_sphere_pair(displacement=displacement, lm=lm_list)
  112. case "quad":
  113. energy_fn = energy.quadrupolar_wca_sphere_pair(displacement=displacement, lm=lm_list)
  114. energy_fn = jax.checkpoint(energy_fn)
  115. # select cost function
  116. cost_fn = cost_functions.CurvedClustersResidualsCost(displacement,
  117. box_size,
  118. contact_fn,
  119. target_radius=run_params["target_radius"],
  120. residuals_avg_type=run_params["residuals_average"],
  121. residuals_cost_factor=run_params["residuals_cost_factor"])
  122. # initialize optimization saver
  123. io_manager = io_functions.OptimizationSaver(results_folder.joinpath(Path(config_data['optimization_folder_name'])),
  124. simulation_params)
  125. # save metadata
  126. if SAVE:
  127. io_manager.export_cost_function_info(cost_fn)
  128. # a lot of run parameters are already saved as simulation_params, but it is easier to just save everything
  129. io_manager.export_run_params(run_params)
  130. io_manager.export_additional_simulation_data({'thermostat': config_data['thermostat'],
  131. 'lm_list': lm_list})
  132. ###########################################################################
  133. # SIMULATION FUNCTION CONSTRUCTION
  134. ###########################################################################
  135. # prepare functions and auxiliary data container for the simulation
  136. if config_data["thermostat"] == "langevin":
  137. gamma = run_params["langevin_gamma"]
  138. init_fn, step_fn, aux = simulation.setup_langevin(energy_fn, shift, simulation_params,
  139. gamma=rigid_body.RigidBody(gamma, gamma))
  140. elif config_data["thermostat"] == "nose-hoover":
  141. init_fn, step_fn, aux = simulation.setup_nose_hoover(energy_fn, shift, simulation_params,
  142. tau=simulation_params.dt * run_params["nose-hoover-tau"])
  143. else:
  144. raise NotImplementedError('Thermostat in config file should be "langevin" or "nose-hoover".')
  145. # we must add additional dimension to aux, compatible with body leading dimension, used for parallelization
  146. # and consistency over multiple evaluations of parallelized function
  147. aux = pytree_transf.repeat_fields(aux, pytree_transf.data_length(body))
  148. # create bptt simulation function
  149. bptt_simulation = simulation.truncated_bptt_nvt_simulation(step_fn,
  150. energy_fn,
  151. cost_fn,
  152. simulation_params,
  153. only_forward_calculation=only_forward_calculation)
  154. bptt_simulation = parallelization.pmap_segment_dispatch(jax.jit(bptt_simulation), map_argnums=(1, 2))
  155. # create optimization configuration
  156. optimizer = optax.adam(learning_rate=run_params["learning_rate"])
  157. opt_state = optimizer.init(interaction_params)
  158. param_rescalings = [partial(fit.normalize_param, param_name='lm_magnitudes'),]
  159. if config_data["type"] == "quad":
  160. param_rescalings.append(partial(quad_attraction_invariant,
  161. d0_start=interaction_params.d0, q0_start=interaction_params.q0))
  162. fit_step = fit.fit_bptt(bptt_simulation,
  163. optimizer.update,
  164. clipping=50,
  165. grad_time_weights=run_params["grad_time_weights"],
  166. param_rescalings=param_rescalings,
  167. lower_bounds=lower_bounds,
  168. upper_bounds=upper_bounds)
  169. ##############################################################################
  170. # RUN SIMULATION
  171. ##############################################################################
  172. init_keys = jax.random.split(jax.random.PRNGKey(0), pytree_transf.data_length(aux, axis=0))
  173. md_state = pytree_transf.map_over_leading_leaf_dimension(partial(init_fn,
  174. mass=simulation.ellipsoid_unit_mass(interaction_params.eigvals),
  175. **vars(interaction_params)),
  176. init_keys, body)
  177. # run optimization
  178. for i in range(run_params["num_iterations"]):
  179. if SAVE:
  180. io_manager.export_interaction_params(interaction_params)
  181. print(f'Params for iter {i}: ', interaction_params)
  182. t0 = time.perf_counter()
  183. interaction_params, opt_state, bptt_results, aux, grad_clipped = jax.block_until_ready(fit_step(interaction_params,
  184. opt_state,
  185. md_state,
  186. aux))
  187. t1 = time.perf_counter()
  188. print(f'Simulation time: {t1 - t0}')
  189. print(f'End cost: {bptt_results.cost[:, -1]}')
  190. if SAVE:
  191. io_manager.export_multiple_results(bptt_results, aux)
  192. io_manager.export_clipped_gradients(grad_clipped)
  193. if __name__ == '__main__':
  194. main()