Skip to content

Commit

Permalink
Fix deprecated usages (#4374)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Nov 27, 2024
1 parent fe96f7d commit 45f16f5
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 113 deletions.
1 change: 1 addition & 0 deletions python/packages/autogen-agentchat/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include = ["src/**", "tests/*.py"]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["src", "tests"]
reportDeprecated = true

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down
2 changes: 1 addition & 1 deletion python/packages/autogen-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ include = ["src/**", "samples/*.py", "docs/**/*.ipynb", "tests/**"]
extends = "../../pyproject.toml"
include = ["src", "tests", "samples"]
exclude = ["src/autogen_core/application/protos", "tests/protos"]
reportDeprecated = false
reportDeprecated = true

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down
20 changes: 12 additions & 8 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, CancellationToken, MessageContext
from autogen_core.base.intervention import DefaultInterventionHandler
from autogen_core.components import DefaultSubscription, DefaultTopicId, FunctionCall, RoutedAgent, message_handler
from autogen_core.components import (
DefaultTopicId,
FunctionCall,
RoutedAgent,
message_handler,
type_subscription,
)
from autogen_core.components.model_context import BufferedChatCompletionContext
from autogen_core.components.models import (
AssistantMessage,
Expand Down Expand Up @@ -81,6 +87,7 @@ def save_content(self, content: Mapping[str, Any]) -> None:
state_persister = MockPersistence()


@type_subscription("scheduling_assistant_conversation")
class SlowUserProxyAgent(RoutedAgent):
def __init__(
self,
Expand Down Expand Up @@ -132,6 +139,7 @@ async def run(self, args: ScheduleMeetingInput, cancellation_token: Cancellation
return ScheduleMeetingOutput()


@type_subscription("scheduling_assistant_conversation")
class SchedulingAssistantAgent(RoutedAgent):
def __init__(
self,
Expand Down Expand Up @@ -256,24 +264,20 @@ async def main(latest_user_input: Optional[str] = None) -> None | str:
needs_user_input_handler = NeedsUserInputHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[needs_user_input_handler, termination_handler])

await runtime.register(
"User",
lambda: SlowUserProxyAgent("User", "I am a user"),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)
await SlowUserProxyAgent.register(runtime, "User", lambda: SlowUserProxyAgent("User", "I am a user"))

initial_schedule_assistant_message = AssistantTextMessage(
content="Hi! How can I help you? I can help schedule meetings", source="User"
)
await runtime.register(
await SchedulingAssistantAgent.register(
runtime,
"SchedulingAssistant",
lambda: SchedulingAssistantAgent(
"SchedulingAssistant",
description="AI that helps you schedule meetings",
model_client=get_chat_completion_client_from_envs(model="gpt-4o-mini"),
initial_message=initial_schedule_assistant_message,
),
subscriptions=lambda: [DefaultSubscription("scheduling_assistant_conversation")],
)

if latest_user_input is not None:
Expand Down
12 changes: 7 additions & 5 deletions python/packages/autogen-core/tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def on_new_message(self, message: MessageType, ctx: MessageContext) -> Mes
async def test_cancellation_with_token() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("long_running", LongRunningAgent)
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
agent_id = AgentId("long_running", key="default")
token = CancellationToken()
response = asyncio.create_task(runtime.send_message(MessageType(), recipient=agent_id, cancellation_token=token))
Expand All @@ -85,8 +85,9 @@ async def test_cancellation_with_token() -> None:
async def test_nested_cancellation_only_outer_called() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("long_running", LongRunningAgent)
await runtime.register(
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
Expand Down Expand Up @@ -119,8 +120,9 @@ async def test_nested_cancellation_only_outer_called() -> None:
async def test_nested_cancellation_inner_called() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("long_running", LongRunningAgent)
await runtime.register(
await LongRunningAgent.register(runtime, "long_running", LongRunningAgent)
await NestingLongRunningAgent.register(
runtime,
"nested",
lambda: NestingLongRunningAgent(AgentId("long_running", key=AgentInstantiationContext.current_agent_id().key)),
)
Expand Down
10 changes: 5 additions & 5 deletions python/packages/autogen-core/tests/test_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie

handler = DebugInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])
await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()

Expand All @@ -42,7 +42,7 @@ async def on_send(
handler = DropSendInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])

await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()

Expand All @@ -66,7 +66,7 @@ async def on_response(
handler = DropResponseInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])

await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()

Expand All @@ -90,7 +90,7 @@ async def on_send(
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])

await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()

Expand All @@ -117,7 +117,7 @@ async def on_response(
handler = ExceptionInterventionHandler()
runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler])

await runtime.register("name", LoopbackAgent)
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
loopback = AgentId("name", key="default")
runtime.start()
with pytest.raises(InterventionException):
Expand Down
14 changes: 9 additions & 5 deletions python/packages/autogen-core/tests/test_routed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ async def on_broadcast_message(self, message: MessageType, ctx: MessageContext)
async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
runtime = SingleThreadedAgentRuntime()
with caplog.at_level(logging.INFO):
await runtime.register("loopback", LoopbackAgent, lambda: [TypeSubscription("default", "loopback")])
await LoopbackAgent.register(runtime, "loopback", LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "loopback"))
runtime.start()
await runtime.publish_message(UnhandledMessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()
Expand All @@ -47,7 +48,8 @@ async def test_routed_agent(caplog: pytest.LogCaptureFixture) -> None:
@pytest.mark.asyncio
async def test_message_handler_router() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", CounterAgent, lambda: [TypeSubscription("default", "counter")])
await CounterAgent.register(runtime, "counter", CounterAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")

# Send a broadcast message.
Expand Down Expand Up @@ -94,7 +96,7 @@ async def handler_two(self, message: TestMessage, ctx: MessageContext) -> None:
@pytest.mark.asyncio
async def test_routed_agent_message_matching() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("message_match", RoutedAgentMessageCustomMatch)
await RoutedAgentMessageCustomMatch.register(runtime, "message_match", RoutedAgentMessageCustomMatch)
agent_id = AgentId(type="message_match", key="default")

agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch)
Expand Down Expand Up @@ -134,7 +136,8 @@ async def on_event_two(self, message: TestMessage, ctx: MessageContext) -> None:
@pytest.mark.asyncio
async def test_event() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", EventAgent, lambda: [TypeSubscription("default", "counter")])
await EventAgent.register(runtime, "counter", EventAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")

# Send a broadcast message.
Expand Down Expand Up @@ -181,7 +184,8 @@ async def on_rpc_two(self, message: TestMessage, ctx: MessageContext) -> TestMes
@pytest.mark.asyncio
async def test_rpc() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register("counter", RPCAgent, lambda: [TypeSubscription("default", "counter")])
await RPCAgent.register(runtime, "counter", RPCAgent)
await runtime.add_subscription(TypeSubscription("default", "counter"))
agent_id = AgentId(type="counter", key="default")

# Send an RPC message.
Expand Down
86 changes: 3 additions & 83 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging

import pytest
Expand All @@ -7,16 +6,10 @@
AgentId,
AgentInstantiationContext,
AgentType,
Subscription,
SubscriptionInstantiationContext,
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import (
DefaultTopicId,
TypeSubscription,
type_subscription,
)
from autogen_core.components import DefaultTopicId, TypeSubscription, type_subscription
from opentelemetry.sdk.trace import TracerProvider
from test_utils import (
CascadingAgent,
Expand Down Expand Up @@ -146,82 +139,9 @@ async def test_register_receives_publish_cascade() -> None:
async def test_register_factory_explicit_name() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_register_factory_context_var_name() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register(
"name", LoopbackAgent, lambda: [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]
)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_register_factory_async() -> None:
runtime = SingleThreadedAgentRuntime()

async def sub_factory() -> list[Subscription]:
await asyncio.sleep(0.1)
return [TypeSubscription("default", SubscriptionInstantiationContext.agent_type().type)]

await runtime.register("name", LoopbackAgent, sub_factory)
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_register_factory_direct_list() -> None:
runtime = SingleThreadedAgentRuntime()
await LoopbackAgent.register(runtime, "name", LoopbackAgent)
await runtime.add_subscription(TypeSubscription("default", "name"))

await runtime.register("name", LoopbackAgent, [TypeSubscription("default", "name")])
runtime.start()
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
Expand Down
6 changes: 3 additions & 3 deletions python/packages/autogen-core/tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
async def test_agent_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name1", StatefulAgent)
await StatefulAgent.register(runtime, "name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
Expand All @@ -44,7 +44,7 @@ async def test_agent_can_save_state() -> None:
async def test_runtime_can_save_state() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name1", StatefulAgent)
await StatefulAgent.register(runtime, "name1", StatefulAgent)
agent1_id = AgentId("name1", key="default")
agent1: StatefulAgent = await runtime.try_get_underlying_agent_instance(agent1_id, type=StatefulAgent)
assert agent1.state == 0
Expand All @@ -54,7 +54,7 @@ async def test_runtime_can_save_state() -> None:
runtime_state = await runtime.save_state()

runtime2 = SingleThreadedAgentRuntime()
await runtime2.register("name1", StatefulAgent)
await StatefulAgent.register(runtime2, "name1", StatefulAgent)
agent2_id = AgentId("name1", key="default")
agent2: StatefulAgent = await runtime2.try_get_underlying_agent_instance(agent2_id, type=StatefulAgent)

Expand Down
2 changes: 1 addition & 1 deletion python/packages/autogen-core/tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_type_subscription_map() -> None:
async def test_non_default_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("MyAgent", LoopbackAgent)
await LoopbackAgent.register(runtime, "MyAgent", LoopbackAgent, skip_class_subscriptions=True)
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()
Expand Down
6 changes: 4 additions & 2 deletions python/packages/autogen-core/tests/test_tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ async def _async_sleep_function(input: str) -> str:
@pytest.mark.asyncio
async def test_tool_agent() -> None:
runtime = SingleThreadedAgentRuntime()
await runtime.register(
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
Expand Down Expand Up @@ -143,7 +144,8 @@ def capabilities(self) -> ModelCapabilities:
client = MockChatCompletionClient()
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
runtime = SingleThreadedAgentRuntime()
await runtime.register(
await ToolAgent.register(
runtime,
"tool_agent",
lambda: ToolAgent(
description="Tool agent",
Expand Down

0 comments on commit 45f16f5

Please sign in to comment.