Skip to content

Commit

Permalink
Merge branch 'main' into filesurfer
Browse files Browse the repository at this point in the history
  • Loading branch information
afourney committed Nov 26, 2024
2 parents 2caf13c + cf80b1b commit acea6e1
Show file tree
Hide file tree
Showing 21 changed files with 348 additions and 63 deletions.
1 change: 1 addition & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ jobs:
run: |
source ${{ github.workspace }}/python/.venv/bin/activate
poe gen-proto
poe gen-test-proto
working-directory: ./python
- name: Check if there are uncommited changes
id: changes
Expand Down
2 changes: 1 addition & 1 deletion docs/design/04 - Agent and Topic ID Specs.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This document describes the structure, constraints, and behavior of Agent IDs an

- Type: `string`
- Description: Topic type is usually defined by application code to mark the type of messages the topic is for.
- Constraints: UTF8 and only contain alphanumeric letters (a-z) and (0-9), or underscores (\_). A valid identifier cannot start with a number, or contain any spaces.
- Constraints: UTF8 and only contain alphanumeric letters (a-z) and (0-9), ':', '=', or underscores (\_). A valid identifier cannot start with a number, or contain any spaces.
- Examples:
- `GitHub_Issues`

Expand Down
6 changes: 6 additions & 0 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ message TypeSubscription {
string agent_type = 2;
}

message TypePrefixSubscription {
string topic_type_prefix = 1;
string agent_type = 2;
}

message Subscription {
oneof subscription {
TypeSubscription typeSubscription = 1;
TypePrefixSubscription typePrefixSubscription = 2;
}
}

Expand Down
6 changes: 3 additions & 3 deletions python/packages/autogen-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ dev-dependencies = [

[tool.ruff]
extend = "../../pyproject.toml"
exclude = ["build", "dist", "src/autogen_core/application/protos"]
exclude = ["build", "dist", "src/autogen_core/application/protos", "tests/protos"]
include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -91,7 +91,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos"]
exclude = ["src/autogen_core/application/protos", "tests/protos"]
reportDeprecated = false

[tool.pytest.ini_options]
Expand All @@ -111,7 +111,7 @@ include = "../../shared_tasks.toml"
test = "pytest -n auto"
mypy.default_item_type = "cmd"
mypy.sequence = [
"mypy --config-file ../../pyproject.toml --exclude src/autogen_core/application/protos src tests",
"mypy --config-file ../../pyproject.toml --exclude src/autogen_core/application/protos --exclude tests/protos src tests",
"nbqa mypy docs/src --config-file ../../pyproject.toml",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
import logging
import threading
import uuid
import warnings
from asyncio import CancelledError, Future, Task
from collections.abc import Sequence
Expand Down Expand Up @@ -53,6 +54,7 @@ class PublishMessageEnvelope:
sender: AgentId | None
topic_id: TopicId
metadata: EnvelopeMetadata | None = None
message_id: str


@dataclass(kw_only=True)
Expand Down Expand Up @@ -256,6 +258,7 @@ async def publish_message(
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
with self._tracer_helper.trace_block(
"create",
Expand All @@ -268,6 +271,9 @@ async def publish_message(
content = message.__dict__ if hasattr(message, "__dict__") else message
logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {content}")

if message_id is None:
message_id = str(uuid.uuid4())

# event_logger.info(
# MessageEvent(
# payload=message,
Expand All @@ -285,6 +291,7 @@ async def publish_message(
sender=sender,
topic_id=topic_id,
metadata=get_telemetry_envelope_metadata(),
message_id=message_id,
)
)

Expand Down Expand Up @@ -327,6 +334,8 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
# Will be fixed when send API removed
message_id="NOT_DEFINED_TODO_FIX",
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
Expand Down Expand Up @@ -385,6 +394,7 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No
topic_id=message_envelope.topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
agent = await self._get_agent(agent_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import signal
import uuid
import warnings
from asyncio import Future, Task
from collections import defaultdict
Expand Down Expand Up @@ -47,7 +48,7 @@
)
from ..base._serialization import MessageSerializer, SerializationRegistry
from ..base._type_helpers import ChannelArgumentType
from ..components import TypeSubscription
from ..components import TypePrefixSubscription, TypeSubscription
from ._helpers import SubscriptionManager, get_impl
from ._utils import GRPC_IMPORT_ERROR_STR
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
Expand Down Expand Up @@ -371,11 +372,17 @@ async def publish_message(
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
message_id: str | None = None,
) -> None:
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if message_id is None:
message_id = str(uuid.uuid4())

# TODO: consume message_id

message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
Expand Down Expand Up @@ -447,6 +454,7 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
topic_id=None,
is_rpc=True,
cancellation_token=CancellationToken(),
message_id=request.request_id,
)

# Call the receiving agent.
Expand Down Expand Up @@ -530,11 +538,13 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
for agent_id in recipients:
if agent_id == sender:
continue
# TODO: consume message_id
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=CancellationToken(),
message_id="NOT_DEFINED_TODO_FIX",
)
agent = await self._get_agent(agent_id)
with MessageHandlerContext.populate_context(agent.id):
Expand Down Expand Up @@ -705,27 +715,44 @@ async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = A
async def add_subscription(self, subscription: Subscription) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if not isinstance(subscription, TypeSubscription):
raise ValueError("Only TypeSubscription is supported.")
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()

match subscription:
case TypeSubscription(topic_type=topic_type, agent_type=agent_type):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=topic_type, agent_type=agent_type
)
),
)
)
case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type):
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription(
topic_type_prefix=topic_type_prefix, agent_type=agent_type
)
),
)
)
case _:
raise ValueError("Unsupported subscription type.")

# Add the future to the pending requests.
self._pending_requests[request_id] = future

# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Send the subscription to the host.
message = agent_worker_pb2.Message(
addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id,
subscription=agent_worker_pb2.Subscription(
typeSubscription=agent_worker_pb2.TypeSubscription(
topic_type=subscription.topic_type, agent_type=subscription.agent_type
)
),
)
)
await self._host_connection.send(message)

# Wait for the subscription response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from asyncio import Future, Task
from typing import Any, Dict, Set

from ..base import TopicId
from autogen_core.components._type_prefix_subscription import TypePrefixSubscription

from ..base import Subscription, TopicId
from ..components import TypeSubscription
from ._helpers import SubscriptionManager
from ._utils import GRPC_IMPORT_ERROR_STR
Expand Down Expand Up @@ -221,34 +223,46 @@ async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: int
) -> None:
oneofcase = add_subscription_req.subscription.WhichOneof("subscription")
subscription: Subscription | None = None
match oneofcase:
case "typeSubscription":
type_subscription_msg: agent_worker_pb2.TypeSubscription = (
add_subscription_req.subscription.typeSubscription
)
type_subscription = TypeSubscription(
subscription = TypeSubscription(
topic_type=type_subscription_msg.topic_type, agent_type=type_subscription_msg.agent_type
)
try:
await self._subscription_manager.add_subscription(type_subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(type_subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)

case "typePrefixSubscription":
type_prefix_subscription_msg: agent_worker_pb2.TypePrefixSubscription = (
add_subscription_req.subscription.typePrefixSubscription
)
subscription = TypePrefixSubscription(
topic_type_prefix=type_prefix_subscription_msg.topic_type_prefix,
agent_type=type_prefix_subscription_msg.agent_type,
)
case None:
logger.warning("Received empty subscription message")

if subscription is not None:
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)
)

async def GetState( # type: ignore
self,
request: agent_worker_pb2.AgentId,
Expand Down
Loading

0 comments on commit acea6e1

Please sign in to comment.