Skip to content

Commit

Permalink
Allow metadata pass-through in flax.struct.field
Browse files Browse the repository at this point in the history
  • Loading branch information
cool-RR committed Jul 10, 2024
1 parent 0c3c74c commit f2f76d9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 3 additions & 2 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
_T = TypeVar('_T')


def field(pytree_node=True, **kwargs):
return dataclasses.field(metadata={'pytree_node': pytree_node}, **kwargs)
def field(pytree_node=True, *, metadata=None, **kwargs):
return dataclasses.field(metadata=(metadata or {}) | {'pytree_node': pytree_node},
**kwargs)


@dataclass_transform(field_specifiers=(field,)) # type: ignore[literal-required]
Expand Down
6 changes: 6 additions & 0 deletions tests/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class B(A):
class B(A, struct.PyTreeNode):
b: int

def test_metadata_pass_through(self):
@struct.dataclass
class A:
foo: int = struct.field(default=9, metadata={'baz': 9})
assert A.__dataclass_fields__['foo'].metadata == {'baz': 9, 'pytree_node': True}

@parameterized.parameters(
{'mode': 'dataclass'},
{'mode': 'pytreenode'},
Expand Down

0 comments on commit f2f76d9

Please sign in to comment.