Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Server return with Attachment #2648

Merged
merged 11 commits into from
Apr 26, 2024
1 change: 1 addition & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func (cli *Client) dial(interfaceName string, info *ClientInfo, opts ...Referenc

return &Connection{refOpts: newRefOpts}, nil
}

func generateInvocation(methodName string, reqs []interface{}, resp interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) {
var paramsRawVals []interface{}
for _, req := range reqs {
Expand Down
59 changes: 32 additions & 27 deletions protocol/triple/triple_invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
return &result
}

ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation)
if err != nil {
result.SetError(err)
return &result
}

ctx, err = mergeAttachmentToOutgoing(ctx, invocation)
if err != nil {
result.SetError(err)
return &result
}

inRawLen := len(inRaw)

if !ti.clientManager.isIDL {
Expand Down Expand Up @@ -136,36 +143,50 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat
return &result
}

func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) (context.Context, error) {
for key, valRaw := range inv.Attachments() {
if str, ok := valRaw.(string); ok {
ctx = tri.AppendToOutgoingContext(ctx, key, str)
continue
}
if strs, ok := valRaw.([]string); ok {
for _, str := range strs {
ctx = tri.AppendToOutgoingContext(ctx, key, str)
}
continue
}
return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
}
return ctx, nil
}

// parseInvocation retrieves information from invocation.
// it returns ctx, callType, inRaw, method, error
func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) {
func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (string, []interface{}, string, error) {
callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey)
if !ok {
return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
return "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker")
}
callType, ok := callTypeRaw.(string)
if !ok {
return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
return "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw)
}
// please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen
// e.g. Client.CallUnary(... req, resp []interface, ...)
// inRaw represents req and resp
inRaw := invocation.ParameterRawValues()
method := invocation.MethodName()
if method == "" {
return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker")
}

ctx, err := parseAttachments(ctx, url, invocation)
if err != nil {
return nil, "", nil, "", err
}
parseAttachments(ctx, url, invocation)

return ctx, callType, inRaw, method, nil
return callType, inRaw, method, nil
}

// parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx
func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) {
func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) {
// retrieve users passed-in attachment
attaRaw := ctx.Value(constant.AttachmentKey)
if attaRaw != nil {
Expand All @@ -181,22 +202,6 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.
invocation.SetAttachment(key, val)
}
}
// inject attachments
for key, valRaw := range invocation.Attachments() {
if str, ok := valRaw.(string); ok {
ctx = tri.AppendToOutgoingContext(ctx, key, str)
continue
}
if strs, ok := valRaw.([]string); ok {
for _, str := range strs {
ctx = tri.AppendToOutgoingContext(ctx, key, str)
}
continue
}
return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key)
}

return ctx, nil
}

// IsAvailable get available status
Expand Down
22 changes: 13 additions & 9 deletions protocol/triple/triple_invoker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package triple

import (
"context"
"net/http"
"testing"

"dubbo.apache.org/dubbo-go/v3/common"
Expand All @@ -35,7 +36,7 @@ func Test_parseInvocation(t *testing.T) {
ctx func() context.Context
url *common.URL
invo func() protocol.Invocation
expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error)
expect func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error)
}{
{
desc: "miss callType",
Expand All @@ -46,7 +47,7 @@ func Test_parseInvocation(t *testing.T) {
invo: func() protocol.Invocation {
return invocation.NewRPCInvocationWithOptions()
},
expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
Expand All @@ -61,7 +62,7 @@ func Test_parseInvocation(t *testing.T) {
iv.SetAttribute(constant.CallTypeKey, 1)
return iv
},
expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
Expand All @@ -76,16 +77,16 @@ func Test_parseInvocation(t *testing.T) {
iv.SetAttribute(constant.CallTypeKey, constant.CallUnary)
return iv
},
expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) {
expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) {
assert.NotNil(t, err)
},
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
test.expect(t, ctx, callType, inRaw, methodName, err)
callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo())
test.expect(t, callType, inRaw, methodName, err)
})
}
}
Expand All @@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) {
},
expect: func(t *testing.T, ctx context.Context, err error) {
assert.Nil(t, err)
header := tri.ExtractFromOutgoingContext(ctx)
header := http.Header(tri.ExtractFromOutgoingContext(ctx))
assert.NotNil(t, header)
assert.Equal(t, "interface", header.Get(constant.InterfaceKey))
assert.Equal(t, "token", header.Get(constant.TokenKey))
Expand All @@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) {
},
expect: func(t *testing.T, ctx context.Context, err error) {
assert.Nil(t, err)
header := tri.ExtractFromOutgoingContext(ctx)
header := http.Header(tri.ExtractFromOutgoingContext(ctx))
assert.NotNil(t, header)
assert.Equal(t, "val1", header.Get("key1"))
assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2"))
Expand All @@ -157,7 +158,10 @@ func Test_parseAttachments(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
ctx, err := parseAttachments(test.ctx(), test.url, test.invo())
ctx := test.ctx()
inv := test.invo()
parseAttachments(ctx, test.url, inv)
ctx, err := mergeAttachmentToOutgoing(ctx, inv)
test.expect(t, ctx, err)
})
}
Expand Down
4 changes: 4 additions & 0 deletions protocol/triple/triple_protocol/duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error {
if err := discard(d.response.Body); err != nil {
return wrapIfRSTError(err)
}
// Return incoming data via context, if set outgoing data.
if ExtractFromOutgoingContext(d.ctx) != nil {
newIncomingContext(d.ctx, d.ResponseTrailer())
}
return wrapIfRSTError(d.response.Body.Close())
}

