Skip to content

Commit

Permalink
Merge branch 'main' into agentchat-cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
rysweet authored Nov 27, 2024
2 parents 40b962f + a4067f6 commit b506c3b
Show file tree
Hide file tree
Showing 15 changed files with 247 additions and 93 deletions.
3 changes: 1 addition & 2 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ message Message {
oneof message {
RpcRequest request = 1;
RpcResponse response = 2;
Event event = 3;
cloudevent.CloudEvent cloudEvent = 3;
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
AddSubscriptionRequest addSubscriptionRequest = 6;
AddSubscriptionResponse addSubscriptionResponse = 7;
cloudevent.CloudEvent cloudEvent = 8;
}
}

1 change: 1 addition & 0 deletions python/packages/autogen-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"opentelemetry-api~=1.27.0",
"asyncio_atexit",
"jsonref~=1.1.0",
"grpcio~=1.62.0", # TODO: update this once we have a stable version.
]

[project.optional-dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
GRPC_IMPORT_ERROR_STR = (
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
)

DATA_CONTENT_TYPE_ATTR = "datacontenttype"
DATA_SCHEMA_ATTR = "dataschema"
AGENT_SENDER_TYPE_ATTR = "agagentsendertype"
AGENT_SENDER_KEY_ATTR = "agagentsenderkey"
MESSAGE_KIND_ATTR = "agmsgkind"
MESSAGE_KIND_VALUE_PUBLISH = "publish"
MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"
MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"
MESSAGE_KIND_VALUE_RPC_ERROR = "error"

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@
cast,
)

from google.protobuf import any_pb2
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated

from autogen_core.application.protos import cloudevent_pb2

from ..base import (
JSON_DATA_CONTENT_TYPE,
PROTOBUF_DATA_CONTENT_TYPE,
Agent,
AgentId,
AgentInstantiationContext,
Expand All @@ -49,8 +53,9 @@
from ..base._serialization import MessageSerializer, SerializationRegistry
from ..base._type_helpers import ChannelArgumentType
from ..components import TypePrefixSubscription, TypeSubscription
from . import _constants
from ._constants import GRPC_IMPORT_ERROR_STR
from ._helpers import SubscriptionManager, get_impl
from ._utils import GRPC_IMPORT_ERROR_STR
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata

Expand Down Expand Up @@ -178,6 +183,7 @@ def __init__(
host_address: str,
tracer_provider: TracerProvider | None = None,
extra_grpc_config: ChannelArgumentType | None = None,
payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
) -> None:
self._host_address = host_address
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
Expand All @@ -198,6 +204,11 @@ def __init__(
self._serialization_registry = SerializationRegistry()
self._extra_grpc_config = extra_grpc_config or []

if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")

self._payload_serialization_format = payload_serialization_format

def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
Expand Down Expand Up @@ -236,8 +247,10 @@ async def _run_read_loop(self) -> None:
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
task = asyncio.create_task(self._process_event(message.event))
case "cloudEvent":
# The proto typing doesnt resolve this one
cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(cloud_event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
Expand All @@ -257,8 +270,6 @@ async def _run_read_loop(self) -> None:
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("No message")
case other:
logger.error(f"Unknown message type: {other}")
except Exception as e:
logger.error("Error in read loop", exc_info=e)

Expand Down Expand Up @@ -381,30 +392,64 @@ async def publish_message(
if message_id is None:
message_id = str(uuid.uuid4())

# TODO: consume message_id

message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
):
serialized_message = self._serialization_registry.serialize(
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
message, type_name=message_type, data_content_type=self._payload_serialization_format
)
telemetry_metadata = get_telemetry_grpc_metadata()
runtime_message = agent_worker_pb2.Message(
event=agent_worker_pb2.Event(
topic_type=topic_id.type,
topic_source=topic_id.source,
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=message_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),

sender_id = sender or AgentId("unknown", "unknown")
attributes = {
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=self._payload_serialization_format
),
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.type
),
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=sender_id.key
),
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
),
}

# If sending JSON we fill text_data with the serialized message
# If sending Protobuf we fill proto_data with the serialized message
# TODO: add an encoding field for serializer

