Skip to content

Commit

Permalink
tweak freq mapping in hci
Browse files Browse the repository at this point in the history
  • Loading branch information
landmanbester committed Aug 21, 2024
1 parent 5272dfa commit fa83bb4
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 96 deletions.
12 changes: 7 additions & 5 deletions pfb/parser/hci.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
inputs:
ms:
dtype: URI
dtype: List[URI]
required: true
abbreviation: ms
info:
Expand All @@ -9,17 +9,17 @@ inputs:
dtype: List[int]
info:
List of SCAN_NUMBERS to image. Defaults to all.
Input as string eg. '[0,2]' if running from CLI
Input as comma separated list 0,2 if running from CLI
ddids:
dtype: List[int]
info:
List of DATA_DESC_ID's to images. Defaults to all.
Input as string eg. '[0,1]' if running from CLI
Input as comma separated list 0,2 if running from CLI
fields:
dtype: List[int]
info:
List of FIELD_ID's to image. Defaults to all.
Input as string eg. '[0,1,2]' if running from CLI
Input as comma separated list 0,2 if running from CLI
freq-range:
dtype: str
info:
Expand Down Expand Up @@ -60,9 +60,11 @@ inputs:
info:
Column containing data flags. Must be the same across MSs
gain-table:
dtype: URI
dtype: List[URI]
info:
Path to Quartical gain table containing NET gains.
There must be a table for each MS and glob(ms) and
glob(gt) should match up when running from CLI.
integrations-per-image:
dtype: int
abbreviation: ipi
Expand Down
2 changes: 1 addition & 1 deletion pfb/parser/init.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ inputs:
info:
Column containing data flags. Must be the same across MSs
gain-table:
dtype: List[str]
dtype: List[URI]
info:
Path to Quartical gain table containing NET gains.
There must be a table for each MS and glob(ms) and
Expand Down
3 changes: 2 additions & 1 deletion pfb/utils/stokes2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from casacore.quanta import quantity
from datetime import datetime
from ducc0.fft import c2r, r2c, good_size
from ducc0.misc import resize_thread_pool
from africanus.constants import c as lightspeed
import gc
iFs = np.fft.ifftshift
Expand Down Expand Up @@ -56,7 +57,7 @@ def single_stokes_image(
bandid=None,
timeid=None,
wid=None):

resize_thread_pool(opts.nthreads)

