Skip to content

Commit

Permalink
Merge pull request #3880 from NeilGirdhar:fix_mappings
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630013257
  • Loading branch information
Flax Authors committed May 2, 2024
2 parents ae5d66d + d1935c7 commit d0e080d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
7 changes: 4 additions & 3 deletions flax/experimental/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# limitations under the License.
from __future__ import annotations

from collections.abc import Mapping
import typing as tp
import typing_extensions as tpe

Expand All @@ -42,7 +43,7 @@
A = tp.TypeVar('A')

StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
FlatState = dict[PathParts, StateLeaf]
FlatState = Mapping[PathParts, StateLeaf]


def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
Expand All @@ -66,8 +67,8 @@ class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable):
def __init__(
self,
mapping: tp.Union[
tp.Mapping[Key, tp.Mapping | StateLeaf],
tp.Iterator[tuple[Key, tp.Mapping | StateLeaf]],
Mapping[Key, Mapping | StateLeaf],
tp.Iterator[tuple[Key, Mapping | StateLeaf]],
],
/,
):
Expand Down
49 changes: 40 additions & 9 deletions flax/traverse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,19 @@
Traversals never mutate the original data. Therefore, an update essentially
returns a copy of the data including the provided updates.
"""
from __future__ import annotations

import abc
from collections.abc import Callable, Mapping
import copy
import dataclasses
import warnings
from typing import Any, Callable
from typing import Any, Union, overload

import jax

import flax
from flax.core.scope import VariableDict
from flax.typing import PathParts
from flax.typing import PathParts, VariableDict

from . import struct

Expand All @@ -77,7 +78,37 @@ class _EmptyNode:
empty_node = _EmptyNode()


def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
# TODO: In Python 3.10, use TypeAlias.
IsLeafCallable = Callable[[tuple[Any, ...], Mapping[Any, Any]], bool]


@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: None = None
) -> dict[tuple[Any, ...], Any]:
...

@overload
def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: str,
) -> dict[str, Any]:
...

def flatten_dict(xs: Mapping[Any, Any],
/,
*,
keep_empty_nodes: bool = False,
is_leaf: Union[None, IsLeafCallable] = None,
sep: Union[None, str] = None
) -> dict[Any, Any]:
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Expand Down Expand Up @@ -111,16 +142,16 @@ def flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None):
The flattened dictionary.
"""
assert isinstance(
xs, (flax.core.FrozenDict, dict)
), f'expected (frozen)dict; got {type(xs)}'
xs, Mapping
), f'expected Mapping; got {type(xs).__qualname__}'

def _key(path):
def _key(path: tuple[Any, ...]) -> Union[tuple[Any, ...], str]:
if sep is None:
return path
return sep.join(path)

def _flatten(xs, prefix):
if not isinstance(xs, (flax.core.FrozenDict, dict)) or (
def _flatten(xs: Any, prefix: tuple[Any, ...]) -> dict[Any, Any]:
if not isinstance(xs, Mapping) or (
is_leaf and is_leaf(prefix, xs)
):
return {_key(prefix): xs}
Expand Down

0 comments on commit d0e080d

Please sign in to comment.