Skip to content

Commit

Permalink
fix and improve sara test
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Aug 26, 2024
1 parent 415cdb7 commit 15021cd
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 111 deletions.
2 changes: 1 addition & 1 deletion pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _comps2vis_impl(uvw,
y0 = mds.center_y
flip_u = mds.flip_u
flip_v = mds.flip_v
flip_w = mds.flip_v
flip_w = mds.flip_w
for t in range(ntime):
indt = slice(tbin_idx2[t], tbin_idx2[t] + tbin_cnts[t])
# TODO - clean up this logic. row_mapping holds the number of rows per
Expand Down
32 changes: 16 additions & 16 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,10 @@ def hess_direct(x, # input image, not overwritten
mode='forward'):
nband, nx, ny = x.shape
xpad[...] = 0.0
if mode == 'forward':
xpad[:, 0:nx, 0:ny] = x / taperxy[None]
else:
xpad[:, 0:nx, 0:ny] = x * taperxy[None]
# if mode == 'forward':
# xpad[:, 0:nx, 0:ny] = x / taperxy[None]
# else:
xpad[:, 0:nx, 0:ny] = x * taperxy[None]
r2c(xpad, out=xhat, axes=(1,2),
forward=True, inorm=0, nthreads=nthreads)
if mode=='forward':
Expand All @@ -208,10 +208,10 @@ def hess_direct(x, # input image, not overwritten
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[:, 0:nx, 0:ny]
if mode=='forward':
xout /= taperxy[None]
else:
xout *= taperxy[None]
# if mode=='forward':
# xout /= taperxy[None]
# else:
xout *= taperxy[None]
return xout


Expand All @@ -228,10 +228,10 @@ def hess_direct_slice(x, # input image, not overwritten
mode='forward'):
nx, ny = x.shape
xpad[...] = 0.0
if mode == 'forward':
xpad[0:nx, 0:ny] = x / taperxy
else:
xpad[0:nx, 0:ny] = x * taperxy
# if mode == 'forward':
# xpad[0:nx, 0:ny] = x / taperxy
# else:
xpad[0:nx, 0:ny] = x * taperxy
r2c(xpad, out=xhat, axes=(0,1),
forward=True, inorm=0, nthreads=nthreads)
if mode=='forward':
Expand All @@ -242,8 +242,8 @@ def hess_direct_slice(x, # input image, not overwritten
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[0:nx, 0:ny]
if mode=='forward':
xout /= taperxy
else:
xout *= taperxy
# if mode=='forward':
# xout /= taperxy
# else:
xout *= taperxy
return xout
16 changes: 2 additions & 14 deletions pfb/workers/fluxmop.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def fluxmop(**kw):

print(f"All done after {time.time() - ti}s", file=log)

def _fluxmop(ddsi=None, **kw):
def _fluxmop(**kw):
opts = OmegaConf.create(kw)
OmegaConf.set_struct(opts, True)

Expand All @@ -175,19 +175,7 @@ def _fluxmop(ddsi=None, **kw):
dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
if ddsi is not None:
dds = []
for ds in ddsi:
dds.append(ds.chunk({'row':-1,
'chan':-1,
'x':-1,
'y':-1,
'x_psf':-1,
'y_psf':-1,
'yo2':-1}))
else:
# are these sorted correctly?
dds = xds_from_url(dds_store.url)
dds = xds_from_url(dds_store.url)

nx, ny = dds[0].x.size, dds[0].y.size
nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size
Expand Down
3 changes: 2 additions & 1 deletion pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _grid(xdsi=None, **kw):
import fsspec
from daskms.fsspec_store import DaskMSStore
from pfb.utils.misc import set_image_size
from pfb.operators.gridder import image_data_products
from pfb.operators.gridder import image_data_products, wgridder_conventions
import xarray as xr
from pfb.utils.astrometry import get_coordinates
from africanus.coordinates import radec_to_lm
Expand Down Expand Up @@ -439,6 +439,7 @@ def _grid(xdsi=None, **kw):
# get the model
if opts.transfer_model_from:
from pfb.utils.misc import eval_coeffs_to_slice
_, _, _, x0, y0 = wgridder_conventions(l0, m0)
model = eval_coeffs_to_slice(
time_out,
freq_out,
Expand Down
22 changes: 5 additions & 17 deletions pfb/workers/klean.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def klean(**kw):
print(f"All done after {time.time() - ti}s", file=log)


def _klean(ddsi=None, **kw):
def _klean(**kw):
opts = OmegaConf.create(kw)
OmegaConf.set_struct(opts, True)

Expand All @@ -124,21 +124,9 @@ def _klean(ddsi=None, **kw):
dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
if ddsi is not None:
dds = []
for ds in ddsi:
dds.append(ds.chunk({'row':-1,
'chan':-1,
'x':-1,
'y':-1,
'x_psf':-1,
'y_psf':-1,
'yo2':-1}))
else:
# are these sorted correctly?
drop_vars = ['UVW','WEIGHT','MASK']
dds = xds_from_list(dds_list, nthreads=opts.nthreads,
drop_vars=drop_vars)
drop_vars = ['UVW','WEIGHT','MASK']
dds = xds_from_list(dds_list, nthreads=opts.nthreads,
drop_vars=drop_vars)

nx, ny = dds[0].x.size, dds[0].y.size
nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size
Expand Down Expand Up @@ -281,7 +269,7 @@ def _klean(ddsi=None, **kw):
'center_y': dds[0].y0,
'flip_u': dds[0].flip_u,
'flip_v': dds[0].flip_v,
'flip_v': dds[0].flip_v,
'flip_w': dds[0].flip_w,
'ra': dds[0].ra,
'dec': dds[0].dec,
'stokes': opts.product, # I,Q,U,V, IQ/IV, IQUV
Expand Down
16 changes: 2 additions & 14 deletions pfb/workers/model2comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def model2comps(**kw):

print(f"All done after {time.time() - ti}s", file=log)

def _model2comps(ddsi=None, **kw):
def _model2comps(**kw):
opts = OmegaConf.create(kw)
OmegaConf.set_struct(opts, True)

Expand All @@ -92,19 +92,7 @@ def _model2comps(ddsi=None, **kw):
dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
if ddsi is not None:
dds = []
for ds in ddsi:
dds.append(ds.chunk({'row':-1,
'chan':-1,
'x':-1,
'y':-1,
'x_psf':-1,
'y_psf':-1,
'yo2':-1}))
else:
# are these sorted correctly?
dds = xds_from_url(dds_store.url)
dds = xds_from_url(dds_store.url)

if opts.model_out is not None:
coeff_name = opts.model_out
Expand Down
24 changes: 6 additions & 18 deletions pfb/workers/sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def sara(**kw):
print(f"All done after {time.time() - ti}s", file=log)


def _sara(ddsi=None, **kw):
def _sara(**kw):
opts = OmegaConf.create(kw)
OmegaConf.set_struct(opts, True)

Expand Down Expand Up @@ -146,21 +146,9 @@ def _sara(ddsi=None, **kw):
dds_name = f'{basename}_{opts.suffix}.dds'
dds_store = DaskMSStore(dds_name)
dds_list = dds_store.fs.glob(f'{dds_store.url}/*.zarr')
if ddsi is not None:
dds = []
for ds in ddsi:
dds.append(ds.chunk({'row':-1,
'chan':-1,
'x':-1,
'y':-1,
'x_psf':-1,
'y_psf':-1,
'yo2':-1}))
else:
# are these sorted correctly?
dds = xds_from_list(dds_list,
drop_vars=['UVW', 'WEIGHT', 'MASK'],
nthreads=opts.nthreads)
dds = xds_from_list(dds_list,
drop_vars=['UVW', 'WEIGHT', 'MASK'],
nthreads=opts.nthreads)

nx, ny = dds[0].x.size, dds[0].y.size
nx_psf, ny_psf = dds[0].x_psf.size, dds[0].y_psf.size
Expand Down Expand Up @@ -374,14 +362,14 @@ def _sara(ddsi=None, **kw):
fits_oname + f'_{opts.suffix}_update_{k+1}.fits',
hdr_mfs)

print(f'Solving for model with lambda = {opts.rmsfactor*rms}', file=log)
modelp = deepcopy(model)
xtilde = model + opts.gamma * update
grad21 = lambda x: -precond(xtilde - x, 'forward')/opts.gamma
if iter0 == 0:
lam = opts.init_factor * opts.rmsfactor * rms
else:
lam = opts.rmsfactor*rms
print(f'Solving for model with lambda = {lam}', file=log)
model, dual = primal_dual(model,
dual,
lam,
Expand Down Expand Up @@ -436,7 +424,7 @@ def _sara(ddsi=None, **kw):
'center_y': dds[0].y0,
'flip_u': dds[0].flip_u,
'flip_v': dds[0].flip_v,
'flip_v': dds[0].flip_v,
'flip_w': dds[0].flip_w,
'ra': dds[0].ra,
'dec': dds[0].dec,
'stokes': opts.product, # I,Q,U,V, IQ/IV, IQUV
Expand Down
70 changes: 40 additions & 30 deletions tests/test_sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def test_sara(ms_name):
# TODO - currently we just check that this runs through.
# What should the passing criteria be?
'''
robustness = None
do_wgridding = True

# we need the client for the init step
from dask.distributed import LocalCluster, Client
Expand Down Expand Up @@ -120,7 +122,7 @@ def test_sara(ms_name):
pixsize_x=cell_rad,
pixsize_y=cell_rad,
epsilon=epsilon,
do_wgridding=True,
do_wgridding=do_wgridding,
divide_by_n=False,
flip_u=flip_u,
flip_v=flip_v,
Expand Down Expand Up @@ -167,8 +169,8 @@ def test_sara(ms_name):
grid_args["residual"] = False
grid_args["nthreads"] = 8
grid_args["overwrite"] = True
grid_args["robustness"] = 0.0
grid_args["do_wgridding"] = True
grid_args["robustness"] = robustness
grid_args["do_wgridding"] = do_wgridding
_grid(**grid_args)

dds_name = f'{outname}_main.dds'
Expand All @@ -183,61 +185,69 @@ def test_sara(ms_name):
sara_args["tol"] = tol
sara_args["gamma"] = 1.0
sara_args["pd_tol"] = [1e-3]
sara_args["rmsfactor"] = 0.1
sara_args["rmsfactor"] = 1.0
sara_args["epsfactor"] = 4.0
sara_args["l1reweight_from"] = 5
sara_args["bases"] = 'self,db1'
sara_args["nlevels"] = 3
sara_args["nthreads"] = 8
sara_args["do_wgridding"] = True
sara_args["do_wgridding"] = do_wgridding
sara_args["epsilon"] = epsilon
sara_args["fits_mfs"] = False
_sara(**sara_args)


# the computed by the grid worker should be idenitcal to that
# computed in sara when passing in model
# the residual computed by the grid worker should be identical
# to that computed in sara when transferring model
dds = xds_from_url(dds_name)
freqs_dds = []
times_dds = []
for ds in dds:
freqs_dds.append(ds.freq_out)
times_dds.append(ds.time_out)

freqs_dds = np.array(freqs_dds)
times_dds = np.array(times_dds)
freqs_dds = np.unique(freqs_dds)
times_dds = np.unique(times_dds)
ntime_dds = times_dds.size
nfreq_dds = freqs_dds.size
# grid data to produce dirty image
grid_args = {}
for key in schema.grid["inputs"].keys():
grid_args[key.replace("-", "_")] = schema.grid["inputs"][key]["default"]
# overwrite defaults
grid_args["output_filename"] = outname
grid_args["field_of_view"] = fov
grid_args["fits_mfs"] = False
grid_args["psf"] = False
grid_args["weight"] = False
grid_args["noise"] = False
grid_args["residual"] = True
grid_args["nthreads"] = 8
grid_args["overwrite"] = True
grid_args["robustness"] = robustness
grid_args["do_wgridding"] = do_wgridding
grid_args["transfer_model_from"] = f'{outname}_main_model.mds'
grid_args["suffix"] = 'subtract'
_grid(**grid_args)

dds2 = xds_from_url(f'{outname}_subtract.dds')

for ds, ds2 in zip(dds, dds2):
wsum = ds.WSUM.values
assert_allclose(1 + np.abs(ds.RESIDUAL.values)/wsum,
1 + np.abs(ds2.RESIDUAL.values)/wsum)

# degrid from coeffs populating MODEL_DATA
# residuals also need to be the same if we do the
# subtraction in visibility space
degrid_args = {}
for key in schema.degrid["inputs"].keys():
degrid_args[key.replace("-", "_")] = schema.degrid["inputs"][key]["default"]
degrid_args["ms"] = [str(test_dir / 'test_ascii_1h60.0s.MS')]
degrid_args["mds"] = f'{outname}_main_model.mds'
degrid_args["channels_per_image"] = 1
degrid_args["nthreads"] = 8
degrid_args["do_wgridding"] = True
degrid_args["do_wgridding"] = do_wgridding
_degrid(**degrid_args)

# manually place residual in CORRECTED_DATA
resid = xds.DATA.data - xds.MODEL_DATA.data
xds['CORRECTED_DATA'] = (('row','chan','coor'), resid)
writes = [xds_to_table(xds, ms_name, columns='CORRECTED_DATA')]
dask.compute(writes)

# gridding CORRECTED_DATA should return identical residuals
init_args = {}
for key in schema.init["inputs"].keys():
init_args[key.replace("-", "_")] = schema.init["inputs"][key]["default"]
# overwrite defaults
outname = str(test_dir / 'test2_I')
init_args["ms"] = [str(test_dir / 'test_ascii_1h60.0s.MS')]
init_args["output_filename"] = outname
init_args["data_column"] = "CORRECTED_DATA"
init_args["data_column"] = "DATA-MODEL_DATA"
# init_args["weight_column"] = 'WEIGHT_SPECTRUM'
init_args["flag_column"] = 'FLAG'
init_args["gain_table"] = None
Expand All @@ -259,8 +269,8 @@ def test_sara(ms_name):
grid_args["residual"] = False
grid_args["nthreads"] = 8
grid_args["overwrite"] = True
grid_args["robustness"] = 0.0
grid_args["do_wgridding"] = True
grid_args["robustness"] = robustness
grid_args["do_wgridding"] = do_wgridding
_grid(**grid_args)

dds_name = f'{outname}_main.dds'
Expand Down

0 comments on commit 15021cd

Please sign in to comment.