diff --git a/go.mod b/go.mod index 96cfa6e9..5cea2de7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module github.com/envoyproxy/ai-gateway go 1.23.2 require ( + github.com/aws/aws-sdk-go v1.55.5 + github.com/envoyproxy/go-control-plane v0.13.1 github.com/stretchr/testify v1.10.0 k8s.io/apimachinery v0.31.3 sigs.k8s.io/controller-runtime v0.19.3 @@ -12,8 +14,10 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.12.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -36,6 +40,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/client_golang v1.19.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect @@ -53,6 +58,8 @@ require ( golang.org/x/text v0.20.0 // indirect golang.org/x/time v0.5.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index a1cd395d..50aa14d7 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,21 @@ +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emicklei/go-restful/v3 v3.12.0 h1:y2DdzBAURM29NFF94q6RaY4vjIH1rtwDapwQtU84iWk= github.com/emicklei/go-restful/v3 v3.12.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/envoyproxy/go-control-plane v0.13.1 h1:vPfJZCkob6yTMEgS+0TwfTUfbHjfy/6vOJ8hUWX/uXE= +github.com/envoyproxy/go-control-plane v0.13.1/go.mod h1:X45hY0mufo6Fd0KW3rqsGvQMw58jvjymeCzBU3mWyHw= +github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= +github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= github.com/evanphx/json-patch v5.7.0+incompatible h1:vgGkfT/9f8zE6tvSCe74nfpAVDQ2tG6yudJd8LBksgI= github.com/evanphx/json-patch v5.7.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= @@ -46,6 +54,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -71,6 +81,8 @@ github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -143,6 +155,10 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw= gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/grpc v1.66.2 h1:3QdXkuq3Bkh7w+ywLdLvM56cmGvQHUMZpiCzt6Rqaoo= +google.golang.org/grpc v1.66.2/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/apischema/awsbedrock/awsbedrock.go b/internal/apischema/awsbedrock/awsbedrock.go new file mode 100644 index 00000000..b868a03f --- /dev/null +++ b/internal/apischema/awsbedrock/awsbedrock.go @@ -0,0 +1,68 @@ +package awsbedrock + +import ( + "encoding/json" + "strings" +) + +// ConverseRequest is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_RequestBody +type ConverseRequest struct { + Messages []Message `json:"messages,omitempty"` +} + +// Message is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Message.html#bedrock-Type-runtime_Message-content +type Message struct { + Role string `json:"role,omitempty"` + Content []ContentBlock `json:"content,omitempty"` +} + +// ContentBlock is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlock.html +type ContentBlock struct { + Text string `json:"text,omitempty"` +} + +// ConverseResponse is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html#API_runtime_Converse_ResponseElements +type ConverseResponse struct { + Output ConverseResponseOutput `json:"output,omitempty"` + Usage TokenUsage `json:"usage,omitempty"` +} + +// ConverseResponseOutput is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseOutput.html +type ConverseResponseOutput struct { + Message Message `json:"message,omitempty"` +} + +// TokenUsage is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_TokenUsage.html +type TokenUsage struct { + InputTokens int `json:"inputTokens,omitempty"` + OutputTokens int `json:"outputTokens,omitempty"` + TotalTokens int `json:"totalTokens,omitempty"` +} + +// ConverseStreamEvent is the union of all possible event types in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ConverseStream.html +type ConverseStreamEvent struct { + ContentBlockIndex int `json:"contentBlockIndex,omitempty"` + Delta *ConverseStreamEventContentBlockDelta `json:"delta,omitempty"` + Role *string `json:"role,omitempty"` + StopReason *string `json:"stopReason,omitempty"` + Usage *TokenUsage `json:"usage,omitempty"` +} + +// String implements fmt.Stringer. +func (c ConverseStreamEvent) String() string { + buf, _ := json.Marshal(c) + return strings.ReplaceAll(string(buf), ",", ", ") +} + +// ConverseStreamEventContentBlockDelta is defined in the AWS Bedrock API: +// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlockDelta.html +type ConverseStreamEventContentBlockDelta struct { + Text string `json:"text,omitempty"` +} diff --git a/internal/apischema/openai/openai.go b/internal/apischema/openai/openai.go new file mode 100644 index 00000000..2db7ddd2 --- /dev/null +++ b/internal/apischema/openai/openai.go @@ -0,0 +1,108 @@ +// Package openai contains the followings are the OpenAI API schema definitions. +// Note that we intentionally do not use the code generation tools like OpenAPI Generator not only to keep the code simple +// but also because the OpenAI's OpenAPI definition is not compliant with the spec and the existing tools do not work well. +package openai + +import ( + "encoding/json" + "strings" +) + +// ChatCompletionRequest represents a request to /v1/chat/completions. +// https://platform.openai.com/docs/api-reference/chat/create +type ChatCompletionRequest struct { + // Model is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-model + Model string `json:"model"` + + // Messages is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages + Messages []ChatCompletionRequestMessage `json:"messages"` + + // Stream is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream + Stream bool `json:"stream,omitempty"` +} + +// ChatCompletionRequestMessage represents a message in a ChatCompletionRequest. +// https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages +type ChatCompletionRequestMessage struct { + // Role is the role of the message. The role of the message (whether it represents the user or the AI). + Role string `json:"role,omitempty"` + // Content is the content of the message. Mainly this is a string, but it can be more complex. + Content any `json:"content,omitempty"` +} + +// ChatCompletionResponse represents a response from /v1/chat/completions. +// https://platform.openai.com/docs/api-reference/chat/object +type ChatCompletionResponse struct { + // Choices are described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices + Choices []ChatCompletionResponseChoice `json:"choices,omitempty"` + + // Object is always "chat.completion" for completions. + // https://platform.openai.com/docs/api-reference/chat/object#chat/object-object + Object string `json:"object,omitempty"` + + // Usage is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage + Usage ChatCompletionResponseUsage `json:"usage,omitempty"` +} + +// ChatCompletionResponseChoice is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices +type ChatCompletionResponseChoice struct { + // Message is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices + Message ChatCompletionResponseChoiceMessage `json:"message,omitempty"` +} + +// ChatCompletionResponseChoiceMessage is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices +type ChatCompletionResponseChoiceMessage struct { + Content *string `json:"content,omitempty"` + Role string `json:"role,omitempty"` +} + +// ChatCompletionResponseUsage is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage +type ChatCompletionResponseUsage struct { + CompletionTokens int `json:"completion_tokens,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +// ChatCompletionResponseChunk is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/streaming#chat-create-messages +type ChatCompletionResponseChunk struct { + // Choices are described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices + Choices []ChatCompletionResponseChunkChoice `json:"choices,omitempty"` + + // Object is always "chat.completion.chunk" for completions. + // https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-object + Object string `json:"object,omitempty"` + + // Usage is described in the OpenAI API documentation: + // https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage + Usage *ChatCompletionResponseUsage `json:"usage,omitempty"` +} + +// String implements fmt.Stringer. +func (c *ChatCompletionResponseChunk) String() string { + buf, _ := json.Marshal(c) + return strings.ReplaceAll(string(buf), ",", ", ") +} + +// ChatCompletionResponseChunkChoice is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices +type ChatCompletionResponseChunkChoice struct { + Delta *ChatCompletionResponseChunkChoiceDelta `json:"delta,omitempty"` +} + +// ChatCompletionResponseChunkChoiceDelta is described in the OpenAI API documentation: +// https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices +type ChatCompletionResponseChunkChoiceDelta struct { + Content *string `json:"content,omitempty"` + Role *string `json:"role,omitempty"` +} diff --git a/internal/extproc/translator/openai_awsbedrock.go b/internal/extproc/translator/openai_awsbedrock.go new file mode 100644 index 00000000..fd2558db --- /dev/null +++ b/internal/extproc/translator/openai_awsbedrock.go @@ -0,0 +1,251 @@ +package translator + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + + "github.com/aws/aws-sdk-go/private/protocol/eventstream" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// newOpenAIToAWSBedrockTranslator implements [TranslatorFactory] for OpenAI to AWS Bedrock translation. +func newOpenAIToAWSBedrockTranslator(path string, l *slog.Logger) (Translator, error) { + if path == "/v1/chat/completions" { + return &openAIToAWSBedrockTranslatorV1ChatCompletion{l: l}, nil + } else { + return nil, fmt.Errorf("unsupported path: %s", path) + } +} + +// openAIToAWSBedrockTranslator implements [Translator] for /v1/chat/completions. +type openAIToAWSBedrockTranslatorV1ChatCompletion struct { + defaultTranslator + l *slog.Logger + stream bool + bufferedBody []byte + events []awsbedrock.ConverseStreamEvent +} + +// RequestBody implements [Translator.RequestBody]. +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) RequestBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, override *extprocv3http.ProcessingMode, modelName string, err error, +) { + var openAIReq openai.ChatCompletionRequest + if err := json.Unmarshal(body.Body, &openAIReq); err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to unmarshal body: %w", err) + } + + var pathTemplate string + if openAIReq.Stream { + o.stream = true + // We need to change the processing mode for streaming requests. + override = &extprocv3http.ProcessingMode{ + ResponseHeaderMode: extprocv3http.ProcessingMode_SEND, + ResponseBodyMode: extprocv3http.ProcessingMode_STREAMED, + } + pathTemplate = "/model/%s/converse-stream" + } else { + pathTemplate = "/model/%s/converse" + } + + headerMutation = &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + {Header: &corev3.HeaderValue{ + Key: ":path", + RawValue: []byte(fmt.Sprintf(pathTemplate, openAIReq.Model)), + }}, + }, + } + + var awsReq awsbedrock.ConverseRequest + awsReq.Messages = make([]awsbedrock.Message, 0, len(openAIReq.Messages)) + for _, msg := range openAIReq.Messages { + var role string + switch msg.Role { + case "user", "assistant": + role = msg.Role + case "system": + role = "assistant" + default: + return nil, nil, nil, "", fmt.Errorf("unexpected role: %s", msg.Role) + } + + text, ok := msg.Content.(string) + if ok { + awsReq.Messages = append(awsReq.Messages, awsbedrock.Message{ + Role: role, + Content: []awsbedrock.ContentBlock{{Text: text}}, + }) + } else { + return nil, nil, nil, "", fmt.Errorf("unexpected content: %v", msg.Content) + } + } + + mut := &extprocv3.BodyMutation_Body{} + if body, err := json.Marshal(awsReq); err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to marshal body: %w", err) + } else { + mut.Body = body + } + setContentLength(headerMutation, mut.Body) + return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, override, openAIReq.Model, nil +} + +// ResponseHeaders implements [Translator.ResponseHeaders]. +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, err error, +) { + if o.stream { + contentType := headers["content-type"] + if contentType != "application/vnd.amazon.eventstream" { + return nil, fmt.Errorf("unexpected content-type for streaming: %s", contentType) + } + + // We need to change the content-type to text/event-stream for streaming responses. + return &extprocv3.HeaderMutation{ + SetHeaders: []*corev3.HeaderValueOption{ + {Header: &corev3.HeaderValue{Key: "content-type", Value: "text/event-stream"}}, + }, + }, nil + } + return nil, nil +} + +// ResponseBody implements [Translator.ResponseBody]. +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, usedToken uint32, err error, +) { + mut := &extprocv3.BodyMutation_Body{} + if o.stream { + o.bufferedBody = append(o.bufferedBody, body.Body...) + o.extractAmazonEventStreamEvents() + + for i := range o.events { + event := &o.events[i] + o.l.Debug("processing event", slog.Any("event", event)) + if usage := event.Usage; usage != nil { + usedToken = uint32(usage.TotalTokens) + } + + oaiEvent, ok := o.convertEvent(event) + if !ok { + continue + } + oaiEventBytes, err := json.Marshal(oaiEvent) + if err != nil { + panic(fmt.Errorf("failed to marshal event: %w", err)) + } + mut.Body = append(mut.Body, []byte("data: ")...) + mut.Body = append(mut.Body, oaiEventBytes...) + mut.Body = append(mut.Body, []byte("\n\n")...) + } + + if body.EndOfStream { + mut.Body = append(mut.Body, []byte("data: [DONE]\n")...) + } + return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, usedToken, nil + } + + var awsResp awsbedrock.ConverseResponse + if err := json.Unmarshal(body.Body, &awsResp); err != nil { + return nil, nil, 0, fmt.Errorf("failed to unmarshal body: %w", err) + } + + usedToken = uint32(awsResp.Usage.TotalTokens) + + openAIResp := openai.ChatCompletionResponse{ + Usage: openai.ChatCompletionResponseUsage{ + TotalTokens: awsResp.Usage.TotalTokens, + PromptTokens: awsResp.Usage.InputTokens, + CompletionTokens: awsResp.Usage.OutputTokens, + }, + Object: "chat.completion", + Choices: make([]openai.ChatCompletionResponseChoice, 0, len(awsResp.Output.Message.Content)), + } + + for _, output := range awsResp.Output.Message.Content { + t := output.Text + openAIResp.Choices = append(openAIResp.Choices, openai.ChatCompletionResponseChoice{Message: openai.ChatCompletionResponseChoiceMessage{ + Content: &t, + Role: awsResp.Output.Message.Role, + }}) + } + + if body, err := json.Marshal(openAIResp); err != nil { + return nil, nil, 0, fmt.Errorf("failed to marshal body: %w", err) + } else { + mut.Body = body + } + headerMutation = &extprocv3.HeaderMutation{} + setContentLength(headerMutation, mut.Body) + return headerMutation, &extprocv3.BodyMutation{Mutation: mut}, usedToken, nil +} + +// extractAmazonEventStreamEvents extracts [awsbedrock.ConverseStreamEvent] from the buffered body. +// The extracted events are stored in the processor's events field. +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) extractAmazonEventStreamEvents() { + // TODO: Maybe reuse the reader and decoder. + r := bytes.NewReader(o.bufferedBody) + dec := eventstream.NewDecoder(r) + o.events = o.events[:0] + var lastRead int64 + for { + msg, err := dec.Decode(nil) + if err != nil { + // When failed, we stop processing the events. + // Copy the unread bytes to the beginning of the buffer. + copy(o.bufferedBody, o.bufferedBody[lastRead:]) + o.bufferedBody = o.bufferedBody[:len(o.bufferedBody)-int(lastRead)] + return + } + var event awsbedrock.ConverseStreamEvent + if err := json.Unmarshal(msg.Payload, &event); err != nil { + // When failed to parse the event, we skip it while logging the error. + o.l.Error("failed to parse event: %v", slog.Any("event", msg)) + } else { + o.events = append(o.events, event) + } + lastRead = r.Size() - int64(r.Len()) + } +} + +var emptyString = "" + +// convertEvent converts an [awsbedrock.ConverseStreamEvent] to an [openai.ChatCompletionResponseChunk]. +// This is a static method and does not require a receiver, but defined as a method for namespacing. +func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbedrock.ConverseStreamEvent) (openai.ChatCompletionResponseChunk, bool) { + const object = "chat.completion.chunk" + chunk := openai.ChatCompletionResponseChunk{Object: object} + + switch { + case event.Usage != nil: + chunk.Usage = &openai.ChatCompletionResponseUsage{ + TotalTokens: event.Usage.TotalTokens, + PromptTokens: event.Usage.InputTokens, + CompletionTokens: event.Usage.OutputTokens, + } + case event.Role != nil: + chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{ + Delta: &openai.ChatCompletionResponseChunkChoiceDelta{ + Role: event.Role, + Content: &emptyString, + }, + }) + case event.Delta != nil: + chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{ + Delta: &openai.ChatCompletionResponseChunkChoiceDelta{ + Content: &event.Delta.Text, + }, + }) + default: + return chunk, false + } + return chunk, true +} diff --git a/internal/extproc/translator/openai_awsbedrock_test.go b/internal/extproc/translator/openai_awsbedrock_test.go new file mode 100644 index 00000000..925a82e3 --- /dev/null +++ b/internal/extproc/translator/openai_awsbedrock_test.go @@ -0,0 +1,374 @@ +package translator + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "strconv" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/private/protocol/eventstream" + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/apischema/awsbedrock" + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestNewOpenAIToAWSBedrockTranslator(t *testing.T) { + t.Run("unsupported path", func(t *testing.T) { + _, err := newOpenAIToAWSBedrockTranslator("unsupported-path", slog.Default()) + require.Error(t, err) + }) + t.Run("v1/chat/completions", func(t *testing.T) { + translator, err := newOpenAIToAWSBedrockTranslator("/v1/chat/completions", slog.Default()) + require.NoError(t, err) + require.NotNil(t, translator) + }) +} + +func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_RequestBody(t *testing.T) { + t.Run("invalid body", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + _, _, _, _, err := o.RequestBody(&extprocv3.HttpBody{Body: []byte("invalid")}) + require.Error(t, err) + }) + t.Run("valid body", func(t *testing.T) { + for _, stream := range []bool{true, false} { + t.Run(fmt.Sprintf("stream=%t", stream), func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + originalReq := openai.ChatCompletionRequest{ + Stream: stream, + Model: "gpt-4o", + Messages: []openai.ChatCompletionRequestMessage{ + {Content: "from-system", Role: "system"}, + {Content: "from-user", Role: "user"}, + {Content: "part1", Role: "user"}, + {Content: "part2", Role: "user"}, + }, + } + + body, err := json.Marshal(originalReq) + require.NoError(t, err) + + hm, bm, mode, modelName, err := o.RequestBody(&extprocv3.HttpBody{Body: body}) + var expPath string + if stream { + expPath = "/model/gpt-4o/converse-stream" + require.True(t, o.stream) + require.NotNil(t, mode) + require.Equal(t, extprocv3http.ProcessingMode_STREAMED, mode.ResponseBodyMode) + require.Equal(t, extprocv3http.ProcessingMode_SEND, mode.ResponseHeaderMode) + } else { + expPath = "/model/gpt-4o/converse" + require.False(t, o.stream) + require.Nil(t, mode) + } + require.NoError(t, err) + require.Equal(t, "gpt-4o", modelName) + require.NotNil(t, hm) + require.NotNil(t, hm.SetHeaders) + require.Len(t, hm.SetHeaders, 2) + require.Equal(t, ":path", hm.SetHeaders[0].Header.Key) + require.Equal(t, expPath, string(hm.SetHeaders[0].Header.RawValue)) + require.Equal(t, "content-length", hm.SetHeaders[1].Header.Key) + newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body + require.Equal(t, strconv.Itoa(len(newBody)), string(hm.SetHeaders[1].Header.RawValue)) + + var awsReq awsbedrock.ConverseRequest + err = json.Unmarshal(newBody, &awsReq) + require.NoError(t, err) + require.NotNil(t, awsReq.Messages) + require.Len(t, awsReq.Messages, 4) + for _, msg := range awsReq.Messages { + t.Log(msg) + } + require.Equal(t, "assistant", awsReq.Messages[0].Role) + require.Equal(t, "from-system", awsReq.Messages[0].Content[0].Text) + require.Equal(t, "user", awsReq.Messages[1].Role) + require.Equal(t, "from-user", awsReq.Messages[1].Content[0].Text) + require.Equal(t, "user", awsReq.Messages[2].Role) + require.Equal(t, "part1", awsReq.Messages[2].Content[0].Text) + require.Equal(t, "user", awsReq.Messages[3].Role) + require.Equal(t, "part2", awsReq.Messages[3].Content[0].Text) + }) + } + }) +} + +func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseHeaders(t *testing.T) { + t.Run("streaming", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{stream: true} + hm, err := o.ResponseHeaders(map[string]string{ + "content-type": "application/vnd.amazon.eventstream", + }) + require.NoError(t, err) + require.NotNil(t, hm) + require.NotNil(t, hm.SetHeaders) + require.Len(t, hm.SetHeaders, 1) + require.Equal(t, "content-type", hm.SetHeaders[0].Header.Key) + require.Equal(t, "text/event-stream", hm.SetHeaders[0].Header.Value) + }) + t.Run("non-streaming", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + hm, err := o.ResponseHeaders(nil) + require.NoError(t, err) + require.Nil(t, hm) + }) +} + +func TestOpenAIToAWSBedrockTranslatorV1ChatCompletion_ResponseBody(t *testing.T) { + t.Run("streaming", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{stream: true, l: slog.Default()} + buf, err := base64.StdEncoding.DecodeString(base64RealStreamingEvents) + require.NoError(t, err) + + var results []string + for i := 0; i < len(buf); i++ { + hm, bm, usedToken, err := o.ResponseBody(&extprocv3.HttpBody{Body: []byte{buf[i]}, EndOfStream: i == len(buf)-1}) + require.NoError(t, err) + require.Nil(t, hm) + require.NotNil(t, bm) + require.NotNil(t, bm.Mutation) + newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body + if len(newBody) > 0 { + results = append(results, string(newBody)) + } + if usedToken > 0 { + require.Equal(t, uint32(77), usedToken) + } + } + + result := strings.Join(results, "") + + require.Equal(t, `data: {"choices":[{"delta":{"content":"","role":"assistant"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":"Don"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":"'t worry, I'm here to help. It"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":" seems like you're testing my ability to respond appropriately"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":". If you'd like to continue the test,"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":" I'm ready."}}],"object":"chat.completion.chunk"} + +data: {"object":"chat.completion.chunk","usage":{"completion_tokens":36,"prompt_tokens":41,"total_tokens":77}} + +data: [DONE] +`, result) + }) + t.Run("non-streaming", func(t *testing.T) { + t.Run("invalid body", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + _, _, _, err := o.ResponseBody(&extprocv3.HttpBody{Body: []byte("invalid")}) + require.Error(t, err) + }) + t.Run("valid body", func(t *testing.T) { + originalAWSResp := awsbedrock.ConverseResponse{ + Usage: awsbedrock.TokenUsage{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + Output: awsbedrock.ConverseResponseOutput{ + Message: awsbedrock.Message{ + Role: "assistant", + Content: []awsbedrock.ContentBlock{ + {Text: "response"}, + {Text: "from"}, + {Text: "assistant"}, + }, + }, + }, + } + body, err := json.Marshal(originalAWSResp) + require.NoError(t, err) + + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + hm, bm, usedToken, err := o.ResponseBody(&extprocv3.HttpBody{Body: body}) + require.NoError(t, err) + require.NotNil(t, bm) + require.NotNil(t, bm.Mutation) + require.NotNil(t, bm.Mutation.(*extprocv3.BodyMutation_Body)) + newBody := bm.Mutation.(*extprocv3.BodyMutation_Body).Body + require.NotNil(t, newBody) + require.NotNil(t, hm) + require.NotNil(t, hm.SetHeaders) + require.Len(t, hm.SetHeaders, 1) + require.Equal(t, "content-length", hm.SetHeaders[0].Header.Key) + require.Equal(t, strconv.Itoa(len(newBody)), string(hm.SetHeaders[0].Header.RawValue)) + + var openAIResp openai.ChatCompletionResponse + err = json.Unmarshal(newBody, &openAIResp) + require.NoError(t, err) + require.NotNil(t, openAIResp.Usage) + require.Equal(t, uint32(30), usedToken) + require.Equal(t, 30, openAIResp.Usage.TotalTokens) + require.Equal(t, 10, openAIResp.Usage.PromptTokens) + require.Equal(t, 20, openAIResp.Usage.CompletionTokens) + + require.NotNil(t, openAIResp.Choices) + require.Len(t, openAIResp.Choices, 3) + + require.Equal(t, "response", *openAIResp.Choices[0].Message.Content) + require.Equal(t, "from", *openAIResp.Choices[1].Message.Content) + require.Equal(t, "assistant", *openAIResp.Choices[2].Message.Content) + }) + }) +} + +const base64RealStreamingEvents = "AAAAnwAAAFKzEV9wCzpldmVudC10eXBlBwAMbWVzc2FnZVN0YXJ0DTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsicCI6ImFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6QUJDREVGR0giLCJyb2xlIjoiYXNzaXN0YW50In0i9wVBAAAAxQAAAFex2HyVCzpldmVudC10eXBlBwARY29udGVudEJsb2NrRGVsdGENOmNvbnRlbnQtdHlwZQcAEGFwcGxpY2F0aW9uL2pzb24NOm1lc3NhZ2UtdHlwZQcABWV2ZW50eyJjb250ZW50QmxvY2tJbmRleCI6MCwiZGVsdGEiOnsidGV4dCI6IkRvbiJ9LCJwIjoiYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNERUZHSElKS0xNTk8ifb/whawAAADAAAAAV3k48+ULOmV2ZW50LXR5cGUHABFjb250ZW50QmxvY2tEZWx0YQ06Y29udGVudC10eXBlBwAQYXBwbGljYXRpb24vanNvbg06bWVzc2FnZS10eXBlBwAFZXZlbnR7ImNvbnRlbnRCbG9ja0luZGV4IjowLCJkZWx0YSI6eyJ0ZXh0IjoiJ3Qgd29ycnksIEknbSBoZXJlIHRvIGhlbHAuIEl0In0sInAiOiJhYmNkZWZnaGkifenahv0AAADgAAAAV7j53OELOmV2ZW50LXR5cGUHABFjb250ZW50QmxvY2tEZWx0YQ06Y29udGVudC10eXBlBwAQYXBwbGljYXRpb24vanNvbg06bWVzc2FnZS10eXBlBwAFZXZlbnR7ImNvbnRlbnRCbG9ja0luZGV4IjowLCJkZWx0YSI6eyJ0ZXh0IjoiIHNlZW1zIGxpa2UgeW91J3JlIHRlc3RpbmcgbXkgYWJpbGl0eSB0byByZXNwb25kIGFwcHJvcHJpYXRlbHkifSwicCI6ImFiY2RlZmdoaSJ9dNZCqAAAAM8AAABX+2hkNAs6ZXZlbnQtdHlwZQcAEWNvbnRlbnRCbG9ja0RlbHRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsImRlbHRhIjp7InRleHQiOiIuIElmIHlvdSdkIGxpa2UgdG8gY29udGludWUgdGhlIHRlc3QsIn0sInAiOiJhYmNkZWZnaGlqa2xtbm9wcSJ9xQJqAgAAALUAAABXSAqcWgs6ZXZlbnQtdHlwZQcAEWNvbnRlbnRCbG9ja0RlbHRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsImRlbHRhIjp7InRleHQiOiIgSSdtIHJlYWR5LiJ9LCJwIjoiYWJjZGVmZ2hpamtsbW5vcHEifTOb7esAAAC5AAAAVvr9Qc0LOmV2ZW50LXR5cGUHABBjb250ZW50QmxvY2tTdG9wDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsiY29udGVudEJsb2NrSW5kZXgiOjAsInAiOiJhYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ekFCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaMCJ9iABE1AAAAI0AAABRMDjKKAs6ZXZlbnQtdHlwZQcAC21lc3NhZ2VTdG9wDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsicCI6ImFiY2RlZmdoaWprbCIsInN0b3BSZWFzb24iOiJlbmRfdHVybiJ9LttU3QAAAPoAAABO9sL7Ags6ZXZlbnQtdHlwZQcACG1ldGFkYXRhDTpjb250ZW50LXR5cGUHABBhcHBsaWNhdGlvbi9qc29uDTptZXNzYWdlLXR5cGUHAAVldmVudHsibWV0cmljcyI6eyJsYXRlbmN5TXMiOjQ1Mn0sInAiOiJhYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ekFCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaMDEyMzQ1IiwidXNhZ2UiOnsiaW5wdXRUb2tlbnMiOjQxLCJvdXRwdXRUb2tlbnMiOjM2LCJ0b3RhbFRva2VucyI6Nzd9fX96gYI=" + +func TestOpenAIToAWSBedrockTranslatorExtractAmazonEventStreamEvents(t *testing.T) { + buf := bytes.NewBuffer(nil) + e := eventstream.NewEncoder(buf) + var offsets []int + for _, data := range []awsbedrock.ConverseStreamEvent{ + {Delta: &awsbedrock.ConverseStreamEventContentBlockDelta{Text: "1"}}, + {Delta: &awsbedrock.ConverseStreamEventContentBlockDelta{Text: "2"}}, + {Delta: &awsbedrock.ConverseStreamEventContentBlockDelta{Text: "3"}}, + } { + offsets = append(offsets, buf.Len()) + eventPayload, err := json.Marshal(data) + require.NoError(t, err) + err = e.Encode(eventstream.Message{ + Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("content")}}, + Payload: eventPayload, + }) + require.NoError(t, err) + } + + eventBytes := buf.Bytes() + + t.Run("all-at-once", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + o.bufferedBody = eventBytes + o.extractAmazonEventStreamEvents() + require.Len(t, o.events, 3) + require.Empty(t, o.bufferedBody) + for i, text := range []string{"1", "2", "3"} { + require.Equal(t, text, o.events[i].Delta.Text) + } + }) + + t.Run("in-chunks", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + o.bufferedBody = eventBytes[0:1] + o.extractAmazonEventStreamEvents() + require.Empty(t, o.events) + require.Len(t, o.bufferedBody, 1) + + o.bufferedBody = eventBytes[0 : offsets[1]+5] + o.extractAmazonEventStreamEvents() + require.Len(t, o.events, 1) + require.Equal(t, eventBytes[offsets[1]:offsets[1]+5], o.bufferedBody) + + o.events = o.events[:0] + o.bufferedBody = eventBytes[0 : offsets[2]+5] + o.extractAmazonEventStreamEvents() + require.Len(t, o.events, 2) + require.Equal(t, eventBytes[offsets[2]:offsets[2]+5], o.bufferedBody) + }) + + t.Run("real events", func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{l: slog.Default()} + var err error + o.bufferedBody, err = base64.StdEncoding.DecodeString(base64RealStreamingEvents) + require.NoError(t, err) + o.extractAmazonEventStreamEvents() + + var texts []string + var usage *awsbedrock.TokenUsage + for _, event := range o.events { + t.Log(event.String()) + if delta := event.Delta; delta != nil && delta.Text != "" { + texts = append(texts, event.Delta.Text) + } + if u := event.Usage; u != nil { + usage = u + } + } + require.Equal(t, + "Don't worry, I'm here to help. It seems like you're testing my ability to respond appropriately. If you'd like to continue the test, I'm ready.", + strings.Join(texts, ""), + ) + require.NotNil(t, usage) + require.Equal(t, 77, usage.TotalTokens) + }) +} + +func TestOpenAIToAWSBedrockTranslator_convertEvent(t *testing.T) { + ptrOf := func(s string) *string { return &s } + for _, tc := range []struct { + name string + in awsbedrock.ConverseStreamEvent + out *openai.ChatCompletionResponseChunk + }{ + { + name: "usage", + in: awsbedrock.ConverseStreamEvent{ + Usage: &awsbedrock.TokenUsage{ + InputTokens: 10, + OutputTokens: 20, + TotalTokens: 30, + }, + }, + out: &openai.ChatCompletionResponseChunk{ + Object: "chat.completion.chunk", + Usage: &openai.ChatCompletionResponseUsage{ + TotalTokens: 30, + PromptTokens: 10, + CompletionTokens: 20, + }, + }, + }, + { + name: "role", + in: awsbedrock.ConverseStreamEvent{ + Role: ptrOf("assistant"), + }, + out: &openai.ChatCompletionResponseChunk{ + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{ + { + Delta: &openai.ChatCompletionResponseChunkChoiceDelta{ + Role: ptrOf("assistant"), + Content: &emptyString, + }, + }, + }, + }, + }, + { + name: "delta", + in: awsbedrock.ConverseStreamEvent{ + Delta: &awsbedrock.ConverseStreamEventContentBlockDelta{Text: "response"}, + }, + out: &openai.ChatCompletionResponseChunk{ + Object: "chat.completion.chunk", + Choices: []openai.ChatCompletionResponseChunkChoice{ + { + Delta: &openai.ChatCompletionResponseChunkChoiceDelta{ + Content: ptrOf("response"), + }, + }, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + o := &openAIToAWSBedrockTranslatorV1ChatCompletion{} + chunk, ok := o.convertEvent(&tc.in) + if tc.out == nil { + require.False(t, ok) + } else { + require.Equal(t, *tc.out, chunk) + } + }) + } +} diff --git a/internal/extproc/translator/openai_openai.go b/internal/extproc/translator/openai_openai.go new file mode 100644 index 00000000..d97bb5c6 --- /dev/null +++ b/internal/extproc/translator/openai_openai.go @@ -0,0 +1,98 @@ +package translator + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +// newOpenAIToOpenAITranslator implements [TranslatorFactory] for OpenAI to OpenAI translation. +func newOpenAIToOpenAITranslator(path string, l *slog.Logger) (Translator, error) { + if path == "/v1/chat/completions" { + return &openAIToOpenAITranslatorV1ChatCompletion{l: l}, nil + } else { + return nil, fmt.Errorf("unsupported path: %s", path) + } +} + +// openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions. +type openAIToOpenAITranslatorV1ChatCompletion struct { + defaultTranslator + l *slog.Logger + stream bool + buffered []byte + bufferingDone bool +} + +// RequestBody implements [RequestBody]. +func (o *openAIToOpenAITranslatorV1ChatCompletion) RequestBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, override *extprocv3http.ProcessingMode, modelName string, err error, +) { + var req openai.ChatCompletionRequest + if err := json.Unmarshal(body.Body, &req); err != nil { + return nil, nil, nil, "", fmt.Errorf("failed to unmarshal body: %w", err) + } + + if req.Stream { + o.stream = true + override = &extprocv3http.ProcessingMode{ + ResponseHeaderMode: extprocv3http.ProcessingMode_SEND, + ResponseBodyMode: extprocv3http.ProcessingMode_STREAMED, + } + } + return nil, nil, override, req.Model, nil +} + +// ResponseBody implements [Translator.ResponseBody]. +func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, bodyMutation *extprocv3.BodyMutation, usedToken uint32, err error, +) { + if o.stream { + if !o.bufferingDone { + o.buffered = append(o.buffered, body.Body...) + usedToken = o.extractUsageFromBufferEvent() + } + return + } + var resp openai.ChatCompletionResponse + if err := json.Unmarshal(body.Body, &resp); err != nil { + return nil, nil, 0, fmt.Errorf("failed to unmarshal body: %w", err) + } + usedToken = uint32(resp.Usage.TotalTokens) + return +} + +var dataPrefix = []byte("data: ") + +// extractUsageFromBufferEvent extracts the token usage from the buffered event. +// Once the usage is extracted, it returns the number of tokens used, and bufferingDone is set to true. +func (o *openAIToOpenAITranslatorV1ChatCompletion) extractUsageFromBufferEvent() (usedToken uint32) { + for { + i := bytes.IndexByte(o.buffered, '\n') + if i == -1 { + return 0 + } + line := o.buffered[:i] + o.buffered = o.buffered[i+1:] + if !bytes.HasPrefix(line, dataPrefix) { + continue + } + var event openai.ChatCompletionResponseChunk + if err := json.Unmarshal(bytes.TrimPrefix(line, dataPrefix), &event); err != nil { + o.l.Warn("failed to unmarshal the event", slog.Any("error", err)) + continue + } + if usage := event.Usage; usage != nil { + usedToken = uint32(usage.TotalTokens) + o.bufferingDone = true + o.buffered = nil + return + } + } +} diff --git a/internal/extproc/translator/openai_openai_test.go b/internal/extproc/translator/openai_openai_test.go new file mode 100644 index 00000000..12eab552 --- /dev/null +++ b/internal/extproc/translator/openai_openai_test.go @@ -0,0 +1,180 @@ +package translator + +import ( + "encoding/json" + "fmt" + "log/slog" + "testing" + + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/apischema/openai" +) + +func TestNewOpenAIToOpenAITranslator(t *testing.T) { + t.Run("unsupported path", func(t *testing.T) { + _, err := newOpenAIToOpenAITranslator("/v1/foo/bar", slog.Default()) + require.Error(t, err) + }) + t.Run("/v1/chat/completions", func(t *testing.T) { + translator, err := newOpenAIToOpenAITranslator("/v1/chat/completions", slog.Default()) + require.NoError(t, err) + require.NotNil(t, translator) + }) +} + +func TestOpenAIToOpenAITranslatorV1ChatCompletionRequestBody(t *testing.T) { + t.Run("invalid body", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{} + _, _, _, _, err := o.RequestBody(&extprocv3.HttpBody{Body: []byte("invalid")}) + require.Error(t, err) + }) + t.Run("valid body", func(t *testing.T) { + for _, stream := range []bool{true, false} { + t.Run(fmt.Sprintf("stream=%t", stream), func(t *testing.T) { + originalReq := openai.ChatCompletionRequest{Model: "foo-bar-ai", Stream: stream} + body, err := json.Marshal(originalReq) + require.NoError(t, err) + + o := &openAIToOpenAITranslatorV1ChatCompletion{} + hm, bm, mode, modelName, err := o.RequestBody(&extprocv3.HttpBody{Body: body}) + require.Nil(t, bm) + require.NoError(t, err) + require.Equal(t, "foo-bar-ai", modelName) + require.Equal(t, stream, o.stream) + if stream { + require.NotNil(t, mode) + require.Equal(t, extprocv3http.ProcessingMode_SEND, mode.ResponseHeaderMode) + require.Equal(t, extprocv3http.ProcessingMode_STREAMED, mode.ResponseBodyMode) + } else { + require.Nil(t, mode) + } + + require.Nil(t, hm) + }) + } + }) +} + +func TestOpenAIToOpenAITranslatorV1ChatCompletionResponseBody(t *testing.T) { + t.Run("streaming", func(t *testing.T) { + // This is the real event stream from OpenAI. + wholeBody := []byte(` +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":"This"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" is"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" a"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" test"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":12,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + +data: [DONE] + +`) + + o := &openAIToOpenAITranslatorV1ChatCompletion{stream: true, l: slog.Default()} + var usedToken uint32 + for i := 0; i < len(wholeBody); i++ { + hm, bm, _usedToken, err := o.ResponseBody(&extprocv3.HttpBody{Body: wholeBody[i : i+1]}) + require.NoError(t, err) + require.Nil(t, hm) + require.Nil(t, bm) + if _usedToken > 0 { + usedToken = _usedToken + } + if usedToken > 0 { + require.True(t, o.bufferingDone) + } else { + require.False(t, o.bufferingDone) + } + } + require.Equal(t, uint32(25), usedToken) + }) + t.Run("non-streaming", func(t *testing.T) { + t.Run("invalid body", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{} + _, _, _, err := o.ResponseBody(&extprocv3.HttpBody{Body: []byte("invalid")}) + require.Error(t, err) + }) + t.Run("valid body", func(t *testing.T) { + var resp openai.ChatCompletionResponse + resp.Usage.TotalTokens = 42 + body, err := json.Marshal(resp) + require.NoError(t, err) + o := &openAIToOpenAITranslatorV1ChatCompletion{} + _, _, usedToken, err := o.ResponseBody(&extprocv3.HttpBody{Body: body}) + require.NoError(t, err) + require.Equal(t, uint32(42), usedToken) + }) + }) +} + +func TestExtractUsageFromBufferEvent(t *testing.T) { + logger := slog.Default() + + t.Run("valid usage data", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{l: logger} + o.buffered = []byte("data: {\"usage\": {\"total_tokens\": 42}}\n") + usedToken := o.extractUsageFromBufferEvent() + require.Equal(t, uint32(42), usedToken) + require.True(t, o.bufferingDone) + require.Nil(t, o.buffered) + }) + + t.Run("valid usage data after invalid", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{l: logger} + o.buffered = []byte("data: invalid\ndata: {\"usage\": {\"total_tokens\": 42}}\n") + usedToken := o.extractUsageFromBufferEvent() + require.Equal(t, uint32(42), usedToken) + require.True(t, o.bufferingDone) + require.Nil(t, o.buffered) + }) + + t.Run("no usage data and then become valid", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{l: logger} + o.buffered = []byte("data: {}\n\ndata: ") + usedToken := o.extractUsageFromBufferEvent() + require.Equal(t, uint32(0), usedToken) + require.False(t, o.bufferingDone) + require.NotNil(t, o.buffered) + + o.buffered = append(o.buffered, []byte("{\"usage\": {\"total_tokens\": 42}}\n")...) + usedToken = o.extractUsageFromBufferEvent() + require.Equal(t, uint32(42), usedToken) + require.True(t, o.bufferingDone) + require.Nil(t, o.buffered) + }) + + t.Run("invalid JSON", func(t *testing.T) { + o := &openAIToOpenAITranslatorV1ChatCompletion{l: logger} + o.buffered = []byte("data: invalid\n") + usedToken := o.extractUsageFromBufferEvent() + require.Equal(t, uint32(0), usedToken) + require.False(t, o.bufferingDone) + require.NotNil(t, o.buffered) + }) +} diff --git a/internal/extproc/translator/translator.go b/internal/extproc/translator/translator.go new file mode 100644 index 00000000..9cdb603c --- /dev/null +++ b/internal/extproc/translator/translator.go @@ -0,0 +1,112 @@ +package translator + +import ( + "fmt" + "log/slog" + + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + extprocv3http "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" + + aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" +) + +// Factory creates a [Translator] for the given API schema combination and request path. +// +// - `path`: the path of the request. +// - `l`: the logger. +type Factory func(path string, l *slog.Logger) (Translator, error) + +// NewFactory returns a callback function that creates a translator for the given API schema combination. +func NewFactory(in, out aigv1a1.LLMAPISchema) (Factory, error) { + if in.Schema == aigv1a1.APISchemaOpenAI { + // TODO: currently, we ignore the LLMAPISchema."Version" field. + switch out.Schema { + case aigv1a1.APISchemaOpenAI: + return newOpenAIToOpenAITranslator, nil + case aigv1a1.APISchemaAWSBedrock: + return newOpenAIToAWSBedrockTranslator, nil + } + } + return nil, fmt.Errorf("unsupported API schema combination: client=%s, backend=%s", in, out) +} + +// Translator translates the request and response messages between the client and the backend API schemas for a specific path. +// The implementation can embed [defaultTranslator] to avoid implementing all methods. +// +// The instance of [Translator] is created by a [Factory]. +// +// This is created per request and is not thread-safe. +type Translator interface { + // RequestHeaders translates the request headers. + // - `headers` is the request headers. + // - This returns `headerMutation` that can be nil to indicate no mutation. + RequestHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, + err error, + ) + + // RequestBody translates the request body. + // - `body` is the request body either chunk or the entire body, depending on the context. + // - This returns `headerMutation` and `bodyMutation` that can be nil to indicate no mutation. + // - This returns `override` that to change the processing mode. This is used to process streaming requests properly. + // - This returns `modelName` that is extracted from the body. + RequestBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, + bodyMutation *extprocv3.BodyMutation, + override *extprocv3http.ProcessingMode, + modelName string, + err error, + ) + + // ResponseHeaders translates the response headers. + // - `headers` is the response headers. + // - This returns `headerMutation` that can be nil to indicate no mutation. + ResponseHeaders(headers map[string]string) ( + headerMutation *extprocv3.HeaderMutation, + err error, + ) + + // ResponseBody translates the response body. + // - `body` is the response body either chunk or the entire body, depending on the context. + // - This returns `headerMutation` and `bodyMutation` that can be nil to indicate no mutation. + // - This returns `usedToken` that is extracted from the body and will be used to do token rate limiting. + ResponseBody(body *extprocv3.HttpBody) ( + headerMutation *extprocv3.HeaderMutation, + bodyMutation *extprocv3.BodyMutation, + usedToken uint32, + err error, + ) +} + +// defaultTranslator is a no-op translator that implements [Translator]. +type defaultTranslator struct{} + +// RequestHeaders implements [Translator.RequestHeaders]. +func (d *defaultTranslator) RequestHeaders(map[string]string) (*extprocv3.HeaderMutation, error) { + return nil, nil +} + +// RequestBody implements [Translator.RequestBody]. +func (d *defaultTranslator) RequestBody(*extprocv3.HttpBody) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, *extprocv3http.ProcessingMode, string, error) { + return nil, nil, nil, "", nil +} + +// ResponseHeaders implements [Translator.ResponseBody]. +func (d *defaultTranslator) ResponseHeaders(map[string]string) (*extprocv3.HeaderMutation, error) { + return nil, nil +} + +// ResponseBody implements [Translator.ResponseBody]. +func (d *defaultTranslator) ResponseBody(*extprocv3.HttpBody) (*extprocv3.HeaderMutation, *extprocv3.BodyMutation, uint32, error) { + return nil, nil, 0, nil +} + +func setContentLength(headers *extprocv3.HeaderMutation, body []byte) { + headers.SetHeaders = append(headers.SetHeaders, &corev3.HeaderValueOption{ + Header: &corev3.HeaderValue{ + Key: "content-length", + RawValue: []byte(fmt.Sprintf("%d", len(body))), + }, + }) +} diff --git a/internal/extproc/translator/translator_test.go b/internal/extproc/translator/translator_test.go new file mode 100644 index 00000000..2388351d --- /dev/null +++ b/internal/extproc/translator/translator_test.go @@ -0,0 +1,39 @@ +package translator + +import ( + "log/slog" + "testing" + + "github.com/stretchr/testify/require" + + aigv1a1 "github.com/envoyproxy/ai-gateway/api/v1alpha1" +) + +func TestNewFactory(t *testing.T) { + t.Run("error", func(t *testing.T) { + _, err := NewFactory(aigv1a1.LLMAPISchema{Schema: "Foo", Version: "v100"}, aigv1a1.LLMAPISchema{Schema: "Bar", Version: "v123"}) + require.ErrorContains(t, err, "unsupported API schema combination: client={Foo v100}, backend={Bar v123}") + }) + t.Run("openai to openai", func(t *testing.T) { + f, err := NewFactory(aigv1a1.LLMAPISchema{Schema: aigv1a1.APISchemaOpenAI}, aigv1a1.LLMAPISchema{Schema: aigv1a1.APISchemaOpenAI}) + require.NoError(t, err) + require.NotNil(t, f) + + tl, err := f("/v1/chat/completions", slog.Default()) + require.NoError(t, err) + require.NotNil(t, tl) + _, ok := tl.(*openAIToOpenAITranslatorV1ChatCompletion) + require.True(t, ok) + }) + t.Run("openai to aws bedrock", func(t *testing.T) { + f, err := NewFactory(aigv1a1.LLMAPISchema{Schema: aigv1a1.APISchemaOpenAI}, aigv1a1.LLMAPISchema{Schema: aigv1a1.APISchemaAWSBedrock}) + require.NoError(t, err) + require.NotNil(t, f) + + tl, err := f("/v1/chat/completions", slog.Default()) + require.NoError(t, err) + require.NotNil(t, tl) + _, ok := tl.(*openAIToAWSBedrockTranslatorV1ChatCompletion) + require.True(t, ok) + }) +}