Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask-ms does not scale with big datasets (spectral line datasets) #214

Closed
miguelcarcamov opened this issue Jun 6, 2022 · 21 comments
Closed

Comments

@miguelcarcamov
Copy link

miguelcarcamov commented Jun 6, 2022

  • dask-ms version: 0.2.6
  • Python version: 3.9.7
  • Operating System: Ubuntu Server 20.04

Description

I'm trying to apply a phase correction to my visibilities, therefore I loop the ms list and I apply a phase-shift. Then I write the visibilities using xds_to_table and then I do dask.write.

This should be really straight forward, however, it takes more than 2 days to run.

What I Did

import sys
import os
import numpy as np
# import re
# from astropy.io import fits
# from astropy.units import Quantity

import pyralysis
import pyralysis.io
# from pyralysis.transformers.weighting_schemes import Robust
from pyralysis.units import lambdas_equivalencies
import astropy.units as un
import dask.array as da

from pyralysis.units import array_unit_conversion


def apply_gain_shift(file_ms,
                     file_ms_output='output_dask.ms',
                     alpha_R=1.,
                     Shift=False,
                     file_ms_ref=False):

    # file_ms_ref : reference ms for pointing

    print("applying shift with alpha_R = ", alpha_R," Shift = ", Shift)
    print("file_ms :", file_ms)
    print("file_ms_output :", file_ms_output)
    print(
        "building output ms structure by copying from filen_ms to file_ms_output"
    )

    os.system("rm -rf " + file_ms_output)
    os.system("rsync -a " + file_ms + "/  " + file_ms_output + "/")

    reader = pyralysis.io.DaskMS(input_name=file_ms)
    dataset = reader.read()

    field_dataset = dataset.field.dataset

    delta_x = Shift[0] * np.pi / (180. * 3600.)
    delta_y = Shift[1] * np.pi / (180. * 3600.)

    for ms in dataset.ms_list:  # loops over spws
        uvw = ms.visibilities.uvw.data
        spw_id = ms.spw_id
        pol_id = ms.polarization_id
        ncorrs = dataset.polarization.ncorrs[pol_id]
        nchans = dataset.spws.nchans[spw_id]

        uvw_broadcast = da.tile(uvw, nchans).reshape((len(uvw), nchans, 3))

        chans = dataset.spws.dataset[spw_id].CHAN_FREQ.data.squeeze(
            axis=0).compute() * un.Hz

        chans_broadcast = chans[np.newaxis, :, np.newaxis]

        uvw_lambdas = uvw_broadcast / chans_broadcast.to(un.m, un.spectral())

        # uvw_lambdas = array_unit_conversion(
        #    array=uvw_broadcast,
        #    unit=un.lambdas,
        #    equivalencies=lambdas_equivalencies(restfreq=chans_broadcast))

        uvw_lambdas = da.map_blocks(lambda x: x.value,
                                    uvw_lambdas,
                                    dtype=np.float64)

        if Shift:
            print("applying gain and shift")
            uus = uvw_lambdas[:, :, 0]
            vvs = uvw_lambdas[:, :, 1]
            eulerphase = alpha_R * da.exp(
                2j * np.pi *
                (uus * delta_x + vvs * delta_y)).astype(np.complex64)
            ms.visibilities.data *= eulerphase[:, :, np.newaxis]
        else:
            print("applying gain")
            ms.visibilities.data *= alpha_R

    if file_ms_output:
        print("PUNCH OUPUT MS")
        if file_ms_ref:
            print(
                "paste pointing center from reference vis file into output vis file"
            )
            print("loading reference ms")

            ref_reader = pyralysis.io.DaskMS(input_name=file_ms_ref)
            ref_dataset = ref_reader.read()
            field_dataset = ref_dataset.field.dataset

            if len(field_dataset) == len(dataset.field.dataset):
                for i, row in enumerate(dataset.field.dataset):
                    row['REFERENCE_DIR'] = field_dataset[i].REFERENCE_DIR
                    row['PHASE_DIR'] = field_dataset[i].PHASE_DIR
            else:
                for i, row in enumerate(dataset.field.dataset):
                    row['REFERENCE_DIR'] = field_dataset[0].REFERENCE_DIR
                    row['PHASE_DIR'] = field_dataset[0].PHASE_DIR

            # Write FIELD TABLE
            print(" Write FIELD TABLE ")
            reader.write_xarray_ds(dataset=dataset.field.dataset,
                                   ms_name=file_ms_output,
                                   table_name="FIELD")
            # Write MAIN TABLE
            print(" Write MAIN TABLE ")
            reader.write(dataset=dataset,
                         ms_name=file_ms_output,
                         columns="DATA")

    return

