Skip to content

Commit

Permalink
openai: add toolcalls support when streaming (#763)
Browse files Browse the repository at this point in the history
* add toolcalls support during stream

* only set the chunk if we have data

* add spaces to comments

* simplify code to meet linter requirements
  • Loading branch information
ChrisCPoirier authored May 7, 2024
1 parent d161fc2 commit 390a1b3
Showing 1 changed file with 73 additions and 33 deletions.
106 changes: 73 additions & 33 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ type StreamedChatResponsePayload struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []*ToolCall `json:"tool_calls,omitempty"`
// ToolCalls is a list of tools that were called in the message.
ToolCalls []*ToolCall `json:"tool_calls,omitempty"`
} `json:"delta,omitempty"`
FinishReason FinishReason `json:"finish_reason,omitempty"`
} `json:"choices,omitempty"`
Expand Down Expand Up @@ -403,43 +404,21 @@ func combineStreamingChatResponse(ctx context.Context, payload *ChatRequest, res
},
}

var (
currentTool ToolCall
currentIndex = -1
)
for streamResponse := range responseChan {
if len(streamResponse.Choices) == 0 {
continue
}
chunk := []byte(streamResponse.Choices[0].Delta.Content)
response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content
response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason
if streamResponse.Choices[0].Delta.FunctionCall != nil {
if response.Choices[0].Message.FunctionCall == nil {
response.Choices[0].Message.FunctionCall = streamResponse.Choices[0].Delta.FunctionCall
} else {
response.Choices[0].Message.FunctionCall.Arguments += streamResponse.Choices[0].Delta.FunctionCall.Arguments
}
chunk, _ = json.Marshal(response.Choices[0].Message.FunctionCall) // nolint:errchkjson
choice := streamResponse.Choices[0]
chunk := []byte(choice.Delta.Content)
response.Choices[0].Message.Content += choice.Delta.Content
response.Choices[0].FinishReason = choice.FinishReason

if choice.Delta.FunctionCall != nil {
chunk = updateFunctionCall(response.Choices[0].Message, choice.Delta.FunctionCall)
}
if len(streamResponse.Choices[0].Delta.ToolCalls) > 0 {
for _, streamTool := range streamResponse.Choices[0].Delta.ToolCalls {
if streamTool.ID != "" {
currentTool = ToolCall{
ID: streamTool.ID,
Type: streamTool.Type,
Function: ToolFunction{
Name: streamTool.Function.Name,
Arguments: streamTool.Function.Arguments,
},
}
response.Choices[0].Message.ToolCalls = append(response.Choices[0].Message.ToolCalls, currentTool)
currentIndex++
} else {
response.Choices[0].Message.ToolCalls[currentIndex].Function.Arguments += streamTool.Function.Arguments
}
}
chunk, _ = json.Marshal(response.Choices[0].Message.ToolCalls) // nolint:errchkjson

if len(choice.Delta.ToolCalls) > 0 {
chunk, response.Choices[0].Message.ToolCalls = updateToolCalls(response.Choices[0].Message.ToolCalls, choice.Delta.ToolCalls)
}

if payload.StreamingFunc != nil {
Expand All @@ -451,3 +430,64 @@ func combineStreamingChatResponse(ctx context.Context, payload *ChatRequest, res
}
return &response, nil
}

func updateFunctionCall(message ChatMessage, functionCall *FunctionCall) []byte {
if message.FunctionCall == nil {
message.FunctionCall = functionCall
} else {
message.FunctionCall.Arguments += functionCall.Arguments
}
chunk, _ := json.Marshal(message.FunctionCall) // nolint:errchkjson
return chunk
}

func updateToolCalls(tools []ToolCall, delta []*ToolCall) ([]byte, []ToolCall) {
if len(delta) == 0 {
return []byte{}, tools
}
for _, t := range delta {
// if we have arguments append to the last Tool call
if t.Type == `` && t.Function.Arguments != `` {
lindex := len(tools) - 1
if lindex < 0 {
continue
}

tools[lindex].Function.Arguments += t.Function.Arguments
continue
}

// Otherwise, this is a new tool call, append that to the stack
tools = append(tools, *t)
}

chunk, _ := json.Marshal(delta) // nolint:errchkjson

return chunk, tools
}

// StreamingChatResponseTools is a helper function to append tool calls to the stack.
func StreamingChatResponseTools(tools []ToolCall, delta []*ToolCall) ([]byte, []ToolCall) {
if len(delta) == 0 {
return []byte{}, tools
}
for _, t := range delta {
// if we have arguments append to the last Tool call
if t.Type == `` && t.Function.Arguments != `` {
lindex := len(tools) - 1
if lindex < 0 {
continue
}

tools[lindex].Function.Arguments += t.Function.Arguments
continue
}

// Otherwise, this is a new tool call, append that to the stack
tools = append(tools, *t)
}

chunk, _ := json.Marshal(delta) // nolint:errchkjson

return chunk, tools
}

0 comments on commit 390a1b3

Please sign in to comment.