From a4e6d0d9771f400739b259ae3cd22f28b4095930 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 27 Nov 2024 06:08:14 -0800 Subject: [PATCH 1/3] Add grpcio back to deps (#4402) * Add grpcio back to deps * Update uv lock --- python/packages/autogen-core/pyproject.toml | 1 + python/uv.lock | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/packages/autogen-core/pyproject.toml b/python/packages/autogen-core/pyproject.toml index 8f70a64ebeb4..759137ddf2f9 100644 --- a/python/packages/autogen-core/pyproject.toml +++ b/python/packages/autogen-core/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "opentelemetry-api~=1.27.0", "asyncio_atexit", "jsonref~=1.1.0", + "grpcio~=1.62.0", # TODO: update this once we have a stable version. ] [project.optional-dependencies] diff --git a/python/uv.lock b/python/uv.lock index 2301cfcbdd1a..c17a7a0a56b9 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -7,7 +7,8 @@ resolution-markers = [ "python_full_version < '3.11'", "python_full_version == '3.11.*'", "python_full_version >= '3.12' and python_full_version < '3.12.4'", - "python_full_version >= '3.12.4'", + "python_full_version < '3.13'", + "python_full_version >= '3.13'", ] [manifest] @@ -346,6 +347,7 @@ source = { editable = "packages/autogen-core" } dependencies = [ { name = "aiohttp" }, { name = "asyncio-atexit" }, + { name = "grpcio" }, { name = "jsonref" }, { name = "openai" }, { name = "opentelemetry-api" }, @@ -407,13 +409,14 @@ dev = [ requires-dist = [ { name = "aiohttp" }, { name = "asyncio-atexit" }, + { name = "grpcio", specifier = "~=1.62.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = "~=1.62.0" }, { name = "jsonref", specifier = "~=1.1.0" }, { name = "openai", specifier = ">=1.3" }, { name = "opentelemetry-api", specifier = "~=1.27.0" }, { name = "pillow" }, { name = "protobuf", specifier = "~=4.25.1" }, - { name = "pydantic", specifier = ">=2.0.0,<3.0.0" }, + { name = "pydantic", specifier = "<3.0.0,>=2.0.0" }, { name = "tiktoken" }, { name = "typing-extensions" }, ] @@ -562,7 +565,7 @@ requires-dist = [ { name = "pdfminer-six" }, { name = "playwright" }, { name = "puremagic" }, - { name = "pydantic", specifier = ">=2.0.0,<3.0.0" }, + { name = "pydantic", specifier = "<3.0.0,>=2.0.0" }, { name = "pydub" }, { name = "python-pptx" }, { name = "requests" }, @@ -3313,7 +3316,7 @@ name = "psycopg" version = "3.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, { name = "tzdata", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d1/ad/7ce016ae63e231575df0498d2395d15f005f05e32d3a2d439038e1bd0851/psycopg-3.2.3.tar.gz", hash = "sha256:a5764f67c27bec8bfac85764d23c534af2c27b893550377e37ce59c12aac47a2", size = 155550 } @@ -4327,7 +4330,7 @@ name = "sqlalchemy" version = "2.0.36" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "greenlet", marker = "(python_full_version < '3.13' and platform_machine == 'AMD64') or (python_full_version < '3.13' and platform_machine == 'WIN32') or (python_full_version < '3.13' and platform_machine == 'aarch64') or (python_full_version < '3.13' and platform_machine == 'amd64') or (python_full_version < '3.13' and platform_machine == 'ppc64le') or (python_full_version < '3.13' and platform_machine == 'win32') or (python_full_version < '3.13' and platform_machine == 'x86_64')" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/50/65/9cbc9c4c3287bed2499e05033e207473504dc4df999ce49385fb1f8b058a/sqlalchemy-2.0.36.tar.gz", hash = "sha256:7f2767680b6d2398aea7082e45a774b2b0767b5c8d8ffb9c8b683088ea9b29c5", size = 9574485 } From bd77ccbd7b96f1a7f63a41f44e5c6ea0c99d75bb Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 27 Nov 2024 10:32:01 -0500 Subject: [PATCH 2/3] Serialize to Proto.Any for python serializer (#4404) --- .../src/autogen_core/base/_serialization.py | 25 +++++++++++------- .../autogen-core/tests/test_serialization.py | 26 +++++++++---------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/base/_serialization.py b/python/packages/autogen-core/src/autogen_core/base/_serialization.py index 51fd531feac5..74e028641126 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_serialization.py +++ b/python/packages/autogen-core/src/autogen_core/base/_serialization.py @@ -2,6 +2,7 @@ from dataclasses import asdict, dataclass, fields from typing import Any, ClassVar, Dict, List, Protocol, Sequence, TypeVar, cast, get_args, get_origin, runtime_checkable +from google.protobuf import any_pb2 from google.protobuf.message import Message from pydantic import BaseModel @@ -149,29 +150,35 @@ def serialize(self, message: PydanticT) -> bytes: ProtobufT = TypeVar("ProtobufT", bound=Message) +# This class serializes to and from a google.protobuf.Any message that has been serialized to a string class ProtobufMessageSerializer(MessageSerializer[ProtobufT]): def __init__(self, cls: type[ProtobufT]) -> None: self.cls = cls @property def data_content_type(self) -> str: - # TODO: This should be PROTOBUF_DATA_CONTENT_TYPE. There are currently - # a couple of hard coded places where the system assumes the - # content is JSON_DATA_CONTENT_TYPE which will need to be fixed - # first. - return JSON_DATA_CONTENT_TYPE + return PROTOBUF_DATA_CONTENT_TYPE @property def type_name(self) -> str: return _type_name(self.cls) def deserialize(self, payload: bytes) -> ProtobufT: - ret = self.cls() - ret.ParseFromString(payload) - return ret + # Parse payload into a proto any + any_proto = any_pb2.Any() + any_proto.ParseFromString(payload) + + destination_message = self.cls() + + if not any_proto.Unpack(destination_message): # type: ignore + raise ValueError(f"Failed to unpack payload into {self.cls}") + + return destination_message def serialize(self, message: ProtobufT) -> bytes: - return message.SerializeToString() + any_proto = any_pb2.Any() + any_proto.Pack(message) # type: ignore + return any_proto.SerializeToString() @dataclass diff --git a/python/packages/autogen-core/tests/test_serialization.py b/python/packages/autogen-core/tests/test_serialization.py index 6b5568411f6f..f6ab2067c4d8 100644 --- a/python/packages/autogen-core/tests/test_serialization.py +++ b/python/packages/autogen-core/tests/test_serialization.py @@ -8,7 +8,11 @@ SerializationRegistry, try_get_known_serializers_for_type, ) -from autogen_core.base._serialization import DataclassJsonMessageSerializer, PydanticJsonMessageSerializer +from autogen_core.base._serialization import ( + PROTOBUF_DATA_CONTENT_TYPE, + DataclassJsonMessageSerializer, + PydanticJsonMessageSerializer, +) from autogen_core.components import Image from PIL import Image as PILImage from protos.serialization_test_pb2 import NestingProtoMessage, ProtoMessage @@ -90,12 +94,10 @@ def test_proto() -> None: message = ProtoMessage(message="hello") name = serde.type_name(message) - # TODO: should be PROTO_DATA_CONTENT_TYPE - data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) + data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) assert name == "ProtoMessage" - # TODO: assert data == stuff - deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - assert deserialized == message + deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + assert deserialized.message == message.message def test_nested_proto() -> None: @@ -104,14 +106,10 @@ def test_nested_proto() -> None: message = NestingProtoMessage(message="hello", nested=ProtoMessage(message="world")) name = serde.type_name(message) - - # TODO: should be PROTO_DATA_CONTENT_TYPE - data = serde.serialize(message, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - - # TODO: assert data == stuff - - deserialized = serde.deserialize(data, type_name=name, data_content_type=JSON_DATA_CONTENT_TYPE) - assert deserialized == message + data = serde.serialize(message, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + deserialized = serde.deserialize(data, type_name=name, data_content_type=PROTOBUF_DATA_CONTENT_TYPE) + assert deserialized.message == message.message + assert deserialized.nested.message == message.nested.message @dataclass From a4067f6c0af6f485e7c952b2e048586357b0b310 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 27 Nov 2024 11:32:03 -0500 Subject: [PATCH 3/3] Migrate Python distributed runtime to use cloud events for event (#4407) * Cloud event publishing * Implement cloud event receiving * impl host servicer and --- protos/agent_worker.proto | 3 +- .../autogen_core/application/_constants.py | 13 ++ .../src/autogen_core/application/_utils.py | 3 - .../application/_worker_runtime.py | 139 ++++++++++++++---- .../application/_worker_runtime_host.py | 2 +- .../_worker_runtime_host_servicer.py | 21 ++- .../application/protos/agent_worker_pb2.py | 8 +- .../application/protos/agent_worker_pb2.pyi | 16 +- .../src/autogen_core/base/__init__.py | 2 + .../autogen-core/tests/test_utils/__init__.py | 2 + .../autogen-core/tests/test_worker_runtime.py | 66 ++++++++- 11 files changed, 210 insertions(+), 65 deletions(-) create mode 100644 python/packages/autogen-core/src/autogen_core/application/_constants.py delete mode 100644 python/packages/autogen-core/src/autogen_core/application/_utils.py diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 61b00333cd24..4d346dfecd63 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -117,12 +117,11 @@ message Message { oneof message { RpcRequest request = 1; RpcResponse response = 2; - Event event = 3; + cloudevent.CloudEvent cloudEvent = 3; RegisterAgentTypeRequest registerAgentTypeRequest = 4; RegisterAgentTypeResponse registerAgentTypeResponse = 5; AddSubscriptionRequest addSubscriptionRequest = 6; AddSubscriptionResponse addSubscriptionResponse = 7; - cloudevent.CloudEvent cloudEvent = 8; } } diff --git a/python/packages/autogen-core/src/autogen_core/application/_constants.py b/python/packages/autogen-core/src/autogen_core/application/_constants.py new file mode 100644 index 000000000000..6dab3fffdb44 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/application/_constants.py @@ -0,0 +1,13 @@ +GRPC_IMPORT_ERROR_STR = ( + "Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]" +) + +DATA_CONTENT_TYPE_ATTR = "datacontenttype" +DATA_SCHEMA_ATTR = "dataschema" +AGENT_SENDER_TYPE_ATTR = "agagentsendertype" +AGENT_SENDER_KEY_ATTR = "agagentsenderkey" +MESSAGE_KIND_ATTR = "agmsgkind" +MESSAGE_KIND_VALUE_PUBLISH = "publish" +MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request" +MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response" +MESSAGE_KIND_VALUE_RPC_ERROR = "error" diff --git a/python/packages/autogen-core/src/autogen_core/application/_utils.py b/python/packages/autogen-core/src/autogen_core/application/_utils.py deleted file mode 100644 index 10fbfd1b8c8a..000000000000 --- a/python/packages/autogen-core/src/autogen_core/application/_utils.py +++ /dev/null @@ -1,3 +0,0 @@ -GRPC_IMPORT_ERROR_STR = ( - "Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]" -) diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 0e5fb933a08e..24007fadfc7d 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,11 +28,15 @@ cast, ) +from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated +from autogen_core.application.protos import cloudevent_pb2 + from ..base import ( JSON_DATA_CONTENT_TYPE, + PROTOBUF_DATA_CONTENT_TYPE, Agent, AgentId, AgentInstantiationContext, @@ -49,8 +53,9 @@ from ..base._serialization import MessageSerializer, SerializationRegistry from ..base._type_helpers import ChannelArgumentType from ..components import TypePrefixSubscription, TypeSubscription +from . import _constants +from ._constants import GRPC_IMPORT_ERROR_STR from ._helpers import SubscriptionManager, get_impl -from ._utils import GRPC_IMPORT_ERROR_STR from .protos import agent_worker_pb2, agent_worker_pb2_grpc from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata @@ -178,6 +183,7 @@ def __init__( host_address: str, tracer_provider: TracerProvider | None = None, extra_grpc_config: ChannelArgumentType | None = None, + payload_serialization_format: str = JSON_DATA_CONTENT_TYPE, ) -> None: self._host_address = host_address self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime")) @@ -198,6 +204,11 @@ def __init__( self._serialization_registry = SerializationRegistry() self._extra_grpc_config = extra_grpc_config or [] + if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}: + raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}") + + self._payload_serialization_format = payload_serialization_format + def start(self) -> None: """Start the runtime in a background task.""" if self._running: @@ -236,8 +247,10 @@ async def _run_read_loop(self) -> None: self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) - case "event": - task = asyncio.create_task(self._process_event(message.event)) + case "cloudEvent": + # The proto typing doesnt resolve this one + cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore + task = asyncio.create_task(self._process_event(cloud_event)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) @@ -257,8 +270,6 @@ async def _run_read_loop(self) -> None: task.add_done_callback(self._background_tasks.discard) case None: logger.warning("No message") - case other: - logger.error(f"Unknown message type: {other}") except Exception as e: logger.error("Error in read loop", exc_info=e) @@ -381,30 +392,64 @@ async def publish_message( if message_id is None: message_id = str(uuid.uuid4()) - # TODO: consume message_id - message_type = self._serialization_registry.type_name(message) with self._trace_helper.trace_block( "create", topic_id, parent=None, extraAttributes={"message_type": message_type} ): serialized_message = self._serialization_registry.serialize( - message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE + message, type_name=message_type, data_content_type=self._payload_serialization_format ) - telemetry_metadata = get_telemetry_grpc_metadata() - runtime_message = agent_worker_pb2.Message( - event=agent_worker_pb2.Event( - topic_type=topic_id.type, - topic_source=topic_id.source, - source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None, - metadata=telemetry_metadata, - payload=agent_worker_pb2.Payload( - data_type=message_type, - data=serialized_message, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), + + sender_id = sender or AgentId("unknown", "unknown") + attributes = { + _constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=self._payload_serialization_format + ), + _constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type), + _constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender_id.type + ), + _constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=sender_id.key + ), + _constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue( + ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH + ), + } + + # If sending JSON we fill text_data with the serialized message + # If sending Protobuf we fill proto_data with the serialized message + # TODO: add an encoding field for serializer + + if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE: + runtime_message = agent_worker_pb2.Message( + cloudEvent=cloudevent_pb2.CloudEvent( + id=message_id, + spec_version="1.0", + type=topic_id.type, + source=topic_id.source, + attributes=attributes, + # TODO: use text, or proto fields appropriately + binary_data=serialized_message, + ) + ) + else: + # We need to unpack the serialized proto back into an Any + # TODO: find a way to prevent the roundtrip serialization + any_proto = any_pb2.Any() + any_proto.ParseFromString(serialized_message) + runtime_message = agent_worker_pb2.Message( + cloudEvent=cloudevent_pb2.CloudEvent( + id=message_id, + spec_version="1.0", + type=topic_id.type, + source=topic_id.source, + attributes=attributes, + proto_data=any_proto, + ) ) - ) + telemetry_metadata = get_telemetry_grpc_metadata() task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) @@ -523,28 +568,58 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> Non else: future.set_result(result) - async def _process_event(self, event: agent_worker_pb2.Event) -> None: - message = self._serialization_registry.deserialize( - event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type - ) + async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: + event_attributes = event.attributes sender: AgentId | None = None - if event.HasField("source"): - sender = AgentId(event.source.type, event.source.key) - topic_id = TopicId(event.topic_type, event.topic_source) + if ( + _constants.AGENT_SENDER_TYPE_ATTR in event_attributes + and _constants.AGENT_SENDER_KEY_ATTR in event_attributes + ): + sender = AgentId( + event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string, + event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string, + ) + topic_id = TopicId(event.type, event.source) # Get the recipients for the topic. recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) + + message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string + message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string + + if message_content_type == JSON_DATA_CONTENT_TYPE: + message = self._serialization_registry.deserialize( + event.binary_data, type_name=message_type, data_content_type=message_content_type + ) + elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE: + # TODO: find a way to prevent the roundtrip serialization + proto_binary_data = event.proto_data.SerializeToString() + message = self._serialization_registry.deserialize( + proto_binary_data, type_name=message_type, data_content_type=message_content_type + ) + else: + raise ValueError(f"Unsupported message content type: {message_content_type}") + + # TODO: dont read these values in the runtime + topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else "" + is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST + is_marked_rpc_type = ( + _constants.MESSAGE_KIND_ATTR in event_attributes + and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST + ) + if is_rpc and not is_marked_rpc_type: + warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2) + # Send the message to each recipient. responses: List[Awaitable[Any]] = [] for agent_id in recipients: if agent_id == sender: continue - # TODO: consume message_id message_context = MessageContext( sender=sender, topic_id=topic_id, - is_rpc=False, + is_rpc=is_rpc, cancellation_token=CancellationToken(), - message_id="NOT_DEFINED_TODO_FIX", + message_id=event.id, ) agent = await self._get_agent(agent_id) with MessageHandlerContext.populate_context(agent.id): @@ -554,7 +629,7 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any: "process", agent.id, parent=event.metadata, - extraAttributes={"message_type": event.payload.data_type}, + extraAttributes={"message_type": message_type}, ): await agent.on_message(message, ctx=message_context) diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py index d7fee07ff1f8..b9befce585d1 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py @@ -4,7 +4,7 @@ from typing import Optional, Sequence from ..base._type_helpers import ChannelArgumentType -from ._utils import GRPC_IMPORT_ERROR_STR +from ._constants import GRPC_IMPORT_ERROR_STR from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer try: diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py index 7c597bd07a8f..e24a7db3f30a 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py @@ -2,21 +2,21 @@ import logging from _collections_abc import AsyncIterator, Iterator from asyncio import Future, Task -from typing import Any, Dict, Set +from typing import Any, Dict, Set, cast from autogen_core.base._type_prefix_subscription import TypePrefixSubscription from ..base import Subscription, TopicId from ..components import TypeSubscription +from ._constants import GRPC_IMPORT_ERROR_STR from ._helpers import SubscriptionManager -from ._utils import GRPC_IMPORT_ERROR_STR try: import grpc except ImportError as e: raise ImportError(GRPC_IMPORT_ERROR_STR) from e -from .protos import agent_worker_pb2, agent_worker_pb2_grpc +from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2 logger = logging.getLogger("autogen_core") event_logger = logging.getLogger("autogen_core.events") @@ -84,7 +84,7 @@ async def _on_client_disconnect(self, client_id: int) -> None: for agent_type in agent_types: logger.info(f"Removing agent type {agent_type} from agent type to client id mapping") del self._agent_type_to_client_id[agent_type] - for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []): + for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()): logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}") await self._subscription_manager.remove_subscription(sub_id) logger.info(f"Client {client_id} disconnected successfully") @@ -114,8 +114,9 @@ async def _receive_messages( self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) task.add_done_callback(self._background_tasks.discard) - case "event": - event: agent_worker_pb2.Event = message.event + case "cloudEvent": + # The proto typing doesnt resolve this one + event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore task = asyncio.create_task(self._process_event(event)) self._background_tasks.add(task) task.add_done_callback(self._raise_on_exception) @@ -138,8 +139,6 @@ async def _receive_messages( logger.warning(f"Received unexpected message type: {oneofcase}") case None: logger.warning("Received empty message") - case other: - logger.error(f"Received unexpected message: {other}") async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: # Deliver the message to a client given the target agent type. @@ -178,8 +177,8 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse, client future = self._pending_responses[client_id].pop(response.request_id) future.set_result(response) - async def _process_event(self, event: agent_worker_pb2.Event) -> None: - topic_id = TopicId(type=event.topic_type, source=event.topic_source) + async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: + topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) # Get the client ids of the recipients. async with self._agent_type_to_client_id_lock: @@ -192,7 +191,7 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None: logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.") # Deliver the event to clients. for client_id in client_ids: - await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event)) + await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event)) async def _process_register_agent_type_request( self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py index 8f143d770aef..319ee2c6365d 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xc6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x1e\n\x05\x65vent\x18\x03 \x01(\x0b\x32\r.agents.EventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x12,\n\ncloudEvent\x18\x08 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -69,7 +69,7 @@ _globals['_SAVESTATERESPONSE']._serialized_start=1807 _globals['_SAVESTATERESPONSE']._serialized_end=1873 _globals['_MESSAGE']._serialized_start=1876 - _globals['_MESSAGE']._serialized_end=2330 - _globals['_AGENTRPC']._serialized_start=2333 - _globals['_AGENTRPC']._serialized_end=2511 + _globals['_MESSAGE']._serialized_end=2298 + _globals['_AGENTRPC']._serialized_start=2301 + _globals['_AGENTRPC']._serialized_end=2479 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi index 728bfafcc81a..79e384ab948b 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi @@ -437,18 +437,17 @@ class Message(google.protobuf.message.Message): REQUEST_FIELD_NUMBER: builtins.int RESPONSE_FIELD_NUMBER: builtins.int - EVENT_FIELD_NUMBER: builtins.int + CLOUDEVENT_FIELD_NUMBER: builtins.int REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int - CLOUDEVENT_FIELD_NUMBER: builtins.int @property def request(self) -> global___RpcRequest: ... @property def response(self) -> global___RpcResponse: ... @property - def event(self) -> global___Event: ... + def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... @property def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ... @property @@ -457,22 +456,19 @@ class Message(google.protobuf.message.Message): def addSubscriptionRequest(self) -> global___AddSubscriptionRequest: ... @property def addSubscriptionResponse(self) -> global___AddSubscriptionResponse: ... - @property - def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... def __init__( self, *, request: global___RpcRequest | None = ..., response: global___RpcResponse | None = ..., - event: global___Event | None = ..., + cloudEvent: cloudevent_pb2.CloudEvent | None = ..., registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ..., registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ..., addSubscriptionRequest: global___AddSubscriptionRequest | None = ..., addSubscriptionResponse: global___AddSubscriptionResponse | None = ..., - cloudEvent: cloudevent_pb2.CloudEvent | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "event", b"event", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "event", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse", "cloudEvent"] | None: ... + def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... global___Message = Message diff --git a/python/packages/autogen-core/src/autogen_core/base/__init__.py b/python/packages/autogen-core/src/autogen_core/base/__init__.py index e4463e7eeccf..8d95083ec883 100644 --- a/python/packages/autogen-core/src/autogen_core/base/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/base/__init__.py @@ -16,6 +16,7 @@ from ._message_handler_context import MessageHandlerContext from ._serialization import ( JSON_DATA_CONTENT_TYPE, + PROTOBUF_DATA_CONTENT_TYPE, MessageSerializer, SerializationRegistry, UnknownPayload, @@ -43,6 +44,7 @@ "SubscriptionInstantiationContext", "MessageHandlerContext", "JSON_DATA_CONTENT_TYPE", + "PROTOBUF_DATA_CONTENT_TYPE", "MessageSerializer", "try_get_known_serializers_for_type", "UnknownPayload", diff --git a/python/packages/autogen-core/tests/test_utils/__init__.py b/python/packages/autogen-core/tests/test_utils/__init__.py index 92096550cd98..5de7519fc49b 100644 --- a/python/packages/autogen-core/tests/test_utils/__init__.py +++ b/python/packages/autogen-core/tests/test_utils/__init__.py @@ -23,12 +23,14 @@ class LoopbackAgent(RoutedAgent): def __init__(self) -> None: super().__init__("A loop back agent.") self.num_calls = 0 + self.received_messages: list[Any] = [] @message_handler async def on_new_message( self, message: MessageType | ContentMessage, ctx: MessageContext ) -> MessageType | ContentMessage: self.num_calls += 1 + self.received_messages.append(message) return message diff --git a/python/packages/autogen-core/tests/test_worker_runtime.py b/python/packages/autogen-core/tests/test_worker_runtime.py index 26c95dc01860..2a7d3acdc38f 100644 --- a/python/packages/autogen-core/tests/test_worker_runtime.py +++ b/python/packages/autogen-core/tests/test_worker_runtime.py @@ -1,22 +1,28 @@ import asyncio import logging import os -from typing import List +from typing import Any, List import pytest from autogen_core.application import WorkerAgentRuntime, WorkerAgentRuntimeHost from autogen_core.base import ( + PROTOBUF_DATA_CONTENT_TYPE, AgentId, AgentType, + MessageContext, + Subscription, TopicId, try_get_known_serializers_for_type, ) -from autogen_core.base._subscription import Subscription from autogen_core.components import ( DefaultTopicId, + RoutedAgent, TypeSubscription, + default_subscription, + event, type_subscription, ) +from protos.serialization_test_pb2 import ProtoMessage from test_utils import ( CascadingAgent, CascadingMessageType, @@ -401,6 +407,62 @@ async def get_subscribed_recipients() -> List[AgentId]: await worker1_2.stop() +@default_subscription +class ProtoReceivingAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("A loop back agent.") + self.num_calls = 0 + self.received_messages: list[Any] = [] + + @event + async def on_new_message(self, message: ProtoMessage, ctx: MessageContext) -> None: + self.num_calls += 1 + self.received_messages.append(message) + + +@pytest.mark.asyncio +async def test_proto_payloads() -> None: + host_address = "localhost:50057" + host = WorkerAgentRuntimeHost(address=host_address) + host.start() + receiver_runtime = WorkerAgentRuntime( + host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE + ) + receiver_runtime.start() + publisher_runtime = WorkerAgentRuntime( + host_address=host_address, payload_serialization_format=PROTOBUF_DATA_CONTENT_TYPE + ) + publisher_runtime.add_message_serializer(try_get_known_serializers_for_type(ProtoMessage)) + publisher_runtime.start() + + await ProtoReceivingAgent.register(receiver_runtime, "name", ProtoReceivingAgent) + + await publisher_runtime.publish_message(ProtoMessage(message="Hello!"), topic_id=DefaultTopicId()) + + await asyncio.sleep(2) + + # Agent in default namespace should have received the message + long_running_agent = await receiver_runtime.try_get_underlying_agent_instance( + AgentId("name", "default"), type=ProtoReceivingAgent + ) + assert long_running_agent.num_calls == 1 + assert long_running_agent.received_messages[0].message == "Hello!" + + # Agent in other namespace should not have received the message + other_long_running_agent = await receiver_runtime.try_get_underlying_agent_instance( + AgentId("name", key="other"), type=ProtoReceivingAgent + ) + assert other_long_running_agent.num_calls == 0 + assert len(other_long_running_agent.received_messages) == 0 + + await receiver_runtime.stop() + await publisher_runtime.stop() + await host.stop() + + +# TODO add tests for failure to deserialize + + @pytest.mark.asyncio async def test_grpc_max_message_size() -> None: default_max_size = 2**22