Skip to content

Commit

Permalink
Support pickle protocol for DataclassArray.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675107466
  • Loading branch information
Jan Hosang authored and The dataclass_array Authors committed Sep 16, 2024
1 parent ebe7df4 commit b40cf75
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
84 changes: 54 additions & 30 deletions dataclass_array/array_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,40 +798,28 @@ def tree_unflatten(
array_field_values: list[DcOrArray],
) -> _DcT:
"""`jax.tree_utils` support."""
array_field_kwargs = dict(
zip(
metadata.array_field_names,
array_field_values,
)
)
init_fields = {}
non_init_fields = {}
fields = {f.name: f for f in dataclasses.fields(cls)} # pytype: disable=wrong-arg-types # re-none
for k, v in metadata.non_array_field_kwargs.items():
if fields[k].init:
init_fields[k] = v
else:
non_init_fields[k] = v

self = cls(**array_field_kwargs, **init_fields)
# Currently it's not clear how to handle non-init fields so raise an error
if non_init_fields:
if set(non_init_fields) - self.__dca_non_init_fields__:
raise ValueError(
'`dca.DataclassArray` field with init=False should be explicitly '
'specified in `__dca_non_init_fields__` for them to be '
'propagated by `tree_map`.'
)
# TODO(py310): Delete once dataclass supports `kw_only=True`
for k, v in non_init_fields.items():
self._setattr(k, v) # pylint: disable=protected-access
return self
return _tree_unflatten(metadata, array_field_values, cls, cls)

def _tree_flatten(self) -> tuple[tuple[DcOrArray, ...], _TreeMetadata]:
components, metadata = self.tree_flatten_with_keys()
components = tuple(v for _, v in components)
return components, metadata

def __getstate__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
components, metadata = self._tree_flatten()
return metadata, components

def __setstate__(
self, state: tuple[_TreeMetadata, list[DcOrArray]]
) -> None:

def constructor(**kwargs):
self.__init__(**kwargs)
return self

metadata, components = state
_tree_unflatten(metadata, components, constructor, type(self))

def __tf_flatten__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
components, metadata = self._tree_flatten()
return metadata, components
Expand All @@ -858,6 +846,42 @@ def assert_same_xnp(self, x: Union[Array[...], DataclassArray]) -> None:
)


def _tree_unflatten(
metadata: _TreeMetadata,
array_field_values: list[DcOrArray],
constructor: Callable,
cls: Type[_DcT],
) -> _DcT:
array_field_kwargs = dict(
zip(
metadata.array_field_names,
array_field_values,
)
)
init_fields = {}
non_init_fields = {}
fields = {f.name: f for f in dataclasses.fields(cls)} # pytype: disable=wrong-arg-types # re-none
for k, v in metadata.non_array_field_kwargs.items():
if fields[k].init:
init_fields[k] = v
else:
non_init_fields[k] = v

self = constructor(**array_field_kwargs, **init_fields)
# Currently it's not clear how to handle non-init fields so raise an error
if non_init_fields:
if set(non_init_fields) - self.__dca_non_init_fields__:
raise ValueError(
'`dca.DataclassArray` field with init=False should be explicitly '
'specified in `__dca_non_init_fields__` for them to be '
'propagated by `tree_map`.'
)
# TODO(py310): Delete once dataclass supports `kw_only=True`
for k, v in non_init_fields.items():
self._setattr(k, v) # pylint: disable=protected-access
return self


def _init_cls(self: DataclassArray) -> None:
"""Setup the class the first time the instance is called.
Expand Down Expand Up @@ -1023,8 +1047,8 @@ class _ArrayFieldMetadata:
Attributes:
inner_shape_non_static: Inner shape. Can contain non-static dims (e.g.
`(None, 3)`)
dtype: Type of the array. Can be `enp.dtypes.DType` or
`dca.DataclassArray` for nested arrays.
dtype: Type of the array. Can be `enp.dtypes.DType` or `dca.DataclassArray`
for nested arrays.
"""

inner_shape_non_static: DynamicShape
Expand Down
19 changes: 19 additions & 0 deletions dataclass_array/array_dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from __future__ import annotations

import io
import pickle
from typing import Any

import dataclass_array as dca
Expand Down Expand Up @@ -769,3 +771,20 @@ def test_class_getitem():
assert Point[''] != Point
assert Point[''] != Point['h w']
assert Point[''] != Isometrie['']


@enp.testing.parametrize_xnp()
@pytest.mark.parametrize('batch_shape', [(), (1, 3)])
def test_dataclass_pickle_unpickle(xnp: enp.NpModule, batch_shape: Shape):
expected = Point(
x=xnp.zeros(batch_shape, dtype=xnp.float32),
y=xnp.zeros(batch_shape, dtype=xnp.float32),
)
buffer = io.BytesIO()
pickle.dump(expected, buffer)
buffer.seek(0)
actual = pickle.load(buffer)
assert actual.xnp is expected.xnp
xnp = actual.xnp
assert (actual.x == expected.x))
assert xnp.all(xnp.equal(actual.y == expected.y))

0 comments on commit b40cf75

Please sign in to comment.