From ba07b4da854ce1936a5aafc001e6293acb4e0990 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 13 Nov 2024 11:48:12 +0000 Subject: [PATCH 01/19] Deduplicate dispatching logic in forward spherical transform --- s2fft/transforms/spherical.py | 45 +++++++++++++---------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 1eb138ac..23064f3d 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -401,35 +401,22 @@ def forward( if spin >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") - if method == "numpy": - return forward_numpy(f, L, spin, nside, sampling, reality, precomps, L_lower) - elif method == "jax": - return forward_jax( - f, - L, - spin, - nside, - sampling, - reality, - precomps, - spmd, - L_lower, - use_healpix_custom_primitive=False, - ) - elif method == "cuda": - return forward_jax( - f, - L, - spin, - nside, - sampling, - reality, - precomps, - spmd, - L_lower, - use_healpix_custom_primitive=True, - ) - + if method in {"numpy", "jax", "cuda"}: + kwargs = { + "f": f, + "L": L, + "spin": spin, + "nside": nside, + "sampling": sampling, + "reality": reality, + "precomps": precomps, + "L_lower": L_lower, + } + if method in {"jax", "cuda"}: + kwargs["spmd"] = spmd + kwargs["use_healpix_custom_primitive"] = method == "cuda" + forward_function = forward_numpy if method == "numpy" else forward_jax + return forward_function(**kwargs) elif method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") From 243b75e5c555be9723cfd05c3585ac7a318036c0 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 13 Nov 2024 14:00:56 +0000 Subject: [PATCH 02/19] Prevent inverse_numpy updating flm arg in-place --- s2fft/transforms/spherical.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 23064f3d..987c1fe5 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -155,6 +155,9 @@ def inverse_numpy( m_start_ind = L - 1 if reality else 0 L0 = L_lower + # Copy flm argument to avoid in-place updates being propagated back to caller + flm = flm.copy() + # Apply harmonic normalisation flm[L0:] = np.einsum( "lm,l->lm", flm[L0:], np.sqrt((2 * np.arange(L0, L) + 1) / (4 * np.pi)) From 1c6e4a193dddfd04c47ac1ce8188b05e77b81ae8 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 13 Nov 2024 15:47:15 +0000 Subject: [PATCH 03/19] Iterative refinement support for jax and numpy forward spherical --- s2fft/transforms/spherical.py | 42 +++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 987c1fe5..1323912c 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -336,14 +336,14 @@ def forward( f: np.ndarray, L: int, spin: int = 0, - nside: int = None, + nside: int | None = None, sampling: str = "mw", method: str = "numpy", reality: bool = False, - precomps: List = None, + precomps: List | None = None, spmd: bool = False, L_lower: int = 0, - iter: int = 3, + iter: int | None = None, _ssht_backend: int = 1, ) -> np.ndarray: r""" @@ -379,9 +379,13 @@ def forward( L_lower (int, optional): Harmonic lower-bound. Transform will only be computed for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. - iter (int, optional): Number of subiterations for healpy. Note that iterations - increase the precision of the forward transform, but reduce the accuracy of - the gradient pass. Between 2 and 3 iterations is a good compromise. + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. If set to `None`, the default, 3 iterations will be used + if :code:`sampling == "healpix"` and :code`method == "jax_healpy"` and zero + otherwise. Not used for `jax_ssht` method. _ssht_backend (int, optional, experimental): Whether to default to SSHT core (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental @@ -404,9 +408,10 @@ def forward( if spin >= 8 and method in ["numpy", "jax"]: raise Warning("Recursive transform may provide lower precision beyond spin ~ 8") + if iter is None: + iter = 3 if sampling.lower() == "healpix" and method == "jax_healpy" else 0 if method in {"numpy", "jax", "cuda"}: - kwargs = { - "f": f, + common_kwargs = { "L": L, "spin": spin, "nside": nside, @@ -416,10 +421,23 @@ def forward( "L_lower": L_lower, } if method in {"jax", "cuda"}: - kwargs["spmd"] = spmd - kwargs["use_healpix_custom_primitive"] = method == "cuda" - forward_function = forward_numpy if method == "numpy" else forward_jax - return forward_function(**kwargs) + forward_kwargs = { + **common_kwargs, + "spmd": spmd, + "use_healpix_custom_primitive": method == "cuda", + } + inverse_kwargs = {**common_kwargs, "method": "jax"} + forward_function = forward_jax + else: + forward_kwargs = common_kwargs + inverse_kwargs = {**common_kwargs, "method": "numpy"} + forward_function = forward_numpy + flm = forward_function(f, **forward_kwargs) + for _ in range(iter): + f_recov = inverse(flm, **inverse_kwargs) + f_error = f - f_recov + flm += forward_function(f_error, **forward_kwargs) + return flm elif method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") From acd6d8e5acccfba95611367424b5c5c5636ffab6 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 13 Nov 2024 18:03:21 +0000 Subject: [PATCH 04/19] Copy array in precomps to avoid inplace update --- s2fft/transforms/otf_recursions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 50718eb2..eff408ee 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -82,6 +82,7 @@ def inverse_latitudinal_step( if precomps is None: precomps = generate_precomputes(L, -mm, sampling, nside, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + lrenorm = lrenorm.copy() for i in range(2): if not (reality and i == 0): @@ -489,6 +490,7 @@ def forward_latitudinal_step( if precomps is None: precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + lrenorm = lrenorm.copy() for i in range(2): if not (reality and i == 0): From badf99e550d985961e5244596902b547383ab3d8 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 13 Nov 2024 18:03:36 +0000 Subject: [PATCH 05/19] Test across different iter values --- tests/test_spherical_custom_grads.py | 6 +++++- tests/test_spherical_transform.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_spherical_custom_grads.py b/tests/test_spherical_custom_grads.py index c3fb3938..a42a1f1a 100644 --- a/tests/test_spherical_custom_grads.py +++ b/tests/test_spherical_custom_grads.py @@ -161,6 +161,7 @@ def func(flm): @pytest.mark.parametrize("L_lower", L_lower_to_test) @pytest.mark.parametrize("spin", spin_to_test) @pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("iter", [0, 1, 2, 3]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_healpix_forward_custom_gradients( flm_generator, @@ -168,6 +169,7 @@ def test_healpix_forward_custom_gradients( L_lower: int, spin: int, reality: bool, + iter: int, ): sampling = "healpix" L = 2 * nside @@ -191,15 +193,17 @@ def test_healpix_forward_custom_gradients( ) def func(f): - flm = spherical.forward_jax( + flm = spherical.forward( f, L, + method="jax", spin=spin, nside=nside, L_lower=L_lower, reality=reality, precomps=precomps, sampling=sampling, + iter=iter, ) return jnp.sum(jnp.abs(flm - flm_target) ** 2) diff --git a/tests/test_spherical_transform.py b/tests/test_spherical_transform.py index a4129fba..2183a51b 100644 --- a/tests/test_spherical_transform.py +++ b/tests/test_spherical_transform.py @@ -150,12 +150,14 @@ def test_transform_forward( @pytest.mark.parametrize("nside", nside_to_test) @pytest.mark.parametrize("method", method_to_test) @pytest.mark.parametrize("spmd", multiple_gpus) +@pytest.mark.parametrize("iter", [0, 1, 2, 3]) @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_transform_forward_healpix( flm_generator, nside: int, method: str, spmd: bool, + iter: int, ): sampling = "healpix" L = 2 * nside @@ -174,10 +176,11 @@ def test_transform_forward_healpix( reality=True, precomps=precomps, spmd=spmd, + iter=iter, ) flm_check = samples.flm_2d_to_hp(flm_check, L) - flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=0) + flm = hp.sphtfunc.map2alm(f, lmax=L - 1, iter=iter) np.testing.assert_allclose(flm, flm_check, atol=1e-14) From cedf1f57b71c995172d5a95685025d1813648d37 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 14 Nov 2024 08:42:15 +0000 Subject: [PATCH 06/19] Use Optional instead of | None for Python 3.9 compat --- s2fft/transforms/spherical.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 1323912c..1944c540 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -1,5 +1,5 @@ from functools import partial -from typing import List +from typing import List, Optional import jax.numpy as jnp import numpy as np @@ -336,14 +336,14 @@ def forward( f: np.ndarray, L: int, spin: int = 0, - nside: int | None = None, + nside: Optional[int] = None, sampling: str = "mw", method: str = "numpy", reality: bool = False, - precomps: List | None = None, + precomps: Optional[List] = None, spmd: bool = False, L_lower: int = 0, - iter: int | None = None, + iter: Optional[int] = None, _ssht_backend: int = 1, ) -> np.ndarray: r""" From 936902726a0b10f68cf212cdf6e28bdf700bb612 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 14 Nov 2024 12:01:29 +0000 Subject: [PATCH 07/19] Add note about healpy wrapper gradient behaviour --- s2fft/transforms/spherical.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 1944c540..d222397d 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -384,8 +384,11 @@ def forward( Primarily of use with HEALPix sampling for which there is not a sampling theorem, and round-tripping through the forward and inverse transforms will introduce an error. If set to `None`, the default, 3 iterations will be used - if :code:`sampling == "healpix"` and :code`method == "jax_healpy"` and zero - otherwise. Not used for `jax_ssht` method. + if :code:`sampling == "healpix"` and :code:`method == "jax_healpy"` and zero + otherwise. For the `healpy` wrappers specifically (that is when + :code:`method == "jax_healpy"`) increasing the number iterations increases + the accuracy of the forward transform, but reduce the accuracy of the + gradient pass. Not used for `jax_ssht` method. _ssht_backend (int, optional, experimental): Whether to default to SSHT core (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental From 5c472694f9f4610f5507ff09ae1e102ceb338d50 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 14 Nov 2024 12:01:57 +0000 Subject: [PATCH 08/19] Add comment to explain copy operations --- s2fft/transforms/otf_recursions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index eff408ee..7031d04b 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -82,6 +82,8 @@ def inverse_latitudinal_step( if precomps is None: precomps = generate_precomputes(L, -mm, sampling, nside, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + + # Create copy to prevent in-place updates propagating to caller lrenorm = lrenorm.copy() for i in range(2): @@ -490,6 +492,8 @@ def forward_latitudinal_step( if precomps is None: precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower) lrenorm, vsign, cpi, cp2, indices = precomps + + # Create copy to prevent in-place updates propagating to caller lrenorm = lrenorm.copy() for i in range(2): From 85645ae7e1e375a654622147688e1b7182578f5c Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 26 Nov 2024 14:59:06 +0000 Subject: [PATCH 09/19] Factor out iterative refinement function --- s2fft/transforms/spherical.py | 13 ++++---- s2fft/utils/__init__.py | 1 + s2fft/utils/iterative_refinement.py | 46 +++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 s2fft/utils/iterative_refinement.py diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index d222397d..3f711e40 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -10,6 +10,7 @@ from s2fft.transforms import otf_recursions as otf from s2fft.utils import healpix_ffts as hp from s2fft.utils import ( + iterative_refinement, quadrature, quadrature_jax, resampling, @@ -435,12 +436,12 @@ def forward( forward_kwargs = common_kwargs inverse_kwargs = {**common_kwargs, "method": "numpy"} forward_function = forward_numpy - flm = forward_function(f, **forward_kwargs) - for _ in range(iter): - f_recov = inverse(flm, **inverse_kwargs) - f_error = f - f_recov - flm += forward_function(f_error, **forward_kwargs) - return flm + return iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial(forward_function, **forward_kwargs), + backward_function=partial(inverse, **inverse_kwargs), + ) elif method == "jax_ssht": if sampling.lower() == "healpix": raise ValueError("SSHT does not support healpix sampling.") diff --git a/s2fft/utils/__init__.py b/s2fft/utils/__init__.py index e71a7bc7..fd5b9504 100644 --- a/s2fft/utils/__init__.py +++ b/s2fft/utils/__init__.py @@ -1,5 +1,6 @@ from . import ( healpix_ffts, + iterative_refinement, jax_primitive, quadrature, quadrature_jax, diff --git a/s2fft/utils/iterative_refinement.py b/s2fft/utils/iterative_refinement.py new file mode 100644 index 00000000..5ae19bc0 --- /dev/null +++ b/s2fft/utils/iterative_refinement.py @@ -0,0 +1,46 @@ +"""Iterative scheme for improving accuracy of linear transforms.""" + +from collections.abc import Callable +from typing import TypeVar + +T = TypeVar("T") + + +def forward_with_iterative_refinement( + f: T, + n_iter: int, + forward_function: Callable[[T], T], + backward_function: Callable[[T], T], +) -> T: + """ + Apply forward transform with iterative refinement to improve accuracy. + + `Iterative refinement `_ is a + general approach for improving the accuracy of numerial solutions to linear systems. + In the context of spherical harmonic transforms, given a forward transform which is + an _approximate_ inverse to a corresponding backward ('inverse') transform, + iterative refinement allows defining an iterative forward transform which is a more + accurate + + Args: + f: Array argument to forward transform (signal on sphere) to compute iteratively + refined forward transform at. + + n_iter: Number of refinement iterations to use, non-negative. + + forward_function: Function computing forward transform (approximate inverse of + backward transform). + + backward_function: Function computing backward ('inverse') transform. + + Returns: + Array output from iteratively refined forward transform (spherical harmonic + coefficients). + + """ + flm = forward_function(f) + for _ in range(n_iter): + f_recov = backward_function(flm) + f_error = f - f_recov + flm += forward_function(f_error) + return flm From 972e6474b1d37b2c8b66ce5bbb855936ae0730f0 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Tue, 26 Nov 2024 16:18:50 +0000 Subject: [PATCH 10/19] Use factored out function in healpy wrapper --- s2fft/transforms/c_backend_spherical.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/s2fft/transforms/c_backend_spherical.py b/s2fft/transforms/c_backend_spherical.py index da6a9ed4..06d1da9d 100644 --- a/s2fft/transforms/c_backend_spherical.py +++ b/s2fft/transforms/c_backend_spherical.py @@ -1,3 +1,5 @@ +from functools import partial + import healpy import jax.numpy as jnp import numpy as np @@ -8,7 +10,7 @@ from jax.interpreters import ad from s2fft.sampling import reindex -from s2fft.utils import quadrature_jax +from s2fft.utils import iterative_refinement, quadrature_jax @custom_vjp @@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda Astrophysical Journal 622.2 (2005): 759 """ - flm = healpy_map2alm(f, L, nside) - for _ in range(iter): - f_recov = healpy_alm2map(flm, L, nside) - f_error = f - f_recov - flm += healpy_map2alm(f_error, L, nside) + flm = iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial(healpy_map2alm, L=L, nside=nside), + backward_function=partial(healpy_alm2map, L=L, nside=nside), + ) return reindex.flm_hp_to_2d_fast(flm, L) From a2abd7c1e855a741c44ad6dceaa83711ab4418a8 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 4 Dec 2024 17:01:39 +0000 Subject: [PATCH 11/19] Using deprecated typing.Callable to maintain py3.8 compat --- s2fft/utils/iterative_refinement.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/s2fft/utils/iterative_refinement.py b/s2fft/utils/iterative_refinement.py index 5ae19bc0..c8d87f66 100644 --- a/s2fft/utils/iterative_refinement.py +++ b/s2fft/utils/iterative_refinement.py @@ -1,7 +1,6 @@ """Iterative scheme for improving accuracy of linear transforms.""" -from collections.abc import Callable -from typing import TypeVar +from typing import Callable, TypeVar T = TypeVar("T") From f50eaeec0cce6478fee5d0760cc506222ad9cbcf Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 6 Dec 2024 19:52:43 +0000 Subject: [PATCH 12/19] Only pass precomps to forward transform when iterating --- s2fft/transforms/spherical.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 3f711e40..cf18c55c 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -421,20 +421,18 @@ def forward( "nside": nside, "sampling": sampling, "reality": reality, - "precomps": precomps, "L_lower": L_lower, } + forward_kwargs = {**common_kwargs, "precomps": precomps} + inverse_kwargs = common_kwargs if method in {"jax", "cuda"}: - forward_kwargs = { - **common_kwargs, - "spmd": spmd, - "use_healpix_custom_primitive": method == "cuda", - } - inverse_kwargs = {**common_kwargs, "method": "jax"} + forward_kwargs["spmd"] = spmd + forward_kwargs["use_healpix_custom_primitive"] = method == "cuda" + inverse_kwargs["method"] = "jax" + inverse_kwargs["spmd"] = spmd forward_function = forward_jax else: - forward_kwargs = common_kwargs - inverse_kwargs = {**common_kwargs, "method": "numpy"} + inverse_kwargs["method"] = "numpy" forward_function = forward_numpy return iterative_refinement.forward_with_iterative_refinement( f=f, From b2881672df6cf923f98b830ed305541e8a916853 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 6 Dec 2024 20:05:49 +0000 Subject: [PATCH 13/19] Deduplicate precompute transform wrappers --- s2fft/precompute_transforms/spherical.py | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index c6fe4a0d..0c85cb8b 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -62,14 +62,14 @@ def inverse( + "Defering to complex transform.", stacklevel=2, ) - if method == "numpy": - return inverse_transform(flm, kernel, L, sampling, reality, spin, nside) - elif method == "jax": - return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside) - elif method == "torch": - return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside) - else: + inverse_functions = { + "numpy": inverse_transform, + "jax": inverse_transform_jax, + "torch": inverse_transform_torch, + } + if method not in inverse_functions: raise ValueError(f"Method {method} not recognised.") + return inverse_functions[method](flm, kernel, L, sampling, reality, spin, nside) def inverse_transform( @@ -337,14 +337,14 @@ def forward( + "Defering to complex transform.", stacklevel=2, ) - if method == "numpy": - return forward_transform(f, kernel, L, sampling, reality, spin, nside) - elif method == "jax": - return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside) - elif method == "torch": - return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside) - else: + forward_functions = { + "numpy": forward_transform, + "jax": forward_transform_jax, + "torch": forward_transform_torch, + } + if method not in forward_functions: raise ValueError(f"Method {method} not recognised.") + return forward_functions[method](f, kernel, L, sampling, reality, spin, nside) def forward_transform( From d5fd078184cfe2c8c98e9b16b4b65eaf54468756 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 6 Dec 2024 20:10:05 +0000 Subject: [PATCH 14/19] Remove note about jax_healpy iter gradient accuracy --- s2fft/transforms/spherical.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index cf18c55c..52f1db61 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -386,10 +386,7 @@ def forward( theorem, and round-tripping through the forward and inverse transforms will introduce an error. If set to `None`, the default, 3 iterations will be used if :code:`sampling == "healpix"` and :code:`method == "jax_healpy"` and zero - otherwise. For the `healpy` wrappers specifically (that is when - :code:`method == "jax_healpy"`) increasing the number iterations increases - the accuracy of the forward transform, but reduce the accuracy of the - gradient pass. Not used for `jax_ssht` method. + otherwise. Not used for `jax_ssht` method. _ssht_backend (int, optional, experimental): Whether to default to SSHT core (set to 0) recursions or pick up ducc0 (set to 1) accelerated experimental From 8dbcc975fb7f09189ac627abd33fcd2400e05242 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Fri, 6 Dec 2024 20:34:53 +0000 Subject: [PATCH 15/19] Add iterative refinement option to precompute forward transform --- s2fft/precompute_transforms/spherical.py | 95 +++++++++++++++++++----- 1 file changed, 78 insertions(+), 17 deletions(-) diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 0c85cb8b..a029c9e6 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -1,4 +1,5 @@ from functools import partial +from typing import Optional from warnings import warn import jax.numpy as jnp @@ -6,20 +7,26 @@ import torch from jax import jit +from s2fft.precompute_transforms import construct from s2fft.sampling import s2_samples as samples from s2fft.utils import healpix_ffts as hp -from s2fft.utils import resampling, resampling_jax, resampling_torch +from s2fft.utils import ( + iterative_refinement, + resampling, + resampling_jax, + resampling_torch, +) def inverse( flm: np.ndarray, L: int, spin: int = 0, - kernel: np.ndarray = None, + kernel: Optional[np.ndarray] = None, sampling: str = "mw", reality: bool = False, method: str = "jax", - nside: int = None, + nside: Optional[int] = None, ) -> np.ndarray: r""" Compute the inverse spherical harmonic transform via precompute. @@ -62,14 +69,21 @@ def inverse( + "Defering to complex transform.", stacklevel=2, ) - inverse_functions = { - "numpy": inverse_transform, - "jax": inverse_transform_jax, - "torch": inverse_transform_torch, + common_kwargs = { + "L": L, + "sampling": sampling, + "reality": reality, + "spin": spin, + "nside": nside, } - if method not in inverse_functions: + kernel = ( + _kernel_functions[method](forward=False, **common_kwargs) + if kernel is None + else kernel + ) + if method not in _inverse_functions: raise ValueError(f"Method {method} not recognised.") - return inverse_functions[method](flm, kernel, L, sampling, reality, spin, nside) + return _inverse_functions[method](flm, kernel, **common_kwargs) def inverse_transform( @@ -290,11 +304,12 @@ def forward( f: np.ndarray, L: int, spin: int = 0, - kernel: np.ndarray = None, + kernel: Optional[np.ndarray] = None, sampling: str = "mw", reality: bool = False, method: str = "jax", - nside: int = None, + nside: Optional[int] = None, + iter: int = 0, ) -> np.ndarray: r""" Compute the forward spherical harmonic transform via precompute. @@ -321,6 +336,12 @@ def forward( nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. + Raises: ValueError: Transform method not recognised. @@ -337,14 +358,34 @@ def forward( + "Defering to complex transform.", stacklevel=2, ) - forward_functions = { - "numpy": forward_transform, - "jax": forward_transform_jax, - "torch": forward_transform_torch, + common_kwargs = { + "L": L, + "sampling": sampling, + "reality": reality, + "spin": spin, + "nside": nside, } - if method not in forward_functions: + kernel = ( + _kernel_functions[method](forward=True, **common_kwargs) + if kernel is None + else kernel + ) + if method not in _forward_functions: raise ValueError(f"Method {method} not recognised.") - return forward_functions[method](f, kernel, L, sampling, reality, spin, nside) + if iter == 0: + return _forward_functions[method](f, kernel, **common_kwargs) + else: + inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs) + return iterative_refinement.forward_with_iterative_refinement( + f=f, + n_iter=iter, + forward_function=partial( + _forward_functions[method], kernel=kernel, **common_kwargs + ), + backward_function=partial( + _inverse_functions[method], kernel=inverse_kernel, **common_kwargs + ), + ) def forward_transform( @@ -567,3 +608,23 @@ def forward_transform_torch( ) return flm * (-1) ** spin + + +_inverse_functions = { + "numpy": inverse_transform, + "jax": inverse_transform_jax, + "torch": inverse_transform_torch, +} + + +_forward_functions = { + "numpy": forward_transform, + "jax": forward_transform_jax, + "torch": forward_transform_torch, +} + +_kernel_functions = { + "numpy": partial(construct.fourier_wigner_kernel, using_torch=False), + "jax": construct.fourier_wigner_kernel_jax, + "torch": partial(construct.fourier_wigner_kernel, using_torch=True), +} From 8af3dd0900331e525ce002515637a300607f00ba Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 11 Dec 2024 12:46:54 +0000 Subject: [PATCH 16/19] Raise error early on unrecognised method + add tests --- s2fft/precompute_transforms/spherical.py | 8 ++++---- tests/test_spherical_precompute.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index a029c9e6..249ab72f 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -62,6 +62,8 @@ def inverse( np.ndarray: Pixel-space coefficients with shape. """ + if method not in _inverse_functions: + raise ValueError(f"Method {method} not recognised.") if reality and spin != 0: reality = False warn( @@ -81,8 +83,6 @@ def inverse( if kernel is None else kernel ) - if method not in _inverse_functions: - raise ValueError(f"Method {method} not recognised.") return _inverse_functions[method](flm, kernel, **common_kwargs) @@ -351,6 +351,8 @@ def forward( np.ndarray: Spherical harmonic coefficients. """ + if method not in _forward_functions: + raise ValueError(f"Method {method} not recognised.") if reality and spin != 0: reality = False warn( @@ -370,8 +372,6 @@ def forward( if kernel is None else kernel ) - if method not in _forward_functions: - raise ValueError(f"Method {method} not recognised.") if iter == 0: return _forward_functions[method](f, kernel, **common_kwargs) else: diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 629c4aa9..972c3312 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -324,3 +324,19 @@ def test_transform_forward_high_spin( flm_recov = forward(f, L, spin, kernel, sampling, reality, "numpy") tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12 np.testing.assert_allclose(flm_recov, flm, atol=tol, rtol=tol) + + +def test_forward_transform_unrecognised_method_raises(): + method = "invalid_method" + L = 32 + f = np.zeros(samples.f_shape(L)) + with pytest.raises(ValueError, match=f"{method} not recognised"): + forward(f, L, method=method) + + +def test_inverse_transform_unrecognised_method_raises(): + method = "invalid_method" + L = 32 + flm = np.zeros(samples.flm_shape(L)) + with pytest.raises(ValueError, match=f"{method} not recognised"): + inverse(flm, L, method=method) From 8919b2f7baa1a06614ba22c182f888956049c513 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 11 Dec 2024 13:15:18 +0000 Subject: [PATCH 17/19] Use correct kernel construction functions --- s2fft/precompute_transforms/spherical.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 249ab72f..7a1a7726 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -624,7 +624,7 @@ def forward_transform_torch( } _kernel_functions = { - "numpy": partial(construct.fourier_wigner_kernel, using_torch=False), - "jax": construct.fourier_wigner_kernel_jax, - "torch": partial(construct.fourier_wigner_kernel, using_torch=True), + "numpy": partial(construct.spin_spherical_kernel, using_torch=False), + "jax": construct.spin_spherical_kernel_jax, + "torch": partial(construct.spin_spherical_kernel, using_torch=True), } From a080eb615cd672c8e94211e80f1c801b87a3696b Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 11 Dec 2024 16:29:53 +0000 Subject: [PATCH 18/19] Add iterative refinement to base spherical transform --- s2fft/base_transforms/spherical.py | 38 +++++++++++++++++++++--------- tests/test_spherical_base.py | 27 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/s2fft/base_transforms/spherical.py b/s2fft/base_transforms/spherical.py index e4070e3c..98024c1e 100644 --- a/s2fft/base_transforms/spherical.py +++ b/s2fft/base_transforms/spherical.py @@ -1,3 +1,4 @@ +from functools import partial from warnings import warn import numpy as np @@ -5,7 +6,7 @@ from s2fft import recursions from s2fft.sampling import s2_samples as samples from s2fft.utils import healpix_ffts as hp -from s2fft.utils import quadrature, resampling +from s2fft.utils import iterative_refinement, quadrature, resampling def inverse( @@ -138,6 +139,7 @@ def forward( nside: int = None, reality: bool = False, L_lower: int = 0, + iter: int = 0, ) -> np.ndarray: r""" Compute forward spherical harmonic transform. @@ -164,20 +166,34 @@ def forward( L_lower (int, optional): Harmonic lower-bound. Transform will only be computed for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0. + iter (int, optional): Number of iterative refinement iterations to use to + improve accuracy of forward transform (as an inverse of inverse transform). + Primarily of use with HEALPix sampling for which there is not a sampling + theorem, and round-tripping through the forward and inverse transforms will + introduce an error. + Returns: np.ndarray: Spherical harmonic coefficients. """ - return _forward( - f, - L, - spin, - sampling, - nside=nside, - method="sov_fft_vectorized", - reality=reality, - L_lower=L_lower, - ) + common_kwargs = { + "L": L, + "spin": spin, + "sampling": sampling, + "nside": nside, + "method": "sov_fft_vectorized", + "reality": reality, + "L_lower": L_lower, + } + if iter == 0: + return _forward(f, **common_kwargs) + else: + return iterative_refinement.forward_with_iterative_refinement( + f, + n_iter=iter, + forward_function=partial(_forward, **common_kwargs), + backward_function=partial(_inverse, **common_kwargs), + ) def _forward( diff --git a/tests/test_spherical_base.py b/tests/test_spherical_base.py index 7499d0b9..ffcb6a08 100644 --- a/tests/test_spherical_base.py +++ b/tests/test_spherical_base.py @@ -14,6 +14,7 @@ sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"] reality_to_test = [False, True] +iter_to_test = [7, 10] @pytest.mark.parametrize("L", L_to_test) @@ -131,6 +132,32 @@ def test_transform_forward_healpix( np.testing.assert_allclose(flm_direct_hp, flm_check, atol=1e-14) +@pytest.mark.parametrize("nside", nside_to_test) +@pytest.mark.parametrize("reality", reality_to_test) +@pytest.mark.parametrize("iter", iter_to_test) +def test_transform_forward_healpix_iter( + flm_generator, nside: int, reality: bool, iter: int +): + sampling = "healpix" + L = 2 * nside + flm = flm_generator(L=L, reality=True) + f = spherical.inverse(flm, L, sampling=sampling, nside=nside, reality=reality) + flm_direct = spherical.forward( + f, + L, + sampling=sampling, + nside=nside, + reality=reality, + iter=iter, + ) + # With iter >> 0 round-trip error should be small + np.testing.assert_allclose(flm_direct, flm, atol=1e-14) + # Also check for consistency with healpy with iter > 0 + flm_direct_hp = samples.flm_2d_to_hp(flm_direct, L) + flm_check = hp.sphtfunc.map2alm(np.real(f), lmax=L - 1, iter=iter) + np.testing.assert_allclose(flm_direct_hp, flm_check, atol=1e-14) + + @pytest.mark.parametrize("nside", nside_to_test) def test_healpix_nside_to_L_exceptions(flm_generator, nside: int): sampling = "healpix" From 23fb7afa57d37b451173f2e54717d5ed665bb7d8 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Wed, 11 Dec 2024 16:30:23 +0000 Subject: [PATCH 19/19] Test precompute transforms with iter > 0 --- tests/test_spherical_precompute.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 972c3312..3f50e7c5 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pyssht as ssht import pytest @@ -8,6 +9,8 @@ from s2fft.precompute_transforms.spherical import forward, inverse from s2fft.sampling import s2_samples as samples +jax.config.update("jax_enable_x64", True) + # Maximum spin number at which Price-McEwen recursion is sufficiently accurate. # For spins > PM_MAX_STABLE_SPIN one should default to the Risbo recursion. PM_MAX_STABLE_SPIN = 6 @@ -20,6 +23,7 @@ reality_to_test = [True, False] methods_to_test = ["numpy", "jax", "torch"] recursions_to_test = ["price-mcewen", "risbo", "auto"] +iter_to_test = [0, 3] @pytest.mark.parametrize("L", L_to_test) @@ -221,6 +225,7 @@ def test_transform_forward( @pytest.mark.parametrize("reality", reality_to_test) @pytest.mark.parametrize("method", methods_to_test) @pytest.mark.parametrize("recursion", recursions_to_test) +@pytest.mark.parametrize("iter", iter_to_test) def test_transform_forward_healpix( flm_generator, nside: int, @@ -228,12 +233,13 @@ def test_transform_forward_healpix( reality: bool, method: str, recursion: str, + iter: int, ): sampling = "healpix" L = ratio * nside flm = flm_generator(L=L, reality=True) f = base.inverse(flm, L, 0, sampling, nside, reality) - flm_check = base.forward(f, L, 0, sampling, nside, reality) + flm_check = base.forward(f, L, 0, sampling, nside, reality, iter=iter) kfunc = ( c.spin_spherical_kernel_jax @@ -254,6 +260,7 @@ def test_transform_forward_healpix( reality, method, nside, + iter, ) np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol) @@ -271,10 +278,11 @@ def test_transform_forward_healpix( reality, method, nside, + iter, ), ) else: - flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside) + flm_recov = forward(f, L, 0, kernel, sampling, reality, method, nside, iter) np.testing.assert_allclose(flm_recov, flm_check, atol=tol, rtol=tol)