Skip to content

Commit

Permalink
Standardize on httpx for requests (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaC215 authored Sep 2, 2024
1 parent 7f51d76 commit fee1333
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 61 deletions.
1 change: 0 additions & 1 deletion agent/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
## NOTE: THIS REQUIREMENTS FILE IS JUST FOR LANGGRAPH STUDIO
## IT IS NOT INTENDED TO BE USED FOR LOCAL DEVELOPMENT

aiohttp~=3.10.0
duckduckgo-search~=6.2.6
langchain-community~=0.2.11
langchain-core~=0.2.26
Expand Down
118 changes: 61 additions & 57 deletions client/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import aiohttp
import json
import os
import httpx
from typing import AsyncGenerator, Dict, Any, Generator
import requests
from schema import ChatMessage, UserInput, StreamInput, Feedback


Expand All @@ -19,6 +18,10 @@ def __init__(self, base_url: str = "http://localhost:80"):
self.base_url = base_url
self.auth_secret = os.getenv("AUTH_SECRET")

# Use a shared async client to get the most benefit from connection pooling
# See: https://www.python-httpx.org/async/#opening-and-closing-clients
self.async_client = httpx.AsyncClient(timeout=None)

@property
def _headers(self):
headers = {}
Expand All @@ -40,20 +43,22 @@ async def ainvoke(
Returns:
AnyMessage: The response from the agent
"""
async with aiohttp.ClientSession() as session:
request = UserInput(message=message)
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
async with session.post(
f"{self.base_url}/invoke", json=request.dict(), headers=self._headers
) as response:
if response.status == 200:
result = await response.json()
return ChatMessage.parse_obj(result)
else:
raise Exception(f"Error: {response.status} - {await response.text()}")
request = UserInput(message=message)
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
response = await self.async_client.post(
f"{self.base_url}/invoke",
json=request.dict(),
headers=self._headers,
timeout=None,
)
if response.status_code == 200:
result = response.json()
return ChatMessage.parse_obj(result)
else:
raise Exception(f"Error: {response.status_code} - {response.text}")

def invoke(
self, message: str, model: str | None = None, thread_id: str | None = None
Expand All @@ -74,16 +79,19 @@ def invoke(
request.thread_id = thread_id
if model:
request.model = model
response = requests.post(
f"{self.base_url}/invoke", json=request.dict(), headers=self._headers
response = httpx.post(
f"{self.base_url}/invoke",
json=request.dict(),
headers=self._headers,
timeout=None,
)
if response.status_code == 200:
return ChatMessage.parse_obj(response.json())
else:
raise Exception(f"Error: {response.status_code} - {response.text}")

def _parse_stream_line(self, line: str) -> ChatMessage | str | None:
line = line.decode("utf-8").strip()
line = line.strip()
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
Expand Down Expand Up @@ -134,18 +142,17 @@ def stream(
request.thread_id = thread_id
if model:
request.model = model
response = requests.post(
f"{self.base_url}/stream", json=request.dict(), headers=self._headers, stream=True
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} - {response.text}")

for line in response.iter_lines():
if line:
parsed = self._parse_stream_line(line)
if parsed is None:
break
yield parsed
with httpx.stream(
"POST", f"{self.base_url}/stream", json=request.dict(), headers=self._headers
) as response:
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} - {response.text}")
for line in response.iter_lines():
if line.strip():
parsed = self._parse_stream_line(line)
if parsed is None:
break
yield parsed

async def astream(
self,
Expand All @@ -171,24 +178,22 @@ async def astream(
Returns:
AsyncGenerator[ChatMessage | str, None]: The response from the agent
"""
async with aiohttp.ClientSession() as session:
request = StreamInput(message=message, stream_tokens=stream_tokens)
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
async with session.post(
f"{self.base_url}/stream", json=request.dict(), headers=self._headers
) as response:
if response.status != 200:
raise Exception(f"Error: {response.status} - {await response.text()}")
# Parse incoming events with the SSE protocol
async for line in response.content:
if line.decode("utf-8").strip():
parsed = self._parse_stream_line(line)
if parsed is None:
break
yield parsed
request = StreamInput(message=message, stream_tokens=stream_tokens)
if thread_id:
request.thread_id = thread_id
if model:
request.model = model
async with self.async_client.stream(
"POST", f"{self.base_url}/stream", json=request.dict(), headers=self._headers
) as response:
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} - {response.text}")
async for line in response.aiter_lines():
if line.strip():
parsed = self._parse_stream_line(line)
if parsed is None:
break
yield parsed

async def acreate_feedback(
self, run_id: str, key: str, score: float, kwargs: Dict[str, Any] = {}
Expand All @@ -200,11 +205,10 @@ async def acreate_feedback(
credentials can be stored and managed in the service rather than the client.
See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post
"""
async with aiohttp.ClientSession() as session:
request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs)
async with session.post(
f"{self.base_url}/feedback", json=request.dict(), headers=self._headers
) as response:
if response.status != 200:
raise Exception(f"Error: {response.status} - {await response.text()}")
await response.json()
request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs)
response = await self.async_client.post(
f"{self.base_url}/feedback", json=request.dict(), headers=self._headers
)
if response.status_code != 200:
raise Exception(f"Error: {response.status_code} - {response.text}")
response.json()
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ requires-python = ">=3.9, <= 3.12.3"
# https://github.com/langchain-ai/langchain/discussions/9337
# IMPORTANT: This also requires using python < 3.12.4
dependencies = [
"aiohttp ~=3.10.0",
"duckduckgo-search ~=6.2.6",
"fastapi <0.100.0",
"httpx~=0.26.0",
"langchain-core ~=0.2.26",
"langchain-community ~=0.2.11",
"langchain-openai ~=0.1.20",
Expand All @@ -42,7 +42,6 @@ dependencies = [

[project.optional-dependencies]
dev = [
"httpx~=0.26.0",
"pre-commit",
"pytest",
"pytest-env",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# https://github.com/langchain-ai/langchain/discussions/9337
# IMPORTANT: This also requires using python < 3.12.4

aiohttp~=3.10.0
duckduckgo-search~=6.2.6
fastapi<0.100.0
httpx~=0.26.0
langchain-core~=0.2.26
langchain-community~=0.2.11
langchain-openai~=0.1.20
Expand Down

0 comments on commit fee1333

Please sign in to comment.