Skip to content

Commit

Permalink
catch and raise errors returned by futures
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Sep 5, 2024
1 parent 4cc7069 commit 3ab346d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 40 deletions.
66 changes: 37 additions & 29 deletions pfb/operators/gridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _comps2vis_impl(uvw,


def image_data_products(dsl,
counts,
dsp,
nx, ny,
nx_psf, ny_psf,
cellx, celly,
Expand Down Expand Up @@ -343,7 +343,6 @@ def image_data_products(dsl,
'y': y
}


# expects a list
if isinstance(dsl, str):
dsl = [dsl]
Expand Down Expand Up @@ -382,15 +381,14 @@ def image_data_products(dsl,
# output ds
dso = xr.Dataset(attrs=attrs, coords=coords)
dso['FREQ'] = (('chan',), freq)
if counts is not None:
dso['COUNTS'] = (('x', 'y'), counts)

if model is None:
if l2_reweight_dof:
raise ValueError('Requested l2 reweight but no model passed in. '
'Perhaps transfer model from somewhere?')
else:
# do not apply weights in this direction
# actually model vis, this saves memory
residual_vis = dirty2vis(
uvw=uvw,
freq=freq,
Expand All @@ -412,48 +410,58 @@ def image_data_products(dsl,
residual_vis += vis

if l2_reweight_dof:
ressq = (residual_vis*residual_vis.conj()).real
if dsp:
dsp = xds_from_list([dsp], drop_all_but='WEIGHT')
wgtp = dsp[0].WEIGHT.values
else:
wgtp = 1.0
# mask needs to be bool here
ressq = (residual_vis*wgtp*residual_vis.conj()).real
ssq = ressq[mask>0].sum()
ovar = ssq/mask.sum()
chi2_dofp = np.mean(ressq[mask>0]*wgt[mask>0])
mean_dev = np.mean(ressq[mask>0]/ovar)
# chi2_dofp = np.mean(ressq[mask>0])
if ovar:
wgt = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
# now divide by ovar to scale to absolute units
# the chi2_dof after reweighting should be closer to one
wgt /= ovar
chi2_dof = np.mean(ressq[mask>0]*wgt[mask>0])
print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof} with mean deviation of {mean_dev}')
# scale the natural weights
# RHS is weight relative to unity since wgtp included in ressq
wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar)
else:
wgt = None

# we usually want to re-evaluate this since the robustness may change
# re-evaluate since robustness and or wgt after reweight may change
if robustness is not None:
numba.set_num_threads(np.maximum(nthreads, 1))
numba_threads = np.maximum(nthreads, 1)
numba.set_num_threads(numba_threads)
counts = _compute_counts(uvw,
freq,
mask,
wgt,
nx, ny,
cellx, celly,
uvw.dtype,
ngrid=np.minimum(nthreads, 8), # limit number of grids
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
imwgt = counts_to_weights(
freq,
mask,
wgt,
nx, ny,
cellx, celly,
uvw.dtype,
# limit number of grids
ngrid=np.minimum(numba_threads, 8),
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
wgt = counts_to_weights(
counts,
uvw,
freq,
wgt,
nx, ny,
cellx, celly,
robustness,
usign=1.0 if flip_u else -1.0,
vsign=1.0 if flip_v else -1.0)
if wgt is not None:
wgt *= imwgt
else:
wgt = imwgt

if l2_reweight_dof:
# normalise to absolute units
ressq = (residual_vis*wgt*residual_vis.conj()).real
ssq = ressq[mask>0].sum()
ovar = ssq/mask.sum()
wgt /= ovar
ressq = (residual_vis*wgt*residual_vis.conj()).real
# chi2_dof = np.mean(ressq[mask>0])
# print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof}')

# these are always used together
if do_weight:
Expand Down
17 changes: 9 additions & 8 deletions pfb/utils/weighting.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,9 @@ def _compute_counts(uvw, freq, mask, wgt, nx, ny,
for r in range(bin_idx[g], bin_idx[g] + bin_counts[g]):
uvw_row = uvw[r]
wgt_row = wgt[r]
mask_row = mask[r]
for c in range(nchan):
if not mask[r, c]:
if not mask_row[c]:
continue
# current uv coords
chan_normfreq = normfreq[c]
Expand Down Expand Up @@ -136,9 +137,9 @@ def _es_kernel(x, beta, k):


@njit(nogil=True, cache=True, parallel=True)
def counts_to_weights(counts, uvw, freq, nx, ny,
cell_size_x, cell_size_y, robust,
usign=1.0, vsign=-1.0):
def counts_to_weights(counts, uvw, freq, weight, nx, ny,
cell_size_x, cell_size_y, robust,
usign=1.0, vsign=-1.0):
# ufreq
u_cell = 1/(nx*cell_size_x)
umax = np.abs(-1/cell_size_x/2 - u_cell/2)
Expand All @@ -152,9 +153,8 @@ def counts_to_weights(counts, uvw, freq, nx, ny,
nchan = freq.size
nrow = uvw.shape[0]

weights = np.zeros((nrow, nchan), dtype=counts.dtype)
if not counts.any():
return weights
return weight

# Briggs weighting factor
if robust > -2:
Expand All @@ -166,6 +166,7 @@ def counts_to_weights(counts, uvw, freq, nx, ny,
normfreq = freq / lightspeed
for r in prange(nrow):
uvw_row = uvw[r]
weight_row = weight[r]
for c in range(nchan):
# get current uv
chan_normfreq = normfreq[c]
Expand All @@ -176,8 +177,8 @@ def counts_to_weights(counts, uvw, freq, nx, ny,
# get v index
v_idx = int(np.floor((v_tmp + vmax)/v_cell))
if counts[u_idx, v_idx]:
weights[r, c] = 1.0/counts[u_idx, v_idx]
return weights
weight_row[c] = weight_row[c]/counts[u_idx, v_idx]
return weight



Expand Down
7 changes: 5 additions & 2 deletions pfb/workers/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ def _grid(**kw):
freq_out = ds_dct['freq_out']

if from_cache:
out_ds = xr.open_zarr(f'{dds_store.url}/time{timeid}_band{bandid}.zarr',
out_ds_name = f'{dds_store.url}/time{timeid}_band{bandid}.zarr'
out_ds = xr.open_zarr(out_ds_name,
chunks=None)
else:
out_ds_name = None

# compute lm coordinates of target
if opts.target is not None:
Expand Down Expand Up @@ -457,7 +460,7 @@ def _grid(**kw):

fut = client.submit(image_data_products,
dsl,
None, # counts
out_ds_name,
nx, ny,
nx_psf, ny_psf,
cell_rad, cell_rad,
Expand Down
6 changes: 5 additions & 1 deletion pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,11 @@ def _init(**kw):
ac_iter = as_completed(futures)
for completed_future in ac_iter:

result = completed_future.result()
try:
result = completed_future.result()
except Exception as e:
raise e

try:
times_out.append(result[0])
freqs_out.append(result[1])
Expand Down

0 comments on commit 3ab346d

Please sign in to comment.