I'm using pyralysis IO which uses dask-ms.

Cheers,

@JSKenyon
Copy link
Collaborator

JSKenyon commented Jun 6, 2022

This is less of a problem with dask-ms and more of a problem with the underling casacore tables. The issue is twofold:

  1. dask-ms cannot read the measurement set in parallel (outside of some circumstances) due to thread-safety (well, the complete lack thereof).
  2. the current python-casacore wrappers do not drop the Python GIL - this effectively stalls all dask-based parallelism.

There aren't any amazing solutions as present. My plan is to move away from the underlying casacore tables and instead store everything in zarr arrays. dask-ms convert already allows you to convert a whole dataset over to zarr. I have made a start on, but not yet finished, making xds_from_storage more uniform so that software will be able to roundtrip MS to MS and zarr to zarr without any code changes.

@miguelcarcamov
Copy link
Author

Thanks for the reply @JSKenyon. Yep I was worried that it could be that. Well, let me know if there are any updates incorporating zarr arrays to dask-ms. Also, how can I use the convert function inside my code? Follow up question to this: Do you know if ngCASA also uses python-casacore wrappers to read Measurement Sets?
Cheers,

@JSKenyon
Copy link
Collaborator

JSKenyon commented Jun 7, 2022

Also, how can I use the convert function inside my code?

You would likely need to use some of the lower level stuff - you can take a look at https://github.com/ratt-ru/dask-ms/blob/process-executors/daskms/apps/convert.py to see how this happens in dask-ms convert.

Do you know if ngCASA also uses python-casacore wrappers to read Measurement Sets?

I am not sure - I should actually reach out to them. Last I checked, they were also going the route of converting casacore tables to xarray datasets backed by zarr. Unfortunately, most of the limitations are in the casa tables themselves and would require a major rewrite to change.

It is technically possible to read a measurement set from multiple processes (rather than threads), but it is not possible to write in parallel. A single write will also lock all reads. The best option will likely be to read the ms in parallel from multiple processes and then write to a format which supports parallel writes (e.g. zarr if you handle chunks correctly).

@miguelcarcamov
Copy link
Author

miguelcarcamov commented Jun 7, 2022

Follow up question - What if we write the data using a TaQL query for that? (as a short term patch). I remember that casacore developers helped me a lot writing a fast writetoMS code. Check writeMS function. Also, how CASA does it then? because the read and writes in CASA are really really fast!
Also what if we use C++ casacore wrappers? would that be fast?

Cheers,
Miguel

@bennahugo
Copy link
Collaborator

bennahugo commented Jun 7, 2022

@miguelcarcamov the C++ table system underneath opens a set of tables only once, so even if you use taql or other table objects they point to the same table with the same table lock underneath the hood. If you use taql to select a smaller set to read and/or write it will of course make things faster - it looks like you are doing that in the C++ snippet.

I've experimented with adding multi-threaded reading support to the casacore table system (see casacore/casacore#1167). However this is still a work in progress and quite a large undertaking to make the table system thread safe. I have shown that you can get the same scaling as you would get from multiple processes.

As @JSKenyon mentioned currently all calls to the table system via python-casacore are gil-locked. If you want really experimental python support you can check out casacore/python-casacore#209.

If you follow the discussion you will see that there are caviats. Because the table system is not thread safe you should not pass tables between threads (edit: or open a table proxy with tables initially opened in other threads). The next task I will do on this thread is to make the table system fully threadsafe so that this can be done safely.

