Skip to content

Commit

Permalink
Use factored out function in healpy wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Nov 26, 2024
1 parent 85645ae commit 972e647
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions s2fft/transforms/c_backend_spherical.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import healpy
import jax.numpy as jnp
import numpy as np
Expand All @@ -8,7 +10,7 @@
from jax.interpreters import ad

from s2fft.sampling import reindex
from s2fft.utils import quadrature_jax
from s2fft.utils import iterative_refinement, quadrature_jax


@custom_vjp
Expand Down Expand Up @@ -427,11 +429,12 @@ def healpy_forward(f: jnp.ndarray, L: int, nside: int, iter: int = 3) -> jnp.nda
Astrophysical Journal 622.2 (2005): 759
"""
flm = healpy_map2alm(f, L, nside)
for _ in range(iter):
f_recov = healpy_alm2map(flm, L, nside)
f_error = f - f_recov
flm += healpy_map2alm(f_error, L, nside)
flm = iterative_refinement.forward_with_iterative_refinement(
f=f,
n_iter=iter,
forward_function=partial(healpy_map2alm, L=L, nside=nside),
backward_function=partial(healpy_alm2map, L=L, nside=nside),
)
return reindex.flm_hp_to_2d_fast(flm, L)


Expand Down

0 comments on commit 972e647

Please sign in to comment.