if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
# TODO: use text, or proto fields appropriately
binary_data=serialized_message,
)
)
else:
# We need to unpack the serialized proto back into an Any
# TODO: find a way to prevent the roundtrip serialization
any_proto = any_pb2.Any()
any_proto.ParseFromString(serialized_message)
runtime_message = agent_worker_pb2.Message(
cloudEvent=cloudevent_pb2.CloudEvent(
id=message_id,
spec_version="1.0",
type=topic_id.type,
source=topic_id.source,
attributes=attributes,
proto_data=any_proto,
)
)
)

telemetry_metadata = get_telemetry_grpc_metadata()
task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
Expand Down Expand Up @@ -523,28 +568,58 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> Non
else:
future.set_result(result)

async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = self._serialization_registry.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
event_attributes = event.attributes
sender: AgentId | None = None
if event.HasField("source"):
sender = AgentId(event.source.type, event.source.key)
topic_id = TopicId(event.topic_type, event.topic_source)
if (
_constants.AGENT_SENDER_TYPE_ATTR in event_attributes
and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
):
sender = AgentId(
event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
)
topic_id = TopicId(event.type, event.source)
# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)

message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string

if message_content_type == JSON_DATA_CONTENT_TYPE:
message = self._serialization_registry.deserialize(
event.binary_data, type_name=message_type, data_content_type=message_content_type
)
elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
# TODO: find a way to prevent the roundtrip serialization
proto_binary_data = event.proto_data.SerializeToString()
message = self._serialization_registry.deserialize(
proto_binary_data, type_name=message_type, data_content_type=message_content_type
)
else:
raise ValueError(f"Unsupported message content type: {message_content_type}")

# TODO: dont read these values in the runtime
topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
is_marked_rpc_type = (
_constants.MESSAGE_KIND_ATTR in event_attributes
and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
)
if is_rpc and not is_marked_rpc_type:
warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)

# Send the message to each recipient.
responses: List[Awaitable[Any]] = []
for agent_id in recipients:
if agent_id == sender:
continue
# TODO: consume message_id
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
is_rpc=is_rpc,
cancellation_token=CancellationToken(),
message_id="NOT_DEFINED_TODO_FIX",
message_id=event.id,
)
agent = await self._get_agent(agent_id)
with MessageHandlerContext.populate_context(agent.id):
Expand All @@ -554,7 +629,7 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any:
"process",
agent.id,
parent=event.metadata,
extraAttributes={"message_type": event.payload.data_type},
extraAttributes={"message_type": message_type},
):
await agent.on_message(message, ctx=message_context)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Sequence

from ..base._type_helpers import ChannelArgumentType
from ._utils import GRPC_IMPORT_ERROR_STR
from ._constants import GRPC_IMPORT_ERROR_STR
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
import logging
from _collections_abc import AsyncIterator, Iterator
from asyncio import Future, Task
from typing import Any, Dict, Set
from typing import Any, Dict, Set, cast

from autogen_core.base._type_prefix_subscription import TypePrefixSubscription

from ..base import Subscription, TopicId
from ..components import TypeSubscription
from ._constants import GRPC_IMPORT_ERROR_STR
from ._helpers import SubscriptionManager
from ._utils import GRPC_IMPORT_ERROR_STR

try:
import grpc
except ImportError as e:
raise ImportError(GRPC_IMPORT_ERROR_STR) from e

from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2

logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
Expand Down Expand Up @@ -84,7 +84,7 @@ async def _on_client_disconnect(self, client_id: int) -> None:
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []):
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
logger.info(f"Client {client_id} disconnected successfully")
Expand Down Expand Up @@ -114,8 +114,9 @@ async def _receive_messages(
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "event":
event: agent_worker_pb2.Event = message.event
case "cloudEvent":
# The proto typing doesnt resolve this one
event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
task = asyncio.create_task(self._process_event(event))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
Expand All @@ -138,8 +139,6 @@ async def _receive_messages(
logger.warning(f"Received unexpected message type: {oneofcase}")
case None:
logger.warning("Received empty message")
case other:
logger.error(f"Received unexpected message: {other}")

async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
# Deliver the message to a client given the target agent type.
Expand Down Expand Up @@ -178,8 +177,8 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse, client
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)

async def _process_event(self, event: agent_worker_pb2.Event) -> None:
topic_id = TopicId(type=event.topic_type, source=event.topic_source)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
topic_id = TopicId(type=event.type, source=event.source)
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Get the client ids of the recipients.
async with self._agent_type_to_client_id_lock:
Expand All @@ -192,7 +191,7 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
# Deliver the event to clients.
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))

async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int
Expand Down
Loading

0 comments on commit b506c3b

Please sign in to comment.