Skip to content

Commit

Permalink
Allow passing keep_attrs, skipna to xESMF regridder (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Mar 31, 2022
1 parent 18b5ff4 commit 3d7eea7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
12 changes: 8 additions & 4 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
data[str(level)] = make_grid_ds(level).chunk(-1)
return data

# data.to_zarr('gs://carbonplan-scratch/grids/epsg:3857/', consolidated=True)


def pyramid_regrid(
ds: xr.Dataset,
Expand All @@ -110,6 +108,7 @@ def pyramid_regrid(
weights_template: str = None,
method: str = 'bilinear',
regridder_kws: dict = None,
regridder_apply_kws: dict = None,
) -> dt.DataTree:
"""Make a pyramid using xesmf's regridders
Expand All @@ -124,9 +123,12 @@ def pyramid_regrid(
weights_template : str, optional
Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None
method : str, optional
Regridding method. See ``xesmf.Regridder`` for valid options, by default 'bilinear'
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
regridder_kws : dict
Keyword arguments to pass to regridder. Default is `{'periodic': True}`
regridder_apply_kws : dict
Keyword arguments such as `keep_attrs`, `skipna`, `na_thres`
to pass to :py:meth:`~xesmf.Regridder.__call__`. Default is None
Returns
-------
Expand Down Expand Up @@ -182,6 +184,8 @@ def pyramid_regrid(
regridder = xe.Regridder(ds, grid, method, weights=fn, **regridder_kws)

# regrid
pyramid[str(level)] = regridder(ds)
if regridder_apply_kws is None:
regridder_apply_kws = {}
pyramid[str(level)] = regridder(ds, **regridder_apply_kws)

return pyramid
17 changes: 12 additions & 5 deletions tests/test_pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def temperature():


def test_xarray_coarsened_pyramid(temperature):
print(temperature)
factors = [4, 2, 1]
pyramid = pyramid_coarsen(temperature, dims=('lat', 'lon'), factors=factors, boundary='trim')
assert pyramid.ds.attrs['multiscales']
Expand All @@ -24,7 +23,7 @@ def test_xarray_coarsened_pyramid(temperature):


def test_reprojected_pyramid(temperature):
rioxarray = pytest.importorskip('rioxarray') # noqa: F841
pytest.importorskip('rioxarray')
levels = 2
temperature = temperature.rio.write_crs('EPSG:4326')
pyramid = pyramid_reproject(temperature, levels=2)
Expand All @@ -33,10 +32,18 @@ def test_reprojected_pyramid(temperature):
pyramid.to_zarr(MemoryStore())


def test_regridded_pyramid(temperature):
xesmf = pytest.importorskip('xesmf') # noqa: F841
pyramid = pyramid_regrid(temperature, levels=2)
@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': True}])
def test_regridded_pyramid(temperature, regridder_apply_kws):
pytest.importorskip('xesmf')
pyramid = pyramid_regrid(temperature, levels=2, regridder_apply_kws=regridder_apply_kws)
assert pyramid.ds.attrs['multiscales']
expected_attrs = (
temperature['air'].attrs
if regridder_apply_kws is not None and regridder_apply_kws['keep_attrs']
else {}
)
assert pyramid['0'].ds.air.attrs == expected_attrs
assert pyramid['1'].ds.air.attrs == expected_attrs
pyramid.to_zarr(MemoryStore())


Expand Down

0 comments on commit 3d7eea7

Please sign in to comment.