cost_functions.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. from typing import Callable, Literal, Protocol, Any
  2. import jax
  3. import jax.numpy as jnp
  4. from curvature_assembly import oriented_particle, clustering, smap, surface_fit, surface_fit_general
  5. from jax_md import rigid_body, space
  6. from functools import partial
  7. Array = jnp.ndarray
  8. class CostFn(Protocol):
  9. def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
  10. ...
  11. def __str__(self):
  12. ...
  13. @jax.jit
  14. def jit_flat_plane_cost_function(body: rigid_body.RigidBody, **interaction_params) -> float:
  15. n_part = body.center.shape[0]
  16. cm = jnp.mean(body.center, axis=0)
  17. matrix = 1 / n_part * jnp.sum(jnp.einsum('ni, nj -> nij', body.center - cm, body.center - cm), axis=0)
  18. values, vecs = jnp.linalg.eigh(matrix)
  19. return values[0]
  20. def single_cluster_cost(mask, body):
  21. n_part = jnp.sum(mask)
  22. cluster_particles = body.center * mask[:, None]
  23. centered = (cluster_particles - jnp.sum(cluster_particles, axis=0) / (n_part + 1e-8)) * mask[:, None]
  24. matrix = 1 / (n_part + 1e-8) * centered.T @ centered
  25. values, vecs = jnp.linalg.eigh(matrix)
  26. return values[0]
  27. class FlatPlaneClustersCost:
  28. def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
  29. self.num_clusters_penalty = num_clusters_penalty
  30. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  31. def __str__(self):
  32. return f"Flat plane clusters cost, penalty={self.num_clusters_penalty}"
  33. def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict):
  34. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  35. clusters = clustering.clustering(neighbors)
  36. cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
  37. cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
  38. return jnp.sum(cluster_costs) + self.num_clusters_penalty * clusters.n_clusters
  39. class DistanceCost:
  40. def __init__(self, displacement, contact_fn):
  41. self.displacement = displacement
  42. self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, num_steps=25)
  43. def __str__(self):
  44. return "Distance cost"
  45. def __call__(self, body: rigid_body.RigidBody, **interaction_params: dict) -> Array:
  46. num_particles = body.center.shape[0]
  47. dr = space.map_product(self.displacement)(body.center, body.center)
  48. eigsys = oriented_particle.eigensystem(body.orientation)
  49. cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
  50. (0, 0, None), 0), (0, None, 0), 0)(dr, eigsys, eigsys)
  51. mask = jnp.float32(1.0) - jnp.eye(num_particles)
  52. return 0.5 * jnp.sum(mask * cf)
  53. class NumClustersCost:
  54. def __init__(self, displacement, contact_fn, num_clusters_penalty: float):
  55. self.num_clusters_penalty = num_clusters_penalty
  56. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  57. def __str__(self):
  58. return f"Num clusters cost, penalty={self.num_clusters_penalty}"
  59. def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
  60. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  61. clusters = clustering.clustering(neighbors)
  62. return self.num_clusters_penalty * clusters.n_clusters
  63. class FlatPlaneCost:
  64. def __init__(self, displacement, contact_fn):
  65. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  66. def __str__(self):
  67. return "Flat plane cost"
  68. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  69. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  70. clusters = clustering.clustering(neighbors)
  71. cluster_masks = clustering.get_all_cluster_masks(clusters) # WARNING: using cluster mask may lead to AD problems
  72. cluster_costs = jax.vmap(partial(single_cluster_cost, body=body))(cluster_masks)
  73. return jnp.sum(cluster_costs)
  74. class SquaredClusterSizeCost:
  75. def __init__(self, displacement, contact_fn):
  76. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  77. def __str__(self):
  78. return "Squared cluster size cost"
  79. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  80. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  81. clusters = clustering.clustering(neighbors)
  82. num_particles = jnp.sum(clusters.n_part_per_cluster)
  83. return num_particles ** 2 / jnp.sum(clusters.n_part_per_cluster ** 2) - 1
  84. class FlatNeighbors:
  85. def __init__(self, displacement, contact_fn, num_neighbors: int = 6, **cf_kwargs):
  86. self.num_neighbors = num_neighbors
  87. self.displacement = displacement
  88. self.contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
  89. def __str__(self):
  90. return f"Flat neighbors cost, num_neighbors={self.num_neighbors}"
  91. def local_flatness(self, neighbor_indices: Array, body: rigid_body.RigidBody):
  92. neighbor_coord = body.center[neighbor_indices]
  93. mean = jnp.mean(neighbor_coord, axis=0, keepdims=True)
  94. centered_r = neighbor_coord - mean
  95. matrix = 1 / self.num_neighbors * centered_r.T @ centered_r
  96. values, vecs = jnp.linalg.eigh(matrix)
  97. return values[0]
  98. def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
  99. dr = space.map_product(self.displacement)(body.center, body.center)
  100. eigsys = oriented_particle.eigensystem(body.orientation)
  101. mapped_cf = jax.vmap(jax.vmap(partial(self.contact_function, eigvals=interaction_params['eigvals']),
  102. (0, 0, None), 0), (0, None, 0), 0)
  103. cf = mapped_cf(dr, eigsys, eigsys)
  104. indices = jnp.argsort(cf, axis=1)[:, 1:self.num_neighbors+1] # if we start with idx 0, also the central particle will count
  105. neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body))(indices)
  106. return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others?
  107. class WeightedFlatNeighbors:
  108. def __init__(self, displacement, weight_fn: Callable[[Array], Array]):
  109. self.weight_fn = weight_fn
  110. self.distance = jax.vmap(displacement, in_axes=(0, None))
  111. def __str__(self):
  112. return "Weighted flat neighbors cost"
  113. def local_flatness(self, particle_coord: Array, body: rigid_body.RigidBody):
  114. dr = self.distance(body.center, particle_coord) + 1e-8 # adding 1e-8 prevents Nan gradients
  115. neighbor_weights = self.weight_fn(jnp.linalg.norm(dr, axis=1))
  116. mean = 1 / jnp.sum(neighbor_weights) * jnp.sum(dr * neighbor_weights[:, None], axis=0, keepdims=True)
  117. centered_r = dr - mean
  118. matrix = 1 / jnp.sum(neighbor_weights) * centered_r.T @ jnp.diag(neighbor_weights) @ centered_r
  119. values, vecs = jnp.linalg.eigh(matrix)
  120. return values[0] * jnp.abs(values[1] - values[2])
  121. def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
  122. neighbor_flatness = jax.vmap(partial(self.local_flatness, body=body), in_axes=(0,))(body.center)
  123. return jnp.mean(neighbor_flatness) # should this be weighted somehow? Based on particle distance to others?
  124. class WeightedDistanceCost:
  125. def _init__(self, displacement,
  126. contact_fn,
  127. weight_fn: Callable[[Array], Array],
  128. **cf_kwargs):
  129. contact_function = oriented_particle.get_ellipsoid_contact_function_param(contact_fn, **cf_kwargs)
  130. self.cf_cost_fn = smap.oriented_pair(
  131. lambda r, e1, e2, **params: weight_fn(contact_function(r, e1, e2, params['eigvals'])),
  132. displacement)
  133. def __str__(self):
  134. return "Weighted distance cost"
  135. def __call__(self, body: rigid_body.RigidBody, **interaction_params) -> Array:
  136. num_particles = body.center.shape[0]
  137. return -self.cf_cost_fn(body, **interaction_params) / num_particles
  138. def normal_weight(x: Array, sigma: float, displacement: float = 1.):
  139. return 1 / (sigma * jnp.sqrt(2 * jnp.pi)) * jnp.exp(-0.5 * ((x - displacement) / sigma) ** 2)
  140. def center_of_mass_pbc(coord, box_size, mask):
  141. """Calculate the center of mass of a cluster given by mask, taking into account periodic boundary conditions."""
  142. num_cluster_particles = jnp.sum(mask)
  143. angle = coord * 2 * jnp.pi / box_size
  144. avg_sin = jnp.sum(jnp.sin(angle) * mask[:, None], axis=0) / num_cluster_particles
  145. avg_cos = jnp.sum(jnp.cos(angle) * mask[:, None], axis=0) / num_cluster_particles
  146. avg_angle = jnp.arctan2(-avg_sin, -avg_cos) + jnp.pi
  147. return box_size * avg_angle / (2 * jnp.pi)
  148. def displace_with_periodic_cm(coord, displacement, box_size, mask):
  149. """Displace all coordinates with a center of mass for a cluster given by mask."""
  150. cm = center_of_mass_pbc(coord, box_size, mask)
  151. mapped_displacement = jax.vmap(displacement, in_axes=(0, None))
  152. displaced_coord = mapped_displacement(coord, cm)
  153. return displaced_coord
  154. # @partial(jnp.vectorize, signature='(m),(m,m),(m,n),()->(m,n)')
  155. def contiguous_clusters(mask_pbc: Array, mask_box: Array, coord: Array, box_size: float):
  156. """
  157. Map clusters in a PBC box to a contiguous cluster with the center of mass in the coordinate origin.
  158. The algorithm is simple but can fail in certain cases. It is based on calculating distances between the center of
  159. mass for the whole cluster and all the centers of mass for the subclusters that we get by not taking into account
  160. periodic boundary conditions. The subclusters are displaced for the PBC period (box size) if this calculated
  161. distance is more than a half of box size in any component.
  162. :param mask_pbc: mask for a SINGLE cluster, taking into account PBC
  163. :param mask_box: N x N array containing all the masks for subclusters without PBC
  164. :param coord: coordinated of all particles in the box
  165. :param box_size: size of box side, assumes all sides are equal
  166. :return: displaced coordinates of all particles, with the center of mass of the given cluster in coordinate origin.
  167. """
  168. cm = center_of_mass_pbc(coord, box_size, mask_pbc)
  169. cm_subcl = jnp.sum(mask_box[..., None] * coord[None, ...], axis=1) / (jnp.sum(mask_box, axis=1) + 1e-6)[:,None]
  170. dist_between_cm = cm_subcl - cm[None, :]
  171. cluster_displacements = jnp.where(jnp.abs(dist_between_cm) > box_size / 2,
  172. jnp.sign(dist_between_cm) * jnp.full((coord.shape[0], 3), box_size,
  173. dtype=jnp.float64),
  174. jnp.zeros((coord.shape[0], 3)))
  175. # jnp.isclose must be used as we add 1e-12 to mask_pbc in some places to enable differentiation through clustering
  176. relevant_subclusters = jnp.all(jnp.isclose(mask_pbc[None, :] * mask_box, mask_box), axis=1)
  177. cluster_displacements = cluster_displacements * relevant_subclusters[:, None]
  178. # sum will just collapse the first dimension with no overlap as each particle can be a part of only one cluster
  179. particle_displacements = jnp.sum(cluster_displacements[:, None, :] * mask_box[:, :, None], axis=0)
  180. displaced_coord = coord - particle_displacements - cm
  181. return displaced_coord - jnp.sum(displaced_coord * mask_pbc[:, None], axis=0) / jnp.sum(mask_pbc)
  182. def box_displacement(Ra, Rb):
  183. """Calculate displacement vector in the simulation box without PBCs."""
  184. return space.pairwise_displacement(Ra, Rb)
  185. ResidualsAvgType = Literal['linear', 'quadratic']
  186. ResidualsAvgFn = Callable[[Array, Array], Array]
  187. def residuals_avg_fn_factory(which: ResidualsAvgType) -> ResidualsAvgFn:
  188. if which == 'linear':
  189. return lambda residuals, mask: jnp.sum(jnp.abs(residuals)) / (jnp.sum(mask) + 1e-6)
  190. if which == 'quadratic':
  191. return lambda residuals, mask: jnp.sum(residuals ** 2) / (jnp.sum(mask) + 1e-6)
  192. raise ValueError('Unknown type of residuals cost function.')
  193. @partial(jax.jit, static_argnums=(4,))
  194. def single_cluster_curv_radius(mask_pbc: Array,
  195. mask_box: Array,
  196. body: rigid_body.RigidBody,
  197. box_size: float,
  198. residuals_avg_fn: ResidualsAvgFn = residuals_avg_fn_factory('linear')):
  199. """
  200. Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
  201. filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
  202. the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
  203. whole system rigid_body.RigidBody.
  204. Returns:
  205. - fitted cluster radius
  206. - mean residuals for the fit
  207. """
  208. displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC
  209. cluster_coord = displaced_coord * mask_pbc[:, None]
  210. matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
  211. values, vecs = jnp.linalg.eigh(matrix)
  212. p0_1 = surface_fit.QuadraticSurfaceParams(center=2 * vecs[:, 0],
  213. radius=1.)
  214. p0_2 = surface_fit.QuadraticSurfaceParams(center=-2 * vecs[:, 0],
  215. radius=1.)
  216. opt1 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_1)
  217. opt2 = surface_fit.surface_fit_gn(surface_fit.spherical_surface, displaced_coord, mask_pbc, p0=p0_2)
  218. mean_residuals1 = residuals_avg_fn(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc), mask_pbc)
  219. mean_residuals2 = residuals_avg_fn(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc), mask_pbc)
  220. # first we order solutions based on fit residuals
  221. (opt1, opt2), (mr1, mr2) = jax.lax.cond(mean_residuals1 < mean_residuals2,
  222. lambda: ((opt1, opt2), (mean_residuals1, mean_residuals2)),
  223. lambda: ((opt2, opt1), (mean_residuals2, mean_residuals1)))
  224. # the second criterion makes sure to select the fit that didn't fail to converge
  225. # (in most cases, at least one converges)
  226. return jax.lax.cond(jnp.abs(opt1.radius) < 1e4,
  227. lambda: (jnp.abs(opt1.radius), mr1),
  228. lambda: (jnp.abs(opt2.radius), mr2))
  229. class CurvedClustersCost:
  230. def __init__(self,
  231. displacement,
  232. box_size,
  233. contact_fn,
  234. target_radius,
  235. radius_cutoff_mul=100):
  236. self.box_size = box_size
  237. self.target_radius = target_radius
  238. self.displacement = displacement
  239. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  240. self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
  241. self.cutoff = radius_cutoff_mul
  242. def __str__(self):
  243. return f"Curved clusters cost, target_radius={self.target_radius}"
  244. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  245. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  246. neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  247. clusters = clustering.clustering(neighbors)
  248. clusters_box = clustering.clustering(neighbors_box)
  249. # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
  250. # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
  251. all_radii, _ = jax.vmap(partial(single_cluster_curv_radius,
  252. mask_box=clusters_box.masks,
  253. body=body,
  254. box_size=self.box_size))(clusters.masks + 1e-12)
  255. # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
  256. all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
  257. cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
  258. relevant_clusters = clusters.n_part_per_cluster > 3
  259. cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
  260. num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
  261. return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
  262. class CurvedClustersResidualsCost:
  263. def __init__(self,
  264. displacement,
  265. box_size,
  266. contact_fn,
  267. target_radius,
  268. radius_cutoff_mul=100,
  269. residuals_cost_factor=1,
  270. residuals_avg_type: ResidualsAvgType = 'linear'):
  271. self.box_size = box_size
  272. self.target_radius = target_radius
  273. self.displacement = displacement
  274. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  275. self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
  276. self.cutoff = radius_cutoff_mul
  277. self.res_cost_fac = residuals_cost_factor
  278. self.residuals_avg_type = residuals_avg_type
  279. self.residuals_avg_fn = residuals_avg_fn_factory(residuals_avg_type)
  280. def __str__(self):
  281. return (f"Curved clusters residuals cost, target_radius={self.target_radius}, "
  282. f"{self.residuals_avg_type} residuals with factor {self.res_cost_fac}.")
  283. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  284. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  285. neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  286. clusters = clustering.clustering(neighbors)
  287. clusters_box = clustering.clustering(neighbors_box)
  288. # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
  289. # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
  290. all_radii, all_residuals = jax.vmap(partial(single_cluster_curv_radius,
  291. mask_box=clusters_box.masks,
  292. body=body,
  293. box_size=self.box_size,
  294. residuals_avg_fn=self.residuals_avg_fn))(clusters.masks + 1e-12)
  295. cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
  296. relevant_clusters = clusters.n_part_per_cluster > 3
  297. cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
  298. num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
  299. # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
  300. all_radii = jnp.where(all_radii > self.cutoff * self.target_radius, self.cutoff * self.target_radius, all_radii)
  301. all_residuals = jnp.where(all_residuals > 1e2, 0., all_residuals)
  302. curvature_cost = jnp.log(all_radii / self.target_radius) ** 2
  303. residuals_cost = self.res_cost_fac * all_residuals
  304. return jnp.sum((curvature_cost + residuals_cost) * cluster_weights) # / num_of_weighted_clusters
  305. @partial(jax.jit, static_argnums=(4,))
  306. def single_cluster_quadratic_surface(mask_pbc: Array,
  307. mask_box: Array,
  308. body: rigid_body.RigidBody,
  309. box_size: float,
  310. surface_constant: int = -1,
  311. ) -> Array:
  312. """
  313. Fit a circle to cluster particles. Cluster parameter should be an array of length N of cluster indices,
  314. filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
  315. the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
  316. whole system rigid_body.RigidBody.
  317. Returns:
  318. - fitted quadratic surface eigenvalues
  319. """
  320. displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size)
  321. p0 = surface_fit_general.GeneralQuadraticSurfaceParams()
  322. opt = surface_fit_general.surface_fit_gn(partial(surface_fit_general.quadratic_surface, constant=surface_constant),
  323. displaced_coord, mask_pbc, p0=p0)
  324. return jnp.linalg.eigvalsh(opt.quadratic_form)
  325. class QuadraticSurfaceClustersCost:
  326. def __init__(self,
  327. displacement,
  328. box_size,
  329. contact_fn,
  330. target_eigvals,
  331. surface_constat=-1):
  332. self.box_size = box_size
  333. self.target_eigvals = jnp.sort(target_eigvals)
  334. self.displacement = displacement
  335. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  336. self.neighboring_box = clustering.get_ellipsoid_neighboring_fn(box_displacement, contact_fn)
  337. self.surface_constat = surface_constat
  338. def __str__(self):
  339. return f"Curved clusters cost, target_eigvals={self.target_eigvals}"
  340. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  341. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  342. neighbors_box = self.neighboring_box(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  343. clusters = clustering.clustering(neighbors)
  344. clusters_box = clustering.clustering(neighbors_box)
  345. # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
  346. # this can change the fit values for small clusters, however, we later only take into account clusters with N>3
  347. all_eigvals = jax.vmap(partial(single_cluster_quadratic_surface,
  348. mask_box=clusters_box.masks,
  349. body=body,
  350. box_size=self.box_size,
  351. surface_constant=self.surface_constat))(clusters.masks + 1e-12)
  352. # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at some value
  353. cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
  354. relevant_clusters = clusters.n_part_per_cluster > 3
  355. cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
  356. num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
  357. return jnp.sum((all_eigvals - self.target_eigvals) ** 2 * cluster_weights[:, None]) / num_of_weighted_clusters
  358. def single_cluster_cylinder_radius(mask_pbc: Array,
  359. mask_box: Array,
  360. body: rigid_body.RigidBody,
  361. box_size: float):
  362. """
  363. Fit a cylinder to cluster particles. Cluster parameter should be an array of length N of cluster indices,
  364. filled to the end by values N for clusters smaller than the entire system size. (Such an array is exactly
  365. the output of clustering algorithm from clustering.py, saved in Clusters.clusters.) Body parameter is the
  366. whole system rigid_body.RigidBody.
  367. """
  368. displaced_coord = contiguous_clusters(mask_pbc, mask_box, body.center, box_size) # frame of reference without PBC
  369. cluster_coord = displaced_coord * mask_pbc[:, None]
  370. matrix = 1 / jnp.sum(mask_pbc) * cluster_coord.T @ cluster_coord
  371. values, vecs = jnp.linalg.eigh(matrix)
  372. euler_best_direction = jnp.array([jnp.arctan2(vecs[0, 2], vecs[1, 2]),
  373. jnp.arctan2(vecs[2, 2], jnp.sqrt(vecs[0, 2] ** 2 + vecs[1, 2] ** 2)),
  374. 0])
  375. p0_1 = surface_fit.QuadraticSurfaceParams(center=3 * vecs[:, 0],
  376. euler=euler_best_direction,
  377. radius=5.)
  378. p0_2 = surface_fit.QuadraticSurfaceParams(center=-3 * vecs[:, 0],
  379. euler=euler_best_direction,
  380. radius=5.)
  381. opt1 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_1)
  382. opt2 = surface_fit.surface_fit_gn(surface_fit.cylindrical_surface, displaced_coord, mask_pbc, p0=p0_2)
  383. mean_residuals1 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt1, displaced_coord, mask_pbc))) / (
  384. jnp.sum(mask_pbc) + 1e-6)
  385. mean_residuals2 = jnp.sum(jnp.abs(surface_fit.spherical_surface(opt2, displaced_coord, mask_pbc))) / (
  386. jnp.sum(mask_pbc) + 1e-6)
  387. return jax.lax.cond(mean_residuals1 < mean_residuals2,
  388. lambda: jnp.abs(opt1.radius),
  389. lambda: jnp.abs(opt2.radius))
  390. # return jnp.minimum(jnp.abs(opt1.radius), jnp.abs(opt2.radius))
  391. class CylindricalClustersCost:
  392. def __init__(self, displacement,
  393. box_size,
  394. contact_fn,
  395. target_radius):
  396. self.box_size = box_size
  397. self.target_radius = target_radius
  398. self.displacement = displacement
  399. self.neighboring = clustering.get_ellipsoid_neighboring_fn(displacement, contact_fn)
  400. def __str__(self):
  401. return f"Cylindrical clusters cost, target_radius={self.target_radius}"
  402. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  403. neighbors = self.neighboring(body, eigvals=interaction_params['eigvals'], neigh_contact_fn=1.4)
  404. clusters = clustering.clustering(neighbors)
  405. # we add 1e-12 to cluster masks to avoid NaN gradients that arise when the entire mask is 0
  406. # this can change the fit values for small clusters, however, we later only take into account clusters with N > 3
  407. all_radii = jax.vmap(partial(single_cluster_cylinder_radius,
  408. body=body,
  409. displacement=self.displacement,
  410. box_size=self.box_size))(clusters.masks + 1e-12)
  411. # as curvature radius of a cluster can get arbitrarily large, we cut off calculated radii at 1000 * target radius
  412. all_radii = jnp.where(all_radii > 1000 * self.target_radius, 1000 * self.target_radius, all_radii)
  413. cluster_weights = clusters.n_part_per_cluster / body.center.shape[0]
  414. relevant_clusters = clusters.n_part_per_cluster > 3
  415. cluster_weights = jnp.where(relevant_clusters, cluster_weights, 0.) # clusters with N < 3 are excluded
  416. num_of_weighted_clusters = jnp.sum(cluster_weights > 0) + 1e-6 # we add a small number to avoid dividing by 0
  417. return jnp.sum((jnp.log(all_radii / self.target_radius) ** 2) * cluster_weights) / num_of_weighted_clusters
  418. class CostCombinator:
  419. def __init__(self, cost_fns: list[CostFn], coefficients: list[float]):
  420. if len(cost_fns) != len(coefficients):
  421. raise ValueError(f'Lengths of cost_fn list and coefficients list should be equal, '
  422. f'got {len(cost_fns)} and {len(coefficients)}, respectively')
  423. self.cost_fns = cost_fns
  424. self.coefficients = coefficients
  425. def __str__(self):
  426. return f"".join([f"{coef} x {cf} \n" for cf, coef in zip(self.cost_fns, self.coefficients)])
  427. def __call__(self, body: rigid_body.RigidBody, **interaction_params):
  428. return sum(coef * cf(body, **interaction_params) for cf, coef in zip(self.cost_fns, self.coefficients))