Skip to content

Commit

Permalink
add option to drop nearly null bands
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Aug 2, 2024
1 parent 523fdc4 commit 2de08b1
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 80 deletions.
48 changes: 10 additions & 38 deletions pfb/parser/hci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ inputs:
field-of-view:
dtype: float
abbreviation: fov
default: 3.0
default: 1.0
info:
Field of view in degrees
super-resolution-factor:
dtype: float
default: 2
default: 1.4
abbreviation: srf
info:
Will over-sample Nyquist by this factor at max frequency
Expand All @@ -121,21 +121,6 @@ inputs:
abbreviation: rob
info:
Robustness factor for Briggs weighting. None means natural
filter-extreme-counts:
dtype: bool
default: false
info:
Replace extreme outliers in the weighting grid by their local mean
filter-nbox:
dtype: int
default: 16
info:
The size of the box to compute local mean over
filter-level:
dtype: float
default: 10
info:
The level above local mean that renders a grid weight extreme
target:
dtype: str
info:
Expand All @@ -149,27 +134,6 @@ inputs:
A sensible value for this parameter depends on the level of RFI in the data.
Small values (eg. 2) result in aggressive reweighting and should be avoided
if the model is still incomplete.
host-address:
dtype: str
abbreviation: ha
info:
Address where the distributed client lives.
Uses LocalCluster if no address is provided
and scheduler is set to distributed.
nworkers:
dtype: int
default: 1
abbreviation: nw
info:
Number of worker processes.
Use with distributed scheduler.
nvthreads:
dtype: int
abbreviation: nvt
info:
Number of threads used to scale vertically (eg. for FFTs and gridding).
Each dask thread can in principle spawn this many threads.
Will attempt to use half the available threads by default.
progressbar:
dtype: bool
default: true
Expand All @@ -180,10 +144,18 @@ inputs:
default: zarr
info:
zarr or fits output
sigmainvsq:
dtype: float
default: 1e-5
info:
Value to add to Hessian to make it invertible.
Smaller values tend to fit more noise and make the inversion less stable.

_include:
- (.)gridding.yml
- (.)out.yml
- (.)dist.yml
- (.)cgopts.yml

outputs:
{}
7 changes: 5 additions & 2 deletions pfb/parser/restore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ inputs:
The default resolution is the native resolution in each imaging band.
This parameter can be used to homogenise the resolution of the cubes.
Set to (0,0,0) to use the resolution of the lowest band.
inflate_factor:
inflate-factor:
dtype: float
default: 1.5
info:
Inflate the intrinsic resolution of the uniformly blurred image by
this amount.

drop-bands:
dtype: List[int]
info:
List of bands to discard

