diff --git a/weather_mv/README.md b/weather_mv/README.md index e99f5b4c..f77ed41b 100644 --- a/weather_mv/README.md +++ b/weather_mv/README.md @@ -398,6 +398,8 @@ _Command options_: * `--initialization_time_regex`: A Regex string to get the initialization time from the filename. * `--forecast_time_regex`: A Regex string to get the forecast/end time from the filename. * `--group_common_hypercubes`: A flag that allows to split up large grib files into multiple level-wise ImageCollections / COGS. +* `--tiff_config`: Configs to handle source data with more than two dimensions. It is a JSON string containing key as `dims` (array with dimensions). Based on dimensions in `dims` separate tiff files will be created. By default multiple tiff files will be generated if source data has multiple time or step values. For grib files with multiple datasets (It is not necessary that every dataset will have same dimensions so some datasets may contain dimension some may not contain) if dataset does not contain dimension provided in `dims` then that dataset will be stored as bands. +`e.g {"dims": ["isobaricInhPa"]}` Invoke with `ee -h` or `earthengine --help` to see the full range of options. diff --git a/weather_mv/loader_pipeline/bq.py b/weather_mv/loader_pipeline/bq.py index bc71da01..9deea6d3 100644 --- a/weather_mv/loader_pipeline/bq.py +++ b/weather_mv/loader_pipeline/bq.py @@ -150,6 +150,9 @@ def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.Lis if known_args.area: assert len(known_args.area) == 4, 'Must specify exactly 4 lat/long values for area: N, W, S, E boundaries.' + # Add a check for tiff_config. + if pipeline_options_dict.get('tiff_config'): + raise RuntimeError('--tiff_config can be specified only for earth engine ingestions.') # Add a check for group_common_hypercubes. if pipeline_options_dict.get('group_common_hypercubes'): raise RuntimeError('--group_common_hypercubes can be specified only for earth engine ingestions.') diff --git a/weather_mv/loader_pipeline/ee.py b/weather_mv/loader_pipeline/ee.py index 2ad914dd..2a14e20e 100644 --- a/weather_mv/loader_pipeline/ee.py +++ b/weather_mv/loader_pipeline/ee.py @@ -243,6 +243,7 @@ class ToEarthEngine(ToDataSink): band_names_mapping: str initialization_time_regex: str forecast_time_regex: str + tiff_config: t.Dict @classmethod def add_parser_arguments(cls, subparser: argparse.ArgumentParser): @@ -283,6 +284,8 @@ def add_parser_arguments(cls, subparser: argparse.ArgumentParser): help='A Regex string to get the initialization time from the filename.') subparser.add_argument('--forecast_time_regex', type=str, default=None, help='A Regex string to get the forecast/end time from the filename.') + subparser.add_argument('--tiff_config', type=json.loads, default={"dims":[]}, + help='Config to create assets splitted by given dimensions.') @classmethod def validate_arguments(cls, known_args: argparse.Namespace, pipeline_args: t.List[str]) -> None: @@ -398,8 +401,10 @@ def process(self, uri: str) -> t.Iterator[str]: # Checks if the asset is already present in the GCS bucket or not. target_path = os.path.join( - self.asset_location, f'{asset_name}{ASSET_TYPE_TO_EXTENSION_MAPPING[self.ee_asset_type]}') - if not self.force_overwrite and FileSystems.exists(target_path): + self.asset_location, f'{asset_name}*{ASSET_TYPE_TO_EXTENSION_MAPPING[self.ee_asset_type]}') + files = FileSystems.match([target_path]) + + if not self.force_overwrite and files[0].metadata_list: logger.info(f'Asset file {target_path} already exists in GCS bucket. Skipping...') return @@ -430,6 +435,7 @@ class ConvertToAsset(beam.DoFn, beam.PTransform, KwargsFactoryMixin): band_names_dict: t.Optional[t.Dict] = None initialization_time_regex: t.Optional[str] = None forecast_time_regex: t.Optional[str] = None + tiff_config: t.Optional[t.Dict] = None def add_to_queue(self, queue: Queue, item: t.Any): """Adds a new item to the queue. @@ -451,7 +457,8 @@ def convert_to_asset(self, queue: Queue, uri: str): band_names_dict=self.band_names_dict, initialization_time_regex=self.initialization_time_regex, forecast_time_regex=self.forecast_time_regex, - group_common_hypercubes=self.group_common_hypercubes) as ds_list: + group_common_hypercubes=self.group_common_hypercubes, + tiff_config = self.tiff_config) as ds_list: if not isinstance(ds_list, list): ds_list = [ds_list] @@ -460,8 +467,9 @@ def convert_to_asset(self, queue: Queue, uri: str): data = list(ds.values()) asset_name = get_ee_safe_name(uri) channel_names = [da.name for da in data] - start_time, end_time, is_normalized = (attrs.get(key) for key in - ('start_time', 'end_time', 'is_normalized')) + start_time, end_time, is_normalized, forecast_hour = (attrs.get(key) for key in + ('start_time', 'end_time', 'is_normalized', 'forecast_hour')) + dtype, crs, transform = (attrs.pop(key) for key in ['dtype', 'crs', 'transform']) attrs.update({'is_normalized': str(is_normalized)}) # EE properties does not support bool. # Adding job_start_time to properites. @@ -469,25 +477,38 @@ def convert_to_asset(self, queue: Queue, uri: str): # Make attrs EE ingestable. attrs = make_attrs_ee_compatible(attrs) + if start_time: + st = re.sub("[^0-9]","",start_time) + asset_name = f"{asset_name}_{st}" + if forecast_hour: + asset_name = f"{asset_name}_FH-{forecast_hour}" + if self.tiff_config: + for var in set(self.tiff_config["dims"]).difference(['time','step']): + var_val = ds.get(var) + if var_val is not None: + if var_val >= 10: + asset_name = f"{asset_name}_{var}_{var_val.values:.0f}" + else: + asset_name = f"{asset_name}_{var}_{var_val.values:.2f}".replace('.', '_') if self.group_common_hypercubes: level, height = (attrs.pop(key) for key in ['level', 'height']) safe_level_name = get_ee_safe_name(level) asset_name = f'{asset_name}_{safe_level_name}' + asset_name = get_ee_safe_name(asset_name) # For tiff ingestions. if self.ee_asset_type == 'IMAGE': file_name = f'{asset_name}.tiff' - with MemoryFile() as memfile: with memfile.open(driver='COG', - dtype=dtype, - width=data[0].data.shape[1], - height=data[0].data.shape[0], - count=len(data), - nodata=np.nan, - crs=crs, - transform=transform, - compress='lzw') as f: + dtype=dtype, + width=data[0].data.shape[1], + height=data[0].data.shape[0], + count=len(data), + nodata=np.nan, + crs=crs, + transform=transform, + compress='lzw') as f: for i, da in enumerate(data): f.write(da, i+1) # Making the channel name EE-safe before adding it as a band name. @@ -539,11 +560,10 @@ def get_dims_data(index: int) -> t.List[t.Any]: writer.writerows( [get_dims_data(i) + list(row) for row in zip( *[d[i:i + ROWS_PER_WRITE] for d in data] - )] + )] ) upload(temp.name, target_path) - asset_data = AssetData( name=asset_name, target_path=target_path, diff --git a/weather_mv/loader_pipeline/ee_test.py b/weather_mv/loader_pipeline/ee_test.py index 7627152e..6f64c5f4 100644 --- a/weather_mv/loader_pipeline/ee_test.py +++ b/weather_mv/loader_pipeline/ee_test.py @@ -65,24 +65,98 @@ class ConvertToAssetTests(TestDataBase): def setUp(self) -> None: super().setUp() self.tmpdir = tempfile.TemporaryDirectory() - self.convert_to_image_asset = ConvertToAsset(asset_location=self.tmpdir.name) - self.convert_to_table_asset = ConvertToAsset(asset_location=self.tmpdir.name, ee_asset_type='TABLE') + self.convert_to_image_asset = ConvertToAsset( + asset_location=self.tmpdir.name, + tiff_config={"dims": []}, + ) + self.convert_to_table_asset = ConvertToAsset( + asset_location=self.tmpdir.name, + ee_asset_type='TABLE' + ) def tearDown(self): self.tmpdir.cleanup() def test_convert_to_image_asset(self): data_path = f'{self.test_data_folder}/test_data_grib_single_timestep' - asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep.tiff') + asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep_20211018060000.tiff') next(self.convert_to_image_asset.process(data_path)) # The size of tiff is expected to be more than grib. self.assertTrue(os.path.getsize(asset_path) > os.path.getsize(data_path)) + def test_convert_to_multiple_image_assets(self): + data_path = f'{self.test_data_folder}/test_data_20180101.nc' + # file with multiple time values will generate separate tiff files + time_arr = ["20180102060000", "20180102070000", "20180102080000", "20180102090000", + "20180102100000", "20180102110000", "20180102120000", "20180102130000", + "20180102140000", "20180102150000", "20180102160000", "20180102170000", + "20180102180000", "20180102190000", "20180102200000", "20180102210000", + "20180102220000", "20180102230000", "20180103000000", "20180103010000", + "20180103020000", "20180103030000", "20180103040000", "20180103050000", + "20180103060000"] + it = self.convert_to_image_asset.process(data_path) + total_assets_size = 0 + for time in time_arr: + next(it) + asset_path = os.path.join(self.tmpdir.name, f'test_data_20180101_{time}.tiff') + self.assertTrue(os.path.lexists(asset_path)) + total_assets_size += os.path.getsize(asset_path) + + # The size of all tiff combined is expected to be more than source file. + self.assertTrue(total_assets_size > os.path.getsize(data_path)) + + def test_convert_to_multiple_image_assets_with_grib_multiple_edition_default_behaviour(self): + # default behaviour i.e if user does not provide any dimension in tiff_config then for grib + # with multiple time values will generate separate tiff files + + data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_multiple_timestep.grib2' + time_arr = ["20230614000000","20230614060000","20230614120000"] + it = self.convert_to_image_asset.process(data_path) + total_assets_size = 0 + for time in time_arr: + next(it) + asset_path = os.path.join(self.tmpdir.name, + f'test_data_grib_multiple_edition_multiple_timestep_{time}_FH-6.tiff') + self.assertTrue(os.path.lexists(asset_path)) + total_assets_size += os.path.getsize(asset_path) + + self.assertTrue(total_assets_size > os.path.getsize(data_path)) + + def test_convert_to_multiple_image_assets_with_grib_multiple_edition(self): + + data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_multiple_timestep.grib2' + expected = [ + {"time": "20230614000000", "depthBelowLandLayer": "0_00"}, + {"time": "20230614000000", "depthBelowLandLayer": "0_10"}, + {"time": "20230614060000", "depthBelowLandLayer": "0_00"}, + {"time": "20230614060000", "depthBelowLandLayer": "0_10"}, + {"time": "20230614120000", "depthBelowLandLayer": "0_00"}, + {"time": "20230614120000", "depthBelowLandLayer": "0_10"} + ] + convert_to_image_asset = ConvertToAsset( + asset_location=self.tmpdir.name, + tiff_config={"dims": ['depthBelowLandLayer']}, + ) + it = convert_to_image_asset.process(data_path) + total_assets_size = 0 + for obj in expected: + next(it) + time,dbl = obj.values() + asset_name = f'test_data_grib_multiple_edition_multiple_timestep_{time}_FH-6_depthBelowLandLayer_{dbl}.tiff' + asset_path = os.path.join(self.tmpdir.name, asset_name) + self.assertTrue(os.path.lexists(asset_path)) + total_assets_size += os.path.getsize(asset_path) + + self.assertTrue(total_assets_size > os.path.getsize(data_path)) + def test_convert_to_image_asset__with_multiple_grib_edition(self): data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' - asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_multiple_edition_single_timestep.tiff') + asset_path = os.path.join( + self.tmpdir.name, + 'test_data_grib_multiple_edition_single_timestep_20211210120000_FH-8.tiff', + ) next(self.convert_to_image_asset.process(data_path)) @@ -91,7 +165,7 @@ def test_convert_to_image_asset__with_multiple_grib_edition(self): def test_convert_to_table_asset(self): data_path = f'{self.test_data_folder}/test_data_grib_single_timestep' - asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep.csv') + asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_single_timestep_20211018060000.csv') next(self.convert_to_table_asset.process(data_path)) @@ -100,7 +174,10 @@ def test_convert_to_table_asset(self): def test_convert_to_table_asset__with_multiple_grib_edition(self): data_path = f'{self.test_data_folder}/test_data_grib_multiple_edition_single_timestep.bz2' - asset_path = os.path.join(self.tmpdir.name, 'test_data_grib_multiple_edition_single_timestep.csv') + asset_path = os.path.join( + self.tmpdir.name, + 'test_data_grib_multiple_edition_single_timestep_20211210120000_FH-8.csv', + ) next(self.convert_to_table_asset.process(data_path)) diff --git a/weather_mv/loader_pipeline/sinks.py b/weather_mv/loader_pipeline/sinks.py index fac3e3b9..4ff752ba 100644 --- a/weather_mv/loader_pipeline/sinks.py +++ b/weather_mv/loader_pipeline/sinks.py @@ -18,7 +18,9 @@ import dataclasses import datetime import inspect +import itertools import logging +from operator import itemgetter import os import re import shutil @@ -34,6 +36,7 @@ import xarray as xr from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE from pyproj import Transformer +from collections import defaultdict TIF_TRANSFORM_CRS_TO = "EPSG:4326" # A constant for all the things in the coords key set that aren't the level name. @@ -220,14 +223,18 @@ def _to_utc_timestring(np_time: np.datetime64) -> str: return datetime.datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%dT%H:%M:%SZ') -def _add_is_normalized_attr(ds: xr.Dataset, value: bool) -> xr.Dataset: +def _add_is_normalized_attr(ds_list: t.Union[xr.Dataset, t.List[xr.Dataset]], value: bool) -> xr.Dataset: """Adds is_normalized to the attrs of the xarray.Dataset. This attribute represents if the dataset is the merged dataset (i.e. created by combining N datasets, specifically for normalizing grib's schema) or not. """ - ds.attrs['is_normalized'] = value - return ds + if isinstance(ds_list, list): + for dataset in ds_list: + dataset.attrs['is_normalized'] = value + else: + ds_list.attrs['is_normalized'] = value + return ds_list def _is_3d_da(da): @@ -329,18 +336,265 @@ def __normalize_grib_dataset(filename: str, return _data_array_list +def create_partition_configs(ds:xr.Dataset, tiff_config_dims:t.List[str]): + """ + Produces indexes and groups them according to tiff_config_dims + it returns a list of partition configs. + For example, if the dimensions are 'step' and 'isobaricInhPa', it would produce + a list like: + case 1. default (grouping is done for values of step and time dimensions) + i.e tiff_config_dims value is [] + [ + [ + {'step': 0, 'isobaricInhPa': 0}, + {'step': 0, 'isobaricInhPa': 1} + ], + [ + {'step': 1, 'isobaricInhPa': 0}, + {'step': 1, 'isobaricInhPa': 1} + ] + ] + + case 2. tiff_config_dims value is ['isobaricInhPa'] + Along with time and step, 'isobaricInhPa' is also considered for grouping. + [ + [{'step': 0, 'isobaricInhPa': 0}], + [{'step': 0, 'isobaricInhPa': 1}], + [{'step': 1, 'isobaricInhPa': 0}], + [{'step': 1, 'isobaricInhPa': 1}] + ] + """ + default_tiff_dims = set([x for x in ['time', 'step'] if x in ds.keys()]) + + # time and step are default dimensions used to split the dataset + # if time and step are not dimensions(i.e not indexed) then assign them as dimensions + # because we are splitting dataset using index + for dim in default_tiff_dims: + if not ds.dims.get(dim): + ds = ds.expand_dims(dim) + + # along with time and step add user preferred dimensions + default_tiff_dims = default_tiff_dims.union(tiff_config_dims) + coords_set = set(ds.coords.keys()) + flat_dims = set(coords_set).difference(DEFAULT_COORD_KEYS) + flat_dims.difference_update(default_tiff_dims) # find the dimensions/coords that needs to be flatten + + # store the dimensions of dataset excluding latitude and longitude + # this dimensions will be used to create the groups of partition + latitude_dims = ds.coords['latitude'].dims + longitude_dims = ds.coords['longitude'].dims + dims = [ dim for dim in ds.dims if dim not in latitude_dims + longitude_dims] + _dims = [range(len(ds[x])) for x in dims] + default_tiff_dims = dims if not default_tiff_dims else default_tiff_dims + + iter = list(itertools.product(*_dims)) # produces a Cartesian-Cross over the range of dimensions. + configs = [] + for i in iter: + _config = {} + for idx,val in enumerate(i): # indexes are converted back to keys, and a dictionary object is generated + _config[dims[idx]] = val + configs.append(_config) + groups = [] + for k, g in itertools.groupby(configs, key=itemgetter(*default_tiff_dims)): # group the configs + groups.append(list(g)) + return groups, ds, flat_dims + + +def create_partition(ds:xr.Dataset, group:t.List[dict], flat_dims:t.Set[str], coords_set:t.Set[str]): + da_units = {} + da_list = [] + for conf in group: + dataset = ds.isel(conf) + for var in dataset.data_vars.keys(): + da = dataset[var] + attrs = da.attrs + forecast_hour = int(da.step.values / np.timedelta64(1, 'h')) if 'step' in coords_set else None + + # We are going to treat the time field as start_time and the + # valid_time field as the end_time for EE purposes. Also, get the + # times into UTC timestrings. + start_time = _to_utc_timestring(da.time.values) + + end_time = _to_utc_timestring(da.valid_time.values) if 'valid_time' in coords_set else start_time + + + attrs['forecast_hour'] = forecast_hour # Stick the forecast hour in the metadata as well, that's useful. + attrs['start_time'] = start_time + attrs['end_time'] = end_time + + channel_name = '' + for flat_dim in flat_dims: + if da[flat_dim].values >= 10: + channel_name += f'{flat_dim}_{da[flat_dim].values:.0f}_' + else: + channel_name += f'{flat_dim}_{da[flat_dim].values:.2f}_'.replace('.', '_') + da = da.drop(flat_dim) + if flat_dims and attrs.get("GRIB_stepType"): + channel_name += f'{attrs.get("GRIB_stepType")}_{var}' + else: + channel_name += var + channel_name = re.sub(r'[^a-zA-Z0-9_]+', r'_', channel_name) + da_units['unit_'+channel_name] = None + if 'units' in attrs: + da_units['unit_'+channel_name] = attrs['units'] + + da.name = channel_name + da_list.append(da) + + return (da_list, da_units) + + +def validate_tiff_config(ds_dims:t.List,tiff_config: t.Dict): + if tiff_config['dims']: + for dim in tiff_config['dims']: + if dim not in ds_dims: + raise RuntimeError("Please provide valid dimensions for '--tiff_config'") + + +def __partition_dataset(ds: xr.Dataset, + tiff_config: t.Optional[t.Dict] = None) -> t.Union[xr.Dataset, t.List[xr.Dataset]]: + """ + Reads dataset and tiff_config given by user. By default if time or step dimension is present in source data + then it will partition the dataset based on this dimension. Based on tiff_config it will consider the dimensions + given by user to partition the dataset. + """ + coords_set = set(ds.coords.keys()) + _merged_dataset_list = [] + + if tiff_config: + validate_tiff_config(ds.dims.keys(),tiff_config) + + groups, ds, flat_dims = create_partition_configs(ds, tiff_config['dims']) + for group in groups: + ds_attrs = ds.attrs + _data_array_list, dv_units_dict = create_partition(ds, group, flat_dims, coords_set) + + # Stick the forecast hour, start_time, end_time, data variables units + # in the ds attrs as well, that's useful. + ds_attrs['forecast_hour'] = _data_array_list[0].attrs['forecast_hour'] + ds_attrs['start_time'] = _data_array_list[0].attrs['start_time'] + ds_attrs['end_time'] = _data_array_list[0].attrs['end_time'] + ds_attrs.update(**dv_units_dict) + + merged_dataset = xr.merge(_data_array_list) + merged_dataset.attrs.clear() + merged_dataset.attrs.update(ds_attrs) + _merged_dataset_list.append(merged_dataset) + + return _merged_dataset_list + return ds + + +def __partition_grib_dataset(filename: str, + tiff_config: t.Dict) -> t.List[xr.Dataset]: + """ + By default partition of all datasets in list will be done for time and step dimensions. If user specifies dimension + and if dataset consist that dimension then it will be considered for partition. Datasets not containing dimension + will be stored as bands in all the tiff files. + we will store data_arrays into dictionary + dictionary structure + common_dims is intersection of tiff_config["dims"] and ds.dims + { + time_step_value1:{ + common_dims_value1:[ list of data array with common_dims_value1], + common_dims_value2:[ list of data array with common_dims_value2], + . + . + bands:[list of data arrays not containing user given dimension] + }, + time_step_value2:{ + common_dims_value1:[ list of data array with common_dims_value1], + common_dims_value2:[ list of data array with common_dims_value2], + . + . + bands:[list of data arrays not containing user given dimension] + }, + + } + + for each common_dims_value we will merge data array list into dataset. + """ + ds_list = cfgrib.open_datasets(filename) + da_dict = defaultdict(lambda: defaultdict(list)) + da_units_dict = defaultdict(dict) + ds_attrs = ds_list[0].attrs + _merged_dataset_list = [] + ds_dims = [] + + for ds in ds_list: + ds_dims.extend(ds.dims.keys()) + ds_dims = list(set(ds_dims)) + validate_tiff_config(ds_dims, tiff_config) + + for ds in ds_list: + # store common dimensions between tiff_config and dataset + common_dims = set(tiff_config['dims']).intersection(set(ds.dims)) + coords_set = set(ds.coords.keys()) + + groups, ds, flat_dims = create_partition_configs(ds, list(common_dims)) + for group in groups: + da_list, da_units = create_partition(ds, group, flat_dims, coords_set) + dims_val = "" + da = da_list[0] + time_step_val = f"time_{da.attrs['start_time']}_step_{da.attrs['forecast_hour']}" + + # if common_dims empty add list of data array to "bands" for respective time_step_value + # else add list of data array to common_dims_value for respective time_step_value + if not common_dims: + da_dict[time_step_val]["bands"].extend(da_list) + da_units_dict["bands"].update(**da_units) + else: + for _dim in common_dims: + dims_val += f"{_dim}_{da[_dim].values}" + da_dict[time_step_val][dims_val].extend(da_list) + da_units_dict[dims_val].update(**da_units) + + # add data arrays from "bands" to all common_dims_value for respective time_step_value + if tiff_config['dims']: + for keys in da_dict: + if da_dict[keys].get("bands"): + bands = da_dict[keys]["bands"] + bands_attr = da_units_dict["bands"] + for dims_val in da_dict[keys]: + da_dict[keys][dims_val].extend(bands) + da_units_dict[dims_val].update(**bands_attr) + del da_dict[keys]["bands"] + del da_units_dict["bands"] + + # merge data array list into dataset and add respective unit attributes to the dataset + for time_step in da_dict: + for key, _da_list in da_dict[time_step].items(): + attrs_to_add = ds_attrs + + # Stick the forecast hour, start_time, end_time, data variables units + # in the ds attrs as well, that's useful. + attrs_to_add['forecast_hour'] = _da_list[0].attrs['forecast_hour'] + attrs_to_add['start_time'] = _da_list[0].attrs['start_time'] + attrs_to_add['end_time'] = _da_list[0].attrs['end_time'] + attrs_to_add.update(**da_units_dict[key]) + + merged_ds = xr.merge(_da_list) + merged_ds.attrs.clear() + merged_ds.attrs.update(attrs_to_add) + _merged_dataset_list.append(merged_ds) + + return _merged_dataset_list + + def __open_dataset_file(filename: str, uri_extension: str, disable_grib_schema_normalization: bool, open_dataset_kwargs: t.Optional[t.Dict] = None, - group_common_hypercubes: t.Optional[bool] = False) -> t.Union[xr.Dataset, t.List[xr.Dataset]]: + group_common_hypercubes: t.Optional[bool] = False, + tiff_config: t.Optional[t.Dict] = None) -> t.Union[xr.Dataset, t.List[xr.Dataset]]: """Opens the dataset at 'uri' and returns a xarray.Dataset.""" # add a flag to group common hypercubes if group_common_hypercubes: return __normalize_grib_dataset(filename, group_common_hypercubes) if open_dataset_kwargs: - return _add_is_normalized_attr(xr.open_dataset(filename, **open_dataset_kwargs), False) + ds = xr.open_dataset(filename, **open_dataset_kwargs) + return _add_is_normalized_attr(__partition_dataset(ds, tiff_config), False) # If URI extension is .tif, try opening file by specifying engine="rasterio". if uri_extension in ['.tif', '.tiff']: @@ -348,7 +602,8 @@ def __open_dataset_file(filename: str, # If no open kwargs are available and URI extension is other than tif, make educated guesses about the dataset. try: - return _add_is_normalized_attr(xr.open_dataset(filename), False) + ds = xr.open_dataset(filename) + return _add_is_normalized_attr(__partition_dataset(ds, tiff_config), False) except ValueError as e: e_str = str(e) if not ("Consider explicitly selecting one of the installed engines" in e_str and "cfgrib" in e_str): @@ -358,6 +613,8 @@ def __open_dataset_file(filename: str, logger.warning("Assuming grib.") logger.info("Normalizing the grib schema, name of the data variables will look like " "'___'.") + if tiff_config: + return _add_is_normalized_attr(__partition_grib_dataset(filename, tiff_config), True) return _add_is_normalized_attr(__normalize_grib_dataset(filename), True) # Trying with explicit engine for cfgrib. @@ -428,7 +685,8 @@ def open_dataset(uri: str, initialization_time_regex: t.Optional[str] = None, forecast_time_regex: t.Optional[str] = None, group_common_hypercubes: t.Optional[bool] = False, - is_zarr: bool = False) -> t.Iterator[xr.Dataset]: + is_zarr: bool = False, + tiff_config: t.Optional[t.Dict] = None,) -> t.Iterator[t.Union[xr.Dataset, t.List[xr.Dataset]]]: """Open the dataset at 'uri' and return a xarray.Dataset.""" try: local_open_dataset_kwargs = start_date = end_date = None @@ -449,20 +707,20 @@ def open_dataset(uri: str, with open_local(uri) as local_path: _, uri_extension = os.path.splitext(uri) xr_datasets: xr.Dataset = __open_dataset_file(local_path, - uri_extension, - disable_grib_schema_normalization, - local_open_dataset_kwargs, - group_common_hypercubes) - # Extracting dtype, crs and transform from the dataset. + uri_extension, + disable_grib_schema_normalization, + local_open_dataset_kwargs, + group_common_hypercubes, + tiff_config) + # Extracting dtype, crs and transform from the dataset & storing them as attributes. try: with rasterio.open(local_path, 'r') as f: dtype, crs, transform = (f.profile.get(key) for key in ['dtype', 'crs', 'transform']) except rasterio.errors.RasterioIOError: logger.warning('Cannot parse projection and data type information for Dataset %r.', uri) - if group_common_hypercubes: + if isinstance(xr_datasets, list): total_size_in_bytes = 0 - for xr_dataset in xr_datasets: xr_dataset.attrs.update({'dtype': dtype, 'crs': crs, 'transform': transform}) total_size_in_bytes += xr_dataset.nbytes @@ -487,10 +745,10 @@ def open_dataset(uri: str, logger.info(f'opened dataset size: {xr_dataset.nbytes}') beam.metrics.Metrics.counter('Success', 'ReadNetcdfData').inc() - yield xr_datasets if group_common_hypercubes else xr_dataset + yield xr_datasets if isinstance(xr_datasets, list) else xr_dataset # Releasing any resources linked to the object(s). - if group_common_hypercubes: + if isinstance(xr_datasets, list): for xr_dataset in xr_datasets: xr_dataset.close() else: diff --git a/weather_mv/loader_pipeline/sinks_test.py b/weather_mv/loader_pipeline/sinks_test.py index 0d8838da..02152a63 100644 --- a/weather_mv/loader_pipeline/sinks_test.py +++ b/weather_mv/loader_pipeline/sinks_test.py @@ -68,10 +68,10 @@ def write_netcdf(): data_arr = np.random.uniform(low=0, high=0.1, size=(5, lat_dim, lon_dim)) ds = xr.Dataset( - {"var_1": (('time', 'lat', 'lon'), data_arr)}, + {"var_1": (('time', 'latitude', 'longitude'), data_arr)}, coords={ - "lat": lat, - "lon": lon, + "latitude": lat, + "longitude": lon, }) with tempfile.NamedTemporaryFile() as fp: ds.to_netcdf(fp.name) diff --git a/weather_mv/test_data/test_data_grib_multiple_edition_multiple_timestep.grib2 b/weather_mv/test_data/test_data_grib_multiple_edition_multiple_timestep.grib2 new file mode 100644 index 00000000..522517c0 Binary files /dev/null and b/weather_mv/test_data/test_data_grib_multiple_edition_multiple_timestep.grib2 differ