Skip to content

Commit

Permalink
Serialize to Proto.Any for python serializer (#4404)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Nov 27, 2024
1 parent a4e6d0d commit bd77ccb
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import asdict, dataclass, fields
from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable

from google.protobuf import any_pb2
from google.protobuf.message import Message
from pydantic import BaseModel

Expand Down Expand Up @@ -149,29 +150,35 @@ def serialize(self, message: PydanticT) -> bytes:
ProtobufT = TypeVar("ProtobufT", bound=Message)


# This class serializes to and from a google.protobuf.Any message that has been serialized to a string
class ProtobufMessageSerializer(MessageSerializer[ProtobufT]):
def __init__(self, cls: type[ProtobufT]) -> None:
self.cls = cls

@property
def data_content_type(self) -> str:
# TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently
# a couple of hard coded places where the system assumes the
# content is JSON_DATA_CONTENT_TYPE which will need to be fixed
# first.
return JSON_DATA_CONTENT_TYPE
return PROTOBUF_DATA_CONTENT_TYPE

@property
def type_name(self) -> str:
return _type_name(self.cls)

def deserialize(self, payload: bytes) -> ProtobufT:
ret = self.cls()
ret.ParseFromString(payload)
return ret
# Parse payload into a proto any
any_proto = any_pb2.Any()
any_proto.ParseFromString(payload)

destination_message = self.cls()

if not any_proto.Unpack(destination_message): # type: ignore
raise ValueError(f"Failed to unpack payload into {self.cls}")

return destination_message

def serialize(self, message: ProtobufT) -> bytes:
return message.SerializeToString()
any_proto = any_pb2.Any()
any_proto.Pack(message) # type: ignore
return any_proto.SerializeToString()


@dataclass
Expand Down
26 changes: 12 additions & 14 deletions python/packages/autogen-core/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
SerializationRegistry,
try_get_known_serializers_for_type,
)
from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer
from autogen_core.base._serialization import (
PROTOBUF_DATA_CONTENT_TYPE,
DataclassJsonMessageSerializer,
PydanticJsonMessageSerializer,
)
from autogen_core.components import Image
from PIL import Image as PILImage
from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage
Expand Down Expand Up @@ -90,12 +94,10 @@ def test_proto() -> None:

message = ProtoMessage(message="hello")
name = serde.type_name(message)
# TODO: should be PROTO_DATA_CONTENT_TYPE
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert name == "ProtoMessage"
# TODO: assert data == stuff
deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message


def test_nested_proto() -> None:
Expand All @@ -104,14 +106,10 @@ def test_nested_proto() -> None:

message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world"))
name = serde.type_name(message)

# TODO: should be PROTO_DATA_CONTENT_TYPE
data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)

# TODO: assert data == stuff

deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE)
assert deserialized == message
data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE)
assert deserialized.message == message.message
assert deserialized.nested.message == message.nested.message


@dataclass
Expand Down

0 comments on commit bd77ccb

Please sign in to comment.