Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix div by zero in fitcleanbeam #96

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pfb/parser/model2comps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ inputs:
default: 1e-10
info:
Multiple of the identity to add to the hessian for stability
model-out:
dtype: str
info:
Optional explicit output name.
Otherwise the default naming convention is used.

_include:
- (.)dist.yml
Expand Down
26 changes: 13 additions & 13 deletions pfb/utils/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,28 +94,28 @@ def compare_headers(hdr1, hdr2):
raise ValueError("Headers do not match on key %s. " % key, hdr1[key], hdr2[key])


def add_beampars(hdr, GaussPar, GaussPars=None):
def add_beampars(hdr, GaussPar, GaussPars=None, unit2deg=1.0):
"""
Add beam keywords to header.
GaussPar - MFS beam pars
GaussPars - beam pars for cube
"""
if len(GaussPar) == 3:
hdr['BMAJ'] = GaussPar[0]
hdr['BMIN'] = GaussPar[1]
hdr['BPA'] = GaussPar[2]
elif len(GaussPar) == 1:
hdr['BMAJ'] = GaussPar[0][0]
hdr['BMIN'] = GaussPar[0][1]
hdr['BPA'] = GaussPar[0][2]
else:
if len(GaussPar) == 1:
GaussPar = GaussPar[0]
elif len(GaussPar) != 3:
raise ValueError('Invalid value for GaussPar')

if not np.isnan(GaussPar).any():
hdr['BMAJ'] = GaussPar[0]*unit2deg
hdr['BMIN'] = GaussPar[1]*unit2deg
hdr['BPA'] = GaussPar[2]*unit2deg

if GaussPars is not None:
for i in range(len(GaussPars)):
hdr['BMAJ' + str(i+1)] = GaussPars[i][0]
hdr['BMIN' + str(i+1)] = GaussPars[i][1]
hdr['PA' + str(i+1)] = GaussPars[i][2]
if not np.isnan(GaussPars[i]).any():
hdr['BMAJ' + str(i+1)] = GaussPars[i][0]*unit2deg
hdr['BMIN' + str(i+1)] = GaussPars[i][1]*unit2deg
hdr['PA' + str(i+1)] = GaussPars[i][2]*unit2deg

return hdr

Expand Down
3 changes: 3 additions & 0 deletions pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ def func(xy, emaj, emin, pa):
Gausspars = []
for v in range(nband):
# make sure psf is normalised
if not psf[v].any():
Gausspars.append([np.nan, np.nan, np.nan])
continue
psfv = psf[v] / psf[v].max()
# find regions where psf is above level
mask = np.where(psfv > level, 1.0, 0)
Expand Down
27 changes: 22 additions & 5 deletions pfb/workers/model2comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ def _model2comps(**kw):

basename = f'{opts.output_filename}_{opts.product.upper()}'
dds_name = f'{basename}_{opts.postfix}.dds'
coeff_name = f'{basename}_{opts.postfix}_{opts.model_name.lower()}.mds'
if opts.model_out is not None:
coeff_name = opts.model_out
else:
coeff_name = f'{basename}_{opts.postfix}_{opts.model_name.lower()}.mds'

if os.path.isdir(coeff_name):
if opts.overwrite:
print(f'Removing {coeff_name}', file=log)
import shutil
shutil.rmtree(coeff_name)
else:
raise RuntimeError(f"{coeff_name} exists. "
f"Set --overwrite if you meant to overwrite it.")

