diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index 5b6df8f5..da6a9ed4 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -4,7 +4,8 @@ # C backend functions for which to provide JAX frontend. import pyssht -from jax import custom_vjp +from jax import core, custom_vjp +from jax.interpreters import ad from s2fft.sampling import reindex from s2fft.utils import quadrature_jax @@ -241,83 +242,181 @@ def _ssht_forward_bwd(res, flm): return f, None, None, None, None, None -@custom_vjp -def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: - r""" - Compute the inverse scalar real spherical harmonic transform (HEALPix JAX). +# Link JAX gradients for C backend functions +ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd) +ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd) - HEALPix is a C++ library which implements the scalar spherical harmonic transform - outlined in [1]. We make use of their healpy python bindings for which we provide - custom JAX frontends, hence providing support for automatic differentiation. Currently - these transforms can only be deployed on CPU, which is a limitation of the C++ library. - Args: - flm (jnp.ndarray): Spherical harmonic coefficients. +def _complex_dtype(real_dtype): + """ + Get complex datatype corresponding to a given real datatype. - L (int): Harmonic band-limit. + Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L92 - nside (int, optional): HEALPix Nside resolution parameter. Only required - if sampling="healpix". Defaults to None. + Original license: - Returns: - jnp.ndarray: Signal on the sphere. + Copyright 2019 The JAX Authors. - Note: - [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution - discretization and fast analysis of data distributed on the sphere." The - Astrophysical Journal 622.2 (2005): 759 + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. """ - flm = reindex.flm_2d_to_hp_fast(flm, L) - f = jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside)) - return f + return (np.zeros((), real_dtype) + np.zeros((), np.complex64)).dtype + + +def _real_dtype(complex_dtype): + """ + Get real datatype corresponding to a given complex datatype. + + Derived from https://github.com/jax-ml/jax/blob/1471702adc28/jax/_src/lax/fft.py#L93 + + Original license: + + Copyright 2019 The JAX Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + """ + return np.finfo(complex_dtype).dtype + + +def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0)) -def _healpy_inverse_fwd(flm: jnp.ndarray, L: int, nside: int): - """Private function which implements the forward pass for inverse jax_healpy.""" - res = ([], L, nside) - return healpy_inverse(flm, L, nside), res +def _healpy_map2alm_abstract_eval( + f: core.ShapedArray, L: int, nside: int +) -> core.ShapedArray: + return core.ShapedArray(shape=(L * (L + 1) // 2,), dtype=_complex_dtype(f.dtype)) -def _healpy_inverse_bwd(res, f): - """Private function which implements the backward pass for inverse jax_healpy.""" - _, L, nside = res - f_new = f * (12 * nside**2) / (4 * jnp.pi) - flm_out = jnp.array( - np.conj(healpy.map2alm(np.conj(np.array(f_new)), lmax=L - 1, iter=0)) + +def _healpy_map2alm_transpose(dflm: jnp.ndarray, L: int, nside: int): + scale_factors = ( + jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2))) + * (3 * nside**2) + / jnp.pi ) - # iter MUST be zero otherwise gradient propagation is incorrect (JDM). - flm_out = reindex.flm_hp_to_2d_fast(flm_out, L) - m_conj = (-1) ** (jnp.arange(1, L) % 2) - flm_out = flm_out.at[..., L:].add( - jnp.flip(m_conj * jnp.conj(flm_out[..., : L - 1]), axis=-1) + return (jnp.conj(healpy_alm2map(jnp.conj(dflm) / scale_factors, L, nside)),) + + +_healpy_map2alm_p = core.Primitive("healpy_map2alm") +_healpy_map2alm_p.def_impl(_healpy_map2alm_impl) +_healpy_map2alm_p.def_abstract_eval(_healpy_map2alm_abstract_eval) +ad.deflinear(_healpy_map2alm_p, _healpy_map2alm_transpose) + + +def healpy_map2alm(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + """ + JAX wrapper for healpy map2alm function (forward spherical harmonic transform). + + This wrapper will return the spherical harmonic coefficients as a one dimensional + array using HEALPix (ring-ordered) indexing. To instead return a two-dimensional + array of harmonic coefficients use :py:func:`healpy_forward`. + + Args: + f (jnp.ndarray): Signal on the sphere. + + L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy. + + nside (int): HEALPix Nside resolution parameter. + + Returns: + jnp.ndarray: Harmonic coefficients of signal f. + + """ + return _healpy_map2alm_p.bind(f, L=L, nside=nside) + + +def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside)) + + +def _healpy_alm2map_abstract_eval( + flm: core.ShapedArray, L: int, nside: int +) -> core.ShapedArray: + return core.ShapedArray(shape=(12 * nside**2,), dtype=_real_dtype(flm.dtype)) + + +def _healpy_alm2map_transpose(df: jnp.ndarray, L: int, nside: int) -> tuple: + scale_factors = ( + jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2))) + * (3 * nside**2) + / jnp.pi ) - flm_out = flm_out.at[..., : L - 1].set(0) + # Scale factor above includes the inverse quadrature weight given by + # (12 * nside**2) / (4 * jnp.pi) = (3 * nside**2) / jnp.pi + # and also a factor of 2 for m>0 to account for the negative m. + # See explanation in this issue comment: + # https://github.com/astro-informatics/s2fft/issues/243#issuecomment-2500951488 + return (scale_factors * jnp.conj(healpy_map2alm(jnp.conj(df), L, nside)),) - return flm_out, None, None + +_healpy_alm2map_p = core.Primitive("healpy_alm2map") +_healpy_alm2map_p.def_impl(_healpy_alm2map_impl) +_healpy_alm2map_p.def_abstract_eval(_healpy_alm2map_abstract_eval) +ad.deflinear(_healpy_alm2map_p, _healpy_alm2map_transpose) + + +def healpy_alm2map(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + """ + JAX wrapper for healpy alm2map function (inverse spherical harmonic transform). + + This wrapper assumes the passed spherical harmonic coefficients are a one + dimensional array using HEALPix (ring-ordered) indexing. To instead pass a + two-dimensional array of harmonic coefficients use :py:func:`healpy_inverse`. + + Args: + flm (jnp.ndarray): Spherical harmonic coefficients. + + L (int): Harmonic band-limit. Equivalent to `lmax + 1` in healpy. + + nside (int): HEALPix Nside resolution parameter. + + Returns: + jnp.ndarray: Signal on the sphere. + + """ + return _healpy_alm2map_p.bind(flm, L=L, nside=nside) -@custom_vjp def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.ndarray: r""" Compute the forward scalar spherical harmonic transform (healpy JAX). HEALPix is a C++ library which implements the scalar spherical harmonic transform outlined in [1]. We make use of their healpy python bindings for which we provide - custom JAX frontends, hence providing support for automatic differentiation. Currently - these transforms can only be deployed on CPU, which is a limitation of the C++ library. + custom JAX frontends, hence providing support for automatic differentiation. + Currently these transforms can only be deployed on CPU, which is a limitation of the + C++ library. Args: f (jnp.ndarray): Signal on the sphere. L (int): Harmonic band-limit. - nside (int, optional): HEALPix Nside resolution parameter. Only required - if sampling="healpix". Defaults to None. + nside (int): HEALPix Nside resolution parameter. - iter (int, optional): Number of subiterations for healpy. Note that iterations - increase the precision of the forward transform, but reduce the accuracy of - the gradient pass. Between 2 and 3 iterations is a good compromise. + iter (int, optional): Number of subiterations (iterative refinement steps) for + healpy. Note that iterations increase the precision of the forward transform + as an inverse of inverse transform, but with a linear increase in + computational cost. Between 2 and 3 iterations is a good compromise. Returns: jnp.ndarray: Harmonic coefficients of signal f. @@ -328,28 +427,39 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda Astrophysical Journal 622.2 (2005): 759 """ - flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter)) + flm = healpy_map2alm(f, L, nside) + for _ in range(iter): + f_recov = healpy_alm2map(flm, L, nside) + f_error = f - f_recov + flm += healpy_map2alm(f_error, L, nside) return reindex.flm_hp_to_2d_fast(flm, L) -def _healpy_forward_fwd(f: jnp.ndarray, L: int, nside: int, iter: int = 3): - """Private function which implements the forward pass for forward jax_healpy.""" - res = ([], L, nside, iter) - return healpy_forward(f, L, nside, iter), res +def healpy_inverse(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray: + r""" + Compute the inverse scalar real spherical harmonic transform (HEALPix JAX). + HEALPix is a C++ library which implements the scalar spherical harmonic transform + outlined in [1]. We make use of their healpy python bindings for which we provide + custom JAX frontends, hence providing support for automatic differentiation. + Currently these transforms can only be deployed on CPU, which is a limitation of the + C++ library. -def _healpy_forward_bwd(res, flm): - """Private function which implements the backward pass for forward jax_healpy.""" - _, L, nside, _ = res - flm_new = reindex.flm_2d_to_hp_fast(flm, L) - f = jnp.array( - np.conj(healpy.alm2map(np.conj(np.array(flm_new)), lmax=L - 1, nside=nside)) - ) - return f * (4 * jnp.pi) / (12 * nside**2), None, None, None + Args: + flm (jnp.ndarray): Spherical harmonic coefficients. + L (int): Harmonic band-limit. -# Link JAX gradients for C backend functions -ssht_inverse.defvjp(_ssht_inverse_fwd, _ssht_inverse_bwd) -ssht_forward.defvjp(_ssht_forward_fwd, _ssht_forward_bwd) -healpy_inverse.defvjp(_healpy_inverse_fwd, _healpy_inverse_bwd) -healpy_forward.defvjp(_healpy_forward_fwd, _healpy_forward_bwd) + nside (int): HEALPix Nside resolution parameter. + + Returns: + jnp.ndarray: Signal on the sphere. + + Note: + [1] Gorski, Krzysztof M., et al. "HEALPix: A framework for high-resolution + discretization and fast analysis of data distributed on the sphere." The + Astrophysical Journal 622.2 (2005): 759 + + """ + flm = reindex.flm_2d_to_hp_fast(flm, L) + return healpy_alm2map(flm, L, nside) diff --git a/tests/test_spherical_custom_grads.py b/tests/test_spherical_custom_grads.py index 15f0dcd5..c3fb3938 100644 --- a/tests/test_spherical_custom_grads.py +++ b/tests/test_spherical_custom_grads.py @@ -307,22 +307,16 @@ def func(f): @pytest.mark.parametrize("nside", nside_to_test) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_healpix_c_backend_inverse_custom_gradients(flm_generator, nside: int): - sampling = "healpix" L = 2 * nside reality = True flm = flm_generator(L=L, reality=reality) - flm_target = flm_generator(L=L, reality=reality) - f_target = spherical.inverse_jax( - flm_target, L, nside=nside, sampling=sampling, reality=reality - ) def func(flm): - f = spherical.inverse( + return spherical.inverse( flm, L, 0, nside, sampling="healpix", method="jax_healpy", reality=True ) - return jnp.sum(jnp.abs(f - f_target) ** 2) - check_grads(func, (flm,), order=1, modes=("rev")) + check_grads(func, (flm,), order=2, modes=("fwd", "rev")) @pytest.mark.parametrize("nside", nside_to_test) @@ -334,16 +328,12 @@ def test_healpix_c_backend_forward_custom_gradients( sampling = "healpix" L = 2 * nside reality = True - flm_target = flm_generator(L=L, reality=reality) flm = flm_generator(L=L, reality=reality) f = spherical.inverse_jax(flm, L, nside=nside, sampling=sampling, reality=reality) def func(f): - flm = spherical.forward( + return spherical.forward( f, L, nside=nside, sampling="healpix", method="jax_healpy", iter=iter ) - return jnp.sum(jnp.abs(flm - flm_target) ** 2) - - rtol = [1e-6, 1e-2, 5e-2, 1e-2][iter] - check_grads(func, (f,), order=1, modes=("rev"), rtol=rtol) + check_grads(func, (f,), order=2, modes=("fwd", "rev"))