Skip to content

Commit

Permalink
Don't place nan in fits header when restoring
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Mar 25, 2024
1 parent d851b6f commit b97b229
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
5 changes: 5 additions & 0 deletions pfb/parser/model2comps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ inputs:
default: 1e-10
info:
Multiple of the identity to add to the hessian for stability
model-out:
dtype: str
info:
Optional explicit output name.
Otherwise the default naming convention is used.

_include:
- (.)dist.yml
Expand Down
20 changes: 10 additions & 10 deletions pfb/utils/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,22 @@ def add_beampars(hdr, GaussPar, GaussPars=None, unit2deg=1.0):
GaussPar - MFS beam pars
GaussPars - beam pars for cube
"""
if len(GaussPar) == 3:
if len(GaussPar) == 1:
GaussPar = GaussPar[0]
elif len(GaussPar) != 3:
raise ValueError('Invalid value for GaussPar')

if not np.isnan(GaussPar).any():
hdr['BMAJ'] = GaussPar[0]*unit2deg
hdr['BMIN'] = GaussPar[1]*unit2deg
hdr['BPA'] = GaussPar[2]*unit2deg
elif len(GaussPar) == 1:
hdr['BMAJ'] = GaussPar[0][0]*unit2deg
hdr['BMIN'] = GaussPar[0][1]*unit2deg
hdr['BPA'] = GaussPar[0][2]*unit2deg
else:
raise ValueError('Invalid value for GaussPar')

if GaussPars is not None:
for i in range(len(GaussPars)):
hdr['BMAJ' + str(i+1)] = GaussPars[i][0]*unit2deg
hdr['BMIN' + str(i+1)] = GaussPars[i][1]*unit2deg
hdr['PA' + str(i+1)] = GaussPars[i][2]*unit2deg
if not np.isnan(GaussPars[i]).any():
hdr['BMAJ' + str(i+1)] = GaussPars[i][0]*unit2deg
hdr['BMIN' + str(i+1)] = GaussPars[i][1]*unit2deg
hdr['PA' + str(i+1)] = GaussPars[i][2]*unit2deg

return hdr

Expand Down
2 changes: 1 addition & 1 deletion pfb/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,10 @@ def func(xy, emaj, emin, pa):
Gausspars = []
for v in range(nband):
# make sure psf is normalised
psfv = psf[v] / psf[v].max()
if not psf[v].any():
Gausspars.append([np.nan, np.nan, np.nan])
continue
psfv = psf[v] / psf[v].max()
# find regions where psf is above level
mask = np.where(psfv > level, 1.0, 0)

Expand Down
27 changes: 22 additions & 5 deletions pfb/workers/model2comps.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ def _model2comps(**kw):

basename = f'{opts.output_filename}_{opts.product.upper()}'
dds_name = f'{basename}_{opts.postfix}.dds'
coeff_name = f'{basename}_{opts.postfix}_{opts.model_name.lower()}.mds'
if opts.model_out is not None:
coeff_name = opts.model_out
else:
coeff_name = f'{basename}_{opts.postfix}_{opts.model_name.lower()}.mds'

if os.path.isdir(coeff_name):
if opts.overwrite:
print(f'Removing {coeff_name}', file=log)
import shutil
shutil.rmtree(coeff_name)
else:
raise RuntimeError(f"{coeff_name} exists. "
f"Set --overwrite if you meant to overwrite it.")

dds = xds_from_zarr(dds_name,
chunks={'x':-1,
Expand Down Expand Up @@ -121,15 +127,26 @@ def _model2comps(**kw):
wsums[...] = 1.0

if opts.min_val is not None:
model = np.where(np.abs(model) >= opts.min_val, model, 0.0)
mmfs = model[wsums>0]
if mmfs.ndim==3:
mmfs = np.mean(mmfs, axis=0)
elif mmfs.ndim==4:
mmfs = np.mean(mmfs, axis=(0,1))
model = np.where(np.abs(mmfs)[None, None] >= opts.min_val, model, 0.0)

if not np.any(model):
raise ValueError(f'Model is empty or has no components above {opts.min_val}')
radec = (dds[0].ra, dds[0].dec)

coeffs, Ix, Iy, expr, params, texpr, fexpr = \
fit_image_cube(mtimes, mfreqs, model, wgt=wsums,
nbasisf=opts.nbasisf, method=opts.fit_mode)
try:
coeffs, Ix, Iy, expr, params, texpr, fexpr = \
fit_image_cube(mtimes, mfreqs, model, wgt=wsums,
nbasisf=opts.nbasisf, method=opts.fit_mode)
except np.linalg.LinAlgError as e:
print(f"Exception {e} raised during fit ."
f"Do you perhaps have empty sub-bands?"
f"Try decreasing nbasisf", file=log)
quit()

# save interpolated dataset
data_vars = {
Expand Down
2 changes: 1 addition & 1 deletion pfb/workers/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _restore(**kw):
# sanity check
try:
psf_mismatch_mfs = np.abs(psf_mfs.max() - 1.0)
psf_mismatch = np.abs(np.amax(psf, axis=(1,2)) - 1.0).max()
psf_mismatch = np.abs(np.amax(psf, axis=(1,2))[fmask] - 1.0).max()
assert psf_mismatch_mfs < 1e-5
assert psf_mismatch < 1e-5
except Exception as e:
Expand Down

0 comments on commit b97b229

Please sign in to comment.