Skip to content

Commit

Permalink
weather-mv will ingest data into BQ from Zarr much faster. (#357)
Browse files Browse the repository at this point in the history
* Fixed issues found loading Zarr into BQ.

Found a couple of errors with loading a Zarr dataset into BigQuery.

* Base weather-tools install requires gcsfs.

* Not normalized by default.

* Parallel Zarr ingestion into BQ.

* Fix setup.py syntax error.

* Fixing Zarr + Xarray-Beam support.

* Added happy path unit test for parallel zarr reading in BQ.

* fix flake8 issues.

* Better whitespace.

* Adding open_ds kwargs to open zarr.

* Attempting to fix pickling issues.

* Another attempt to fix pickling error, now in transform.

* Experiment: is xbeam.open_zarr the issue?

* adding engine=zarr.

* open_zarr --> open_dataset w/ engine.

* delete regrid

* Pinned Zarr version.

* Hard coded current CL for docker image.

* rm unnecessary delete.

* Only recent years.

* All data w/ streaming inserts.

* Experiment: added windowing.

* Documented `timestamp_row` fn.

* Self-review: Prepared changes for PR.

* Small cleanup.

* Remove debug isel.

* Added types to `to_rows()`.

* Fixed flake8 lint errors.

* Better types for `to_rows()`.

* Test updated and 'chunks' removed from zarr_kwargs

* Zarr version updated.

---------

Co-authored-by: dabhi_cusp <[email protected]>
  • Loading branch information
alxmrs and dabhicusp authored Aug 16, 2023
1 parent 3464844 commit 9cdf0e7
Show file tree
Hide file tree
Showing 12 changed files with 163 additions and 90 deletions.
3 changes: 2 additions & 1 deletion ci3.8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -33,6 +33,7 @@ dependencies:
- ruff==0.0.260
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- zarr=2.15.0
- pip:
- cython==0.29.34
- earthengine-api==0.1.329
Expand Down
3 changes: 2 additions & 1 deletion ci3.9.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:
- requests=2.28.1
- netcdf4=1.6.1
- rioxarray=0.13.4
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- ecmwf-api-client=1.6.3
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -33,6 +33,7 @@ dependencies:
- aria2=1.36.0
- xarray==2023.1.0
- ruff==0.0.260
- zarr=2.15.0
- pip:
- cython==0.29.34
- earthengine-api==0.1.329
Expand Down
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ channels:
dependencies:
- python=3.8.13
- apache-beam=2.40.0
- xarray-beam=0.3.1
- xarray-beam=0.6.2
- xarray=2023.1.0
- fsspec=2022.11.0
- gcsfs=2022.11.0
Expand All @@ -25,6 +25,7 @@ dependencies:
- google-cloud-sdk=410.0.0
- aria2=1.36.0
- pip=22.3
- zarr=2.15.0
- pip:
- cython==0.29.34
- earthengine-api==0.1.329
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
"earthengine-api>=0.1.263",
"pyproj", # requires separate binary installation!
"gdal", # requires separate binary installation!
"xarray-beam==0.3.1",
"xarray-beam==0.6.2",
"gcsfs==2022.11.0",
"zarr==2.15.0",
]

weather_sp_requirements = [
Expand All @@ -82,6 +83,7 @@
"memray",
"pytest-memray",
"h5py",
"pooch",
]

all_test_requirements = beam_gcp_requirements + weather_dl_requirements + \
Expand Down Expand Up @@ -115,7 +117,7 @@

],
python_requires='>=3.8, <3.10',
install_requires=['apache-beam[gcp]==2.40.0'],
install_requires=['apache-beam[gcp]==2.40.0', 'gcsfs==2022.11.0'],
use_scm_version=True,
setup_requires=['setuptools_scm'],
scripts=['weather_dl/weather-dl', 'weather_mv/weather-mv', 'weather_sp/weather-sp'],
Expand Down
135 changes: 80 additions & 55 deletions weather_mv/loader_pipeline/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import geojson
import numpy as np
import xarray as xr
import xarray_beam as xbeam
from apache_beam.io import WriteToBigQuery, BigQueryDisposition
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.transforms import window
from google.cloud import bigquery
from xarray.core.utils import ensure_us_time_resolution

