From 32c2fb7ea9b98de7d3c827a9717f9245bae2b8af Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Sat, 18 Jan 2025 10:50:47 +0100 Subject: [PATCH] allow overriding the grid info (#63) * accessor method for decoding while possibly overriding the grid info * consolidate the treatment of `grid_info` * docstring for `decode` * don't modify the accessor * add a test for `decode` * expose DGGSInfo * add the rotation * Revert "add the rotation" This reverts commit df221522f5131bf86a3ab580d8fdc42a914c3e87. * expect `level` instead of `resolution` * forward to the accessor version * sync documentation between `decode` and the accessor's `decode` --- xdggs/accessor.py | 24 +++++++++++ xdggs/index.py | 27 +++++++----- xdggs/tests/test_accessor.py | 84 ++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 11 deletions(-) diff --git a/xdggs/accessor.py b/xdggs/accessor.py index 6aac600..f7fcf83 100644 --- a/xdggs/accessor.py +++ b/xdggs/accessor.py @@ -29,6 +29,30 @@ def __init__(self, obj: xr.Dataset | xr.DataArray): self._name = name self._index = index + def decode(self, grid_info=None, *, name="cell_ids") -> xr.Dataset | xr.DataArray: + """decode the DGGS cell ids + + Parameters + ---------- + grid_info : dict or DGGSInfo, optional + Override the grid parameters on the dataset. Useful to set attributes on + the dataset. + name : str, default: "cell_ids" + The name of the coordinate containing the cell ids. + + Returns + ------- + obj : xarray.DataArray or xarray.Dataset + The object with a DGGS index on the cell id coordinate. + """ + var = self._obj[name] + if isinstance(grid_info, DGGSInfo): + grid_info = grid_info.to_dict() + if isinstance(grid_info, dict): + var.attrs = grid_info + + return self._obj.drop_indexes(name, errors="ignore").set_xindex(name, DGGSIndex) + @property def index(self) -> DGGSIndex: """The DGGSIndex instance for this Dataset or DataArray. diff --git a/xdggs/index.py b/xdggs/index.py index 13149d9..b0f71f2 100644 --- a/xdggs/index.py +++ b/xdggs/index.py @@ -9,27 +9,32 @@ from xdggs.utils import GRID_REGISTRY, _extract_cell_id_variable -def decode(ds): +def decode(ds, grid_info=None, *, name="cell_ids"): """ decode grid parameters and create a DGGS index Parameters ---------- ds : xarray.Dataset - The input dataset. Must contain a `"cell_ids"` coordinate with at least - the attributes `grid_name` and `resolution`. + The input dataset. Must contain a coordinate for the cell ids with at + least the attributes `grid_name` and `level`. + grid_info : dict or DGGSInfo, optional + Override the grid parameters on the dataset. Useful to set attributes on + the dataset. + name : str, default: "cell_ids" + The name of the coordinate containing the cell ids. Returns ------- - decoded : xarray.Dataset - The input dataset with a DGGS index on the ``"cell_ids"`` coordinate. - """ - - variable_name = "cell_ids" + decoded : xarray.DataArray or xarray.Dataset + The input dataset with a DGGS index on the cell id coordinate. - return ds.drop_indexes(variable_name, errors="ignore").set_xindex( - variable_name, DGGSIndex - ) + See Also + -------- + xarray.Dataset.dggs.decode + xarray.DataArray.dggs.decode + """ + return ds.dggs.decode(name=name, grid_info=grid_info) class DGGSIndex(Index): diff --git a/xdggs/tests/test_accessor.py b/xdggs/tests/test_accessor.py index 306e6d4..2083e48 100644 --- a/xdggs/tests/test_accessor.py +++ b/xdggs/tests/test_accessor.py @@ -4,6 +4,90 @@ import xdggs +@pytest.mark.parametrize( + ["obj", "grid_info", "name"], + ( + pytest.param( + xr.Dataset( + coords={ + "cell_ids": ( + "cells", + [1], + { + "grid_name": "healpix", + "level": 1, + "indexing_scheme": "ring", + }, + ) + } + ), + None, + None, + id="dataset-from attrs-standard name", + ), + pytest.param( + xr.DataArray( + [0.1], + coords={ + "cell_ids": ( + "cells", + [1], + { + "grid_name": "healpix", + "level": 1, + "indexing_scheme": "ring", + }, + ) + }, + dims="cells", + ), + None, + None, + id="dataarray-from attrs-standard name", + ), + pytest.param( + xr.Dataset( + coords={ + "zone_ids": ( + "zones", + [1], + { + "grid_name": "healpix", + "level": 1, + "indexing_scheme": "ring", + }, + ) + } + ), + None, + "zone_ids", + id="dataset-from attrs-custom name", + ), + pytest.param( + xr.Dataset(coords={"cell_ids": ("cells", [1])}), + {"grid_name": "healpix", "level": 1, "indexing_scheme": "ring"}, + None, + id="dataset-dict-standard name", + ), + ), +) +def test_decode(obj, grid_info, name) -> None: + kwargs = {} + if name is not None: + kwargs["name"] = name + + if isinstance(grid_info, dict): + expected_grid_info = grid_info + elif isinstance(grid_info, xdggs.DGGSInfo): + expected_grid_info = grid_info.to_dict() + else: + expected_grid_info = obj[name if name is not None else "cell_ids"].attrs + + actual = obj.dggs.decode(grid_info, **kwargs) + assert any(isinstance(index, xdggs.DGGSIndex) for index in actual.xindexes.values()) + assert actual.dggs.grid_info.to_dict() == expected_grid_info + + @pytest.mark.parametrize( ["obj", "expected"], (