Skip to content

Commit

Permalink
Adapts most Penzai objects to render using __penzai_repr__ extension …
Browse files Browse the repository at this point in the history
…method.

Also makes some other small changes to the rendering intermediate representation.

PiperOrigin-RevId: 641747917
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jul 17, 2024
1 parent e27debf commit 215d5ec
Show file tree
Hide file tree
Showing 21 changed files with 323 additions and 288 deletions.
1 change: 0 additions & 1 deletion penzai/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from . import context
from . import dataclass_util
from . import formatting_util
from . import layer
from . import named_axes
from . import partitioning
Expand Down
73 changes: 0 additions & 73 deletions penzai/core/formatting_util.py

This file was deleted.

5 changes: 5 additions & 0 deletions penzai/core/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def __init_subclass__(cls, **kwargs):
" decorate with `penzai.core.layer.unchecked_layer_call`.)"
)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import layer_handler # pylint: disable=g-import-not-at-top

return layer_handler.handle_layers(self, path, subtree_renderer)


# Type alias for an arbitrary callable object with the expected signature.
LayerLike: typing.TypeAlias = Callable[[Any], Any]
11 changes: 9 additions & 2 deletions penzai/core/named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,13 @@ def __iter__(self):
for i in range(self.positional_shape[0]):
yield self[i]

# Rendering
def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
"""Treescope handler for named arrays."""
from penzai.treescope.handlers.penzai import named_axes_handlers # pylint: disable=g-import-not-at-top

return named_axes_handlers.handle_named_arrays(self, path, subtree_renderer)

# Convenience wrappers: Elementwise infix operators.
__lt__ = _nmap_with_doc(operator.lt, "jax.Array.__lt__")
__le__ = _nmap_with_doc(operator.le, "jax.Array.__le__")
Expand Down Expand Up @@ -1316,7 +1323,7 @@ def __iter__(self):


@struct.pytree_dataclass
class NamedArray(struct.Struct, NamedArrayBase):
class NamedArray(NamedArrayBase, struct.Struct):
r"""A multidimensional array with a combination of positional and named axes.
Conceptually, ``NamedArray``\ s can have positional axes like an ordinary
Expand Down Expand Up @@ -1522,7 +1529,7 @@ def tag(self, *names) -> NamedArray:


@struct.pytree_dataclass
class NamedArrayView(struct.Struct, NamedArrayBase):
class NamedArrayView(NamedArrayBase, struct.Struct):
"""A possibly-transposed view of an array with positional and named axes.
This view identifies a particular set of axes in a data array as "virtual
Expand Down
14 changes: 14 additions & 0 deletions penzai/core/shapecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class Wildcard(struct.Struct):
default=None, metadata={"pytree_node": False}
)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import shapecheck_handlers # pylint: disable=g-import-not-at-top

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
)


ANY = Wildcard()

Expand Down Expand Up @@ -350,6 +357,13 @@ def into_pytree(self) -> jax.ShapeDtypeStruct | named_axes.NamedArray:
else:
return jax.ShapeDtypeStruct(self.shape, self.dtype)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import shapecheck_handlers # pylint: disable=g-import-not-at-top

return shapecheck_handlers.handle_arraystructures(
self, path, subtree_renderer
)


def _abstract_leaf(value: Any) -> ArraySpec | Wildcard:
"""Helper function to get an `ArraySpec` view of a leaf."""
Expand Down
8 changes: 7 additions & 1 deletion penzai/core/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import jax
from penzai.core import dataclass_util
from penzai.core import formatting_util
from typing_extensions import dataclass_transform

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -667,6 +666,8 @@ def treescope_color(self) -> str | tuple[str, str]:
"""
# By default, we render structs in color if they define __call__.
if hasattr(self, "__call__"):
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

type_string = type(self).__module__ + "." + type(self).__qualname__
return formatting_util.color_from_string(type_string)
else:
Expand All @@ -691,3 +692,8 @@ def _repr_pretty_(self, p, cycle):
if i:
p.break_()
p.text(line)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import struct_handler # pylint: disable=g-import-not-at-top

return struct_handler.handle_structs(self, path, subtree_renderer)
26 changes: 25 additions & 1 deletion penzai/data_effects/effect_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@

import jax
import ordered_set
from penzai.core import formatting_util
from penzai.core import layer as layer_base
from penzai.core import selectors
from penzai.core import struct
Expand All @@ -100,6 +99,8 @@ def get_effect_color(effect_protocol: type[Any]) -> str:
"""Gets the default color for a given effect (for treescope rendering)."""
if effect_protocol in _EFFECT_COLORS:
return _EFFECT_COLORS[effect_protocol]
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

return formatting_util.color_from_string(effect_protocol.__qualname__)


Expand Down Expand Up @@ -388,6 +389,13 @@ def unhandled_effect_stub(*args, **kwargs):
def treescope_color(self):
return get_effect_color(self.effect_protocol())

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)


class EffectRuntimeImpl(abc.ABC):
"""Base class for runtime effect implementations.
Expand Down Expand Up @@ -435,6 +443,13 @@ def handler_id(self) -> HandlerId:
def treescope_color(self):
return get_effect_color(self.effect_protocol())

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)


