Skip to content

Commit

Permalink
replace generated_jit with overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Mar 15, 2024
1 parent a545cfd commit 83e44cf
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pfb/utils/correlations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numba import generated_jit, njit
from numba import njit
from numba.types import literal
from dask.graph_manipulation import clone
import dask.array as da
Expand Down
60 changes: 43 additions & 17 deletions pfb/utils/stokes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import numpy as np
import numexpr as ne
from numba import generated_jit, njit, prange
from numba import njit, prange, literally
from numba.extending import overload
from numba.types import literal
from dask.graph_manipulation import clone
import dask.array as da
from xarray import Dataset
from pfb.operators.gridder import vis2im
from quartical.utils.numba import coerce_literal
# from quartical.utils.numba import coerce_literal
from operator import getitem
from pfb.utils.beam import interp_beam
from pfb.utils.misc import JIT_OPTIONS
import dask
from quartical.utils.dask import Blocker


# for old style vs new style warnings
from numba.core.errors import NumbaPendingDeprecationWarning
import warnings

warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)


def weight_from_sigma(sigma):
weight = ne.evaluate('1.0/(sigma*sigma)',
casting='same_kind')
Expand Down Expand Up @@ -220,7 +226,7 @@ def single_stokes(ds=None,
# Instead of BEAM we should have a pre-init step which computes
# per facet best approximations to smooth beams as described in
# https://www.overleaf.com/read/yzrsrdwxhxrd
npix = int(np.deg2rad(opts.max_field_of_view)/cell_rad)
npix = int(np.deg2rad(opts.max_field_of_view*1.1)/cell_rad)
beam, l_beam, m_beam = interp_beam(freq_out/1e6, npix, npix,
np.rad2deg(cell_rad),
opts.beam_model,
Expand Down Expand Up @@ -260,9 +266,12 @@ def single_stokes(ds=None,
def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):

vis, wgt = _weight_data_impl(data, weight, flag, jones,
vis, wgt = _weight_data(data, weight, flag, jones,
tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc)
ant1, ant2,
literally(pol),
literally(product),
literally(nc))

out_dict = {}
out_dict['vis'] = vis
Expand All @@ -271,11 +280,28 @@ def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts,
return out_dict


@generated_jit(nopython=True, nogil=True, parallel=True) # cache=True
@njit(**JIT_OPTIONS)
def _weight_data(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):

vis, wgt = _weight_data_impl(data, weight, flag, jones,
tbin_idx, tbin_counts,
ant1, ant2,
literally(pol),
literally(product),
literally(nc))

return vis, wgt


def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):
ant1, ant2, pol, product, nc):
raise NotImplementedError


coerce_literal(weight_data, ["product", "pol", "nc"])
@overload(_weight_data_impl, **JIT_OPTIONS)
def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
ant1, ant2, pol, product, nc):

vis_func, wgt_func = stokes_funcs(data, jones, product, pol, nc)

Expand Down Expand Up @@ -340,12 +366,12 @@ def stokes_funcs(data, jones, product, pol, nc):
# Full Stokes to corr operator
# Is this the only difference between linear and circular pol?
# What about paralactic angle rotation?
if pol == literal('linear'):
if pol.literal_value == 'linear':
T = sm.Matrix([[1.0, 1.0, 0, 0],
[0, 0, 1.0, 1.0j],
[0, 0, 1.0, -1.0j],
[1.0, -1.0, 0, 0]])
elif pol == literal('circular'):
elif pol.literal_value == 'circular':
T = sm.Matrix([[1.0, 0, 0, 1.0],
[0, 1.0, 1.0j, 0],
[0, 1.0, -1.0j, 0],
Expand All @@ -360,13 +386,13 @@ def stokes_funcs(data, jones, product, pol, nc):
C = Winv * (T.H * (Mpq.H * (Sinv * Vpq)))
# C = T.H * (Mpq.H * (Sinv * Vpq))

if product == literal('I'):
if product.literal_value == 'I':
i = 0
elif product == literal('Q'):
elif product.literal_value == 'Q':
i = 1
elif product == literal('U'):
elif product.literal_value == 'U':
i = 2
elif product == literal('V'):
elif product.literal_value == 'V':
i = 3
else:
raise ValueError(f"Unknown polarisation product {product}")
Expand Down Expand Up @@ -451,7 +477,7 @@ def vfunc(gp, gq, W, V):
sm.simplify(sm.expand(C[i])))
Djfn = njit(nogil=True, inline='always')(Dsymb)

if nc==literal('4'):
if nc.literal_value == '4':
@njit(nogil=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
Expand Down Expand Up @@ -484,7 +510,7 @@ def vfunc(gp, gq, W, V):
gq00, gq11,
W00, W01, W10, W11,
V00, V01, V10, V11)
elif nc==literal('2'):
elif nc.literal_value == '2':
@njit(nogil=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
Expand Down

0 comments on commit 83e44cf

Please sign in to comment.