Skip to content

Commit

Permalink
Update sparse data ingestion to handle huge datasets (#376)
Browse files Browse the repository at this point in the history
* Creating CSVs using file IO instead of pandas df

* Updated meshgrid call

* Removed meshgrid, since it could run out of memory for multiple dimensions

* Added comments.

* Resolved a CI error

* using csv library to write a csv + small changes.
  • Loading branch information
deepgabani8 authored Aug 11, 2023
1 parent b8bb7be commit 3464844
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 15 deletions.
51 changes: 36 additions & 15 deletions weather_mv/loader_pipeline/ee.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import csv
import dataclasses
import json
import logging
import math
import os
import re
import shutil
Expand All @@ -27,7 +29,6 @@
import apache_beam as beam
import ee
import numpy as np
import xarray as xr
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE
from apache_beam.options.pipeline_options import PipelineOptions
Expand All @@ -36,7 +37,7 @@
from google.auth.transport import requests
from rasterio.io import MemoryFile

from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin
from .sinks import ToDataSink, open_dataset, open_local, KwargsFactoryMixin, upload
from .util import make_attrs_ee_compatible, RateLimit, validate_region, get_utc_timestamp

logger = logging.getLogger(__name__)
Expand All @@ -51,6 +52,7 @@
'IMAGE': '.tiff',
'TABLE': '.csv'
}
ROWS_PER_WRITE = 10_000 # Number of rows per feature collection write.


def is_compute_engine() -> bool:
Expand Down Expand Up @@ -486,21 +488,40 @@ def convert_to_asset(self, queue: Queue, uri: str):
channel_names = []
file_name = f'{asset_name}.csv'

df = xr.Dataset.to_dataframe(ds)
df = df.reset_index()
# NULL and NaN create data-type mismatch issue in ee therefore replacing all of them.
# fillna fills in NaNs, NULLs, and NaTs but we have to exclude NaTs.
non_nat = df.select_dtypes(exclude=['datetime', 'timedelta', 'datetimetz'])
df[non_nat.columns] = non_nat.fillna(-9999)
shape = math.prod(list(ds.dims.values()))
# Names of dimesions, coordinates and data variables.
dims = list(ds.dims)
coords = [c for c in list(ds.coords) if c not in dims]
vars = list(ds.data_vars)
header = dims + coords + vars

# Copy in-memory dataframe to gcs.
# Data of dimesions, coordinates and data variables.
dims_data = [ds[dim].data for dim in dims]
coords_data = [np.full((shape,), ds[coord].data) for coord in coords]
vars_data = [ds[var].data.flatten() for var in vars]
data = coords_data + vars_data

dims_shape = [len(ds[dim].data) for dim in dims]

def get_dims_data(index: int) -> t.List[t.Any]:
"""Returns dimensions for the given flattened index."""
return [
dim[int(index / math.prod(dims_shape[i+1:])) % len(dim)] for (i, dim) in enumerate(dims_data)
]

# Copy CSV to gcs.
target_path = os.path.join(self.asset_location, file_name)
with tempfile.NamedTemporaryFile() as tmp_df:
df.to_csv(tmp_df.name, index=False)
tmp_df.flush()
tmp_df.seek(0)
with FileSystems().create(target_path) as dst:
shutil.copyfileobj(tmp_df, dst, WRITE_CHUNK_SIZE)
with tempfile.NamedTemporaryFile() as temp:
with open(temp.name, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerows([header])
# Write rows in batches.
for i in range(0, shape, ROWS_PER_WRITE):
writer.writerows(
[get_dims_data(i) + list(row) for row in zip(*[d[i:i + ROWS_PER_WRITE] for d in data])]
)

upload(temp.name, target_path)

asset_data = AssetData(
name=asset_name,
Expand Down
5 changes: 5 additions & 0 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ def __open_dataset_file(filename: str,
False)


def upload(src: str, dst: str) -> None:
"""Uploads a file to the specified GCS bucket destination."""
subprocess.run(f'gsutil -m cp {src} {dst}'.split(), check=True, capture_output=True, text=True, input="n/n")


def copy(src: str, dst: str) -> None:
"""Copy data via `gcloud alpha storage` or `gsutil`."""
errors: t.List[subprocess.CalledProcessError] = []
Expand Down

0 comments on commit 3464844

Please sign in to comment.