Expand Down
28 changes: 23 additions & 5 deletions protocol/triple/triple_protocol/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ func generateUnaryHandlerFunc(
// merge headers
mergeHeaders(conn.ResponseHeader(), response.Header())
mergeHeaders(conn.ResponseTrailer(), response.Trailer())
//Write the server-side return-attachment-data in the tailer to send to the caller
if data := ExtractFromOutgoingContext(ctx); data != nil {
mergeHeaders(conn.ResponseTrailer(), data)
}
return conn.Send(response.Any())
}

Expand Down Expand Up @@ -160,6 +164,9 @@ func generateClientStreamHandlerFunc(
}
mergeHeaders(conn.ResponseHeader(), res.header)
mergeHeaders(conn.ResponseTrailer(), res.trailer)
if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
mergeHeaders(conn.ResponseTrailer(), outgoingData)
}
return conn.Send(res.Msg)
}
if interceptor != nil {
Expand Down Expand Up @@ -205,7 +212,7 @@ func generateServerStreamHandlerFunc(
}
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
return streamFunc(
err := streamFunc(
ctx,
&Request{
Msg: req,
Expand All @@ -215,6 +222,13 @@ func generateServerStreamHandlerFunc(
},
&ServerStream{conn: conn},
)
if err != nil {
return err
}
if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
mergeHeaders(conn.ResponseTrailer(), outgoingData)
}
return nil
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
Expand Down Expand Up @@ -253,10 +267,14 @@ func generateBidiStreamHandlerFunc(
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
// embed header in context so that user logic could process them via FromIncomingContext
ctx = newIncomingContext(ctx, conn.RequestHeader())
return streamFunc(
ctx,
&BidiStream{conn: conn},
)
err := streamFunc(ctx, &BidiStream{conn: conn})
if err != nil {
return err
}
if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil {
mergeHeaders(conn.ResponseTrailer(), outgoingData)
}
return nil
}
if interceptor != nil {
implementation = interceptor.WrapStreamingHandler(implementation)
Expand Down
74 changes: 55 additions & 19 deletions protocol/triple/triple_protocol/header.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/base64"
"fmt"
"net/http"
"strings"
)

// EncodeBinaryHeader base64-encodes the data. It always emits unpadded values.
Expand Down Expand Up @@ -88,20 +89,45 @@ func addHeaderCanonical(h http.Header, key, value string) {
h[key] = append(h[key], value)
}

type headerIncomingKey struct{}
type headerOutgoingKey struct{}
type extraDataKey struct{}

const headerIncomingKey string = "headerIncomingKey"
const headerOutgoingKey string = "headerOutgoingKey"

type handlerOutgoingKey struct{}

func newIncomingContext(ctx context.Context, header http.Header) context.Context {
return context.WithValue(ctx, headerIncomingKey{}, header)
func newIncomingContext(ctx context.Context, data http.Header) context.Context {
var header = http.Header{}
extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
extraData = map[string]http.Header{}
}
if data != nil {
for key, vals := range data {
header[strings.ToLower(key)] = vals
}
}
extraData[headerIncomingKey] = header
return context.WithValue(ctx, extraDataKey{}, extraData)
}

// NewOutgoingContext sets headers entirely. If there are existing headers, they would be replaced.
// It is used for passing headers to server-side.
// It is like grpc.NewOutgoingContext.
// Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata.
func NewOutgoingContext(ctx context.Context, header http.Header) context.Context {
return context.WithValue(ctx, headerOutgoingKey{}, header)
func NewOutgoingContext(ctx context.Context, data http.Header) context.Context {
var header = http.Header{}
if data != nil {
for key, vals := range data {
header[strings.ToLower(key)] = vals
}
}
extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
extraData = map[string]http.Header{}
}
extraData[headerOutgoingKey] = header
return context.WithValue(ctx, extraDataKey{}, extraData)
}

// AppendToOutgoingContext merges kv pairs from user and existing headers.
Expand All @@ -112,37 +138,47 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context
if len(kv)%2 == 1 {
panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv)))
}
var header http.Header
headerRaw := ctx.Value(headerOutgoingKey{})
if headerRaw == nil {
extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
extraData = map[string]http.Header{}
ctx = context.WithValue(ctx, extraDataKey{}, extraData)
}
header, ok := extraData[headerOutgoingKey]
if !ok {
header = make(http.Header)
} else {
header = headerRaw.(http.Header)
extraData[headerOutgoingKey] = header
}
for i := 0; i < len(kv); i += 2 {
// todo(DMwangnima): think about lowering
header.Add(kv[i], kv[i+1])
header.Add(strings.ToLower(kv[i]), kv[i+1])
}
return context.WithValue(ctx, headerOutgoingKey{}, header)
return ctx
}

func ExtractFromOutgoingContext(ctx context.Context) http.Header {
headerRaw := ctx.Value(headerOutgoingKey{})
if headerRaw == nil {
extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
return nil
}
// since headerOutgoingKey is only used in triple_protocol package, we need not verify the type
return headerRaw.(http.Header)
if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok {
return nil
} else {
return outGoingDataHeader
}
}

// FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext.
// it must call after append/setOutgoingContext to return current value
// Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1.
func FromIncomingContext(ctx context.Context) (http.Header, bool) {
header, ok := ctx.Value(headerIncomingKey{}).(http.Header)
data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header)
if !ok {
return nil, false
} else if incomingDataHeader, ok := data[headerIncomingKey]; !ok {
return nil, false
} else {
return incomingDataHeader, true
}
return header, true
}

// SetHeader is used for setting response header in server-side. It is like grpc.SendHeader(ctx, header) but
Expand Down
Loading