We don't yet fully support this in dask-ms (not until a full implementation is completed for casacore and python-casacore)

@bennahugo
Copy link
Collaborator

bennahugo commented Jun 7, 2022

An alternative I have for you is to do what the WSRT archive did and split databases out by scan -- it least on the data I have.

(edit: If you want to do this in a single script you will still need to use multiple processes due to the aforementioned GIL locking of python-casacore operations)

This way you get a table lock per table object and you can possibly chunk things that way. It is however very cumbersome and not worth it if you intend this as a once off operation?

@miguelcarcamov
Copy link
Author

Yep, in fact that's what I'm doing when reading and writing a MS. That was an advice that casacore developers gave me.
What I will do for now is to split the MS and apply the shift to each part. Then I will concatenate them all together.

@bennahugo
Copy link
Collaborator

Ok sorry about that -- unfortunately it is a fundamental limitation in the data format itself. Currently the only way you would get parallelism is to have heafty operations per chunk that significantly outweighs reading/writing. Hopefully with some support coming from the future MeerKAT / SKA data processor there will be some traction to make the data format more fine-grain parallel.

@miguelcarcamov
Copy link
Author

In fact, I realized that if you simulate a ~12h observation in one frecuency (1 channel) - this returns a dataset with ~3M rows and dask-ms reads and writes very slow as well.

@bennahugo
Copy link
Collaborator

bennahugo commented Jul 2, 2022 via email

@bennahugo
Copy link
Collaborator

bennahugo commented Jul 2, 2022 via email

@miguelcarcamov
Copy link
Author

miguelcarcamov commented Jul 2, 2022

Hi @bennahugo, sorry for not being clear on this. Check this dataset. I have read it using

from daskms import xds_from_ms
ms = xds_from_ms("9.8-9.5.ms", taql_where="!FLAG_ROW", index_cols=["SCAN_NUMBER", "TIME", "ANTENNA1", "ANTENNA2"])
ms[0].MODEL_DATA.data.compute()

This takes ~3m 44s, and it feels too long to me considering it's just one channel and two correlations.

I agree that one might need to change the chunks though. Is there anyway that one could change the chunks from the dask-ms API?

@sjperkins
Copy link
Member

Hi @bennahugo, sorry for not being clear on this. Check this dataset. I have read it using

from daskms import xds_from_ms
ms = xds_from_ms("9.8-9.5.ms", taql_where="!FLAG_ROW", index_cols=["SCAN_NUMBER", "TIME", "ANTENNA1", "ANTENNA2"])
ms[0].MODEL_DATA.data.compute()

This takes ~3m 44s, and it feels too long to me considering it's just one channel and two correlations.

I agree that one might need to change the chunks though. Is there anyway that one could change the chunks from the dask-ms API?

@miguelcarcamov I suspect the index_cols are the issue here as they're probably producing a non-contiguous disk access order. Could you confirm by repeating the test with:

dask-ms used to warn about excessively fragmented row ordering and we're thinking about reintroducing these warnings so that these kind of performance problems are more obvious to the user:

@miguelcarcamov
Copy link
Author

miguelcarcamov commented Jul 2, 2022

Hi @sjperkins, I have updated to dask-ms==0.2.9 and that was enough to get a very decent speed up, it went from ~3 minutes to ~2 seconds. I'm going to keep testing using group_cols=[] and index_cols=[]. I have a question though, if I use group_cols=[], then how dask-ms orders the list of xarray datasets when having different number of channels on each spectral window or this only is valid when you have 1 spectral window and 1 channel? I noticed if group_cols=[] then the time for a compute is reduced to the order of ~300ms !

@sjperkins
Copy link
Member

sjperkins commented Jul 4, 2022

Hi @sjperkins, I have updated to dask-ms==0.2.9 and that was enough to get a very decent speed up, it went from ~3 minutes to ~2 seconds.

Great to hear.

I'm going to keep testing using group_cols=[] and index_cols=[]. I have a question though, if I use group_cols=[], then how dask-ms orders the list of xarray datasets when having different number of channels on each spectral window or this only is valid when you have 1 spectral window and 1 channel?

