Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative refinement support for JAX and NumPy forward (spherical) transform implementations #241

Merged
merged 19 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ba07b4d
Deduplicate dispatching logic in forward spherical transform
matt-graham Nov 13, 2024
243b75e
Prevent inverse_numpy updating flm arg in-place
matt-graham Nov 13, 2024
1c6e4a1
Iterative refinement support for jax and numpy forward spherical
matt-graham Nov 13, 2024
acd6d8e
Copy array in precomps to avoid inplace update
matt-graham Nov 13, 2024
badf99e
Test across different iter values
matt-graham Nov 13, 2024
cedf1f5
Use Optional instead of | None for Python 3.9 compat
matt-graham Nov 14, 2024
9369027
Add note about healpy wrapper gradient behaviour
matt-graham Nov 14, 2024
5c47269
Add comment to explain copy operations
matt-graham Nov 14, 2024
85645ae
Factor out iterative refinement function
matt-graham Nov 26, 2024
972e647
Use factored out function in healpy wrapper
matt-graham Nov 26, 2024
a2abd7c
Using deprecated typing.Callable to maintain py3.8 compat
matt-graham Dec 4, 2024
f50eaee
Only pass precomps to forward transform when iterating
matt-graham Dec 6, 2024
b288167
Deduplicate precompute transform wrappers
matt-graham Dec 6, 2024
d5fd078
Remove note about jax_healpy iter gradient accuracy
matt-graham Dec 6, 2024
8dbcc97
Add iterative refinement option to precompute forward transform
matt-graham Dec 6, 2024
8af3dd0
Raise error early on unrecognised method + add tests
matt-graham Dec 11, 2024
8919b2f
Use correct kernel construction functions
matt-graham Dec 11, 2024
a080eb6
Add iterative refinement to base spherical transform
matt-graham Dec 11, 2024
23fb7af
Test precompute transforms with iter > 0
matt-graham Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions s2fft/base_transforms/spherical.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import partial
from warnings import warn

import numpy as np

from s2fft import recursions
from s2fft.sampling import s2_samples as samples
from s2fft.utils import healpix_ffts as hp
from s2fft.utils import quadrature, resampling
from s2fft.utils import iterative_refinement, quadrature, resampling


