123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965 |
- import jax.numpy as jnp
- from typing import Callable
- import jax
- from functools import partial
- Array = jnp.ndarray
- def neg_m(sph_harm: Callable, m: int) -> Callable:
- def wrapped(x):
- return -1 ** m * jnp.conj(sph_harm(x))
- return wrapped
- def Y00(x):
- return jax.lax.convert_element_type(0.5 * jnp.sqrt(1 / jnp.pi), new_dtype=jnp.complex128)
- def Y10(x):
- return jax.lax.convert_element_type(0.5 * jnp.sqrt(3 / jnp.pi) * x[2], new_dtype=jnp.complex128)
- def Y11(x):
- return -0.5 * jnp.sqrt(1.5 / jnp.pi) * (x[0] + 1j * x[1])
- def Y20(x):
- return jax.lax.convert_element_type(0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1),
- new_dtype=jnp.complex128)
- # return jax.lax.convert_element_type(0.3153915652525201 * (3 * x[2] ** 2 - 1),
- # new_dtype=jnp.complex128)
- def Y21(x):
- return -0.5 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) * x[2]
- # return -0.7725484040463791 * (x[0] + 1j * x[1]) * x[2]
- def Y22(x):
- return 0.25 * jnp.sqrt(7.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2
- # return 0.3862742020231896 * (x[0] + 1j * x[1]) ** 2
- def Y30(x):
- return jax.lax.convert_element_type(0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2]),
- new_dtype=jnp.complex128)
- def Y31(x):
- return -0.125 * jnp.sqrt(21 / jnp.pi) * (x[0] + 1j * x[1]) * (5 * x[2] ** 2 - 1)
- def Y32(x):
- return 0.25 * jnp.sqrt(52.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * x[2]
- def Y33(x):
- return -0.125 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3
- def Y40(x):
- return jax.lax.convert_element_type(3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
- new_dtype=jnp.complex128)
- # return jax.lax.convert_element_type(0.1057855469152043 * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3),
- # new_dtype=jnp.complex128)
- def Y41(x):
- return -3 / 8 * jnp.sqrt(5 / jnp.pi) * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
- # return -0.47308734787878 * (x[0] + 1j * x[1]) * (7 * x[2] ** 3 - 3 * x[2])
- def Y42(x):
- return 3 / 8 * jnp.sqrt(2.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 -1)
- # return 0.3345232717786446 * (x[0] + 1j * x[1]) ** 2 * (7 * x[2] ** 2 - 1)
- def Y43(x):
- return -3 / 8 * jnp.sqrt(35 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * x[2]
- # return 1.251671470898352 * (x[0] + 1j * x[1]) ** 3 * x[2]
- def Y44(x):
- return 3 / 16 * jnp.sqrt(17.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4
- # return 0.4425326924449826 * (x[0] + 1j * x[1]) ** 4
- def Y50(x):
- return jax.lax.convert_element_type(1 / 16 * jnp.sqrt(11 / jnp.pi) *
- (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2]),
- new_dtype=jnp.complex128)
- def Y51(x):
- return -1 / 16 * jnp.sqrt(82.5 / jnp.pi) *\
- (x[0] + 1j * x[1]) * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
- def Y52(x):
- return 1 / 8 * jnp.sqrt(577.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 2 * (3 * x[2] ** 3 - x[2])
- def Y53(x):
- return -1 / 32 * jnp.sqrt(385 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (9 * x[2] ** 2 - 1)
- def Y54(x):
- return 3 / 16 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * x[2]
- def Y55(x):
- return -3 / 32 * jnp.sqrt(77 / jnp.pi) * (x[0] + 1j * x[1]) ** 5
- def Y60(x):
- return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(13 / jnp.pi) *
- (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5),
- new_dtype=jnp.complex128)
- def Y61(x):
- return -1 / 16 * jnp.sqrt(136.5 / jnp.pi) *\
- (x[0] + 1j * x[1]) * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
- def Y62(x):
- return 1 / 64 * jnp.sqrt(1365 / jnp.pi) *\
- (x[0] + 1j * x[1]) ** 2 * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
- def Y63(x):
- return -1 / 32 * jnp.sqrt(1365 / jnp.pi) * (x[0] + 1j * x[1]) ** 3 * (11 * x[2] ** 3 - 3 * x[2])
- def Y64(x):
- return 3 / 32 * jnp.sqrt(45.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (11 * x[2] ** 2 - 1)
- def Y65(x):
- return -3 / 32 * jnp.sqrt(1001 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * x[2]
- def Y66(x):
- return 1 / 64 * jnp.sqrt(3003 / jnp.pi) * (x[0] + 1j * x[1]) ** 6
- def Y70(x):
- return jax.lax.convert_element_type(1 / 32 * jnp.sqrt(15 / jnp.pi) *
- (429 * x[2] ** 7 - 693 * x[2] ** 5 + 315 * x[2] ** 3 - 35 * x[2]),
- new_dtype=jnp.complex128)
- def Y71(x):
- return -1 / 64 * jnp.sqrt(52.5 / jnp.pi) *\
- (x[0] + 1j * x[1]) * (429 * x[2] ** 6 - 495 * x[2] ** 4 + 135 * x[2] ** 2 - 5)
- def Y72(x):
- return 3 / 64 * jnp.sqrt(35 / jnp.pi) * \
- (x[0] + 1j * x[1]) ** 2 * (143 * x[2] ** 5 - 110 * x[2] ** 3 + 15 * x[2])
- def Y73(x):
- return -3 / 64 * jnp.sqrt(17.5 / jnp.pi) * \
- (x[0] + 1j * x[1]) ** 3 * (143 * x[2] ** 4 - 66 * x[2] ** 2 + 3)
- def Y74(x):
- return 3 / 32 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 4 * (13 * x[2] ** 3 - 3 * x[2])
- def Y75(x):
- return -3 / 64 * jnp.sqrt(192.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 5 * (13 * x[2] ** 2 - 1)
- def Y76(x):
- return 3 / 64 * jnp.sqrt(5005 / jnp.pi) * (x[0] + 1j * x[1]) ** 6 * x[2]
- def Y77(x):
- return -3 / 64 * jnp.sqrt(357.5 / jnp.pi) * (x[0] + 1j * x[1]) ** 7
- def get_sph_function(l, m):
- if abs(m) <= l < len(sph_harm_list):
- return jax.jit(sph_harm_list[l * (l + 1) + m])
- return 0
- # @partial(jax.jit, static_argnums=(1, 2))
- # def sph_harm(x: Array, l: Array, m: Array) -> Array:
- # return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
- def sph_harm_fn(l: tuple, m: tuple) -> Callable[[Array,], Array]:
- def f(x: Array):
- return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
- return f
- def sph_harm_fn_custom(l: tuple, m: tuple) -> Callable[[Array,], Array]:
- l_array = jnp.array(l)
- m_array = jnp.array(m)
- # @jax.custom_jvp
- @jax.custom_vjp
- def f(x: Array):
- return jnp.stack([get_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
- # @f.defjvp
- def sph_harm_jvp(primals, tangents):
- x, = primals
- dx, = tangents
- primal_out = f(x)
- extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- extanded_primal = extanded_primal.at[1:-1].set(primal_out)
- rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
- rho = jnp.sqrt(rho2)
- coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
- coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
- theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
- phi_derivatives = 1j * m_array * primal_out
- x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
- y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
- z_derivatives = -theta_derivatives * rho
- jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
- tangent_out = jacobian.T @ dx
- return primal_out, tangent_out
- def sph_harm_fwd(x):
- primal_out = f(x)
- extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- extanded_primal = extanded_primal.at[1:-1].set(primal_out)
- rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
- rho = jnp.sqrt(rho2)
- coeffs1 = (x[0] - 1j * x[1]) / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
- coeffs2 = (x[0] + 1j * x[1]) / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
- theta_derivatives = 0.5 * (coeffs1 * extanded_primal[2:] + coeffs2 * extanded_primal[:-2])
- phi_derivatives = 1j * m_array * primal_out
- x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
- y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
- z_derivatives = -theta_derivatives * rho
- jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
- return primal_out, jacobian
- def sph_harm_rev(jacobian, y_bar):
- return (y_bar @ jacobian,)
- f.defvjp(sph_harm_fwd, sph_harm_rev)
- return f
- sph_harm_list = [Y00,
- neg_m(Y11, 1), Y10, Y11,
- neg_m(Y22, 2), neg_m(Y21, 1), Y20, Y21, Y22,
- neg_m(Y33, 3), neg_m(Y32, 2), neg_m(Y31, 1), Y30, Y31, Y32, Y33,
- neg_m(Y44, 4), neg_m(Y43, 3), neg_m(Y42, 2), neg_m(Y41, 1), Y40, Y41, Y42, Y43, Y44,
- neg_m(Y55, 5), neg_m(Y54, 4), neg_m(Y53, 3), neg_m(Y52, 2), neg_m(Y51, 1), Y50, Y51, Y52, Y53, Y54, Y55,
- neg_m(Y66, 6), neg_m(Y65, 5), neg_m(Y64, 4), neg_m(Y63, 3), neg_m(Y62, 2), neg_m(Y61, 1), Y60, Y61, Y62, Y63, Y64, Y65, Y66,
- neg_m(Y77, 7), neg_m(Y76, 6), neg_m(Y75, 5), neg_m(Y74, 4), neg_m(Y73, 3), neg_m(Y72, 2), neg_m(Y71, 1), Y70, Y71, Y72, Y73, Y74, Y75, Y76, Y77]
- def Y00real(x):
- return 0.5 * jnp.sqrt(1 / jnp.pi)
- def Y1m1real(x):
- return 0.5 * jnp.sqrt(3 / jnp.pi) * x[1]
- def Y10real(x):
- return 0.5 * jnp.sqrt(3 / jnp.pi) * x[2]
- def Y11real(x):
- return 0.5 * jnp.sqrt(3 / jnp.pi) * x[0]
- def Y2m2real(x):
- return 0.5 * jnp.sqrt(15 / jnp.pi) * x[0] * x[1]
- def Y2m1real(x):
- return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[1]
- def Y20real(x):
- return 0.25 * jnp.sqrt(5 / jnp.pi) * (3 * x[2] ** 2 - 1)
- def Y21real(x):
- return 0.5 * jnp.sqrt(15 / jnp.pi) * x[2] * x[0]
- def Y22real(x):
- return 0.25 * jnp.sqrt(15 / jnp.pi) * (x[0] ** 2 - x[1] ** 2)
- def Y3m3real(x):
- return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2)
- def Y3m2real(x):
- return 0.5 * jnp.sqrt(105 / jnp.pi) * x[0] * x[1] * x[2]
- def Y3m1real(x):
- return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[1] * (5 * x[2] ** 2 - 1)
- def Y30real(x):
- return 0.25 * jnp.sqrt(7 / jnp.pi) * (5 * x[2] ** 3 - 3 * x[2])
- def Y31real(x):
- return 0.25 * jnp.sqrt(10.5 / jnp.pi) * x[0] * (5 * x[2] ** 2 - 1)
- def Y32real(x):
- return 0.25 * jnp.sqrt(105 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * x[2]
- def Y33real(x):
- return 0.25 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2)
- def Y4m4real(x):
- return 0.75 * jnp.sqrt(35 / jnp.pi) * x[0] * x[1] * (x[0] ** 2 - x[1] ** 2)
- def Y4m3real(x):
- return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * x[2]
- def Y4m2real(x):
- return 0.75 * jnp.sqrt(5 / jnp.pi) * x[0] * x[1] * (7 * x[2] ** 2 - 1)
- def Y4m1real(x):
- return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[1] * (7 * x[2] ** 3 - 3 * x[2])
- def Y40real(x):
- return 3 / 16 * jnp.sqrt(1 / jnp.pi) * (35 * x[2] ** 4 - 30 * x[2] ** 2 + 3)
- def Y41real(x):
- return 0.75 * jnp.sqrt(2.5 / jnp.pi) * x[0] * (7 * x[2] ** 3 - 3 * x[2])
- def Y42real(x):
- return 0.375 * jnp.sqrt(5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (7 * x[2] ** 2 - 1)
- def Y43real(x):
- return 0.75 * jnp.sqrt(17.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * x[2]
- def Y44real(x):
- return 0.1875 * jnp.sqrt(35 / jnp.pi) * (x[0] ** 2 * (x[0] ** 2 - 3 * x[1] ** 2) - x[1] ** 2 * (3 * x[0] ** 2 - x[1] ** 2))
- def Y5m5real(x):
- return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5)
- def Y5m4real(x):
- return 3 / 4 * jnp.sqrt(385 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * x[2]
- def Y5m3real(x):
- return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (9 * x[2] ** 2 - 1)
- def Y5m2real(x):
- return 1 / 4 * jnp.sqrt(1155 / jnp.pi) * x[0] * x[1] * (3 * x[2] ** 3 - x[2])
- def Y5m1real(x):
- return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[1] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
- def Y50real(x):
- return 1 / 16 * jnp.sqrt(11 / jnp.pi) * (63 * x[2] ** 5 - 70 * x[2] ** 3 + 15 * x[2])
- def Y51real(x):
- return -1 / 16 * jnp.sqrt(165 / jnp.pi) * x[0] * (21 * x[2] ** 4 - 14 * x[2] ** 2 + 1)
- def Y52real(x):
- return 1 / 8 * jnp.sqrt(1155 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (3 * x[2] ** 3 - x[2])
- def Y53real(x):
- return -1 / 16 * jnp.sqrt(192.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (9 * x[2] ** 2 - 1)
- def Y54real(x):
- return 3 / 16 * jnp.sqrt(385 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * x[2]
- def Y55real(x):
- return -3 / 16 * jnp.sqrt(38.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4)
- def Y6m6real(x):
- return 1 / 16 * jnp.sqrt(1501.5 / jnp.pi) * x[0] * x[1] * (-3 * x[0] ** 4 + 10 * x[0] ** 2 * x[1] ** 2 - 3 * x[1] ** 4)
- def Y6m5real(x):
- return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (5 * x[0] ** 4 * x[1] - 10 * x[0] ** 2 * x[1] ** 3 + x[1] ** 5) * x[2]
- def Y6m4real(x):
- return 3 / 8 * jnp.sqrt(91 / jnp.pi) * x[0] * x[1] * (x[1] ** 2 - x[0] ** 2) * (11 * x[2] ** 2 - 1)
- def Y6m3real(x):
- return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[1] * (3 * x[0] ** 2 - x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
- def Y6m2real(x):
- return 1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * x[1] * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
- def Y6m1real(x):
- return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[1] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
- def Y60real(x):
- return 1 / 32 * jnp.sqrt(13 / jnp.pi) * (231 * x[2] ** 6 - 315 * x[2] ** 4 + 105 * x[2] ** 2 - 5)
- def Y61real(x):
- return -1 / 16 * jnp.sqrt(273 / jnp.pi) * x[0] * (33 * x[2] ** 5 - 30 * x[2] ** 3 + 5 * x[2])
- def Y62real(x):
- return 1 / 32 * jnp.sqrt(682.5 / jnp.pi) * (x[0] ** 2 - x[1] ** 2) * (33 * x[2] ** 4 - 18 * x[2] ** 2 + 1)
- def Y63real(x):
- return -1 / 16 * jnp.sqrt(682.5 / jnp.pi) * x[0] * (x[0] ** 2 - 3 * x[1] ** 2) * (11 * x[2] ** 3 - 3 * x[2])
- def Y64real(x):
- return 3 / 32 * jnp.sqrt(91 / jnp.pi) * (x[0] ** 4 - 6 * x[0] ** 2 * x[1] ** 2 + x[1] ** 4) * (11 * x[2] ** 2 - 1)
- def Y65real(x):
- return -3 / 16 * jnp.sqrt(500.5 / jnp.pi) * (x[0] ** 5 - 10 * x[0] ** 3 * x[1] ** 2 + 5 * x[0] * x[1] ** 4) * x[2]
- def Y66real(x):
- return 1 / 32 * jnp.sqrt(1501.5 / jnp.pi) * (x[0] ** 6 - 15 * x[0] ** 4 * x[1] ** 2 + 15 * x[0] ** 2 * x[1] ** 4 - x[1] ** 6)
- def get_real_sph_function(l, m):
- if abs(m) <= l < len(sph_harm_list):
- return jax.jit(real_sph_harm_list[l * (l + 1) + m])
- return 0
- @partial(jax.jit, static_argnums=(1, 2))
- def real_sph_harm(x: Array, l: Array, m: Array) -> Array:
- return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l, m)])
- def real_sph_harm_fn_custom_fwd(l_max: int) -> Callable[[Array,], Array]:
- l_list = list(range(0, l_max + 1))
- lm_list = []
- for l in l_list:
- for m in range(-l, l + 1):
- lm_list.append((l, m))
- l_list, m_list = zip(*lm_list)
- l_array = jnp.array(l_list)
- m_array = jnp.array(m_list)
- # indices where derivative rules differ from the general case
- m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
- m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
- m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
- m_plus_one_factors = jnp.sign(m_array)
- m_minus_one_factors = -jnp.sign(m_array)
- minus_m_plus_one_factors = jnp.sign(m_array)
- minus_m_minus_one_factors = jnp.sign(m_array)
- # m = -1, 0, 1 special cases:
- m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
- m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
- m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
- minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
- m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
- minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
- @jax.custom_jvp
- def f(x: Array):
- return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
- @f.defjvp
- def sph_harm_jvp(primals, tangents):
- x, = primals
- dx, = tangents
- primal_out = f(x)
- extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- extanded_primal = extanded_primal.at[1:-1].set(primal_out)
- mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
- rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
- rho = jnp.sqrt(rho2)
- coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
- coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
- theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
- x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
- coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
- x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
- phi_derivatives = m_array * primal_out
- x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
- y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
- z_derivatives = -theta_derivatives * rho
- jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives])
- tangent_out = jacobian.T @ dx
- return primal_out, tangent_out
- return f
- def real_sph_harm_fn_custom_rev(l_max: int) -> Callable[[Array, ], Array]:
- l_list = list(range(0, l_max + 1))
- lm_list = []
- for l in l_list:
- for m in range(-l, l + 1):
- lm_list.append((l, m))
- l_list, m_list = zip(*lm_list)
- l_array = jnp.array(l_list)
- m_array = jnp.array(m_list)
- # indices where derivative rules differ from the general case
- m_one_indices = jnp.array([l * (l + 1) + 1 for l in range(0, l_max + 1) if l > 0])
- m_zero_indices = jnp.array([l * (l + 1) for l in range(0, l_max + 1)])
- m_minus_one_indices = jnp.array([l * (l + 1) - 1 for l in range(0, l_max + 1) if l > 0])
- m_plus_one_factors = jnp.sign(m_array)
- m_minus_one_factors = -jnp.sign(m_array)
- minus_m_plus_one_factors = jnp.sign(m_array)
- minus_m_minus_one_factors = jnp.sign(m_array)
- # m = -1, 0, 1 special cases:
- m_plus_one_factors = m_plus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
- m_minus_one_factors = m_minus_one_factors.at[m_zero_indices].set(jnp.sqrt(2))
- m_minus_one_factors = m_minus_one_factors.at[m_one_indices].set(-jnp.sqrt(2))
- minus_m_minus_one_factors = minus_m_minus_one_factors.at[m_one_indices].set(0.)
- m_plus_one_factors = m_plus_one_factors.at[m_minus_one_indices].set(0)
- minus_m_plus_one_factors = minus_m_plus_one_factors.at[m_minus_one_indices].set(-jnp.sqrt(2))
- @jax.custom_vjp
- def f(x: Array):
- return jnp.stack([get_real_sph_function(sl, sm)(x) for sl, sm in zip(l_list, m_list)])
- def sph_harm_fwd(x):
- primal_out = f(x)
- extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- mirrored_extanded_primal = jnp.zeros((l_array.shape[0] + 2,), dtype=primal_out.dtype)
- extanded_primal = extanded_primal.at[1:-1].set(primal_out)
- mirrored_extanded_primal = mirrored_extanded_primal.at[1:-1].set(primal_out[::-1])
- rho2 = x[0] ** 2 + x[1] ** 2 + 1e-8
- rho = jnp.sqrt(rho2)
- coeffs1 = 1 / rho * jnp.sqrt((l_array - m_array) * (l_array + m_array + 1))
- coeffs2 = 1 / rho * jnp.sqrt((l_array - m_array + 1) * (l_array + m_array))
- theta_derivatives = 0.5 * (coeffs1 * (x[0] * m_plus_one_factors * extanded_primal[2:] +
- x[1] * minus_m_plus_one_factors * mirrored_extanded_primal[2:]) +
- coeffs2 * (x[0] * m_minus_one_factors * extanded_primal[:-2] +
- x[1] * minus_m_minus_one_factors * mirrored_extanded_primal[:-2]))
- phi_derivatives = m_array * primal_out
- x_derivatives = theta_derivatives * x[0] * x[2] / rho - phi_derivatives * x[1] / rho2
- y_derivatives = theta_derivatives * x[1] * x[2] / rho + phi_derivatives * x[0] / rho2
- z_derivatives = -theta_derivatives * rho
- jacobian = jnp.array([x_derivatives, y_derivatives, z_derivatives]).T
- return primal_out, jacobian
- def sph_harm_rev(jacobian, y_bar):
- return (y_bar @ jacobian,)
- f.defvjp(sph_harm_fwd, sph_harm_rev)
- return f
- real_sph_harm_list = [Y00real,
- Y1m1real, Y10real, Y11real,
- Y2m2real, Y2m1real, Y20real, Y21real, Y22real,
- Y3m3real, Y3m2real, Y3m1real, Y30real, Y31real, Y32real, Y33real,
- Y4m4real, Y4m3real, Y4m2real, Y4m1real, Y40real, Y41real, Y42real, Y43real, Y44real,
- Y5m5real, Y5m4real, Y5m3real, Y5m2real, Y5m1real, Y50real, Y51real, Y52real, Y53real, Y54real, Y55real,
- Y6m6real, Y6m5real, Y6m4real, Y6m3real, Y6m2real, Y6m1real, Y60real, Y61real, Y62real, Y63real, Y64real, Y65real, Y66real]
- # spherical harmonic prefactors up to l=7
- ylm_prefactors = jnp.array([0.5 * jnp.sqrt(1 / jnp.pi),
- 0.5 * jnp.sqrt(1.5 / jnp.pi),
- 0.5 * jnp.sqrt(3 / jnp.pi),
- -0.5 * jnp.sqrt(1.5 / jnp.pi),
- 0.25 * jnp.sqrt(7.5 / jnp.pi),
- 0.5 * jnp.sqrt(7.5 / jnp.pi),
- 0.25 * jnp.sqrt(5 / jnp.pi),
- -0.5 * jnp.sqrt(7.5 / jnp.pi),
- 0.25 * jnp.sqrt(7.5 / jnp.pi),
- 0.125 * jnp.sqrt(35 / jnp.pi),
- 0.25 * jnp.sqrt(52.5 / jnp.pi),
- 0.125 * jnp.sqrt(21 / jnp.pi),
- 0.25 * jnp.sqrt(7 / jnp.pi),
- -0.125 * jnp.sqrt(21 / jnp.pi),
- 0.25 * jnp.sqrt(52.5 / jnp.pi),
- -0.125 * jnp.sqrt(35 / jnp.pi),
- 3 / 16 * jnp.sqrt(17.5 / jnp.pi),
- 3 / 8 * jnp.sqrt(35 / jnp.pi),
- 3 / 8 * jnp.sqrt(2.5 / jnp.pi),
- 3 / 8 * jnp.sqrt(5 / jnp.pi),
- 3 / 16 * jnp.sqrt(1 / jnp.pi),
- -3 / 8 * jnp.sqrt(5 / jnp.pi),
- 3 / 8 * jnp.sqrt(2.5 / jnp.pi),
- -3 / 8 * jnp.sqrt(35 / jnp.pi),
- 3 / 16 * jnp.sqrt(17.5 / jnp.pi),
- 3 / 32 * jnp.sqrt(77 / jnp.pi),
- 3 / 16 * jnp.sqrt(192.5 / jnp.pi),
- 1 / 32 * jnp.sqrt(385 / jnp.pi),
- 1 / 8 * jnp.sqrt(577.5 / jnp.pi),
- 1 / 16 * jnp.sqrt(82.5 / jnp.pi),
- 1 / 16 * jnp.sqrt(11 / jnp.pi),
- -1 / 16 * jnp.sqrt(82.5 / jnp.pi),
- 1 / 8 * jnp.sqrt(577.5 / jnp.pi),
- -1 / 32 * jnp.sqrt(385 / jnp.pi),
- 3 / 16 * jnp.sqrt(192.5 / jnp.pi),
- -3 / 32 * jnp.sqrt(77 / jnp.pi),
- 1 / 64 * jnp.sqrt(3003 / jnp.pi),
- 3 / 32 * jnp.sqrt(1001 / jnp.pi),
- 3 / 32 * jnp.sqrt(45.5 / jnp.pi),
- 1 / 32 * jnp.sqrt(1365 / jnp.pi),
- 1 / 64 * jnp.sqrt(1365 / jnp.pi),
- 1 / 16 * jnp.sqrt(136.5 / jnp.pi),
- 1 / 32 * jnp.sqrt(13 / jnp.pi),
- -1 / 16 * jnp.sqrt(136.5 / jnp.pi),
- 1 / 64 * jnp.sqrt(1365 / jnp.pi),
- -1 / 32 * jnp.sqrt(1365 / jnp.pi),
- 3 / 32 * jnp.sqrt(45.5 / jnp.pi),
- -3 / 32 * jnp.sqrt(1001 / jnp.pi),
- 1 / 64 * jnp.sqrt(3003 / jnp.pi),
- 3 / 64 * jnp.sqrt(357.5 / jnp.pi),
- 3 / 64 * jnp.sqrt(5005 / jnp.pi),
- 3 / 64 * jnp.sqrt(192.5 / jnp.pi),
- 3 / 32 * jnp.sqrt(192.5 / jnp.pi),
- 3 / 64 * jnp.sqrt(17.5 / jnp.pi),
- 3 / 64 * jnp.sqrt(35 / jnp.pi),
- 1 / 64 * jnp.sqrt(52.5 / jnp.pi),
- 1 / 32 * jnp.sqrt(15 / jnp.pi),
- -1 / 64 * jnp.sqrt(52.5 / jnp.pi),
- 3 / 64 * jnp.sqrt(35 / jnp.pi),
- -3 / 64 * jnp.sqrt(17.5 / jnp.pi),
- 3 / 32 * jnp.sqrt(192.5 / jnp.pi),
- -3 / 64 * jnp.sqrt(192.5 / jnp.pi),
- 3 / 64 * jnp.sqrt(5005 / jnp.pi),
- 3 / 64 * jnp.sqrt(357.5 / jnp.pi),
- ])
- # coefficient array up to l=7
- z_coef_array = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 3, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 5, 0, 0, 0, 0, 0],
- [0, -3, 0, 5, 0, 0, 0, 0],
- [-1, 0, 5, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 7, 0, 0, 0, 0, 0],
- [0, -3, 0, 7, 0, 0, 0, 0],
- [3, 0, -30, 0, 35, 0, 0, 0],
- [0, -3, 0, 7, 0, 0, 0, 0],
- [-1, 0, 7, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 9, 0, 0, 0, 0, 0],
- [0, -1, 0, 3, 0, 0, 0, 0],
- [1, 0, -14, 0, 21, 0, 0, 0],
- [0, 15, 0, -70, 0, 63, 0, 0],
- [1, 0, -14, 0, 21, 0, 0, 0],
- [0, -1, 0, 3, 0, 0, 0, 0],
- [-1, 0, 9, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 11, 0, 0, 0, 0, 0],
- [0, -3, 0, 11, 0, 0, 0, 0],
- [1, 0, -18, 0, 33, 0, 0, 0],
- [0, 5, 0, -30, 0, 33, 0, 0],
- [-5, 0, 105, 0, -315, 0, 231, 0],
- [0, 5, 0, -30, 0, 33, 0, 0],
- [1, 0, -18, 0, 33, 0, 0, 0],
- [0, -3, 0, 11, 0, 0, 0, 0],
- [-1, 0, 11, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
-
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [-1, 0, 13, 0, 0, 0, 0, 0],
- [0, -3, 0, 13, 0, 0, 0, 0],
- [3, 0, -66, 0, 143, 0, 0, 0],
- [0, 15, 0, -110, 0, 143, 0, 0],
- [-5, 0, 135, 0, -495, 0, 429, 0],
- [0, -35, 0, 315, 0, -693, 0, 429],
- [-5, 0, 135, 0, -495, 0, 429, 0],
- [0, 15, 0, -110, 0, 143, 0, 0],
- [3, 0, -66, 0, 143, 0, 0, 0],
- [0, -3, 0, 13, 0, 0, 0, 0],
- [-1, 0, 13, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- ])
- x_p_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0],
- [0, 0, 0, 0, 0, 0, 0, 1],
- ])
- x_m_iy_coefs = jnp.array([[1, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0, 1],
- [0, 0, 0, 0, 0, 0, 1, 0],
- [0, 0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 0, 0, 0, 0],
- ])
- def sph_harm_not_fast(l: Array, m: Array) -> Callable:
- max_l = jnp.max(l)
- idx = l * (l + 1) + m # l ** 2 + l + m
- prefactors = ylm_prefactors[idx]
- z_coefs = z_coef_array[idx, :max_l+1]
- xy = x_p_iy_coefs[idx, :max_l+1]
- xy_conj = x_m_iy_coefs[idx, :max_l+1]
- powers = jnp.arange(max_l + 1, dtype=jnp.int32)
- @partial(jnp.vectorize, signature='(d)->(k)')
- def f(x):
- z_powers = x[2] ** powers
- xy_powers = (x[0] + 1j * x[1]) ** powers
- xy_conj_powers = jnp.conj(xy_powers)
- return prefactors * (z_coefs @ z_powers) * (xy @ xy_powers) * (xy_conj @ xy_conj_powers)
- @partial(jnp.vectorize, signature='(d)->(k)')
- def f_positive_m(x):
- z_powers = x[2] ** powers
- xy_powers = (x[0] + 1j * x[1]) ** m
- return prefactors * (z_coefs @ z_powers) * xy_powers
- if jnp.all(m >= 0):
- return f_positive_m
- return jax.jit(f)
|