diff --git a/pfb/parser/hci.yaml b/pfb/parser/hci.yaml index 09c0d9965..f21377210 100644 --- a/pfb/parser/hci.yaml +++ b/pfb/parser/hci.yaml @@ -95,12 +95,12 @@ inputs: field-of-view: dtype: float abbreviation: fov - default: 3.0 + default: 1.0 info: Field of view in degrees super-resolution-factor: dtype: float - default: 2 + default: 1.4 abbreviation: srf info: Will over-sample Nyquist by this factor at max frequency @@ -121,21 +121,6 @@ inputs: abbreviation: rob info: Robustness factor for Briggs weighting. None means natural - filter-extreme-counts: - dtype: bool - default: false - info: - Replace extreme outliers in the weighting grid by their local mean - filter-nbox: - dtype: int - default: 16 - info: - The size of the box to compute local mean over - filter-level: - dtype: float - default: 10 - info: - The level above local mean that renders a grid weight extreme target: dtype: str info: @@ -149,27 +134,6 @@ inputs: A sensible value for this parameter depends on the level of RFI in the data. Small values (eg. 2) result in aggressive reweighting and should be avoided if the model is still incomplete. - host-address: - dtype: str - abbreviation: ha - info: - Address where the distributed client lives. - Uses LocalCluster if no address is provided - and scheduler is set to distributed. - nworkers: - dtype: int - default: 1 - abbreviation: nw - info: - Number of worker processes. - Use with distributed scheduler. - nvthreads: - dtype: int - abbreviation: nvt - info: - Number of threads used to scale vertically (eg. for FFTs and gridding). - Each dask thread can in principle spawn this many threads. - Will attempt to use half the available threads by default. progressbar: dtype: bool default: true @@ -180,10 +144,18 @@ inputs: default: zarr info: zarr or fits output + sigmainvsq: + dtype: float + default: 1e-5 + info: + Value to add to Hessian to make it invertible. + Smaller values tend to fit more noise and make the inversion less stable. _include: - (.)gridding.yml - (.)out.yml + - (.)dist.yml + - (.)cgopts.yml outputs: {} diff --git a/pfb/parser/restore.yaml b/pfb/parser/restore.yaml index 502b59447..9f1e7adbf 100644 --- a/pfb/parser/restore.yaml +++ b/pfb/parser/restore.yaml @@ -40,13 +40,16 @@ inputs: The default resolution is the native resolution in each imaging band. This parameter can be used to homogenise the resolution of the cubes. Set to (0,0,0) to use the resolution of the lowest band. - inflate_factor: + inflate-factor: dtype: float default: 1.5 info: Inflate the intrinsic resolution of the uniformly blurred image by this amount. - + drop-bands: + dtype: List[int] + info: + List of bands to discard outputs: {} diff --git a/pfb/utils/fits.py b/pfb/utils/fits.py index c65610492..b69c8a79f 100644 --- a/pfb/utils/fits.py +++ b/pfb/utils/fits.py @@ -58,7 +58,7 @@ def set_wcs(cell_x, cell_y, nx, ny, radec, freq, crpix3 = 1 w.wcs.crval = [radec[0]*180.0/np.pi, radec[1]*180.0/np.pi, ref_freq, 1] # LB - y axis treated differently because of stupid fits convention? - w.wcs.crpix = [1 + nx//2, ny//2, crpix3, 1] + w.wcs.crpix = [1 + nx//2,1 + ny//2, crpix3, 1] if np.size(freq) > 1: w.wcs.crval[2] = freq[0] @@ -176,12 +176,12 @@ def dds2fits(dsl, column, outname, norm_wsum=True, cube = np.zeros((nband, nx, ny)) wsums = np.zeros(nband) wsum = 0.0 - for ds in dds: + for i, ds in enumerate(dds): if ds.timeid == timeid: b = int(ds.bandid) - cube[b] = ds.get(column).values - wsums[b] = ds.wsum - wsum += wsums[b] + cube[i] = ds.get(column).values + wsums[i] = ds.wsum + wsum += wsums[i] radec = (ds.ra, ds.dec) cell_deg = np.rad2deg(ds.cell_rad) nx, ny = ds.get(column).shape diff --git a/pfb/utils/stokes2im.py b/pfb/utils/stokes2im.py index 2d20e09e0..f3e52d636 100644 --- a/pfb/utils/stokes2im.py +++ b/pfb/utils/stokes2im.py @@ -184,34 +184,6 @@ def single_stokes_image( mask = (~flag).astype(np.uint8) - - # TODO - we may want to apply this to bigger chunks of data - if opts.robustness is not None: - counts = _compute_counts( - uvw, - freq, - mask, - nx, - ny, - cell_rad, - cell_rad, - np.float64, # same type as uvw - weight, - ngrid=opts.nthreads) - - # counts will be accumulated on nvthreads grids in parallel - # so we need the sum here - counts = counts.sum(axis=0) - - # get rid of artificially high weights corresponding to - # nearly empty cells - if opts.filter_extreme_counts: - counts = _filter_extreme_counts(counts, - nbox=opts.filter_nbox, - nlevel=opts.filter_level) - - - if mds is not None: nband = fbin_idx.size model = np.zeros((nband, mds.npix_x, mds.npix_y), dtype=real_type) @@ -273,16 +245,22 @@ def single_stokes_image( weight = None if opts.robustness is not None: - if counts is None: - raise ValueError('counts are None but robustness specified. ' - 'This is probably a bug!') - imwgt = _counts_to_weights( + counts = _compute_counts(uvw, + freq, + mask, + weight, + nx, ny, + cellx, celly, + uvw.dtype, + ngrid=1) + + imwgt = counts_to_weights( counts, uvw, freq, nx, ny, - cell_rad, cell_rad, - opts.robustness) + cellx, celly, + robustness) if weight is not None: weight *= imwgt else: @@ -310,6 +288,42 @@ def single_stokes_image( rms = np.std(residual/wsum) + if opts.natural_grad: + from pfb.opt.pcg import pcg + from pfb.operators.hessian import _hessian_impl + from functools import partial + + hess = partial(_hessian_impl, + uvw=uvw, + weight=weight, + vis_mask=mask, + freq=freq, + beam=None, + cell=cell_rad, + x0=x0, + y0=y0, + do_wgridding=opts.do_wgridding, + epsilon=opts.epsilon, + double_accum=opts.double_accum, + nthreads=opts.nthreads, + sigmainvsq=opts.sigmainvsq*wsum, + wsum=1.0) # we haven't normalised residual + + x = pcg(hess, + residual, + x0=np.zeros_like(residual), + # M=precond, + tol=opts.cg_tol, + minit=1, + tol=opts.cg_tol, + maxit=opts.cg_maxit, + verbosity=opts.cg_verbose, + report_freq=opts.cg_report_freq, + backtrack=False, + return_resid=False) + else: + x = None + unix_time = quantity(f'{time_out}s').to_unix_time() utc = datetime.utcfromtimestamp(unix_time).strftime('%Y-%m-%d %H:%M:%S') @@ -322,6 +336,8 @@ def single_stokes_image( if opts.output_format == 'zarr': data_vars = {} data_vars['RESIDUAL'] = (('x', 'y'), residual.astype(np.float32)) + if x is not None: + data_vars['NATGRAD'] = (('x', 'y'), residual.astype(np.float32)) coords = {'chan': (('chan',), freq), 'time': (('time',), utime), diff --git a/pfb/workers/restore.py b/pfb/workers/restore.py index 0fee0e912..d3aea5fed 100644 --- a/pfb/workers/restore.py +++ b/pfb/workers/restore.py @@ -102,6 +102,16 @@ def _restore(**kw): dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr') dds = xds_from_url(dds_store.url) + ddso = [] + ddso_list = [] + for ds, dsl in zip(dds, dds_list): + b = int(ds.bandid) + if b not in opts.drop_bands: + ddso.append(ds) + ddso_list.append(dsl) + dds = ddso + dds_list = ddso_list + freq_out = [] time_out = [] for ds in dds: diff --git a/pfb/workers/sara.py b/pfb/workers/sara.py index cce9d35f1..241e23cff 100644 --- a/pfb/workers/sara.py +++ b/pfb/workers/sara.py @@ -308,7 +308,6 @@ def _sara(ddsi=None, **kw): # we need an array to put the components in for reweighting outvar = np.zeros((nband, nbasis, Nymax, Nxmax), dtype=real_type) - # TODO - should we cache this? dual = np.zeros((nband, nbasis, Nymax, Nxmax), dtype=residual.dtype) if l1reweight_from == 0: print('Initialising with L1 reweighted', file=log) @@ -319,6 +318,12 @@ def _sara(ddsi=None, **kw): tmp = np.sum(outvar, axis=0) # exclude zeros from padding DWT's rms_comps = np.std(tmp[tmp!=0]) + print(f'rms_comps updated to {rms_comps}', file=log) + for i, base in enumerate(bases): + tmpb = tmp[i] + print(f'rms for base {base} is {np.std(tmpb[tmpb!=0])}', + file=log) + import ipdb; ipdb.set_trace() reweighter = partial(l1reweight_func, psiH=psi.dot, outvar=outvar, @@ -493,6 +498,11 @@ def _sara(ddsi=None, **kw): tmp = np.sum(outvar, axis=0) # exclude zeros from padding DWT's rms_comps = np.std(tmp[tmp!=0]) + print(f'rms_comps updated to {rms_comps}', file=log) + for i, base in enumerate(bases): + tmpb = tmp[i] + print(f'rms for base {base} is {np.std(tmpb[tmpb!=0])}', + file=log) reweighter = partial(l1reweight_func, psiH=psi.dot, outvar=outvar,