Skip to content

Commit

Permalink
Allow reusing weights saved in a pyramid during xESFM regridding (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Apr 11, 2022
1 parent 9d21440 commit 935878f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 18 deletions.
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies:
- rasterio
- rioxarray
- scipy
- sparse>=0.13.0
- xarray
- xarray-datatree>=0.0.4
- xesmf
Expand Down
102 changes: 87 additions & 15 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations # noqa: F401

import itertools
import pathlib

import datatree as dt
import numpy as np
Expand All @@ -10,6 +9,35 @@
from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template


def xesmf_weights_to_xarray(regridder) -> xr.Dataset:
w = regridder.weights.data
dim = 'n_s'
ds = xr.Dataset(
{
'S': (dim, w.data),
'col': (dim, w.coords[1, :] + 1),
'row': (dim, w.coords[0, :] + 1),
}
)
ds.attrs = {'n_in': regridder.n_in, 'n_out': regridder.n_out}
return ds


def _reconstruct_xesmf_weights(ds_w):
"""Reconstruct weights into format that xESMF understands"""
import sparse
import xarray as xr

col = ds_w['col'].values - 1
row = ds_w['row'].values - 1
s = ds_w['S'].values
n_out, n_in = ds_w.attrs['n_out'], ds_w.attrs['n_in']
crds = np.stack([row, col])
return xr.DataArray(
sparse.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights'
)


def make_grid_ds(level: int, pixels_per_tile: int = 128) -> xr.Dataset:
"""Make a dataset representing a target grid
Expand Down Expand Up @@ -97,11 +125,52 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
return data


def generate_weights_pyramid(
ds_in: xr.Dataset, levels: int, method: str = 'bilinear', regridder_kws: dict = None
) -> dt.DataTree:
"""helper function to generate weights for a multiscale regridder
Parameters
----------
ds_in : xr.Dataset
Input dataset to regrid
levels : int
Number of levels in the pyramid
method : str, optional
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
regridder_kws : dict
Keyword arguments to pass to :py:class:`~xesmf.Regridder`. Default is `{'periodic': True}`
Returns
-------
weights : dt.DataTree
Multiscale weights
"""
import datatree
import xesmf as xe

regridder_kws = {} if regridder_kws is None else regridder_kws
regridder_kws = {'periodic': True, **regridder_kws}

weights_pyramid = datatree.DataTree()
for level in range(levels):
ds_out = make_grid_ds(level=level)
regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws)
ds = xesmf_weights_to_xarray(regridder)

weights_pyramid[str(level)] = ds

weights_pyramid.ds.attrs['levels'] = levels
weights_pyramid.ds.attrs['regrid_method'] = method

return weights_pyramid


def pyramid_regrid(
ds: xr.Dataset,
target_pyramid: dt.DataTree = None,
levels: int = None,
weights_template: str = None,
weights_pyramid: dt.DataTree = None,
method: str = 'bilinear',
regridder_kws: dict = None,
regridder_apply_kws: dict = None,
Expand All @@ -118,8 +187,8 @@ def pyramid_regrid(
Target grids, if not provided, they will be generated, by default None
levels : int, optional
Number of levels in pyramid, by default None
weights_template : str, optional
Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None
weights_pyramid : dt.DataTree, optional
pyramid containing pregenerated weights
method : str, optional
Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear'
regridder_kws : dict
Expand Down Expand Up @@ -147,14 +216,15 @@ def pyramid_regrid(
if levels is None:
levels = len(target_pyramid.keys()) # TODO: get levels from the pyramid metadata

if regridder_kws is None:
regridder_kws = {'periodic': True}
regridder_kws = {} if regridder_kws is None else regridder_kws
regridder_kws = {'periodic': True, **regridder_kws}

# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']
del save_kwargs['target_pyramid']
del save_kwargs['xe']
del save_kwargs['weights_pyramid']

attrs = {
'multiscales': multiscales_template(
Expand All @@ -173,21 +243,23 @@ def pyramid_regrid(
# pyramid data
for level in range(levels):
grid = target_pyramid[str(level)].ds.load()

# get the regridder object
if not weights_template:
if weights_pyramid is None:
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
else:
fn = pathlib.PosixPath(weights_template.format(level=level))
if not fn.exists():
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
regridder.to_netcdf(filename=fn)
else:
regridder = xe.Regridder(ds, grid, method, weights=fn, **regridder_kws)

# Reconstruct weights into format that xESMF understands
# this is a hack that assumes the weights were generated by
# the `generate_weights_pyramid` function

ds_w = weights_pyramid[str(level)].ds
weights = _reconstruct_xesmf_weights(ds_w)
regridder = xe.Regridder(
ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws
)
# regrid
if regridder_apply_kws is None:
regridder_apply_kws = {}
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
pyramid[str(level)] = regridder(ds, **regridder_apply_kws)

pyramid = add_metadata_and_zarr_encoding(
Expand Down
28 changes: 25 additions & 3 deletions tests/test_pyramids.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from zarr.storage import MemoryStore

from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject
from ndpyramid.regrid import make_grid_ds
from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds


@pytest.fixture
Expand Down Expand Up @@ -32,7 +32,7 @@ def test_reprojected_pyramid(temperature):
pyramid.to_zarr(MemoryStore())


@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': True}])
@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}])
def test_regridded_pyramid(temperature, regridder_apply_kws):
pytest.importorskip('xesmf')
pyramid = pyramid_regrid(
Expand All @@ -41,16 +41,38 @@ def test_regridded_pyramid(temperature, regridder_apply_kws):
assert pyramid.ds.attrs['multiscales']
expected_attrs = (
temperature['air'].attrs
if regridder_apply_kws is not None and regridder_apply_kws['keep_attrs']
if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs')
else {}
)
assert pyramid['0'].ds.air.attrs == expected_attrs
assert pyramid['1'].ds.air.attrs == expected_attrs
pyramid.to_zarr(MemoryStore())


def test_regridded_pyramid_with_weights(temperature):
pytest.importorskip('xesmf')
levels = 2
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels)
pyramid = pyramid_regrid(
temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2}
)
assert pyramid.ds.attrs['multiscales']
assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels
pyramid.to_zarr(MemoryStore())


def test_make_grid_ds():

grid = make_grid_ds(0, pixels_per_tile=8)
lon_vals = grid.lon_b.values
assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001)


@pytest.mark.parametrize('levels', [1, 2])
@pytest.mark.parametrize('method', ['bilinear', 'conservative'])
def test_generate_weights_pyramid(temperature, levels, method):
weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels, method=method)
assert weights_pyramid.ds.attrs['levels'] == levels
assert weights_pyramid.ds.attrs['regrid_method'] == method
assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'}
assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs

0 comments on commit 935878f

Please sign in to comment.