dds = xds_from_zarr(dds_name,
chunks={'x':-1,
Expand Down Expand Up @@ -121,15 +127,26 @@ def _model2comps(**kw):
wsums[...] = 1.0

if opts.min_val is not None:
model = np.where(model >= opts.min_val, model, 0.0)
mmfs = model[wsums>0]
if mmfs.ndim==3:
mmfs = np.mean(mmfs, axis=0)
elif mmfs.ndim==4:
mmfs = np.mean(mmfs, axis=(0,1))
model = np.where(np.abs(mmfs)[None, None] >= opts.min_val, model, 0.0)

if not np.any(model):
raise ValueError(f'Model is empty or has no components above {opts.min_val}')
radec = (dds[0].ra, dds[0].dec)

coeffs, Ix, Iy, expr, params, texpr, fexpr = \
fit_image_cube(mtimes, mfreqs, model, wgt=wsums,
nbasisf=opts.nbasisf, method=opts.fit_mode)
try:
coeffs, Ix, Iy, expr, params, texpr, fexpr = \
fit_image_cube(mtimes, mfreqs, model, wgt=wsums,
nbasisf=opts.nbasisf, method=opts.fit_mode)
except np.linalg.LinAlgError as e:
print(f"Exception {e} raised during fit ."
f"Do you perhaps have empty sub-bands?"
f"Try decreasing nbasisf", file=log)
quit()

# save interpolated dataset
data_vars = {
Expand Down
87 changes: 48 additions & 39 deletions pfb/workers/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ def _restore(**kw):
hdr_mfs = set_wcs(cell_deg, cell_deg, nx, ny, radec, ref_freq)
hdr = set_wcs(cell_deg, cell_deg, nx, ny, radec, freq)


# stack cubes
dirty, model, residual, psf, _, _, wsums, _ = dds2cubes(dds,
nband,
apparent=True)
wsum = np.sum(wsums)
output_type = dirty.dtype
fmask = wsums > 0
if fmask.all():
raise ValueError("All data seem to be flagged")

if residual is None:
print('Warning, no residual in dds. '
Expand All @@ -92,54 +94,41 @@ def _restore(**kw):

if not model.any():
print("Warning - model is empty", file=log)
model_mfs = np.mean(model, axis=0)
model_mfs = np.mean(model[fmask], axis=0)

# lm in pixel coordinates
lpsf = -(nx//2) + np.arange(nx)
mpsf = -(ny//2) + np.arange(ny)
xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

if psf is not None:
nx_psf = dds[0].x_psf.size
ny_psf = dds[0].y_psf.size
psf_mfs = np.sum(psf, axis=0)
psf[fmask] /= wsums[fmask, None, None]/wsum
# sanity check
assert (psf_mfs.max() - 1.0) < 2e-7
assert ((np.amax(psf, axis=(1,2)) - 1.0) < 2e-7).all()

# fit restoring psf
try:
psf_mismatch_mfs = np.abs(psf_mfs.max() - 1.0)
psf_mismatch = np.abs(np.amax(psf, axis=(1,2))[fmask] - 1.0).max()
assert psf_mismatch_mfs < 1e-5
assert psf_mismatch < 1e-5
except Exception as e:
max_mismatch = np.maximum(psf_mismatch_mfs, psf_mismatch)
print(f"Warning - PSF does not normlaise to one. "
f"Max mismatch is {max_mismatch:.3e}", file=log)

# fit restoring psf (pixel units)
GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
cpsf_mfs = np.zeros(residual_mfs.shape, dtype=output_type)
lpsf = -(nx//2) + np.arange(nx)
mpsf = -(ny//2) + np.arange(ny)
xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')
cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)
image_mfs = convolve2gaussres(model_mfs[None], xx, yy,
GaussPar[0], opts.nvthreads,
norm_kernel=False)[0] # peak of kernel set to unity
image_mfs += residual_mfs
# convert pixel units to deg
GaussPar[0][0] *= cell_deg
GaussPar[0][1] *= cell_deg
hdr_mfs = add_beampars(hdr_mfs, GaussPar)

if any([i.isupper() for i in opts.outputs]):
GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0) # pixel units
cpsf = np.zeros(residual.shape, dtype=output_type)
for v in range(opts.nband):
cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

image = np.zeros_like(model)
for b in range(nband):
image[b:b+1] = convolve2gaussres(model[b:b+1], xx, yy,
GaussPars[b], opts.nvthreads,
norm_kernel=False) # peak of kernel set to unity
image[b] += residual[b]

for i, gp in enumerate(GaussPars):
GaussPars[i] = [gp[0]*cell_deg, gp[1]*cell_deg, gp[2]]

hdr = add_beampars(hdr, GaussPar, GaussPars)
hdr_mfs = add_beampars(hdr_mfs, GaussPar, unit2deg=cell_deg)
GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)
hdr = add_beampars(hdr, GaussPar, GaussPars, unit2deg=cell_deg)

else:
print('Warning, no psf in dds. '
'Unable to add resolution info or make restored image. ',
file=log)
GaussPar = None
GaussPars = None

if 'm' in opts.outputs:
save_fits(model_mfs,
Expand Down Expand Up @@ -180,24 +169,44 @@ def _restore(**kw):
overwrite=opts.overwrite)

if 'i' in opts.outputs and psf is not None:
image_mfs = convolve2gaussres(model_mfs[None], xx, yy,
GaussPar[0], opts.nvthreads,
norm_kernel=False)[0] # peak of kernel set to unity
image_mfs += residual_mfs
save_fits(image_mfs,
f'{basename}_{opts.postfix}.image_mfs.fits',
hdr_mfs,
overwrite=opts.overwrite)

if 'I' in opts.outputs and psf is not None:
image = np.zeros_like(model)
for b in range(nband):
image[b:b+1] = convolve2gaussres(model[b:b+1], xx, yy,
GaussPars[b], opts.nvthreads,
norm_kernel=False) # peak of kernel set to unity
image[b] += residual[b]
save_fits(image,
f'{basename}_{opts.postfix}.image.fits',
hdr,
overwrite=opts.overwrite)

if 'c' in opts.outputs and psf is not None:
if 'c' in opts.outputs:
if GaussPar is None:
raise ValueError("Clean beam in output but no PSF in dds")
cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)
save_fits(cpsf_mfs,
f'{basename}_{opts.postfix}.cpsf_mfs.fits',
hdr_mfs,
overwrite=opts.overwrite)

if 'C' in opts.outputs and psf is not None:
if 'C' in opts.outputs:
if GaussPars is None:
raise ValueError("Clean beam in output but no PSF in dds")
cpsf = np.zeros(residual.shape, dtype=output_type)
for v in range(opts.nband):
gpar = GaussPars[v]
if not np.isnan(gpar).any():
cpsf[v] = Gaussian2D(xx, yy, gpar, normalise=False)
save_fits(cpsf,
f'{basename}_{opts.postfix}.cpsf.fits',
hdr,
Expand Down
Loading