From deb30a4728fcc1f3029e0f5b26eddf1df1873beb Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Wed, 4 Dec 2024 21:44:02 -0800 Subject: [PATCH 1/9] Define types and lay the groundwork for service metadata --- src/agents/__init__.py | 4 ++-- src/agents/agents.py | 29 ++++++++++++++++++++++++---- src/agents/research_assistant.py | 6 +++++- src/core/settings.py | 32 +++++++++++++++++++++---------- src/schema/__init__.py | 4 ++++ src/schema/schema.py | 20 +++++++++++++++++++ src/service/service.py | 14 ++++++++------ tests/core/test_settings.py | 22 +++++++++++++++++++++ tests/service/conftest.py | 3 +-- tests/service/test_service.py | 10 +++++++--- tests/service/test_service_e2e.py | 8 +++++++- 11 files changed, 123 insertions(+), 29 deletions(-) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index e7d276d..3daf6e1 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -1,3 +1,3 @@ -from agents.agents import DEFAULT_AGENT, agents +from agents.agents import DEFAULT_AGENT, get_agent, get_all_agent_info -__all__ = ["agents", "DEFAULT_AGENT"] +__all__ = ["get_agent", "get_all_agent_info", "DEFAULT_AGENT"] diff --git a/src/agents/agents.py b/src/agents/agents.py index 6d5db41..2404be4 100644 --- a/src/agents/agents.py +++ b/src/agents/agents.py @@ -1,14 +1,35 @@ +from dataclasses import dataclass + from langgraph.graph.state import CompiledStateGraph from agents.bg_task_agent.bg_task_agent import bg_task_agent from agents.chatbot import chatbot from agents.research_assistant import research_assistant +from schema import AgentInfo DEFAULT_AGENT = "research-assistant" -agents: dict[str, CompiledStateGraph] = { - "chatbot": chatbot, - "research-assistant": research_assistant, - "bg-task-agent": bg_task_agent, +@dataclass +class Agent: + description: str + graph: CompiledStateGraph + + +agents: dict[str, Agent] = { + "chatbot": Agent(description="A simple chatbot.", graph=chatbot), + "research-assistant": Agent( + description="A research assistant with web search and calculator.", graph=research_assistant + ), + "bg-task-agent": Agent(description="A background task agent.", graph=bg_task_agent), } + + +def get_agent(agent_id: str) -> CompiledStateGraph: + return agents[agent_id].graph + + +def get_all_agent_info() -> list[AgentInfo]: + return [ + AgentInfo(key=agent_id, description=agent.description) for agent_id, agent in agents.items() + ] diff --git a/src/agents/research_assistant.py b/src/agents/research_assistant.py index 6a4a149..cdcc9af 100644 --- a/src/agents/research_assistant.py +++ b/src/agents/research_assistant.py @@ -2,6 +2,7 @@ from typing import Literal from langchain_community.tools import DuckDuckGoSearchResults, OpenWeatherMapQueryRun +from langchain_community.utilities import OpenWeatherMapAPIWrapper from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable @@ -31,7 +32,10 @@ class AgentState(MessagesState, total=False): # Add weather tool if API key is set # Register for an API key at https://openweathermap.org/api/ if settings.OPENWEATHERMAP_API_KEY: - tools.append(OpenWeatherMapQueryRun(name="Weather")) + wrapper = OpenWeatherMapAPIWrapper( + openweathermap_api_key=settings.OPENWEATHERMAP_API_KEY.get_secret_value() + ) + tools.append(OpenWeatherMapQueryRun(name="Weather", api_wrapper=wrapper)) current_date = datetime.now().strftime("%B %d, %Y") instructions = f""" diff --git a/src/core/settings.py b/src/core/settings.py index c4d9b0b..4ad0a36 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -45,6 +45,7 @@ class Settings(BaseSettings): # If DEFAULT_MODEL is None, it will be set in model_post_init DEFAULT_MODEL: AllModelEnum | None = None # type: ignore[assignment] + AVAILABLE_MODELS: set[AllModelEnum] = set() # type: ignore[assignment] OPENWEATHERMAP_API_KEY: SecretStr | None = None @@ -68,23 +69,34 @@ def model_post_init(self, __context: Any) -> None: if not active_keys: raise ValueError("At least one LLM API key must be provided.") - if self.DEFAULT_MODEL is None: - first_provider = next(iter(active_keys)) - match first_provider: + for provider in active_keys: + match provider: case Provider.OPENAI: - self.DEFAULT_MODEL = OpenAIModelName.GPT_4O_MINI + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = OpenAIModelName.GPT_4O_MINI + self.AVAILABLE_MODELS.update(set(OpenAIModelName)) case Provider.ANTHROPIC: - self.DEFAULT_MODEL = AnthropicModelName.HAIKU_3 + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = AnthropicModelName.HAIKU_3 + self.AVAILABLE_MODELS.update(set(AnthropicModelName)) case Provider.GOOGLE: - self.DEFAULT_MODEL = GoogleModelName.GEMINI_15_FLASH + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = GoogleModelName.GEMINI_15_FLASH + self.AVAILABLE_MODELS.update(set(GoogleModelName)) case Provider.GROQ: - self.DEFAULT_MODEL = GroqModelName.LLAMA_31_8B + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = GroqModelName.LLAMA_31_8B + self.AVAILABLE_MODELS.update(set(GroqModelName)) case Provider.AWS: - self.DEFAULT_MODEL = AWSModelName.BEDROCK_HAIKU + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = AWSModelName.BEDROCK_HAIKU + self.AVAILABLE_MODELS.update(set(AWSModelName)) case Provider.FAKE: - self.DEFAULT_MODEL = FakeModelName.FAKE + if self.DEFAULT_MODEL is None: + self.DEFAULT_MODEL = FakeModelName.FAKE + self.AVAILABLE_MODELS.update(set(FakeModelName)) case _: - raise ValueError(f"Unknown provider: {first_provider}") + raise ValueError(f"Unknown provider: {provider}") @computed_field @property diff --git a/src/schema/__init__.py b/src/schema/__init__.py index a873398..3dc70f9 100644 --- a/src/schema/__init__.py +++ b/src/schema/__init__.py @@ -1,18 +1,22 @@ from schema.models import AllModelEnum from schema.schema import ( + AgentInfo, ChatHistory, ChatHistoryInput, ChatMessage, Feedback, FeedbackResponse, + ServiceMetadata, StreamInput, UserInput, ) __all__ = [ + "AgentInfo", "AllModelEnum", "UserInput", "ChatMessage", + "ServiceMetadata", "StreamInput", "Feedback", "FeedbackResponse", diff --git a/src/schema/schema.py b/src/schema/schema.py index bc14f27..4e6f57e 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -6,6 +6,26 @@ from schema.models import AllModelEnum +class AgentInfo(BaseModel): + """Info about an available agent.""" + + key: str = Field( + description="Agent key.", + examples=["research-assistant"], + ) + description: str = Field( + description="Description of the agent.", + examples=["A research assistant for generating research papers."], + ) + + +class ServiceMetadata(BaseModel): + """Metadata about the service including available agents and models.""" + + agents: list[AgentInfo] + models: set[AllModelEnum] + + class UserInput(BaseModel): """Basic user input for the agent.""" diff --git a/src/service/service.py b/src/service/service.py index 130b45d..4e862b4 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -16,7 +16,7 @@ from langgraph.graph.state import CompiledStateGraph from langsmith import Client as LangsmithClient -from agents import DEFAULT_AGENT, agents +from agents import DEFAULT_AGENT, get_agent, get_all_agent_info from core import settings from schema import ( ChatHistory, @@ -55,8 +55,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # Construct agent with Sqlite checkpointer # TODO: It's probably dangerous to share the same checkpointer on multiple agents async with AsyncSqliteSaver.from_conn_string("checkpoints.db") as saver: - for a in agents.values(): - a.checkpointer = saver + agents = get_all_agent_info() + for a in agents: + agent = get_agent(a.key) + agent.checkpointer = saver yield # context manager will clean up the AsyncSqliteSaver on exit @@ -87,7 +89,7 @@ async def invoke(user_input: UserInput, agent_id: str = DEFAULT_AGENT) -> ChatMe Use thread_id to persist and continue a multi-turn conversation. run_id kwarg is also attached to messages for recording feedback. """ - agent: CompiledStateGraph = agents[agent_id] + agent: CompiledStateGraph = get_agent(agent_id) kwargs, run_id = _parse_input(user_input) try: response = await agent.ainvoke(**kwargs) @@ -107,7 +109,7 @@ async def message_generator( This is the workhorse method for the /stream endpoint. """ - agent: CompiledStateGraph = agents[agent_id] + agent: CompiledStateGraph = get_agent(agent_id) kwargs, run_id = _parse_input(user_input) # Process streamed events from the graph and yield messages over the SSE stream. @@ -220,7 +222,7 @@ def history(input: ChatHistoryInput) -> ChatHistory: Get chat history. """ # TODO: Hard-coding DEFAULT_AGENT here is wonky - agent: CompiledStateGraph = agents[DEFAULT_AGENT] + agent: CompiledStateGraph = get_agent(DEFAULT_AGENT) try: state_snapshot = agent.get_state( config=RunnableConfig( diff --git a/tests/core/test_settings.py b/tests/core/test_settings.py index 5145e49..1125ccd 100644 --- a/tests/core/test_settings.py +++ b/tests/core/test_settings.py @@ -40,6 +40,7 @@ def test_settings_with_openai_key(): settings = Settings(_env_file=None) assert settings.OPENAI_API_KEY == SecretStr("test_key") assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI + assert settings.AVAILABLE_MODELS == set(OpenAIModelName) def test_settings_with_anthropic_key(): @@ -47,6 +48,27 @@ def test_settings_with_anthropic_key(): settings = Settings(_env_file=None) assert settings.ANTHROPIC_API_KEY == SecretStr("test_key") assert settings.DEFAULT_MODEL == AnthropicModelName.HAIKU_3 + assert settings.AVAILABLE_MODELS == set(AnthropicModelName) + + +def test_settings_with_multiple_api_keys(): + with patch.dict( + os.environ, + { + "OPENAI_API_KEY": "test_openai_key", + "ANTHROPIC_API_KEY": "test_anthropic_key", + }, + clear=True, + ): + settings = Settings(_env_file=None) + assert settings.OPENAI_API_KEY == SecretStr("test_openai_key") + assert settings.ANTHROPIC_API_KEY == SecretStr("test_anthropic_key") + # When multiple providers are available, OpenAI should be the default + assert settings.DEFAULT_MODEL == OpenAIModelName.GPT_4O_MINI + # Available models should include exactly all OpenAI and Anthropic models + expected_models = set(OpenAIModelName) + expected_models.update(set(AnthropicModelName)) + assert settings.AVAILABLE_MODELS == expected_models def test_settings_base_url(): diff --git a/tests/service/conftest.py b/tests/service/conftest.py index a1cf893..ffb9139 100644 --- a/tests/service/conftest.py +++ b/tests/service/conftest.py @@ -4,7 +4,6 @@ from fastapi.testclient import TestClient from langchain_core.messages import AIMessage -from agents import DEFAULT_AGENT from service import app @@ -20,7 +19,7 @@ def mock_agent(): agent_mock = AsyncMock() agent_mock.ainvoke = AsyncMock(return_value={"messages": [AIMessage(content="Test response")]}) agent_mock.get_state = Mock() # Default empty mock for get_state - with patch.dict("service.service.agents", {DEFAULT_AGENT: agent_mock}): + with patch("service.service.get_agent", Mock(return_value=agent_mock)): yield agent_mock diff --git a/tests/service/test_service.py b/tests/service/test_service.py index a6bbe4e..e9fb8e2 100644 --- a/tests/service/test_service.py +++ b/tests/service/test_service.py @@ -30,7 +30,6 @@ def test_invoke(test_client, mock_agent) -> None: def test_invoke_custom_agent(test_client, mock_agent) -> None: """Test that /invoke works with a custom agent_id path parameter.""" CUSTOM_AGENT = "custom_agent" - DEFAULT_AGENT = "default_agent" QUESTION = "What is the weather in Tokyo?" CUSTOM_ANSWER = "The weather in Tokyo is sunny." DEFAULT_ANSWER = "This is from the default agent." @@ -42,8 +41,13 @@ def test_invoke_custom_agent(test_client, mock_agent) -> None: # Configure our custom mock agent mock_agent.ainvoke.return_value = {"messages": [AIMessage(content=CUSTOM_ANSWER)]} - # Patch the agents dictionary to include both agents - with patch("service.service.agents", {CUSTOM_AGENT: mock_agent, DEFAULT_AGENT: default_mock}): + # Patch get_agent to return the correct agent based on the provided agent_id + def agent_lookup(agent_id): + if agent_id == CUSTOM_AGENT: + return mock_agent + return default_mock + + with patch("service.service.get_agent", side_effect=agent_lookup): response = test_client.post(f"/{CUSTOM_AGENT}/invoke", json={"message": QUESTION}) assert response.status_code == 200 diff --git a/tests/service/test_service_e2e.py b/tests/service/test_service_e2e.py index 7f28e63..ac0942f 100644 --- a/tests/service/test_service_e2e.py +++ b/tests/service/test_service_e2e.py @@ -84,7 +84,13 @@ def test_agent_stream(mock_httpx_stream): # Use stream to get intermediate responses messages = [] - with patch("service.service.agents", {"static-agent": static_agent}): + + def agent_lookup(agent_id): + if agent_id == "static-agent": + return static_agent + return None + + with patch("service.service.get_agent", side_effect=agent_lookup): for response in client.stream("Test message", stream_tokens=False): if isinstance(response, ChatMessage): messages.append(response) From f900f5b3a2b41a6592e13d036a391a4b6e3d1fda Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Wed, 4 Dec 2024 21:51:04 -0800 Subject: [PATCH 2/9] Fix bug with settings active_keys being unordered --- src/core/settings.py | 2 +- src/schema/schema.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/core/settings.py b/src/core/settings.py index 4ad0a36..612ff2a 100644 --- a/src/core/settings.py +++ b/src/core/settings.py @@ -65,7 +65,7 @@ def model_post_init(self, __context: Any) -> None: Provider.AWS: self.USE_AWS_BEDROCK, Provider.FAKE: self.USE_FAKE_MODEL, } - active_keys = {k for k, v in api_keys.items() if v} + active_keys = [k for k, v in api_keys.items() if v] if not active_keys: raise ValueError("At least one LLM API key must be provided.") diff --git a/src/schema/schema.py b/src/schema/schema.py index 4e6f57e..59b7db3 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -24,6 +24,8 @@ class ServiceMetadata(BaseModel): agents: list[AgentInfo] models: set[AllModelEnum] + default_agent: str + default_model: AllModelEnum class UserInput(BaseModel): From e683bb864b6616aa017b6a60be09499247702589 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Sun, 8 Dec 2024 23:10:23 -0800 Subject: [PATCH 3/9] Add /info endpoint for ServiceMetadata --- src/schema/schema.py | 2 +- src/service/service.py | 13 +++++++++++++ tests/service/test_service.py | 25 ++++++++++++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/schema/schema.py b/src/schema/schema.py index 59b7db3..9a96db7 100644 --- a/src/schema/schema.py +++ b/src/schema/schema.py @@ -23,7 +23,7 @@ class ServiceMetadata(BaseModel): """Metadata about the service including available agents and models.""" agents: list[AgentInfo] - models: set[AllModelEnum] + models: list[AllModelEnum] default_agent: str default_model: AllModelEnum diff --git a/src/service/service.py b/src/service/service.py index 4e862b4..1b8cff4 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -24,6 +24,7 @@ ChatMessage, Feedback, FeedbackResponse, + ServiceMetadata, StreamInput, UserInput, ) @@ -67,6 +68,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: router = APIRouter(dependencies=[Depends(verify_bearer)]) +@router.get("/info") +async def info() -> ServiceMetadata: + models = list(settings.AVAILABLE_MODELS) + models.sort() + return ServiceMetadata( + agents=get_all_agent_info(), + models=models, + default_agent=DEFAULT_AGENT, + default_model=settings.DEFAULT_MODEL, + ) + + def _parse_input(user_input: UserInput) -> tuple[dict[str, Any], UUID]: run_id = uuid4() thread_id = user_input.thread_id or str(uuid4()) diff --git a/tests/service/test_service.py b/tests/service/test_service.py index e9fb8e2..4598f71 100644 --- a/tests/service/test_service.py +++ b/tests/service/test_service.py @@ -7,7 +7,9 @@ from langchain_core.messages import AIMessage, HumanMessage from langgraph.pregel.types import StateSnapshot -from schema import ChatHistory, ChatMessage +from agents.agents import Agent +from schema import ChatHistory, ChatMessage, ServiceMetadata +from schema.models import OpenAIModelName def test_invoke(test_client, mock_agent) -> None: @@ -238,3 +240,24 @@ async def mock_astream_events(**kwargs): assert len(final_messages) == 1 assert final_messages[0]["content"]["content"] == FINAL_ANSWER assert final_messages[0]["content"]["type"] == "ai" + + +def test_info(test_client, mock_settings) -> None: + """Test that /info returns the correct service metadata.""" + + base_agent = Agent(description="A base agent.", graph=None) + mock_settings.AUTH_SECRET = None + mock_settings.DEFAULT_MODEL = OpenAIModelName.GPT_4O_MINI + mock_settings.AVAILABLE_MODELS = {OpenAIModelName.GPT_4O_MINI, OpenAIModelName.GPT_4O} + with patch.dict("agents.agents.agents", {"base-agent": base_agent}, clear=True): + response = test_client.get("/info") + assert response.status_code == 200 + output = ServiceMetadata.model_validate(response.json()) + + assert output.default_agent == "research-assistant" + assert len(output.agents) == 1 + assert output.agents[0].key == "base-agent" + assert output.agents[0].description == "A base agent." + + assert output.default_model == OpenAIModelName.GPT_4O_MINI + assert output.models == [OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI] From f9d2bb7ff558e61fc7425871ca424d762a162b35 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Sun, 8 Dec 2024 23:57:11 -0800 Subject: [PATCH 4/9] Add client support for /info endpoint --- src/client/client.py | 62 +++++++++++++++++++++++++++++-- src/run_client.py | 6 +++ tests/client/conftest.py | 4 +- tests/client/test_client.py | 53 ++++++++++++++++++++++---- tests/service/conftest.py | 12 ++++-- tests/service/test_service_e2e.py | 7 +++- 6 files changed, 128 insertions(+), 16 deletions(-) diff --git a/src/client/client.py b/src/client/client.py index be6f91a..66e318e 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -5,7 +5,15 @@ import httpx -from schema import ChatHistory, ChatHistoryInput, ChatMessage, Feedback, StreamInput, UserInput +from schema import ( + ChatHistory, + ChatHistoryInput, + ChatMessage, + Feedback, + ServiceMetadata, + StreamInput, + UserInput, +) class AgentClient: @@ -14,19 +22,29 @@ class AgentClient: def __init__( self, base_url: str = "http://localhost", - agent: str = "research-assistant", + agent: str = None, timeout: float | None = None, + get_info: bool = True, ) -> None: """ Initialize the client. Args: base_url (str): The base URL of the agent service. + agent (str): The name of the default agent to use. + timeout (float, optional): The timeout for requests. + get_info (bool, optional): Whether to fetch agent information on init. + Default: True """ self.base_url = base_url - self.agent = agent self.auth_secret = os.getenv("AUTH_SECRET") self.timeout = timeout + self.info: ServiceMetadata | None = None + self.agent: str | None = None + if get_info: + self.retrieve_info() + if agent: + self.update_agent(agent) @property def _headers(self) -> dict[str, str]: @@ -35,6 +53,36 @@ def _headers(self) -> dict[str, str]: headers["Authorization"] = f"Bearer {self.auth_secret}" return headers + def retrieve_info(self) -> None: + try: + response = httpx.get( + f"{self.base_url}/info", + headers=self._headers, + timeout=self.timeout, + ) + if response.status_code == 200: + self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json()) + else: + raise Exception( + f"Error getting service info: {response.status_code} - {response.text}" + ) + except Exception as e: + raise Exception(f"Error getting service info: {e}") + + if not self.agent or self.agent not in [a.key for a in self.info.agents]: + self.agent = self.info.default_agent + + def update_agent(self, agent: str, verify: bool = True) -> None: + if verify: + if not self.info: + self.retrieve_info() + agent_keys = [a.key for a in self.info.agents] + if agent not in agent_keys: + raise Exception( + f"Agent {agent} not found in available agents: {', '.join(agent_keys)}" + ) + self.agent = agent + async def ainvoke( self, message: str, model: str | None = None, thread_id: str | None = None ) -> ChatMessage: @@ -49,6 +97,8 @@ async def ainvoke( Returns: AnyMessage: The response from the agent """ + if not self.agent: + raise Exception("No agent selected. Use update_agent() to select an agent.") request = UserInput(message=message, thread_id=thread_id, model=model) async with httpx.AsyncClient() as client: response = await client.post( @@ -75,6 +125,8 @@ def invoke( Returns: ChatMessage: The response from the agent """ + if not self.agent: + raise Exception("No agent selected. Use update_agent() to select an agent.") request = UserInput(message=message) if thread_id: request.thread_id = thread_id @@ -138,6 +190,8 @@ def stream( Returns: Generator[ChatMessage | str, None, None]: The response from the agent """ + if not self.agent: + raise Exception("No agent selected. Use update_agent() to select an agent.") request = StreamInput(message=message, stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id @@ -183,6 +237,8 @@ async def astream( Returns: AsyncGenerator[ChatMessage | str, None]: The response from the agent """ + if not self.agent: + raise Exception("No agent selected. Use update_agent() to select an agent.") request = StreamInput(message=message, stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id diff --git a/src/run_client.py b/src/run_client.py index 6dafac0..7186849 100644 --- a/src/run_client.py +++ b/src/run_client.py @@ -9,6 +9,9 @@ async def amain() -> None: #### ASYNC #### client = AgentClient(settings.BASE_URL) + print("Agent info:") + print(client.info) + print("Chat example:") response = await client.ainvoke("Tell me a brief joke?", model="gpt-4o") response.pretty_print() @@ -28,6 +31,9 @@ def main() -> None: #### SYNC #### client = AgentClient(settings.BASE_URL) + print("Agent info:") + print(client.info) + print("Chat example:") response = client.invoke("Tell me a brief joke?", model="gpt-4o") response.pretty_print() diff --git a/tests/client/conftest.py b/tests/client/conftest.py index a9bbbb4..0055e6b 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -6,4 +6,6 @@ @pytest.fixture def agent_client(mock_env): """Fixture for creating a test client with a clean environment.""" - return AgentClient(base_url="http://test", agent="test-agent") + ac = AgentClient(base_url="http://test", get_info=False) + ac.update_agent("test-agent", verify=False) + return ac diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 817ffd7..45646ec 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -6,37 +6,38 @@ from httpx import Response from client import AgentClient -from schema import ChatHistory, ChatMessage +from schema import AgentInfo, ChatHistory, ChatMessage, ServiceMetadata +from schema.models import OpenAIModelName def test_init(mock_env): """Test client initialization with different parameters.""" # Test default values - client = AgentClient() + client = AgentClient(get_info=False) assert client.base_url == "http://localhost" - assert client.agent == "research-assistant" assert client.timeout is None # Test custom values client = AgentClient( base_url="http://test", - agent="custom-agent", timeout=30.0, + get_info=False, ) assert client.base_url == "http://test" - assert client.agent == "custom-agent" assert client.timeout == 30.0 + client.update_agent("test-agent", verify=False) + assert client.agent == "test-agent" def test_headers(mock_env): """Test header generation with and without auth.""" # Test without auth - client = AgentClient() + client = AgentClient(get_info=False) assert client._headers == {} # Test with auth with patch.dict(os.environ, {"AUTH_SECRET": "test-secret"}, clear=True): - client = AgentClient() + client = AgentClient(get_info=False) assert client._headers == {"Authorization": "Bearer test-secret"} @@ -277,3 +278,41 @@ def test_get_history(agent_client): with pytest.raises(Exception) as exc: agent_client.get_history(THREAD_ID) assert "Error: 500" in str(exc.value) + + +def test_info(agent_client): + assert agent_client.info is None + assert agent_client.agent == "test-agent" + + # Mock info response + test_info = ServiceMetadata( + default_agent="custom-agent", + agents=[AgentInfo(key="custom-agent", description="Custom agent")], + default_model=OpenAIModelName.GPT_4O, + models=[OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI], + ) + test_response = Response(200, json=test_info.model_dump()) + + # Update an existing client with info + with patch("httpx.get", return_value=test_response): + agent_client.retrieve_info() + + assert agent_client.info == test_info + assert agent_client.agent == "custom-agent" + + # Test invalid update_agent + with pytest.raises(Exception) as exc: + agent_client.update_agent("unknown-agent") + assert "Agent unknown-agent not found in available agents: custom-agent" in str(exc.value) + + # Test a fresh client with info + with patch("httpx.get", return_value=test_response): + agent_client = AgentClient(base_url="http://test") + assert agent_client.info == test_info + assert agent_client.agent == "custom-agent" + + # Test error on invoke if no agent set + agent_client = AgentClient(base_url="http://test", get_info=False) + with pytest.raises(Exception) as exc: + agent_client.invoke("test") + assert "No agent selected. Use update_agent() to select an agent." in str(exc.value) diff --git a/tests/service/conftest.py b/tests/service/conftest.py index ffb9139..83908fa 100644 --- a/tests/service/conftest.py +++ b/tests/service/conftest.py @@ -31,8 +31,8 @@ def mock_settings(mock_env): @pytest.fixture -def mock_httpx_stream(): - """Patch httpx.stream to use our test client.""" +def mock_httpx(): + """Patch httpx.stream and httpx.get to use our test client.""" with TestClient(app) as client: @@ -41,5 +41,11 @@ def mock_stream(method: str, url: str, **kwargs): path = url.replace("http://localhost", "") return client.stream(method, path, **kwargs) + def mock_get(url: str, **kwargs): + # Strip the base URL since TestClient expects just the path + path = url.replace("http://localhost", "") + return client.get(path, **kwargs) + with patch("httpx.stream", mock_stream): - yield + with patch("httpx.get", mock_get): + yield diff --git a/tests/service/test_service_e2e.py b/tests/service/test_service_e2e.py index ac0942f..efffde3 100644 --- a/tests/service/test_service_e2e.py +++ b/tests/service/test_service_e2e.py @@ -5,6 +5,7 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, MessagesState, StateGraph +from agents.agents import Agent from agents.utils import CustomData from client import AgentClient from schema.schema import ChatMessage @@ -78,9 +79,11 @@ async def static_messages(state: MessagesState, config: RunnableConfig) -> Messa static_agent = agent.compile(checkpointer=MemorySaver()) -def test_agent_stream(mock_httpx_stream): +def test_agent_stream(mock_httpx): """Test that streaming from our static agent works correctly with token streaming.""" - client = AgentClient(agent="static-agent") + agent_meta = Agent(description="A static agent.", graph=static_agent) + with patch.dict("agents.agents.agents", {"static-agent": agent_meta}, clear=True): + client = AgentClient(agent="static-agent") # Use stream to get intermediate responses messages = [] From bc89fb64d371d0b736ca9eb0fae2adf41b52c88d Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Mon, 9 Dec 2024 21:57:07 -0800 Subject: [PATCH 5/9] Update app to use metadata endpoint for dynamic settings --- src/streamlit_app.py | 36 ++++++++++++--------------------- tests/app/conftest.py | 14 +++++++++++++ tests/app/test_streamlit_app.py | 11 +++++----- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/streamlit_app.py b/src/streamlit_app.py index a991eb7..ce322b3 100644 --- a/src/streamlit_app.py +++ b/src/streamlit_app.py @@ -4,18 +4,12 @@ import streamlit as st from dotenv import load_dotenv +from httpx import ConnectError, ConnectTimeout from pydantic import ValidationError from streamlit.runtime.scriptrunner import get_script_run_ctx from client import AgentClient from schema import ChatHistory, ChatMessage -from schema.models import ( - AnthropicModelName, - AWSModelName, - GoogleModelName, - GroqModelName, - OpenAIModelName, -) from schema.task_data import TaskData, TaskDataStatus # A Streamlit app for interacting with the langgraph agent via a simple chat interface. @@ -64,7 +58,11 @@ async def main() -> None: host = os.getenv("HOST", "0.0.0.0") port = os.getenv("PORT", 80) agent_url = f"http://{host}:{port}" - st.session_state.agent_client = AgentClient(base_url=agent_url) + try: + st.session_state.agent_client = AgentClient(base_url=agent_url) + except (ConnectError, ConnectTimeout) as e: + st.error(f"Error connecting to agent service: {e}") + st.stop() agent_client: AgentClient = st.session_state.agent_client if "thread_id" not in st.session_state: @@ -78,28 +76,20 @@ async def main() -> None: st.session_state.messages = messages st.session_state.thread_id = thread_id - models = { - "OpenAI GPT-4o-mini (streaming)": OpenAIModelName.GPT_4O_MINI, - "Gemini 1.5 Flash (streaming)": GoogleModelName.GEMINI_15_FLASH, - "Claude 3 Haiku (streaming)": AnthropicModelName.HAIKU_3, - "llama-3.1-70b on Groq": GroqModelName.LLAMA_31_70B, - "AWS Bedrock Haiku (streaming)": AWSModelName.BEDROCK_HAIKU, - } # Config options with st.sidebar: st.header(f"{APP_ICON} {APP_TITLE}") "" "Full toolkit for running an AI agent service built with LangGraph, FastAPI and Streamlit" with st.popover(":material/settings: Settings", use_container_width=True): - m = st.radio("LLM to use", options=models.keys()) - model = models[m] + model_idx = agent_client.info.models.index(agent_client.info.default_model) + model = st.selectbox("LLM to use", options=agent_client.info.models, index=model_idx) + agent_list = [a.key for a in agent_client.info.agents] + agent_idx = agent_list.index(agent_client.info.default_agent) agent_client.agent = st.selectbox( "Agent to use", - options=[ - "research-assistant", - "chatbot", - "bg-task-agent", - ], + options=agent_list, + index=agent_idx, ) use_streaming = st.toggle("Stream results", value=True) @@ -135,7 +125,7 @@ def architecture_dialog() -> None: messages: list[ChatMessage] = st.session_state.messages if len(messages) == 0: - WELCOME = "Hello! I'm an AI-powered research assistant with web search and a calculator. I may take a few seconds to boot up when you send your first message. Ask me anything!" + WELCOME = "Hello! I'm an AI-powered research assistant with web search and a calculator. Ask me anything!" with st.chat_message("ai"): st.write(WELCOME) diff --git a/tests/app/conftest.py b/tests/app/conftest.py index 1e020a4..433be83 100644 --- a/tests/app/conftest.py +++ b/tests/app/conftest.py @@ -2,11 +2,25 @@ import pytest +from schema import AgentInfo, ServiceMetadata +from schema.models import OpenAIModelName + @pytest.fixture def mock_agent_client(mock_env): """Fixture for creating a mock AgentClient with a clean environment.""" + mock_info = ServiceMetadata( + default_agent="test-agent", + agents=[ + AgentInfo(key="test-agent", description="Test agent"), + AgentInfo(key="chatbot", description="Chatbot"), + ], + default_model=OpenAIModelName.GPT_4O, + models=[OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI], + ) + with patch("client.AgentClient") as mock_agent_client: mock_agent_client_instance = mock_agent_client.return_value + mock_agent_client_instance.info = mock_info yield mock_agent_client_instance diff --git a/tests/app/test_streamlit_app.py b/tests/app/test_streamlit_app.py index 711bc59..d5134ce 100644 --- a/tests/app/test_streamlit_app.py +++ b/tests/app/test_streamlit_app.py @@ -5,7 +5,7 @@ from streamlit.testing.v1 import AppTest from schema import ChatHistory, ChatMessage -from schema.models import AnthropicModelName +from schema.models import OpenAIModelName def test_app_simple_non_streaming(mock_agent_client): @@ -45,9 +45,10 @@ def test_app_settings(mock_agent_client): ) at.sidebar.toggle[0].set_value(False) # Use Streaming = False - at.sidebar.radio[0].set_value("Claude 3 Haiku (streaming)") - assert mock_agent_client.agent == "research-assistant" - at.sidebar.selectbox[0].set_value("chatbot") + assert at.sidebar.selectbox[0].value == "gpt-4o" + assert mock_agent_client.agent == "test-agent" + at.sidebar.selectbox[0].set_value("gpt-4o-mini") + at.sidebar.selectbox[1].set_value("chatbot") at.chat_input[0].set_value(PROMPT).run() print(at) @@ -61,7 +62,7 @@ def test_app_settings(mock_agent_client): assert mock_agent_client.agent == "chatbot" mock_agent_client.ainvoke.assert_called_with( message=PROMPT, - model=AnthropicModelName.HAIKU_3, + model=OpenAIModelName.GPT_4O_MINI, thread_id="test session id", ) assert not at.exception From 08bb9b55351080020d5294b0092fbeb3ea815dcd Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Mon, 9 Dec 2024 22:19:16 -0800 Subject: [PATCH 6/9] Introduce AgentClientError and clean up agent error handling --- src/client/__init__.py | 4 +- src/client/client.py | 175 +++++++++++++++++++----------------- tests/client/test_client.py | 86 ++++++++++-------- 3 files changed, 147 insertions(+), 118 deletions(-) diff --git a/src/client/__init__.py b/src/client/__init__.py index 2ecf86b..9c2963a 100644 --- a/src/client/__init__.py +++ b/src/client/__init__.py @@ -1,3 +1,3 @@ -from client.client import AgentClient +from client.client import AgentClient, AgentClientError -__all__ = ["AgentClient"] +__all__ = ["AgentClient", "AgentClientError"] diff --git a/src/client/client.py b/src/client/client.py index 66e318e..bf61f5d 100644 --- a/src/client/client.py +++ b/src/client/client.py @@ -16,6 +16,10 @@ ) +class AgentClientError(Exception): + pass + + class AgentClient: """Client for interacting with the agent service.""" @@ -60,15 +64,11 @@ def retrieve_info(self) -> None: headers=self._headers, timeout=self.timeout, ) - if response.status_code == 200: - self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json()) - else: - raise Exception( - f"Error getting service info: {response.status_code} - {response.text}" - ) - except Exception as e: - raise Exception(f"Error getting service info: {e}") + response.raise_for_status() + except httpx.HTTPError as e: + raise AgentClientError(f"Error getting service info: {e}") + self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json()) if not self.agent or self.agent not in [a.key for a in self.info.agents]: self.agent = self.info.default_agent @@ -78,7 +78,7 @@ def update_agent(self, agent: str, verify: bool = True) -> None: self.retrieve_info() agent_keys = [a.key for a in self.info.agents] if agent not in agent_keys: - raise Exception( + raise AgentClientError( f"Agent {agent} not found in available agents: {', '.join(agent_keys)}" ) self.agent = agent @@ -98,18 +98,21 @@ async def ainvoke( AnyMessage: The response from the agent """ if not self.agent: - raise Exception("No agent selected. Use update_agent() to select an agent.") + raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = UserInput(message=message, thread_id=thread_id, model=model) async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.base_url}/{self.agent}/invoke", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) - if response.status_code == 200: - return ChatMessage.model_validate(response.json()) - raise Exception(f"Error: {response.status_code} - {response.text}") + try: + response = await client.post( + f"{self.base_url}/{self.agent}/invoke", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) + response.raise_for_status() + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") + + return ChatMessage.model_validate(response.json()) def invoke( self, message: str, model: str | None = None, thread_id: str | None = None @@ -126,21 +129,24 @@ def invoke( ChatMessage: The response from the agent """ if not self.agent: - raise Exception("No agent selected. Use update_agent() to select an agent.") + raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = UserInput(message=message) if thread_id: request.thread_id = thread_id if model: request.model = model - response = httpx.post( - f"{self.base_url}/{self.agent}/invoke", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) - if response.status_code == 200: - return ChatMessage.model_validate(response.json()) - raise Exception(f"Error: {response.status_code} - {response.text}") + try: + response = httpx.post( + f"{self.base_url}/{self.agent}/invoke", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) + response.raise_for_status() + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") + + return ChatMessage.model_validate(response.json()) def _parse_stream_line(self, line: str) -> ChatMessage | str | None: line = line.strip() @@ -191,27 +197,29 @@ def stream( Generator[ChatMessage | str, None, None]: The response from the agent """ if not self.agent: - raise Exception("No agent selected. Use update_agent() to select an agent.") + raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = StreamInput(message=message, stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id if model: request.model = model - with httpx.stream( - "POST", - f"{self.base_url}/{self.agent}/stream", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) as response: - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} - {response.text}") - for line in response.iter_lines(): - if line.strip(): - parsed = self._parse_stream_line(line) - if parsed is None: - break - yield parsed + try: + with httpx.stream( + "POST", + f"{self.base_url}/{self.agent}/stream", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) as response: + response.raise_for_status() + for line in response.iter_lines(): + if line.strip(): + parsed = self._parse_stream_line(line) + if parsed is None: + break + yield parsed + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") async def astream( self, @@ -238,28 +246,30 @@ async def astream( AsyncGenerator[ChatMessage | str, None]: The response from the agent """ if not self.agent: - raise Exception("No agent selected. Use update_agent() to select an agent.") + raise AgentClientError("No agent selected. Use update_agent() to select an agent.") request = StreamInput(message=message, stream_tokens=stream_tokens) if thread_id: request.thread_id = thread_id if model: request.model = model async with httpx.AsyncClient() as client: - async with client.stream( - "POST", - f"{self.base_url}/{self.agent}/stream", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) as response: - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} - {response.text}") - async for line in response.aiter_lines(): - if line.strip(): - parsed = self._parse_stream_line(line) - if parsed is None: - break - yield parsed + try: + async with client.stream( + "POST", + f"{self.base_url}/{self.agent}/stream", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if line.strip(): + parsed = self._parse_stream_line(line) + if parsed is None: + break + yield parsed + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") async def acreate_feedback( self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {} @@ -273,15 +283,17 @@ async def acreate_feedback( """ request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs) async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.base_url}/feedback", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) - if response.status_code != 200: - raise Exception(f"Error: {response.status_code} - {response.text}") - response.json() + try: + response = await client.post( + f"{self.base_url}/feedback", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) + response.raise_for_status() + response.json() + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") def get_history( self, @@ -294,14 +306,15 @@ def get_history( thread_id (str, optional): Thread ID for identifying a conversation """ request = ChatHistoryInput(thread_id=thread_id) - response = httpx.post( - f"{self.base_url}/history", - json=request.model_dump(), - headers=self._headers, - timeout=self.timeout, - ) - if response.status_code == 200: - response_object = response.json() - return ChatHistory.model_validate(response_object) - else: - raise Exception(f"Error: {response.status_code} - {response.text}") + try: + response = httpx.post( + f"{self.base_url}/history", + json=request.model_dump(), + headers=self._headers, + timeout=self.timeout, + ) + response.raise_for_status() + except httpx.HTTPError as e: + raise AgentClientError(f"Error: {e}") + + return ChatHistory.model_validate(response.json()) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 45646ec..2d56c30 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -3,9 +3,9 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from httpx import Response +from httpx import Request, Response -from client import AgentClient +from client import AgentClient, AgentClientError from schema import AgentInfo, ChatHistory, ChatMessage, ServiceMetadata from schema.models import OpenAIModelName @@ -47,9 +47,11 @@ def test_invoke(agent_client): ANSWER = "The weather is sunny." # Mock successful response + mock_request = Request("POST", "http://test/invoke") mock_response = Response( 200, json={"type": "ai", "content": ANSWER}, + request=mock_request, ) with patch("httpx.post", return_value=mock_response): response = agent_client.invoke(QUESTION) @@ -72,11 +74,11 @@ def test_invoke(agent_client): assert kwargs["json"]["thread_id"] == "test-thread" # Test error response - error_response = Response(500, text="Internal Server Error") + error_response = Response(500, text="Internal Server Error", request=mock_request) with patch("httpx.post", return_value=error_response): - with pytest.raises(Exception) as exc: + with pytest.raises(AgentClientError) as exc: agent_client.invoke(QUESTION) - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) @pytest.mark.asyncio @@ -86,7 +88,8 @@ async def test_ainvoke(agent_client): ANSWER = "The weather is sunny." # Test successful response - mock_response = Response(200, json={"type": "ai", "content": ANSWER}) + mock_request = Request("POST", "http://test/invoke") + mock_response = Response(200, json={"type": "ai", "content": ANSWER}, request=mock_request) with patch("httpx.AsyncClient.post", return_value=mock_response): response = await agent_client.ainvoke(QUESTION) assert isinstance(response, ChatMessage) @@ -110,10 +113,11 @@ async def test_ainvoke(agent_client): assert kwargs["json"]["thread_id"] == "test-thread" # Test error response - with patch("httpx.AsyncClient.post", return_value=Response(500, text="Internal Server Error")): - with pytest.raises(Exception) as exc: + error_response = Response(500, text="Internal Server Error", request=mock_request) + with patch("httpx.AsyncClient.post", return_value=error_response): + with pytest.raises(AgentClientError) as exc: await agent_client.ainvoke(QUESTION) - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) def test_stream(agent_client): @@ -135,6 +139,7 @@ def test_stream(agent_client): mock_response = Mock() mock_response.status_code = 200 mock_response.iter_lines.return_value = events + mock_response.request = Request("POST", "http://test/stream") mock_response.__enter__ = Mock(return_value=mock_response) mock_response.__exit__ = Mock(return_value=None) @@ -154,15 +159,16 @@ def test_stream(agent_client): assert final_message.content == FINAL_ANSWER # Test error response - error_response = Mock() - error_response.status_code = 500 - error_response.text = "Internal Server Error" - error_response.__enter__ = Mock(return_value=error_response) - error_response.__exit__ = Mock(return_value=None) - with patch("httpx.stream", return_value=error_response): - with pytest.raises(Exception) as exc: + error_response = Response( + 500, text="Internal Server Error", request=Request("POST", "http://test/stream") + ) + error_response_mock = Mock() + error_response_mock.__enter__ = Mock(return_value=error_response) + error_response_mock.__exit__ = Mock(return_value=None) + with patch("httpx.stream", return_value=error_response_mock): + with pytest.raises(AgentClientError) as exc: list(agent_client.stream(QUESTION)) - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) @pytest.mark.asyncio @@ -189,6 +195,7 @@ async def async_events(): # Mock the streaming response mock_response = AsyncMock() mock_response.status_code = 200 + mock_response.request = Request("POST", "http://test/stream") mock_response.aiter_lines = Mock(return_value=async_events()) mock_response.__aenter__ = AsyncMock(return_value=mock_response) @@ -214,18 +221,19 @@ async def async_events(): assert final_message.content == FINAL_ANSWER # Test error response - error_response = AsyncMock() - error_response.status_code = 500 - error_response.text = "Internal Server Error" - error_response.__aenter__ = AsyncMock(return_value=error_response) + error_response = Response( + 500, text="Internal Server Error", request=Request("POST", "http://test/stream") + ) + error_response_mock = AsyncMock() + error_response_mock.__aenter__ = AsyncMock(return_value=error_response) - mock_client.stream.return_value = error_response + mock_client.stream.return_value = error_response_mock with patch("httpx.AsyncClient", return_value=mock_client): - with pytest.raises(Exception) as exc: + with pytest.raises(AgentClientError) as exc: async for _ in agent_client.astream(QUESTION): pass - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) @pytest.mark.asyncio @@ -237,7 +245,8 @@ async def test_acreate_feedback(agent_client): KWARGS = {"comment": "Great response!"} # Test successful response - with patch("httpx.AsyncClient.post", return_value=Response(200, json={})) as mock_post: + mock_response = Response(200, json={}, request=Request("POST", "http://test/feedback")) + with patch("httpx.AsyncClient.post", return_value=mock_response) as mock_post: await agent_client.acreate_feedback(RUN_ID, KEY, SCORE, KWARGS) # Verify request args, kwargs = mock_post.call_args @@ -247,10 +256,13 @@ async def test_acreate_feedback(agent_client): assert kwargs["json"]["kwargs"] == KWARGS # Test error response - with patch("httpx.AsyncClient.post", return_value=Response(500, text="Internal Server Error")): - with pytest.raises(Exception) as exc: + error_response = Response( + 500, text="Internal Server Error", request=Request("POST", "http://test/feedback") + ) + with patch("httpx.AsyncClient.post", return_value=error_response): + with pytest.raises(AgentClientError) as exc: await agent_client.acreate_feedback(RUN_ID, KEY, SCORE) - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) def test_get_history(agent_client): @@ -264,7 +276,7 @@ def test_get_history(agent_client): } # Mock successful response - mock_response = Response(200, json=HISTORY) + mock_response = Response(200, json=HISTORY, request=Request("POST", "http://test/history")) with patch("httpx.post", return_value=mock_response): history = agent_client.get_history(THREAD_ID) assert isinstance(history, ChatHistory) @@ -273,11 +285,13 @@ def test_get_history(agent_client): assert history.messages[1].type == "ai" # Test error response - error_response = Response(500, text="Internal Server Error") + error_response = Response( + 500, text="Internal Server Error", request=Request("POST", "http://test/history") + ) with patch("httpx.post", return_value=error_response): - with pytest.raises(Exception) as exc: + with pytest.raises(AgentClientError) as exc: agent_client.get_history(THREAD_ID) - assert "Error: 500" in str(exc.value) + assert "500 Internal Server Error" in str(exc.value) def test_info(agent_client): @@ -291,7 +305,9 @@ def test_info(agent_client): default_model=OpenAIModelName.GPT_4O, models=[OpenAIModelName.GPT_4O, OpenAIModelName.GPT_4O_MINI], ) - test_response = Response(200, json=test_info.model_dump()) + test_response = Response( + 200, json=test_info.model_dump(), request=Request("GET", "http://test/info") + ) # Update an existing client with info with patch("httpx.get", return_value=test_response): @@ -301,7 +317,7 @@ def test_info(agent_client): assert agent_client.agent == "custom-agent" # Test invalid update_agent - with pytest.raises(Exception) as exc: + with pytest.raises(AgentClientError) as exc: agent_client.update_agent("unknown-agent") assert "Agent unknown-agent not found in available agents: custom-agent" in str(exc.value) @@ -313,6 +329,6 @@ def test_info(agent_client): # Test error on invoke if no agent set agent_client = AgentClient(base_url="http://test", get_info=False) - with pytest.raises(Exception) as exc: + with pytest.raises(AgentClientError) as exc: agent_client.invoke("test") assert "No agent selected. Use update_agent() to select an agent." in str(exc.value) From 1fb349d36dc26adc1f3ce3217fa8bf7c2bd380a0 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Mon, 9 Dec 2024 22:43:47 -0800 Subject: [PATCH 7/9] Update app error handling --- src/streamlit_app.py | 63 +++++++++++++++++++-------------- tests/app/test_streamlit_app.py | 21 +++++++++++ 2 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/streamlit_app.py b/src/streamlit_app.py index ce322b3..112109b 100644 --- a/src/streamlit_app.py +++ b/src/streamlit_app.py @@ -4,11 +4,10 @@ import streamlit as st from dotenv import load_dotenv -from httpx import ConnectError, ConnectTimeout from pydantic import ValidationError from streamlit.runtime.scriptrunner import get_script_run_ctx -from client import AgentClient +from client import AgentClient, AgentClientError from schema import ChatHistory, ChatMessage from schema.task_data import TaskData, TaskDataStatus @@ -59,9 +58,11 @@ async def main() -> None: port = os.getenv("PORT", 80) agent_url = f"http://{host}:{port}" try: - st.session_state.agent_client = AgentClient(base_url=agent_url) - except (ConnectError, ConnectTimeout) as e: + with st.spinner("Connecting to agent service..."): + st.session_state.agent_client = AgentClient(base_url=agent_url) + except AgentClientError as e: st.error(f"Error connecting to agent service: {e}") + st.markdown("The service might be booting up. Try again in a few seconds.") st.stop() agent_client: AgentClient = st.session_state.agent_client @@ -140,25 +141,29 @@ async def amessage_iter() -> AsyncGenerator[ChatMessage, None]: if user_input := st.chat_input(): messages.append(ChatMessage(type="human", content=user_input)) st.chat_message("human").write(user_input) - if use_streaming: - stream = agent_client.astream( - message=user_input, - model=model, - thread_id=st.session_state.thread_id, - ) - await draw_messages(stream, is_new=True) - else: - response = await agent_client.ainvoke( - message=user_input, - model=model, - thread_id=st.session_state.thread_id, - ) - messages.append(response) - st.chat_message("ai").write(response.content) - st.rerun() # Clear stale containers + try: + if use_streaming: + stream = agent_client.astream( + message=user_input, + model=model, + thread_id=st.session_state.thread_id, + ) + await draw_messages(stream, is_new=True) + else: + response = await agent_client.ainvoke( + message=user_input, + model=model, + thread_id=st.session_state.thread_id, + ) + messages.append(response) + st.chat_message("ai").write(response.content) + st.rerun() # Clear stale containers + except AgentClientError as e: + st.error(f"Error generating response: {e}") + st.stop() # If messages have been generated, show feedback widget - if len(messages) > 0: + if len(messages) > 0 and st.session_state.last_message: with st.session_state.last_message: await handle_feedback() @@ -321,12 +326,16 @@ async def handle_feedback() -> None: normalized_score = (feedback + 1) / 5.0 agent_client: AgentClient = st.session_state.agent_client - await agent_client.acreate_feedback( - run_id=latest_run_id, - key="human-feedback-stars", - score=normalized_score, - kwargs={"comment": "In-line human feedback"}, - ) + try: + await agent_client.acreate_feedback( + run_id=latest_run_id, + key="human-feedback-stars", + score=normalized_score, + kwargs={"comment": "In-line human feedback"}, + ) + except AgentClientError as e: + st.error(f"Error recording feedback: {e}") + st.stop() st.session_state.last_feedback = (latest_run_id, feedback) st.toast("Feedback recorded", icon=":material/reviews:") diff --git a/tests/app/test_streamlit_app.py b/tests/app/test_streamlit_app.py index d5134ce..1d6a0ad 100644 --- a/tests/app/test_streamlit_app.py +++ b/tests/app/test_streamlit_app.py @@ -4,6 +4,7 @@ import pytest from streamlit.testing.v1 import AppTest +from client import AgentClientError from schema import ChatHistory, ChatMessage from schema.models import OpenAIModelName @@ -139,3 +140,23 @@ async def amessage_iter() -> AsyncGenerator[ChatMessage, None]: assert tool_status.markdown[2].value == "42" assert response.markdown[-1].value == "The answer is 42" assert not at.exception + + +@pytest.mark.asyncio +async def test_app_init_error(mock_agent_client): + """Test the app with an error in the agent initialization""" + at = AppTest.from_file("../../src/streamlit_app.py").run() + + # Setup mock streaming response + PROMPT = "What is 6 * 7?" + mock_agent_client.astream.side_effect = AgentClientError("Error connecting to agent") + + at.toggle[0].set_value(True) # Use Streaming = True + at.chat_input[0].set_value(PROMPT).run() + print(at) + + assert at.chat_message[0].avatar == "assistant" + assert at.chat_message[1].avatar == "user" + assert at.chat_message[1].markdown[0].value == PROMPT + assert at.error[0].value == "Error generating response: Error connecting to agent" + assert not at.exception From c758e75014b8a9c9ec54eb061b4dcc11b4b2156e Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Mon, 9 Dec 2024 22:51:02 -0800 Subject: [PATCH 8/9] Add an integration test with streamlit app --- tests/integration/test_docker_e2e.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/integration/test_docker_e2e.py b/tests/integration/test_docker_e2e.py index 55008ab..851a856 100644 --- a/tests/integration/test_docker_e2e.py +++ b/tests/integration/test_docker_e2e.py @@ -1,4 +1,5 @@ import pytest +from streamlit.testing.v1 import AppTest from client import AgentClient @@ -13,3 +14,24 @@ def test_service_with_fake_model(): response = client.invoke("Tell me a joke?", model="fake") assert response.type == "ai" assert response.content == "This is a test response from the fake model." + + +@pytest.mark.docker +def test_service_with_app(): + """Test the service using the app. + + This test requires the service container to be running with USE_FAKE_MODEL=true + """ + at = AppTest.from_file("../../src/streamlit_app.py").run() + assert at.chat_message[0].avatar == "assistant" + welcome = at.chat_message[0].markdown[0].value + assert welcome.startswith("Hello! I'm an AI-powered research assistant") + assert not at.exception + + at.sidebar.selectbox[1].set_value("chatbot") + at.chat_input[0].set_value("What is the weather in Tokyo?").run() + assert at.chat_message[0].avatar == "user" + assert at.chat_message[0].markdown[0].value == "What is the weather in Tokyo?" + assert at.chat_message[1].avatar == "assistant" + assert at.chat_message[1].markdown[0].value == "This is a test response from the fake model." + assert not at.exception From b7ce21120c3878f10b3d2b0d5d6f5a38c29657d1 Mon Sep 17 00:00:00 2001 From: Joshua Carroll Date: Mon, 9 Dec 2024 23:08:05 -0800 Subject: [PATCH 9/9] Update README.md --- README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0b9c317..6e365ce 100644 --- a/README.md +++ b/README.md @@ -57,20 +57,21 @@ docker compose watch 1. **Multiple Agent Support**: Run multiple agents in the service and call by URL path 1. **Asynchronous Design**: Utilizes async/await for efficient handling of concurrent requests. 1. **Feedback Mechanism**: Includes a star-based feedback system integrated with LangSmith. +1. **Dynamic Metadata**: `/info` endpoint provides dynamically configured metadata about the service and available agents and models. 1. **Docker Support**: Includes Dockerfiles and a docker compose file for easy development and deployment. +1. **Testing**: Includes robust unit and integration tests for the full repo. ### Key Files The repository is structured as follows: -- `src/agents/research_assistant.py`: Defines the main LangGraph agent -- `src/agents/llama_guard.py`: Defines the LlamaGuard content moderation -- `src/agents/models.py`: Configures available models based on ENV -- `src/agents/agents.py`: Mapping of all agents provided by the service -- `src/schema/schema.py`: Defines the protocol schema +- `src/agents/`: Defines several agents with different capabilities +- `src/schema/`: Defines the protocol schema +- `src/core/`: Core modules including LLM definition and settings - `src/service/service.py`: FastAPI service to serve the agents - `src/client/client.py`: Client to interact with the agent service - `src/streamlit_app.py`: Streamlit app providing a chat interface +- `tests/`: Unit and integration tests ## Why LangGraph? @@ -213,7 +214,7 @@ Contributions are welcome! Please feel free to submit a Pull Request. - [x] Add more sophisticated tools for the research assistant - [x] Increase test coverage and add CI pipeline - [x] Add support for multiple agents running on the same service, including non-chat agent -- [ ] Deployment instructions and configuration for cloud providers +- [x] Service metadata endpoint `/info` and dynamic app configuration - [ ] More ideas? File an issue or create a discussion! ## License