@struct.pytree_dataclass
class EffectHandler(layer_base.Layer, abc.ABC):
Expand Down Expand Up @@ -494,8 +509,17 @@ def effect_protocol(cls) -> type[Any] | Collection[type[Any]] | None:
)

def treescope_color(self):
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

protocol = self.effect_protocol()
if isinstance(protocol, type):
return get_effect_color(protocol)
else:
return formatting_util.color_from_string(type(self).__qualname__)

def __penzai_repr__(self, path: str | None, subtree_renderer: Any):
from penzai.treescope.handlers.penzai import data_effects_handlers # pylint: disable=g-import-not-at-top

return data_effects_handlers.handle_data_effects_objects(
self, path, subtree_renderer
)
3 changes: 2 additions & 1 deletion penzai/nn/basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import jax
import jax.numpy as jnp
from penzai.core import formatting_util
from penzai.core import layer
from penzai.core import named_axes
from penzai.core import shapecheck
Expand Down Expand Up @@ -60,6 +59,8 @@ def __call__(
return self.fn(value)

def treescope_color(self) -> str:
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

return formatting_util.color_from_string(repr(self.fn))


Expand Down
7 changes: 6 additions & 1 deletion penzai/nn/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Callable, Sequence

import jax
from penzai.core import formatting_util
from penzai.core import layer as layer_base
from penzai.core import selectors
from penzai.core import shapecheck
Expand Down Expand Up @@ -80,6 +79,8 @@ def __call__(self, value: Any) -> Any:
return value

def treescope_color(self) -> str | tuple[str, str]:
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

if type(self) is Sequential: # pylint: disable=unidiomatic-typecheck
return "#cdcdcd", "color-mix(in oklab, #cdcdcd 25%, white)"
else:
Expand Down Expand Up @@ -144,6 +145,8 @@ def __call__(self, value: Any) -> Any:
return value

def treescope_color(self) -> str | tuple[str, str]:
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

accent = formatting_util.color_from_string(self.name)
return accent, f"color-mix(in oklab, {accent} 25%, white)"

Expand Down Expand Up @@ -193,6 +196,8 @@ def __call__(self, value: Any) -> Any:
return value

def treescope_color(self) -> str | tuple[str, str]:
from penzai.treescope import formatting_util # pylint: disable=g-import-not-at-top

if type(self) is CheckedSequential: # pylint: disable=unidiomatic-typecheck
return "#cdcdcd"
else:
Expand Down
4 changes: 0 additions & 4 deletions penzai/pz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
dataclass_from_attributes,
init_takes_fields,
)
from penzai.core.formatting_util import (
oklch_color,
color_from_string,
)
from penzai.core.layer import (
Layer,
LayerLike,
Expand Down
3 changes: 2 additions & 1 deletion penzai/toolshed/unflaxify.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import flax.typing
import jax
from penzai import pz
from penzai.treescope import formatting_util


@pz.pytree_dataclass
Expand Down Expand Up @@ -194,7 +195,7 @@ def redirecting_interceptor(next_fun, args, kwargs, context):
return output

def treescope_color(self) -> str:
return pz.color_from_string(type(self.module).__name__)
return formatting_util.color_from_string(type(self.module).__name__)


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions penzai/treescope/arrayviz/array_autovisualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""An automatic NDArray visualizer using arrayviz."""
from __future__ import annotations

import dataclasses
from typing import Any, Callable, Collection
Expand Down
Loading

0 comments on commit 215d5ec

Please sign in to comment.