Skip to content

Commit

Permalink
flip_v everywhere and correct fits crpix2 convention
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Aug 26, 2024
1 parent 2108b7a commit 415cdb7
Show file tree
Hide file tree
Showing 18 changed files with 876 additions and 897 deletions.
225 changes: 86 additions & 139 deletions pfb/operators/gridder.py

Large diffs are not rendered by default.

176 changes: 78 additions & 98 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,14 @@
import numpy as np
import dask
import dask.array as da
from ducc0.wgridder import vis2dirty, dirty2vis
from ducc0.wgridder.experimental import vis2dirty, dirty2vis
from ducc0.fft import r2c, c2r
from ducc0.misc import make_noncritical
from uuid import uuid4
from pfb.operators.psf import (psf_convolve_slice,
psf_convolve_cube)


def hessian_xds(x, xds, hessopts, wsum, sigmainv, mask,
compute=True, use_beam=True):
'''
Vis space Hessian reduction over dataset.
Hessian will be applied to x
'''
if not isinstance(x, da.Array):
x = da.from_array(x, chunks=(1, -1, -1),
name="x-" + uuid4().hex)

if not isinstance(mask, da.Array):
mask = da.from_array(mask, chunks=(-1, -1),
name="mask-" + uuid4().hex)

assert mask.ndim == 2

nband, nx, ny = x.shape

# LB - what is the point of specifying name?
convims = [da.zeros((nx, ny),
chunks=(-1, -1), name="zeros-" + uuid4().hex)
for _ in range(nband)]

for ds in xds:
wgt = ds.WEIGHT.data
vis_mask = ds.MASK.data
uvw = ds.UVW.data
freq = ds.FREQ.data
b = ds.bandid
if use_beam:
beam = ds.BEAM.data * mask
else:
# TODO - separate implementation without
# unnecessary beam application
beam = mask

convim = hessian(x[b], uvw, wgt, vis_mask, freq, beam, hessopts)

convims[b] += convim

convim = da.stack(convims)/wsum

if sigmainv:
convim += x * sigmainv**2

if compute:
return convim.compute()
else:
return convim


