diff --git a/pfb/operators/gridder.py b/pfb/operators/gridder.py index 8fd63502..00db1290 100644 --- a/pfb/operators/gridder.py +++ b/pfb/operators/gridder.py @@ -299,7 +299,7 @@ def _comps2vis_impl(uvw, def image_data_products(dsl, - counts, + dsp, nx, ny, nx_psf, ny_psf, cellx, celly, @@ -343,7 +343,6 @@ def image_data_products(dsl, 'y': y } - # expects a list if isinstance(dsl, str): dsl = [dsl] @@ -382,8 +381,6 @@ def image_data_products(dsl, # output ds dso = xr.Dataset(attrs=attrs, coords=coords) dso['FREQ'] = (('chan',), freq) - if counts is not None: - dso['COUNTS'] = (('x', 'y'), counts) if model is None: if l2_reweight_dof: @@ -391,6 +388,7 @@ def image_data_products(dsl, 'Perhaps transfer model from somewhere?') else: # do not apply weights in this direction + # actually model vis, this saves memory residual_vis = dirty2vis( uvw=uvw, freq=freq, @@ -412,48 +410,58 @@ def image_data_products(dsl, residual_vis += vis if l2_reweight_dof: - ressq = (residual_vis*residual_vis.conj()).real + if dsp: + dsp = xds_from_list([dsp], drop_all_but='WEIGHT') + wgtp = dsp[0].WEIGHT.values + else: + wgtp = 1.0 # mask needs to be bool here + ressq = (residual_vis*wgtp*residual_vis.conj()).real ssq = ressq[mask>0].sum() ovar = ssq/mask.sum() - chi2_dofp = np.mean(ressq[mask>0]*wgt[mask>0]) - mean_dev = np.mean(ressq[mask>0]/ovar) + # chi2_dofp = np.mean(ressq[mask>0]) if ovar: - wgt = (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) - # now divide by ovar to scale to absolute units - # the chi2_dof after reweighting should be closer to one - wgt /= ovar - chi2_dof = np.mean(ressq[mask>0]*wgt[mask>0]) - print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof} with mean deviation of {mean_dev}') + # scale the natural weights + # RHS is weight relative to unity since wgtp included in ressq + wgt *= (l2_reweight_dof + 1)/(l2_reweight_dof + ressq/ovar) else: wgt = None - # we usually want to re-evaluate this since the robustness may change + # re-evaluate since robustness and or wgt after reweight may change if robustness is not None: - numba.set_num_threads(np.maximum(nthreads, 1)) + numba_threads = np.maximum(nthreads, 1) + numba.set_num_threads(numba_threads) counts = _compute_counts(uvw, - freq, - mask, - wgt, - nx, ny, - cellx, celly, - uvw.dtype, - ngrid=np.minimum(nthreads, 8), # limit number of grids - usign=1.0 if flip_u else -1.0, - vsign=1.0 if flip_v else -1.0) - imwgt = counts_to_weights( + freq, + mask, + wgt, + nx, ny, + cellx, celly, + uvw.dtype, + # limit number of grids + ngrid=np.minimum(numba_threads, 8), + usign=1.0 if flip_u else -1.0, + vsign=1.0 if flip_v else -1.0) + wgt = counts_to_weights( counts, uvw, freq, + wgt, nx, ny, cellx, celly, robustness, usign=1.0 if flip_u else -1.0, vsign=1.0 if flip_v else -1.0) - if wgt is not None: - wgt *= imwgt - else: - wgt = imwgt + + if l2_reweight_dof: + # normalise to absolute units + ressq = (residual_vis*wgt*residual_vis.conj()).real + ssq = ressq[mask>0].sum() + ovar = ssq/mask.sum() + wgt /= ovar + ressq = (residual_vis*wgt*residual_vis.conj()).real + # chi2_dof = np.mean(ressq[mask>0]) + # print(f'Band {bandid} chi2-dof changed from {chi2_dofp} to {chi2_dof}') # these are always used together if do_weight: diff --git a/pfb/utils/weighting.py b/pfb/utils/weighting.py index 64f86b53..e5955bec 100644 --- a/pfb/utils/weighting.py +++ b/pfb/utils/weighting.py @@ -99,8 +99,9 @@ def _compute_counts(uvw, freq, mask, wgt, nx, ny, for r in range(bin_idx[g], bin_idx[g] + bin_counts[g]): uvw_row = uvw[r] wgt_row = wgt[r] + mask_row = mask[r] for c in range(nchan): - if not mask[r, c]: + if not mask_row[c]: continue # current uv coords chan_normfreq = normfreq[c] @@ -136,9 +137,9 @@ def _es_kernel(x, beta, k): @njit(nogil=True, cache=True, parallel=True) -def counts_to_weights(counts, uvw, freq, nx, ny, - cell_size_x, cell_size_y, robust, - usign=1.0, vsign=-1.0): +def counts_to_weights(counts, uvw, freq, weight, nx, ny, + cell_size_x, cell_size_y, robust, + usign=1.0, vsign=-1.0): # ufreq u_cell = 1/(nx*cell_size_x) umax = np.abs(-1/cell_size_x/2 - u_cell/2) @@ -152,9 +153,8 @@ def counts_to_weights(counts, uvw, freq, nx, ny, nchan = freq.size nrow = uvw.shape[0] - weights = np.zeros((nrow, nchan), dtype=counts.dtype) if not counts.any(): - return weights + return weight # Briggs weighting factor if robust > -2: @@ -166,6 +166,7 @@ def counts_to_weights(counts, uvw, freq, nx, ny, normfreq = freq / lightspeed for r in prange(nrow): uvw_row = uvw[r] + weight_row = weight[r] for c in range(nchan): # get current uv chan_normfreq = normfreq[c] @@ -176,8 +177,8 @@ def counts_to_weights(counts, uvw, freq, nx, ny, # get v index v_idx = int(np.floor((v_tmp + vmax)/v_cell)) if counts[u_idx, v_idx]: - weights[r, c] = 1.0/counts[u_idx, v_idx] - return weights + weight_row[c] = weight_row[c]/counts[u_idx, v_idx] + return weight diff --git a/pfb/workers/grid.py b/pfb/workers/grid.py index aba86369..c2de33fe 100644 --- a/pfb/workers/grid.py +++ b/pfb/workers/grid.py @@ -380,8 +380,11 @@ def _grid(**kw): freq_out = ds_dct['freq_out'] if from_cache: - out_ds = xr.open_zarr(f'{dds_store.url}/time{timeid}_band{bandid}.zarr', + out_ds_name = f'{dds_store.url}/time{timeid}_band{bandid}.zarr' + out_ds = xr.open_zarr(out_ds_name, chunks=None) + else: + out_ds_name = None # compute lm coordinates of target if opts.target is not None: @@ -457,7 +460,7 @@ def _grid(**kw): fut = client.submit(image_data_products, dsl, - None, # counts + out_ds_name, nx, ny, nx_psf, ny_psf, cell_rad, cell_rad, diff --git a/pfb/workers/init.py b/pfb/workers/init.py index d33bef86..3a9ae00c 100644 --- a/pfb/workers/init.py +++ b/pfb/workers/init.py @@ -324,7 +324,11 @@ def _init(**kw): ac_iter = as_completed(futures) for completed_future in ac_iter: - result = completed_future.result() + try: + result = completed_future.result() + except Exception as e: + raise e + try: times_out.append(result[0]) freqs_out.append(result[1])