From 3d7eea75d379a432c2a5b0ea05fd09e8e5115076 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Thu, 31 Mar 2022 10:06:44 -0600 Subject: [PATCH] Allow passing `keep_attrs`, `skipna` to xESMF regridder (#26) --- ndpyramid/regrid.py | 12 ++++++++---- tests/test_pyramids.py | 17 ++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index de80927..b13e52b 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -100,8 +100,6 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree: data[str(level)] = make_grid_ds(level).chunk(-1) return data - # data.to_zarr('gs://carbonplan-scratch/grids/epsg:3857/', consolidated=True) - def pyramid_regrid( ds: xr.Dataset, @@ -110,6 +108,7 @@ def pyramid_regrid( weights_template: str = None, method: str = 'bilinear', regridder_kws: dict = None, + regridder_apply_kws: dict = None, ) -> dt.DataTree: """Make a pyramid using xesmf's regridders @@ -124,9 +123,12 @@ def pyramid_regrid( weights_template : str, optional Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None method : str, optional - Regridding method. See ``xesmf.Regridder`` for valid options, by default 'bilinear' + Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear' regridder_kws : dict Keyword arguments to pass to regridder. Default is `{'periodic': True}` + regridder_apply_kws : dict + Keyword arguments such as `keep_attrs`, `skipna`, `na_thres` + to pass to :py:meth:`~xesmf.Regridder.__call__`. Default is None Returns ------- @@ -182,6 +184,8 @@ def pyramid_regrid( regridder = xe.Regridder(ds, grid, method, weights=fn, **regridder_kws) # regrid - pyramid[str(level)] = regridder(ds) + if regridder_apply_kws is None: + regridder_apply_kws = {} + pyramid[str(level)] = regridder(ds, **regridder_apply_kws) return pyramid diff --git a/tests/test_pyramids.py b/tests/test_pyramids.py index 8cb484c..c59bba8 100644 --- a/tests/test_pyramids.py +++ b/tests/test_pyramids.py @@ -15,7 +15,6 @@ def temperature(): def test_xarray_coarsened_pyramid(temperature): - print(temperature) factors = [4, 2, 1] pyramid = pyramid_coarsen(temperature, dims=('lat', 'lon'), factors=factors, boundary='trim') assert pyramid.ds.attrs['multiscales'] @@ -24,7 +23,7 @@ def test_xarray_coarsened_pyramid(temperature): def test_reprojected_pyramid(temperature): - rioxarray = pytest.importorskip('rioxarray') # noqa: F841 + pytest.importorskip('rioxarray') levels = 2 temperature = temperature.rio.write_crs('EPSG:4326') pyramid = pyramid_reproject(temperature, levels=2) @@ -33,10 +32,18 @@ def test_reprojected_pyramid(temperature): pyramid.to_zarr(MemoryStore()) -def test_regridded_pyramid(temperature): - xesmf = pytest.importorskip('xesmf') # noqa: F841 - pyramid = pyramid_regrid(temperature, levels=2) +@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': True}]) +def test_regridded_pyramid(temperature, regridder_apply_kws): + pytest.importorskip('xesmf') + pyramid = pyramid_regrid(temperature, levels=2, regridder_apply_kws=regridder_apply_kws) assert pyramid.ds.attrs['multiscales'] + expected_attrs = ( + temperature['air'].attrs + if regridder_apply_kws is not None and regridder_apply_kws['keep_attrs'] + else {} + ) + assert pyramid['0'].ds.air.attrs == expected_attrs + assert pyramid['1'].ds.air.attrs == expected_attrs pyramid.to_zarr(MemoryStore())