From 935878f35f585fa42db3e50c9f130636fc74996d Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe Date: Mon, 11 Apr 2022 17:58:50 -0600 Subject: [PATCH] Allow reusing weights saved in a pyramid during xESFM regridding (#34) --- ci/environment.yml | 1 + ndpyramid/regrid.py | 102 +++++++++++++++++++++++++++++++++++------ tests/test_pyramids.py | 28 +++++++++-- 3 files changed, 113 insertions(+), 18 deletions(-) diff --git a/ci/environment.yml b/ci/environment.yml index ccf67ea..a040881 100644 --- a/ci/environment.yml +++ b/ci/environment.yml @@ -16,6 +16,7 @@ dependencies: - rasterio - rioxarray - scipy + - sparse>=0.13.0 - xarray - xarray-datatree>=0.0.4 - xesmf diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index 33b0cd7..9fc6f14 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -1,7 +1,6 @@ from __future__ import annotations # noqa: F401 import itertools -import pathlib import datatree as dt import numpy as np @@ -10,6 +9,35 @@ from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template +def xesmf_weights_to_xarray(regridder) -> xr.Dataset: + w = regridder.weights.data + dim = 'n_s' + ds = xr.Dataset( + { + 'S': (dim, w.data), + 'col': (dim, w.coords[1, :] + 1), + 'row': (dim, w.coords[0, :] + 1), + } + ) + ds.attrs = {'n_in': regridder.n_in, 'n_out': regridder.n_out} + return ds + + +def _reconstruct_xesmf_weights(ds_w): + """Reconstruct weights into format that xESMF understands""" + import sparse + import xarray as xr + + col = ds_w['col'].values - 1 + row = ds_w['row'].values - 1 + s = ds_w['S'].values + n_out, n_in = ds_w.attrs['n_out'], ds_w.attrs['n_in'] + crds = np.stack([row, col]) + return xr.DataArray( + sparse.COO(crds, s, (n_out, n_in)), dims=('out_dim', 'in_dim'), name='weights' + ) + + def make_grid_ds(level: int, pixels_per_tile: int = 128) -> xr.Dataset: """Make a dataset representing a target grid @@ -97,11 +125,52 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree: return data +def generate_weights_pyramid( + ds_in: xr.Dataset, levels: int, method: str = 'bilinear', regridder_kws: dict = None +) -> dt.DataTree: + """helper function to generate weights for a multiscale regridder + + Parameters + ---------- + ds_in : xr.Dataset + Input dataset to regrid + levels : int + Number of levels in the pyramid + method : str, optional + Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear' + regridder_kws : dict + Keyword arguments to pass to :py:class:`~xesmf.Regridder`. Default is `{'periodic': True}` + + Returns + ------- + weights : dt.DataTree + Multiscale weights + """ + import datatree + import xesmf as xe + + regridder_kws = {} if regridder_kws is None else regridder_kws + regridder_kws = {'periodic': True, **regridder_kws} + + weights_pyramid = datatree.DataTree() + for level in range(levels): + ds_out = make_grid_ds(level=level) + regridder = xe.Regridder(ds_in, ds_out, method, **regridder_kws) + ds = xesmf_weights_to_xarray(regridder) + + weights_pyramid[str(level)] = ds + + weights_pyramid.ds.attrs['levels'] = levels + weights_pyramid.ds.attrs['regrid_method'] = method + + return weights_pyramid + + def pyramid_regrid( ds: xr.Dataset, target_pyramid: dt.DataTree = None, levels: int = None, - weights_template: str = None, + weights_pyramid: dt.DataTree = None, method: str = 'bilinear', regridder_kws: dict = None, regridder_apply_kws: dict = None, @@ -118,8 +187,8 @@ def pyramid_regrid( Target grids, if not provided, they will be generated, by default None levels : int, optional Number of levels in pyramid, by default None - weights_template : str, optional - Filepath to write generated weights to, e.g. `'weights_{level}'`, by default None + weights_pyramid : dt.DataTree, optional + pyramid containing pregenerated weights method : str, optional Regridding method. See :py:class:`~xesmf.Regridder` for valid options, by default 'bilinear' regridder_kws : dict @@ -147,14 +216,15 @@ def pyramid_regrid( if levels is None: levels = len(target_pyramid.keys()) # TODO: get levels from the pyramid metadata - if regridder_kws is None: - regridder_kws = {'periodic': True} + regridder_kws = {} if regridder_kws is None else regridder_kws + regridder_kws = {'periodic': True, **regridder_kws} # multiscales spec save_kwargs = locals() del save_kwargs['ds'] del save_kwargs['target_pyramid'] del save_kwargs['xe'] + del save_kwargs['weights_pyramid'] attrs = { 'multiscales': multiscales_template( @@ -173,21 +243,23 @@ def pyramid_regrid( # pyramid data for level in range(levels): grid = target_pyramid[str(level)].ds.load() - # get the regridder object - if not weights_template: + if weights_pyramid is None: regridder = xe.Regridder(ds, grid, method, **regridder_kws) else: - fn = pathlib.PosixPath(weights_template.format(level=level)) - if not fn.exists(): - regridder = xe.Regridder(ds, grid, method, **regridder_kws) - regridder.to_netcdf(filename=fn) - else: - regridder = xe.Regridder(ds, grid, method, weights=fn, **regridder_kws) - + # Reconstruct weights into format that xESMF understands + # this is a hack that assumes the weights were generated by + # the `generate_weights_pyramid` function + + ds_w = weights_pyramid[str(level)].ds + weights = _reconstruct_xesmf_weights(ds_w) + regridder = xe.Regridder( + ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws + ) # regrid if regridder_apply_kws is None: regridder_apply_kws = {} + regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} pyramid[str(level)] = regridder(ds, **regridder_apply_kws) pyramid = add_metadata_and_zarr_encoding( diff --git a/tests/test_pyramids.py b/tests/test_pyramids.py index 284eab0..49447b6 100644 --- a/tests/test_pyramids.py +++ b/tests/test_pyramids.py @@ -4,7 +4,7 @@ from zarr.storage import MemoryStore from ndpyramid import pyramid_coarsen, pyramid_regrid, pyramid_reproject -from ndpyramid.regrid import make_grid_ds +from ndpyramid.regrid import generate_weights_pyramid, make_grid_ds @pytest.fixture @@ -32,7 +32,7 @@ def test_reprojected_pyramid(temperature): pyramid.to_zarr(MemoryStore()) -@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': True}]) +@pytest.mark.parametrize('regridder_apply_kws', [None, {'keep_attrs': False}]) def test_regridded_pyramid(temperature, regridder_apply_kws): pytest.importorskip('xesmf') pyramid = pyramid_regrid( @@ -41,7 +41,7 @@ def test_regridded_pyramid(temperature, regridder_apply_kws): assert pyramid.ds.attrs['multiscales'] expected_attrs = ( temperature['air'].attrs - if regridder_apply_kws is not None and regridder_apply_kws['keep_attrs'] + if not regridder_apply_kws or regridder_apply_kws.get('keep_attrs') else {} ) assert pyramid['0'].ds.air.attrs == expected_attrs @@ -49,8 +49,30 @@ def test_regridded_pyramid(temperature, regridder_apply_kws): pyramid.to_zarr(MemoryStore()) +def test_regridded_pyramid_with_weights(temperature): + pytest.importorskip('xesmf') + levels = 2 + weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels) + pyramid = pyramid_regrid( + temperature, levels=levels, weights_pyramid=weights_pyramid, other_chunks={'time': 2} + ) + assert pyramid.ds.attrs['multiscales'] + assert len(pyramid.ds.attrs['multiscales'][0]['datasets']) == levels + pyramid.to_zarr(MemoryStore()) + + def test_make_grid_ds(): grid = make_grid_ds(0, pixels_per_tile=8) lon_vals = grid.lon_b.values assert np.all((lon_vals[-1, :] - lon_vals[0, :]) < 0.001) + + +@pytest.mark.parametrize('levels', [1, 2]) +@pytest.mark.parametrize('method', ['bilinear', 'conservative']) +def test_generate_weights_pyramid(temperature, levels, method): + weights_pyramid = generate_weights_pyramid(temperature.isel(time=0), levels, method=method) + assert weights_pyramid.ds.attrs['levels'] == levels + assert weights_pyramid.ds.attrs['regrid_method'] == method + assert set(weights_pyramid['0'].ds.data_vars) == {'S', 'col', 'row'} + assert 'n_in' in weights_pyramid['0'].ds.attrs and 'n_out' in weights_pyramid['0'].ds.attrs