Skip to content

Commit

Permalink
fix failing tests. Report timings in pcg. Plot residual histograms af…
Browse files Browse the repository at this point in the history
…ter reweighting
  • Loading branch information
landmanbester committed Sep 9, 2024
1 parent d2fe71a commit 0973d5c
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 36 deletions.
6 changes: 3 additions & 3 deletions pfb/deconv/clark.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def clark(ID,
model = np.zeros((nband, nx, ny), dtype=ID.dtype)
IR = ID.copy()
# pre-allocate arrays for doing FFT's
xout = empty_noncritical(ID.shape, dtype=ID.dtype)
xpad = empty_noncritical(PSF.shape, dtype=ID.dtype)
xhat = empty_noncritical(PSFHAT.shape, dtype=PSFHAT.dtype)
xout = empty_noncritical(ID.shape, dtype='f8')
xpad = empty_noncritical(PSF.shape, dtype='f8')
xhat = empty_noncritical(PSFHAT.shape, dtype='c16')
# square avoids abs of full array
IRsearch = np.sum(IR, axis=0)**2
pq = IRsearch.argmax()
Expand Down
22 changes: 21 additions & 1 deletion pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def image_data_products(dsl,
if ovar:
# scale the natural weights
# RHS is weight relative to unity since wgtp included in ressq
# tmp = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
else:
wgt = None
Expand Down Expand Up @@ -459,10 +460,29 @@ def image_data_products(dsl,
ssq = ressq[mask>0].sum()
ovar = ssq/mask.sum()
wgt /= ovar
ressq = (residual_vis*wgt*residual_vis.conj()).real
# ressq = (residual_vis*wgt*residual_vis.conj()).real
# chi2_dof = np.mean(ressq[mask>0])
# print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof}')

import matplotlib.pyplot as plt
from scipy.stats import norm
x = np.linspace(-5, 5, 150)
y = norm.pdf(x, 0, 1)
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(8, 12))
ax[0,0].hist((residual_vis.real*wgtp).ravel(), bins=15, density=True)
ax[0,0].plot(x, y, 'k')
ax[0,1].hist((residual_vis.real*wgt).ravel(), bins=15, density=True)
ax[0,1].plot(x, y, 'k')
ax[1,0].hist((residual_vis.imag*wgtp).ravel(), bins=15, density=True)
ax[1,0].plot(x, y, 'k')
ax[1,1].hist((residual_vis.imag*wgt).ravel(), bins=15, density=True)
ax[1,1].plot(x, y, 'k')
import os
cwd = os.getcwd()
bid = dso.attrs['bandid']
fig.savefig(f'{cwd}/resid_hist_{bid}.png')
# import ipdb; ipdb.set_trace()

# these are always used together
if do_weight:
dso['WEIGHT'] = (('row','chan'), wgt)
Expand Down
2 changes: 1 addition & 1 deletion pfb/operators/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def idot(self, x, mode='psf', x0=None):
x0=x0[b],
tol=self.cgtol,
maxit=self.cgmaxit,
minit=1,
minit=3,
verbosity=self.cgverbose,
report_freq=self.cgrf,
backtrack=False,
Expand Down
72 changes: 58 additions & 14 deletions pfb/opt/pcg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from time import time
import numpy as np
import numexpr as ne
from functools import partial
import dask.array as da
from distributed import wait
Expand Down Expand Up @@ -46,14 +48,35 @@ def M(x): return x
x = x0
eps = 1.0
stall_count = 0
xp = x.copy()
rp = r.copy()
tcopy = 0.0
tA = 0.0
tvdot = 0.0
tupdate = 0.0
tp = 0.0
tnorm = 0.0
tii = time()
while (eps > tol or k < minit) and k < maxit and stall_count < 5:
xp = x.copy()
rp = r.copy()
ti = time()
np.copyto(xp, x)
np.copyto(rp, r)
tcopy += (time() - ti)
ti = time()
Ap = A(p)
tA += (time() - ti)
ti = time()
rnorm = np.vdot(r, y)
alpha = rnorm / np.vdot(p, Ap)
x = xp + alpha * p
r = rp + alpha * Ap
tvdot += (time() - ti)
ti = time()
ne.evaluate('xp + alpha*p',
out=x)
ne.evaluate('rp + alpha*Ap',
out=r)
# x = xp + alpha * p
# r = rp + alpha * Ap
tupdate += (time() - ti)
y = M(r)
rnorm_next = np.vdot(r, y)
while rnorm_next > rnorm and backtrack: # TODO - better line search
Expand All @@ -63,23 +86,44 @@ def M(x): return x
y = M(r)
rnorm_next = np.vdot(r, y)

ti = time()
beta = rnorm_next / rnorm
p = beta * p - y
ne.evaluate('beta*p-y',
out=p)
# p = beta * p - y
tp += (time() - ti)
# if p is zero we should stop
if not np.any(p):
break
rnorm = rnorm_next
k += 1
epsp = eps
ti = time()
eps = norm_diff(x, xp)
phi = rnorm / phi0
tnorm += (time() - ti)

if np.abs(epsp - eps) < 1e-3*tol:
stall_count += 1

if not k % report_freq and verbosity > 1:
print(f"At iteration {k} eps = {eps:.3e}, phi = {phi:.3e}")
# file=log)
ttot = time() - tii
tcopy /= ttot
tA /= ttot
tvdot /= ttot
tupdate /= ttot
tp /= ttot
tnorm /= ttot
ttally = tcopy + tA + tvdot + tupdate + tp + tnorm
print('tcopy = ', tcopy)
print('tA = ', tA)
print('tvdot = ', tvdot)
print('tupdate = ', tupdate)
print('tp = ', tp)
print('tnorm = ', tnorm)
print('ttally = ', ttally)

if k >= maxit:
if verbosity:
Expand Down Expand Up @@ -124,16 +168,16 @@ def M(x): return x / eta
M = None

for k in range(nband):
xpad = empty_noncritical((nx_psf, lastsize), dtype=b.dtype, order='C')
xhat = empty_noncritical((nx_psf, nyo2), dtype=psfhat.dtype, order='C')
xout = empty_noncritical((nx, ny), dtype=b.dtype, order='C')
xpad = empty_noncritical((nx_psf, lastsize), dtype=b.dtype)
xhat = empty_noncritical((nx_psf, nyo2), dtype=psfhat.dtype)
xout = empty_noncritical((nx, ny), dtype=b.dtype)
A = partial(_hessian_psf_slice,
xpad,
xhat,
xout,
psfhat[k],
beam[k],
lastsize,
xpad=xpad,
xhat=xhat,
xout=xout,
abspsf=np.abs(psfhat[k]),
beam=beam[k],
lastsize=lastsize,
nthreads=nthreads,
eta=eta)

Expand Down
23 changes: 7 additions & 16 deletions pfb/workers/klean.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,23 @@ def klean(**kw):

dds = xds_from_url(dds_store.url)

from pfb.utils.fits import dds2fits, dds2fits_mfs
from pfb.utils.fits import dds2fits


if opts.fits_mfs or opts.fits:
print(f"Writing fits files to {fits_oname}_{opts.suffix}", file=log)

# convert to fits files
if opts.fits_mfs:
dds2fits_mfs(dds,
'RESIDUAL',
f'{fits_oname}_{opts.suffix}',
norm_wsum=True)
dds2fits_mfs(dds,
'MODEL',
f'{fits_oname}_{opts.suffix}',
norm_wsum=False)

if opts.fits_cubes:
dds2fits(dds,
'RESIDUAL',
f'{fits_oname}_{opts.suffix}',
norm_wsum=True)
norm_wsum=True,
do_mfs=opts.fits_mfs,
do_cube=opts.fits_cubes)
dds2fits(dds,
'MODEL',
f'{fits_oname}_{opts.suffix}',
norm_wsum=False)
norm_wsum=False,
do_mfs=opts.fits_mfs,
do_cube=opts.fits_cubes)

print(f"All done after {time.time() - ti}s", file=log)

Expand Down
4 changes: 3 additions & 1 deletion pfb/workers/sara.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def _sara(**kw):
best_rmax = rmax
best_model = model.copy()
diverge_count = 0
eps = 1.0
print(f"Iter {iter0}: peak residual = {rmax:.3e}, rms = {rms:.3e}",
file=log)
if opts.skip_model:
Expand All @@ -336,7 +337,8 @@ def _sara(**kw):
residual *= beam # avoid copy
update = precond.idot(residual,
mode=opts.hess_approx,
x0=update)
# only warm start close to convergence
x0=update if eps < 0.1 else None)
update_mfs = np.mean(update, axis=0)
save_fits(update_mfs,
fits_oname + f'_{opts.suffix}_update_{k+1}.fits',
Expand Down

0 comments on commit 0973d5c

Please sign in to comment.