From 390a1b345ec4cd8d6baa1e2c3720a3f6e4ae2371 Mon Sep 17 00:00:00 2001 From: Chris Poirier Date: Tue, 7 May 2024 00:55:40 -0400 Subject: [PATCH] openai: add toolcalls support when streaming (#763) * add toolcalls support during stream * only set the chunk if we have data * add spaces to comments * simplify code to meet linter requirements --- llms/openai/internal/openaiclient/chat.go | 106 +++++++++++++++------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 20fbf6f0a..780bfccf1 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -278,7 +278,8 @@ type StreamedChatResponsePayload struct { Role string `json:"role,omitempty"` Content string `json:"content,omitempty"` FunctionCall *FunctionCall `json:"function_call,omitempty"` - ToolCalls []*ToolCall `json:"tool_calls,omitempty"` + // ToolCalls is a list of tools that were called in the message. + ToolCalls []*ToolCall `json:"tool_calls,omitempty"` } `json:"delta,omitempty"` FinishReason FinishReason `json:"finish_reason,omitempty"` } `json:"choices,omitempty"` @@ -403,43 +404,21 @@ func combineStreamingChatResponse(ctx context.Context, payload *ChatRequest, res }, } - var ( - currentTool ToolCall - currentIndex = -1 - ) for streamResponse := range responseChan { if len(streamResponse.Choices) == 0 { continue } - chunk := []byte(streamResponse.Choices[0].Delta.Content) - response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content - response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason - if streamResponse.Choices[0].Delta.FunctionCall != nil { - if response.Choices[0].Message.FunctionCall == nil { - response.Choices[0].Message.FunctionCall = streamResponse.Choices[0].Delta.FunctionCall - } else { - response.Choices[0].Message.FunctionCall.Arguments += streamResponse.Choices[0].Delta.FunctionCall.Arguments - } - chunk, _ = json.Marshal(response.Choices[0].Message.FunctionCall) // nolint:errchkjson + choice := streamResponse.Choices[0] + chunk := []byte(choice.Delta.Content) + response.Choices[0].Message.Content += choice.Delta.Content + response.Choices[0].FinishReason = choice.FinishReason + + if choice.Delta.FunctionCall != nil { + chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall) } - if len(streamResponse.Choices[0].Delta.ToolCalls) > 0 { - for _, streamTool := range streamResponse.Choices[0].Delta.ToolCalls { - if streamTool.ID != "" { - currentTool = ToolCall{ - ID: streamTool.ID, - Type: streamTool.Type, - Function: ToolFunction{ - Name: streamTool.Function.Name, - Arguments: streamTool.Function.Arguments, - }, - } - response.Choices[0].Message.ToolCalls = append(response.Choices[0].Message.ToolCalls, currentTool) - currentIndex++ - } else { - response.Choices[0].Message.ToolCalls[currentIndex].Function.Arguments += streamTool.Function.Arguments - } - } - chunk, _ = json.Marshal(response.Choices[0].Message.ToolCalls) // nolint:errchkjson + + if len(choice.Delta.ToolCalls) > 0 { + chunk, response.Choices[0].Message.ToolCalls = updateToolCalls(response.Choices[0].Message.ToolCalls, choice.Delta.ToolCalls) } if payload.StreamingFunc != nil { @@ -451,3 +430,64 @@ func combineStreamingChatResponse(ctx context.Context, payload *ChatRequest, res } return &response, nil } + +func updateFunctionCall(message ChatMessage, functionCall *FunctionCall) []byte { + if message.FunctionCall == nil { + message.FunctionCall = functionCall + } else { + message.FunctionCall.Arguments += functionCall.Arguments + } + chunk, _ := json.Marshal(message.FunctionCall) // nolint:errchkjson + return chunk +} + +func updateToolCalls(tools []ToolCall, delta []*ToolCall) ([]byte, []ToolCall) { + if len(delta) == 0 { + return []byte{}, tools + } + for _, t := range delta { + // if we have arguments append to the last Tool call + if t.Type == `` && t.Function.Arguments != `` { + lindex := len(tools) - 1 + if lindex < 0 { + continue + } + + tools[lindex].Function.Arguments += t.Function.Arguments + continue + } + + // Otherwise, this is a new tool call, append that to the stack + tools = append(tools, *t) + } + + chunk, _ := json.Marshal(delta) // nolint:errchkjson + + return chunk, tools +} + +// StreamingChatResponseTools is a helper function to append tool calls to the stack. +func StreamingChatResponseTools(tools []ToolCall, delta []*ToolCall) ([]byte, []ToolCall) { + if len(delta) == 0 { + return []byte{}, tools + } + for _, t := range delta { + // if we have arguments append to the last Tool call + if t.Type == `` && t.Function.Arguments != `` { + lindex := len(tools) - 1 + if lindex < 0 { + continue + } + + tools[lindex].Function.Arguments += t.Function.Arguments + continue + } + + // Otherwise, this is a new tool call, append that to the stack + tools = append(tools, *t) + } + + chunk, _ := json.Marshal(delta) // nolint:errchkjson + + return chunk, tools +}