Expand Down Expand Up @@ -236,73 +238,90 @@ def extract_rows(self, uri: str, coordinates: t.List[t.Dict]) -> t.Iterator[t.Di
with open_dataset(uri, self.xarray_open_dataset_kwargs, self.disable_grib_schema_normalization,
self.tif_metadata_for_datetime, is_zarr=self.zarr) as ds:
data_ds: xr.Dataset = _only_target_vars(ds, self.variables)
yield from self.to_rows(coordinates, data_ds, uri)

first_ts_raw = data_ds.time[0].values if isinstance(data_ds.time.values,
np.ndarray) else data_ds.time.values
first_time_step = to_json_serializable_type(first_ts_raw)

for it in coordinates:
# Use those index values to select a Dataset containing one row of data.
row_ds = data_ds.loc[it]

# Create a Name-Value map for data columns. Result looks like:
# {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
for n, v in row_ds.data_vars.items()}

# Serialize coordinates.
it = {k: to_json_serializable_type(v) for k, v in it.items()}

# Add indexed coordinates.
row.update(it)
# Add un-indexed coordinates.
for c in row_ds.coords:
if c not in it and (not self.variables or c in self.variables):
row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))

# Add import metadata.
row[DATA_IMPORT_TIME_COLUMN] = self.import_time
row[DATA_URI_COLUMN] = uri
row[DATA_FIRST_STEP] = first_time_step

longitude = ((row['longitude'] + 180) % 360) - 180
row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
row[GEO_POLYGON_COLUMN] = (
fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
if not self.skip_creating_polygon
else None
)
# 'row' ends up looking like:
# {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
# 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
yield row
def to_rows(self, coordinates: t.Iterable[t.Dict], ds: xr.Dataset, uri: str) -> t.Iterator[t.Dict]:
first_ts_raw = (
ds.time[0].values if isinstance(ds.time.values, np.ndarray)
else ds.time.values
)
first_time_step = to_json_serializable_type(first_ts_raw)
for it in coordinates:
# Use those index values to select a Dataset containing one row of data.
row_ds = ds.loc[it]

# Create a Name-Value map for data columns. Result looks like:
# {'d': -2.0187, 'cc': 0.007812, 'z': 50049.8, 'rr': None}
row = {n: to_json_serializable_type(ensure_us_time_resolution(v.values))
for n, v in row_ds.data_vars.items()}

# Serialize coordinates.
it = {k: to_json_serializable_type(v) for k, v in it.items()}

# Add indexed coordinates.
row.update(it)
# Add un-indexed coordinates.
for c in row_ds.coords:
if c not in it and (not self.variables or c in self.variables):
row[c] = to_json_serializable_type(ensure_us_time_resolution(row_ds[c].values))

# Add import metadata.
row[DATA_IMPORT_TIME_COLUMN] = self.import_time
row[DATA_URI_COLUMN] = uri
row[DATA_FIRST_STEP] = first_time_step

longitude = ((row['longitude'] + 180) % 360) - 180
row[GEO_POINT_COLUMN] = fetch_geo_point(row['latitude'], longitude)
row[GEO_POLYGON_COLUMN] = (
fetch_geo_polygon(row["latitude"], longitude, self.lat_grid_resolution, self.lon_grid_resolution)
if not self.skip_creating_polygon
else None
)
# 'row' ends up looking like:
# {'latitude': 88.0, 'longitude': 2.0, 'time': '2015-01-01 06:00:00', 'd': -2.0187, 'cc': 0.007812,
# 'z': 50049.8, 'data_import_time': '2020-12-05 00:12:02.424573 UTC', ...}
beam.metrics.Metrics.counter('Success', 'ExtractRows').inc()
yield row

def chunks_to_rows(self, _, ds: xr.Dataset) -> t.Iterator[t.Dict]:
uri = ds.attrs.get(DATA_URI_COLUMN, '')
# Re-calculate import time for streaming extractions.
if not self.import_time or self.zarr:
self.import_time = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
yield from self.to_rows(get_coordinates(ds, uri), ds, uri)