with worker_client() as client:
(data, data2, ant1, ant2, uvw, frow, flag, sigma, weight,
Expand Down
153 changes: 66 additions & 87 deletions pfb/workers/hci.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# flake8: noqa
import os
import sys
from pathlib import Path
from contextlib import ExitStack
from pfb.workers.main import cli
import click
Expand All @@ -26,88 +25,69 @@ def hci(**kw):
'''
defaults.update(kw)
opts = OmegaConf.create(defaults)
timestamp = time.strftime("%Y%m%d-%H%M%S")
ldir = Path(opts.log_directory).resolve()
ldir.mkdir(parents=True, exist_ok=True)
pyscilog.log_to_file(f'{ldir}/hci_{timestamp}.log')
print(f'Logs will be written to {str(ldir)}/fastim_{timestamp}.log', file=log)

from pfb.utils.naming import set_output_names
opts, basedir, oname = set_output_names(opts)

import psutil
nthreads = psutil.cpu_count(logical=True)
ncpu = psutil.cpu_count(logical=False)
if opts.nthreads is None:
opts.nthreads = nthreads//2
ncpu = ncpu//2

if opts.product.upper() not in ["I","Q", "U", "V"]:
raise NotImplementedError(f"Product {opts.product} not yet supported")


from daskms.fsspec_store import DaskMSStore
msstore = DaskMSStore(opts.ms.rstrip('/'))
mslist = msstore.fs.glob(opts.ms.rstrip('/'))
try:
assert len(mslist) == 1
except:
raise ValueError(f"There must be a single MS corresponding "
f"to {opts.ms}")
opts.ms = mslist[0]
if opts.gain_table is not None:
gainstore = DaskMSStore(opts.gain_table.rstrip('/'))
gtlist = gainstore.fs.glob(opts.gain_table.rstrip('/'))
msnames = []
for ms in opts.ms:
msstore = DaskMSStore(ms.rstrip('/'))
mslist = msstore.fs.glob(ms.rstrip('/'))
try:
assert len(gtlist) == 1
except Exception as e:
raise ValueError(f"There must be a single gain table "
f"corresponding to {opts.gain_table}")
opts.gain_table = gtlist[0]
if opts.transfer_model_from is not None:
tmf = opts.transfer_model_from.rstrip('/')
modelstore = DaskMSStore(tmf)
try:
assert modelstore.exists()
except Exception as e:
raise ValueError(f"There must be a single model corresponding "
f"to {opts.transfer_model_from}")
opts.transfer_model_from = modelstore.url
assert len(mslist) > 0
msnames.append(*list(map(msstore.fs.unstrip_protocol, mslist)))
except:
raise ValueError(f"No MS at {ms}")
opts.ms = msnames
if opts.gain_table is not None:
gainnames = []
for gt in opts.gain_table:
gainstore = DaskMSStore(gt.rstrip('/'))
gtlist = gainstore.fs.glob(gt.rstrip('/'))
try:
assert len(gtlist) > 0
gainnames.append(*list(map(gainstore.fs.unstrip_protocol, gtlist)))
except Exception as e:
raise ValueError(f"No gain table at {gt}")
opts.gain_table = gainnames

OmegaConf.set_struct(opts, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")
logname = f'{str(opts.log_directory)}/hci_{timestamp}.log'
pyscilog.log_to_file(logname)
print(f'Logs will be written to {logname}', file=log)

# TODO - prettier config printing
print('Input Options:', file=log)
for key in opts.keys():
print(' %25s = %s' % (key, opts[key]), file=log)

basename = f'{basedir}/{oname}'

from pfb import set_envs
from ducc0.misc import resize_thread_pool, thread_pool_size
resize_thread_pool(opts.nthreads)
set_envs(opts.nthreads, ncpu)

with ExitStack() as stack:
os.environ["OMP_NUM_THREADS"] = str(opts.nthreads)
os.environ["OPENBLAS_NUM_THREADS"] = str(opts.nthreads)
os.environ["MKL_NUM_THREADS"] = str(opts.nthreads)
os.environ["VECLIB_MAXIMUM_THREADS"] = str(opts.nthreads)
paths = sys.path
ppath = [paths[i] for i in range(len(paths)) if 'pfb/bin' in paths[i]]
if len(ppath):
ldpath = ppath[0].replace('bin', 'lib')
ldcurrent = os.environ.get('LD_LIBRARY_PATH', '')
os.environ["LD_LIBRARY_PATH"] = f'{ldpath}:{ldcurrent}'
# TODO - should we fall over in else?
os.environ["NUMBA_NUM_THREADS"] = str(opts.nthreads)

import numexpr as ne
max_cores = ne.detect_number_of_cores()
ne_threads = min(max_cores, opts.nthreads)
os.environ["NUMEXPR_NUM_THREADS"] = str(ne_threads)
import dask
dask.config.set(**{'array.slicing.split_large_chunks': False})

# set up client
host_address = opts.host_address or os.environ.get("DASK_SCHEDULER_ADDRESS")
if host_address is not None:
from distributed import Client
print("Initialising distributed client.", file=log)
client = stack.enter_context(Client(host_address))
else:
from dask.distributed import Client, LocalCluster
print("Initialising client with LocalCluster.", file=log)
cluster = LocalCluster(processes=True,
n_workers=opts.nworkers,
threads_per_worker=1,
memory_limit=0,
asynchronous=False)
cluster = stack.enter_context(cluster)
client = stack.enter_context(Client(cluster,
direct_to_workers=False))

client.wait_for_workers(opts.nworkers)
client.amm.stop()
from pfb import set_client
from distributed import wait, get_client
client = set_client(opts.nworkers, stack, log)

ti = time.time()
_hci(**opts)
Expand Down Expand Up @@ -135,7 +115,7 @@ def _hci(**kw):
from pfb.utils.stokes2im import single_stokes_image
import xarray as xr

basename = f'{opts.output_filename}_{opts.product.upper()}'
basename = f'{opts.output_filename}'

fdsstore = DaskMSStore(f'{basename}.fds')
if fdsstore.exists():
Expand All @@ -150,9 +130,10 @@ def _hci(**kw):
fs.makedirs(fdsstore.url, exist_ok=True)

if opts.gain_table is not None:
gain_name = "::".join(opts.gain_table.rstrip('/').rsplit("/", 1))
tmpf = lambda x: '::'.join(x.rsplit('/', 1))
gain_names = list(map(tmpf, opts.gain_table))
else:
gain_name = None
gain_names = None

if opts.freq_range is not None:
fmin, fmax = opts.freq_range.strip(' ').split(':')
Expand All @@ -177,18 +158,20 @@ def _hci(**kw):
freqs, utimes, ms_chunks, gain_chunks, radecs, \
chan_widths, uv_max, antpos, poltype = \
construct_mappings(opts.ms,
gain_name,
gain_names,
ipi=opts.integrations_per_image,
cpi=opts.channels_per_degrid_image,
freq_min=freq_min,
freq_max=freq_max)

max_freq = 0
ms = opts.ms
ms = opts.ms[0]
gain_name = gain_names[0]
# for ms in opts.ms:
for idt in freqs[ms].keys():
freq = freqs[ms][idt]
max_freq = np.maximum(max_freq, freq.max())
mask = (freq <= freq_max) & (freq >= freq_min)
max_freq = np.maximum(max_freq, freq[mask].max())

# cell size
cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)
Expand Down Expand Up @@ -293,20 +276,19 @@ def _hci(**kw):

# a flat list to use with as_completed
datasets = []

for ids, ds in enumerate(xds):
fid = ds.FIELD_ID
ddid = ds.DATA_DESC_ID
scanid = ds.SCAN_NUMBER
# TODO - cleaner syntax
if opts.fields is not None:
if fid not in opts.fields:
if fid not in list(map(int, opts.fields)):
continue
if opts.ddids is not None:
if ddid not in opts.ddids:
if ddid not in list(map(int, opts.ddids)):
continue
if opts.scans is not None:
if scanid not in opts.scans:
if scanid not in list(map(int, opts.scans)):
continue


Expand All @@ -328,9 +310,6 @@ def _hci(**kw):
# select all rows for output dataset
Irow = slice(ridx[0], ridx[-1] + rcnts[-1])

fitr = enumerate(zip(freq_mapping[ms][idt]['start_indices'],
freq_mapping[ms][idt]['counts']))

# TODO - cpdi to cpgi mapping
# assumes cpdi is integer multiple of cpgi
nbandi = freq_mapping[ms][idt]['start_indices'].size
Expand All @@ -343,7 +322,7 @@ def _hci(**kw):
cpdi = nfreqs
else:
cpdi = opts.channels_per_degrid_image
fbins_per_band = int(cpgi / cpdi)
fbins_per_band = int(np.round(cpgi / cpdi))
nband = int(np.ceil(nbandi/fbins_per_band))

for fi in range(nband):
Expand All @@ -370,9 +349,9 @@ def _hci(**kw):
freqs[ms][idt][Inu],
utimes[ms][idt][It],
ridx, rcnts,
fidx, fcnts,
fidx-fidx.min(), fcnts, # start counting from zero
radecs[ms][idt],
fi, ti])
fi, ti, ms])

futures = []
associated_workers = {}
Expand All @@ -382,7 +361,7 @@ def _hci(**kw):
while idle_workers: # Seed each worker with a task.

(subds, jones, freqsi, utimesi, ridx, rcnts, fidx, fcnts,
radeci, fi, ti) = datasets[n_launched]
radeci, fi, ti, ms) = datasets[n_launched]
data2 = None if dc2 is None else getattr(subds, dc2).data
sc = opts.sigma_column
sigma = None if sc is None else getattr(subds, sc).data
Expand Down Expand Up @@ -438,7 +417,7 @@ def _hci(**kw):
continue

(subds, jones, freqsi, utimesi, ridx, rcnts, fidx, fcnts,
radeci, fi, ti) = datasets[n_launched]
radeci, fi, ti, ms) = datasets[n_launched]
data2 = None if dc2 is None else getattr(subds, dc2).data
sc = opts.sigma_column
sigma = None if sc is None else getattr(subds, sc).data
Expand Down
2 changes: 0 additions & 2 deletions pfb/workers/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def init(**kw):
opts.nthreads = nthreads//2
ncpu = ncpu//2

OmegaConf.set_struct(opts, True)

if opts.product.upper() not in ["I","Q", "U", "V"]:
raise NotImplementedError(f"Product {opts.product} not yet supported")

Expand Down

0 comments on commit fa83bb4

Please sign in to comment.