|
@@ -0,0 +1,965 @@
|
|
|
+import jax.numpy as jnp
|
|
|
+from typing import Callable
|
|
|
+import jax
|
|
|
+from functools import partial
|
|
|
+
|
|
|
+
|
|
|
+Array = jnp.ndarray
|
|
|
+
|
|
|
+
|
|
|
+def neg_m(sph_harm: Callable, m: int) -> Callable:
|
|
|
+ def wrapped(x):
|
|
|
+ return -1 ** m * jnp.conj(sph_harm(x))
|
|
|
+ return wrapped
|
|
|
+
|
|
|
+
|
|
|
+def Y00(x):
|
|
|
+ return jax.lax.convert_element_type(0.5 * jnp.sqrt(1 / jnp.pi), new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y10(x):
|
|
|
+ return jax.lax.convert_element_type(0.5 * jnp.sqrt(3 / jnp.pi) * x[2], new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y11(x):
|
|
|
+ return -0.5 * jnp.sqrt(1.5 / jnp.pi) * (x[0] + 1j * x[1])
|
|
|
+
|
|
|
+
|
|
|
+def Y20(x):
|
|
|
+ return jax.lax.convert_element_type(0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+ # return jax.lax.convert_element_type(0.3153915652525201 * (3 * x[2] ** 2 - 1),
|
|
|
+ # new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y21(x):
|
|
|
+ return -0.5 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) * x[2]
|
|
|
+ # return -0.7725484040463791 * (x[0] + 1j * x[1]) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y22(x):
|
|
|
+ return 0.25 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2
|
|
|
+ # return 0.3862742020231896 * (x[0] + 1j * x[1]) ** 2
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def Y30(x):
|
|
|
+ return jax.lax.convert_element_type(0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2]),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y31(x):
|
|
|
+ return -0.125 * jnp.sqrt(21 / jnp.pi) * (x[0] + 1j * x[1]) * (5 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y32(x):
|
|
|
+ return 0.25 * jnp.sqrt(52.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y33(x):
|
|
|
+ return -0.125 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3
|
|
|
+
|
|
|
+
|
|
|
+def Y40(x):
|
|
|
+ return jax.lax.convert_element_type(3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+ # return jax.lax.convert_element_type(0.1057855469152043 * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
|
|
|
+ # new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y41(x):
|
|
|
+ return -3 / 8 * jnp.sqrt(5 / jnp.pi) * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
|
|
|
+ # return -0.47308734787878 * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y42(x):
|
|
|
+ return 3 / 8 * jnp.sqrt(2.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 -1)
|
|
|
+ # return 0.3345232717786446 * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y43(x):
|
|
|
+ return -3 / 8 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * x[2]
|
|
|
+ # return 1.251671470898352 * (x[0] + 1j * x[1]) ** 3 * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y44(x):
|
|
|
+ return 3 / 16 * jnp.sqrt(17.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4
|
|
|
+ # return 0.4425326924449826 * (x[0] + 1j * x[1]) ** 4
|
|
|
+
|
|
|
+
|
|
|
+def Y50(x):
|
|
|
+ return jax.lax.convert_element_type(1 / 16 * jnp.sqrt(11 / jnp.pi) *
|
|
|
+ (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2]),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y51(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(82.5 / jnp.pi) *\
|
|
|
+ (x[0] + 1j * x[1]) * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y52(x):
|
|
|
+ return 1 / 8 * jnp.sqrt(577.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (3 * x[2] ** 3 - x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y53(x):
|
|
|
+ return -1 / 32 * jnp.sqrt(385 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (9 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y54(x):
|
|
|
+ return 3 / 16 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y55(x):
|
|
|
+ return -3 / 32 * jnp.sqrt(77 / jnp.pi) * (x[0] + 1j * x[1]) ** 5
|
|
|
+
|
|
|
+
|
|
|
+def Y60(x):
|
|
|
+ return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(13 / jnp.pi) *
|
|
|
+ (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y61(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(136.5 / jnp.pi) *\
|
|
|
+ (x[0] + 1j * x[1]) * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y62(x):
|
|
|
+ return 1 / 64 * jnp.sqrt(1365 / jnp.pi) *\
|
|
|
+ (x[0] + 1j * x[1]) ** 2 * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y63(x):
|
|
|
+ return -1 / 32 * jnp.sqrt(1365 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (11 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y64(x):
|
|
|
+ return 3 / 32 * jnp.sqrt(45.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (11 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y65(x):
|
|
|
+ return -3 / 32 * jnp.sqrt(1001 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y66(x):
|
|
|
+ return 1 / 64 * jnp.sqrt(3003 / jnp.pi) * (x[0] + 1j * x[1]) ** 6
|
|
|
+
|
|
|
+
|
|
|
+def Y70(x):
|
|
|
+ return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(15 / jnp.pi) *
|
|
|
+ (429 * x[2] ** 7 - 693 * x[2] ** 5 + 315 * x[2] ** 3 - 35 * x[2]),
|
|
|
+ new_dtype=jnp.complex128)
|
|
|
+
|
|
|
+
|
|
|
+def Y71(x):
|
|
|
+ return -1 / 64 * jnp.sqrt(52.5 / jnp.pi) *\
|
|
|
+ (x[0] + 1j * x[1]) * (429 * x[2] ** 6 - 495 * x[2] ** 4 + 135 * x[2] ** 2 - 5)
|
|
|
+
|
|
|
+
|
|
|
+def Y72(x):
|
|
|
+ return 3 / 64 * jnp.sqrt(35 / jnp.pi) * \
|
|
|
+ (x[0] + 1j * x[1]) ** 2 * (143 * x[2] ** 5 - 110 * x[2] ** 3 + 15 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y73(x):
|
|
|
+ return -3 / 64 * jnp.sqrt(17.5 / jnp.pi) * \
|
|
|
+ (x[0] + 1j * x[1]) ** 3 * (143 * x[2] ** 4 - 66 * x[2] ** 2 + 3)
|
|
|
+
|
|
|
+
|
|
|
+def Y74(x):
|
|
|
+ return 3 / 32 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (13 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y75(x):
|
|
|
+ return -3 / 64 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * (13 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y76(x):
|
|
|
+ return 3 / 64 * jnp.sqrt(5005 / jnp.pi) * (x[0] + 1j * x[1]) ** 6 * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y77(x):
|
|
|
+ return -3 / 64 * jnp.sqrt(357.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 7
|
|
|
+
|
|
|
+
|
|
|
+def get_sph_function(l, m):
|
|
|
+ if abs(m) <= l < len(sph_harm_list):
|
|
|
+ return jax.jit(sph_harm_list[l * (l + 1) + m])
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+# @partial(jax.jit, static_argnums=(1, 2))
|
|
|
+# def sph_harm(x: Array, l: Array, m: Array) -> Array:
|
|
|
+# return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
|
|
|
+
|
|
|
+
|
|
|
+def sph_harm_fn(l: tuple, m: tuple) -> Callable[[Array,], Array]:
|
|
|
+
|
|
|
+ def f(x: Array):
|
|
|
+ return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
|
|
|
+
|
|
|
+ return f
|
|
|
+
|
|
|
+
|
|
|
+def sph_harm_fn_custom(l: tuple, m: tuple) -> Callable[[Array,], Array]:
|
|
|
+
|
|
|
+ l_array = jnp.array(l)
|
|
|
+ m_array = jnp.array(m)
|
|
|
+
|
|
|
+ # @jax.custom_jvp
|
|
|
+ @jax.custom_vjp
|
|
|
+ def f(x: Array):
|
|
|
+ return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
|
|
|
+
|
|
|
+ # @f.defjvp
|
|
|
+ def sph_harm_jvp(primals, tangents):
|
|
|
+ x, = primals
|
|
|
+ dx, = tangents
|
|
|
+
|
|
|
+ primal_out = f(x)
|
|
|
+
|
|
|
+ extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ extanded_primal = extanded_primal.at[1:-1].set(primal_out)
|
|
|
+ rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
|
|
|
+ rho = jnp.sqrt(rho2)
|
|
|
+
|
|
|
+ coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
|
|
|
+ coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
|
|
|
+ theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
|
|
|
+ phi_derivatives = 1j * m_array * primal_out
|
|
|
+
|
|
|
+ x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
|
|
|
+ y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
|
|
|
+ z_derivatives = -theta_derivatives * rho
|
|
|
+
|
|
|
+ jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
|
|
|
+ tangent_out = jacobian.T @ dx
|
|
|
+ return primal_out, tangent_out
|
|
|
+
|
|
|
+ def sph_harm_fwd(x):
|
|
|
+ primal_out = f(x)
|
|
|
+
|
|
|
+ extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ extanded_primal = extanded_primal.at[1:-1].set(primal_out)
|
|
|
+ rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
|
|
|
+ rho = jnp.sqrt(rho2)
|
|
|
+
|
|
|
+ coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
|
|
|
+ coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
|
|
|
+ theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
|
|
|
+ phi_derivatives = 1j * m_array * primal_out
|
|
|
+
|
|
|
+ x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
|
|
|
+ y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
|
|
|
+ z_derivatives = -theta_derivatives * rho
|
|
|
+
|
|
|
+ jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
|
|
|
+ return primal_out, jacobian
|
|
|
+
|
|
|
+ def sph_harm_rev(jacobian, y_bar):
|
|
|
+ return (y_bar @ jacobian,)
|
|
|
+
|
|
|
+ f.defvjp(sph_harm_fwd, sph_harm_rev)
|
|
|
+
|
|
|
+ return f
|
|
|
+
|
|
|
+sph_harm_list = [Y00,
|
|
|
+ neg_m(Y11, 1), Y10, Y11,
|
|
|
+ neg_m(Y22, 2), neg_m(Y21, 1), Y20, Y21, Y22,
|
|
|
+ neg_m(Y33, 3), neg_m(Y32, 2), neg_m(Y31, 1), Y30, Y31, Y32, Y33,
|
|
|
+ neg_m(Y44, 4), neg_m(Y43, 3), neg_m(Y42, 2), neg_m(Y41, 1), Y40, Y41, Y42, Y43, Y44,
|
|
|
+ neg_m(Y55, 5), neg_m(Y54, 4), neg_m(Y53, 3), neg_m(Y52, 2), neg_m(Y51, 1), Y50, Y51, Y52, Y53, Y54, Y55,
|
|
|
+ neg_m(Y66, 6), neg_m(Y65, 5), neg_m(Y64, 4), neg_m(Y63, 3), neg_m(Y62, 2), neg_m(Y61, 1), Y60, Y61, Y62, Y63, Y64, Y65, Y66,
|
|
|
+ neg_m(Y77, 7), neg_m(Y76, 6), neg_m(Y75, 5), neg_m(Y74, 4), neg_m(Y73, 3), neg_m(Y72, 2), neg_m(Y71, 1), Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77]
|
|
|
+
|
|
|
+
|
|
|
+def Y00real(x):
|
|
|
+ return 0.5 * jnp.sqrt(1 / jnp.pi)
|
|
|
+
|
|
|
+
|
|
|
+def Y1m1real(x):
|
|
|
+ return 0.5 * jnp.sqrt(3 / jnp.pi) * x[1]
|
|
|
+
|
|
|
+
|
|
|
+def Y10real(x):
|
|
|
+ return 0.5 * jnp.sqrt(3 / jnp.pi) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y11real(x):
|
|
|
+ return 0.5 * jnp.sqrt(3 / jnp.pi) * x[0]
|
|
|
+
|
|
|
+
|
|
|
+def Y2m2real(x):
|
|
|
+ return 0.5 * jnp.sqrt(15 / jnp.pi) * x[0] * x[1]
|
|
|
+
|
|
|
+
|
|
|
+def Y2m1real(x):
|
|
|
+ return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[1]
|
|
|
+
|
|
|
+
|
|
|
+def Y20real(x):
|
|
|
+ return 0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y21real(x):
|
|
|
+ return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[0]
|
|
|
+
|
|
|
+
|
|
|
+def Y22real(x):
|
|
|
+ return 0.25 * jnp.sqrt(15 / jnp.pi) * (x[0] ** 2 - x[1] ** 2)
|
|
|
+
|
|
|
+
|
|
|
+def Y3m3real(x):
|
|
|
+ return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2)
|
|
|
+
|
|
|
+
|
|
|
+def Y3m2real(x):
|
|
|
+ return 0.5 * jnp.sqrt(105 / jnp.pi) * x[0] * x[1] * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y3m1real(x):
|
|
|
+ return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[1] * (5 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y30real(x):
|
|
|
+ return 0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y31real(x):
|
|
|
+ return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[0] * (5 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y32real(x):
|
|
|
+ return 0.25 * jnp.sqrt(105 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y33real(x):
|
|
|
+ return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2)
|
|
|
+
|
|
|
+
|
|
|
+def Y4m4real(x):
|
|
|
+ return 0.75 * jnp.sqrt(35 / jnp.pi) * x[0] * x[1] * (x[0] ** 2 - x[1] ** 2)
|
|
|
+
|
|
|
+
|
|
|
+def Y4m3real(x):
|
|
|
+ return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y4m2real(x):
|
|
|
+ return 0.75 * jnp.sqrt(5 / jnp.pi) * x[0] * x[1] * (7 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y4m1real(x):
|
|
|
+ return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[1] * (7 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y40real(x):
|
|
|
+ return 3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3)
|
|
|
+
|
|
|
+
|
|
|
+def Y41real(x):
|
|
|
+ return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[0] * (7 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y42real(x):
|
|
|
+ return 0.375 * jnp.sqrt(5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (7 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y43real(x):
|
|
|
+ return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y44real(x):
|
|
|
+ return 0.1875 * jnp.sqrt(35 / jnp.pi) * (x[0] ** 2 * (x[0] ** 2 - 3 * x[1] ** 2) - x[1] ** 2 * (3 * x[0] ** 2 - x[1] ** 2))
|
|
|
+
|
|
|
+
|
|
|
+def Y5m5real(x):
|
|
|
+ return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5)
|
|
|
+
|
|
|
+
|
|
|
+def Y5m4real(x):
|
|
|
+ return 3 / 4 * jnp.sqrt(385 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y5m3real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (9 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y5m2real(x):
|
|
|
+ return 1 / 4 * jnp.sqrt(1155 / jnp.pi) * x[0] * x[1] * (3 * x[2] ** 3 - x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y5m1real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[1] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y50real(x):
|
|
|
+ return 1 / 16 * jnp.sqrt(11 / jnp.pi) * (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y51real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[0] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y52real(x):
|
|
|
+ return 1 / 8 * jnp.sqrt(1155 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (3 * x[2] ** 3 - x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y53real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (9 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y54real(x):
|
|
|
+ return 3 / 16 * jnp.sqrt(385 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y55real(x):
|
|
|
+ return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4)
|
|
|
+
|
|
|
+
|
|
|
+def Y6m6real(x):
|
|
|
+ return 1 / 16 * jnp.sqrt(1501.5 / jnp.pi) * x[0] * x[1] * (-3 * x[0] ** 4 + 10 * x[0] ** 2 * x[1] ** 2 - 3 * x[1] ** 4)
|
|
|
+
|
|
|
+
|
|
|
+def Y6m5real(x):
|
|
|
+ return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y6m4real(x):
|
|
|
+ return 3 / 8 * jnp.sqrt(91 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * (11 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y6m3real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y6m2real(x):
|
|
|
+ return 1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * x[1] * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y6m1real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[1] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y60real(x):
|
|
|
+ return 1 / 32 * jnp.sqrt(13 / jnp.pi) * (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5)
|
|
|
+
|
|
|
+
|
|
|
+def Y61real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[0] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y62real(x):
|
|
|
+ return 1 / 32 * jnp.sqrt(682.5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y63real(x):
|
|
|
+ return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
|
|
|
+
|
|
|
+
|
|
|
+def Y64real(x):
|
|
|
+ return 3 / 32 * jnp.sqrt(91 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * (11 * x[2] ** 2 - 1)
|
|
|
+
|
|
|
+
|
|
|
+def Y65real(x):
|
|
|
+ return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4) * x[2]
|
|
|
+
|
|
|
+
|
|
|
+def Y66real(x):
|
|
|
+ return 1 / 32 * jnp.sqrt(1501.5 / jnp.pi) * (x[0] ** 6 - 15 * x[0] ** 4 * x[1] ** 2 + 15 * x[0] ** 2 * x[1] ** 4 - x[1] ** 6)
|
|
|
+
|
|
|
+
|
|
|
+def get_real_sph_function(l, m):
|
|
|
+ if abs(m) <= l < len(sph_harm_list):
|
|
|
+ return jax.jit(real_sph_harm_list[l * (l + 1) + m])
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+@partial(jax.jit, static_argnums=(1, 2))
|
|
|
+def real_sph_harm(x: Array, l: Array, m: Array) -> Array:
|
|
|
+ return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
|
|
|
+
|
|
|
+
|
|
|
+def real_sph_harm_fn_custom_fwd(l_max: int) -> Callable[[Array,], Array]:
|
|
|
+
|
|
|
+ l_list = list(range(0, l_max + 1))
|
|
|
+ lm_list = []
|
|
|
+ for l in l_list:
|
|
|
+ for m in range(-l, l + 1):
|
|
|
+ lm_list.append((l, m))
|
|
|
+
|
|
|
+ l_list, m_list = zip(*lm_list)
|
|
|
+
|
|
|
+ l_array = jnp.array(l_list)
|
|
|
+ m_array = jnp.array(m_list)
|
|
|
+
|
|
|
+ # indices where derivative rules differ from the general case
|
|
|
+ m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
|
|
|
+ m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
|
|
|
+ m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
|
|
|
+
|
|
|
+ m_plus_one_factors = jnp.sign(m_array)
|
|
|
+ m_minus_one_factors = -jnp.sign(m_array)
|
|
|
+ minus_m_plus_one_factors = jnp.sign(m_array)
|
|
|
+ minus_m_minus_one_factors = jnp.sign(m_array)
|
|
|
+
|
|
|
+ # m = -1, 0, 1 special cases:
|
|
|
+ m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
|
|
|
+ m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
|
|
|
+
|
|
|
+ m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
|
|
|
+ minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
|
|
|
+
|
|
|
+ m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
|
|
|
+ minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
|
|
|
+
|
|
|
+ @jax.custom_jvp
|
|
|
+ def f(x: Array):
|
|
|
+ return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
|
|
|
+
|
|
|
+ @f.defjvp
|
|
|
+ def sph_harm_jvp(primals, tangents):
|
|
|
+ x, = primals
|
|
|
+ dx, = tangents
|
|
|
+
|
|
|
+ primal_out = f(x)
|
|
|
+
|
|
|
+ extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ extanded_primal = extanded_primal.at[1:-1].set(primal_out)
|
|
|
+ mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
|
|
|
+
|
|
|
+ rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
|
|
|
+ rho = jnp.sqrt(rho2)
|
|
|
+
|
|
|
+ coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
|
|
|
+ coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
|
|
|
+
|
|
|
+ theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
|
|
|
+ x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
|
|
|
+ coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
|
|
|
+ x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
|
|
|
+ phi_derivatives = m_array * primal_out
|
|
|
+
|
|
|
+ x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
|
|
|
+ y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
|
|
|
+ z_derivatives = -theta_derivatives * rho
|
|
|
+
|
|
|
+ jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
|
|
|
+ tangent_out = jacobian.T @ dx
|
|
|
+ return primal_out, tangent_out
|
|
|
+
|
|
|
+ return f
|
|
|
+
|
|
|
+
|
|
|
+def real_sph_harm_fn_custom_rev(l_max: int) -> Callable[[Array, ], Array]:
|
|
|
+ l_list = list(range(0, l_max + 1))
|
|
|
+ lm_list = []
|
|
|
+ for l in l_list:
|
|
|
+ for m in range(-l, l + 1):
|
|
|
+ lm_list.append((l, m))
|
|
|
+
|
|
|
+ l_list, m_list = zip(*lm_list)
|
|
|
+
|
|
|
+ l_array = jnp.array(l_list)
|
|
|
+ m_array = jnp.array(m_list)
|
|
|
+
|
|
|
+ # indices where derivative rules differ from the general case
|
|
|
+ m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
|
|
|
+ m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
|
|
|
+ m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
|
|
|
+
|
|
|
+ m_plus_one_factors = jnp.sign(m_array)
|
|
|
+ m_minus_one_factors = -jnp.sign(m_array)
|
|
|
+ minus_m_plus_one_factors = jnp.sign(m_array)
|
|
|
+ minus_m_minus_one_factors = jnp.sign(m_array)
|
|
|
+
|
|
|
+ # m = -1, 0, 1 special cases:
|
|
|
+ m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
|
|
|
+ m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
|
|
|
+
|
|
|
+ m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
|
|
|
+ minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
|
|
|
+
|
|
|
+ m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
|
|
|
+ minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
|
|
|
+
|
|
|
+ @jax.custom_vjp
|
|
|
+ def f(x: Array):
|
|
|
+ return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
|
|
|
+
|
|
|
+ def sph_harm_fwd(x):
|
|
|
+ primal_out = f(x)
|
|
|
+
|
|
|
+ extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
|
|
|
+ extanded_primal = extanded_primal.at[1:-1].set(primal_out)
|
|
|
+ mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
|
|
|
+
|
|
|
+ rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
|
|
|
+ rho = jnp.sqrt(rho2)
|
|
|
+
|
|
|
+ coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
|
|
|
+ coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
|
|
|
+
|
|
|
+ theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
|
|
|
+ x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
|
|
|
+ coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
|
|
|
+ x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
|
|
|
+ phi_derivatives = m_array * primal_out
|
|
|
+
|
|
|
+ x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
|
|
|
+ y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
|
|
|
+ z_derivatives = -theta_derivatives * rho
|
|
|
+
|
|
|
+ jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
|
|
|
+
|
|
|
+ return primal_out, jacobian
|
|
|
+
|
|
|
+ def sph_harm_rev(jacobian, y_bar):
|
|
|
+ return (y_bar @ jacobian,)
|
|
|
+
|
|
|
+ f.defvjp(sph_harm_fwd, sph_harm_rev)
|
|
|
+
|
|
|
+ return f
|
|
|
+
|
|
|
+
|
|
|
+real_sph_harm_list = [Y00real,
|
|
|
+ Y1m1real, Y10real, Y11real,
|
|
|
+ Y2m2real, Y2m1real, Y20real, Y21real, Y22real,
|
|
|
+ Y3m3real, Y3m2real, Y3m1real, Y30real, Y31real, Y32real, Y33real,
|
|
|
+ Y4m4real, Y4m3real, Y4m2real, Y4m1real, Y40real, Y41real, Y42real, Y43real, Y44real,
|
|
|
+ Y5m5real, Y5m4real, Y5m3real, Y5m2real, Y5m1real, Y50real, Y51real, Y52real, Y53real, Y54real, Y55real,
|
|
|
+ Y6m6real, Y6m5real, Y6m4real, Y6m3real, Y6m2real, Y6m1real, Y60real, Y61real, Y62real, Y63real, Y64real, Y65real, Y66real]
|
|
|
+
|
|
|
+
|
|
|
+# spherical harmonic prefactors up to l=7
|
|
|
+ylm_prefactors = jnp.array([0.5 * jnp.sqrt(1 / jnp.pi),
|
|
|
+
|
|
|
+ 0.5 * jnp.sqrt(1.5 / jnp.pi),
|
|
|
+ 0.5 * jnp.sqrt(3 / jnp.pi),
|
|
|
+ -0.5 * jnp.sqrt(1.5 / jnp.pi),
|
|
|
+
|
|
|
+ 0.25 * jnp.sqrt(7.5 / jnp.pi),
|
|
|
+ 0.5 * jnp.sqrt(7.5 / jnp.pi),
|
|
|
+ 0.25 * jnp.sqrt(5 / jnp.pi),
|
|
|
+ -0.5 * jnp.sqrt(7.5 / jnp.pi),
|
|
|
+ 0.25 * jnp.sqrt(7.5 / jnp.pi),
|
|
|
+
|
|
|
+ 0.125 * jnp.sqrt(35 / jnp.pi),
|
|
|
+ 0.25 * jnp.sqrt(52.5 / jnp.pi),
|
|
|
+ 0.125 * jnp.sqrt(21 / jnp.pi),
|
|
|
+ 0.25 * jnp.sqrt(7 / jnp.pi),
|
|
|
+ -0.125 * jnp.sqrt(21 / jnp.pi),
|
|
|
+ 0.25 * jnp.sqrt(52.5 / jnp.pi),
|
|
|
+ -0.125 * jnp.sqrt(35 / jnp.pi),
|
|
|
+
|
|
|
+ 3 / 16 * jnp.sqrt(17.5 / jnp.pi),
|
|
|
+ 3 / 8 * jnp.sqrt(35 / jnp.pi),
|
|
|
+ 3 / 8 * jnp.sqrt(2.5 / jnp.pi),
|
|
|
+ 3 / 8 * jnp.sqrt(5 / jnp.pi),
|
|
|
+ 3 / 16 * jnp.sqrt(1 / jnp.pi),
|
|
|
+ -3 / 8 * jnp.sqrt(5 / jnp.pi),
|
|
|
+ 3 / 8 * jnp.sqrt(2.5 / jnp.pi),
|
|
|
+ -3 / 8 * jnp.sqrt(35 / jnp.pi),
|
|
|
+ 3 / 16 * jnp.sqrt(17.5 / jnp.pi),
|
|
|
+
|
|
|
+ 3 / 32 * jnp.sqrt(77 / jnp.pi),
|
|
|
+ 3 / 16 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ 1 / 32 * jnp.sqrt(385 / jnp.pi),
|
|
|
+ 1 / 8 * jnp.sqrt(577.5 / jnp.pi),
|
|
|
+ 1 / 16 * jnp.sqrt(82.5 / jnp.pi),
|
|
|
+ 1 / 16 * jnp.sqrt(11 / jnp.pi),
|
|
|
+ -1 / 16 * jnp.sqrt(82.5 / jnp.pi),
|
|
|
+ 1 / 8 * jnp.sqrt(577.5 / jnp.pi),
|
|
|
+ -1 / 32 * jnp.sqrt(385 / jnp.pi),
|
|
|
+ 3 / 16 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ -3 / 32 * jnp.sqrt(77 / jnp.pi),
|
|
|
+
|
|
|
+ 1 / 64 * jnp.sqrt(3003 / jnp.pi),
|
|
|
+ 3 / 32 * jnp.sqrt(1001 / jnp.pi),
|
|
|
+ 3 / 32 * jnp.sqrt(45.5 / jnp.pi),
|
|
|
+ 1 / 32 * jnp.sqrt(1365 / jnp.pi),
|
|
|
+ 1 / 64 * jnp.sqrt(1365 / jnp.pi),
|
|
|
+ 1 / 16 * jnp.sqrt(136.5 / jnp.pi),
|
|
|
+ 1 / 32 * jnp.sqrt(13 / jnp.pi),
|
|
|
+ -1 / 16 * jnp.sqrt(136.5 / jnp.pi),
|
|
|
+ 1 / 64 * jnp.sqrt(1365 / jnp.pi),
|
|
|
+ -1 / 32 * jnp.sqrt(1365 / jnp.pi),
|
|
|
+ 3 / 32 * jnp.sqrt(45.5 / jnp.pi),
|
|
|
+ -3 / 32 * jnp.sqrt(1001 / jnp.pi),
|
|
|
+ 1 / 64 * jnp.sqrt(3003 / jnp.pi),
|
|
|
+
|
|
|
+ 3 / 64 * jnp.sqrt(357.5 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(5005 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ 3 / 32 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(17.5 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(35 / jnp.pi),
|
|
|
+ 1 / 64 * jnp.sqrt(52.5 / jnp.pi),
|
|
|
+ 1 / 32 * jnp.sqrt(15 / jnp.pi),
|
|
|
+ -1 / 64 * jnp.sqrt(52.5 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(35 / jnp.pi),
|
|
|
+ -3 / 64 * jnp.sqrt(17.5 / jnp.pi),
|
|
|
+ 3 / 32 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ -3 / 64 * jnp.sqrt(192.5 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(5005 / jnp.pi),
|
|
|
+ 3 / 64 * jnp.sqrt(357.5 / jnp.pi),
|
|
|
+ ])
|
|
|
+
|
|
|
+# coefficient array up to l=7
|
|
|
+z_coef_array = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 3, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 5, 0, 0, 0, 0, 0],
|
|
|
+ [0, -3, 0, 5, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 5, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 7, 0, 0, 0, 0, 0],
|
|
|
+ [0, -3, 0, 7, 0, 0, 0, 0],
|
|
|
+ [3, 0, -30, 0, 35, 0, 0, 0],
|
|
|
+ [0, -3, 0, 7, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 7, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 9, 0, 0, 0, 0, 0],
|
|
|
+ [0, -1, 0, 3, 0, 0, 0, 0],
|
|
|
+ [1, 0, -14, 0, 21, 0, 0, 0],
|
|
|
+ [0, 15, 0, -70, 0, 63, 0, 0],
|
|
|
+ [1, 0, -14, 0, 21, 0, 0, 0],
|
|
|
+ [0, -1, 0, 3, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 9, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 11, 0, 0, 0, 0, 0],
|
|
|
+ [0, -3, 0, 11, 0, 0, 0, 0],
|
|
|
+ [1, 0, -18, 0, 33, 0, 0, 0],
|
|
|
+ [0, 5, 0, -30, 0, 33, 0, 0],
|
|
|
+ [-5, 0, 105, 0, -315, 0, 231, 0],
|
|
|
+ [0, 5, 0, -30, 0, 33, 0, 0],
|
|
|
+ [1, 0, -18, 0, 33, 0, 0, 0],
|
|
|
+ [0, -3, 0, 11, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 11, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 13, 0, 0, 0, 0, 0],
|
|
|
+ [0, -3, 0, 13, 0, 0, 0, 0],
|
|
|
+ [3, 0, -66, 0, 143, 0, 0, 0],
|
|
|
+ [0, 15, 0, -110, 0, 143, 0, 0],
|
|
|
+ [-5, 0, 135, 0, -495, 0, 429, 0],
|
|
|
+ [0, -35, 0, 315, 0, -693, 0, 429],
|
|
|
+ [-5, 0, 135, 0, -495, 0, 429, 0],
|
|
|
+ [0, 15, 0, -110, 0, 143, 0, 0],
|
|
|
+ [3, 0, -66, 0, 143, 0, 0, 0],
|
|
|
+ [0, -3, 0, 13, 0, 0, 0, 0],
|
|
|
+ [-1, 0, 13, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ ])
|
|
|
+
|
|
|
+
|
|
|
+x_p_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+ [0, 0, 0, 0, 0, 0, 1, 0],
|
|
|
+
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+ [0, 0, 0, 0, 0, 0, 1, 0],
|
|
|
+ [0, 0, 0, 0, 0, 0, 0, 1],
|
|
|
+ ])
|
|
|
+
|
|
|
+
|
|
|
+x_m_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 0, 0, 0, 0, 1, 0],
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+
|
|
|
+ [0, 0, 0, 0, 0, 0, 0, 1],
|
|
|
+ [0, 0, 0, 0, 0, 0, 1, 0],
|
|
|
+ [0, 0, 0, 0, 0, 1, 0, 0],
|
|
|
+ [0, 0, 0, 0, 1, 0, 0, 0],
|
|
|
+ [0, 0, 0, 1, 0, 0, 0, 0],
|
|
|
+ [0, 0, 1, 0, 0, 0, 0, 0],
|
|
|
+ [0, 1, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ [1, 0, 0, 0, 0, 0, 0, 0],
|
|
|
+ ])
|
|
|
+
|
|
|
+
|
|
|
+def sph_harm_not_fast(l: Array, m: Array) -> Callable:
|
|
|
+ max_l = jnp.max(l)
|
|
|
+ idx = l * (l + 1) + m # l ** 2 + l + m
|
|
|
+
|
|
|
+ prefactors = ylm_prefactors[idx]
|
|
|
+ z_coefs = z_coef_array[idx, :max_l+1]
|
|
|
+ xy = x_p_iy_coefs[idx, :max_l+1]
|
|
|
+ xy_conj = x_m_iy_coefs[idx, :max_l+1]
|
|
|
+
|
|
|
+ powers = jnp.arange(max_l + 1, dtype=jnp.int32)
|
|
|
+
|
|
|
+ @partial(jnp.vectorize, signature='(d)->(k)')
|
|
|
+ def f(x):
|
|
|
+ z_powers = x[2] ** powers
|
|
|
+ xy_powers = (x[0] + 1j * x[1]) ** powers
|
|
|
+ xy_conj_powers = jnp.conj(xy_powers)
|
|
|
+ return prefactors * (z_coefs @ z_powers) * (xy @ xy_powers) * (xy_conj @ xy_conj_powers)
|
|
|
+
|
|
|
+ @partial(jnp.vectorize, signature='(d)->(k)')
|
|
|
+ def f_positive_m(x):
|
|
|
+ z_powers = x[2] ** powers
|
|
|
+ xy_powers = (x[0] + 1j * x[1]) ** m
|
|
|
+ return prefactors * (z_coefs @ z_powers) * xy_powers
|
|
|
+
|
|
|
+ if jnp.all(m >= 0):
|
|
|
+ return f_positive_m
|
|
|
+
|
|
|
+ return jax.jit(f)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|