Skip to content

Commit

Permalink
Add type prefix subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Nov 26, 2024
1 parent 5aecb56 commit ecce47e
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 59 deletions.
2 changes: 1 addition & 1 deletion docs/design/04 - Agent and Topic ID Specs.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This document describes the structure, constraints, and behavior of Agent IDs an

- Type: `string`
- Description: Topic type is usually defined by application code to mark the type of messages the topic is for.
- Constraints: UTF8 and only contain alphanumeric letters (a-z) and (0-9), or underscores (\_). A valid identifier cannot start with a number, or contain any spaces.
- Constraints: UTF8 and only contain alphanumeric letters (a-z) and (0-9), ':', '=', or underscores (\_). A valid identifier cannot start with a number, or contain any spaces.
- Examples:
- `GitHub_Issues`

Expand Down
6 changes: 6 additions & 0 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ message TypeSubscription {
string agent_type = 2;
}

message TypePrefixSubscription {
string topic_type_prefix = 1;
string agent_type = 2;
}

message Subscription {
oneof subscription {
TypeSubscription typeSubscription = 1;
TypePrefixSubscription typePrefixSubscription = 2;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from ..base._serialization import MessageSerializer, SerializationRegistry
from ..base._type_helpers import ChannelArgumentType
from ..components import TypeSubscription
from ..components import TypePrefixSubscription, TypeSubscription
from ._helpers import SubscriptionManager, get_impl
from ._utils import GRPC_IMPORT_ERROR_STR
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
Expand Down Expand Up @@ -705,27 +705,44 @@ async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = A
async def add_subscription(self, subscription: Subscription) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if not isinstance(subscription, TypeSubscription):
raise ValueError("Only TypeSubscription is supported.")
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()

match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=topic_type, agent_type=agent_type
)
),
)
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
)
),
)
)
case _:
raise ValueError("Unsupported subscription type.")

# Add the future to the pending requests.
self._pending_requests[request_id] = future

# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Send the subscription to the host.
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=subscription.topic_type, agent_type=subscription.agent_type
)
),
)
)
await self._host_connection.send(message)

# Wait for the subscription response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from asyncio import Future, Task
from typing import Any, Dict, Set

from ..base import TopicId
from autogen_core.components._type_prefix_subscription import TypePrefixSubscription

from ..base import Subscription, TopicId
from ..components import TypeSubscription
from ._helpers import SubscriptionManager
from ._utils import GRPC_IMPORT_ERROR_STR
Expand Down Expand Up @@ -221,34 +223,46 @@ async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
subscription: Subscription | None = None
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
add_subscription_req.subscription.typeSubscription
)
type_subscription = TypeSubscription(
subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
)
try:
await self._subscription_manager.add_subscription(type_subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(type_subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)

case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = (
add_subscription_req.subscription.typePrefixSubscription
)
subscription = TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
)
case None:
logger.warning("Received empty subscription message")

if subscription is not None:
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)
)

async def GetState( # type: ignore
self,
request: agent_worker_pb2.AgentId,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ecce47e

Please sign in to comment.