def expand(self, paths):
"""Extract rows of variables from data paths into a BigQuery table."""
extracted_rows = (
if not self.zarr:
extracted_rows = (
paths
| 'PrepareCoordinates' >> beam.FlatMap(self.prepare_coordinates)
| beam.Reshuffle()
| 'ExtractRows' >> beam.FlatMapTuple(self.extract_rows)
)

if not self.dry_run:
(
extracted_rows
| 'WriteToBigQuery' >> WriteToBigQuery(
project=self.table.project,
dataset=self.table.dataset_id,
table=self.table.table_id,
write_disposition=BigQueryDisposition.WRITE_APPEND,
create_disposition=BigQueryDisposition.CREATE_NEVER)
)
else:
(
extracted_rows
| 'Log Extracted Rows' >> beam.Map(logger.debug)
ds, chunks = xbeam.open_zarr(self.first_uri, **self.xarray_open_dataset_kwargs)
ds.attrs[DATA_URI_COLUMN] = self.first_uri
extracted_rows = (
paths
| 'OpenChunks' >> xbeam.DatasetToChunks(ds, chunks)
| 'ExtractRows' >> beam.FlatMapTuple(self.chunks_to_rows)
| 'Window' >> beam.WindowInto(window.FixedWindows(60))
| 'AddTimestamp' >> beam.Map(timestamp_row)
)

if self.dry_run:
return extracted_rows | 'Log Rows' >> beam.Map(logger.info)
return (
extracted_rows
| 'WriteToBigQuery' >> WriteToBigQuery(
project=self.table.project,
dataset=self.table.dataset_id,
table=self.table.table_id,
write_disposition=BigQueryDisposition.WRITE_APPEND,
create_disposition=BigQueryDisposition.CREATE_NEVER)
)


def map_dtype_to_sql_type(var_type: np.dtype) -> str:
"""Maps a np.dtype to a suitable BigQuery column type."""
Expand Down Expand Up @@ -343,6 +362,12 @@ def to_table_schema(columns: t.List[t.Tuple[str, str]]) -> t.List[bigquery.Schem
return fields


def timestamp_row(it: t.Dict) -> window.TimestampedValue:
"""Associate an extracted row with the import_time timestamp."""
timestamp = it[DATA_IMPORT_TIME_COLUMN].timestamp()
return window.TimestampedValue(it, timestamp)


def fetch_geo_point(lat: float, long: float) -> str:
"""Calculates a geography point from an input latitude and longitude."""
if lat > LATITUDE_RANGE[1] or lat < LATITUDE_RANGE[0]:
Expand Down
48 changes: 41 additions & 7 deletions weather_mv/loader_pipeline/bq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import tempfile
import typing as t
import unittest

Expand All @@ -23,6 +24,8 @@
import pandas as pd
import simplejson
import xarray as xr
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, is_not_empty
from google.cloud.bigquery import SchemaField

from .bq import (
Expand Down Expand Up @@ -205,13 +208,13 @@ def extract(self, data_path, *, variables=None, area=None, open_dataset_kwargs=N
skip_creating_polygon: bool = False) -> t.Iterator[t.Dict]:
if zarr_kwargs is None:
zarr_kwargs = {}
op = ToBigQuery.from_kwargs(first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
output_table='foo.bar.baz', variables=variables, area=area,
xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time,
infer_schema=False, tif_metadata_for_datetime=tif_metadata_for_datetime,
skip_region_validation=True,
disable_grib_schema_normalization=disable_grib_schema_normalization,
coordinate_chunk_size=1000, skip_creating_polygon=skip_creating_polygon)
op = ToBigQuery.from_kwargs(
first_uri=data_path, dry_run=True, zarr=zarr, zarr_kwargs=zarr_kwargs,
output_table='foo.bar.baz', variables=variables, area=area,
xarray_open_dataset_kwargs=open_dataset_kwargs, import_time=import_time, infer_schema=False,
tif_metadata_for_datetime=tif_metadata_for_datetime, skip_region_validation=True,
disable_grib_schema_normalization=disable_grib_schema_normalization, coordinate_chunk_size=1000,
skip_creating_polygon=skip_creating_polygon)
coords = op.prepare_coordinates(data_path)
for uri, chunk in coords:
yield from op.extract_rows(uri, chunk)
Expand Down Expand Up @@ -737,5 +740,36 @@ def test_multiple_editions__with_vars__includes_coordinates_in_vars__with_schema
self.assertRowsEqual(actual, expected)


class ExtractRowsFromZarrTest(ExtractRowsTestBase):

def setUp(self) -> None:
super().setUp()
self.tmpdir = tempfile.TemporaryDirectory()

def tearDown(self) -> None:
super().tearDown()
self.tmpdir.cleanup()

def test_extracts_rows(self):
input_zarr = os.path.join(self.tmpdir.name, 'air_temp.zarr')

ds = (
xr.tutorial.open_dataset('air_temperature', cache_dir=self.test_data_folder)
.isel(time=slice(0, 4), lat=slice(0, 4), lon=slice(0, 4))
.rename(dict(lon='longitude', lat='latitude'))
)
ds.to_zarr(input_zarr)

op = ToBigQuery.from_kwargs(
first_uri=input_zarr, zarr_kwargs=dict(), dry_run=True, zarr=True, output_table='foo.bar.baz',
variables=list(), area=list(), xarray_open_dataset_kwargs=dict(), import_time=None, infer_schema=False,
tif_metadata_for_datetime=None, skip_region_validation=True, disable_grib_schema_normalization=False,
)

with TestPipeline() as p:
result = p | op
assert_that(result, is_not_empty())


if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions weather_mv/loader_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .streaming import GroupMessagesByFixedWindows, ParsePaths

logger = logging.getLogger(__name__)
SDK_CONTAINER_IMAGE='gcr.io/weather-tools-prod/weather-tools:0.0.0'
SDK_CONTAINER_IMAGE = 'gcr.io/weather-tools-prod/weather-tools:0.0.0'


def configure_logger(verbosity: int) -> None:
Expand Down Expand Up @@ -55,8 +55,9 @@ def pipeline(known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None
known_args.first_uri = next(iter(all_uris))

with beam.Pipeline(argv=pipeline_args) as p:
if known_args.topic or known_args.subscription:

if known_args.zarr:
paths = p
elif known_args.topic or known_args.subscription:
paths = (
p
# Windowing is based on this code sample:
Expand Down Expand Up @@ -140,7 +141,6 @@ def run(argv: t.List[str]) -> t.Tuple[argparse.Namespace, t.List[str]]:
# Validate Zarr arguments
if known_args.uris.endswith('.zarr'):
known_args.zarr = True
known_args.zarr_kwargs['chunks'] = known_args.zarr_kwargs.get('chunks', None)

if known_args.zarr_kwargs and not known_args.zarr:
raise ValueError('`--zarr_kwargs` argument is only allowed with valid Zarr input URI.')
Expand Down
2 changes: 1 addition & 1 deletion weather_mv/loader_pipeline/regrid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_zarr__coarsen(self):
self.Op,
first_uri=input_zarr,
output_path=output_zarr,
zarr_input_chunks={"time": 5},
zarr_input_chunks={"time": 25},
zarr=True
)

Expand Down
4 changes: 2 additions & 2 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _replace_dataarray_names_with_long_names(ds: xr.Dataset):
datetime_value_ms = None
try:
datetime_value_s = (int(end_time.timestamp()) if end_time is not None
else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
else int(ds.attrs[tif_metadata_for_datetime]) / 1000.0)
ds = ds.assign_coords({'time': datetime.datetime.utcfromtimestamp(datetime_value_s)})
except KeyError:
raise RuntimeError(f"Invalid datetime metadata of tif: {tif_metadata_for_datetime}.")
Expand Down Expand Up @@ -380,7 +380,7 @@ def open_dataset(uri: str,
"""Open the dataset at 'uri' and return a xarray.Dataset."""
try:
if is_zarr:
ds: xr.Dataset = xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs)
ds: xr.Dataset = _add_is_normalized_attr(xr.open_dataset(uri, engine='zarr', **open_dataset_kwargs), False)
beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc()
yield ds
ds.close()
Expand Down
1 change: 1 addition & 0 deletions weather_mv/loader_pipeline/sinks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_opens_zarr(self):
with open_dataset(self.test_zarr_path, is_zarr=True, open_dataset_kwargs={}) as ds:
self.assertIsNotNone(ds)
self.assertEqual(list(ds.data_vars), ['cape', 'd2m'])

def test_open_dataset__fits_memory_bounds(self):
with write_netcdf() as test_netcdf_path:
with limit_memory(max_memory=30):
Expand Down
Loading

0 comments on commit 9cdf0e7

Please sign in to comment.