outputs:
{}
10 changes: 5 additions & 5 deletions pfb/utils/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def set_wcs(cell_x, cell_y, nx, ny, radec, freq,
crpix3 = 1
w.wcs.crval = [radec[0]*180.0/np.pi, radec[1]*180.0/np.pi, ref_freq, 1]
# LB - y axis treated differently because of stupid fits convention?
w.wcs.crpix = [1 + nx//2, ny//2, crpix3, 1]
w.wcs.crpix = [1 + nx//2,1 + ny//2, crpix3, 1]

if np.size(freq) > 1:
w.wcs.crval[2] = freq[0]
Expand Down Expand Up @@ -176,12 +176,12 @@ def dds2fits(dsl, column, outname, norm_wsum=True,
cube = np.zeros((nband, nx, ny))
wsums = np.zeros(nband)
wsum = 0.0
for ds in dds:
for i, ds in enumerate(dds):
if ds.timeid == timeid:
b = int(ds.bandid)
cube[b] = ds.get(column).values
wsums[b] = ds.wsum
wsum += wsums[b]
cube[i] = ds.get(column).values
wsums[i] = ds.wsum
wsum += wsums[i]
radec = (ds.ra, ds.dec)
cell_deg = np.rad2deg(ds.cell_rad)
nx, ny = ds.get(column).shape
Expand Down
84 changes: 50 additions & 34 deletions pfb/utils/stokes2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,34 +184,6 @@ def single_stokes_image(

mask = (~flag).astype(np.uint8)


# TODO - we may want to apply this to bigger chunks of data
if opts.robustness is not None:
counts = _compute_counts(
uvw,
freq,
mask,
nx,
ny,
cell_rad,
cell_rad,
np.float64, # same type as uvw
weight,
ngrid=opts.nthreads)

# counts will be accumulated on nvthreads grids in parallel
# so we need the sum here
counts = counts.sum(axis=0)

# get rid of artificially high weights corresponding to
# nearly empty cells
if opts.filter_extreme_counts:
counts = _filter_extreme_counts(counts,
nbox=opts.filter_nbox,
nlevel=opts.filter_level)



if mds is not None:
nband = fbin_idx.size
model = np.zeros((nband, mds.npix_x, mds.npix_y), dtype=real_type)
Expand Down Expand Up @@ -273,16 +245,22 @@ def single_stokes_image(
weight = None

if opts.robustness is not None:
if counts is None:
raise ValueError('counts are None but robustness specified. '
'This is probably a bug!')
imwgt = _counts_to_weights(
counts = _compute_counts(uvw,
freq,
mask,
weight,
nx, ny,
cellx, celly,
uvw.dtype,
ngrid=1)

imwgt = counts_to_weights(
counts,
uvw,
freq,
nx, ny,
cell_rad, cell_rad,
opts.robustness)
cellx, celly,
robustness)
if weight is not None:
weight *= imwgt
else:
Expand Down Expand Up @@ -310,6 +288,42 @@ def single_stokes_image(

rms = np.std(residual/wsum)

if opts.natural_grad:
from pfb.opt.pcg import pcg
from pfb.operators.hessian import _hessian_impl
from functools import partial

hess = partial(_hessian_impl,
uvw=uvw,
weight=weight,
vis_mask=mask,
freq=freq,
beam=None,
cell=cell_rad,
x0=x0,
y0=y0,
do_wgridding=opts.do_wgridding,
epsilon=opts.epsilon,
double_accum=opts.double_accum,
nthreads=opts.nthreads,
sigmainvsq=opts.sigmainvsq*wsum,
wsum=1.0) # we haven't normalised residual

x = pcg(hess,
residual,
x0=np.zeros_like(residual),
# M=precond,
tol=opts.cg_tol,
minit=1,
tol=opts.cg_tol,
maxit=opts.cg_maxit,
verbosity=opts.cg_verbose,
report_freq=opts.cg_report_freq,
backtrack=False,
return_resid=False)
else:
x = None

unix_time = quantity(f'{time_out}s').to_unix_time()
utc = datetime.utcfromtimestamp(unix_time).strftime('%Y-%m-%d %H:%M:%S')

Expand All @@ -322,6 +336,8 @@ def single_stokes_image(
if opts.output_format == 'zarr':
data_vars = {}
data_vars['RESIDUAL'] = (('x', 'y'), residual.astype(np.float32))
if x is not None:
data_vars['NATGRAD'] = (('x', 'y'), residual.astype(np.float32))

coords = {'chan': (('chan',), freq),
'time': (('time',), utime),
Expand Down
10 changes: 10 additions & 0 deletions pfb/workers/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ def _restore(**kw):
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
dds = xds_from_url(dds_store.url)

ddso = []
ddso_list = []
for ds, dsl in zip(dds, dds_list):
b = int(ds.bandid)
if b not in opts.drop_bands:
ddso.append(ds)
ddso_list.append(dsl)
dds = ddso
dds_list = ddso_list

freq_out = []
time_out = []
for ds in dds:
Expand Down
12 changes: 11 additions & 1 deletion pfb/workers/sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _sara(ddsi=None, **kw):
# we need an array to put the components in for reweighting
outvar = np.zeros((nband, nbasis, Nymax, Nxmax), dtype=real_type)

# TODO - should we cache this?
dual = np.zeros((nband, nbasis, Nymax, Nxmax), dtype=residual.dtype)
if l1reweight_from == 0:
print('Initialising with L1 reweighted', file=log)
Expand All @@ -319,6 +318,12 @@ def _sara(ddsi=None, **kw):
tmp = np.sum(outvar, axis=0)
# exclude zeros from padding DWT's
rms_comps = np.std(tmp[tmp!=0])
print(f'rms_comps updated to {rms_comps}', file=log)
for i, base in enumerate(bases):
tmpb = tmp[i]
print(f'rms for base {base} is {np.std(tmpb[tmpb!=0])}',
file=log)
import ipdb; ipdb.set_trace()
reweighter = partial(l1reweight_func,
psiH=psi.dot,
outvar=outvar,
Expand Down Expand Up @@ -493,6 +498,11 @@ def _sara(ddsi=None, **kw):
tmp = np.sum(outvar, axis=0)
# exclude zeros from padding DWT's
rms_comps = np.std(tmp[tmp!=0])
print(f'rms_comps updated to {rms_comps}', file=log)
for i, base in enumerate(bases):
tmpb = tmp[i]
print(f'rms for base {base} is {np.std(tmpb[tmpb!=0])}',
file=log)
reweighter = partial(l1reweight_func,
psiH=psi.dot,
outvar=outvar,
Expand Down

0 comments on commit 2de08b1

Please sign in to comment.