multipole_interaction.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import jax.numpy as jnp
  2. import jax
  3. from curvature_assembly import oriented_particle
  4. Array = jnp.ndarray
  5. def quadrupolar_eigenvalues(q0: Array, theta: Array) -> Array:
  6. return q0 * jnp.array([(jnp.cos(theta) + 3) / 4, (jnp.cos(theta) - 3) / 4, -jnp.cos(theta) / 2])
  7. def quadrupolar_interaction(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array) -> Array:
  8. """General quadrupolar interaction. However, it is really slow to evaluate in gradient-based simulations."""
  9. distance2 = jnp.sum(dr ** 2)
  10. distance4 = distance2 ** 2
  11. distance = jnp.sqrt(distance2)
  12. qf1 = oriented_particle.qf_from_rotation(eigsys1, oriented_particle.make_diagonal(eigvals))
  13. qf2 = oriented_particle.qf_from_rotation(eigsys2, oriented_particle.make_diagonal(eigvals))
  14. dr2 = jax.lax.dot_general(dr, dr, dimension_numbers=(((), ()), ((), ())))
  15. dr4 = jax.lax.dot_general(dr2, dr2, dimension_numbers=(((), ()), ((), ())))
  16. term1 = jnp.einsum('ijkl, ij, kl', dr4, qf1, qf2)
  17. term2 = jnp.einsum('jk, ij, ik', dr2, qf1, qf2)
  18. term3 = jnp.einsum('ij, ij', qf1, qf2)
  19. return 1 / (3 * distance ** 5) * (35 * term1 / distance4 - 20 * term2 / distance2 + 2 * term3)
  20. def lin_quad_energy(dr: Array, eigsys1: Array, eigsys2: Array, eigvals: Array):
  21. """Interaction between two linear quadrupoles with eigenvalues [1. -1, 0] in this exact order."""
  22. q0 = eigvals[0]
  23. mi = eigsys1[:, 0]
  24. ni = eigsys1[:, 1]
  25. mj = eigsys2[:, 0]
  26. nj = eigsys2[:, 1]
  27. dist = jnp.sqrt(jnp.sum(dr * dr))
  28. rij_hat = dr / dist
  29. mi_rij = jnp.sum(mi * rij_hat)
  30. mj_rij = jnp.sum(mj * rij_hat)
  31. ni_rij = jnp.sum(ni * rij_hat)
  32. nj_rij = jnp.sum(nj * rij_hat)
  33. mi_mj = jnp.sum(mi * mj)
  34. ni_nj = jnp.sum(ni * nj)
  35. mi_nj = jnp.sum(mi * nj)
  36. ni_mj = jnp.sum(ni * mj)
  37. Aij = mi_rij ** 2 * mj_rij ** 2 - mi_rij ** 2 * nj_rij ** 2 - ni_rij ** 2 * mj_rij ** 2 + ni_rij ** 2 * nj_rij ** 2
  38. Bij = mi_mj * mi_rij * mj_rij - mi_nj * mi_rij * nj_rij - ni_mj * ni_rij * mj_rij + ni_nj * ni_rij * nj_rij
  39. Cij = mi_mj ** 2 - mi_nj ** 2 - ni_mj ** 2 + ni_nj ** 2
  40. return q0 ** 2 / (3 * dist ** 5) * (35 * Aij - 20 * Bij + 2 * Cij)
  41. def ferro_orientational_energy(dr: Array, eigsys1: Array, eigsys2: Array, softness: float = 1.5):
  42. """
  43. Ferromagnetic-like interaction between a pair of particles. Must be combined with some distance based term.
  44. Softness parameter is a factor that scales the second term of the expansion and relates to the energy sensitivity
  45. on deviations from the parallel configuration for side by side particles. Lower values mean more stiff potential.
  46. Increasing it too much can lead to the preference for dipolar-like ordering (at softness = 3, effects notable at
  47. softness >= 2).
  48. """
  49. pi = eigsys1[:, 2]
  50. pj = eigsys2[:, 2]
  51. dist = jnp.sqrt(jnp.sum(dr * dr))
  52. rij_hat = dr / dist
  53. # positive values for attraction as added distance based term should make the entire energy negative
  54. return jnp.sum(pi * pj) - softness * jnp.sum(pi * rij_hat) * jnp.sum(pj * rij_hat)