diff --git a/penzai/core/tree_util.py b/penzai/core/tree_util.py index 2b10837..0047eb5 100644 --- a/penzai/core/tree_util.py +++ b/penzai/core/tree_util.py @@ -16,6 +16,7 @@ from __future__ import annotations +import dataclasses from typing import Any, Optional import jax @@ -23,6 +24,16 @@ PyTreeDef = jax.tree_util.PyTreeDef +@dataclasses.dataclass(frozen=True) +class CustomGetAttrKey: + """Subclass-friendly variant of jax.tree_util.GetAttrKey.""" + + name: str + + def __str__(self): + return f".{self.name}" + + def tree_flatten_exactly_one_level( tree: Any, ) -> Optional[tuple[list[tuple[Any, Any]], PyTreeDef]]: @@ -66,7 +77,10 @@ def pretty_keystr(keypath: tuple[Any, ...], tree: Any) -> str: parts = [] for key in keypath: if isinstance( - key, jax.tree_util.GetAttrKey | jax.tree_util.FlattenedIndexKey + key, + jax.tree_util.GetAttrKey + | jax.tree_util.FlattenedIndexKey + | CustomGetAttrKey, ): parts.extend(("/", type(tree).__name__)) split = tree_flatten_exactly_one_level(tree) diff --git a/penzai/nn/layer_stack.py b/penzai/nn/layer_stack.py index f511ca9..d2686de 100644 --- a/penzai/nn/layer_stack.py +++ b/penzai/nn/layer_stack.py @@ -17,10 +17,11 @@ from __future__ import annotations import collections +from collections.abc import Hashable import copy import dataclasses import enum -from typing import Any, Callable, Hashable +from typing import Any, Callable import jax from penzai.core import named_axes @@ -39,7 +40,7 @@ class LayerStackVarBehavior(enum.Enum): @dataclasses.dataclass(frozen=True) -class LayerStackGetAttrKey(jax.tree_util.GetAttrKey): +class LayerStackGetAttrKey(pz_tree_util.CustomGetAttrKey): """GetAttrKey for LayerStack with extra metadata. This allows us to identify whether a given PyTree leaf is contained inside a diff --git a/tests/nn/layer_stack_test.py b/tests/nn/layer_stack_test.py index 34250e7..421cacc 100644 --- a/tests/nn/layer_stack_test.py +++ b/tests/nn/layer_stack_test.py @@ -17,6 +17,7 @@ from typing import Any from absl.testing import absltest import chex +import collections import jax from penzai import pz @@ -155,7 +156,18 @@ def builder(init_base_rng, some_value): unbound_layer, layer_vars = pz.unbind_variables(layer) unbound_slot_layer, slot_layer_vars = pz.unbind_variables(slot_layer) - chex.assert_trees_all_equal(unbound_layer, unbound_slot_layer) + # Check as dictionaries to avoid limitations of chex: + unbound_layer_leaves, unbound_layer_treedef = ( + jax.tree_util.tree_flatten_with_path(unbound_layer) + ) + unbound_slot_layer_leaves, unbound_slot_layer_treedef = ( + jax.tree_util.tree_flatten_with_path(unbound_slot_layer) + ) + self.assertEqual(unbound_layer_treedef, unbound_slot_layer_treedef) + chex.assert_trees_all_equal( + collections.OrderedDict(unbound_layer_leaves), + collections.OrderedDict(unbound_slot_layer_leaves), + ) slot_layer_vars_dict = {var.label: var for var in slot_layer_vars} for var in layer_vars: