Skip to content

Commit

Permalink
Add protocol methods to Penzai variables for library interop.
Browse files Browse the repository at this point in the history
These methods can be used by other JAX libraries such as NNX to
support passing Penzai layers through function transformations.

PiperOrigin-RevId: 673081778
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Sep 10, 2024
1 parent 766eafc commit e78349f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
15 changes: 15 additions & 0 deletions penzai/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def update(self, new_frozen_value: AbstractVariableValue):
"""Updates the value of this variable to match a frozen variable."""
raise NotImplementedError("update must be overridden by subclasses.")

@typing.final
def __get_state_as_jax_pytree__(self) -> AbstractVariableValue:
"""Mutable variable protocol method, for JAX ecosystem interoperability."""
return self.freeze()

@typing.final
def __set_state_from_jax_pytree__(self, value: AbstractVariableValue):
"""Mutable variable protocol method, for JAX ecosystem interoperability."""
return self.update(value)


class AbstractVariableValue(struct.Struct, abc.ABC):
"""Base class for all frozen variables."""
Expand All @@ -141,6 +151,11 @@ def get_slot(self) -> AbstractVariableSlot:
"""
raise NotImplementedError("get_slot must be overridden by subclasses.")

@typing.final
def __jax_pytree_state_to_new_variable__(self):
"""Mutable variable protocol method, for JAX ecosystem interoperability."""
return self.unfreeze_as_copy()


class AbstractVariableSlot(struct.Struct, abc.ABC):
"""Base class for all variable slots.
Expand Down
19 changes: 19 additions & 0 deletions tests/core/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,25 @@ def bad(x):
):
variables.variable_jit(bad)(10)

def test_variable_protocol_methods(self):
var1 = variables.StateVariable(value=1, label="var1")
var1_value = var1.__get_state_as_jax_pytree__()
self.assertEqual(
var1_value,
variables.StateVariableValue(value=1, label="var1"),
)

var2 = var1_value.__jax_pytree_state_to_new_variable__()
self.assertIsInstance(var2, variables.StateVariable)
self.assertEqual(var2.value, 1)
self.assertEqual(var2.label, "var1")
self.assertIsNot(var1, var2)

var1.__set_state_from_jax_pytree__(
variables.StateVariableValue(value=2, label="var1")
)
self.assertEqual(var1.value, 2)


if __name__ == "__main__":
absltest.main()

0 comments on commit e78349f

Please sign in to comment.