Skip to content

Commit

Permalink
Support pickle protocol for DataclassArray.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675107466
  • Loading branch information
Jan Hosang authored and The dataclass_array Authors committed Sep 16, 2024
1 parent ebe7df4 commit a032514
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
38 changes: 35 additions & 3 deletions dataclass_array/array_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,21 @@ def tree_unflatten(
array_field_values: list[DcOrArray],
) -> _DcT:
"""`jax.tree_utils` support."""
return cls._tree_unflatten(
metadata=metadata,
array_field_values=array_field_values,
constructor=cls,
)

@classmethod
def _tree_unflatten(
cls: Type[_DcT],
*,
metadata: _TreeMetadata,
array_field_values: list[DcOrArray],
constructor: Callable[..., _DcT],
) -> _DcT:
"""Initialize a model after serialization."""
array_field_kwargs = dict(
zip(
metadata.array_field_names,
Expand All @@ -813,7 +828,7 @@ def tree_unflatten(
else:
non_init_fields[k] = v

self = cls(**array_field_kwargs, **init_fields)
self = constructor(**array_field_kwargs, **init_fields)
# Currently it's not clear how to handle non-init fields so raise an error
if non_init_fields:
if set(non_init_fields) - self.__dca_non_init_fields__:
Expand Down Expand Up @@ -844,6 +859,23 @@ def __tf_unflatten__(
) -> _DcT:
return cls.tree_unflatten(metadata, components)

def __getstate__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
components, metadata = self._tree_flatten()
return metadata, components

def __setstate__(self, state: tuple[_TreeMetadata, list[DcOrArray]]) -> None:

def constructor(**kwargs):
self.__init__(**kwargs)
return self

metadata, components = state
type(self)._tree_unflatten(
metadata=metadata,
array_field_values=components,
constructor=constructor,
)

def _setattr(self, name: str, value: Any) -> None:
"""Like setattr, but support `frozen` dataclasses."""
object.__setattr__(self, name, value)
Expand Down Expand Up @@ -1023,8 +1055,8 @@ class _ArrayFieldMetadata:
Attributes:
inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
`(None, 3)`)
dtype: Type of the array. Can be `enp.dtypes.DType` or
`dca.DataclassArray` for nested arrays.
dtype: Type of the array. Can be `enp.dtypes.DType` or `dca.DataclassArray`
for nested arrays.
"""

inner_shape_non_static: DynamicShape
Expand Down
13 changes: 13 additions & 0 deletions dataclass_array/array_dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import pickle
from typing import Any

import dataclass_array as dca
Expand Down Expand Up @@ -769,3 +770,15 @@ def test_class_getitem():
assert Point[''] != Point
assert Point[''] != Point['h w']
assert Point[''] != Isometrie['']


@enp.testing.parametrize_xnp()
@pytest.mark.parametrize('batch_shape', [(), (1, 3)])
def test_dataclass_pickle_unpickle(xnp: enp.NpModule, batch_shape: Shape):
expected = Point(
x=xnp.zeros(batch_shape, dtype=xnp.float32),
y=xnp.zeros(batch_shape, dtype=xnp.float32),
)
buffer = pickle.dumps(expected)
actual = pickle.loads(buffer)
dca.testing.assert_array_equal(actual, expected)

0 comments on commit a032514

Please sign in to comment.