Skip to content

Commit

Permalink
feat(chatbot): Add streaming support and enhanced tool call handling
Browse files Browse the repository at this point in the history
fix #41
  • Loading branch information
yufeikang committed Jul 5, 2024
1 parent 257c229 commit 0d00bda
Showing 1 changed file with 46 additions and 26 deletions.
72 changes: 46 additions & 26 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ async def chat_completions(self, raycast_data: dict):
):
tools.append(self.__build_openai_function_img_tool(raycast_data))
async for i in self.__warp_chat(
openai_messages, model, temperature, tools=tools
openai_messages, model, temperature, tools=tools, stream=True
):
yield i

async def __warp_chat(self, messages, model, temperature, **kwargs):
functions = {}
current_function_id = None
async for choice, error in self.__chat(messages, model, temperature, **kwargs):
if error:
error_message = (
Expand All @@ -169,34 +171,50 @@ async def __warp_chat(self, messages, model, temperature, **kwargs):
if choice.delta and choice.delta.content:
yield f'data: {json_dumps({"text": choice.delta.content})}\n\n'
if choice.delta.tool_calls:
has_valid_tool = False
messages.append(choice.delta) # add the tool call to messages
logger.debug(f"Tool calls: {choice.delta}")
for tool_call in choice.delta.tool_calls:
tool_call_id = tool_call.id
tool_function_name = tool_call.function.name
logger.debug(f"Tool call: {tool_function_name}")
if tool_function_name == "generate_image":
if not tool_call.function.arguments:
continue
function_args = json.loads(tool_call.function.arguments)
yield f"data: {json_dumps({'text': 'Generating image...'})}\n\n"
fun_res = await self.__generate_image(**function_args)
# add to messages
has_valid_tool = True
messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_function_name,
"content": fun_res,
logger.debug(f"Tool call: {tool_call}")
if tool_call.id and tool_call.type == "function":
current_function_id = tool_call.id
if current_function_id not in functions:
functions[current_function_id] = {
"delta": choice.delta,
"name": tool_call.function.name,
"args": "",
}
)
if has_valid_tool:
async for i in self.__warp_chat(messages, model, temperature):
yield i
# add arguments stream string to the current function
functions[current_function_id][
"args"
] += tool_call.function.arguments
continue
if choice.finish_reason is not None:
logger.debug(f"Finish reason: {choice.finish_reason}")
if choice.finish_reason == "tool_calls":
continue
yield f'data: {json_dumps({"text": "", "finish_reason": choice.finish_reason})}\n\n'
if functions:
logger.debug(f"Tool functions: {functions}")
for tool_call_id, tool in functions.items():
delta, name, args = tool["delta"], tool["name"], tool["args"]
logger.debug(f"Tool call: {name} with args: {args}")
args = json.loads(args)
messages.append(delta) # add the tool call to messages
tool_res = None
if name == "generate_image":
yield f'data: {json_dumps({"text": "Generating image..."})}\n\n'
tool_res = await self.__generate_image(**args)
else:
logger.error(f"Unknown tool function: {name}")
messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool["name"],
"content": tool_res,
}
)
async for i in self.__warp_chat(messages, model, temperature, **kwargs):
yield i

def __build_openai_function_img_tool(self, raycast_data: dict):
return {
Expand All @@ -218,6 +236,7 @@ def __build_openai_function_img_tool(self, raycast_data: dict):
}

async def __generate_image(self, prompt, model="dall-e-3"):
# return '{"url": "https://images.ctfassets.net/kftzwdyauwt9/1ZTOGp7opuUflFmI2CsATh/df5da4be74f62c70d35e2f5518bf2660/ChatGPT_Carousel1.png?w=828&q=90&fm=webp"}' # debug image
try:
res = await self.openai_client.images.generate(
model=model,
Expand Down Expand Up @@ -255,7 +274,8 @@ async def __chat(self, messages, model, temperature, **kwargs):
if "tools" in kwargs and not kwargs["tools"]:
# pop tools from kwargs, empty tools will cause error
kwargs.pop("tools")
stream = "tools" not in kwargs
# stream = "tools" not in kwargs
stream = "stream" in kwargs and kwargs["stream"]
try:
logger.debug(f"openai chat stream: {stream}")
res = await self.openai_client.chat.completions.create(
Expand All @@ -264,7 +284,7 @@ async def __chat(self, messages, model, temperature, **kwargs):
max_tokens=MAX_TOKENS,
n=1,
temperature=temperature,
stream=stream,
# stream=stream,
**kwargs,
)
except openai.OpenAIError as e:
Expand Down

0 comments on commit 0d00bda

Please sign in to comment.