def _hessian_impl(x,
uvw=None,
weight=None,
Expand All @@ -68,13 +17,27 @@ def _hessian_impl(x,
beam=None,
x0=0.0,
y0=0.0,
flip_u=False,
flip_v=True,
flip_w=False,
cell=None,
do_wgridding=None,
epsilon=None,
double_accum=None,
nthreads=None,
sigmainvsq=None,
wsum=1.0):
'''
Apply vis space Hessian approximation on a slice of an image.
Important!
x0, y0, flip_u, flip_v and flip_w must be consistent with the
conventions defined in pfb.operators.gridder.wgridder_conventions
These are inputs here to allow for testing but should generally be taken
from the attrs of the datasets produced by
pfb.operators.gridder.image_data_products
'''
if not x.any():
return np.zeros_like(x)
nx, ny = x.shape
Expand All @@ -86,27 +49,34 @@ def _hessian_impl(x,
pixsize_y=cell,
center_x=x0,
center_y=y0,
flip_u=flip_u,
flip_v=flip_v,
flip_w=flip_w,
epsilon=epsilon,
nthreads=nthreads,
do_wgridding=do_wgridding,
divide_by_n=False)

convim = vis2dirty(uvw=uvw,
freq=freq,
vis=mvis,
wgt=weight,
mask=vis_mask,
npix_x=nx,
npix_y=ny,
pixsize_x=cell,
pixsize_y=cell,
center_x=x0,
center_y=y0,
epsilon=epsilon,
nthreads=nthreads,
do_wgridding=do_wgridding,
double_precision_accumulation=double_accum,
divide_by_n=False)
convim = vis2dirty(
uvw=uvw,
freq=freq,
vis=mvis,
wgt=weight,
mask=vis_mask,
npix_x=nx,
npix_y=ny,
pixsize_x=cell,
pixsize_y=cell,
center_x=x0,
center_y=y0,
flip_u=flip_u,
flip_v=flip_v,
flip_w=flip_w,
epsilon=epsilon,
nthreads=nthreads,
do_wgridding=do_wgridding,
double_precision_accumulation=double_accum,
divide_by_n=False)
convim /= wsum

if beam is not None:
Expand All @@ -118,24 +88,25 @@ def _hessian_impl(x,
return convim


def _hessian(x, uvw, weight, vis_mask, freq, beam, hessopts):
return _hessian_impl(x, uvw[0][0], weight[0][0], vis_mask[0][0], freq[0],
beam, **hessopts)
# Kept in case we need them in the future
# def _hessian(x, uvw, weight, vis_mask, freq, beam, hessopts):
# return _hessian_impl(x, uvw[0][0], weight[0][0], vis_mask[0][0], freq[0],
# beam, **hessopts)

def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts):
if beam is None:
bout = None
else:
bout = ('nx', 'ny')
return da.blockwise(_hessian, ('nx', 'ny'),
x, ('nx', 'ny'),
uvw, ('row', 'three'),
weight, ('row', 'chan'),
vis_mask, ('row', 'chan'),
freq, ('chan',),
beam, bout,
hessopts, None,
dtype=x.dtype)
# def hessian(x, uvw, weight, vis_mask, freq, beam, hessopts):
# if beam is None:
# bout = None
# else:
# bout = ('nx', 'ny')
# return da.blockwise(_hessian, ('nx', 'ny'),
# x, ('nx', 'ny'),
# uvw, ('row', 'three'),
# weight, ('row', 'chan'),
# vis_mask, ('row', 'chan'),
# freq, ('chan',),
# beam, bout,
# hessopts, None,
# dtype=x.dtype)


def _hessian_psf_slice(
Expand Down Expand Up @@ -167,10 +138,10 @@ def _hessian_psf_slice(
if wsum is not None:
xout /= wsum

# if sigmainv:
# xout += x * sigmainv
if sigmainv:
xout += x * sigmainv

return xout + x * sigmainv
return xout


def hessian_psf_cube(
Expand Down Expand Up @@ -202,7 +173,10 @@ def hessian_psf_cube(
if wsum is not None:
xout /= wsum

return xout + x * sigmainv
if sigmainv:
xout += x * sigmainv

return xout
else:
raise NotImplementedError

Expand All @@ -220,7 +194,10 @@ def hess_direct(x, # input image, not overwritten
mode='forward'):
nband, nx, ny = x.shape
xpad[...] = 0.0
xpad[:, 0:nx, 0:ny] = x * taperxy[None]
if mode == 'forward':
xpad[:, 0:nx, 0:ny] = x / taperxy[None]
else:
xpad[:, 0:nx, 0:ny] = x * taperxy[None]
r2c(xpad, out=xhat, axes=(1,2),
forward=True, inorm=0, nthreads=nthreads)
if mode=='forward':
Expand All @@ -231,7 +208,11 @@ def hess_direct(x, # input image, not overwritten
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[:, 0:nx, 0:ny]
return xout * taperxy[None]
if mode=='forward':
xout /= taperxy[None]
else:
xout *= taperxy[None]
return xout


def hess_direct_slice(x, # input image, not overwritten
Expand All @@ -254,16 +235,15 @@ def hess_direct_slice(x, # input image, not overwritten
r2c(xpad, out=xhat, axes=(0,1),
forward=True, inorm=0, nthreads=nthreads)
if mode=='forward':
# xhat *= (psfhat + sigmainvsq)
xhat *= psfhat
xhat *= (psfhat + sigmainvsq)
else:
# xhat /= (psfhat + sigmainvsq)
xhat /= psfhat
xhat /= (psfhat + sigmainvsq)
c2r(xhat, axes=(0, 1), forward=False, out=xpad,
lastsize=lastsize, inorm=2, nthreads=nthreads,
allow_overwriting_input=True)
xout[...] = xpad[0:nx, 0:ny]
if mode=='foward':
return xout / taperxy
if mode=='forward':
xout /= taperxy
else:
return xout * taperxy
xout *= taperxy
return xout
4 changes: 2 additions & 2 deletions pfb/opt/pcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def pcg_dds(ds_name,

# set precond if PSF is present
if 'PSFHAT' in ds and use_psf:
psfhat = np.abs(ds.PSFHAT.values)/wsum + sigma
psfhat = np.abs(ds.PSFHAT.values)/wsum
ds.drop_vars(('PSFHAT'))
nx_psf, nyo2 = psfhat.shape
ny_psf = 2*(nyo2-1) # is this always the case?
Expand Down Expand Up @@ -323,7 +323,7 @@ def pcg_dds(ds_name,
taperxy=taperxy,
lastsize=ny_psf,
nthreads=nthreads,
sigmainvsq=1.0, # not used
sigmainvsq=sigma,
mode='backward')

x0 = precond(j)
Expand Down
8 changes: 6 additions & 2 deletions pfb/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,12 @@ def __init__(self, xds_list, opts, bandid, cache_path, max_freq, uv_max):
self.uv_max = uv_max
nx, ny, nx_psf, ny_psf, cell_N, cell_rad = set_image_size(uv_max,
max_freq,
opts)
opts.field_of_view,
opts.super_resolution_factor,
opts.cell_size,
opts.nx,
opts.ny,
opts.psf_oversize)
cell_deg = np.rad2deg(cell_rad)
cell_size = cell_deg * 3600
# print(f"Super resolution factor = {cell_N/cell_rad}", file=log)
Expand Down Expand Up @@ -354,7 +359,6 @@ def set_residual(self, k, x=None):
self.cell_rad, self.cell_rad,
self.cache_path, # output_name (same as dsl names?)
x,
x0=self.x0, y0=self.y0,
nthreads=self.nthreads,
epsilon=self.opts.epsilon,
do_wgridding=self.opts.do_wgridding,
Expand Down
22 changes: 16 additions & 6 deletions pfb/utils/fits.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS
from pfb.utils.misc import to4d
import dask.array as da
from dask import delayed
from datetime import datetime
Expand All @@ -10,6 +9,19 @@
from pfb.utils.naming import xds_from_list


def to4d(data):
if data.ndim == 4:
return data
elif data.ndim == 2:
return data[None, None]
elif data.ndim == 3:
return data[None]
elif data.ndim == 1:
return data[None, None, None]
else:
raise ValueError("Only arrays with ndim <= 4 can be broadcast to 4D.")


def data_from_header(hdr, axis=3):
npix = hdr['NAXIS' + str(axis)]
refpix = hdr['CRPIX' + str(axis)]
Expand All @@ -20,13 +32,13 @@ def data_from_header(hdr, axis=3):

def load_fits(name, dtype=np.float32):
data = fits.getdata(name)
data = np.transpose(to4d(data)[:, :, ::-1], axes=(0, 1, 3, 2))
data = np.transpose(to4d(data), axes=(0, 1, 3, 2))
return np.require(data, dtype=dtype, requirements='C')


def save_fits(data, name, hdr, overwrite=True, dtype=np.float32):
hdu = fits.PrimaryHDU(header=hdr)
data = np.transpose(to4d(data), axes=(0, 1, 3, 2))[:, :, ::-1]
data = np.transpose(to4d(data), axes=(0, 1, 3, 2))
hdu.data = np.require(data, dtype=dtype, requirements='F')
hdu.writeto(name, overwrite=overwrite)
return
Expand Down Expand Up @@ -62,9 +74,7 @@ def set_wcs(cell_x, cell_y, nx, ny, radec, freq,
ref_freq = freq
crpix3 = 1
w.wcs.crval = [radec[0]*180.0/np.pi, radec[1]*180.0/np.pi, ref_freq, 1]
# y axis treated differently because of wgridder convention?
# https://github.com/mreineck/ducc/issues/34
w.wcs.crpix = [1 + nx//2, ny//2, crpix3, 1]
w.wcs.crpix = [1 + nx//2, 1 + ny//2, crpix3, 1]

header = w.to_header()
header['RESTFRQ'] = ref_freq
Expand Down
Loading

0 comments on commit 415cdb7

Please sign in to comment.