"
@@ -410,26 +399,84 @@
{
"data": {
"text/html": [
- "
"
+ " MLP Summary \n",
+ "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ path ┃ type ┃ BatchStat ┃ Param ┃ RngState ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ bn │ BatchNorm │ mean: float32[5,32] │ bias: float32[5,32] │ │\n",
+ "│ │ │ var: float32[5,32] │ scale: float32[5,32] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ 320 (1.3 KB) │ 320 (1.3 KB) │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
+ "│ │ │ │ │ tag: default │\n",
+ "│ │ │ │ │ value: uint32[5] │\n",
+ "│ │ │ │ │ key: │\n",
+ "│ │ │ │ │ tag: default │\n",
+ "│ │ │ │ │ value: key<fry>[5] │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ │ 10 (60 B) │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ linear1 │ Linear │ │ b: float32[5,32] │ │\n",
+ "│ │ │ │ w: float32[5,10,32] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ 1,760 (7.0 KB) │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ linear2 │ Linear │ │ b: float32[5,10] │ │\n",
+ "│ │ │ │ w: float32[5,32,10] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ 1,650 (6.6 KB) │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ │ Total │ 320 (1.3 KB) │ 3,730 (14.9 KB) │ 10 (60 B) │\n",
+ "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
+ " \n",
+ " Total Parameters: 4,060 (16.3 KB) \n",
+ "
\n"
],
"text/plain": [
- ""
+ "\u001b[3m MLP Summary \u001b[0m\n",
+ "┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mtype \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mBatchStat \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mRngState \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ bn │ BatchNorm │ mean: \u001b[2mfloat32\u001b[0m[5,32] │ bias: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
+ "│ │ │ var: \u001b[2mfloat32\u001b[0m[5,32] │ scale: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ \u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ dropout/rngs/default │ RngStream │ │ │ count: │\n",
+ "│ │ │ │ │ tag: default │\n",
+ "│ │ │ │ │ value: \u001b[2muint32\u001b[0m[5] │\n",
+ "│ │ │ │ │ key: │\n",
+ "│ │ │ │ │ tag: default │\n",
+ "│ │ │ │ │ value: \u001b[2mkey\u001b[0m[5] │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ │ \u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ linear1 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,32] │ │\n",
+ "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,10,32] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ \u001b[1m1,760 \u001b[0m\u001b[1;2m(7.0 KB)\u001b[0m │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│ linear2 │ Linear │ │ b: \u001b[2mfloat32\u001b[0m[5,10] │ │\n",
+ "│ │ │ │ w: \u001b[2mfloat32\u001b[0m[5,32,10] │ │\n",
+ "│ │ │ │ │ │\n",
+ "│ │ │ │ \u001b[1m1,650 \u001b[0m\u001b[1;2m(6.6 KB)\u001b[0m │ │\n",
+ "├──────────────────────┼───────────┼─────────────────────┼──────────────────────┼──────────────────────┤\n",
+ "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m320 \u001b[0m\u001b[1;2m(1.3 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3,730 \u001b[0m\u001b[1;2m(14.9 KB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m10 \u001b[0m\u001b[1;2m(60 B)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n",
+ "└──────────────────────┴───────────┴─────────────────────┴──────────────────────┴──────────────────────┘\n",
+ "\u001b[1m \u001b[0m\n",
+ "\u001b[1m Total Parameters: 4,060 \u001b[0m\u001b[1;2m(16.3 KB)\u001b[0m\u001b[1m \u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "data": {
- "text/html": [
- "
"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
}
],
"source": [
@@ -481,7 +528,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -493,7 +540,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -542,7 +589,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -554,7 +601,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -566,7 +613,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -667,7 +714,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -679,7 +726,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -691,7 +738,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -703,7 +750,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md
index 61b96e2d34..51e0cda53f 100644
--- a/docs_nnx/nnx_basics.md
+++ b/docs_nnx/nnx_basics.md
@@ -12,18 +12,7 @@ jupytext:
Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.
-In this guide you will learn about:
-
-- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
- - Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
- - Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
- - Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
-- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
- - [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
-- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
- - [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
- - [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
- - Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
+To begin, install Flax with `pip` and import necessary dependencies:
## Setup
@@ -106,7 +95,7 @@ to handle them, as demonstrated in later sections of this guide.
Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
-The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
+The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
```{code-cell} ipython3
class MLP(nnx.Module):
diff --git a/flax/linen/summary.py b/flax/linen/summary.py
index d6676729f0..5d1b214249 100644
--- a/flax/linen/summary.py
+++ b/flax/linen/summary.py
@@ -48,6 +48,13 @@
LogicalNames,
)
+try:
+ from IPython import get_ipython
+
+ in_ipython = get_ipython() is not None
+except ImportError:
+ in_ipython = False
+
class _ValueRepresentation(ABC):
"""A class that represents a value in the summary table."""
@@ -242,11 +249,6 @@ def tabulate(
Total Parameters: 50 (200 B)
-
- **Note**: rows order in the table does not represent execution order,
- instead it aligns with the order of keys in `variables` which are sorted
- alphabetically.
-
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
Args:
@@ -267,7 +269,9 @@ def tabulate(
mutable.
console_kwargs: An optional dictionary with additional keyword arguments
that are passed to `rich.console.Console` when rendering the table.
- Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
+ Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
+ is set to ``True`` if the code is running in a Jupyter notebook, otherwise
+ it is set to ``False``.
table_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.table.Table` constructor.
column_kwargs: An optional dictionary with additional keyword arguments that
@@ -564,7 +568,7 @@ def _render_table(
non_params_cols: list[str],
) -> str:
"""A function that renders a Table to a string representation using rich."""
- console_kwargs = {'force_terminal': True, 'force_jupyter': False}
+ console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
if console_extras is not None:
console_kwargs.update(console_extras)
diff --git a/flax/nnx/filterlib.py b/flax/nnx/filterlib.py
index 63ed371be9..1028efb2b1 100644
--- a/flax/nnx/filterlib.py
+++ b/flax/nnx/filterlib.py
@@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
else:
raise TypeError(f'Invalid collection filter: {filter:!r}. ')
-def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
+def filters_to_predicates(
+ filters: tp.Sequence[Filter],
+) -> tuple[Predicate, ...]:
for i, filter_ in enumerate(filters):
if filter_ in (..., True) and i != len(filters) - 1:
remaining_filters = filters[i + 1 :]
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index a29999d34f..8cc272f8eb 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -24,7 +24,7 @@
import numpy as np
import typing_extensions as tpe
-from flax.nnx import filterlib, reprlib
+from flax.nnx import filterlib, reprlib, visualization
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
@@ -63,7 +63,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
return isinstance(x, Variable)
-class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
+class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin):
"""A mapping that uses object id as the hash for the keys."""
def __init__(
@@ -248,8 +248,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('index', self.index)
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- return treescope.repr_lib.render_object_constructor(
+ return visualization.render_object_constructor(
object_type=type(self),
attributes={'type': self.type, 'index': self.index},
path=path,
@@ -272,9 +271,7 @@ def __nnx_repr__(self):
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
-
- return treescope.repr_lib.render_object_constructor(
+ return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
@@ -353,8 +350,7 @@ def __nnx_repr__(self):
)
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- return treescope.repr_lib.render_object_constructor(
+ return visualization.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
diff --git a/flax/nnx/module.py b/flax/nnx/module.py
index 795bb9a088..b07efa7711 100644
--- a/flax/nnx/module.py
+++ b/flax/nnx/module.py
@@ -403,23 +403,6 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)
- def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- children = {}
- for name, value in vars(self).items():
- if name.startswith('_'):
- continue
- children[name] = value
- return treescope.repr_lib.render_object_constructor(
- object_type=type(self),
- attributes=children,
- path=path,
- subtree_renderer=subtree_renderer,
- color=treescope.formatting_util.color_from_string(
- type(self).__qualname__
- )
- )
-
# -------------------------
# Pytree Definition
# -------------------------
diff --git a/flax/nnx/nn/stochastic.py b/flax/nnx/nn/stochastic.py
index 2a495826a4..add545634a 100644
--- a/flax/nnx/nn/stochastic.py
+++ b/flax/nnx/nn/stochastic.py
@@ -24,7 +24,7 @@
from flax.nnx.module import Module, first_from
-@dataclasses.dataclass
+@dataclasses.dataclass(repr=False)
class Dropout(Module):
"""Create a dropout layer.
diff --git a/flax/nnx/object.py b/flax/nnx/object.py
index afa41cdb7b..3ff35023d7 100644
--- a/flax/nnx/object.py
+++ b/flax/nnx/object.py
@@ -20,27 +20,67 @@
from abc import ABCMeta
from copy import deepcopy
-
import jax
import numpy as np
+import treescope
+from treescope import rendering_parts
+from flax.nnx import visualization
+from flax import errors
from flax.nnx import (
+ graph,
reprlib,
tracers,
)
-from flax.nnx import graph
+from flax import nnx
from flax.nnx.variablelib import Variable, VariableState
-from flax import errors
+from flax.typing import SizeBytes, value_stats
G = tp.TypeVar('G', bound='Object')
+def _collect_stats(
+ node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]]
+):
+ if not graph.is_node(node) and not isinstance(node, Variable):
+ raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.')
+
+ if id(node) in node_stats:
+ return
+
+ stats: dict[type[Variable], SizeBytes] = {}
+ node_stats[id(node)] = stats
+
+ if isinstance(node, Variable):
+ var_type = type(node)
+ if issubclass(var_type, nnx.RngState):
+ var_type = nnx.RngState
+ size_bytes = value_stats(node.value)
+ if size_bytes:
+ stats[var_type] = size_bytes
+
+ else:
+ node_dict = graph.get_node_impl(node).node_dict(node)
+ for key, value in node_dict.items():
+ if id(value) in node_stats:
+ continue
+ if graph.is_node(value) or isinstance(value, Variable):
+ _collect_stats(value, node_stats)
+ child_stats = node_stats[id(value)]
+ for var_type, size_bytes in child_stats.items():
+ if var_type in stats:
+ stats[var_type] += size_bytes
+ else:
+ stats[var_type] = size_bytes
+
+
@dataclasses.dataclass
-class GraphUtilsContext(threading.local):
+class ObjectContext(threading.local):
seen_modules_repr: set[int] | None = None
+ node_stats: dict[int, dict[type[Variable], SizeBytes]] | None = None
-CONTEXT = GraphUtilsContext()
+OBJECT_CONTEXT = ObjectContext()
class ObjectState(reprlib.Representable):
@@ -63,14 +103,14 @@ def __nnx_repr__(self):
yield reprlib.Attr('trace_state', self._trace_state)
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- return treescope.repr_lib.render_object_constructor(
- object_type=type(self),
- attributes={'trace_state': self._trace_state},
- path=path,
- subtree_renderer=subtree_renderer,
+ return visualization.render_object_constructor(
+ object_type=type(self),
+ attributes={'trace_state': self._trace_state},
+ path=path,
+ subtree_renderer=subtree_renderer,
)
+
class ObjectMeta(ABCMeta):
if not tp.TYPE_CHECKING:
@@ -90,12 +130,14 @@ def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G:
@dataclasses.dataclass(frozen=True, repr=False)
-class Array:
+class Array(reprlib.Representable):
shape: tp.Tuple[int, ...]
dtype: tp.Any
- def __repr__(self):
- return f'Array(shape={self.shape}, dtype={self.dtype.name})'
+ def __nnx_repr__(self):
+ yield reprlib.Object(type='Array', same_line=True)
+ yield reprlib.Attr('shape', self.shape)
+ yield reprlib.Attr('dtype', self.dtype)
class Object(reprlib.Representable, metaclass=ObjectMeta):
@@ -137,20 +179,41 @@ def __deepcopy__(self: G, memo=None) -> G:
return graph.merge(graphdef, state)
def __nnx_repr__(self):
- if CONTEXT.seen_modules_repr is None:
- CONTEXT.seen_modules_repr = set()
+ if OBJECT_CONTEXT.node_stats is None:
+ node_stats: dict[int, dict[type[Variable], SizeBytes]] = {}
+ _collect_stats(self, node_stats)
+ OBJECT_CONTEXT.node_stats = node_stats
+ stats = node_stats[id(self)]
+ clear_node_stats = True
+ else:
+ stats = OBJECT_CONTEXT.node_stats[id(self)]
+ clear_node_stats = False
+
+ if OBJECT_CONTEXT.seen_modules_repr is None:
+ OBJECT_CONTEXT.seen_modules_repr = set()
clear_seen = True
else:
clear_seen = False
- if id(self) in CONTEXT.seen_modules_repr:
+ if id(self) in OBJECT_CONTEXT.seen_modules_repr:
yield reprlib.Object(type=type(self), empty_repr='...')
return
- yield reprlib.Object(type=type(self))
- CONTEXT.seen_modules_repr.add(id(self))
-
try:
+ if stats:
+ stats_repr = ' # ' + ', '.join(
+ f'{var_type.__name__}: {size_bytes}'
+ for var_type, size_bytes in stats.items()
+ )
+ if len(stats) > 1:
+ total_bytes = sum(stats.values(), SizeBytes(0, 0))
+ stats_repr += f', Total: {total_bytes}'
+ else:
+ stats_repr = ''
+
+ yield reprlib.Object(type=type(self), comment=stats_repr)
+ OBJECT_CONTEXT.seen_modules_repr.add(id(self))
+
for name, value in vars(self).items():
if name.startswith('_'):
continue
@@ -168,24 +231,64 @@ def to_shape_dtype(value):
return value
value = jax.tree.map(to_shape_dtype, value)
- yield reprlib.Attr(name, repr(value))
+ yield reprlib.Attr(name, value)
finally:
if clear_seen:
- CONTEXT.seen_modules_repr = None
+ OBJECT_CONTEXT.seen_modules_repr = None
+ if clear_node_stats:
+ OBJECT_CONTEXT.node_stats = None
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- children = {}
- for name, value in vars(self).items():
- if name.startswith('_'):
- continue
- children[name] = value
- return treescope.repr_lib.render_object_constructor(
+ from flax import nnx
+
+ if OBJECT_CONTEXT.node_stats is None:
+ node_stats: dict[int, dict[type[Variable], SizeBytes]] = {}
+ _collect_stats(self, node_stats)
+ OBJECT_CONTEXT.node_stats = node_stats
+ stats = node_stats[id(self)]
+ clear_node_stats = True
+ else:
+ stats = OBJECT_CONTEXT.node_stats[id(self)]
+ clear_node_stats = False
+
+ try:
+ if stats:
+ stats_repr = ' # ' + ', '.join(
+ f'{var_type.__name__}: {size_bytes}'
+ for var_type, size_bytes in stats.items()
+ )
+ if len(stats) > 1:
+ total_bytes = sum(stats.values(), SizeBytes(0, 0))
+ stats_repr += f', Total: {total_bytes}'
+
+ first_line_annotation = rendering_parts.comment_color(
+ rendering_parts.text(f'{stats_repr}')
+ )
+ else:
+ first_line_annotation = None
+ children = {}
+ for name, value in vars(self).items():
+ if name.startswith('_'):
+ continue
+ children[name] = value
+
+ if isinstance(self, nnx.Module):
+ color = treescope.formatting_util.color_from_string(
+ type(self).__qualname__
+ )
+ else:
+ color = None
+ return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
- )
+ first_line_annotation=first_line_annotation,
+ color=color,
+ )
+ finally:
+ if clear_node_stats:
+ OBJECT_CONTEXT.node_stats = None
# Graph Definition
def _graph_node_flatten(self):
@@ -225,4 +328,13 @@ def _graph_node_clear(self):
module_vars['_object__state'] = module_state
def _graph_node_init(self, attributes: tp.Iterable[tuple[str, tp.Any]]):
- vars(self).update(attributes)
\ No newline at end of file
+ vars(self).update(attributes)
+
+
+def supports_color() -> bool:
+ """
+ Returns True if the running system's terminal supports color, and False otherwise.
+ """
+ supported_platform = sys.platform != 'win32' or 'ANSICON' in os.environ
+ is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
+ return supported_platform and is_a_tty
\ No newline at end of file
diff --git a/flax/nnx/reprlib.py b/flax/nnx/reprlib.py
index 6ed7660cdf..155c2e7e90 100644
--- a/flax/nnx/reprlib.py
+++ b/flax/nnx/reprlib.py
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
import dataclasses
+import os
+import sys
import threading
import typing as tp
@@ -21,22 +22,125 @@
B = tp.TypeVar('B')
+def supports_color() -> bool:
+ """
+ Returns True if the running system's terminal supports color, and False otherwise.
+ """
+ try:
+ from IPython import get_ipython
+
+ ipython_available = get_ipython() is not None
+ except ImportError:
+ ipython_available = False
+
+ supported_platform = sys.platform != 'win32' or 'ANSICON' in os.environ
+ is_a_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
+ return (supported_platform and is_a_tty) or ipython_available
+
+
+class Color(tp.NamedTuple):
+ TYPE: str
+ ATTRIBUTE: str
+ SEP: str
+ PAREN: str
+ COMMENT: str
+ INT: str
+ STRING: str
+ FLOAT: str
+ BOOL: str
+ NONE: str
+ END: str
+
+
+NO_COLOR = Color(
+ TYPE='',
+ ATTRIBUTE='',
+ SEP='',
+ PAREN='',
+ COMMENT='',
+ INT='',
+ STRING='',
+ FLOAT='',
+ BOOL='',
+ NONE='',
+ END='',
+)
+
+
+# Use python vscode theme colors
+if supports_color():
+ COLOR = Color(
+ TYPE='\x1b[38;2;79;201;177m',
+ ATTRIBUTE='\033[38;2;156;220;254m',
+ SEP='\x1b[38;2;212;212;212m',
+ PAREN='\x1b[38;2;255;213;3m',
+ # COMMENT='\033[38;2;87;166;74m',
+ COMMENT='\033[38;2;105;105;105m', # Dark gray
+ INT='\x1b[38;2;182;207;169m',
+ STRING='\x1b[38;2;207;144;120m',
+ FLOAT='\x1b[38;2;182;207;169m',
+ BOOL='\x1b[38;2;86;156;214m',
+ NONE='\x1b[38;2;86;156;214m',
+ END='\x1b[0m',
+ )
+else:
+ COLOR = NO_COLOR
+
+
@dataclasses.dataclass
class ReprContext(threading.local):
- indent_stack: tp.List[str] = dataclasses.field(default_factory=lambda: [''])
+ current_color: Color = COLOR
REPR_CONTEXT = ReprContext()
+def colorized(x, /):
+ c = REPR_CONTEXT.current_color
+ if isinstance(x, list):
+ return f'{c.PAREN}[{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}]{c.END}'
+ elif isinstance(x, tuple):
+ if len(x) == 1:
+ return f'{c.PAREN}({c.END}{colorized(x[0])},{c.PAREN}){c.END}'
+ return f'{c.PAREN}({c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}){c.END}'
+ elif isinstance(x, dict):
+ open, close = '{', '}'
+ return f'{c.PAREN}{open}{c.END}{", ".join(f"{c.STRING}{k!r}{c.END}: {colorized(v)}" for k, v in x.items())}{c.PAREN}{close}{c.END}'
+ elif isinstance(x, set):
+ open, close = '{', '}'
+ return f'{c.PAREN}{open}{c.END}{", ".join(map(lambda i: colorized(i), x))}{c.PAREN}{close}{c.END}'
+ elif isinstance(x, type):
+ return f'{c.TYPE}{x.__name__}{c.END}'
+ elif isinstance(x, bool):
+ return f'{c.BOOL}{x}{c.END}'
+ elif isinstance(x, int):
+ return f'{c.INT}{x}{c.END}'
+ elif isinstance(x, str):
+ return f'{c.STRING}{x!r}{c.END}'
+ elif isinstance(x, float):
+ return f'{c.FLOAT}{x}{c.END}'
+ elif x is None:
+ return f'{c.NONE}{x}{c.END}'
+ elif isinstance(x, Representable):
+ return get_repr(x)
+ else:
+ return repr(x)
+
+
@dataclasses.dataclass
class Object:
type: tp.Union[str, type]
start: str = '('
end: str = ')'
- value_sep: str = '='
- elem_indent: str = ' '
+ kv_sep: str = '='
+ indent: str = ' '
empty_repr: str = ''
+ comment: str = ''
+ same_line: bool = False
+
+ @property
+ def elem_sep(self):
+ return ', ' if self.same_line else ',\n'
@dataclasses.dataclass
@@ -45,6 +149,8 @@ class Attr:
value: tp.Union[str, tp.Any]
start: str = ''
end: str = ''
+ use_raw_value: bool = False
+ use_raw_key: bool = False
class Representable:
@@ -54,79 +160,96 @@ def __nnx_repr__(self) -> tp.Iterator[tp.Union[Object, Attr]]:
raise NotImplementedError
def __repr__(self) -> str:
+ current_color = REPR_CONTEXT.current_color
+ REPR_CONTEXT.current_color = NO_COLOR
+ try:
+ return get_repr(self)
+ finally:
+ REPR_CONTEXT.current_color = current_color
+
+ def __str__(self) -> str:
return get_repr(self)
-@contextlib.contextmanager
-def add_indent(indent: str) -> tp.Iterator[None]:
- REPR_CONTEXT.indent_stack.append(REPR_CONTEXT.indent_stack[-1] + indent)
-
- try:
- yield
- finally:
- REPR_CONTEXT.indent_stack.pop()
-
-
-def get_indent() -> str:
- return REPR_CONTEXT.indent_stack[-1]
-
-
def get_repr(obj: Representable) -> str:
if not isinstance(obj, Representable):
raise TypeError(f'Object {obj!r} is not representable')
+ c = REPR_CONTEXT.current_color
iterator = obj.__nnx_repr__()
config = next(iterator)
+
if not isinstance(config, Object):
raise TypeError(f'First item must be Config, got {type(config).__name__}')
+ kv_sep = f'{c.SEP}{config.kv_sep}{c.END}'
+
def _repr_elem(elem: tp.Any) -> str:
if not isinstance(elem, Attr):
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
- value = elem.value if isinstance(elem.value, str) else repr(elem.value)
-
- value = value.replace('\n', '\n' + config.elem_indent)
+ value_repr = elem.value if elem.use_raw_value else colorized(elem.value)
+ value_repr = value_repr.replace('\n', '\n' + config.indent)
+ key = elem.key if elem.use_raw_key else f'{c.ATTRIBUTE}{elem.key}{c.END}'
+ indent = '' if config.same_line else config.indent
- return f'{config.elem_indent}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}'
+ return f'{indent}{elem.start}{key}{kv_sep}{value_repr}{elem.end}'
- with add_indent(config.elem_indent):
- elems = ',\n'.join(map(_repr_elem, iterator))
+ elems = config.elem_sep.join(map(_repr_elem, iterator))
if elems:
- elems = '\n' + elems + '\n'
+ if config.same_line:
+ elems_repr = elems
+ comment = ''
+ else:
+ elems_repr = '\n' + elems + '\n'
+ comment = f'{c.COMMENT}{config.comment}{c.END}'
else:
- elems = config.empty_repr
+ elems_repr = config.empty_repr
+ comment = ''
type_repr = (
config.type if isinstance(config.type, str) else config.type.__name__
)
+ type_repr = f'{c.TYPE}{type_repr}{c.END}' if type_repr else ''
+ start = f'{c.PAREN}{config.start}{c.END}' if config.start else ''
+ end = f'{c.PAREN}{config.end}{c.END}' if config.end else ''
- return f'{type_repr}{config.start}{elems}{config.end}'
+ out = f'{type_repr}{start}{comment}{elems_repr}{end}'
+ return out
-class MappingReprMixin(tp.Mapping[A, B]):
+class MappingReprMixin(Representable):
def __nnx_repr__(self):
- yield Object(type='', value_sep=': ', start='{', end='}')
+ yield Object(type='', kv_sep=': ', start='{', end='}')
- for key, value in self.items():
- yield Attr(repr(key), value)
+ for key, value in self.items(): # type: ignore
+ yield Attr(colorized(key), value, use_raw_key=True)
@dataclasses.dataclass(repr=False)
class PrettyMapping(Representable):
mapping: tp.Mapping
def __nnx_repr__(self):
- yield Object(type='', value_sep=': ', start='{', end='}')
+ yield Object(type=type(self), kv_sep=': ', start='({', end='})')
for key, value in self.mapping.items():
- yield Attr(repr(key), value)
+ yield Attr(colorized(key), value, use_raw_key=True)
+
+@dataclasses.dataclass(repr=False)
+class SequenceReprMixin(Representable):
+ def __nnx_repr__(self):
+ yield Object(type=type(self), kv_sep='', start='([', end='])')
+
+ for value in self: # type: ignore
+ yield Attr('', value, use_raw_key=True)
+
@dataclasses.dataclass(repr=False)
class PrettySequence(Representable):
- list: tp.Sequence
+ sequence: tp.Sequence
def __nnx_repr__(self):
- yield Object(type='', value_sep='', start='[', end=']')
+ yield Object(type=type(self), kv_sep='', start='([', end='])')
- for value in self.list:
- yield Attr('', value)
\ No newline at end of file
+ for value in self.sequence:
+ yield Attr('', value, use_raw_key=True)
\ No newline at end of file
diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py
index 42a2604042..38cb3da759 100644
--- a/flax/nnx/statelib.py
+++ b/flax/nnx/statelib.py
@@ -38,7 +38,7 @@ def __init__(self, state: State):
self.state = state
def __nnx_repr__(self):
- yield reprlib.Object('', value_sep=': ', start='{', end='}')
+ yield reprlib.Object('', kv_sep=': ', start='{', end='}')
for r in self.state.__nnx_repr__():
if isinstance(r, reprlib.Object):
@@ -54,7 +54,7 @@ def __treescope_repr__(self, path, subtree_renderer):
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)
-class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.PrettySequence):
+class FlatState(tp.Sequence[tuple[PathParts, V]], reprlib.SequenceReprMixin):
_keys: tuple[PathParts, ...]
_values: list[V]
@@ -66,6 +66,14 @@ def __init__(self, items: tp.Iterable[tuple[PathParts, V]]):
self._keys = tuple(keys)
self._values = values
+ @property
+ def paths(self) -> tp.Sequence[PathParts]:
+ return self._keys
+
+ @property
+ def leaves(self) -> tp.Sequence[V]:
+ return self._values
+
@tp.overload
def __getitem__(self, index: int) -> tuple[PathParts, V]: ...
@tp.overload
@@ -173,7 +181,7 @@ def __len__(self) -> int:
return len(self._mapping)
def __nnx_repr__(self):
- yield reprlib.Object(type(self), value_sep=': ', start='({', end='})')
+ yield reprlib.Object(type(self), kv_sep=': ', start='({', end='})')
for k, v in self.items():
if isinstance(v, State):
diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py
index c53bbd5c4d..a7b72b1540 100644
--- a/flax/nnx/tracers.py
+++ b/flax/nnx/tracers.py
@@ -18,7 +18,7 @@
import jax
import jax.core
-from flax.nnx import reprlib
+from flax.nnx import reprlib, visualization
def current_jax_trace():
@@ -47,12 +47,11 @@ def __nnx_repr__(self):
yield reprlib.Attr('jax_trace', self._jax_trace)
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
- return treescope.repr_lib.render_object_constructor(
- object_type=type(self),
- attributes={'jax_trace': self._jax_trace},
- path=path,
- subtree_renderer=subtree_renderer,
+ return visualization.render_object_constructor(
+ object_type=type(self),
+ attributes={'jax_trace': self._jax_trace},
+ path=path,
+ subtree_renderer=subtree_renderer,
)
def __eq__(self, other):
diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py
index 4752a9b7bd..1a9d5d4f03 100644
--- a/flax/nnx/variablelib.py
+++ b/flax/nnx/variablelib.py
@@ -21,10 +21,15 @@
from typing import Any
import jax
+import treescope
from flax import errors
-from flax.nnx import filterlib, reprlib, tracers
-from flax.typing import Missing, PathParts
+from flax.nnx import filterlib, reprlib, tracers, visualization
+from flax.typing import (
+ Missing,
+ PathParts,
+ value_stats,
+)
import jax.tree_util as jtu
A = tp.TypeVar('A')
@@ -42,6 +47,7 @@
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
+
@dataclasses.dataclass
class VariableMetadata(tp.Generic[A]):
raw_value: A
@@ -311,20 +317,34 @@ def to_state(self: Variable[A]) -> VariableState[A]:
return VariableState(type(self), self.raw_value, **self._var_metadata)
def __nnx_repr__(self):
- yield reprlib.Object(type=type(self))
+ stats = value_stats(self.value)
+ if stats:
+ comment = f' # {stats}'
+ else:
+ comment = ''
+
+ yield reprlib.Object(type=type(self).__name__, comment=comment)
yield reprlib.Attr('value', self.raw_value)
for name, value in self._var_metadata.items():
yield reprlib.Attr(name, repr(value))
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
+ size_bytes = value_stats(self.value)
+ if size_bytes:
+ stats_repr = f' # {size_bytes}'
+ first_line_annotation = treescope.rendering_parts.comment_color(
+ treescope.rendering_parts.text(f'{stats_repr}')
+ )
+ else:
+ first_line_annotation = None
children = {'value': self.raw_value, **self._var_metadata}
- return treescope.repr_lib.render_object_constructor(
+ return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
+ first_line_annotation=first_line_annotation,
)
# hooks API
@@ -764,22 +784,35 @@ def __delattr__(self, name: str) -> None:
del self._var_metadata[name]
def __nnx_repr__(self):
- yield reprlib.Object(type=type(self))
- yield reprlib.Attr('type', self.type.__name__)
+ stats = value_stats(self.value)
+ if stats:
+ comment = f' # {stats}'
+ else:
+ comment = ''
+
+ yield reprlib.Object(type=type(self), comment=comment)
+ yield reprlib.Attr('type', self.type)
yield reprlib.Attr('value', self.value)
for name, value in self._var_metadata.items():
- yield reprlib.Attr(name, repr(value))
+ yield reprlib.Attr(name, value)
def __treescope_repr__(self, path, subtree_renderer):
- import treescope # type: ignore[import-not-found,import-untyped]
-
+ size_bytes = value_stats(self.value)
+ if size_bytes:
+ stats_repr = f' # {size_bytes}'
+ first_line_annotation = treescope.rendering_parts.comment_color(
+ treescope.rendering_parts.text(f'{stats_repr}')
+ )
+ else:
+ first_line_annotation = None
children = {'type': self.type, 'value': self.value, **self._var_metadata}
- return treescope.repr_lib.render_object_constructor(
+ return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
+ first_line_annotation=first_line_annotation,
)
def replace(self, value: B) -> VariableState[B]:
@@ -911,7 +944,7 @@ def wrapper(*args):
def split_flat_state(
flat_state: tp.Iterable[tuple[PathParts, Variable | VariableState]],
- filters: tuple[filterlib.Filter, ...],
+ filters: tp.Sequence[filterlib.Filter],
) -> tuple[list[tuple[PathParts, Variable | VariableState]], ...]:
predicates = filterlib.filters_to_predicates(filters)
# we have n + 1 states, where n is the number of predicates
diff --git a/flax/nnx/visualization.py b/flax/nnx/visualization.py
index d49eed7cf7..63de76cae3 100644
--- a/flax/nnx/visualization.py
+++ b/flax/nnx/visualization.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import importlib.util
+import typing as tp
+
+import treescope
+from treescope import rendering_parts, renderers
-treescope_installed = importlib.util.find_spec('treescope') is not None
try:
from IPython import get_ipython
@@ -29,12 +31,112 @@ def display(*args):
If treescope is not installed or the code is not running in IPython,
``display`` will print the objects instead.
"""
- if not treescope_installed or not in_ipython:
+ if not in_ipython:
for x in args:
print(x)
return
- import treescope # type: ignore[import-not-found,import-untyped]
-
for x in args:
treescope.display(x, ignore_exceptions=True, autovisualize=True)
+
+
+def render_object_constructor(
+ object_type: type[tp.Any],
+ attributes: tp.Mapping[str, tp.Any],
+ path: str | None,
+ subtree_renderer: renderers.TreescopeSubtreeRenderer,
+ roundtrippable: bool = False,
+ color: str | None = None,
+ first_line_annotation: rendering_parts.RenderableTreePart | None = None,
+) -> rendering_parts.Rendering:
+ """Renders an object in "constructor format", similar to a dataclass.
+
+ This produces a rendering like `Foo(bar=1, baz=2)`, where Foo identifies the
+ type of the object, and bar and baz are the names of the attributes of the
+ object. It is a *requirement* that these are the actual attributes of the
+ object, which can be accessed via `obj.bar` or similar; otherwise, the
+ path renderings will break.
+
+ This can be used from within a `__treescope_repr__` implementation via ::
+
+ def __treescope_repr__(self, path, subtree_renderer):
+ return repr_lib.render_object_constructor(
+ object_type=type(self),
+ attributes=,
+ path=path,
+ subtree_renderer=subtree_renderer,
+ )
+
+ Args:
+ object_type: The type of the object.
+ attributes: The attributes of the object, which will be rendered as keyword
+ arguments to the constructor.
+ path: The path to the object. When `render_object_constructor` is called
+ from `__treescope_repr__`, this should come from the `path` argument to
+ `__treescope_repr__`.
+ subtree_renderer: The renderer to use to render subtrees. When
+ `render_object_constructor` is called from `__treescope_repr__`, this
+ should come from the `subtree_renderer` argument to `__treescope_repr__`.
+ roundtrippable: Whether evaluating the rendering as Python code will produce
+ an object that is equal to the original object. This implies that the
+ keyword arguments are actually the keyword arguments to the constructor,
+ and not some other attributes of the object.
+ color: The background color to use for the object rendering. If None, does
+ not use a background color. A utility for assigning a random color based
+ on a string key is given in `treescope.formatting_util`.
+ first_line_annotation: An annotation for the first line of the node when it
+ is expanded.
+
+ Returns:
+ A rendering of the object, suitable for returning from `__treescope_repr__`.
+ """
+ if roundtrippable:
+ constructor = rendering_parts.siblings(
+ rendering_parts.maybe_qualified_type_name(object_type), '('
+ )
+ closing_suffix = rendering_parts.text(')')
+ else:
+ constructor = rendering_parts.siblings(
+ rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('<')),
+ rendering_parts.maybe_qualified_type_name(object_type),
+ '(',
+ )
+ closing_suffix = rendering_parts.siblings(
+ ')',
+ rendering_parts.roundtrip_condition(roundtrip=rendering_parts.text('>')),
+ )
+
+ children = []
+ for i, (name, value) in enumerate(attributes.items()):
+ child_path = None if path is None else f'{path}.{name}'
+
+ if i < len(attributes) - 1:
+ # Not the last child. Always show a comma, and add a space when
+ # collapsed.
+ comma_after = rendering_parts.siblings(
+ ',',
+ rendering_parts.fold_condition(collapsed=rendering_parts.text(' ')),
+ )
+ else:
+ # Last child: only show the comma when the node is expanded.
+ comma_after = rendering_parts.fold_condition(
+ expanded=rendering_parts.text(',')
+ )
+
+ child_line = rendering_parts.build_full_line_with_annotations(
+ rendering_parts.siblings_with_annotations(
+ f'{name}=',
+ subtree_renderer(value, path=child_path),
+ ),
+ comma_after,
+ )
+ children.append(child_line)
+
+ return rendering_parts.build_foldable_tree_node_from_children(
+ prefix=constructor,
+ children=children,
+ suffix=closing_suffix,
+ path=path,
+ background_color=color,
+ first_line_annotation=first_line_annotation,
+ )
\ No newline at end of file
diff --git a/flax/typing.py b/flax/typing.py
index a630a3571e..af0ef679b3 100644
--- a/flax/typing.py
+++ b/flax/typing.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import annotations
from collections import deque
from functools import partial
@@ -26,6 +27,8 @@
from collections.abc import Callable, Hashable, Mapping, Sequence
import jax
+import jax.numpy as jnp
+import numpy as np
from flax.core import FrozenDict
import dataclasses
@@ -161,3 +164,62 @@ class Missing:
MISSING = Missing()
+
+
+def _bytes_repr(num_bytes):
+ count, units = (
+ (f'{num_bytes / 1e9 :,.1f}', 'GB')
+ if num_bytes > 1e9
+ else (f'{num_bytes / 1e6 :,.1f}', 'MB')
+ if num_bytes > 1e6
+ else (f'{num_bytes / 1e3 :,.1f}', 'KB')
+ if num_bytes > 1e3
+ else (f'{num_bytes:,}', 'B')
+ )
+
+ return f'{count} {units}'
+
+
+class ShapeDtype(Protocol):
+ shape: Shape
+ dtype: Dtype
+
+
+def has_shape_dtype(x: Any) -> TypeGuard[ShapeDtype]:
+ return hasattr(x, 'shape') and hasattr(x, 'dtype')
+
+
+@dataclasses.dataclass(frozen=True, slots=True)
+class SizeBytes:
+ size: int
+ bytes: int
+
+ @staticmethod
+ def from_array(x: ShapeDtype) -> SizeBytes:
+ size = int(np.prod(x.shape))
+ if isinstance(x.dtype, str):
+ dtype = jnp.dtype(x.dtype)
+ else:
+ dtype = x.dtype
+ bytes = size * dtype.itemsize # type: ignore
+ return SizeBytes(size, bytes)
+
+ def __add__(self, other: SizeBytes) -> SizeBytes:
+ return SizeBytes(self.size + other.size, self.bytes + other.bytes)
+
+ def __bool__(self) -> bool:
+ return bool(self.size)
+
+ def __repr__(self) -> str:
+ bytes_repr = _bytes_repr(self.bytes)
+ return f'{self.size:,} ({bytes_repr})'
+
+
+def value_stats(x):
+ leaves = jax.tree.leaves(x)
+ size_bytes = SizeBytes(0, 0)
+ for leaf in leaves:
+ if has_shape_dtype(leaf):
+ size_bytes += SizeBytes.from_array(leaf)
+
+ return size_bytes
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 658b2f15d5..f7a890fad0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ dependencies = [
"rich>=11.1",
"typing_extensions>=4.2",
"PyYAML>=5.4.1",
- "treescope>=0.1.2",
+ "treescope>=0.1.7",
]
classifiers = [
"Development Status :: 3 - Alpha",
diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py
index ce65186dd2..64928f46b8 100644
--- a/tests/nnx/module_test.py
+++ b/tests/nnx/module_test.py
@@ -25,6 +25,7 @@
import jax.numpy as jnp
import numpy as np
+
A = TypeVar('A')
class List(nnx.Module):
@@ -550,6 +551,46 @@ def __call__(self, x):
y2 = model(jnp.ones((5, 2)))
np.testing.assert_allclose(y1, y2)
+ def test_repr(self):
+ class Block(nnx.Module):
+ def __init__(self, din, dout, rngs: nnx.Rngs):
+ self.linear = nnx.Linear(din, dout, rngs=rngs)
+ self.bn = nnx.BatchNorm(dout, rngs=rngs)
+ self.dropout = nnx.Dropout(0.2, rngs=rngs)
+
+ def __call__(self, x):
+ return nnx.relu(self.dropout(self.bn(self.linear(x))))
+
+ class Foo(nnx.Module):
+ def __init__(self, rngs: nnx.Rngs):
+ self.block1 = Block(32, 128, rngs=rngs)
+ self.block2 = Block(128, 10, rngs=rngs)
+
+ def __call__(self, x):
+ return self.block2(self.block1(x))
+
+ obj = Foo(nnx.Rngs(0))
+
+ leaves = nnx.state(obj).flat_state().leaves
+
+ expected_total = sum(int(np.prod(x.value.shape)) for x in leaves)
+ expected_total_params = sum(
+ int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.Param
+ )
+ expected_total_batch_stats = sum(
+ int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.BatchStat
+ )
+ expected_total_rng_states = sum(
+ int(np.prod(x.value.shape)) for x in leaves if x.type is nnx.RngState
+ )
+
+ foo_repr = repr(obj).replace(',', '').splitlines()
+
+ self.assertIn(str(expected_total), foo_repr[0])
+ self.assertIn(str(expected_total_params), foo_repr[0])
+ self.assertIn(str(expected_total_batch_stats), foo_repr[0])
+ self.assertIn(str(expected_total_rng_states), foo_repr[0])
+
class TestModulePytree:
def test_tree_map(self):
diff --git a/uv.lock b/uv.lock
index e08e2dbf53..48bda4f756 100644
--- a/uv.lock
+++ b/uv.lock
@@ -3,13 +3,13 @@ requires-python = ">=3.10"
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
[[package]]
@@ -641,7 +641,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/99/bc/cfb52b9e8531526604afe8666185d207e4f0cb9c6d90bc76f62fb8746804/etils-1.7.0.tar.gz", hash = "sha256:97b68fd25e185683215286ef3a54e38199b6245f5fe8be6bedc1189be4256350", size = 95695 }
wheels = [
@@ -676,10 +676,10 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/ba/49/d480aeb4fc441d933acce97261bea002234a45fb847599c9a93c31e51b2e/etils-1.9.2.tar.gz", hash = "sha256:15dcd35ac0c0cc2404b46ac0846af3cc4e876fd3d80f36f57951e27e8b9d6379", size = 101506 }
wheels = [
@@ -890,7 +890,7 @@ requires-dist = [
{ name = "tensorflow-text", marker = "platform_system != 'Darwin' and extra == 'testing'", specifier = ">=2.11.0" },
{ name = "tensorstore" },
{ name = "torch", marker = "extra == 'testing'" },
- { name = "treescope", specifier = ">=0.1.2" },
+ { name = "treescope", specifier = ">=0.1.7" },
{ name = "treescope", marker = "python_full_version >= '3.10' and extra == 'testing'", specifier = ">=0.1.1" },
{ name = "typing-extensions", specifier = ">=4.2" },
]
@@ -1202,7 +1202,7 @@ name = "ipython"
version = "8.26.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "decorator" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "jedi" },
@@ -1246,7 +1246,7 @@ wheels = [
[[package]]
name = "jax"
-version = "0.4.37"
+version = "0.4.38"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jaxlib" },
@@ -1255,14 +1255,14 @@ dependencies = [
{ name = "opt-einsum" },
{ name = "scipy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/50/30/ad7617a960c86782587540a179cef676962322d1e5411415b1aa24f02ce0/jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b", size = 1915966 }
+sdist = { url = "https://files.pythonhosted.org/packages/fb/e5/c4aa9644bb96b7f6747bd7c9f8cda7665ca5e194fa2542b2dea3ff730701/jax-0.4.38.tar.gz", hash = "sha256:43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8", size = 1930034 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/5f/3f/6c5553baaa7faa3fa8bae8279b1e46cb54c7ce52360139eae53498786ea5/jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e", size = 2221192 },
+ { url = "https://files.pythonhosted.org/packages/22/49/b4418a7a892c0dd64442bbbeef54e1cdfe722dfc5a7bf0d611d3f5f90e99/jax-0.4.38-py3-none-any.whl", hash = "sha256:78987306f7041ea8500d99df1a17c33ed92620c2268c4c3677fb24e06712be64", size = 2236864 },
]
[[package]]
name = "jaxlib"
-version = "0.4.36"
+version = "0.4.38"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ml-dtypes" },
@@ -1270,26 +1270,26 @@ dependencies = [
{ name = "scipy" },
]
wheels = [
- { url = "https://files.pythonhosted.org/packages/23/8d/8a44618f3493f29d769b2b40778d24075689cc8697b98e2c43bafbe50edf/jaxlib-0.4.36-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d69f991833b6dca794767049843462805936c89553b136a8ebb8485334204457", size = 98648230 },
- { url = "https://files.pythonhosted.org/packages/78/b8/207485eab566dcfbc29bb833714ac1ca47a1665ca605b1ff7d3d5dd2afbe/jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:807814c1ba3ec69cffaa93d3f90651c694a9b8a750b43832cc167ed590c821dd", size = 78553787 },
- { url = "https://files.pythonhosted.org/packages/26/42/3c2b0dc86a17aafd8f46ba0e4388f39f55706ee25f6c463c3dadea7a71e2/jaxlib-0.4.36-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:1bc27d9ae09549d7652eafe1fdb10c21546cd2fd02bb24a49a7e6208b69163b0", size = 84008742 },
- { url = "https://files.pythonhosted.org/packages/b9/b2/29be712098342df10075fe085c0b39d783a579bd3325fb0d69c22712cf27/jaxlib-0.4.36-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:3379f03a794d6a30b75765d2786f6e31052f364196fcd49aaae292a3c16f12ec", size = 100263041 },
- { url = "https://files.pythonhosted.org/packages/63/a9/93404a2f1d59647749d4d6dbab7bee9f5a7bfaeb9ade25b7e66c0ca0949a/jaxlib-0.4.36-cp310-cp310-win_amd64.whl", hash = "sha256:63e575ac8a515dee8171dd4a88c460d538bbcc9d959cabc9781e961763678f84", size = 63270658 },
- { url = "https://files.pythonhosted.org/packages/e4/7d/9394ff39af5c23bb98a241c33742a328df5a43c21d569855ea7e096aaf5e/jaxlib-0.4.36-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:213792db3b876206b45f6a9fbea15e4dd22a9e80be25b03136f20c94784fecfa", size = 98669744 },
- { url = "https://files.pythonhosted.org/packages/34/5a/9f3c9e5cec23e60f78bb3c3da108a5ef664601862dbc4e84fc4be3654f5d/jaxlib-0.4.36-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d7a89adf4c9d3cddd20482931dedc7a9e2669e904196a9599d9a605b3d9e552", size = 78574312 },
- { url = "https://files.pythonhosted.org/packages/ff/5c/bf78ed9b8d0f174a562f6496049a4872e14a3bb3a80de09c4292d04be5f0/jaxlib-0.4.36-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c395fe8cc5bd6558dd2fbce78e24172b6f27762e17628720ae03d693001283f3", size = 84038323 },
- { url = "https://files.pythonhosted.org/packages/67/af/6a9dd26e8a6bedd4c9fe702059767256b0d9ed18c29a180a4598d5795bb4/jaxlib-0.4.36-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:bc324c6b1c64fe68400934c653e4e622f12576120dcdb451c3b4ea4dcaba2ae9", size = 100285487 },
- { url = "https://files.pythonhosted.org/packages/b7/46/31c3a519a94e84c672ca264c4151998e3e3fd11c481d8fa5af5885b91a1e/jaxlib-0.4.36-cp311-cp311-win_amd64.whl", hash = "sha256:c9e0c45a79e63aea65447f82bd0fa21c17b9afe884aa18dd5362b9965abe9d72", size = 63308064 },
- { url = "https://files.pythonhosted.org/packages/e3/0e/3b4a99c09431ee5820624d4dcf4efa7becd3c83b56ff0f09a078f4c421a2/jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d", size = 98718357 },
- { url = "https://files.pythonhosted.org/packages/d3/46/05e70a1236ec3782333b3e9469f971c9d45af2aa0aebf602acd9d76292eb/jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e", size = 78596060 },
- { url = "https://files.pythonhosted.org/packages/8e/76/6b969cbf197b8c53c84c2642069722e84a3a260af084a8acbbf90ca444ea/jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442", size = 84053202 },
- { url = "https://files.pythonhosted.org/packages/fe/f2/7624a304426daa7b135b85caf1b8eccf879e7cb10bc074656ce628309cb0/jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357", size = 100325610 },
- { url = "https://files.pythonhosted.org/packages/bb/8b/ded8420cd9198eb677869ffd557d9880af5833c7bf39e604e80b56550e09/jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3", size = 63338518 },
- { url = "https://files.pythonhosted.org/packages/5d/22/b72811c61e8b594951d3ee03245cb0932c723ac35e75569005c3c976eec2/jaxlib-0.4.36-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:02df9c0e1323dde01e966c22eb12432905d2d4de8aac7b603cad2083101b0e6b", size = 98719384 },
- { url = "https://files.pythonhosted.org/packages/f1/66/3f4a97097983914899100db9e5312493fe1d6adc924e47a0e47e15c553f5/jaxlib-0.4.36-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ec980e85983f41999c4dc84137dec70507d958e23d7eefa104da93053d135f", size = 78596150 },
- { url = "https://files.pythonhosted.org/packages/3a/6f/cf02f56d1532962d8ca77a6548acab8204294b96b5a153ca4a2caf4971fc/jaxlib-0.4.36-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7ce9368515348d869d6c59d9904c3cb3c81f22ff3e9e969eae0e3563fe472080", size = 84055851 },
- { url = "https://files.pythonhosted.org/packages/28/10/4fc4e9719c065c6455491730011e87fe4b5120a9a008161cc32663feb9ce/jaxlib-0.4.36-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:93f1c502d08e517f842fe7b18428bb086cfd077db0ea9a2418fb21e5b4e06d3d", size = 100325986 },
- { url = "https://files.pythonhosted.org/packages/ba/28/fece5385e736ef2f1b5bed133f8001f0fc66dd0104707381343e047b341a/jaxlib-0.4.36-cp313-cp313-win_amd64.whl", hash = "sha256:bddf436a243e83ec6bc16bcbb74d15b1960a69318c9ea796fb2109492bc52575", size = 63338694 },
+ { url = "https://files.pythonhosted.org/packages/ee/d4/e6a0881a88b8f17491c2ee271fd77c348b0221d9e2ec92dad23a2c9e41bc/jaxlib-0.4.38-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:55c19b9d3f33a6fc59f644aa5a21fba02639ccdd776cb4a9b5526625f57839ff", size = 99663603 },
+ { url = "https://files.pythonhosted.org/packages/b6/6d/11569ce873f04c82ec22e58d822f4187dccae1d400c0d6dd05ed314d5328/jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:30b2f52cb50d74734af2f477c2533a7a583e3bb7b2c8acdeb361ee77d940577a", size = 79475708 },
+ { url = "https://files.pythonhosted.org/packages/72/61/1de2405d13089c83b1ad87ec0266479c9d00080659dae2474892ae356306/jaxlib-0.4.38-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ee19c163a8fdf0839d4c18b88a5fbfb4e731ba7c437416d3e5483e570bb764e4", size = 93219045 },
+ { url = "https://files.pythonhosted.org/packages/9c/24/0829decf233c6af9efe7c53888ae8ac72395e0979869cd9cee487e35dac3/jaxlib-0.4.38-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:61aeccb9a27c67fdb8450f6357240019cd4511cb9d62a44e4764756d384853ad", size = 101732107 },
+ { url = "https://files.pythonhosted.org/packages/0d/04/120c4caac6151f7297fedf9dd776362aa2d417d3f87bda826050b4da45e8/jaxlib-0.4.38-cp310-cp310-win_amd64.whl", hash = "sha256:d6ab745a89d0fb737a36fe1d8b86659e3fffe6ee8303b20651b26193d5edc0ef", size = 64223924 },
+ { url = "https://files.pythonhosted.org/packages/b0/6a/b9fba73eb5e758e40a514919e096a039d27dc0ab4776a6cc977f5153a55f/jaxlib-0.4.38-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:b67fdeabd6dfed08b7768f3bdffb521160085f8305669bd197beef61d08de08b", size = 99679916 },
+ { url = "https://files.pythonhosted.org/packages/44/2a/3458130d44d44038fd6974e7c43948f68408f685063203b82229b9b72c1a/jaxlib-0.4.38-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fb0eaae7369157afecbead50aaf29e73ffddfa77a2335d721bd9794f3c510e4", size = 79488377 },
+ { url = "https://files.pythonhosted.org/packages/94/96/7d9a0b9f35af4727df44b68ade4c6f15163840727d1cb47251b1ea515e30/jaxlib-0.4.38-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:43db58c4c427627296366a56c10318e1f00f503690e17f94bb4344293e1995e0", size = 93241543 },
+ { url = "https://files.pythonhosted.org/packages/a3/2d/68f85037e60c981b37b18b23ace458c677199dea4722ddce541b48ddfc63/jaxlib-0.4.38-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2751ff7037d6a997d0be0e77cc4be381c5a9f9bb8b314edb755c13a6fd969f45", size = 101751923 },
+ { url = "https://files.pythonhosted.org/packages/cc/24/a9c571c8a189f58e0b54b14d53fc7f5a0a06e4f1d7ab9edcf8d1d91d07e7/jaxlib-0.4.38-cp311-cp311-win_amd64.whl", hash = "sha256:35226968fc9de6873d1571670eac4117f5ed80e955f7a1775204d1044abe16c6", size = 64255189 },
+ { url = "https://files.pythonhosted.org/packages/49/df/08b94c593c0867c7eaa334592807ba74495de4be90580f360db8b96221dc/jaxlib-0.4.38-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:3fefea985f0415816f3bbafd3f03a437050275ef9bac9a72c1314e1644ac57c1", size = 99737849 },
+ { url = "https://files.pythonhosted.org/packages/ab/b1/c9d2a7ba9ebeabb7ac37082f4c466364f475dc7550a79358c0f0aa89fdf2/jaxlib-0.4.38-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f33bcafe32c97a562ecf6894d7c41674c80c0acdedfa5423d49af51147149874", size = 79509242 },
+ { url = "https://files.pythonhosted.org/packages/53/25/dd670d8bdf3799ece76d12cfe6a6a250ea256057aa4b0fcace4753a99d2d/jaxlib-0.4.38-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:496f45b0e001a2341309cd0c74af0b670537dced79c168cb230cfcc773f0aa86", size = 93251503 },
+ { url = "https://files.pythonhosted.org/packages/f9/cc/37fce5162f6b9070203fd76cc0f298d9b3bfdf01939a78935a6078d63621/jaxlib-0.4.38-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:dad6c0a96567c06d083c0469fec40f201210b099365bd698be31a6d2ec88fd59", size = 101792792 },
+ { url = "https://files.pythonhosted.org/packages/6f/7a/8515950a60a4ea5b13cc98fc0a42e36553b2db5a6eedc00d3bd7836f77b5/jaxlib-0.4.38-cp312-cp312-win_amd64.whl", hash = "sha256:966cdec36cfa978f5b4582bcb4147fe511725b94c1a752dac3a5f52ce46b6fa3", size = 64288223 },
+ { url = "https://files.pythonhosted.org/packages/91/03/aee503c7077c6dbbd568842303426c6ec1cef9bff330c418c9e71906cccd/jaxlib-0.4.38-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:41e55ae5818a882e5789e848f6f16687ac132bcfbb5a5fa114a5d18b78d05f2d", size = 99739026 },
+ { url = "https://files.pythonhosted.org/packages/cb/bf/fbbf61da319611d88e11c691d5a2077039208ded05e1731dea940f824a59/jaxlib-0.4.38-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6fe326b8af366387dd47ccf312583b2b17fed12712c9b74a648b18a13cbdbabf", size = 79508735 },
+ { url = "https://files.pythonhosted.org/packages/e4/0b/8cbff0b6d62a4694351c49baf53b7ed8deb8a6854d129408c38158e11676/jaxlib-0.4.38-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:248cca3771ebf24b070f49701364ceada33e6139445b06c782cca5ac5ad92bf4", size = 93251882 },
+ { url = "https://files.pythonhosted.org/packages/15/57/7f0283273b69c417071bcd2f4c2ed076479ec5ffc22a647f13c21da8d071/jaxlib-0.4.38-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:2ce77ba8cda9259a4bca97afc1c722e4291a6c463a63f8d372c6edc85117d625", size = 101791137 },
+ { url = "https://files.pythonhosted.org/packages/de/de/d6c4d234cd426b97459cb070af90792b48643967a0d28641379ee9e10fc9/jaxlib-0.4.38-cp313-cp313-win_amd64.whl", hash = "sha256:4103db0b3a38a5dc132741237453c24d8547290a22079ba1b577d6c88c95300a", size = 64288459 },
]
[[package]]
@@ -1431,7 +1431,7 @@ version = "5.7.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "platformdirs" },
- { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" },
+ { name = "pywin32", marker = "(platform_machine != 'aarch64' and platform_python_implementation != 'PyPy' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_python_implementation != 'PyPy' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "traitlets" },
]
sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 }
@@ -2095,7 +2095,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2122,9 +2122,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
- { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
- { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2135,7 +2135,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -2262,7 +2262,7 @@ wheels = [
[[package]]
name = "orbax-checkpoint"
-version = "0.10.2"
+version = "0.11.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "absl-py" },
@@ -2280,9 +2280,9 @@ dependencies = [
{ name = "tensorstore" },
{ name = "typing-extensions" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 }
+sdist = { url = "https://files.pythonhosted.org/packages/de/b3/a9a8a6bc08ded7634a9d85ba440400172f0a11f9341897b8fd3389fad245/orbax_checkpoint-0.11.0.tar.gz", hash = "sha256:d4a0dcc81edd29191cf5a4feb9cf2a4edd31fc5da79d7be616a04f11f2a4d484", size = 253035 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 },
+ { url = "https://files.pythonhosted.org/packages/87/32/3779fa524a2272f408ab51d869fde9ff1c0ca731eedd01e40436bcf7ba2c/orbax_checkpoint-0.11.0-py3-none-any.whl", hash = "sha256:892a124fce71f3e7c71451a2b2090c0251db1097803a119a00baa377113bc9ba", size = 360423 },
]
[[package]]
@@ -2436,7 +2436,7 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version < '3.11' and platform_system == 'Darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/55/5b/e3d951e34f8356e5feecacd12a8e3b258a1da6d9a03ad1770f28925f29bc/protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2", size = 216768 }
wheels = [
@@ -2454,10 +2454,10 @@ source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version == '3.11.*' and platform_system == 'Darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version == '3.11.*' and platform_system != 'Darwin' and platform_system != 'Linux')",
"python_full_version >= '3.12' and platform_system == 'Darwin'",
"python_full_version >= '3.12' and platform_machine == 'aarch64' and platform_system == 'Linux'",
- "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
+ "(python_full_version >= '3.12' and platform_machine != 'aarch64' and platform_system == 'Linux') or (python_full_version >= '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')",
]
sdist = { url = "https://files.pythonhosted.org/packages/e8/ab/cb61a4b87b2e7e6c312dce33602bd5884797fd054e0e53205f1c27cf0f66/protobuf-4.25.4.tar.gz", hash = "sha256:0dc4a62cc4052a036ee2204d26fe4d835c62827c855c8a03f29fe6da146b380d", size = 380283 }
wheels = [
@@ -2606,7 +2606,7 @@ name = "pytest"
version = "8.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "iniconfig" },
{ name = "packaging" },
@@ -3195,7 +3195,7 @@ source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "alabaster" },
{ name = "babel" },
- { name = "colorama", marker = "sys_platform == 'win32'" },
+ { name = "colorama", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux' and sys_platform == 'win32') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32')" },
{ name = "docutils" },
{ name = "imagesize" },
{ name = "jinja2" },
@@ -3669,14 +3669,14 @@ wheels = [
[[package]]
name = "treescope"
-version = "0.1.2"
+version = "0.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/2f/5d/ecb176971c78d90a3f74b7878ab9d013995fed285e3386a503ca008c9b03/treescope-0.1.2.tar.gz", hash = "sha256:2e4b35780884dfdbdcf44315d1c1c98fcf41daa0ea48a5b45ecc716920f88c86", size = 402255 }
+sdist = { url = "https://files.pythonhosted.org/packages/40/34/8ad5475c26837ca400c77951bcc0788b5f291d1509ae2eda5f97b042c24a/treescope-0.1.7.tar.gz", hash = "sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3", size = 530052 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/af/11/1a4d1877e5f7202bb3d0778a77b6ca222848b9b36fa65cbbc1fe12cb82b7/treescope-0.1.2-py3-none-any.whl", hash = "sha256:1811df6fbf79a5f54804e3ce2230b100547dc6350c99d973a6b9ba2bcd932e57", size = 172154 },
+ { url = "https://files.pythonhosted.org/packages/59/7d/f6da2b223749c58ec8ff95c87319196765fed05bd44dd86fb9bc4bf35f77/treescope-0.1.7-py3-none-any.whl", hash = "sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102", size = 175566 },
]
[[package]]
@@ -3684,7 +3684,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
- { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
+ { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system == 'Linux') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },