diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index 93531bb485..121bb98eb8 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -174,6 +174,7 @@ def _recursive_merge(dict1, dict2): def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: + """Convert a dict of Linen-style variables to NNX variables.""" nnx_vars = jax.tree_util.tree_map_with_path( lambda kp, x: to_nnx_var(get_col_name(kp), x), variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) @@ -190,19 +191,22 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: + """Convert a dict of NNX variables (or variable states) to Linen-style variables.""" linen_structured = {} for kp, v in traversals.flatten_mapping( nnx_attrs, - is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef), - ).items(): + is_leaf=lambda _, x: isinstance( + x, variableslib.Variable | variableslib.VariableState | GraphDef + ), + ).items(): if isinstance(v, variableslib.Variable): col_name = variable_type_name(type(v)) + v = to_linen_var(v.to_state()) + elif isinstance(v, variableslib.VariableState): + col_name = variable_type_name(v.type) + v = to_linen_var(v) else: col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule linen_structured[(col_name, *kp)] = v variables = traversals.unflatten_mapping(linen_structured) - variables = jax.tree.map(lambda x: to_linen_var(x.to_state()), - variables, - is_leaf=lambda x: isinstance(x, variableslib.Variable)) return variables -