Skip to content

Commit

Permalink
Adds option for parallel weight generation with xESMF (#145)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
norlandrhagen and andersy005 authored Nov 22, 2024
1 parent 4e87518 commit 764494e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
5 changes: 4 additions & 1 deletion ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def pyramid_regrid(
projection: typing.Literal["web-mercator", "equidistant-cylindrical"] = "web-mercator",
target_pyramid: xr.DataTree = None,
levels: int = None,
parallel_weights: bool = True,
weights_pyramid: xr.DataTree = None,
method: str = "bilinear",
regridder_kws: dict = None,
Expand All @@ -217,6 +218,8 @@ def pyramid_regrid(
Number of levels in pyramid, by default None
weights_pyramid : xr.DataTree, optional
pyramid containing pregenerated weights
parallel_weights : Bool
Use dask to generate parallel weights
method : str, optional
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
regridder_kws : dict
Expand Down Expand Up @@ -285,7 +288,7 @@ def pyramid_regrid(
grid = target_pyramid[str(level)].ds.load()
# get the regridder object
if weights_pyramid is None:
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
regridder = xe.Regridder(ds, grid, method, parallel=parallel_weights, **regridder_kws)
else:
# Reconstruct weights into format that xESMF understands
# this is a hack that assumes the weights were generated by
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ known-first-party = ["ndpyramid"]


# Notebook ruff config
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"*.ipynb" = [
"D100",
"E402",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_pyramid_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def test_regridded_pyramid(temperature, regridder_apply_kws, benchmark):
temperature = temperature.isel(time=slice(0, 5))
pyramid = benchmark(
lambda: pyramid_regrid(
temperature, levels=2, regridder_apply_kws=regridder_apply_kws, other_chunks={"time": 2}
temperature,
levels=2,
parallel_weights=False,
regridder_apply_kws=regridder_apply_kws,
other_chunks={"time": 2},
)
)
verify_bounds(pyramid)
Expand Down

0 comments on commit 764494e

Please sign in to comment.