diff --git a/client/client.go b/client/client.go index 98ab66733e..a2bcaef1e5 100644 --- a/client/client.go +++ b/client/client.go @@ -122,6 +122,7 @@ func (cli *Client) dial(interfaceName string, info *ClientInfo, opts ...Referenc return &Connection{refOpts: newRefOpts}, nil } + func generateInvocation(methodName string, reqs []interface{}, resp interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) { var paramsRawVals []interface{} for _, req := range reqs { diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index e087784298..4c1eb6868b 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } - ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation) + callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation) if err != nil { result.SetError(err) return &result } + + ctx, err = mergeAttachmentToOutgoing(ctx, invocation) + if err != nil { + result.SetError(err) + return &result + } + inRawLen := len(inRaw) if !ti.clientManager.isIDL { @@ -136,16 +143,33 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } +func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) (context.Context, error) { + for key, valRaw := range inv.Attachments() { + if str, ok := valRaw.(string); ok { + ctx = tri.AppendToOutgoingContext(ctx, key, str) + continue + } + if strs, ok := valRaw.([]string); ok { + for _, str := range strs { + ctx = tri.AppendToOutgoingContext(ctx, key, str) + } + continue + } + return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key) + } + return ctx, nil +} + // parseInvocation retrieves information from invocation. // it returns ctx, callType, inRaw, method, error -func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) { +func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (string, []interface{}, string, error) { callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey) if !ok { - return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker") + return "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker") } callType, ok := callTypeRaw.(string) if !ok { - return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw) + return "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw) } // please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen // e.g. Client.CallUnary(... req, resp []interface, ...) @@ -153,19 +177,16 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.I inRaw := invocation.ParameterRawValues() method := invocation.MethodName() if method == "" { - return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker") + return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker") } - ctx, err := parseAttachments(ctx, url, invocation) - if err != nil { - return nil, "", nil, "", err - } + parseAttachments(ctx, url, invocation) - return ctx, callType, inRaw, method, nil + return callType, inRaw, method, nil } // parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx -func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) { +func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) { // retrieve users passed-in attachment attaRaw := ctx.Value(constant.AttachmentKey) if attaRaw != nil { @@ -181,22 +202,6 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol. invocation.SetAttachment(key, val) } } - // inject attachments - for key, valRaw := range invocation.Attachments() { - if str, ok := valRaw.(string); ok { - ctx = tri.AppendToOutgoingContext(ctx, key, str) - continue - } - if strs, ok := valRaw.([]string); ok { - for _, str := range strs { - ctx = tri.AppendToOutgoingContext(ctx, key, str) - } - continue - } - return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key) - } - - return ctx, nil } // IsAvailable get available status diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index 7d14d42394..e9dcc968cc 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -19,6 +19,7 @@ package triple import ( "context" + "net/http" "testing" "dubbo.apache.org/dubbo-go/v3/common" @@ -35,7 +36,7 @@ func Test_parseInvocation(t *testing.T) { ctx func() context.Context url *common.URL invo func() protocol.Invocation - expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) + expect func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) }{ { desc: "miss callType", @@ -46,7 +47,7 @@ func Test_parseInvocation(t *testing.T) { invo: func() protocol.Invocation { return invocation.NewRPCInvocationWithOptions() }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -61,7 +62,7 @@ func Test_parseInvocation(t *testing.T) { iv.SetAttribute(constant.CallTypeKey, 1) return iv }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -76,7 +77,7 @@ func Test_parseInvocation(t *testing.T) { iv.SetAttribute(constant.CallTypeKey, constant.CallUnary) return iv }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -84,8 +85,8 @@ func Test_parseInvocation(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo()) - test.expect(t, ctx, callType, inRaw, methodName, err) + callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo()) + test.expect(t, callType, inRaw, methodName, err) }) } } @@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) { }, expect: func(t *testing.T, ctx context.Context, err error) { assert.Nil(t, err) - header := tri.ExtractFromOutgoingContext(ctx) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) assert.NotNil(t, header) assert.Equal(t, "interface", header.Get(constant.InterfaceKey)) assert.Equal(t, "token", header.Get(constant.TokenKey)) @@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) { }, expect: func(t *testing.T, ctx context.Context, err error) { assert.Nil(t, err) - header := tri.ExtractFromOutgoingContext(ctx) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) assert.NotNil(t, header) assert.Equal(t, "val1", header.Get("key1")) assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2")) @@ -157,7 +158,10 @@ func Test_parseAttachments(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ctx, err := parseAttachments(test.ctx(), test.url, test.invo()) + ctx := test.ctx() + inv := test.invo() + parseAttachments(ctx, test.url, inv) + ctx, err := mergeAttachmentToOutgoing(ctx, inv) test.expect(t, ctx, err) }) } diff --git a/protocol/triple/triple_protocol/duplex_http_call.go b/protocol/triple/triple_protocol/duplex_http_call.go index 3865a56ab7..6c4c021764 100644 --- a/protocol/triple/triple_protocol/duplex_http_call.go +++ b/protocol/triple/triple_protocol/duplex_http_call.go @@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error { if err := discard(d.response.Body); err != nil { return wrapIfRSTError(err) } + // Return incoming data via context, if set outgoing data. + if ExtractFromOutgoingContext(d.ctx) != nil { + newIncomingContext(d.ctx, d.ResponseTrailer()) + } return wrapIfRSTError(d.response.Body.Close()) } diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index 56f83b3fec..7a44f85354 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -112,6 +112,10 @@ func generateUnaryHandlerFunc( // merge headers mergeHeaders(conn.ResponseHeader(), response.Header()) mergeHeaders(conn.ResponseTrailer(), response.Trailer()) + //Write the server-side return-attachment-data in the tailer to send to the caller + if data := ExtractFromOutgoingContext(ctx); data != nil { + mergeHeaders(conn.ResponseTrailer(), data) + } return conn.Send(response.Any()) } @@ -160,6 +164,9 @@ func generateClientStreamHandlerFunc( } mergeHeaders(conn.ResponseHeader(), res.header) mergeHeaders(conn.ResponseTrailer(), res.trailer) + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } return conn.Send(res.Msg) } if interceptor != nil { @@ -205,7 +212,7 @@ func generateServerStreamHandlerFunc( } // embed header in context so that user logic could process them via FromIncomingContext ctx = newIncomingContext(ctx, conn.RequestHeader()) - return streamFunc( + err := streamFunc( ctx, &Request{ Msg: req, @@ -215,6 +222,13 @@ func generateServerStreamHandlerFunc( }, &ServerStream{conn: conn}, ) + if err != nil { + return err + } + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } + return nil } if interceptor != nil { implementation = interceptor.WrapStreamingHandler(implementation) @@ -253,10 +267,14 @@ func generateBidiStreamHandlerFunc( implementation := func(ctx context.Context, conn StreamingHandlerConn) error { // embed header in context so that user logic could process them via FromIncomingContext ctx = newIncomingContext(ctx, conn.RequestHeader()) - return streamFunc( - ctx, - &BidiStream{conn: conn}, - ) + err := streamFunc(ctx, &BidiStream{conn: conn}) + if err != nil { + return err + } + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } + return nil } if interceptor != nil { implementation = interceptor.WrapStreamingHandler(implementation) diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index 28618e0dcc..791cf4a30d 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -19,6 +19,7 @@ import ( "encoding/base64" "fmt" "net/http" + "strings" ) // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values. @@ -88,20 +89,45 @@ func addHeaderCanonical(h http.Header, key, value string) { h[key] = append(h[key], value) } -type headerIncomingKey struct{} -type headerOutgoingKey struct{} +type extraDataKey struct{} + +const headerIncomingKey string = "headerIncomingKey" +const headerOutgoingKey string = "headerOutgoingKey" + type handlerOutgoingKey struct{} -func newIncomingContext(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, headerIncomingKey{}, header) +func newIncomingContext(ctx context.Context, data http.Header) context.Context { + var header = http.Header{} + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + } + if data != nil { + for key, vals := range data { + header[strings.ToLower(key)] = vals + } + } + extraData[headerIncomingKey] = header + return context.WithValue(ctx, extraDataKey{}, extraData) } // NewOutgoingContext sets headers entirely. If there are existing headers, they would be replaced. // It is used for passing headers to server-side. // It is like grpc.NewOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func NewOutgoingContext(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, headerOutgoingKey{}, header) +func NewOutgoingContext(ctx context.Context, data http.Header) context.Context { + var header = http.Header{} + if data != nil { + for key, vals := range data { + header[strings.ToLower(key)] = vals + } + } + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + } + extraData[headerOutgoingKey] = header + return context.WithValue(ctx, extraDataKey{}, extraData) } // AppendToOutgoingContext merges kv pairs from user and existing headers. @@ -112,37 +138,47 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context if len(kv)%2 == 1 { panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv))) } - var header http.Header - headerRaw := ctx.Value(headerOutgoingKey{}) - if headerRaw == nil { + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + ctx = context.WithValue(ctx, extraDataKey{}, extraData) + } + header, ok := extraData[headerOutgoingKey] + if !ok { header = make(http.Header) - } else { - header = headerRaw.(http.Header) + extraData[headerOutgoingKey] = header } for i := 0; i < len(kv); i += 2 { // todo(DMwangnima): think about lowering - header.Add(kv[i], kv[i+1]) + header.Add(strings.ToLower(kv[i]), kv[i+1]) } - return context.WithValue(ctx, headerOutgoingKey{}, header) + return ctx } func ExtractFromOutgoingContext(ctx context.Context) http.Header { - headerRaw := ctx.Value(headerOutgoingKey{}) - if headerRaw == nil { + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { return nil } - // since headerOutgoingKey is only used in triple_protocol package, we need not verify the type - return headerRaw.(http.Header) + if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok { + return nil + } else { + return outGoingDataHeader + } } // FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext. +// it must call after append/setOutgoingContext to return current value // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1. func FromIncomingContext(ctx context.Context) (http.Header, bool) { - header, ok := ctx.Value(headerIncomingKey{}).(http.Header) + data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { return nil, false + } else if incomingDataHeader, ok := data[headerIncomingKey]; !ok { + return nil, false + } else { + return incomingDataHeader, true } - return header, true } // SetHeader is used for setting response header in server-side. It is like grpc.SendHeader(ctx, header) but