Skip to content

Commit

Permalink
Change treescope handler interface to use string paths.
Browse files Browse the repository at this point in the history
Treescope's handlers originally used the same path format as JAX keypaths, but the
rendering utilities only need to render them to strings. This changes the format so
that they are always strings.

This is the first step toward decoupling treescope from the rest of Penzai and from JAX.

PiperOrigin-RevId: 653389922
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 17, 2024
1 parent 4ef341b commit e27debf
Show file tree
Hide file tree
Showing 25 changed files with 99 additions and 104 deletions.
4 changes: 2 additions & 2 deletions penzai/treescope/arrayviz/array_autovisualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ArrayAutovisualizer:
def _autovisualize_namedarray(
self,
named_array: named_axes.NamedArrayBase,
path: tuple[Any, ...] | None,
path: str | None,
label: str,
expand_state: part_interface.ExpandState,
) -> part_interface.RenderableTreePart:
Expand Down Expand Up @@ -213,7 +213,7 @@ def _autovisualize_namedarray(
return custom_rendering.renderable

def __call__(
self, value: Any, path: tuple[Any, ...] | None
self, value: Any, path: str | None
) -> autovisualize.CustomTreescopeVisualization | None:
"""Implementation of an autovisualizer, visualizing arrays."""
with jax.core.ensure_compile_time_eval():
Expand Down
7 changes: 3 additions & 4 deletions penzai/treescope/autovisualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Autovisualizer(Protocol):

@abc.abstractmethod
def __call__(
self, value: Any, path: tuple[Any, ...] | None
self, value: Any, path: str | None
) -> (
IPythonVisualization
| CustomTreescopeVisualization
Expand All @@ -113,9 +113,8 @@ def __call__(
Args:
value: A value being rendered in treescope.
path: Path to this value from the root, as a JAX keypath. May be None if
this object isn't part of the root PyTree and so treescope doesn't know
how to access it.
path: Optionally, a path to this node, represented as a string that can be
used to reach this node from the root (e.g. ".foo.bar['baz']").
Returns:
A visualization for this subtree, a child autovisualizer to use while
Expand Down
13 changes: 6 additions & 7 deletions penzai/treescope/foldable_representation/common_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,20 @@
RenderableAndLineAnnotations = basic_parts.RenderableAndLineAnnotations


def build_copy_button(path: tuple[Any, ...] | None) -> RenderableTreePart:
def build_copy_button(path: str | None) -> RenderableTreePart:
"""Builds a copy-path button, if `path` is provided and not empty."""
if not path:
return basic_parts.EmptyPart()
else:
return foldable_impl.StringCopyButton(
annotation="Copy path: ",
copy_string=(
"(lambda root: root" + "".join(str(key) for key in path) + ")"
),
copy_string=f"(lambda root: root{path})",
)


def build_custom_foldable_tree_node(
contents: RenderableTreePart,
path: tuple[Any, ...] | None = None,
path: str | None = None,
label: RenderableTreePart = basic_parts.EmptyPart(),
expand_state: part_interface.ExpandState = part_interface.ExpandState.WEAKLY_COLLAPSED,
) -> RenderableAndLineAnnotations:
Expand Down Expand Up @@ -85,7 +83,7 @@ def build_custom_foldable_tree_node(

def build_one_line_tree_node(
line: RenderableAndLineAnnotations | RenderableTreePart | str,
path: tuple[Any, ...] | None = None,
path: str | None = None,
background_color: str | None = None,
background_pattern: str | None = None,
) -> RenderableAndLineAnnotations:
Expand Down Expand Up @@ -149,7 +147,7 @@ def build_foldable_tree_node_from_children(
suffix: RenderableTreePart | str,
comma_separated: bool = False,
force_trailing_comma: bool = False,
path: tuple[Any, ...] | None = None,
path: str | None = None,
background_color: str | None = None,
background_pattern: str | None = None,
first_line_annotation: RenderableTreePart | None = None,
Expand Down Expand Up @@ -188,6 +186,7 @@ def build_foldable_tree_node_from_children(
)

maybe_copy_button = build_copy_button(path)

if isinstance(prefix, str):
prefix = basic_parts.Text(prefix)

Expand Down
4 changes: 2 additions & 2 deletions penzai/treescope/foldable_representation/foldable_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class HyperlinkTarget(basic_parts.DeferringToChild):
"""

child: RenderableTreePart
keypath: tuple[Any, ...] | None
keypath: str | None

def render_to_html(
self,
Expand Down Expand Up @@ -299,7 +299,7 @@ class NodeHyperlink(basic_parts.DeferringToChild):
"""

child: RenderableTreePart
target_keypath: tuple[Any, ...] | None
target_keypath: str | None

def html_setup_parts(
self, setup_context: HtmlContextForSetup
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/foldable_representation/part_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,5 +373,5 @@ class RenderableAndLineAnnotations:
annotations: RenderableTreePart | None = None


NodePath: TypeAlias = tuple[Any, ...]
NodePath: TypeAlias = str
Rendering: TypeAlias = RenderableTreePart | RenderableAndLineAnnotations
2 changes: 1 addition & 1 deletion penzai/treescope/handlers/builtin_atom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _span_css_rule(self, context: HtmlContextForSetup) -> CSSStyleRule:

def handle_builtin_atoms(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
24 changes: 7 additions & 17 deletions penzai/treescope/handlers/builtin_structure_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from typing import Any, Callable, Optional, Sequence
import warnings

import jax
from penzai.core import dataclass_util
from penzai.treescope import renderer
from penzai.treescope.foldable_representation import basic_parts
Expand All @@ -37,7 +36,7 @@

def _dict_to_foldable(
node: dict[Any, Any],
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> part_interface.RenderableAndLineAnnotations:
"""Renders a dictionary."""
Expand All @@ -54,9 +53,7 @@ def _dict_to_foldable(
# Last child: only show the comma when the node is expanded.
comma_after = basic_parts.FoldCondition(expanded=basic_parts.Text(","))

child_path = (
None if path is None else (path + (jax.tree_util.DictKey(key),))
)
child_path = None if path is None else f"{path}[{repr(key)}]"
# Figure out whether this key is simple enough to render inline with
# its value.
key_rendering = subtree_renderer(key)
Expand Down Expand Up @@ -117,16 +114,14 @@ def _dict_to_foldable(

def _sequence_or_set_to_foldable(
sequence: dict[Any, Any],
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> part_interface.RenderableAndLineAnnotations:
"""Renders a sequence or set to a foldable."""

children = []
for i, child in enumerate(sequence):
child_path = (
None if path is None else (path + (jax.tree_util.SequenceKey(i),))
)
child_path = None if path is None else f"{path}[{repr(i)}]"
children.append(subtree_renderer(child, path=child_path))

force_trailing_comma = False
Expand Down Expand Up @@ -197,10 +192,9 @@ def _sequence_or_set_to_foldable(

def build_field_children(
node: dict[Any, Any],
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
fields_or_attribute_names: Sequence[dataclasses.Field[Any] | str],
key_path_fn: Callable[[str], Any] = jax.tree_util.GetAttrKey,
attr_style_fn: (
Callable[[str], part_interface.RenderableTreePart] | None
) = None,
Expand Down Expand Up @@ -228,10 +222,6 @@ def build_field_children(
fields_or_attribute_names: Sequence of fields or attribute names to render.
Any field with the metadata key "treescope_always_collapse" set to True
will always render collapsed.
key_path_fn: Optional function which maps field names to their JAX keys, if
applicable. This should match their registered keypaths in the PyTree
registry when applicable (although it will also be called for fields that
are not necessarily PyTree children).
attr_style_fn: Optional function which makes attributes to a part that
should render them. If not provided, all parts are rendered as plain text.
Expand All @@ -255,7 +245,7 @@ def build_field_children(

children = []
for i, (field_name, maybe_field) in enumerate(zip(field_names, fields)):
child_path = None if path is None else (path + (key_path_fn(field_name),))
child_path = None if path is None else f"{path}.{field_name}"

if i < len(fields) - 1:
# Not the last child. Always show a comma, and add a space when
Expand Down Expand Up @@ -354,7 +344,7 @@ def parse_color_and_pattern(

def handle_builtin_structures(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/handlers/canonical_alias_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

def replace_with_canonical_aliases(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
node_renderer: renderer.TreescopeSubtreeRenderer,
summarization_threshold: int = 20,
) -> (
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/handlers/extension_method_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def handle_via_penzai_repr_method(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/handlers/function_reflection_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def format_source_location(

def handle_code_objects_with_reflection(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
show_closure_vars: bool = False,
) -> (
Expand Down
4 changes: 2 additions & 2 deletions penzai/treescope/handlers/generic_pytree_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def handle_arbitrary_pytrees(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down Expand Up @@ -55,7 +55,7 @@ def handle_arbitrary_pytrees(
# Then add an extra block that pretty-prints its children.
list_items = []
for key, child in subtrees_with_paths:
child_path = None if path is None else (path + (key,))
child_path = None if path is None else path + str(key)
list_items.append(
basic_parts.siblings_with_annotations(
subtree_renderer(key, path=None),
Expand Down
2 changes: 1 addition & 1 deletion penzai/treescope/handlers/generic_repr_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

def handle_anything_with_repr(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
6 changes: 3 additions & 3 deletions penzai/treescope/handlers/hardcoded_structure_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class IsEnumLike:
def _dataclass_like(
fields: Sequence[str],
node: Any,
path: tuple[Any, ...],
path: str,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
):
"""Renders a dataclass-like object."""
Expand All @@ -113,7 +113,7 @@ def _dataclass_like(

def _enum_like(
node: Any,
path: tuple[Any, ...],
path: str,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
):
"""Renders a enum-like object."""
Expand Down Expand Up @@ -161,7 +161,7 @@ class HardcodedStructureHandler:
def __call__(
self,
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
4 changes: 2 additions & 2 deletions penzai/treescope/handlers/ndarray_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def handle_ndarrays(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down Expand Up @@ -137,7 +137,7 @@ def _thunk(placeholder):

def handle_dtype_instances(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down
7 changes: 3 additions & 4 deletions penzai/treescope/handlers/penzai/data_effects_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from penzai.treescope.handlers.penzai import struct_handler

_known_handlers: context.ContextualValue[
dict[str, tuple[effect_base.EffectHandler, tuple[Any, ...] | None]] | None
dict[str, tuple[effect_base.EffectHandler, str | None]] | None
] = context.ContextualValue(
module=__name__, qualname="_known_handlers", initial_value=None
)
Expand All @@ -49,7 +49,7 @@

def handle_data_effects_objects(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand All @@ -60,7 +60,7 @@ def handle_data_effects_objects(

def handler_id_interceptor(
node: Any,
path: tuple[Any, ...] | None = None,
path: str | None = None,
*,
handler_id: str,
hyperlink_path=None,
Expand Down Expand Up @@ -132,7 +132,6 @@ def handler_id_interceptor(
hyperlink_path=handler_path,
),
fields_or_attribute_names=fields,
key_path_fn=node.key_for_field,
attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields),
)
background_color = node.treescope_color()
Expand Down
3 changes: 1 addition & 2 deletions penzai/treescope/handlers/penzai/layer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

def handle_layers(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
obvious_input_output_structure_types: tuple[type[Any], ...] = (
grouping.CheckStructure,
Expand Down Expand Up @@ -194,7 +194,6 @@ def handle_layers(
path,
subtree_renderer,
fields_or_attribute_names=fields,
key_path_fn=node.key_for_field,
attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields),
)

Expand Down
3 changes: 1 addition & 2 deletions penzai/treescope/handlers/penzai/named_axes_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def named_array_and_contained_type_summary(

def handle_named_arrays(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down Expand Up @@ -137,7 +137,6 @@ def _make_label(inspect_device_data):
path,
subtree_renderer,
fields_or_attribute_names=fields,
key_path_fn=node.key_for_field,
attr_style_fn=struct_handler.struct_attr_style_fn_for_fields(fields),
)

Expand Down
3 changes: 1 addition & 2 deletions penzai/treescope/handlers/penzai/shapecheck_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _arraystructure_summary(

def handle_arraystructures(
node: Any,
path: tuple[Any, ...] | None,
path: str | None,
subtree_renderer: renderer.TreescopeSubtreeRenderer,
) -> (
part_interface.RenderableTreePart
Expand Down Expand Up @@ -161,7 +161,6 @@ def handle_arraystructures(
path,
subtree_renderer,
fields_or_attribute_names=dataclasses.fields(node),
key_path_fn=node.key_for_field,
)
indented_children = basic_parts.IndentedChildren.build(children)

Expand Down
Loading

0 comments on commit e27debf

Please sign in to comment.