diff --git a/pfb/utils/correlations.py b/pfb/utils/correlations.py index a198add48..c1544bfaf 100644 --- a/pfb/utils/correlations.py +++ b/pfb/utils/correlations.py @@ -188,65 +188,46 @@ def _weight_data(data, weight, jones, tbin_idx, tbin_counts, return _weight_data_impl(data[0], weight[0], jones[0][0][0], tbin_idx, tbin_counts, ant1, ant2) -@generated_jit(nopython=True, nogil=True, cache=True) +@njit(nopython=True, nogil=True, cache=True) def _weight_data_impl(data, weight, jones, tbin_idx, tbin_counts, ant1, ant2): + # for dask arrays we need to adjust the chunks to + # start counting from zero + tbin_idx -= tbin_idx.min() + nt = np.shape(tbin_idx)[0] + nrow, nchan, ncorr = data.shape + vis = np.zeros((nrow, nchan), dtype=data.dtype) + wgt = np.zeros((nrow, nchan), dtype=data.real.dtype) + + for t in range(nt): + for row in range(tbin_idx[t], + tbin_idx[t] + tbin_counts[t]): + p = int(ant1[row]) + q = int(ant2[row]) + gp = jones[t, p, :, 0] + gq = jones[t, q, :, 0] + for chan in range(nchan): + wval = wgt_func(gp[chan], gq[chan], + weight[row, chan]) + wgt[row, chan] = wval + vis[row, chan] = vis_func(gp[chan], gq[chan], + weight[row, chan], + data[row, chan]) #/wval - vis_func, wgt_func = corr_funcs(data, jones) - - def _impl(data, weight, jones, tbin_idx, tbin_counts, - ant1, ant2): - # for dask arrays we need to adjust the chunks to - # start counting from zero - tbin_idx -= tbin_idx.min() - nt = np.shape(tbin_idx)[0] - nrow, nchan, ncorr = data.shape - vis = np.zeros((nrow, nchan), dtype=data.dtype) - wgt = np.zeros((nrow, nchan), dtype=data.real.dtype) - - for t in range(nt): - for row in range(tbin_idx[t], - tbin_idx[t] + tbin_counts[t]): - p = int(ant1[row]) - q = int(ant2[row]) - gp = jones[t, p, :, 0] - gq = jones[t, q, :, 0] - for chan in range(nchan): - wval = wgt_func(gp[chan], gq[chan], - weight[row, chan]) - wgt[row, chan] = wval - vis[row, chan] = vis_func(gp[chan], gq[chan], - weight[row, chan], - data[row, chan]) #/wval - - return vis, wgt - return _impl - - -def corr_funcs(data, jones): - # The expressions for DIAG_DIAG and DIAG mode are essentially the same - if jones.ndim == 5: - # I and Q have identical weights - @njit(nogil=True, cache=True, inline='always') - def wfunc(gp, gq, W): - gp00 = gp[0] - gq00 = gq[0] - W0 = W[0] - return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00)) - - @njit(nogil=True, cache=True, inline='always') - def vfunc(gp, gq, W, V): - gp00 = gp[0] - gq00 = gq[0] - W0 = W[0] - v00 = V[0] - return W0*gq00*v00*np.conjugate(gp00) - - return vfunc, wfunc - - # Full mode - elif jones.ndim == 6: - raise NotImplementedError("Full polarisation imaging not yet supported") + return vis, wgt - else: - raise ValueError("jones array has an unsupported number of dimensions") + +@njit(nogil=True, cache=True, inline='always') +def wgt_func(gp, gq, W): + gp00 = gp[0] + gq00 = gq[0] + W0 = W[0] + return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00)) + +@njit(nogil=True, cache=True, inline='always') +def vis_func(gp, gq, W, V): + gp00 = gp[0] + gq00 = gq[0] + W0 = W[0] + v00 = V[0] + return W0*gq00*v00*np.conjugate(gp00) diff --git a/pfb/utils/stokes.py b/pfb/utils/stokes.py index 4b80fd43d..3a04fae1a 100644 --- a/pfb/utils/stokes.py +++ b/pfb/utils/stokes.py @@ -1,6 +1,7 @@ import numpy as np import numexpr as ne from numba import generated_jit, njit, prange +from numba.extending import overload from numba.types import literal from dask.graph_manipulation import clone import dask.array as da @@ -9,6 +10,7 @@ from quartical.utils.numba import coerce_literal from operator import getitem from pfb.utils.beam import interp_beam +from pfb.utils.misc import JIT_OPTIONS import dask from quartical.utils.dask import Blocker @@ -384,7 +386,7 @@ def stokes_funcs(data, jones, product, pol, nc): sm.simplify(sm.expand(C[i]))) Djfn = njit(nogil=True, inline='always')(Dsymb) - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0,0] gp01 = gp[0,1] @@ -402,7 +404,7 @@ def wfunc(gp, gq, W): gq00, gq01, gq10, gq11, W00, W01, W10, W11).real - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0,0] gp01 = gp[0,1] @@ -439,7 +441,7 @@ def vfunc(gp, gq, W, V): gq00, gq11, w0, w1, w2, w3), sm.simplify(sm.expand(W[i,i]))) - Wjfn = njit(nogil=True, cache=True, inline='always')(Wsymb) + Wjfn = njit(nogil=True, inline='always')(Wsymb) Dsymb = lambdify((gp00, gp11, @@ -450,7 +452,7 @@ def vfunc(gp, gq, W, V): Djfn = njit(nogil=True, cache=True, inline='always')(Dsymb) if nc==literal('4'): - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] gp11 = gp[1] @@ -464,7 +466,7 @@ def wfunc(gp, gq, W): gq00, gq11, W00, W01, W10, W11).real - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0] gp11 = gp[1] @@ -483,7 +485,7 @@ def vfunc(gp, gq, W, V): W00, W01, W10, W11, V00, V01, V10, V11) elif nc==literal('2'): - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def wfunc(gp, gq, W): gp00 = gp[0] gp11 = gp[1] @@ -497,7 +499,7 @@ def wfunc(gp, gq, W): gq00, gq11, W00, W01, W10, W11).real - @njit(nogil=True, cache=True, inline='always') + @njit(nogil=True, inline='always') def vfunc(gp, gq, W, V): gp00 = gp[0] gp11 = gp[1] @@ -530,3 +532,20 @@ def vfunc(gp, gq, W, V): # quit() return vfunc, wfunc + + +# @njit(**JIT_OPTIONS) +# def _weight_data(data, weight, flag, jones, tbin_idx, tbin_counts, +# ant1, ant2, pol, product, nc): +# return _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, +# ant1, ant2, pol, product, nc) + +# def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, +# ant1, ant2, pol, product, nc): +# return NotImplementedError + +# @overload(_weight_data_impl, jit_options=JIT_OPTIONS) +# def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts, +# ant1, ant2, pol, product, nc): + +# return