Skip to content

Commit

Permalink
conform to newest versions of xarray-datatree (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
andersy005 authored Mar 2, 2023
1 parent 6746e88 commit 03058b3
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 77 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
ndpyramid/_version.py

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
2 changes: 1 addition & 1 deletion ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# flake8: noqa

from ._version import __version__
from .core import pyramid_coarsen, pyramid_reproject
from .regrid import pyramid_regrid
from ._version import __version__
10 changes: 0 additions & 10 deletions ndpyramid/_version.py

This file was deleted.

24 changes: 15 additions & 9 deletions ndpyramid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,16 @@ def pyramid_coarsen(
}

# set up pyramid
levels = {}
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
# merge dictionary via union operator
kwargs |= {d: factor for d in dims}
plevels[str(key)] = ds.coarsen(**kwargs).mean()

kwargs.update({d: factor for d in dims})
levels[str(key)] = ds.coarsen(**kwargs).mean()

pyramid = dt.DataTree.from_dict(levels)
pyramid.ds = xr.Dataset(attrs=attrs)

return pyramid
plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)


def pyramid_reproject(
Expand All @@ -63,6 +61,7 @@ def pyramid_reproject(
resampling: str | dict = 'average',
extra_dim: str = None,
) -> dt.DataTree:

"""Create a multiscale pyramid of a dataset via reprojection.
Parameters
Expand All @@ -88,6 +87,7 @@ def pyramid_reproject(
The multiscale pyramid.
"""

import rioxarray # noqa: F401
from rasterio.transform import Affine
from rasterio.warp import Resampling
Expand All @@ -104,6 +104,7 @@ def pyramid_reproject(
)
}

# Convert resampling from string to dictionary if necessary
if isinstance(resampling, str):
resampling_dict = defaultdict(lambda: resampling)
else:
Expand All @@ -128,21 +129,26 @@ def reproject(da, var):
transform=dst_transform,
)

# create the data array for each level
plevels[lkey] = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
# if extra_dim is not specified, raise an error
if extra_dim is None:
raise ValueError("must specify 'extra_dim' to iterate over 4d data")
da_all = []
for index in ds[extra_dim]:
# reproject each index of the 4th dimension
da_reprojected = reproject(da.sel({extra_dim: index}), k)
da_all.append(da_reprojected)
plevels[lkey][k] = xr.concat(da_all, ds[extra_dim])
else:
# if the data array is not 4D, just reproject it
plevels[lkey][k] = reproject(da, k)

# create the final multiscale pyramid
plevels['/'] = xr.Dataset(attrs=attrs)
pyramid = dt.DataTree.from_dict(plevels)
pyramid.ds = xr.Dataset(attrs=attrs)

pyramid = add_metadata_and_zarr_encoding(
pyramid, levels=levels, pixels_per_tile=pixels_per_tile, other_chunks=other_chunks
Expand Down
15 changes: 5 additions & 10 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,10 @@ def make_grid_pyramid(levels: int = 6) -> dt.DataTree:
pyramid : dt.DataTree
Multiscale grid definition
"""
plevels = {}
for level in range(levels):
plevels[str(level)] = make_grid_ds(level).chunk(-1)
data = dt.DataTree.from_dict(plevels)

return data
plevels = {
str(level): make_grid_ds(level).chunk(-1) for level in range(levels)
}
return dt.DataTree.from_dict(plevels)


def generate_weights_pyramid(
Expand Down Expand Up @@ -163,10 +161,7 @@ def generate_weights_pyramid(

root = xr.Dataset(attrs={'levels': levels, 'regrid_method': method})
plevels['/'] = root
weights_pyramid = dt.DataTree.from_dict(plevels)


return weights_pyramid
return dt.DataTree.from_dict(plevels)


def pyramid_regrid(
Expand Down
2 changes: 1 addition & 1 deletion ndpyramid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def add_metadata_and_zarr_encoding(
'''
chunks = {'x': pixels_per_tile, 'y': pixels_per_tile}
if other_chunks is not None:
chunks.update(other_chunks)
chunks |= other_chunks

for level in range(levels):
slevel = str(level)
Expand Down
55 changes: 46 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,54 @@

[build-system]
requires = ["setuptools>=64", "setuptools-scm[toml]>=6.2", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "ndpyramid"
description = "A small utility for generating ND array pyramids using Xarray and Zarr"
readme = "README.md"
license = { text = "MIT" }
authors = [{ name = "CarbonPlan", email = "[email protected]" }]
requires-python = ">=3.9"
classifiers = [
"Development Status :: 4 - Beta",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Intended Audience :: Science/Research",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
]
dynamic = ["version"]

dependencies = ["cf_xarray", "xarray-datatree >= 0.0.11", "zarr"]

[project.urls]
repository = "https://github.com/carbonplan/ndpyramid"

[tool.setuptools.packages.find]
include = ["ndpyramid*"]

[tool.setuptools_scm]
version_scheme = "post-release"
local_scheme = "node-and-date"
fallback_version = "999"
write_to = "ndpyramid/_version.py"
write_to_template = '__version__ = "{version}"'


# [tool.setuptools.dynamic]
# version = { attr = "ndpyramid.__version__" }


[tool.black]
line-length = 100
target-version = ['py39']
skip-string-normalization = true


[build-system]
requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"]


[tool.ruff]
line-length = 100
target-version = "py39"
Expand Down Expand Up @@ -38,11 +79,7 @@ per-file-ignores = {}
# E402: module level import not at top of file
# E501: line too long - let black worry about that
# E731: do not assign a lambda expression, use a def
ignore = [
"E402",
"E501",
"E731",
]
ignore = ["E402", "E501", "E731"]
select = [
# Pyflakes
"F",
Expand Down
3 changes: 0 additions & 3 deletions requirements.txt

This file was deleted.

34 changes: 0 additions & 34 deletions setup.py

This file was deleted.

0 comments on commit 03058b3

Please sign in to comment.