Skip to content

Commit

Permalink
explicitly check if ac_iter is empty and break out of loop if it is
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 11, 2024
1 parent 8d0d3e8 commit 74e8e71
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
10 changes: 5 additions & 5 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,10 @@ def image_data_products(dsl,
ressq = (residual_vis*wgtp*residual_vis.conj()).real
ssq = ressq[mask>0].sum()
ovar = ssq/mask.sum()
# chi2_dofp = np.mean(ressq[mask>0])
if ovar:
# scale the natural weights
# RHS is weight relative to unity since wgtp included in ressq
# tmp = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
print(np.mean(ressq[mask>0]/ovar), np.std(ressq[mask>0]/ovar))
wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
else:
wgt = None
Expand Down Expand Up @@ -460,9 +459,10 @@ def image_data_products(dsl,
ssq = ressq[mask>0].sum()
ovar = ssq/mask.sum()
wgt /= ovar
# ressq = (residual_vis*wgt*residual_vis.conj()).real
# chi2_dof = np.mean(ressq[mask>0])
# print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof}')
ressq = (residual_vis*wgt*residual_vis.conj()).real
print(np.mean(ressq[mask>0]), np.std(ressq[mask>0]))

# import ipdb; ipdb.set_trace()

import matplotlib.pyplot as plt
from scipy.stats import norm
Expand Down
9 changes: 5 additions & 4 deletions pfb/operators/hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,33 +434,34 @@ def idot(self, x, mode='psf', x0=None):
assert ny == self.ny

if x0 is None:
# initialise with direct estimate
x0 = np.zeros_like(xtmp)
for b in range(self.nband):
x0[b] = hess_direct_slice(xtmp,
x0[b] = hess_direct_slice(xtmp[b],
xpad=self.xpad,
xhat=self.xhat,
xout=self.xout[b],
abspsf=self.abspsf[b],
taperxy=self.taperxy,
lastsize=self.ny_psf,
nthreads=self.nthreads,
eta=self.eta[b],
eta=self.eta[b]*np.sqrt(nx*ny),
mode='backward')
if self.beam[b] is not None:
mask = (self.xout[b] > 0) & (self.beam[b] > self.min_beam)
self.xout[b, mask] /= self.beam[b, mask]**2

if mode=='direct':
for b in range(self.nband):
self.xout[b] = hess_direct_slice(x,
self.xout[b] = hess_direct_slice(xtmp[b],
xpad=self.xpad,
xhat=self.xhat,
xout=self.xout[b],
abspsf=self.abspsf[b],
taperxy=self.taperxy,
lastsize=self.ny_psf,
nthreads=self.nthreads,
eta=self.eta[b],
eta=self.eta[b]*np.sqrt(nx*ny),
mode='backward')
if self.beam[b] is not None:
mask = (self.xout[b] > 0) & (self.beam[b] > self.min_beam)
Expand Down
9 changes: 6 additions & 3 deletions pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,9 @@ def _init(**kw):
except Exception as e:
raise e

try:
if result is not None:
times_out.append(result[0])
freqs_out.append(result[1])
except:
pass # no result if chunk fully flagged

if isinstance(completed_future.result(), BaseException):
print(completed_future.result())
Expand Down Expand Up @@ -372,6 +370,7 @@ def _init(**kw):
workers=worker,
key='image-'+uuid4().hex)


ac_iter.add(future)
associated_workers[future] = worker
n_launched += 1
Expand All @@ -384,6 +383,10 @@ def _init(**kw):
if opts.progressbar:
print(f"\rProcessing: {n_launched}/{nds}", end='', flush=True)

# this should not be necessary but just in case
if ac_iter.is_empty():
break

times_out = np.unique(times_out)
freqs_out = np.unique(freqs_out)

Expand Down

0 comments on commit 74e8e71

Please sign in to comment.