Yes, this would only be valid for a MS with 1 spectral window, but you could have any number number of channels in that SPW.

I noticed if group_cols=[] then the time for a compute is reduced to the order of ~300ms !

At minimum, I think it's necessary to have DATA_DESC_ID in group_cols to handle multiple spectral windows. To handle this, GROUPING TAQL queries are performed to determine:

  1. the unique GROUPING values, DATA_DESC_ID 0, 1, 2, 3, .... for example.
  2. the rows associated with these unique GROUPING values.

This has some cost to it, probably because it needs to load each grouping column from disk.

@sjperkins
Copy link
Member

sjperkins commented Jul 4, 2022

I'm going to close this issue, but feel free to reopen if you've think I've done this in error.

@JSKenyon
Copy link
Collaborator

JSKenyon commented Jul 4, 2022

I am glad you seem to have come right @miguelcarcamov. There are actually a host of factors at play here.

Is there anyway that one could change the chunks from the dask-ms API?

Absolutely - it takes a chunks argument in the same way that xds_from_table does (see https://dask-ms.readthedocs.io/en/latest/api.html#daskms.xds_from_table). I suspect that the poor behaviour you are seeing is likely because the default chunk size for row is 10000. It is obviously very inefficient to read 3M rows 10000 rows at a time. The changes in 0.2.9 somewhat mitigate the overheads which were involved in the tiny reads, but I suspect the "correct" approach would be something like:

from daskms import xds_from_ms

ms = xds_from_ms(
    "9.8-9.5.ms",
    taql_where="!FLAG_ROW",
    index_cols=["SCAN_NUMBER", "TIME", "ANTENNA1", "ANTENNA2"],
    group_cols=["DATA_DESC_ID"],
    chunks={"row": 100000}
)

ms[0].MODEL_DATA.data.compute()

Edit: Do note that using taql_where, index_cols and group_cols will have some associated overhead.

@sjperkins
Copy link
Member

Edit: Do note that using taql_where, index_cols and group_cols will have some associated overhead.

I've created an issue tracking documentation of performance tuning concerns:

@miguelcarcamov
Copy link
Author

Edit: Do note that using taql_where, index_cols and group_cols will have some associated overhead.

I've created an issue tracking documentation of performance tuning concerns:

* [Document performance tuning #225](https://github.com/ratt-ru/dask-ms/issues/225)

Two follow up questions.

  1. How can I choose my storage manager? And use it to read and write my data?
  2. Is the zarr IO currently working on version 0.2.9?

@sjperkins
Copy link
Member

Edit: Do note that using taql_where, index_cols and group_cols will have some associated overhead.

I've created an issue tracking documentation of performance tuning concerns:

* [Document performance tuning #225](https://github.com/ratt-ru/dask-ms/issues/225)

Two follow up questions.

  1. How can I choose my storage manager? And use it to read and write my data?

If you use descriptor="ms" in calls to xds_to_table, dask-ms will try to create TiledStorage Managers based on the chunk sizes of the dask arrays.

This is supported by building CASA Table and Data Manager descriptors from xarray Datasets. There's an undocumented API to do this here: https://github.com/ratt-ru/dask-ms/tree/master/daskms/descriptors
and https://github.com/ratt-ru/dask-ms/blob/master/daskms/descriptors/ms.py is probably what you're interested in.
https://github.com/ratt-ru/dask-ms/blob/master/daskms/descriptors/ratt_ms.py is an example of a descriptor builder that adds support for fixed shape bitflag columns.

  1. Is the zarr IO currently working on version 0.2.9?

Yes, see the functionality in daskms.experimental.zarr.xds_{from,to}_zarr although we're still iterating on this.

But the idea is to for xarray Datasets to represent data independently of disk format and that conversion between formats occurs through these Datasets. So there's

  • daskms.xds_{from,to}_table
  • daskms.experimental.arrows.xds_{from,to}_parquet
  • dask.ms.experimental.zarr.xds_{from,to}_zarr

@sjperkins
Copy link
Member

I'd recommend looking through the test cases to get an understanding of how the above functionality works

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants