Skip to content

Commit

Permalink
don't cache lambdified functions
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Feb 21, 2024
1 parent 1889d09 commit 4afef38
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 65 deletions.
97 changes: 39 additions & 58 deletions pfb/utils/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,65 +188,46 @@ def _weight_data(data, weight, jones, tbin_idx, tbin_counts,
return _weight_data_impl(data[0], weight[0], jones[0][0][0],
tbin_idx, tbin_counts, ant1, ant2)

@generated_jit(nopython=True, nogil=True, cache=True)
@njit(nopython=True, nogil=True, cache=True)
def _weight_data_impl(data, weight, jones, tbin_idx, tbin_counts,
ant1, ant2):
# for dask arrays we need to adjust the chunks to
# start counting from zero
tbin_idx -= tbin_idx.min()
nt = np.shape(tbin_idx)[0]
nrow, nchan, ncorr = data.shape
vis = np.zeros((nrow, nchan), dtype=data.dtype)
wgt = np.zeros((nrow, nchan), dtype=data.real.dtype)

for t in range(nt):
for row in range(tbin_idx[t],
tbin_idx[t] + tbin_counts[t]):
p = int(ant1[row])
q = int(ant2[row])
gp = jones[t, p, :, 0]
gq = jones[t, q, :, 0]
for chan in range(nchan):
wval = wgt_func(gp[chan], gq[chan],
weight[row, chan])
wgt[row, chan] = wval
vis[row, chan] = vis_func(gp[chan], gq[chan],
weight[row, chan],
data[row, chan]) #/wval

vis_func, wgt_func = corr_funcs(data, jones)

def _impl(data, weight, jones, tbin_idx, tbin_counts,
ant1, ant2):
# for dask arrays we need to adjust the chunks to
# start counting from zero
tbin_idx -= tbin_idx.min()
nt = np.shape(tbin_idx)[0]
nrow, nchan, ncorr = data.shape
vis = np.zeros((nrow, nchan), dtype=data.dtype)
wgt = np.zeros((nrow, nchan), dtype=data.real.dtype)

for t in range(nt):
for row in range(tbin_idx[t],
tbin_idx[t] + tbin_counts[t]):
p = int(ant1[row])
q = int(ant2[row])
gp = jones[t, p, :, 0]
gq = jones[t, q, :, 0]
for chan in range(nchan):
wval = wgt_func(gp[chan], gq[chan],
weight[row, chan])
wgt[row, chan] = wval
vis[row, chan] = vis_func(gp[chan], gq[chan],
weight[row, chan],
data[row, chan]) #/wval

return vis, wgt
return _impl


def corr_funcs(data, jones):
# The expressions for DIAG_DIAG and DIAG mode are essentially the same
if jones.ndim == 5:
# I and Q have identical weights
@njit(nogil=True, cache=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gq00 = gq[0]
W0 = W[0]
return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00))

@njit(nogil=True, cache=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gq00 = gq[0]
W0 = W[0]
v00 = V[0]
return W0*gq00*v00*np.conjugate(gp00)

return vfunc, wfunc

# Full mode
elif jones.ndim == 6:
raise NotImplementedError("Full polarisation imaging not yet supported")
return vis, wgt

else:
raise ValueError("jones array has an unsupported number of dimensions")

@njit(nogil=True, cache=True, inline='always')
def wgt_func(gp, gq, W):
gp00 = gp[0]
gq00 = gq[0]
W0 = W[0]
return np.real(W0*gp00*gq00*np.conjugate(gp00)*np.conjugate(gq00))

@njit(nogil=True, cache=True, inline='always')
def vis_func(gp, gq, W, V):
gp00 = gp[0]
gq00 = gq[0]
W0 = W[0]
v00 = V[0]
return W0*gq00*v00*np.conjugate(gp00)
33 changes: 26 additions & 7 deletions pfb/utils/stokes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import numexpr as ne
from numba import generated_jit, njit, prange
from numba.extending import overload
from numba.types import literal
from dask.graph_manipulation import clone
import dask.array as da
Expand All @@ -9,6 +10,7 @@
from quartical.utils.numba import coerce_literal
from operator import getitem
from pfb.utils.beam import interp_beam
from pfb.utils.misc import JIT_OPTIONS
import dask
from quartical.utils.dask import Blocker

Expand Down Expand Up @@ -384,7 +386,7 @@ def stokes_funcs(data, jones, product, pol, nc):
sm.simplify(sm.expand(C[i])))
Djfn = njit(nogil=True, inline='always')(Dsymb)

@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0,0]
gp01 = gp[0,1]
Expand All @@ -402,7 +404,7 @@ def wfunc(gp, gq, W):
gq00, gq01, gq10, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0,0]
gp01 = gp[0,1]
Expand Down Expand Up @@ -439,7 +441,7 @@ def vfunc(gp, gq, W, V):
gq00, gq11,
w0, w1, w2, w3),
sm.simplify(sm.expand(W[i,i])))
Wjfn = njit(nogil=True, cache=True, inline='always')(Wsymb)
Wjfn = njit(nogil=True, inline='always')(Wsymb)


Dsymb = lambdify((gp00, gp11,
Expand All @@ -450,7 +452,7 @@ def vfunc(gp, gq, W, V):
Djfn = njit(nogil=True, cache=True, inline='always')(Dsymb)

if nc==literal('4'):
@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -464,7 +466,7 @@ def wfunc(gp, gq, W):
gq00, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -483,7 +485,7 @@ def vfunc(gp, gq, W, V):
W00, W01, W10, W11,
V00, V01, V10, V11)
elif nc==literal('2'):
@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def wfunc(gp, gq, W):
gp00 = gp[0]
gp11 = gp[1]
Expand All @@ -497,7 +499,7 @@ def wfunc(gp, gq, W):
gq00, gq11,
W00, W01, W10, W11).real

@njit(nogil=True, cache=True, inline='always')
@njit(nogil=True, inline='always')
def vfunc(gp, gq, W, V):
gp00 = gp[0]
gp11 = gp[1]
Expand Down Expand Up @@ -530,3 +532,20 @@ def vfunc(gp, gq, W, V):
# quit()

return vfunc, wfunc


# @njit(**JIT_OPTIONS)
# def _weight_data(data, weight, flag, jones, tbin_idx, tbin_counts,
# ant1, ant2, pol, product, nc):
# return _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
# ant1, ant2, pol, product, nc)

# def _weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
# ant1, ant2, pol, product, nc):
# return NotImplementedError

# @overload(_weight_data_impl, jit_options=JIT_OPTIONS)
# def nb_weight_data_impl(data, weight, flag, jones, tbin_idx, tbin_counts,
# ant1, ant2, pol, product, nc):

# return

0 comments on commit 4afef38

Please sign in to comment.