import jax.numpy as jnp from typing import Callable import jax from functools import partial Array = jnp.ndarray def neg_m(sph_harm: Callable, m: int) -> Callable: def wrapped(x): return -1 ** m * jnp.conj(sph_harm(x)) return wrapped def Y00(x): return jax.lax.convert_element_type(0.5 * jnp.sqrt(1 / jnp.pi), new_dtype=jnp.complex128) def Y10(x): return jax.lax.convert_element_type(0.5 * jnp.sqrt(3 / jnp.pi) * x[2], new_dtype=jnp.complex128) def Y11(x): return -0.5 * jnp.sqrt(1.5 / jnp.pi) * (x[0] + 1j * x[1]) def Y20(x): return jax.lax.convert_element_type(0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1), new_dtype=jnp.complex128) # return jax.lax.convert_element_type(0.3153915652525201 * (3 * x[2] ** 2 - 1), # new_dtype=jnp.complex128) def Y21(x): return -0.5 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) * x[2] # return -0.7725484040463791 * (x[0] + 1j * x[1]) * x[2] def Y22(x): return 0.25 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 # return 0.3862742020231896 * (x[0] + 1j * x[1]) ** 2 def Y30(x): return jax.lax.convert_element_type(0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2]), new_dtype=jnp.complex128) def Y31(x): return -0.125 * jnp.sqrt(21 / jnp.pi) * (x[0] + 1j * x[1]) * (5 * x[2] ** 2 - 1) def Y32(x): return 0.25 * jnp.sqrt(52.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * x[2] def Y33(x): return -0.125 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 def Y40(x): return jax.lax.convert_element_type(3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3), new_dtype=jnp.complex128) # return jax.lax.convert_element_type(0.1057855469152043 * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3), # new_dtype=jnp.complex128) def Y41(x): return -3 / 8 * jnp.sqrt(5 / jnp.pi) * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2]) # return -0.47308734787878 * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2]) def Y42(x): return 3 / 8 * jnp.sqrt(2.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 -1) # return 0.3345232717786446 * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 - 1) def Y43(x): return -3 / 8 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * x[2] # return 1.251671470898352 * (x[0] + 1j * x[1]) ** 3 * x[2] def Y44(x): return 3 / 16 * jnp.sqrt(17.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 # return 0.4425326924449826 * (x[0] + 1j * x[1]) ** 4 def Y50(x): return jax.lax.convert_element_type(1 / 16 * jnp.sqrt(11 / jnp.pi) * (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2]), new_dtype=jnp.complex128) def Y51(x): return -1 / 16 * jnp.sqrt(82.5 / jnp.pi) *\ (x[0] + 1j * x[1]) * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1) def Y52(x): return 1 / 8 * jnp.sqrt(577.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (3 * x[2] ** 3 - x[2]) def Y53(x): return -1 / 32 * jnp.sqrt(385 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (9 * x[2] ** 2 - 1) def Y54(x): return 3 / 16 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * x[2] def Y55(x): return -3 / 32 * jnp.sqrt(77 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 def Y60(x): return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(13 / jnp.pi) * (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5), new_dtype=jnp.complex128) def Y61(x): return -1 / 16 * jnp.sqrt(136.5 / jnp.pi) *\ (x[0] + 1j * x[1]) * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2]) def Y62(x): return 1 / 64 * jnp.sqrt(1365 / jnp.pi) *\ (x[0] + 1j * x[1]) ** 2 * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1) def Y63(x): return -1 / 32 * jnp.sqrt(1365 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (11 * x[2] ** 3 - 3 * x[2]) def Y64(x): return 3 / 32 * jnp.sqrt(45.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (11 * x[2] ** 2 - 1) def Y65(x): return -3 / 32 * jnp.sqrt(1001 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * x[2] def Y66(x): return 1 / 64 * jnp.sqrt(3003 / jnp.pi) * (x[0] + 1j * x[1]) ** 6 def Y70(x): return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(15 / jnp.pi) * (429 * x[2] ** 7 - 693 * x[2] ** 5 + 315 * x[2] ** 3 - 35 * x[2]), new_dtype=jnp.complex128) def Y71(x): return -1 / 64 * jnp.sqrt(52.5 / jnp.pi) *\ (x[0] + 1j * x[1]) * (429 * x[2] ** 6 - 495 * x[2] ** 4 + 135 * x[2] ** 2 - 5) def Y72(x): return 3 / 64 * jnp.sqrt(35 / jnp.pi) * \ (x[0] + 1j * x[1]) ** 2 * (143 * x[2] ** 5 - 110 * x[2] ** 3 + 15 * x[2]) def Y73(x): return -3 / 64 * jnp.sqrt(17.5 / jnp.pi) * \ (x[0] + 1j * x[1]) ** 3 * (143 * x[2] ** 4 - 66 * x[2] ** 2 + 3) def Y74(x): return 3 / 32 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (13 * x[2] ** 3 - 3 * x[2]) def Y75(x): return -3 / 64 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * (13 * x[2] ** 2 - 1) def Y76(x): return 3 / 64 * jnp.sqrt(5005 / jnp.pi) * (x[0] + 1j * x[1]) ** 6 * x[2] def Y77(x): return -3 / 64 * jnp.sqrt(357.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 7 def get_sph_function(l, m): if abs(m) <= l < len(sph_harm_list): return jax.jit(sph_harm_list[l * (l + 1) + m]) return 0 # @partial(jax.jit, static_argnums=(1, 2)) # def sph_harm(x: Array, l: Array, m: Array) -> Array: # return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)]) def sph_harm_fn(l: tuple, m: tuple) -> Callable[[Array,], Array]: def f(x: Array): return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)]) return f def sph_harm_fn_custom(l: tuple, m: tuple) -> Callable[[Array,], Array]: l_array = jnp.array(l) m_array = jnp.array(m) # @jax.custom_jvp @jax.custom_vjp def f(x: Array): return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)]) # @f.defjvp def sph_harm_jvp(primals, tangents): x, = primals dx, = tangents primal_out = f(x) extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) extanded_primal = extanded_primal.at[1:-1].set(primal_out) rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8 rho = jnp.sqrt(rho2) coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1)) coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array)) theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2]) phi_derivatives = 1j * m_array * primal_out x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2 y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2 z_derivatives = -theta_derivatives * rho jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]) tangent_out = jacobian.T @ dx return primal_out, tangent_out def sph_harm_fwd(x): primal_out = f(x) extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) extanded_primal = extanded_primal.at[1:-1].set(primal_out) rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8 rho = jnp.sqrt(rho2) coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1)) coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array)) theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2]) phi_derivatives = 1j * m_array * primal_out x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2 y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2 z_derivatives = -theta_derivatives * rho jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T return primal_out, jacobian def sph_harm_rev(jacobian, y_bar): return (y_bar @ jacobian,) f.defvjp(sph_harm_fwd, sph_harm_rev) return f sph_harm_list = [Y00, neg_m(Y11, 1), Y10, Y11, neg_m(Y22, 2), neg_m(Y21, 1), Y20, Y21, Y22, neg_m(Y33, 3), neg_m(Y32, 2), neg_m(Y31, 1), Y30, Y31, Y32, Y33, neg_m(Y44, 4), neg_m(Y43, 3), neg_m(Y42, 2), neg_m(Y41, 1), Y40, Y41, Y42, Y43, Y44, neg_m(Y55, 5), neg_m(Y54, 4), neg_m(Y53, 3), neg_m(Y52, 2), neg_m(Y51, 1), Y50, Y51, Y52, Y53, Y54, Y55, neg_m(Y66, 6), neg_m(Y65, 5), neg_m(Y64, 4), neg_m(Y63, 3), neg_m(Y62, 2), neg_m(Y61, 1), Y60, Y61, Y62, Y63, Y64, Y65, Y66, neg_m(Y77, 7), neg_m(Y76, 6), neg_m(Y75, 5), neg_m(Y74, 4), neg_m(Y73, 3), neg_m(Y72, 2), neg_m(Y71, 1), Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77] def Y00real(x): return 0.5 * jnp.sqrt(1 / jnp.pi) def Y1m1real(x): return 0.5 * jnp.sqrt(3 / jnp.pi) * x[1] def Y10real(x): return 0.5 * jnp.sqrt(3 / jnp.pi) * x[2] def Y11real(x): return 0.5 * jnp.sqrt(3 / jnp.pi) * x[0] def Y2m2real(x): return 0.5 * jnp.sqrt(15 / jnp.pi) * x[0] * x[1] def Y2m1real(x): return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[1] def Y20real(x): return 0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1) def Y21real(x): return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[0] def Y22real(x): return 0.25 * jnp.sqrt(15 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) def Y3m3real(x): return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) def Y3m2real(x): return 0.5 * jnp.sqrt(105 / jnp.pi) * x[0] * x[1] * x[2] def Y3m1real(x): return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[1] * (5 * x[2] ** 2 - 1) def Y30real(x): return 0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2]) def Y31real(x): return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[0] * (5 * x[2] ** 2 - 1) def Y32real(x): return 0.25 * jnp.sqrt(105 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * x[2] def Y33real(x): return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) def Y4m4real(x): return 0.75 * jnp.sqrt(35 / jnp.pi) * x[0] * x[1] * (x[0] ** 2 - x[1] ** 2) def Y4m3real(x): return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * x[2] def Y4m2real(x): return 0.75 * jnp.sqrt(5 / jnp.pi) * x[0] * x[1] * (7 * x[2] ** 2 - 1) def Y4m1real(x): return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[1] * (7 * x[2] ** 3 - 3 * x[2]) def Y40real(x): return 3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3) def Y41real(x): return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[0] * (7 * x[2] ** 3 - 3 * x[2]) def Y42real(x): return 0.375 * jnp.sqrt(5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (7 * x[2] ** 2 - 1) def Y43real(x): return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * x[2] def Y44real(x): return 0.1875 * jnp.sqrt(35 / jnp.pi) * (x[0] ** 2 * (x[0] ** 2 - 3 * x[1] ** 2) - x[1] ** 2 * (3 * x[0] ** 2 - x[1] ** 2)) def Y5m5real(x): return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5) def Y5m4real(x): return 3 / 4 * jnp.sqrt(385 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * x[2] def Y5m3real(x): return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (9 * x[2] ** 2 - 1) def Y5m2real(x): return 1 / 4 * jnp.sqrt(1155 / jnp.pi) * x[0] * x[1] * (3 * x[2] ** 3 - x[2]) def Y5m1real(x): return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[1] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1) def Y50real(x): return 1 / 16 * jnp.sqrt(11 / jnp.pi) * (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2]) def Y51real(x): return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[0] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1) def Y52real(x): return 1 / 8 * jnp.sqrt(1155 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (3 * x[2] ** 3 - x[2]) def Y53real(x): return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (9 * x[2] ** 2 - 1) def Y54real(x): return 3 / 16 * jnp.sqrt(385 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * x[2] def Y55real(x): return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4) def Y6m6real(x): return 1 / 16 * jnp.sqrt(1501.5 / jnp.pi) * x[0] * x[1] * (-3 * x[0] ** 4 + 10 * x[0] ** 2 * x[1] ** 2 - 3 * x[1] ** 4) def Y6m5real(x): return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5) * x[2] def Y6m4real(x): return 3 / 8 * jnp.sqrt(91 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * (11 * x[2] ** 2 - 1) def Y6m3real(x): return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2]) def Y6m2real(x): return 1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * x[1] * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1) def Y6m1real(x): return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[1] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2]) def Y60real(x): return 1 / 32 * jnp.sqrt(13 / jnp.pi) * (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5) def Y61real(x): return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[0] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2]) def Y62real(x): return 1 / 32 * jnp.sqrt(682.5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1) def Y63real(x): return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2]) def Y64real(x): return 3 / 32 * jnp.sqrt(91 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * (11 * x[2] ** 2 - 1) def Y65real(x): return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4) * x[2] def Y66real(x): return 1 / 32 * jnp.sqrt(1501.5 / jnp.pi) * (x[0] ** 6 - 15 * x[0] ** 4 * x[1] ** 2 + 15 * x[0] ** 2 * x[1] ** 4 - x[1] ** 6) def get_real_sph_function(l, m): if abs(m) <= l < len(sph_harm_list): return jax.jit(real_sph_harm_list[l * (l + 1) + m]) return 0 @partial(jax.jit, static_argnums=(1, 2)) def real_sph_harm(x: Array, l: Array, m: Array) -> Array: return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l, m)]) def real_sph_harm_fn_custom_fwd(l_max: int) -> Callable[[Array,], Array]: l_list = list(range(0, l_max + 1)) lm_list = [] for l in l_list: for m in range(-l, l + 1): lm_list.append((l, m)) l_list, m_list = zip(*lm_list) l_array = jnp.array(l_list) m_array = jnp.array(m_list) # indices where derivative rules differ from the general case m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0]) m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)]) m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0]) m_plus_one_factors = jnp.sign(m_array) m_minus_one_factors = -jnp.sign(m_array) minus_m_plus_one_factors = jnp.sign(m_array) minus_m_minus_one_factors = jnp.sign(m_array) # m = -1, 0, 1 special cases: m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2)) m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2)) m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2)) minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.) m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0) minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2)) @jax.custom_jvp def f(x: Array): return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)]) @f.defjvp def sph_harm_jvp(primals, tangents): x, = primals dx, = tangents primal_out = f(x) extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) extanded_primal = extanded_primal.at[1:-1].set(primal_out) mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1]) rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8 rho = jnp.sqrt(rho2) coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1)) coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array)) theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] + x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) + coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] + x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2])) phi_derivatives = m_array * primal_out x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2 y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2 z_derivatives = -theta_derivatives * rho jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]) tangent_out = jacobian.T @ dx return primal_out, tangent_out return f def real_sph_harm_fn_custom_rev(l_max: int) -> Callable[[Array, ], Array]: l_list = list(range(0, l_max + 1)) lm_list = [] for l in l_list: for m in range(-l, l + 1): lm_list.append((l, m)) l_list, m_list = zip(*lm_list) l_array = jnp.array(l_list) m_array = jnp.array(m_list) # indices where derivative rules differ from the general case m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0]) m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)]) m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0]) m_plus_one_factors = jnp.sign(m_array) m_minus_one_factors = -jnp.sign(m_array) minus_m_plus_one_factors = jnp.sign(m_array) minus_m_minus_one_factors = jnp.sign(m_array) # m = -1, 0, 1 special cases: m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2)) m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2)) m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2)) minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.) m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0) minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2)) @jax.custom_vjp def f(x: Array): return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)]) def sph_harm_fwd(x): primal_out = f(x) extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype) extanded_primal = extanded_primal.at[1:-1].set(primal_out) mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1]) rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8 rho = jnp.sqrt(rho2) coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1)) coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array)) theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] + x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) + coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] + x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2])) phi_derivatives = m_array * primal_out x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2 y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2 z_derivatives = -theta_derivatives * rho jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T return primal_out, jacobian def sph_harm_rev(jacobian, y_bar): return (y_bar @ jacobian,) f.defvjp(sph_harm_fwd, sph_harm_rev) return f real_sph_harm_list = [Y00real, Y1m1real, Y10real, Y11real, Y2m2real, Y2m1real, Y20real, Y21real, Y22real, Y3m3real, Y3m2real, Y3m1real, Y30real, Y31real, Y32real, Y33real, Y4m4real, Y4m3real, Y4m2real, Y4m1real, Y40real, Y41real, Y42real, Y43real, Y44real, Y5m5real, Y5m4real, Y5m3real, Y5m2real, Y5m1real, Y50real, Y51real, Y52real, Y53real, Y54real, Y55real, Y6m6real, Y6m5real, Y6m4real, Y6m3real, Y6m2real, Y6m1real, Y60real, Y61real, Y62real, Y63real, Y64real, Y65real, Y66real] # spherical harmonic prefactors up to l=7 ylm_prefactors = jnp.array([0.5 * jnp.sqrt(1 / jnp.pi), 0.5 * jnp.sqrt(1.5 / jnp.pi), 0.5 * jnp.sqrt(3 / jnp.pi), -0.5 * jnp.sqrt(1.5 / jnp.pi), 0.25 * jnp.sqrt(7.5 / jnp.pi), 0.5 * jnp.sqrt(7.5 / jnp.pi), 0.25 * jnp.sqrt(5 / jnp.pi), -0.5 * jnp.sqrt(7.5 / jnp.pi), 0.25 * jnp.sqrt(7.5 / jnp.pi), 0.125 * jnp.sqrt(35 / jnp.pi), 0.25 * jnp.sqrt(52.5 / jnp.pi), 0.125 * jnp.sqrt(21 / jnp.pi), 0.25 * jnp.sqrt(7 / jnp.pi), -0.125 * jnp.sqrt(21 / jnp.pi), 0.25 * jnp.sqrt(52.5 / jnp.pi), -0.125 * jnp.sqrt(35 / jnp.pi), 3 / 16 * jnp.sqrt(17.5 / jnp.pi), 3 / 8 * jnp.sqrt(35 / jnp.pi), 3 / 8 * jnp.sqrt(2.5 / jnp.pi), 3 / 8 * jnp.sqrt(5 / jnp.pi), 3 / 16 * jnp.sqrt(1 / jnp.pi), -3 / 8 * jnp.sqrt(5 / jnp.pi), 3 / 8 * jnp.sqrt(2.5 / jnp.pi), -3 / 8 * jnp.sqrt(35 / jnp.pi), 3 / 16 * jnp.sqrt(17.5 / jnp.pi), 3 / 32 * jnp.sqrt(77 / jnp.pi), 3 / 16 * jnp.sqrt(192.5 / jnp.pi), 1 / 32 * jnp.sqrt(385 / jnp.pi), 1 / 8 * jnp.sqrt(577.5 / jnp.pi), 1 / 16 * jnp.sqrt(82.5 / jnp.pi), 1 / 16 * jnp.sqrt(11 / jnp.pi), -1 / 16 * jnp.sqrt(82.5 / jnp.pi), 1 / 8 * jnp.sqrt(577.5 / jnp.pi), -1 / 32 * jnp.sqrt(385 / jnp.pi), 3 / 16 * jnp.sqrt(192.5 / jnp.pi), -3 / 32 * jnp.sqrt(77 / jnp.pi), 1 / 64 * jnp.sqrt(3003 / jnp.pi), 3 / 32 * jnp.sqrt(1001 / jnp.pi), 3 / 32 * jnp.sqrt(45.5 / jnp.pi), 1 / 32 * jnp.sqrt(1365 / jnp.pi), 1 / 64 * jnp.sqrt(1365 / jnp.pi), 1 / 16 * jnp.sqrt(136.5 / jnp.pi), 1 / 32 * jnp.sqrt(13 / jnp.pi), -1 / 16 * jnp.sqrt(136.5 / jnp.pi), 1 / 64 * jnp.sqrt(1365 / jnp.pi), -1 / 32 * jnp.sqrt(1365 / jnp.pi), 3 / 32 * jnp.sqrt(45.5 / jnp.pi), -3 / 32 * jnp.sqrt(1001 / jnp.pi), 1 / 64 * jnp.sqrt(3003 / jnp.pi), 3 / 64 * jnp.sqrt(357.5 / jnp.pi), 3 / 64 * jnp.sqrt(5005 / jnp.pi), 3 / 64 * jnp.sqrt(192.5 / jnp.pi), 3 / 32 * jnp.sqrt(192.5 / jnp.pi), 3 / 64 * jnp.sqrt(17.5 / jnp.pi), 3 / 64 * jnp.sqrt(35 / jnp.pi), 1 / 64 * jnp.sqrt(52.5 / jnp.pi), 1 / 32 * jnp.sqrt(15 / jnp.pi), -1 / 64 * jnp.sqrt(52.5 / jnp.pi), 3 / 64 * jnp.sqrt(35 / jnp.pi), -3 / 64 * jnp.sqrt(17.5 / jnp.pi), 3 / 32 * jnp.sqrt(192.5 / jnp.pi), -3 / 64 * jnp.sqrt(192.5 / jnp.pi), 3 / 64 * jnp.sqrt(5005 / jnp.pi), 3 / 64 * jnp.sqrt(357.5 / jnp.pi), ]) # coefficient array up to l=7 z_coef_array = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 3, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 5, 0, 0, 0, 0, 0], [0, -3, 0, 5, 0, 0, 0, 0], [-1, 0, 5, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 7, 0, 0, 0, 0, 0], [0, -3, 0, 7, 0, 0, 0, 0], [3, 0, -30, 0, 35, 0, 0, 0], [0, -3, 0, 7, 0, 0, 0, 0], [-1, 0, 7, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 9, 0, 0, 0, 0, 0], [0, -1, 0, 3, 0, 0, 0, 0], [1, 0, -14, 0, 21, 0, 0, 0], [0, 15, 0, -70, 0, 63, 0, 0], [1, 0, -14, 0, 21, 0, 0, 0], [0, -1, 0, 3, 0, 0, 0, 0], [-1, 0, 9, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 11, 0, 0, 0, 0, 0], [0, -3, 0, 11, 0, 0, 0, 0], [1, 0, -18, 0, 33, 0, 0, 0], [0, 5, 0, -30, 0, 33, 0, 0], [-5, 0, 105, 0, -315, 0, 231, 0], [0, 5, 0, -30, 0, 33, 0, 0], [1, 0, -18, 0, 33, 0, 0, 0], [0, -3, 0, 11, 0, 0, 0, 0], [-1, 0, 11, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [-1, 0, 13, 0, 0, 0, 0, 0], [0, -3, 0, 13, 0, 0, 0, 0], [3, 0, -66, 0, 143, 0, 0, 0], [0, 15, 0, -110, 0, 143, 0, 0], [-5, 0, 135, 0, -495, 0, 429, 0], [0, -35, 0, 315, 0, -693, 0, 429], [-5, 0, 135, 0, -495, 0, 429, 0], [0, 15, 0, -110, 0, 143, 0, 0], [3, 0, -66, 0, 143, 0, 0, 0], [0, -3, 0, 13, 0, 0, 0, 0], [-1, 0, 13, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], ]) x_p_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 0, 1], ]) x_m_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0], ]) def sph_harm_not_fast(l: Array, m: Array) -> Callable: max_l = jnp.max(l) idx = l * (l + 1) + m # l ** 2 + l + m prefactors = ylm_prefactors[idx] z_coefs = z_coef_array[idx, :max_l+1] xy = x_p_iy_coefs[idx, :max_l+1] xy_conj = x_m_iy_coefs[idx, :max_l+1] powers = jnp.arange(max_l + 1, dtype=jnp.int32) @partial(jnp.vectorize, signature='(d)->(k)') def f(x): z_powers = x[2] ** powers xy_powers = (x[0] + 1j * x[1]) ** powers xy_conj_powers = jnp.conj(xy_powers) return prefactors * (z_coefs @ z_powers) * (xy @ xy_powers) * (xy_conj @ xy_conj_powers) @partial(jnp.vectorize, signature='(d)->(k)') def f_positive_m(x): z_powers = x[2] ** powers xy_powers = (x[0] + 1j * x[1]) ** m return prefactors * (z_coefs @ z_powers) * xy_powers if jnp.all(m >= 0): return f_positive_m return jax.jit(f)