diff --git a/s2fft/base_transforms/spherical.py b/s2fft/base_transforms/spherical.py index e4070e3c..98024c1e 100644 --- a/s2fft/base_transforms/spherical.py +++ b/s2fft/base_transforms/spherical.py @@ -1,3 +1,4 @@ +from functools import partial from warnings import warn import numpy as np @@ -5,7 +6,7 @@ from s2fft import recursions from s2fft.sampling import s2_samples as samples from s2fft.utils import healpix_ffts as hp -from s2fft.utils import quadrature, resampling +from s2fft.utils import iterative_refinement, quadrature, resampling def inverse( @@ -138,6 +139,7 @@ def forward( nside: int = None, reality: bool = False, L_lower: int = 0, + iter: int = 0, ) -> np.ndarray: r""" Compute forward spherical harmonic transform. @@ -164,20 +166,34 @@ def forward( L_lower (int, optional): Harmonic lower-bound. Transform will only be computed for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. + Returns: np.ndarray: Spherical harmonic coefficients. """ - return _forward( - f, - L, - spin, - sampling, - nside=nside, - method="sov_fft_vectorized", - reality=reality, - L_lower=L_lower, - ) + common_kwargs = { + "L": L, + "spin": spin, + "sampling": sampling, + "nside": nside, + "method": "sov_fft_vectorized", + "reality": reality, + "L_lower": L_lower, + } + if iter == 0: + return _forward(f, **common_kwargs) + else: + return iterative_refinement.forward_with_iterative_refinement( + f, + n_iter=iter, + forward_function=partial(_forward, **common_kwargs), + backward_function=partial(_inverse, **common_kwargs), + ) def _forward( diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index c6fe4a0d..7a1a7726 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Optional from warnings import warn import jax.numpy as jnp @@ -6,20 +7,26 @@ import torch from jax import jit +from s2fft.precompute_transforms import construct from s2fft.sampling import s2_samples as samples from s2fft.utils import healpix_ffts as hp -from s2fft.utils import resampling, resampling_jax, resampling_torch +from s2fft.utils import ( + iterative_refinement, + resampling, + resampling_jax, + resampling_torch, +) def inverse( flm: np.ndarray, L: int, spin: int = 0, - kernel: np.ndarray = None, + kernel: Optional[np.ndarray] = None, sampling: str = "mw", reality: bool = False, method: str = "jax", - nside: int = None, + nside: Optional[int] = None, ) -> np.ndarray: r""" Compute the inverse spherical harmonic transform via precompute. @@ -55,6 +62,8 @@ def inverse( np.ndarray: Pixel-space coefficients with shape. """ + if method not in _inverse_functions: + raise ValueError(f"Method {method} not recognised.") if reality and spin != 0: reality = False warn( @@ -62,14 +71,19 @@ def inverse( + "Defering to complex transform.", stacklevel=2, ) - if method == "numpy": - return inverse_transform(flm, kernel, L, sampling, reality, spin, nside) - elif method == "jax": - return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside) - elif method == "torch": - return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside) - else: - raise ValueError(f"Method {method} not recognised.") + common_kwargs = { + "L": L, + "sampling": sampling, + "reality": reality, + "spin": spin, + "nside": nside, + } + kernel = ( + _kernel_functions[method](forward=False, **common_kwargs) + if kernel is None + else kernel + ) + return _inverse_functions[method](flm, kernel, **common_kwargs) def inverse_transform( @@ -290,11 +304,12 @@ def forward( f: np.ndarray, L: int, spin: int = 0, - kernel: np.ndarray = None, + kernel: Optional[np.ndarray] = None, sampling: str = "mw", reality: bool = False, method: str = "jax", - nside: int = None, + nside: Optional[int] = None, + iter: int = 0, ) -> np.ndarray: r""" Compute the forward spherical harmonic transform via precompute. @@ -321,6 +336,12 @@ def forward( nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. + Raises: ValueError: Transform method not recognised. @@ -330,6 +351,8 @@ def forward( np.ndarray: Spherical harmonic coefficients. """ + if method not in _forward_functions: + raise ValueError(f"Method {method} not recognised.") if reality and spin != 0: reality = False warn( @@ -337,14 +360,32 @@ def forward( + "Defering to complex transform.", stacklevel=2, ) - if method == "numpy": - return forward_transform(f, kernel, L, sampling, reality, spin, nside) - elif method == "jax": - return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside) - elif method == "torch": - return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside) + common_kwargs = { + "L": L, + "sampling": sampling, + "reality": reality, + "spin": spin, + "nside": nside, + } + kernel = ( + _kernel_functions[method](forward=True, **common_kwargs) + if kernel is None + else kernel + ) + if iter == 0: + return _forward_functions[method](f, kernel, **common_kwargs) else: - raise ValueError(f"Method {method} not recognised.") + inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs) + return iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial( + _forward_functions[method], kernel=kernel, **common_kwargs + ), + backward_function=partial( + _inverse_functions[method], kernel=inverse_kernel, **common_kwargs + ), + ) def forward_transform( @@ -567,3 +608,23 @@ def forward_transform_torch( ) return flm * (-1) ** spin + + +_inverse_functions = { + "numpy": inverse_transform, + "jax": inverse_transform_jax, + "torch": inverse_transform_torch, +} + + +_forward_functions = { + "numpy": forward_transform, + "jax": forward_transform_jax, + "torch": forward_transform_torch, +} + +_kernel_functions = { + "numpy": partial(construct.spin_spherical_kernel, using_torch=False), + "jax": construct.spin_spherical_kernel_jax, + "torch": partial(construct.spin_spherical_kernel, using_torch=True), +} diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index da6a9ed4..06d1da9d 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -1,3 +1,5 @@ +from functools import partial + import healpy import jax.numpy as jnp import numpy as np @@ -8,7 +10,7 @@ from jax.interpreters import ad from s2fft.sampling import reindex -from s2fft.utils import quadrature_jax +from s2fft.utils import iterative_refinement, quadrature_jax @custom_vjp @@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda Astrophysical Journal 622.2 (2005): 759 """ - flm = healpy_map2alm(f, L, nside) - for _ in range(iter): - f_recov = healpy_alm2map(flm, L, nside) - f_error = f - f_recov - flm += healpy_map2alm(f_error, L, nside) + flm = iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial(healpy_map2alm, L=L, nside=nside), + backward_function=partial(healpy_alm2map, L=L, nside=nside), + ) return reindex.flm_hp_to_2d_fast(flm, L) diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 50718eb2..7031d04b 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -83,6 +83,9 @@ def inverse_latitudinal_step( precomps = generate_precomputes(L, -mm, sampling, nside, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + # Create copy to prevent in-place updates propagating to caller + lrenorm = lrenorm.copy() + for i in range(2): if not (reality and i == 0): m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0 @@ -490,6 +493,9 @@ def forward_latitudinal_step( precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + # Create copy to prevent in-place updates propagating to caller + lrenorm = lrenorm.copy() + for i in range(2): if not (reality and i == 0): m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0 diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 1eb138ac..52f1db61 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List +from typing import List, Optional import jax.numpy as jnp import numpy as np @@ -10,6 +10,7 @@ from s2fft.transforms import otf_recursions as otf from s2fft.utils import healpix_ffts as hp from s2fft.utils import ( + iterative_refinement, quadrature, quadrature_jax, resampling, @@ -155,6 +156,9 @@ def inverse_numpy( m_start_ind = L - 1 if reality else 0 L0 = L_lower + # Copy flm argument to avoid in-place updates being propagated back to caller + flm = flm.copy() + # Apply harmonic normalisation flm[L0:] = np.einsum( "lm,l->lm", flm[L0:], np.sqrt((2 * np.arange(L0, L) + 1) / (4 * np.pi)) @@ -333,14 +337,14 @@ def forward( f: np.ndarray, L: int, spin: int = 0, - nside: int = None, + nside: Optional[int] = None, sampling: str = "mw", method: str = "numpy", reality: bool = False, - precomps: List = None, + precomps: Optional[List] = None, spmd: bool = False, L_lower: int = 0, - iter: int = 3, + iter: Optional[int] = None, _ssht_backend: int = 1, ) -> np.ndarray: r""" @@ -376,9 +380,13 @@ def forward( L_lower (int, optional): Harmonic lower-bound. Transform will only be computed for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. - iter (int, optional): Number of subiterations for healpy. Note that iterations - increase the precision of the forward transform, but reduce the accuracy of - the gradient pass. Between 2 and 3 iterations is a good compromise. + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. If set to `None`, the default, 3 iterations will be used + if :code:`sampling == "healpix"` and :code:`method == "jax_healpy"` and zero + otherwise. Not used for `jax_ssht` method. _ssht_backend (int, optional, experimental): Whether to default to SSHT core (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental @@ -401,35 +409,34 @@ def forward( if spin >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") - if method == "numpy": - return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower) - elif method == "jax": - return forward_jax( - f, - L, - spin, - nside, - sampling, - reality, - precomps, - spmd, - L_lower, - use_healpix_custom_primitive=False, - ) - elif method == "cuda": - return forward_jax( - f, - L, - spin, - nside, - sampling, - reality, - precomps, - spmd, - L_lower, - use_healpix_custom_primitive=True, + if iter is None: + iter = 3 if sampling.lower() == "healpix" and method == "jax_healpy" else 0 + if method in {"numpy", "jax", "cuda"}: + common_kwargs = { + "L": L, + "spin": spin, + "nside": nside, + "sampling": sampling, + "reality": reality, + "L_lower": L_lower, + } + forward_kwargs = {**common_kwargs, "precomps": precomps} + inverse_kwargs = common_kwargs + if method in {"jax", "cuda"}: + forward_kwargs["spmd"] = spmd + forward_kwargs["use_healpix_custom_primitive"] = method == "cuda" + inverse_kwargs["method"] = "jax" + inverse_kwargs["spmd"] = spmd + forward_function = forward_jax + else: + inverse_kwargs["method"] = "numpy" + forward_function = forward_numpy + return iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial(forward_function, **forward_kwargs), + backward_function=partial(inverse, **inverse_kwargs), ) - elif method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") diff --git a/s2fft/utils/__init__.py b/s2fft/utils/__init__.py index e71a7bc7..fd5b9504 100644 --- a/s2fft/utils/__init__.py +++ b/s2fft/utils/__init__.py @@ -1,5 +1,6 @@ from . import ( healpix_ffts, + iterative_refinement, jax_primitive, quadrature, quadrature_jax, diff --git a/s2fft/utils/iterative_refinement.py b/s2fft/utils/iterative_refinement.py new file mode 100644 index 00000000..c8d87f66 --- /dev/null +++ b/s2fft/utils/iterative_refinement.py @@ -0,0 +1,45 @@ +"""Iterative scheme for improving accuracy of linear transforms.""" + +from typing import Callable, TypeVar + +T = TypeVar("T") + + +def forward_with_iterative_refinement( + f: T, + n_iter: int, + forward_function: Callable[[T], T], + backward_function: Callable[[T], T], +) -> T: + """ + Apply forward transform with iterative refinement to improve accuracy. + + `Iterative refinement `_ is a + general approach for improving the accuracy of numerial solutions to linear systems. + In the context of spherical harmonic transforms, given a forward transform which is + an _approximate_ inverse to a corresponding backward ('inverse') transform, + iterative refinement allows defining an iterative forward transform which is a more + accurate + + Args: + f: Array argument to forward transform (signal on sphere) to compute iteratively + refined forward transform at. + + n_iter: Number of refinement iterations to use, non-negative. + + forward_function: Function computing forward transform (approximate inverse of + backward transform). + + backward_function: Function computing backward ('inverse') transform. + + Returns: + Array output from iteratively refined forward transform (spherical harmonic + coefficients). + + """ + flm = forward_function(f) + for _ in range(n_iter): + f_recov = backward_function(flm) + f_error = f - f_recov + flm += forward_function(f_error) + return flm diff --git a/tests/test_spherical_base.py b/tests/test_spherical_base.py index 7499d0b9..ffcb6a08 100644 --- a/tests/test_spherical_base.py +++ b/tests/test_spherical_base.py @@ -14,6 +14,7 @@ sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"] reality_to_test = [False, True] +iter_to_test = [7, 10] @pytest.mark.parametrize("L", L_to_test) @@ -131,6 +132,32 @@ def test_transform_forward_healpix( np.testing.assert_allclose(flm_direct_hp, flm_check, atol=1e-14) +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("iter", iter_to_test) +def test_transform_forward_healpix_iter( + flm_generator, nside: int, reality: bool, iter: int +): + sampling = "healpix" + L = 2 * nside + flm = flm_generator(L=L, reality=True) + f = spherical.inverse(flm, L, sampling=sampling, nside=nside, reality=reality) + flm_direct = spherical.forward( + f, + L, + sampling=sampling, + nside=nside, + reality=reality, + iter=iter, + ) + # With iter >> 0 round-trip error should be small + np.testing.assert_allclose(flm_direct, flm, atol=1e-14) + # Also check for consistency with healpy with iter > 0 + flm_direct_hp = samples.flm_2d_to_hp(flm_direct, L) + flm_check = hp.sphtfunc.map2alm(np.real(f), lmax=L - 1, iter=iter) + np.testing.assert_allclose(flm_direct_hp, flm_check, atol=1e-14) + + @pytest.mark.parametrize("nside", nside_to_test) def test_healpix_nside_to_L_exceptions(flm_generator, nside: int): sampling = "healpix" diff --git a/tests/test_spherical_custom_grads.py b/tests/test_spherical_custom_grads.py index c3fb3938..a42a1f1a 100644 --- a/tests/test_spherical_custom_grads.py +++ b/tests/test_spherical_custom_grads.py @@ -161,6 +161,7 @@ def func(flm): @pytest.mark.parametrize("L_lower", L_lower_to_test) @pytest.mark.parametrize("spin", spin_to_test) @pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("iter", [0, 1, 2, 3]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_healpix_forward_custom_gradients( flm_generator, @@ -168,6 +169,7 @@ def test_healpix_forward_custom_gradients( L_lower: int, spin: int, reality: bool, + iter: int, ): sampling = "healpix" L = 2 * nside @@ -191,15 +193,17 @@ def test_healpix_forward_custom_gradients( ) def func(f): - flm = spherical.forward_jax( + flm = spherical.forward( f, L, + method="jax", spin=spin, nside=nside, L_lower=L_lower, reality=reality, precomps=precomps, sampling=sampling, + iter=iter, ) return jnp.sum(jnp.abs(flm - flm_target) ** 2) diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 629c4aa9..3f50e7c5 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pyssht as ssht import pytest @@ -8,6 +9,8 @@ from s2fft.precompute_transforms.spherical import forward, inverse from s2fft.sampling import s2_samples as samples +jax.config.update("jax_enable_x64", True) + # Maximum spin number at which Price-McEwen recursion is sufficiently accurate. # For spins > PM_MAX_STABLE_SPIN one should default to the Risbo recursion. PM_MAX_STABLE_SPIN = 6 @@ -20,6 +23,7 @@ reality_to_test = [True, False] methods_to_test = ["numpy", "jax", "torch"] recursions_to_test = ["price-mcewen", "risbo", "auto"] +iter_to_test = [0, 3] @pytest.mark.parametrize("L", L_to_test) @@ -221,6 +225,7 @@ def test_transform_forward( @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("method", methods_to_test) @pytest.mark.parametrize("recursion", recursions_to_test) +@pytest.mark.parametrize("iter", iter_to_test) def test_transform_forward_healpix( flm_generator, nside: int, @@ -228,12 +233,13 @@ def test_transform_forward_healpix( reality: bool, method: str, recursion: str, + iter: int, ): sampling = "healpix" L = ratio * nside flm = flm_generator(L=L, reality=True) f = base.inverse(flm, L, 0, sampling, nside, reality) - flm_check = base.forward(f, L, 0, sampling, nside, reality) + flm_check = base.forward(f, L, 0, sampling, nside, reality, iter=iter) kfunc = ( c.spin_spherical_kernel_jax @@ -254,6 +260,7 @@ def test_transform_forward_healpix( reality, method, nside, + iter, ) np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol) @@ -271,10 +278,11 @@ def test_transform_forward_healpix( reality, method, nside, + iter, ), ) else: - flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside) + flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside, iter) np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol) @@ -324,3 +332,19 @@ def test_transform_forward_high_spin( flm_recov = forward(f, L, spin, kernel, sampling, reality, "numpy") tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 np.testing.assert_allclose(flm_recov, flm, atol=tol, rtol=tol) + + +def test_forward_transform_unrecognised_method_raises(): + method = "invalid_method" + L = 32 + f = np.zeros(samples.f_shape(L)) + with pytest.raises(ValueError, match=f"{method} not recognised"): + forward(f, L, method=method) + + +def test_inverse_transform_unrecognised_method_raises(): + method = "invalid_method" + L = 32 + flm = np.zeros(samples.flm_shape(L)) + with pytest.raises(ValueError, match=f"{method} not recognised"): + inverse(flm, L, method=method) diff --git a/tests/test_spherical_transform.py b/tests/test_spherical_transform.py index a4129fba..2183a51b 100644 --- a/tests/test_spherical_transform.py +++ b/tests/test_spherical_transform.py @@ -150,12 +150,14 @@ def test_transform_forward( @pytest.mark.parametrize("nside", nside_to_test) @pytest.mark.parametrize("method", method_to_test) @pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("iter", [0, 1, 2, 3]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_transform_forward_healpix( flm_generator, nside: int, method: str, spmd: bool, + iter: int, ): sampling = "healpix" L = 2 * nside @@ -174,10 +176,11 @@ def test_transform_forward_healpix( reality=True, precomps=precomps, spmd=spmd, + iter=iter, ) flm_check = samples.flm_2d_to_hp(flm_check, L) - flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=0) + flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=iter) np.testing.assert_allclose(flm, flm_check, atol=1e-14)