From 83e44cf5be51fabd029eea9bf378cfad276c97f4 Mon Sep 17 00:00:00 2001 From: landmanbester Date: Fri, 15 Mar 2024 16:59:30 +0200 Subject: [PATCH] replace generated_jit with overloads --- pfb/utils/correlations.py | 2 +- pfb/utils/stokes.py | 60 ++++++++++++++++++++++++++++----------- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/pfb/utils/correlations.py b/pfb/utils/correlations.py index c1544bfaf..2cb970bb6 100644 --- a/pfb/utils/correlations.py +++ b/pfb/utils/correlations.py @@ -1,5 +1,5 @@ import numpy as np -from numba import generated_jit, njit +from numba import njit from numba.types import literal from dask.graph_manipulation import clone import dask.array as da diff --git a/pfb/utils/stokes.py b/pfb/utils/stokes.py index da258690c..6940f9fc7 100644 --- a/pfb/utils/stokes.py +++ b/pfb/utils/stokes.py @@ -1,13 +1,12 @@ import numpy as np import numexpr as ne -from numba import generated_jit, njit, prange +from numba import njit, prange, literally from numba.extending import overload -from numba.types import literal from dask.graph_manipulation import clone import dask.array as da from xarray import Dataset from pfb.operators.gridder import vis2im -from quartical.utils.numba import coerce_literal +# from quartical.utils.numba import coerce_literal from operator import getitem from pfb.utils.beam import interp_beam from pfb.utils.misc import JIT_OPTIONS @@ -15,6 +14,13 @@ from quartical.utils.dask import Blocker +# for old style vs new style warnings +from numba.core.errors import NumbaPendingDeprecationWarning +import warnings + +warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning) + + def weight_from_sigma(sigma): weight = ne.evaluate('1.0/(sigma*sigma)', casting='same_kind') @@ -220,7 +226,7 @@ def single_stokes(ds=None, # Instead of BEAM we should have a pre-init step which computes # per facet best approximations to smooth beams as described in # https://www.overleaf.com/read/yzrsrdwxhxrd - npix = int(np.deg2rad(opts.max_field_of_view)/cell_rad) + npix = int(np.deg2rad(opts.max_field_of_view*1.1)/cell_rad) beam, l_beam, m_beam = interp_beam(freq_out/1e6, npix, npix, np.rad2deg(cell_rad), opts.beam_model, @@ -260,9 +266,12 @@ def single_stokes(ds=None, def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, ant1, ant2, pol, product, nc): - vis, wgt = _weight_data_impl(data, weight, flag, jones, + vis, wgt = _weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, - ant1, ant2, pol, product, nc) + ant1, ant2, + literally(pol), + literally(product), + literally(nc)) out_dict = {} out_dict['vis'] = vis @@ -271,11 +280,28 @@ def weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, return out_dict -@generated_jit(nopython=True, nogil=True, parallel=True) # cache=True +@njit(**JIT_OPTIONS) +def _weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, + ant1, ant2, pol, product, nc): + + vis, wgt = _weight_data_impl(data, weight, flag, jones, + tbin_idx, tbin_counts, + ant1, ant2, + literally(pol), + literally(product), + literally(nc)) + + return vis, wgt + + def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, - ant1, ant2, pol, product, nc): + ant1, ant2, pol, product, nc): + raise NotImplementedError + - coerce_literal(weight_data, ["product", "pol", "nc"]) +@overload(_weight_data_impl, **JIT_OPTIONS) +def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, + ant1, ant2, pol, product, nc): vis_func, wgt_func = stokes_funcs(data, jones, product, pol, nc) @@ -340,12 +366,12 @@ def stokes_funcs(data, jones, product, pol, nc): # Full Stokes to corr operator # Is this the only difference between linear and circular pol? # What about paralactic angle rotation? - if pol == literal('linear'): + if pol.literal_value == 'linear': T = sm.Matrix([[1.0, 1.0, 0, 0], [0, 0, 1.0, 1.0j], [0, 0, 1.0, -1.0j], [1.0, -1.0, 0, 0]]) - elif pol == literal('circular'): + elif pol.literal_value == 'circular': T = sm.Matrix([[1.0, 0, 0, 1.0], [0, 1.0, 1.0j, 0], [0, 1.0, -1.0j, 0], @@ -360,13 +386,13 @@ def stokes_funcs(data, jones, product, pol, nc): C = Winv * (T.H * (Mpq.H * (Sinv * Vpq))) # C = T.H * (Mpq.H * (Sinv * Vpq)) - if product == literal('I'): + if product.literal_value == 'I': i = 0 - elif product == literal('Q'): + elif product.literal_value == 'Q': i = 1 - elif product == literal('U'): + elif product.literal_value == 'U': i = 2 - elif product == literal('V'): + elif product.literal_value == 'V': i = 3 else: raise ValueError(f"Unknown polarisation product {product}") @@ -451,7 +477,7 @@ def vfunc(gp, gq, W, V): sm.simplify(sm.expand(C[i]))) Djfn = njit(nogil=True, inline='always')(Dsymb) - if nc==literal('4'): + if nc.literal_value == '4': @njit(nogil=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] @@ -484,7 +510,7 @@ def vfunc(gp, gq, W, V): gq00, gq11, W00, W01, W10, W11, V00, V01, V10, V11) - elif nc==literal('2'): + elif nc.literal_value == '2': @njit(nogil=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0]