From 8ce359f2c6ca6afafe1a2f6a33b73fda5ae44ca6 Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 10 Apr 2024 10:52:12 +0800 Subject: [PATCH 1/9] Add : add extraData api in rpc calling Signed-off-by: YarBor --- protocol/extra_data.go | 108 ++++++++++++++++++++++++++++++++++++ protocol/extra_data_test.go | 77 +++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 protocol/extra_data.go create mode 100644 protocol/extra_data_test.go diff --git a/protocol/extra_data.go b/protocol/extra_data.go new file mode 100644 index 000000000..1fbf05c70 --- /dev/null +++ b/protocol/extra_data.go @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protocol + +import ( + "context" + "github.com/pkg/errors" +) + +type rpcExtraDataKey struct{} + +const outgoingKey = "outgoingKey" +const incomingKey = "incomingKey" + +var _ *rpcExtraDataKey = (*rpcExtraDataKey)(nil) + +func SetIncomingData(ctx context.Context, extraData map[string]interface{}) context.Context { + data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}) + if !ok { + data = map[string]interface{}{} + ctx = context.WithValue(ctx, rpcExtraDataKey{}, data) + } + data[incomingKey] = extraData + return ctx +} + +func GetIncomingData(ctx context.Context) (map[string]interface{}, bool) { + if data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}); !ok { + return nil, false + } else { + if incomingDataInterface, ok := data[incomingKey]; !ok { + return nil, false + } else if incomingData, ok := incomingDataInterface.(map[string]interface{}); !ok { + return nil, false + } else { + // may need copy here ? + return incomingData, true + } + } +} + +func AppendIncomingData(ctx context.Context, kv ...string) (context.Context, error) { + if kv == nil || len(kv)%2 != 0 { + return ctx, errors.New("kv must be a non-nil slice with an even number of elements") + } + incomingData, ok := GetIncomingData(ctx) + if !ok { + incomingData = map[string]interface{}{} + } + for i := 0; i < len(kv); i += 2 { + incomingData[kv[i]] = []string{kv[i+1]} + } + return SetIncomingData(ctx, incomingData), nil +} + +func SetOutgoingData(ctx context.Context, extraData map[string]interface{}) context.Context { + data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}) + if !ok { + data = map[string]interface{}{} + ctx = context.WithValue(ctx, rpcExtraDataKey{}, data) + } + data[outgoingKey] = extraData + return ctx +} + +func GetOutgoingData(ctx context.Context) (map[string]interface{}, bool) { + if data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}); !ok { + return nil, false + } else { + if OutgoingDataInterface, ok := data[outgoingKey]; !ok { + return nil, false + } else if OutgoingData, ok := OutgoingDataInterface.(map[string]interface{}); !ok { + return nil, false + } else { + // may need copy here ? + return OutgoingData, true + } + } +} + +func AppendOutgoingData(ctx context.Context, kv ...string) (context.Context, error) { + if kv == nil || len(kv)%2 != 0 { + return ctx, errors.New("kv must be a non-nil slice with an even number of elements") + } + OutgoingData, ok := GetOutgoingData(ctx) + if !ok { + OutgoingData = map[string]interface{}{} + } + for i := 0; i < len(kv); i += 2 { + OutgoingData[kv[i]] = []string{kv[i+1]} + } + return SetOutgoingData(ctx, OutgoingData), nil +} diff --git a/protocol/extra_data_test.go b/protocol/extra_data_test.go new file mode 100644 index 000000000..b08cf47b2 --- /dev/null +++ b/protocol/extra_data_test.go @@ -0,0 +1,77 @@ +package protocol + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestIncomingData(t *testing.T) { + ctx := context.Background() + data := map[string]interface{}{"key": []string{"value"}, "key2": []string{"value2"}} + // set check + ctx = SetIncomingData(ctx, data) + testData := ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{}) + assert.Equal(t, data, testData) + // append check + ctx, err := AppendIncomingData(ctx, "hello", "world") + assert.Equal(t, nil, err) + data["hello"] = "world" + testData = ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{}) + assert.Equal(t, data, testData) + // get check + testData, ok := GetIncomingData(ctx) + assert.Equal(t, true, ok) + assert.Equal(t, data, testData) + + // test err + ctx = context.Background() + testData, ok = GetIncomingData(ctx) + assert.Equal(t, false, ok) + assert.Nil(t, testData) + + ctx, err = AppendIncomingData(ctx, "hello", "world") + assert.Equal(t, nil, err) + assert.Equal(t, map[string]interface{}{ + "hello": []string{"world"}, + }, ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{})) + + ctx1, err := AppendIncomingData(ctx, "hello", "world", "err input") + assert.Equal(t, ctx, ctx1) + assert.NotEqual(t, nil, err) +} + +func TestOutgoingData(t *testing.T) { + ctx := context.Background() + data := map[string]interface{}{"key": []string{"value"}, "key2": []string{"value2"}} + // set check + ctx = SetOutgoingData(ctx, data) + testData := ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{}) + assert.Equal(t, data, testData) + // append check + ctx, err := AppendOutgoingData(ctx, "hello", "world") + assert.Equal(t, nil, err) + data["hello"] = "world" + testData = ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{}) + assert.Equal(t, data, testData) + // get check + testData, ok := GetOutgoingData(ctx) + assert.Equal(t, true, ok) + assert.Equal(t, data, testData) + + // test err + ctx = context.Background() + testData, ok = GetOutgoingData(ctx) + assert.Equal(t, false, ok) + assert.Nil(t, testData) + + ctx, err = AppendOutgoingData(ctx, "hello", "world") + assert.Equal(t, nil, err) + assert.Equal(t, map[string]interface{}{ + "hello": []string{"world"}, + }, ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{})) + + ctx1, err := AppendOutgoingData(ctx, "hello", "world", "err input") + assert.Equal(t, ctx, ctx1) + assert.NotEqual(t, nil, err) +} From ee885964c0fbb2003aa0dc8e40ffb3645c8e5646 Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 10 Apr 2024 14:08:28 +0800 Subject: [PATCH 2/9] Add :Client side unray extradata send/return logic Signed-off-by: YarBor --- client/client.go | 24 ++++++++++++- protocol/extra_data_test.go | 16 +++++++++ protocol/triple/client.go | 8 +++++ protocol/triple/triple_invoker.go | 59 ++++++++++++++++++------------- 4 files changed, 81 insertions(+), 26 deletions(-) diff --git a/client/client.go b/client/client.go index 98ab66733..7752d27a2 100644 --- a/client/client.go +++ b/client/client.go @@ -20,6 +20,7 @@ package client import ( "context" + "errors" ) import ( @@ -59,7 +60,14 @@ func (conn *Connection) call(ctx context.Context, reqs []interface{}, resp inter if err != nil { return nil, err } - return conn.refOpts.invoker.Invoke(ctx, inv), nil + if userData, ok := protocol.GetOutgoingData(ctx); ok { + err = addClientExtraDataToInvocation(inv, userData) + if err != nil { + return nil, err + } + } + res := conn.refOpts.invoker.Invoke(ctx, inv) + return res, nil } func (conn *Connection) CallUnary(ctx context.Context, reqs []interface{}, resp interface{}, methodName string, opts ...CallOption) error { @@ -122,6 +130,7 @@ func (cli *Client) dial(interfaceName string, info *ClientInfo, opts ...Referenc return &Connection{refOpts: newRefOpts}, nil } + func generateInvocation(methodName string, reqs []interface{}, resp interface{}, callType string, opts *CallOptions) (protocol.Invocation, error) { var paramsRawVals []interface{} for _, req := range reqs { @@ -152,3 +161,16 @@ func NewClient(opts ...ClientOption) (*Client, error) { cliOpts: newCliOpts, }, nil } + +func addClientExtraDataToInvocation(inv protocol.Invocation, data map[string]interface{}) error { + for k, v := range data { + if str, ok := v.(string); ok { + inv.SetAttachment(k, []string{str}) + } else if strs, ok := v.([]string); ok { + inv.SetAttachment(k, strs) + } else { + return errors.New("ExtraData's type needs to be string or []string") + } + } + return nil +} diff --git a/protocol/extra_data_test.go b/protocol/extra_data_test.go index b08cf47b2..7a260f6f1 100644 --- a/protocol/extra_data_test.go +++ b/protocol/extra_data_test.go @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package protocol import ( diff --git a/protocol/triple/client.go b/protocol/triple/client.go index 66a3e5331..6b37aff70 100644 --- a/protocol/triple/client.go +++ b/protocol/triple/client.go @@ -20,6 +20,7 @@ package triple import ( "context" "crypto/tls" + "dubbo.apache.org/dubbo-go/v3/protocol" "fmt" "net" "net/http" @@ -70,6 +71,13 @@ func (cm *clientManager) callUnary(ctx context.Context, method string, req, resp if err := triClient.CallUnary(ctx, triReq, triResp); err != nil { return err } + if _, ok := protocol.GetOutgoingData(ctx); ok { + rtExtraData := map[string]interface{}{} + for k, v := range triResp.Trailer() { + rtExtraData[k] = v + } + protocol.SetIncomingData(ctx, rtExtraData) + } return nil } diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index e08778429..0c0847cce 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -81,11 +81,18 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } - ctx, callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation) + callType, inRaw, method, err := parseInvocation(ctx, ti.GetURL(), invocation) if err != nil { result.SetError(err) return &result } + + ctx, err = mergeAttachmentToOutgoing(ctx, invocation.Attachments()) + if err != nil { + result.SetError(err) + return &result + } + inRawLen := len(inRaw) if !ti.clientManager.isIDL { @@ -136,16 +143,33 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } +func mergeAttachmentToOutgoing(ctx context.Context, attachments map[string]interface{}) (context.Context, error) { + for key, valRaw := range attachments { + if str, ok := valRaw.(string); ok { + ctx = tri.AppendToOutgoingContext(ctx, key, str) + continue + } + if strs, ok := valRaw.([]string); ok { + for _, str := range strs { + ctx = tri.AppendToOutgoingContext(ctx, key, str) + } + continue + } + return ctx, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key) + } + return ctx, nil +} + // parseInvocation retrieves information from invocation. // it returns ctx, callType, inRaw, method, error -func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, string, []interface{}, string, error) { +func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.Invocation) (string, []interface{}, string, error) { callTypeRaw, ok := invocation.GetAttribute(constant.CallTypeKey) if !ok { - return nil, "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker") + return "", nil, "", errors.New("miss CallType in invocation to invoke TripleInvoker") } callType, ok := callTypeRaw.(string) if !ok { - return nil, "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw) + return "", nil, "", fmt.Errorf("CallType should be string, but got %v", callTypeRaw) } // please refer to methods of client.Client or code generated by new triple for the usage of inRaw and inRawLen // e.g. Client.CallUnary(... req, resp []interface, ...) @@ -153,19 +177,19 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.I inRaw := invocation.ParameterRawValues() method := invocation.MethodName() if method == "" { - return nil, "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker") + return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker") } - ctx, err := parseAttachments(ctx, url, invocation) + err := parseAttachments(ctx, url, invocation) if err != nil { - return nil, "", nil, "", err + return "", nil, "", err } - return ctx, callType, inRaw, method, nil + return callType, inRaw, method, nil } // parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx -func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) (context.Context, error) { +func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) error { // retrieve users passed-in attachment attaRaw := ctx.Value(constant.AttachmentKey) if attaRaw != nil { @@ -181,22 +205,7 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol. invocation.SetAttachment(key, val) } } - // inject attachments - for key, valRaw := range invocation.Attachments() { - if str, ok := valRaw.(string); ok { - ctx = tri.AppendToOutgoingContext(ctx, key, str) - continue - } - if strs, ok := valRaw.([]string); ok { - for _, str := range strs { - ctx = tri.AppendToOutgoingContext(ctx, key, str) - } - continue - } - return nil, fmt.Errorf("triple attachments value with key = %s is invalid, which should be string or []string", key) - } - - return ctx, nil + return nil } // IsAvailable get available status From f45ff57ea32154cd0dd36c22fc263b6686fba903 Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 10 Apr 2024 14:31:24 +0800 Subject: [PATCH 3/9] fic ci Signed-off-by: YarBor --- protocol/triple/triple_invoker_test.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index 7d14d4239..98751ff18 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -19,13 +19,13 @@ package triple import ( "context" + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" "testing" "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" "dubbo.apache.org/dubbo-go/v3/protocol" "dubbo.apache.org/dubbo-go/v3/protocol/invocation" - tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" "github.com/stretchr/testify/assert" ) @@ -35,7 +35,7 @@ func Test_parseInvocation(t *testing.T) { ctx func() context.Context url *common.URL invo func() protocol.Invocation - expect func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) + expect func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) }{ { desc: "miss callType", @@ -46,7 +46,7 @@ func Test_parseInvocation(t *testing.T) { invo: func() protocol.Invocation { return invocation.NewRPCInvocationWithOptions() }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -61,7 +61,7 @@ func Test_parseInvocation(t *testing.T) { iv.SetAttribute(constant.CallTypeKey, 1) return iv }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -76,7 +76,7 @@ func Test_parseInvocation(t *testing.T) { iv.SetAttribute(constant.CallTypeKey, constant.CallUnary) return iv }, - expect: func(t *testing.T, ctx context.Context, callType string, inRaw []interface{}, methodName string, err error) { + expect: func(t *testing.T, callType string, inRaw []interface{}, methodName string, err error) { assert.NotNil(t, err) }, }, @@ -84,8 +84,8 @@ func Test_parseInvocation(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ctx, callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo()) - test.expect(t, ctx, callType, inRaw, methodName, err) + callType, inRaw, methodName, err := parseInvocation(test.ctx(), test.url, test.invo()) + test.expect(t, callType, inRaw, methodName, err) }) } } @@ -157,7 +157,10 @@ func Test_parseAttachments(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - ctx, err := parseAttachments(test.ctx(), test.url, test.invo()) + ctx := test.ctx() + inv := test.invo() + err := parseAttachments(ctx, test.url, inv) + ctx, err = mergeAttachmentToOutgoing(ctx, inv.Attachments()) test.expect(t, ctx, err) }) } From 6d789670db20e66bb3ea574ed0159e47ab61c6c3 Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 10 Apr 2024 17:54:22 +0800 Subject: [PATCH 4/9] feat:Implemented data transfer and return rewrite logic,Put the api into protocol/triple/triple_protocol layer used triple/triple_protocol/header.go's logic Signed-off-by: YarBor --- client/client.go | 20 ---- protocol/extra_data.go | 108 --------------------- protocol/extra_data_test.go | 93 ------------------ protocol/triple/client.go | 8 -- protocol/triple/triple_invoker.go | 4 +- protocol/triple/triple_invoker_test.go | 5 +- protocol/triple/triple_protocol/client.go | 6 +- protocol/triple/triple_protocol/handler.go | 4 + protocol/triple/triple_protocol/header.go | 89 +++++++++++++---- 9 files changed, 83 insertions(+), 254 deletions(-) delete mode 100644 protocol/extra_data.go delete mode 100644 protocol/extra_data_test.go diff --git a/client/client.go b/client/client.go index 7752d27a2..c7547d640 100644 --- a/client/client.go +++ b/client/client.go @@ -20,7 +20,6 @@ package client import ( "context" - "errors" ) import ( @@ -60,12 +59,6 @@ func (conn *Connection) call(ctx context.Context, reqs []interface{}, resp inter if err != nil { return nil, err } - if userData, ok := protocol.GetOutgoingData(ctx); ok { - err = addClientExtraDataToInvocation(inv, userData) - if err != nil { - return nil, err - } - } res := conn.refOpts.invoker.Invoke(ctx, inv) return res, nil } @@ -161,16 +154,3 @@ func NewClient(opts ...ClientOption) (*Client, error) { cliOpts: newCliOpts, }, nil } - -func addClientExtraDataToInvocation(inv protocol.Invocation, data map[string]interface{}) error { - for k, v := range data { - if str, ok := v.(string); ok { - inv.SetAttachment(k, []string{str}) - } else if strs, ok := v.([]string); ok { - inv.SetAttachment(k, strs) - } else { - return errors.New("ExtraData's type needs to be string or []string") - } - } - return nil -} diff --git a/protocol/extra_data.go b/protocol/extra_data.go deleted file mode 100644 index 1fbf05c70..000000000 --- a/protocol/extra_data.go +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package protocol - -import ( - "context" - "github.com/pkg/errors" -) - -type rpcExtraDataKey struct{} - -const outgoingKey = "outgoingKey" -const incomingKey = "incomingKey" - -var _ *rpcExtraDataKey = (*rpcExtraDataKey)(nil) - -func SetIncomingData(ctx context.Context, extraData map[string]interface{}) context.Context { - data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}) - if !ok { - data = map[string]interface{}{} - ctx = context.WithValue(ctx, rpcExtraDataKey{}, data) - } - data[incomingKey] = extraData - return ctx -} - -func GetIncomingData(ctx context.Context) (map[string]interface{}, bool) { - if data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}); !ok { - return nil, false - } else { - if incomingDataInterface, ok := data[incomingKey]; !ok { - return nil, false - } else if incomingData, ok := incomingDataInterface.(map[string]interface{}); !ok { - return nil, false - } else { - // may need copy here ? - return incomingData, true - } - } -} - -func AppendIncomingData(ctx context.Context, kv ...string) (context.Context, error) { - if kv == nil || len(kv)%2 != 0 { - return ctx, errors.New("kv must be a non-nil slice with an even number of elements") - } - incomingData, ok := GetIncomingData(ctx) - if !ok { - incomingData = map[string]interface{}{} - } - for i := 0; i < len(kv); i += 2 { - incomingData[kv[i]] = []string{kv[i+1]} - } - return SetIncomingData(ctx, incomingData), nil -} - -func SetOutgoingData(ctx context.Context, extraData map[string]interface{}) context.Context { - data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}) - if !ok { - data = map[string]interface{}{} - ctx = context.WithValue(ctx, rpcExtraDataKey{}, data) - } - data[outgoingKey] = extraData - return ctx -} - -func GetOutgoingData(ctx context.Context) (map[string]interface{}, bool) { - if data, ok := ctx.Value(rpcExtraDataKey{}).(map[string]interface{}); !ok { - return nil, false - } else { - if OutgoingDataInterface, ok := data[outgoingKey]; !ok { - return nil, false - } else if OutgoingData, ok := OutgoingDataInterface.(map[string]interface{}); !ok { - return nil, false - } else { - // may need copy here ? - return OutgoingData, true - } - } -} - -func AppendOutgoingData(ctx context.Context, kv ...string) (context.Context, error) { - if kv == nil || len(kv)%2 != 0 { - return ctx, errors.New("kv must be a non-nil slice with an even number of elements") - } - OutgoingData, ok := GetOutgoingData(ctx) - if !ok { - OutgoingData = map[string]interface{}{} - } - for i := 0; i < len(kv); i += 2 { - OutgoingData[kv[i]] = []string{kv[i+1]} - } - return SetOutgoingData(ctx, OutgoingData), nil -} diff --git a/protocol/extra_data_test.go b/protocol/extra_data_test.go deleted file mode 100644 index 7a260f6f1..000000000 --- a/protocol/extra_data_test.go +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package protocol - -import ( - "context" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestIncomingData(t *testing.T) { - ctx := context.Background() - data := map[string]interface{}{"key": []string{"value"}, "key2": []string{"value2"}} - // set check - ctx = SetIncomingData(ctx, data) - testData := ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{}) - assert.Equal(t, data, testData) - // append check - ctx, err := AppendIncomingData(ctx, "hello", "world") - assert.Equal(t, nil, err) - data["hello"] = "world" - testData = ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{}) - assert.Equal(t, data, testData) - // get check - testData, ok := GetIncomingData(ctx) - assert.Equal(t, true, ok) - assert.Equal(t, data, testData) - - // test err - ctx = context.Background() - testData, ok = GetIncomingData(ctx) - assert.Equal(t, false, ok) - assert.Nil(t, testData) - - ctx, err = AppendIncomingData(ctx, "hello", "world") - assert.Equal(t, nil, err) - assert.Equal(t, map[string]interface{}{ - "hello": []string{"world"}, - }, ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[incomingKey].(map[string]interface{})) - - ctx1, err := AppendIncomingData(ctx, "hello", "world", "err input") - assert.Equal(t, ctx, ctx1) - assert.NotEqual(t, nil, err) -} - -func TestOutgoingData(t *testing.T) { - ctx := context.Background() - data := map[string]interface{}{"key": []string{"value"}, "key2": []string{"value2"}} - // set check - ctx = SetOutgoingData(ctx, data) - testData := ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{}) - assert.Equal(t, data, testData) - // append check - ctx, err := AppendOutgoingData(ctx, "hello", "world") - assert.Equal(t, nil, err) - data["hello"] = "world" - testData = ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{}) - assert.Equal(t, data, testData) - // get check - testData, ok := GetOutgoingData(ctx) - assert.Equal(t, true, ok) - assert.Equal(t, data, testData) - - // test err - ctx = context.Background() - testData, ok = GetOutgoingData(ctx) - assert.Equal(t, false, ok) - assert.Nil(t, testData) - - ctx, err = AppendOutgoingData(ctx, "hello", "world") - assert.Equal(t, nil, err) - assert.Equal(t, map[string]interface{}{ - "hello": []string{"world"}, - }, ctx.Value(rpcExtraDataKey{}).(map[string]interface{})[outgoingKey].(map[string]interface{})) - - ctx1, err := AppendOutgoingData(ctx, "hello", "world", "err input") - assert.Equal(t, ctx, ctx1) - assert.NotEqual(t, nil, err) -} diff --git a/protocol/triple/client.go b/protocol/triple/client.go index 6b37aff70..66a3e5331 100644 --- a/protocol/triple/client.go +++ b/protocol/triple/client.go @@ -20,7 +20,6 @@ package triple import ( "context" "crypto/tls" - "dubbo.apache.org/dubbo-go/v3/protocol" "fmt" "net" "net/http" @@ -71,13 +70,6 @@ func (cm *clientManager) callUnary(ctx context.Context, method string, req, resp if err := triClient.CallUnary(ctx, triReq, triResp); err != nil { return err } - if _, ok := protocol.GetOutgoingData(ctx); ok { - rtExtraData := map[string]interface{}{} - for k, v := range triResp.Trailer() { - rtExtraData[k] = v - } - protocol.SetIncomingData(ctx, rtExtraData) - } return nil } diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index 0c0847cce..5bab45d80 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -146,12 +146,12 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat func mergeAttachmentToOutgoing(ctx context.Context, attachments map[string]interface{}) (context.Context, error) { for key, valRaw := range attachments { if str, ok := valRaw.(string); ok { - ctx = tri.AppendToOutgoingContext(ctx, key, str) + ctx, _ = tri.AppendToOutgoingContext(ctx, key, str) continue } if strs, ok := valRaw.([]string); ok { for _, str := range strs { - ctx = tri.AppendToOutgoingContext(ctx, key, str) + ctx, _ = tri.AppendToOutgoingContext(ctx, key, str) } continue } diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index 98751ff18..ad1782d6d 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -20,6 +20,7 @@ package triple import ( "context" tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" + "net/http" "testing" "dubbo.apache.org/dubbo-go/v3/common" @@ -112,7 +113,7 @@ func Test_parseAttachments(t *testing.T) { }, expect: func(t *testing.T, ctx context.Context, err error) { assert.Nil(t, err) - header := tri.ExtractFromOutgoingContext(ctx) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) assert.NotNil(t, header) assert.Equal(t, "interface", header.Get(constant.InterfaceKey)) assert.Equal(t, "token", header.Get(constant.TokenKey)) @@ -132,7 +133,7 @@ func Test_parseAttachments(t *testing.T) { }, expect: func(t *testing.T, ctx context.Context, err error) { assert.Nil(t, err) - header := tri.ExtractFromOutgoingContext(ctx) + header := http.Header(tri.ExtractFromOutgoingContext(ctx)) assert.NotNil(t, header) assert.Equal(t, "val1", header.Get("key1")) assert.Equal(t, []string{"key2_1", "key2_2"}, header.Values("key2")) diff --git a/protocol/triple/triple_protocol/client.go b/protocol/triple/triple_protocol/client.go index 8ba67eec2..66d87037c 100644 --- a/protocol/triple/triple_protocol/client.go +++ b/protocol/triple/triple_protocol/client.go @@ -94,7 +94,11 @@ func NewClient(httpClient HTTPClient, url string, options ...ClientOption) *Clie _ = conn.CloseResponse() return err } - return conn.CloseResponse() + err := conn.CloseResponse() + if err == nil && ExtractFromOutgoingContext(ctx) != nil { + newIncomingContext(ctx, conn.ResponseTrailer()) + } + return err }) if interceptor := config.Interceptor; interceptor != nil { unaryFunc = interceptor.WrapUnary(unaryFunc) diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index 56f83b3fe..998019ec5 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -112,6 +112,10 @@ func generateUnaryHandlerFunc( // merge headers mergeHeaders(conn.ResponseHeader(), response.Header()) mergeHeaders(conn.ResponseTrailer(), response.Trailer()) + //Write the server-side return-attachment-data in the tailer to send to the caller + if data := ExtractFromOutgoingContext(ctx); data != nil { + mergeHeaders(conn.ResponseTrailer(), data) + } return conn.Send(response.Any()) } diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index 28618e0dc..d79662bac 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -18,7 +18,9 @@ import ( "context" "encoding/base64" "fmt" + "github.com/pkg/errors" "net/http" + "strings" ) // EncodeBinaryHeader base64-encodes the data. It always emits unpadded values. @@ -88,61 +90,100 @@ func addHeaderCanonical(h http.Header, key, value string) { h[key] = append(h[key], value) } -type headerIncomingKey struct{} -type headerOutgoingKey struct{} +type extraDataKey struct{} + +const headerIncomingKey string = "headerIncomingKey" +const headerOutgoingKey string = "headerOutgoingKey" + type handlerOutgoingKey struct{} func newIncomingContext(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, headerIncomingKey{}, header) + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + } + if header == nil { + header = make(http.Header) + } + extraData[headerIncomingKey] = header + return context.WithValue(ctx, extraDataKey{}, extraData) } // NewOutgoingContext sets headers entirely. If there are existing headers, they would be replaced. // It is used for passing headers to server-side. // It is like grpc.NewOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func NewOutgoingContext(ctx context.Context, header http.Header) context.Context { - return context.WithValue(ctx, headerOutgoingKey{}, header) +func NewOutgoingContext(ctx context.Context, data interface{}) (context.Context, error) { + header := make(http.Header) + if inputData, ok := data.(map[string]string); ok { + for k, v := range inputData { + + header.Add(k, v) + } + } else if inputData, ok := data.(map[string][]string); ok { + header = inputData + } else if inputData, ok := data.(http.Header); ok { + header = inputData + } else { + return ctx, errors.New("IncomingContext data must be map[string]string or map[string][]string") + } + + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + } + extraData[headerOutgoingKey] = header + return context.WithValue(ctx, extraDataKey{}, extraData), nil } // AppendToOutgoingContext merges kv pairs from user and existing headers. // It is used for passing headers to server-side. // It is like grpc.AppendToOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { +func AppendToOutgoingContext(ctx context.Context, kv ...string) (context.Context, error) { if len(kv)%2 == 1 { panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv))) } - var header http.Header - headerRaw := ctx.Value(headerOutgoingKey{}) - if headerRaw == nil { + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + extraData = map[string]http.Header{} + ctx = context.WithValue(ctx, extraDataKey{}, extraData) + } + header, ok := extraData[headerOutgoingKey] + if !ok { header = make(http.Header) - } else { - header = headerRaw.(http.Header) + extraData[headerOutgoingKey] = header } for i := 0; i < len(kv); i += 2 { // todo(DMwangnima): think about lowering header.Add(kv[i], kv[i+1]) } - return context.WithValue(ctx, headerOutgoingKey{}, header) + return ctx, nil } -func ExtractFromOutgoingContext(ctx context.Context) http.Header { - headerRaw := ctx.Value(headerOutgoingKey{}) - if headerRaw == nil { +func ExtractFromOutgoingContext(ctx context.Context) map[string][]string { + extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) + if !ok { + return nil + } else if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok { return nil + } else { + return outGoingDataHeader } - // since headerOutgoingKey is only used in triple_protocol package, we need not verify the type - return headerRaw.(http.Header) } // FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext. +// it must call after append/setOutgoingContext to return current value // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1. -func FromIncomingContext(ctx context.Context) (http.Header, bool) { - header, ok := ctx.Value(headerIncomingKey{}).(http.Header) +func FromIncomingContext(ctx context.Context) (map[string][]string, bool) { + data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { return nil, false + } else if incomingDataHeader, ok := data[headerIncomingKey]; !ok { + return nil, false + } else { + return incomingDataHeader, true } - return header, true } // SetHeader is used for setting response header in server-side. It is like grpc.SendHeader(ctx, header) but @@ -181,3 +222,11 @@ func SendHeader(ctx context.Context, header http.Header) error { mergeHeaders(conn.RequestHeader(), header) return conn.Send(nil) } + +func outGoingKeyCheck(key string) bool { + if len(key) < 1 { + return false + } + firstLetter := string(key[0]) + return firstLetter == strings.ToUpper(firstLetter) +} From 4a1a2204fff8cefdd626fc296fbb14898a52be3b Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 10 Apr 2024 19:00:10 +0800 Subject: [PATCH 5/9] fix ci and format import Signed-off-by: YarBor --- protocol/triple/triple_invoker.go | 8 ++------ protocol/triple/triple_invoker_test.go | 6 +++--- protocol/triple/triple_protocol/header.go | 11 ++++++++++- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index 5bab45d80..b6ae119fc 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -180,16 +180,13 @@ func parseInvocation(ctx context.Context, url *common.URL, invocation protocol.I return "", nil, "", errors.New("miss MethodName in invocation to invoke TripleInvoker") } - err := parseAttachments(ctx, url, invocation) - if err != nil { - return "", nil, "", err - } + parseAttachments(ctx, url, invocation) return callType, inRaw, method, nil } // parseAttachments retrieves attachments from users passed-in and URL, then injects them into ctx -func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) error { +func parseAttachments(ctx context.Context, url *common.URL, invocation protocol.Invocation) { // retrieve users passed-in attachment attaRaw := ctx.Value(constant.AttachmentKey) if attaRaw != nil { @@ -205,7 +202,6 @@ func parseAttachments(ctx context.Context, url *common.URL, invocation protocol. invocation.SetAttachment(key, val) } } - return nil } // IsAvailable get available status diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index ad1782d6d..d08344eff 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -19,7 +19,6 @@ package triple import ( "context" - tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" "net/http" "testing" @@ -27,6 +26,7 @@ import ( "dubbo.apache.org/dubbo-go/v3/common/constant" "dubbo.apache.org/dubbo-go/v3/protocol" "dubbo.apache.org/dubbo-go/v3/protocol/invocation" + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" "github.com/stretchr/testify/assert" ) @@ -160,8 +160,8 @@ func Test_parseAttachments(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx := test.ctx() inv := test.invo() - err := parseAttachments(ctx, test.url, inv) - ctx, err = mergeAttachmentToOutgoing(ctx, inv.Attachments()) + parseAttachments(ctx, test.url, inv) + ctx, err := mergeAttachmentToOutgoing(ctx, inv.Attachments()) test.expect(t, ctx, err) }) } diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index d79662bac..83a864f36 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -117,13 +117,22 @@ func NewOutgoingContext(ctx context.Context, data interface{}) (context.Context, header := make(http.Header) if inputData, ok := data.(map[string]string); ok { for k, v := range inputData { - header.Add(k, v) } } else if inputData, ok := data.(map[string][]string); ok { header = inputData } else if inputData, ok := data.(http.Header); ok { header = inputData + } else if inputData, ok := data.(map[string]interface{}); ok { + for k, v := range inputData { + if val, ok := v.(string); ok { + header[k] = []string{val} + } else if val, ok := v.([]string); ok { + header[k] = val + } else { + return ctx, errors.New("IncomingContext data must be map[string]string or map[string][]string") + } + } } else { return ctx, errors.New("IncomingContext data must be map[string]string or map[string][]string") } From 0d1eb75b17f809f93c73e504d76be40c5f5a4aef Mon Sep 17 00:00:00 2001 From: YarBor Date: Mon, 15 Apr 2024 09:33:57 +0800 Subject: [PATCH 6/9] change api back Signed-off-by: YarBor --- client/client.go | 3 +-- protocol/triple/triple_protocol/header.go | 32 +++++------------------ 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/client/client.go b/client/client.go index c7547d640..a2bcaef1e 100644 --- a/client/client.go +++ b/client/client.go @@ -59,8 +59,7 @@ func (conn *Connection) call(ctx context.Context, reqs []interface{}, resp inter if err != nil { return nil, err } - res := conn.refOpts.invoker.Invoke(ctx, inv) - return res, nil + return conn.refOpts.invoker.Invoke(ctx, inv), nil } func (conn *Connection) CallUnary(ctx context.Context, reqs []interface{}, resp interface{}, methodName string, opts ...CallOption) error { diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index 83a864f36..044792a55 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -18,7 +18,6 @@ import ( "context" "encoding/base64" "fmt" - "github.com/pkg/errors" "net/http" "strings" ) @@ -113,30 +112,13 @@ func newIncomingContext(ctx context.Context, header http.Header) context.Context // It is used for passing headers to server-side. // It is like grpc.NewOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func NewOutgoingContext(ctx context.Context, data interface{}) (context.Context, error) { - header := make(http.Header) - if inputData, ok := data.(map[string]string); ok { - for k, v := range inputData { - header.Add(k, v) - } - } else if inputData, ok := data.(map[string][]string); ok { - header = inputData - } else if inputData, ok := data.(http.Header); ok { - header = inputData - } else if inputData, ok := data.(map[string]interface{}); ok { - for k, v := range inputData { - if val, ok := v.(string); ok { - header[k] = []string{val} - } else if val, ok := v.([]string); ok { - header[k] = val - } else { - return ctx, errors.New("IncomingContext data must be map[string]string or map[string][]string") - } - } +func NewOutgoingContext(ctx context.Context, data http.Header) (context.Context, error) { + var header http.Header + if data == nil { + header = make(http.Header) } else { - return ctx, errors.New("IncomingContext data must be map[string]string or map[string][]string") + header = data.Clone() } - extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { extraData = map[string]http.Header{} @@ -170,7 +152,7 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) (context.Context return ctx, nil } -func ExtractFromOutgoingContext(ctx context.Context) map[string][]string { +func ExtractFromOutgoingContext(ctx context.Context) http.Header { extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { return nil @@ -184,7 +166,7 @@ func ExtractFromOutgoingContext(ctx context.Context) map[string][]string { // FromIncomingContext retrieves headers passed by client-side. It is like grpc.FromIncomingContext. // it must call after append/setOutgoingContext to return current value // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#receiving-metadata-1. -func FromIncomingContext(ctx context.Context) (map[string][]string, bool) { +func FromIncomingContext(ctx context.Context) (http.Header, bool) { data, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { return nil, false From 731373265401ef43658e0d7d748002691fe0bb3f Mon Sep 17 00:00:00 2001 From: YarBor Date: Tue, 16 Apr 2024 09:48:12 +0800 Subject: [PATCH 7/9] Change the writing rules of incoming/outgoing key-value pairs Signed-off-by: YarBor --- protocol/triple/triple_protocol/header.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index 044792a55..0f7a87c11 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -96,13 +96,14 @@ const headerOutgoingKey string = "headerOutgoingKey" type handlerOutgoingKey struct{} -func newIncomingContext(ctx context.Context, header http.Header) context.Context { +func newIncomingContext(ctx context.Context, data http.Header) context.Context { + var header = http.Header{} extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { extraData = map[string]http.Header{} } - if header == nil { - header = make(http.Header) + for key, vals := range data { + header[strings.ToLower(key)] = vals } extraData[headerIncomingKey] = header return context.WithValue(ctx, extraDataKey{}, extraData) @@ -113,11 +114,9 @@ func newIncomingContext(ctx context.Context, header http.Header) context.Context // It is like grpc.NewOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. func NewOutgoingContext(ctx context.Context, data http.Header) (context.Context, error) { - var header http.Header - if data == nil { - header = make(http.Header) - } else { - header = data.Clone() + var header = http.Header{} + for key, vals := range data { + header[strings.ToLower(key)] = vals } extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { @@ -147,7 +146,7 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) (context.Context } for i := 0; i < len(kv); i += 2 { // todo(DMwangnima): think about lowering - header.Add(kv[i], kv[i+1]) + header.Add(strings.ToLower(kv[i]), kv[i+1]) } return ctx, nil } From d09d80f775e28d37008666fc34c1f362eaab30cb Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 17 Apr 2024 19:22:14 +0800 Subject: [PATCH 8/9] Rewrite the logic of returning ExtraData, now the new api applies to all 4 calling methods. Signed-off-by: YarBor --- protocol/triple/triple_protocol/client.go | 6 +---- .../triple_protocol/duplex_http_call.go | 4 ++++ protocol/triple/triple_protocol/handler.go | 24 +++++++++++++++---- protocol/triple/triple_protocol/header.go | 23 ++++++++---------- 4 files changed, 34 insertions(+), 23 deletions(-) diff --git a/protocol/triple/triple_protocol/client.go b/protocol/triple/triple_protocol/client.go index 66d87037c..8ba67eec2 100644 --- a/protocol/triple/triple_protocol/client.go +++ b/protocol/triple/triple_protocol/client.go @@ -94,11 +94,7 @@ func NewClient(httpClient HTTPClient, url string, options ...ClientOption) *Clie _ = conn.CloseResponse() return err } - err := conn.CloseResponse() - if err == nil && ExtractFromOutgoingContext(ctx) != nil { - newIncomingContext(ctx, conn.ResponseTrailer()) - } - return err + return conn.CloseResponse() }) if interceptor := config.Interceptor; interceptor != nil { unaryFunc = interceptor.WrapUnary(unaryFunc) diff --git a/protocol/triple/triple_protocol/duplex_http_call.go b/protocol/triple/triple_protocol/duplex_http_call.go index 3865a56ab..6c4c02176 100644 --- a/protocol/triple/triple_protocol/duplex_http_call.go +++ b/protocol/triple/triple_protocol/duplex_http_call.go @@ -184,6 +184,10 @@ func (d *duplexHTTPCall) CloseRead() error { if err := discard(d.response.Body); err != nil { return wrapIfRSTError(err) } + // Return incoming data via context, if set outgoing data. + if ExtractFromOutgoingContext(d.ctx) != nil { + newIncomingContext(d.ctx, d.ResponseTrailer()) + } return wrapIfRSTError(d.response.Body.Close()) } diff --git a/protocol/triple/triple_protocol/handler.go b/protocol/triple/triple_protocol/handler.go index 998019ec5..7a44f8535 100644 --- a/protocol/triple/triple_protocol/handler.go +++ b/protocol/triple/triple_protocol/handler.go @@ -164,6 +164,9 @@ func generateClientStreamHandlerFunc( } mergeHeaders(conn.ResponseHeader(), res.header) mergeHeaders(conn.ResponseTrailer(), res.trailer) + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } return conn.Send(res.Msg) } if interceptor != nil { @@ -209,7 +212,7 @@ func generateServerStreamHandlerFunc( } // embed header in context so that user logic could process them via FromIncomingContext ctx = newIncomingContext(ctx, conn.RequestHeader()) - return streamFunc( + err := streamFunc( ctx, &Request{ Msg: req, @@ -219,6 +222,13 @@ func generateServerStreamHandlerFunc( }, &ServerStream{conn: conn}, ) + if err != nil { + return err + } + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } + return nil } if interceptor != nil { implementation = interceptor.WrapStreamingHandler(implementation) @@ -257,10 +267,14 @@ func generateBidiStreamHandlerFunc( implementation := func(ctx context.Context, conn StreamingHandlerConn) error { // embed header in context so that user logic could process them via FromIncomingContext ctx = newIncomingContext(ctx, conn.RequestHeader()) - return streamFunc( - ctx, - &BidiStream{conn: conn}, - ) + err := streamFunc(ctx, &BidiStream{conn: conn}) + if err != nil { + return err + } + if outgoingData := ExtractFromOutgoingContext(ctx); outgoingData != nil { + mergeHeaders(conn.ResponseTrailer(), outgoingData) + } + return nil } if interceptor != nil { implementation = interceptor.WrapStreamingHandler(implementation) diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index 0f7a87c11..f5ec28ccb 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -102,8 +102,10 @@ func newIncomingContext(ctx context.Context, data http.Header) context.Context { if !ok { extraData = map[string]http.Header{} } - for key, vals := range data { - header[strings.ToLower(key)] = vals + if data != nil { + for key, vals := range data { + header[strings.ToLower(key)] = vals + } } extraData[headerIncomingKey] = header return context.WithValue(ctx, extraDataKey{}, extraData) @@ -115,8 +117,10 @@ func newIncomingContext(ctx context.Context, data http.Header) context.Context { // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. func NewOutgoingContext(ctx context.Context, data http.Header) (context.Context, error) { var header = http.Header{} - for key, vals := range data { - header[strings.ToLower(key)] = vals + if data != nil { + for key, vals := range data { + header[strings.ToLower(key)] = vals + } } extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { @@ -155,7 +159,8 @@ func ExtractFromOutgoingContext(ctx context.Context) http.Header { extraData, ok := ctx.Value(extraDataKey{}).(map[string]http.Header) if !ok { return nil - } else if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok { + } + if outGoingDataHeader, ok := extraData[headerOutgoingKey]; !ok { return nil } else { return outGoingDataHeader @@ -212,11 +217,3 @@ func SendHeader(ctx context.Context, header http.Header) error { mergeHeaders(conn.RequestHeader(), header) return conn.Send(nil) } - -func outGoingKeyCheck(key string) bool { - if len(key) < 1 { - return false - } - firstLetter := string(key[0]) - return firstLetter == strings.ToUpper(firstLetter) -} From 44965e470e3c705c06c7221ff21f7408212d99d3 Mon Sep 17 00:00:00 2001 From: YarBor Date: Wed, 17 Apr 2024 19:50:30 +0800 Subject: [PATCH 9/9] fix comment Signed-off-by: YarBor --- protocol/triple/triple_invoker.go | 10 +++++----- protocol/triple/triple_invoker_test.go | 2 +- protocol/triple/triple_protocol/header.go | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/protocol/triple/triple_invoker.go b/protocol/triple/triple_invoker.go index b6ae119fc..4c1eb6868 100644 --- a/protocol/triple/triple_invoker.go +++ b/protocol/triple/triple_invoker.go @@ -87,7 +87,7 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } - ctx, err = mergeAttachmentToOutgoing(ctx, invocation.Attachments()) + ctx, err = mergeAttachmentToOutgoing(ctx, invocation) if err != nil { result.SetError(err) return &result @@ -143,15 +143,15 @@ func (ti *TripleInvoker) Invoke(ctx context.Context, invocation protocol.Invocat return &result } -func mergeAttachmentToOutgoing(ctx context.Context, attachments map[string]interface{}) (context.Context, error) { - for key, valRaw := range attachments { +func mergeAttachmentToOutgoing(ctx context.Context, inv protocol.Invocation) (context.Context, error) { + for key, valRaw := range inv.Attachments() { if str, ok := valRaw.(string); ok { - ctx, _ = tri.AppendToOutgoingContext(ctx, key, str) + ctx = tri.AppendToOutgoingContext(ctx, key, str) continue } if strs, ok := valRaw.([]string); ok { for _, str := range strs { - ctx, _ = tri.AppendToOutgoingContext(ctx, key, str) + ctx = tri.AppendToOutgoingContext(ctx, key, str) } continue } diff --git a/protocol/triple/triple_invoker_test.go b/protocol/triple/triple_invoker_test.go index d08344eff..e9dcc968c 100644 --- a/protocol/triple/triple_invoker_test.go +++ b/protocol/triple/triple_invoker_test.go @@ -161,7 +161,7 @@ func Test_parseAttachments(t *testing.T) { ctx := test.ctx() inv := test.invo() parseAttachments(ctx, test.url, inv) - ctx, err := mergeAttachmentToOutgoing(ctx, inv.Attachments()) + ctx, err := mergeAttachmentToOutgoing(ctx, inv) test.expect(t, ctx, err) }) } diff --git a/protocol/triple/triple_protocol/header.go b/protocol/triple/triple_protocol/header.go index f5ec28ccb..791cf4a30 100644 --- a/protocol/triple/triple_protocol/header.go +++ b/protocol/triple/triple_protocol/header.go @@ -115,7 +115,7 @@ func newIncomingContext(ctx context.Context, data http.Header) context.Context { // It is used for passing headers to server-side. // It is like grpc.NewOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func NewOutgoingContext(ctx context.Context, data http.Header) (context.Context, error) { +func NewOutgoingContext(ctx context.Context, data http.Header) context.Context { var header = http.Header{} if data != nil { for key, vals := range data { @@ -127,14 +127,14 @@ func NewOutgoingContext(ctx context.Context, data http.Header) (context.Context, extraData = map[string]http.Header{} } extraData[headerOutgoingKey] = header - return context.WithValue(ctx, extraDataKey{}, extraData), nil + return context.WithValue(ctx, extraDataKey{}, extraData) } // AppendToOutgoingContext merges kv pairs from user and existing headers. // It is used for passing headers to server-side. // It is like grpc.AppendToOutgoingContext. // Please refer to https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md#sending-metadata. -func AppendToOutgoingContext(ctx context.Context, kv ...string) (context.Context, error) { +func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { if len(kv)%2 == 1 { panic(fmt.Sprintf("AppendToOutgoingContext got an odd number of input pairs for header: %d", len(kv))) } @@ -152,7 +152,7 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) (context.Context // todo(DMwangnima): think about lowering header.Add(strings.ToLower(kv[i]), kv[i+1]) } - return ctx, nil + return ctx } func ExtractFromOutgoingContext(ctx context.Context) http.Header {