def inverse(
Expand Down Expand Up @@ -138,6 +139,7 @@ def forward(
nside: int = None,
reality: bool = False,
L_lower: int = 0,
iter: int = 0,
) -> np.ndarray:
r"""
Compute forward spherical harmonic transform.
Expand All @@ -164,20 +166,34 @@ def forward(
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.

iter (int, optional): Number of iterative refinement iterations to use to
improve accuracy of forward transform (as an inverse of inverse transform).
Primarily of use with HEALPix sampling for which there is not a sampling
theorem, and round-tripping through the forward and inverse transforms will
introduce an error.

Returns:
np.ndarray: Spherical harmonic coefficients.

"""
return _forward(
f,
L,
spin,
sampling,
nside=nside,
method="sov_fft_vectorized",
reality=reality,
L_lower=L_lower,
)
common_kwargs = {
"L": L,
"spin": spin,
"sampling": sampling,
"nside": nside,
"method": "sov_fft_vectorized",
"reality": reality,
"L_lower": L_lower,
}
if iter == 0:
return _forward(f, **common_kwargs)
else:
return iterative_refinement.forward_with_iterative_refinement(
f,
n_iter=iter,
forward_function=partial(_forward, **common_kwargs),
backward_function=partial(_inverse, **common_kwargs),
)


def _forward(
Expand Down
101 changes: 81 additions & 20 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
from functools import partial
from typing import Optional
from warnings import warn

import jax.numpy as jnp
import numpy as np
import torch
from jax import jit

from s2fft.precompute_transforms import construct
from s2fft.sampling import s2_samples as samples
from s2fft.utils import healpix_ffts as hp
from s2fft.utils import resampling, resampling_jax, resampling_torch
from s2fft.utils import (
iterative_refinement,
resampling,
resampling_jax,
resampling_torch,
)


def inverse(
flm: np.ndarray,
L: int,
spin: int = 0,
kernel: np.ndarray = None,
kernel: Optional[np.ndarray] = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: int = None,
nside: Optional[int] = None,
) -> np.ndarray:
r"""
Compute the inverse spherical harmonic transform via precompute.
Expand Down Expand Up @@ -55,21 +62,28 @@ def inverse(
np.ndarray: Pixel-space coefficients with shape.

"""
if method not in _inverse_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
"Reality acceleration only supports spin 0 fields. "
+ "Defering to complex transform.",
stacklevel=2,
)
if method == "numpy":
return inverse_transform(flm, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return inverse_transform_jax(flm, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return inverse_transform_torch(flm, kernel, L, sampling, reality, spin, nside)
else:
raise ValueError(f"Method {method} not recognised.")
common_kwargs = {
"L": L,
"sampling": sampling,
"reality": reality,
"spin": spin,
"nside": nside,
}
kernel = (
_kernel_functions[method](forward=False, **common_kwargs)
if kernel is None
else kernel
)
return _inverse_functions[method](flm, kernel, **common_kwargs)


def inverse_transform(
Expand Down Expand Up @@ -290,11 +304,12 @@ def forward(
f: np.ndarray,
L: int,
spin: int = 0,
kernel: np.ndarray = None,
kernel: Optional[np.ndarray] = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: int = None,
nside: Optional[int] = None,
iter: int = 0,
) -> np.ndarray:
r"""
Compute the forward spherical harmonic transform via precompute.
Expand All @@ -321,6 +336,12 @@ def forward(
nside (int): HEALPix Nside resolution parameter. Only required
if sampling="healpix".

iter (int, optional): Number of iterative refinement iterations to use to
improve accuracy of forward transform (as an inverse of inverse transform).
Primarily of use with HEALPix sampling for which there is not a sampling
theorem, and round-tripping through the forward and inverse transforms will
introduce an error.

Raises:
ValueError: Transform method not recognised.

Expand All @@ -330,21 +351,41 @@ def forward(
np.ndarray: Spherical harmonic coefficients.

"""
if method not in _forward_functions:
raise ValueError(f"Method {method} not recognised.")
if reality and spin != 0:
reality = False
warn(
"Reality acceleration only supports spin 0 fields. "
+ "Defering to complex transform.",
stacklevel=2,
)
if method == "numpy":
return forward_transform(f, kernel, L, sampling, reality, spin, nside)
elif method == "jax":
return forward_transform_jax(f, kernel, L, sampling, reality, spin, nside)
elif method == "torch":
return forward_transform_torch(f, kernel, L, sampling, reality, spin, nside)
common_kwargs = {
"L": L,
"sampling": sampling,
"reality": reality,
"spin": spin,
"nside": nside,
}
kernel = (
_kernel_functions[method](forward=True, **common_kwargs)
if kernel is None
else kernel
)
if iter == 0:
return _forward_functions[method](f, kernel, **common_kwargs)
else:
raise ValueError(f"Method {method} not recognised.")
inverse_kernel = _kernel_functions[method](forward=False, **common_kwargs)
return iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(
_forward_functions[method], kernel=kernel, **common_kwargs
),
backward_function=partial(
_inverse_functions[method], kernel=inverse_kernel, **common_kwargs
),
)


def forward_transform(
Expand Down Expand Up @@ -567,3 +608,23 @@ def forward_transform_torch(
)

return flm * (-1) ** spin


_inverse_functions = {
"numpy": inverse_transform,
"jax": inverse_transform_jax,
"torch": inverse_transform_torch,
}


_forward_functions = {
"numpy": forward_transform,
"jax": forward_transform_jax,
"torch": forward_transform_torch,
}

_kernel_functions = {
"numpy": partial(construct.spin_spherical_kernel, using_torch=False),
"jax": construct.spin_spherical_kernel_jax,
"torch": partial(construct.spin_spherical_kernel, using_torch=True),
}
15 changes: 9 additions & 6 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import healpy
import jax.numpy as jnp
import numpy as np
Expand All @@ -8,7 +10,7 @@
from jax.interpreters import ad

from s2fft.sampling import reindex
from s2fft.utils import quadrature_jax
from s2fft.utils import iterative_refinement, quadrature_jax


@custom_vjp
Expand Down Expand Up @@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
Astrophysical Journal 622.2 (2005): 759

"""
flm = healpy_map2alm(f, L, nside)
for _ in range(iter):
f_recov = healpy_alm2map(flm, L, nside)
f_error = f - f_recov
flm += healpy_map2alm(f_error, L, nside)
flm = iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(healpy_map2alm, L=L, nside=nside),
backward_function=partial(healpy_alm2map, L=L, nside=nside),
)
return reindex.flm_hp_to_2d_fast(flm, L)


Expand Down
6 changes: 6 additions & 0 deletions s2fft/transforms/otf_recursions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def inverse_latitudinal_step(
precomps = generate_precomputes(L, -mm, sampling, nside, L_lower)
lrenorm, vsign, cpi, cp2, indices = precomps

# Create copy to prevent in-place updates propagating to caller
lrenorm = lrenorm.copy()

for i in range(2):
if not (reality and i == 0):
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0
Expand Down Expand Up @@ -490,6 +493,9 @@ def forward_latitudinal_step(
precomps = generate_precomputes(L, -mm, sampling, nside, True, L_lower)
lrenorm, vsign, cpi, cp2, indices = precomps

# Create copy to prevent in-place updates propagating to caller
lrenorm = lrenorm.copy()

for i in range(2):
if not (reality and i == 0):
m_offset = 1 if sampling in ["mwss", "healpix"] and i == 0 else 0
Expand Down
Loading
Loading