Skip to content

Commit

Permalink
allow passing in list of measurement sets from cli
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Nov 14, 2023
1 parent 0902433 commit a283959
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 40 deletions.
2 changes: 1 addition & 1 deletion pfb/parser/init.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
inputs:
ms:
dtype: str
dtype: List[str]
required: true
abbreviation: ms
info:
Expand Down
90 changes: 76 additions & 14 deletions pfb/utils/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,100 @@
from functools import partial
from katbeam import JimBeam
import dask.array as da
from numba.core.errors import NumbaDeprecationWarning
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=NumbaDeprecationWarning)
from africanus.rime.fast_beam_cubes import beam_cube_dde
from africanus.rime import parallactic_angles

def Id_beam(l, m):
return np.ones(l.shape, dtype=float)

def _interp_beam_impl(freq, nx, ny, cell_deg, btype):
def _interp_beam_impl(freq, nx, ny, cell_deg, btype,
utime=None, ant_pos=None, phase_dir=None):
'''
A function that returns an object array containing a function
returning beam values given (l,m) coordinates at a single frequency.
Frequency mapped to imaging band extenally. Result is meant to be
passed into eval_beam below.
'''
l = (-(nx//2) + np.arange(nx)) * cell_deg
m = (-(ny//2) + np.arange(ny)) * cell_deg
ll, mm = np.meshgrid(l, m, indexing='ij')

if isinstance(freq, np.ndarray):
assert freq.size == 1
freq = freq[0]
if btype is None:
beam = Id_beam
beam = np.ones((nx, ny), dtype=float)
elif btype.endswith('.npz'):
# these are expected to be in the format given here
# https://archive-gw-1.kat.ac.za/public/repository/10.48479/wdb0-h061/index.html
dct = np.load(btype)
beam = dct['abeam']
l = np.deg2rad(dct['ldeg'])
m = np.deg2rad(dct['mdeg'])
ll, mm = np.meshgrid(l, m, indexing='ij')
lm = np.vstack((ll.flatten(), mm.flatten())).T
beam_extents = np.array([[l.min(), l.max()], [m.min(), m.max()]])
bfreqs = dct['freq']
beam_amp = (beam[0, :, :, :] * beam[0, :, :, :].conj() +
beam[-1, :, :, :] * beam[-1, :, :, :].conj())/2.0
beam_amp = np.transpose(beam_amp, (1,2,0))[:, :, :, None, None].real
else:
btype = btype.lower()
btype = btype.replace('-', '_')
l = (-(nx//2) + np.arange(nx)) * cell_deg
m = (-(ny//2) + np.arange(ny)) * cell_deg
ll, mm = np.meshgrid(l, m, indexing='ij')
if btype in ["kbl", "kb_l", "katbeam_l"]:
# katbeam L band
beam = partial(JimBeam('MKAT-AA-L-JIM-2020').I, freqMHz=freq)
beam_amp = JimBeam('MKAT-AA-L-JIM-2020').I(ll.flatten(),
mm.flatten(),
freqMHz=freq/1e6)
elif btype in ["kbuhf", "kb_uhf", "katbeam_uhf"]:
# katbeam L band
beam = partial(JimBeam('MKAT-AA-UHF-JIM-2020').I, freqMHz=freq)
beam_amp = JimBeam('MKAT-AA-UHF-JIM-2020').I(ll.flatten(),
mm.flatten(),
freqMHz=freq/1e6)
else:
raise ValueError(f"Unknown beam model {btype}")
return np.array([beam], dtype=object)
beam_amp = beam_amp[:, :, None, None, None]

parangles = parallactic_angles(utime, ant_pos, phase_dir, backend='')
# mean over antanna nant -> 1
parangles = np.mean(parangles, axis=1, keepdims=True)
nant = 1
# beam_cube_dde requirements
nband = 1
ntimes = utime.size
ant_scale = np.ones((nant, nband, 2), dtype=np.float64)
point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64)
beam_image = beam_cube_dde(np.ascontiguousarray(beam_amp),
beam_extents, bfreqs,
lm, parangles, point_errs,
ant_scale, np.array((freq,))).squeeze()
return beam_image


def interp_beam(freq, nx, ny, cell_deg, btype):
def interp_beam(freq, nx, ny, cell_deg, btype,
utime=None, ant_pos=None, phase_dir=None):
'''
Blockwise wrapper that returns an object array containing a function
returning beam values given (l,m) coordinates at a single frequency.
Frequency mapped to imaging band extenally. Result is meant to be
passed into eval_beam below.
'''
return da.blockwise(_interp_beam_impl, '1',
if btype.endwith('.npz'):
dct = np.load(btype)
nx = dct['ldeg'].size
ny = dct['mdeg'].size

return da.blockwise(_interp_beam_impl, 'xy',
freq, None,
nx, None,
ny, None,
cell_deg, None,
btype, None,
new_axes={'1': 1},
utime, None,
ant_pos, None,
phase_dir, None,
new_axes={'x': nx, 'y': ny},
dtype=object)

def _eval_beam_impl(beam_object_array, l, m):
Expand All @@ -67,3 +116,16 @@ def eval_beam(beam_object_array, l, m):
l, lout,
m, mout,
dtype=float)


def get_beam_meta(msname):
from pyrap.tables import table
ms = table(msname)
time = ms.getcol('TIME')
utime = np.unique(time)
field = table(f'{msname}::FIELD')
phase_dir = field.getcol('PHASE_DIR').squeeze()
ant = table(f'{msname}::ANTENNA')
ant_pos = ant.getcol('POSITION')

return utime, phase_dir, ant_pos
4 changes: 3 additions & 1 deletion pfb/utils/stokes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def single_stokes(ds=None,
vis = output_dict['vis']
wgt = output_dict['wgt']

import ipdb; ipdb.set_trace()

if isinstance(opts.radec, str):
raise NotImplementedError()
elif isinstance(opts.radec, np.ndarray) and not np.array_equal(radec, opts.radec):
Expand Down Expand Up @@ -217,7 +219,7 @@ def single_stokes(ds=None,
# https://www.overleaf.com/read/yzrsrdwxhxrd
npix = int(np.deg2rad(opts.max_field_of_view)/cell_rad)
beam = interp_beam(freq_out/1e6, npix, npix, np.rad2deg(cell_rad), opts.beam_model)
data_vars['BEAM'] = (('scalar'), beam)
data_vars['BEAM'] = (('x','y'), beam)

coords = {'chan': (('chan',), freq)} #,
# 'row': (('row',), ds.ROWID.values)}
Expand Down
2 changes: 1 addition & 1 deletion pfb/workers/fwdbwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def hesspsi(x):

# get clean beam area to convert residual units during l1reweighting
# TODO - could refine this with comparison between dirty and restored
# if contiuing the deconvolution
# if continuing the deconvolution
GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)[0]
pix_per_beam = GaussPar[0]*GaussPar[1]*np.pi/4
print(f"Number of pixels per beam estimated as {pix_per_beam}",
Expand Down
38 changes: 21 additions & 17 deletions pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,28 @@ def init(**kw):
pyscilog.log_to_file(f'{ldir}/init_{timestamp}.log')
print(f'Logs will be written to {str(ldir)}/init_{timestamp}.log', file=log)
from daskms.fsspec_store import DaskMSStore
msstore = DaskMSStore(opts.ms.rstrip('/'))
ms = msstore.fs.glob(opts.ms.rstrip('/'))
try:
assert len(ms) > 0
opts.ms = list(map(msstore.fs.unstrip_protocol, ms))
except:
raise ValueError(f"No MS at {opts.ms}")

if opts.gain_table is not None:
gainstore = DaskMSStore(opts.gain_table.rstrip('/'))
gt = gainstore.fs.glob(opts.gain_table.rstrip('/'))
msnames = []
for ms in opts.ms:
msstore = DaskMSStore(ms.rstrip('/'))
mslist = msstore.fs.glob(ms.rstrip('/'))
try:
assert len(gt) > 0
opts.gain_table = list(map(gainstore.fs.unstrip_protocol, gt))
except Exception as e:
raise ValueError(f"No gain table at {opts.gain_table}")

if opts.product.upper() not in ["I"]:
assert len(mslist) > 0
msnames.append(*list(map(msstore.fs.unstrip_protocol, mslist)))
except:
raise ValueError(f"No MS at {ms}")
opts.ms = msnames
if opts.gain_table is not None:
gainnames = []
for gt in opts.gain_table:
gainstore = DaskMSStore(gt.rstrip('/'))
gtlist = gainstore.fs.glob(gt.rstrip('/'))
try:
assert len(gtlist) > 0
gainnames.append(*list(map(gainstore.fs.unstrip_protocol, gt)))
except Exception as e:
raise ValueError(f"No gain table at {gt}")
opts.gain_table = gainnames
if opts.product.upper() not in ["I","Q"]:
# , "Q", "U", "V", "XX", "YX", "XY",
# "YY", "RR", "RL", "LR", "LL"]:
raise NotImplementedError(f"Product {opts.product} not yet supported")
Expand Down
10 changes: 5 additions & 5 deletions pfb/workers/spotless.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,11 @@ def _spotless(ddsi=None, **kw):
print('Solving for model', file=log)
modelp = deepcopy(model)
data = residual + psf_convolve(model)
# grad21 = lambda x: psf_convolve(x) - data
def grad21(x):
res = psf_convolve(x) - data
res[fsel] *= sfactor
return res
grad21 = lambda x: psf_convolve(x) - data
# def grad21(x):
# res = psf_convolve(x) - data
# res[fsel] *= sfactor
# return res
model, dual = primal_dual(model,
dual,
opts.rmsfactor*rms,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

"QuartiCal[degrid]"
"@git+https://github.com/ratt-ru/QuartiCal.git"
"@v0.2.1-degridder"
"@bandpass_smoothing"
# "stimela"
# "@git+https://github.com/caracal-pipeline/stimela.git"
# "@FIASCO3"
Expand Down

0 comments on commit a283959

Please sign in to comment.