Skip to content

Commit

Permalink
Factor out iterative refinement function
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Nov 26, 2024
1 parent 5c47269 commit 85645ae
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
13 changes: 7 additions & 6 deletions s2fft/transforms/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from s2fft.transforms import otf_recursions as otf
from s2fft.utils import healpix_ffts as hp
from s2fft.utils import (
iterative_refinement,
quadrature,
quadrature_jax,
resampling,
Expand Down Expand Up @@ -435,12 +436,12 @@ def forward(
forward_kwargs = common_kwargs
inverse_kwargs = {**common_kwargs, "method": "numpy"}
forward_function = forward_numpy
flm = forward_function(f, **forward_kwargs)
for _ in range(iter):
f_recov = inverse(flm, **inverse_kwargs)
f_error = f - f_recov
flm += forward_function(f_error, **forward_kwargs)
return flm
return iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(forward_function, **forward_kwargs),
backward_function=partial(inverse, **inverse_kwargs),
)
elif method == "jax_ssht":
if sampling.lower() == "healpix":
raise ValueError("SSHT does not support healpix sampling.")
Expand Down
1 change: 1 addition & 0 deletions s2fft/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import (
healpix_ffts,
iterative_refinement,
jax_primitive,
quadrature,
quadrature_jax,
Expand Down
46 changes: 46 additions & 0 deletions s2fft/utils/iterative_refinement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Iterative scheme for improving accuracy of linear transforms."""

from collections.abc import Callable
from typing import TypeVar

T = TypeVar("T")


def forward_with_iterative_refinement(
f: T,
n_iter: int,
forward_function: Callable[[T], T],
backward_function: Callable[[T], T],
) -> T:
"""
Apply forward transform with iterative refinement to improve accuracy.
`Iterative refinement <https://en.wikipedia.org/wiki/Iterative_refinement>`_ is a
general approach for improving the accuracy of numerial solutions to linear systems.
In the context of spherical harmonic transforms, given a forward transform which is
an _approximate_ inverse to a corresponding backward ('inverse') transform,
iterative refinement allows defining an iterative forward transform which is a more
accurate
Args:
f: Array argument to forward transform (signal on sphere) to compute iteratively
refined forward transform at.
n_iter: Number of refinement iterations to use, non-negative.
forward_function: Function computing forward transform (approximate inverse of
backward transform).
backward_function: Function computing backward ('inverse') transform.
Returns:
Array output from iteratively refined forward transform (spherical harmonic
coefficients).
"""
flm = forward_function(f)
for _ in range(n_iter):
f_recov = backward_function(flm)
f_error = f - f_recov
flm += forward_function(f_error)
return flm

0 comments on commit 85645ae

Please sign in to comment.