diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 84ce392ad32..932cab2fa23 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -5,11 +5,12 @@ from pathlib import PurePosixPath from typing import ( TYPE_CHECKING, + Any, Generic, TypeVar, ) -from xarray.core.utils import Frozen, is_dict_like +from xarray.core.utils import is_dict_like if TYPE_CHECKING: from xarray.core.types import T_DataArray @@ -44,6 +45,73 @@ def __init__(self, *pathsegments): Tree = TypeVar("Tree", bound="TreeNode") +class Children(Mapping[str, Tree], Generic[Tree]): + """ + Dictionary-like container for the immediate children of a single DataTree node. + + This collection can be passed directly to the :py:class:`~xarray.DataTree` constructor via its `children` argument. + """ + + _treenode: Tree + + # TODO add slots? + # __slots__ = ("_data",) + + def __init__(self, treenode: Tree): + self._treenode = treenode + + @property + def _names(self) -> list[str]: + return list(self._treenode._children.keys()) + + def __iter__(self) -> Iterator[str]: + return iter(self._names) + + def __len__(self) -> int: + return len(self._names) + + def __contains__(self, key: str) -> bool: + return key in self._names + + def __repr__(self) -> str: + return "\n".join(["Children:"] + [f" {name}" for name in self._names]) + + def __getitem__(self, key: str) -> Tree: + return self._treenode._children[key] + + def __delitem__(self, key: str) -> None: + if key in self._names: + child = self._treenode._children[key] + del self._treenode._children[key] + child.orphan() + else: + raise KeyError(key) + + def __setitem__(self, key: str, value: Any) -> None: + self.update({key: value}) + + def update(self, other: Mapping[str, Tree]) -> None: + """Update with other child nodes.""" + + # TODO forbid strings with `/` in here? + + if not len(other): + return + + children = self._treenode._children.copy() + children.update(other) + self._treenode.children = children + + def _ipython_key_completions_(self): + """Provide method for the key-autocompletions in IPython.""" + # TODO + return [ + key + for key in self._treenode._ipython_key_completions_() + if key not in self._treenode.variables + ] + + class TreeNode(Generic[Tree]): """ Base class representing a node of a tree, with methods for traversing and altering the tree. @@ -160,9 +228,9 @@ def orphan(self) -> None: self._set_parent(new_parent=None) @property - def children(self: Tree) -> Mapping[str, Tree]: + def children(self: Tree) -> Children[str, Tree]: """Child nodes of this node, stored under a mapping via their names.""" - return Frozen(self._children) + return Children(self) @children.setter def children(self: Tree, children: Mapping[str, Tree]) -> None: diff --git a/xarray/tests/test_treenode.py b/xarray/tests/test_treenode.py index d9d581cc314..02d478c24de 100644 --- a/xarray/tests/test_treenode.py +++ b/xarray/tests/test_treenode.py @@ -1,12 +1,19 @@ from __future__ import annotations from collections.abc import Iterator +from textwrap import dedent from typing import cast import pytest from xarray.core.iterators import LevelOrderIter -from xarray.core.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode +from xarray.core.treenode import ( + Children, + InvalidTreeError, + NamedNode, + NodePath, + TreeNode, +) class TestFamilyTree: @@ -224,6 +231,74 @@ def test_overwrite_child(self): assert marys_evil_twin.parent is john +class TestChildren: + def test_properties(self): + sue: TreeNode = TreeNode() + mary: TreeNode = TreeNode(children={"Sue": sue}) + kate: TreeNode = TreeNode() + john = TreeNode(children={"Mary": mary, "Kate": kate}) + + children = john.children + assert isinstance(children, Children) + + # len + assert len(children) == 2 + + # iter + assert list(children) == ["Mary", "Kate"] + + assert john.children["Mary"] is mary + assert john.children["Kate"] is kate + + assert "Mary" in john.children + assert "Kate" in john.children + assert 0 not in john.children + assert "foo" not in john.children + + # only immediate children should be accessible + assert "sue" not in john.children + + with pytest.raises(KeyError): + children["foo"] + with pytest.raises(KeyError): + children[0] + + # repr + expected = dedent( + """\ + Children: + Mary + Kate""" + ) + actual = repr(children) + assert expected == actual + + def test_modify(self): + sue: TreeNode = TreeNode() + mary: TreeNode = TreeNode(children={"Sue": sue}) + kate: TreeNode = TreeNode() + john = TreeNode(children={"Mary": mary, "Kate": kate}) + + children = john.children + + # test assignment + ashley: TreeNode = TreeNode() + children["Ashley"] = ashley + assert john.children["Ashley"] is ashley + + # test deletion + del children["Ashley"] + assert "Ashley" not in john.children + + # test constructor + john2 = TreeNode(children=children) + assert john2.children == children + + def test_modify_below_root(self): + # TODO test that modifying .children doesn't affect grandparent + ... + + class TestPruning: def test_del_child(self): john: TreeNode = TreeNode()