Skip to content

Commit

Permalink
Raise error early on unrecognised method + add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Dec 11, 2024
1 parent 8dbcc97 commit 8af3dd0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
8 changes: 4 additions & 4 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def inverse(
np.ndarray: Pixel-space coefficients with shape.
"""
if method not in _inverse_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
Expand All @@ -81,8 +83,6 @@ def inverse(
if kernel is None
else kernel
)
if method not in _inverse_functions:
raise ValueError(f"Method {method} not recognised.")
return _inverse_functions[method](flm, kernel, **common_kwargs)


Expand Down Expand Up @@ -351,6 +351,8 @@ def forward(
np.ndarray: Spherical harmonic coefficients.
"""
if method not in _forward_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
Expand All @@ -370,8 +372,6 @@ def forward(
if kernel is None
else kernel
)
if method not in _forward_functions:
raise ValueError(f"Method {method} not recognised.")
if iter == 0:
return _forward_functions[method](f, kernel, **common_kwargs)
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_spherical_precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,19 @@ def test_transform_forward_high_spin(
flm_recov = forward(f, L, spin, kernel, sampling, reality, "numpy")
tol = 1e-8 if sampling.lower() in ["dh", "gl"] else 1e-12
np.testing.assert_allclose(flm_recov, flm, atol=tol, rtol=tol)


def test_forward_transform_unrecognised_method_raises():
method = "invalid_method"
L = 32
f = np.zeros(samples.f_shape(L))
with pytest.raises(ValueError, match=f"{method} not recognised"):
forward(f, L, method=method)


def test_inverse_transform_unrecognised_method_raises():
method = "invalid_method"
L = 32
flm = np.zeros(samples.flm_shape(L))
with pytest.raises(ValueError, match=f"{method} not recognised"):
inverse(flm, L, method=method)

0 comments on commit 8af3dd0

Please sign in to comment.