Skip to content

Commit

Permalink
Pyramid create for creating pyramids with custom funcs (#120)
Browse files Browse the repository at this point in the history
Co-authored-by: Max Jones <[email protected]>
  • Loading branch information
ahuang11 and maxrjones authored Apr 8, 2024
1 parent 9a62b88 commit ec65ab4
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ Top level API
:toctree: generated/

pyramid_coarsen
pyramid_create
pyramid_reproject
pyramid_regrid
22 changes: 22 additions & 0 deletions docs/generate-pyramids.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,25 @@ pyramid = pyramid_reproject(ds, levels=2)
# write the pyramid to zarr
pyramid.to_zarr('./path/to/write')
```

There's also `pyramid_create` -- a more versatile alternative to `pyramid_coarsen`.
This function accepts a custom function with the signature: `ds`, `factor`, `dims`.
Here, the `sel_coarsen` function uses `ds.sel` to perform coarsening:

```python
from ndpyramid import pyramid_create

def sel_coarsen(ds, factor, dims, **kwargs):
return ds.sel(**{dim: slice(None, None, factor) for dim in dims})

factors = [4, 2, 1]
pyramid = pyramid_create(
temperature,
dims=('lat', 'lon'),
factors=factors,
boundary='trim',
func=sel_coarsen,
method_label="slice_coarsen",
type_label='pick',
)
```
1 change: 1 addition & 0 deletions ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa

from .create import pyramid_create
from .coarsen import pyramid_coarsen
from .reproject import pyramid_reproject
from .regrid import pyramid_regrid
Expand Down
37 changes: 13 additions & 24 deletions ndpyramid/coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datatree as dt
import xarray as xr

from .utils import get_version, multiscales_template
from .create import pyramid_create


def pyramid_coarsen(
Expand All @@ -23,28 +23,17 @@ def pyramid_coarsen(
Additional keyword arguments to pass to xarray.Dataset.coarsen.
"""

# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
type='reduce',
method='pyramid_coarsen',
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
def coarsen(ds: xr.Dataset, factor: int, dims: list[str], **kwargs):
# merge dictionary via union operator
kwargs |= {d: factor for d in dims}
plevels[str(key)] = ds.coarsen(**kwargs).mean() # type: ignore

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
return ds.coarsen(**kwargs).mean() # type: ignore

return pyramid_create(
ds,
factors=factors,
dims=dims,
func=coarsen,
method_label='pyramid_coarsen',
type_label='reduce',
**kwargs,
)
70 changes: 70 additions & 0 deletions ndpyramid/create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations # noqa: F401

from typing import Callable

import datatree as dt
import xarray as xr

from .utils import get_version, multiscales_template


def pyramid_create(
ds: xr.Dataset,
*,
factors: list[int],
dims: list[str],
func: Callable,
type_label: str = 'reduce',
method_label: str | None = None,
**kwargs,
):
"""Create a multiscale pyramid via a given function applied to a dataset.
The generalized version of pyramid_coarsen.
Parameters
----------
ds : xarray.Dataset
The dataset to apply the function to.
factors : list[int]
The factors to coarsen by.
dims : list[str]
The dimensions to coarsen.
func : Callable
The function to apply to the dataset; must accept the
`ds`, `factor`, and `dims` as positional arguments.
type_label : str, optional
The type label to use as metadata for the multiscales spec.
The default is 'reduce'.
method_label : str, optional
The method label to use as metadata for the multiscales spec.
The default is the name of the function.
kwargs : dict
Additional keyword arguments to pass to the func.
"""
# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']
del save_kwargs['func']
del save_kwargs['type_label']
del save_kwargs['method_label']

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
type=type_label,
method=method_label or func.__name__,
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
plevels[str(key)] = func(ds, factor, dims, **kwargs)

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
28 changes: 27 additions & 1 deletion tests/test_pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import xarray as xr
from zarr.storage import MemoryStore

from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject
from ndpyramid import pyramid_coarsen, pyramid_create, pyramid_regrid, pyramid_reproject
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds
from ndpyramid.testing import verify_bounds

Expand All @@ -22,6 +22,32 @@ def test_xarray_coarsened_pyramid(temperature, benchmark):
)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors)
assert pyramid.ds.attrs['multiscales'][0]['metadata']['method'] == 'pyramid_coarsen'
assert pyramid.ds.attrs['multiscales'][0]['type'] == 'reduce'
pyramid.to_zarr(MemoryStore())


@pytest.mark.parametrize('method_label', [None, 'sel_coarsen'])
def test_xarray_custom_coarsened_pyramid(temperature, benchmark, method_label):
def sel_coarsen(ds, factor, dims, **kwargs):
return ds.sel(**{dim: slice(None, None, factor) for dim in dims})

factors = [4, 2, 1]
pyramid = benchmark(
lambda: pyramid_create(
temperature,
dims=('lat', 'lon'),
factors=factors,
boundary='trim',
func=sel_coarsen,
method_label=method_label,
type_label='pick',
)
)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == len(factors)
assert pyramid.ds.attrs['multiscales'][0]['metadata']['method'] == 'sel_coarsen'
assert pyramid.ds.attrs['multiscales'][0]['type'] == 'pick'
pyramid.to_zarr(MemoryStore())


Expand Down

0 comments on commit ec65ab4

Please sign in to comment.