Skip to content

Commit

Permalink
Merge pull request #256 from astro-informatics/mmg/fix-generate-preco…
Browse files Browse the repository at this point in the history
…mps-pass-through

Fix pass through of arguments to `generate_precomputes` in NumPy forward spherical transform
  • Loading branch information
matt-graham authored Dec 18, 2024
2 parents a9e7c0c + be0cf52 commit 3ff7699
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
9 changes: 8 additions & 1 deletion s2fft/transforms/otf_recursions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ def inverse_latitudinal_step(
half_slices = [el + mm + 1, el - mm + 1]

if precomps is None:
precomps = generate_precomputes(L, -mm, sampling, nside, L_lower)
precomps = generate_precomputes(
L=L,
spin=-mm,
sampling=sampling,
nside=nside,
forward=False,
L_lower=L_lower,
)
lrenorm, vsign, cpi, cp2, indices = precomps

# Create copy to prevent in-place updates propagating to caller
Expand Down
15 changes: 12 additions & 3 deletions tests/test_spherical_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@pytest.mark.parametrize("method", method_to_test)
@pytest.mark.parametrize("reality", reality_to_test)
@pytest.mark.parametrize("spmd", multiple_gpus)
@pytest.mark.parametrize("use_generate_precomputes", [True, False])
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_transform_inverse(
flm_generator,
Expand All @@ -37,6 +38,7 @@ def test_transform_inverse(
method: str,
reality: bool,
spmd: bool,
use_generate_precomputes: bool,
):
if reality and spin != 0:
pytest.skip("Reality only valid for scalar fields (spin=0).")
Expand All @@ -52,7 +54,10 @@ def test_transform_inverse(
Reality=reality,
)

precomps = generate_precomputes(L, spin, sampling, L_lower=L_lower)
if use_generate_precomputes:
precomps = generate_precomputes(L, spin, sampling, L_lower=L_lower)
else:
precomps = None
f = spherical.inverse(
flm,
L,
Expand Down Expand Up @@ -106,6 +111,7 @@ def test_transform_inverse_healpix(
@pytest.mark.parametrize("method", method_to_test)
@pytest.mark.parametrize("reality", reality_to_test)
@pytest.mark.parametrize("spmd", multiple_gpus)
@pytest.mark.parametrize("use_generate_precomputes", [True, False])
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_transform_forward(
flm_generator,
Expand All @@ -116,6 +122,7 @@ def test_transform_forward(
method: str,
reality: bool,
spmd: bool,
use_generate_precomputes: bool,
):
if reality and spin != 0:
pytest.skip("Reality only valid for scalar fields (spin=0).")
Expand All @@ -131,8 +138,10 @@ def test_transform_forward(
Spin=spin,
Reality=reality,
)

precomps = generate_precomputes(L, spin, sampling, None, True, L_lower)
if use_generate_precomputes:
precomps = generate_precomputes(L, spin, sampling, None, True, L_lower)
else:
precomps = None
flm_check = spherical.forward(
f,
L,
Expand Down

0 comments on commit 3ff7699

Please sign in to comment.