diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index bff24ad5..094b3a17 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import csv import dataclasses import json import logging +import math import os import re import shutil @@ -27,7 +29,6 @@ import apache_beam as beam import ee import numpy as np -import xarray as xr from apache_beam.io.filesystems import FileSystems from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE from apache_beam.options.pipeline_options import PipelineOptions @@ -36,7 +37,7 @@ from google.auth.transport import requests from rasterio.io import MemoryFile -from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin +from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp logger = logging.getLogger(__name__) @@ -51,6 +52,7 @@ 'IMAGE': '.tiff', 'TABLE': '.csv' } +ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write. def is_compute_engine() -> bool: @@ -486,21 +488,40 @@ def convert_to_asset(self, queue: Queue, uri: str): channel_names = [] file_name = f'{asset_name}.csv' - df = xr.Dataset.to_dataframe(ds) - df = df.reset_index() - # NULL and NaN create data-type mismatch issue in ee therefore replacing all of them. - # fillna fills in NaNs, NULLs, and NaTs but we have to exclude NaTs. - non_nat = df.select_dtypes(exclude=['datetime', 'timedelta', 'datetimetz']) - df[non_nat.columns] = non_nat.fillna(-9999) + shape = math.prod(list(ds.dims.values())) + # Names of dimesions, coordinates and data variables. + dims = list(ds.dims) + coords = [c for c in list(ds.coords) if c not in dims] + vars = list(ds.data_vars) + header = dims + coords + vars - # Copy in-memory dataframe to gcs. + # Data of dimesions, coordinates and data variables. + dims_data = [ds[dim].data for dim in dims] + coords_data = [np.full((shape,), ds[coord].data) for coord in coords] + vars_data = [ds[var].data.flatten() for var in vars] + data = coords_data + vars_data + + dims_shape = [len(ds[dim].data) for dim in dims] + + def get_dims_data(index: int) -> t.List[t.Any]: + """Returns dimensions for the given flattened index.""" + return [ + dim[int(index / math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data) + ] + + # Copy CSV to gcs. target_path = os.path.join(self.asset_location, file_name) - with tempfile.NamedTemporaryFile() as tmp_df: - df.to_csv(tmp_df.name, index=False) - tmp_df.flush() - tmp_df.seek(0) - with FileSystems().create(target_path) as dst: - shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE) + with tempfile.NamedTemporaryFile() as temp: + with open(temp.name, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerows([header]) + # Write rows in batches. + for i in range(0, shape, ROWS_PER_WRITE): + writer.writerows( + [get_dims_data(i) + list(row) for row in zip(*[d[i:i + ROWS_PER_WRITE] for d in data])] + ) + + upload(temp.name, target_path) asset_data = AssetData( name=asset_name, diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index 0f8cd561..22569bbb 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -326,6 +326,11 @@ def __open_dataset_file(filename: str, False) +def upload(src: str, dst: str) -> None: + """Uploads a file to the specified GCS bucket destination.""" + subprocess.run(f'gsutil -m cp {src} {dst}'.split(), check=True, capture_output=True, text=True, input="n/n") + + def copy(src: str, dst: str) -> None: """Copy data via `gcloud alpha storage` or `gsutil`.""" errors: t.List[subprocess.CalledProcessError] = []