Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate datatreee assertions/extensions/formatting #8967

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,17 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`)
- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`)
By `Eni Awowale <https://github.com/eni-awowale>`_, `Julia Signell <https://github.com/jsignell>`_
and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
<https://github.com/owenlittlejohns>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for
``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into
``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`)
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
check_isomorphic,
map_over_subtree,
)
from xarray.core.datatree_render import RenderDataTree
from xarray.core.formatting import datatree_repr
from xarray.core.formatting_html import (
datatree_repr as datatree_repr_html,
)
Expand All @@ -40,13 +42,11 @@
)
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin
from xarray.datatree_.datatree.formatting import datatree_repr
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
MappedDataWithCoords,
)
from xarray.datatree_.datatree.render import RenderTree

try:
from xarray.core.variable import calculate_dimensions
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def pipe(

def render(self):
"""Print tree structure, including any data stored at each node."""
for pre, fill, node in RenderTree(self):
for pre, fill, node in RenderDataTree(self):
print(f"{pre}DataTree('{self.name}')")
for ds_line in repr(node.ds)[1:]:
print(f"{fill}{ds_line}")
Expand Down
37 changes: 3 additions & 34 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import functools
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset
from xarray.core.iterators import LevelOrderIter
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,37 +71,6 @@ def check_isomorphic(
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into xarray/core/formatting.py to avoid a circular dependency issue.

"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff

return ""


def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Expand Down
266 changes: 266 additions & 0 deletions xarray/core/datatree_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""
String Tree Rendering. Copied from anytree.

Minor changes to `RenderDataTree` include accessing `children.values()`, and
type hints.

"""

from __future__ import annotations

from collections import namedtuple
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from xarray.core.datatree import DataTree

Row = namedtuple("Row", ("pre", "fill", "node"))


class AbstractStyle:
def __init__(self, vertical: str, cont: str, end: str):
"""
Tree Render Style.
Args:
vertical: Sign for vertical line.
cont: Chars for a continued branch.
end: Chars for the last branch.
"""
super().__init__()
self.vertical = vertical
self.cont = cont
self.end = end
assert (
len(cont) == len(vertical) == len(end)
), f"'{vertical}', '{cont}' and '{end}' need to have equal length"

@property
def empty(self) -> str:
"""Empty string as placeholder."""
return " " * len(self.end)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class ContStyle(AbstractStyle):
def __init__(self):
"""
Continued style, without gaps.

>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(name="sub0B", parent=s0)
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> print(RenderDataTree(root))
DataTree('root', parent=None)
├── DataTree('sub0')
│ ├── DataTree('sub0B')
│ └── DataTree('sub0A')
└── DataTree('sub1')
"""
super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ")


class RenderDataTree:
def __init__(
self,
node: DataTree,
style=ContStyle(),
childiter: type = list,
maxlevel: int | None = None,
):
"""
Render tree starting at `node`.
Keyword Args:
style (AbstractStyle): Render Style.
childiter: Child iterator. Note, due to the use of node.children.values(),
Iterables that change the order of children cannot be used
(e.g., `reversed`).
maxlevel: Limit rendering to this depth.
:any:`RenderDataTree` is an iterator, returning a tuple with 3 items:
`pre`
tree prefix.
`fill`
filling for multiline entries.
`node`
:any:`NodeMixin` object.
It is up to the user to assemble these parts to a whole.

Examples
--------

>>> from xarray import Dataset
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples in this documentation string are a bit shorter than the originals from anytree. That's because using node.children.values() gets a ValuesView which isn't compatible with iterables like reversed that alter the order of the items in the iterable.

>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1}))
>>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3}))
>>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4}))
>>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6}))
>>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7}))

# Simple one line:

>>> for pre, _, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
...
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1

# Multiline:

>>> for pre, fill, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
... for variable in node.variables:
... print(f"{fill}{variable}")
...
root
a
b
├── sub0
│ c
│ d
│ ├── sub0B
│ │ e
│ └── sub0A
│ f
│ g
└── sub1
h

:any:`by_attr` simplifies attribute rendering and supports multiline:
>>> print(RenderDataTree(root).by_attr())
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1

# `maxlevel` limits the depth of the tree:

>>> print(RenderDataTree(root, maxlevel=2).by_attr("name"))
root
├── sub0
└── sub1
"""
if not isinstance(style, AbstractStyle):
style = style()
self.node = node
self.style = style
self.childiter = childiter
self.maxlevel = maxlevel

def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())

def __next(
self, node: DataTree, continues: tuple[bool, ...], level: int = 0
) -> Iterator[Row]:
yield RenderDataTree.__item(node, continues, self.style)
children = node.children.values()
level += 1
if children and (self.maxlevel is None or level < self.maxlevel):
children = self.childiter(children)
for child, is_last in _is_last(children):
yield from self.__next(child, continues + (not is_last,), level=level)

@staticmethod
def __item(
node: DataTree, continues: tuple[bool, ...], style: AbstractStyle
) -> Row:
if not continues:
return Row("", "", node)
else:
items = [style.vertical if cont else style.empty for cont in continues]
indent = "".join(items[:-1])
branch = style.cont if continues[-1] else style.end
pre = indent + branch
fill = "".join(items)
return Row(pre, fill, node)

def __str__(self) -> str:
return str(self.node)

def __repr__(self) -> str:
classname = self.__class__.__name__
args = [
repr(self.node),
f"style={repr(self.style)}",
f"childiter={repr(self.childiter)}",
]
return f"{classname}({', '.join(args)})"

def by_attr(self, attrname: str = "name") -> str:
"""
Return rendered tree with node attribute `attrname`.

Examples
--------

>>> from xarray import Dataset
>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(
... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109})
... )
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> s1a = DataTree(name="sub1A", parent=s1)
>>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8}))
>>> s1c = DataTree(name="sub1C", parent=s1)
>>> s1ca = DataTree(name="sub1Ca", parent=s1c)
>>> print(RenderDataTree(root).by_attr("name"))
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1
├── sub1A
├── sub1B
└── sub1C
└── sub1Ca
"""

def get() -> Iterator[str]:
for pre, fill, node in self:
attr = (
attrname(node)
if callable(attrname)
else getattr(node, attrname, "")
)
if isinstance(attr, (list, tuple)):
lines = attr
else:
lines = str(attr).split("\n")
yield f"{pre}{lines[0]}"
for line in lines[1:]:
yield f"{fill}{line}"

return "\n".join(get())


def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]:
iter_ = iter(iterable)
try:
nextitem = next(iter_)
except StopIteration:
pass
else:
item = nextitem
while True:
try:
nextitem = next(iter_)
yield item, False
except StopIteration:
yield nextitem, True
break
item = nextitem
Loading
Loading