diff --git a/pfb/deconv/clark.py b/pfb/deconv/clark.py index 3f5af657..2d33d72b 100644 --- a/pfb/deconv/clark.py +++ b/pfb/deconv/clark.py @@ -116,9 +116,9 @@ def clark(ID, model = np.zeros((nband, nx, ny), dtype=ID.dtype) IR = ID.copy() # pre-allocate arrays for doing FFT's - xout = empty_noncritical(ID.shape, dtype=ID.dtype) - xpad = empty_noncritical(PSF.shape, dtype=ID.dtype) - xhat = empty_noncritical(PSFHAT.shape, dtype=PSFHAT.dtype) + xout = empty_noncritical(ID.shape, dtype='f8') + xpad = empty_noncritical(PSF.shape, dtype='f8') + xhat = empty_noncritical(PSFHAT.shape, dtype='c16') # square avoids abs of full array IRsearch = np.sum(IR, axis=0)**2 pq = IRsearch.argmax() diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index 00db1290..37f690f6 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -423,6 +423,7 @@ def image_data_products(dsl, if ovar: # scale the natural weights # RHS is weight relative to unity since wgtp included in ressq + # tmp = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) else: wgt = None @@ -459,10 +460,29 @@ def image_data_products(dsl, ssq = ressq[mask>0].sum() ovar = ssq/mask.sum() wgt /= ovar - ressq = (residual_vis*wgt*residual_vis.conj()).real + # ressq = (residual_vis*wgt*residual_vis.conj()).real # chi2_dof = np.mean(ressq[mask>0]) # print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof}') + import matplotlib.pyplot as plt + from scipy.stats import norm + x = np.linspace(-5, 5, 150) + y = norm.pdf(x, 0, 1) + fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 12)) + ax[0,0].hist((residual_vis.real*wgtp).ravel(), bins=15, density=True) + ax[0,0].plot(x, y, 'k') + ax[0,1].hist((residual_vis.real*wgt).ravel(), bins=15, density=True) + ax[0,1].plot(x, y, 'k') + ax[1,0].hist((residual_vis.imag*wgtp).ravel(), bins=15, density=True) + ax[1,0].plot(x, y, 'k') + ax[1,1].hist((residual_vis.imag*wgt).ravel(), bins=15, density=True) + ax[1,1].plot(x, y, 'k') + import os + cwd = os.getcwd() + bid = dso.attrs['bandid'] + fig.savefig(f'{cwd}/resid_hist_{bid}.png') + # import ipdb; ipdb.set_trace() + # these are always used together if do_weight: dso['WEIGHT'] = (('row','chan'), wgt) diff --git a/pfb/operators/hessian.py b/pfb/operators/hessian.py index ac1ca43f..f20a7c00 100644 --- a/pfb/operators/hessian.py +++ b/pfb/operators/hessian.py @@ -360,7 +360,7 @@ def idot(self, x, mode='psf', x0=None): x0=x0[b], tol=self.cgtol, maxit=self.cgmaxit, - minit=1, + minit=3, verbosity=self.cgverbose, report_freq=self.cgrf, backtrack=False, diff --git a/pfb/opt/pcg.py b/pfb/opt/pcg.py index 0494032b..e5e11ad3 100644 --- a/pfb/opt/pcg.py +++ b/pfb/opt/pcg.py @@ -1,4 +1,6 @@ +from time import time import numpy as np +import numexpr as ne from functools import partial import dask.array as da from distributed import wait @@ -46,14 +48,35 @@ def M(x): return x x = x0 eps = 1.0 stall_count = 0 + xp = x.copy() + rp = r.copy() + tcopy = 0.0 + tA = 0.0 + tvdot = 0.0 + tupdate = 0.0 + tp = 0.0 + tnorm = 0.0 + tii = time() while (eps > tol or k < minit) and k < maxit and stall_count < 5: - xp = x.copy() - rp = r.copy() + ti = time() + np.copyto(xp, x) + np.copyto(rp, r) + tcopy += (time() - ti) + ti = time() Ap = A(p) + tA += (time() - ti) + ti = time() rnorm = np.vdot(r, y) alpha = rnorm / np.vdot(p, Ap) - x = xp + alpha * p - r = rp + alpha * Ap + tvdot += (time() - ti) + ti = time() + ne.evaluate('xp + alpha*p', + out=x) + ne.evaluate('rp + alpha*Ap', + out=r) + # x = xp + alpha * p + # r = rp + alpha * Ap + tupdate += (time() - ti) y = M(r) rnorm_next = np.vdot(r, y) while rnorm_next > rnorm and backtrack: # TODO - better line search @@ -63,16 +86,22 @@ def M(x): return x y = M(r) rnorm_next = np.vdot(r, y) + ti = time() beta = rnorm_next / rnorm - p = beta * p - y + ne.evaluate('beta*p-y', + out=p) + # p = beta * p - y + tp += (time() - ti) # if p is zero we should stop if not np.any(p): break rnorm = rnorm_next k += 1 epsp = eps + ti = time() eps = norm_diff(x, xp) phi = rnorm / phi0 + tnorm += (time() - ti) if np.abs(epsp - eps) < 1e-3*tol: stall_count += 1 @@ -80,6 +109,21 @@ def M(x): return x if not k % report_freq and verbosity > 1: print(f"At iteration {k} eps = {eps:.3e}, phi = {phi:.3e}") # file=log) + ttot = time() - tii + tcopy /= ttot + tA /= ttot + tvdot /= ttot + tupdate /= ttot + tp /= ttot + tnorm /= ttot + ttally = tcopy + tA + tvdot + tupdate + tp + tnorm + print('tcopy = ', tcopy) + print('tA = ', tA) + print('tvdot = ', tvdot) + print('tupdate = ', tupdate) + print('tp = ', tp) + print('tnorm = ', tnorm) + print('ttally = ', ttally) if k >= maxit: if verbosity: @@ -124,16 +168,16 @@ def M(x): return x / eta M = None for k in range(nband): - xpad = empty_noncritical((nx_psf, lastsize), dtype=b.dtype, order='C') - xhat = empty_noncritical((nx_psf, nyo2), dtype=psfhat.dtype, order='C') - xout = empty_noncritical((nx, ny), dtype=b.dtype, order='C') + xpad = empty_noncritical((nx_psf, lastsize), dtype=b.dtype) + xhat = empty_noncritical((nx_psf, nyo2), dtype=psfhat.dtype) + xout = empty_noncritical((nx, ny), dtype=b.dtype) A = partial(_hessian_psf_slice, - xpad, - xhat, - xout, - psfhat[k], - beam[k], - lastsize, + xpad=xpad, + xhat=xhat, + xout=xout, + abspsf=np.abs(psfhat[k]), + beam=beam[k], + lastsize=lastsize, nthreads=nthreads, eta=eta) diff --git a/pfb/workers/klean.py b/pfb/workers/klean.py index b7497f5f..b290f7bb 100644 --- a/pfb/workers/klean.py +++ b/pfb/workers/klean.py @@ -63,32 +63,23 @@ def klean(**kw): dds = xds_from_url(dds_store.url) - from pfb.utils.fits import dds2fits, dds2fits_mfs + from pfb.utils.fits import dds2fits if opts.fits_mfs or opts.fits: print(f"Writing fits files to {fits_oname}_{opts.suffix}", file=log) - - # convert to fits files - if opts.fits_mfs: - dds2fits_mfs(dds, - 'RESIDUAL', - f'{fits_oname}_{opts.suffix}', - norm_wsum=True) - dds2fits_mfs(dds, - 'MODEL', - f'{fits_oname}_{opts.suffix}', - norm_wsum=False) - - if opts.fits_cubes: dds2fits(dds, 'RESIDUAL', f'{fits_oname}_{opts.suffix}', - norm_wsum=True) + norm_wsum=True, + do_mfs=opts.fits_mfs, + do_cube=opts.fits_cubes) dds2fits(dds, 'MODEL', f'{fits_oname}_{opts.suffix}', - norm_wsum=False) + norm_wsum=False, + do_mfs=opts.fits_mfs, + do_cube=opts.fits_cubes) print(f"All done after {time.time() - ti}s", file=log) diff --git a/pfb/workers/sara.py b/pfb/workers/sara.py index 71bb064f..3c0bfed1 100644 --- a/pfb/workers/sara.py +++ b/pfb/workers/sara.py @@ -325,6 +325,7 @@ def _sara(**kw): best_rmax = rmax best_model = model.copy() diverge_count = 0 + eps = 1.0 print(f"Iter {iter0}: peak residual = {rmax:.3e}, rms = {rms:.3e}", file=log) if opts.skip_model: @@ -336,7 +337,8 @@ def _sara(**kw): residual *= beam # avoid copy update = precond.idot(residual, mode=opts.hess_approx, - x0=update) + # only warm start close to convergence + x0=update if eps < 0.1 else None) update_mfs = np.mean(update, axis=0) save_fits(update_mfs, fits_oname + f'_{opts.suffix}_update_{k+1}.fits',