Skip to content

Commit

Permalink
Add agent_config field to request for custom configs (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaC215 authored Jan 22, 2025
1 parent 4d14563 commit 5db28a3
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 4 deletions.
32 changes: 29 additions & 3 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def update_agent(self, agent: str, verify: bool = True) -> None:
self.agent = agent

async def ainvoke(
self, message: str, model: str | None = None, thread_id: str | None = None
self,
message: str,
model: str | None = None,
thread_id: str | None = None,
agent_config: dict[str, Any] | None = None,
) -> ChatMessage:
"""
Invoke the agent asynchronously. Only the final message is returned.
Expand All @@ -93,13 +97,20 @@ async def ainvoke(
message (str): The message to send to the agent
model (str, optional): LLM model to use for the agent
thread_id (str, optional): Thread ID for continuing a conversation
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
Returns:
AnyMessage: The response from the agent
"""
if not self.agent:
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
request = UserInput(message=message, thread_id=thread_id, model=model)
request = UserInput(message=message)
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
if agent_config:
request.agent_config = agent_config
async with httpx.AsyncClient() as client:
try:
response = await client.post(
Expand All @@ -115,7 +126,11 @@ async def ainvoke(
return ChatMessage.model_validate(response.json())

def invoke(
self, message: str, model: str | None = None, thread_id: str | None = None
self,
message: str,
model: str | None = None,
thread_id: str | None = None,
agent_config: dict[str, Any] | None = None,
) -> ChatMessage:
"""
Invoke the agent synchronously. Only the final message is returned.
Expand All @@ -124,6 +139,7 @@ def invoke(
message (str): The message to send to the agent
model (str, optional): LLM model to use for the agent
thread_id (str, optional): Thread ID for continuing a conversation
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
Returns:
ChatMessage: The response from the agent
Expand All @@ -135,6 +151,8 @@ def invoke(
request.thread_id = thread_id
if model:
request.model = model
if agent_config:
request.agent_config = agent_config
try:
response = httpx.post(
f"{self.base_url}/{self.agent}/invoke",
Expand Down Expand Up @@ -177,6 +195,7 @@ def stream(
message: str,
model: str | None = None,
thread_id: str | None = None,
agent_config: dict[str, Any] | None = None,
stream_tokens: bool = True,
) -> Generator[ChatMessage | str, None, None]:
"""
Expand All @@ -190,6 +209,7 @@ def stream(
message (str): The message to send to the agent
model (str, optional): LLM model to use for the agent
thread_id (str, optional): Thread ID for continuing a conversation
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
stream_tokens (bool, optional): Stream tokens as they are generated
Default: True
Expand All @@ -203,6 +223,8 @@ def stream(
request.thread_id = thread_id
if model:
request.model = model
if agent_config:
request.agent_config = agent_config
try:
with httpx.stream(
"POST",
Expand All @@ -226,6 +248,7 @@ async def astream(
message: str,
model: str | None = None,
thread_id: str | None = None,
agent_config: dict[str, Any] | None = None,
stream_tokens: bool = True,
) -> AsyncGenerator[ChatMessage | str, None]:
"""
Expand All @@ -239,6 +262,7 @@ async def astream(
message (str): The message to send to the agent
model (str, optional): LLM model to use for the agent
thread_id (str, optional): Thread ID for continuing a conversation
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
stream_tokens (bool, optional): Stream tokens as they are generated
Default: True
Expand All @@ -252,6 +276,8 @@ async def astream(
request.thread_id = thread_id
if model:
request.model = model
if agent_config:
request.agent_config = agent_config
async with httpx.AsyncClient() as client:
try:
async with client.stream(
Expand Down
5 changes: 5 additions & 0 deletions src/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class UserInput(BaseModel):
default=None,
examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
)
agent_config: dict[str, Any] = Field(
description="Additional configuration to pass through to the agent",
default={},
examples=[{"spicy_level": 0.8}],
)


class StreamInput(UserInput):
Expand Down
13 changes: 12 additions & 1 deletion src/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,21 @@ async def info() -> ServiceMetadata:
def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]:
run_id = uuid4()
thread_id = user_input.thread_id or str(uuid4())

configurable = {"thread_id": thread_id, "model": user_input.model}

if user_input.agent_config:
if overlap := configurable.keys() & user_input.agent_config.keys():
raise HTTPException(
status_code=422, detail=f"agent_config contains reserved keys: {overlap}"
)
configurable.update(user_input.agent_config)

kwargs = {
"input": {"messages": [HumanMessage(content=user_input.message)]},
"config": RunnableConfig(
configurable={"thread_id": thread_id, "model": user_input.model}, run_id=run_id
configurable=configurable,
run_id=run_id,
),
}
return kwargs, run_id
Expand Down
32 changes: 32 additions & 0 deletions tests/service/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,38 @@ def test_invoke_model_param(test_client, mock_agent) -> None:
assert response.status_code == 422


def test_invoke_custom_agent_config(test_client, mock_agent) -> None:
"""Test that the agent_config parameter is correctly passed to the agent."""
QUESTION = "What is the weather in Tokyo?"
ANSWER = "The weather in Tokyo is sunny."
CUSTOM_CONFIG = {"spicy_level": 0.1, "additional_param": "value_foo"}

mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=ANSWER)]}

response = test_client.post(
"/invoke", json={"message": QUESTION, "agent_config": CUSTOM_CONFIG}
)
assert response.status_code == 200

# Verify the agent_config was passed correctly in the config
mock_agent.ainvoke.assert_awaited_once()
config = mock_agent.ainvoke.await_args.kwargs["config"]
assert config["configurable"]["spicy_level"] == 0.1
assert config["configurable"]["additional_param"] == "value_foo"

# Verify the response is still correct
output = ChatMessage.model_validate(response.json())
assert output.type == "ai"
assert output.content == ANSWER

# Verify a reserved key in agent_config throws a validation error
INVALID_CONFIG = {"model": "gpt-4o"}
response = test_client.post(
"/invoke", json={"message": QUESTION, "agent_config": INVALID_CONFIG}
)
assert response.status_code == 422


@patch("service.service.LangsmithClient")
def test_feedback(mock_client: langsmith.Client, test_client) -> None:
ls_instance = mock_client.return_value
Expand Down

0 comments on commit 5db28a3

Please sign in to comment.