diff --git a/agent/agent_test.go b/agent/agent_test.go index 89877bee..27f8ded0 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -22,6 +22,7 @@ package agent import ( "context" + "encoding/json" "errors" "fmt" "math/rand" @@ -33,6 +34,8 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/topfreegames/pitaya/v2/conn/codec" codecmocks "github.com/topfreegames/pitaya/v2/conn/codec/mocks" "github.com/topfreegames/pitaya/v2/conn/message" messagemocks "github.com/topfreegames/pitaya/v2/conn/message/mocks" @@ -45,6 +48,7 @@ import ( metricsmocks "github.com/topfreegames/pitaya/v2/metrics/mocks" "github.com/topfreegames/pitaya/v2/mocks" "github.com/topfreegames/pitaya/v2/protos" + "github.com/topfreegames/pitaya/v2/serialize" serializemocks "github.com/topfreegames/pitaya/v2/serialize/mocks" "github.com/topfreegames/pitaya/v2/session" ) @@ -520,6 +524,7 @@ func TestAgentResponseMID(t *testing.T) { expected := pendingWrite{ctx: ctx, data: []byte("ok!"), err: nil} var err error if table.msgErr { + mockSerializer.EXPECT().Unmarshal(gomock.Any(), gomock.Any()).Return(nil) err = ag.ResponseMID(ctx, table.mid, table.data, table.msgErr) } else { err = ag.ResponseMID(ctx, table.mid, table.data) @@ -837,19 +842,39 @@ func TestAgentSendHandshakeResponse(t *testing.T) { } func TestAnswerWithError(t *testing.T) { - tables := []struct { + unknownError := e.NewError(errors.New(""), e.ErrUnknownCode) + table := []struct { name string + answeredErr error + encoderErr error getPayloadErr error - resErr error - err error + expectedErr error }{ - {"success", nil, nil, nil}, - {"failure_get_payload", errors.New("serialize err"), nil, errors.New("serialize err")}, - {"failure_response_mid", nil, errors.New("responsemid err"), errors.New("responsemid err")}, + { + name: "should succeed with unknown error", + answeredErr: assert.AnError, + encoderErr: nil, + getPayloadErr: nil, + expectedErr: unknownError, + }, + { + name: "should not answer if fails to get payload", + answeredErr: assert.AnError, + encoderErr: nil, + getPayloadErr: errors.New("serialize err"), + expectedErr: nil, + }, + { + name: "should not answer if fails to send", + answeredErr: assert.AnError, + encoderErr: assert.AnError, + getPayloadErr: nil, + expectedErr: nil, + }, } - for _, table := range tables { - t.Run(table.name, func(t *testing.T) { + for _, row := range table { + t.Run(row.name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -862,18 +887,97 @@ func TestAnswerWithError(t *testing.T) { ag := newAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) assert.NotNil(t, ag) - mockSerializer.EXPECT().Marshal(gomock.Any()).Return(nil, table.getPayloadErr) - if table.getPayloadErr == nil { - mockEncoder.EXPECT().Encode(packet.Type(packet.Data), gomock.Any()) - } - ag.AnswerWithError(nil, uint(rand.Int()), errors.New("something went wrong")) - if table.err == nil { - helpers.ShouldEventuallyReceive(t, ag.chSend) + mockSerializer.EXPECT().Marshal(gomock.Any()).Return(nil, row.getPayloadErr).AnyTimes() + mockSerializer.EXPECT().Unmarshal(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + mockEncoder.EXPECT().Encode(packet.Type(packet.Data), gomock.Any()).Return(nil, row.encoderErr).AnyTimes() + + ag.AnswerWithError(nil, uint(rand.Int()), row.answeredErr) + if row.expectedErr != nil { + pWrite := helpers.ShouldEventuallyReceive(t, ag.chSend) + assert.Equal(t, pendingWrite{err: row.expectedErr}, pWrite) } }) } } +type customSerializer struct{} + +func (*customSerializer) Marshal(obj interface{}) ([]byte, error) { return json.Marshal(obj) } +func (*customSerializer) Unmarshal(data []byte, obj interface{}) error { + return json.Unmarshal(data, obj) +} +func (*customSerializer) GetName() string { return "custom" } + +func TestAgentAnswerWithError(t *testing.T) { + jsonSerializer, err := serialize.NewSerializer(serialize.JSON) + require.NoError(t, err) + + protobufSerializer, err := serialize.NewSerializer(serialize.PROTOBUF) + require.NoError(t, err) + + customSerializer := &customSerializer{} + + table := []struct { + name string + answeredErr error + serializer serialize.Serializer + expectedErr error + }{ + { + name: "should return unknown code for generic error and JSON serializer", + answeredErr: assert.AnError, + serializer: jsonSerializer, + expectedErr: e.NewError(assert.AnError, e.ErrUnknownCode), + }, + { + name: "should return custom code for pitaya error and JSON serializer", + answeredErr: e.NewError(assert.AnError, "CUSTOM-123"), + serializer: jsonSerializer, + expectedErr: e.NewError(assert.AnError, "CUSTOM-123"), + }, + { + name: "should return unknown code for generic error and Protobuf serializer", + answeredErr: assert.AnError, + serializer: protobufSerializer, + expectedErr: e.NewError(assert.AnError, e.ErrUnknownCode), + }, + { + name: "should return custom code for pitaya error and Protobuf serializer", + answeredErr: e.NewError(assert.AnError, "CUSTOM-123"), + serializer: protobufSerializer, + expectedErr: e.NewError(assert.AnError, "CUSTOM-123"), + }, + { + name: "should return unknown code for generic error and custom serializer", + answeredErr: assert.AnError, + serializer: customSerializer, + expectedErr: e.NewError(assert.AnError, e.ErrUnknownCode), + }, + { + name: "should return custom code for pitaya error and custom serializer", + answeredErr: e.NewError(assert.AnError, "CUSTOM-123"), + serializer: customSerializer, + expectedErr: e.NewError(assert.AnError, "CUSTOM-123"), + }, + } + + for _, row := range table { + t.Run(row.name, func(t *testing.T) { + encoder := codec.NewPomeloPacketEncoder() + + messageEncoder := message.NewMessagesEncoder(false) + sessionPool := session.NewSessionPool() + ag := newAgent(nil, nil, encoder, row.serializer, time.Second, 1, nil, messageEncoder, nil, sessionPool).(*agentImpl) + assert.NotNil(t, ag) + + ag.AnswerWithError(nil, uint(rand.Int()), row.answeredErr) + + pWrite := helpers.ShouldEventuallyReceive(t, ag.chSend) + assert.Equal(t, row.expectedErr, pWrite.(pendingWrite).err) + }) + } +} + func TestAgentHeartbeat(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() diff --git a/errors/errors.go b/errors/errors.go index da52cf48..a4171ab9 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -20,6 +20,8 @@ package errors +import "errors" + // ErrUnknownCode is a string code representing an unknown error // This will be used when no error code is sent by the handler const ErrUnknownCode = "PIT-000" @@ -46,9 +48,10 @@ type Error struct { Metadata map[string]string } -//NewError ctor +// NewError ctor func NewError(err error, code string, metadata ...map[string]string) *Error { - if pitayaErr, ok := err.(*Error); ok { + var pitayaErr *Error + if ok := errors.As(err, &pitayaErr); ok { if len(metadata) > 0 { mergeMetadatas(pitayaErr, metadata[0]) } diff --git a/router/router_test.go b/router/router_test.go index b8ade8f2..2814a1a3 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -10,6 +10,7 @@ import ( "github.com/topfreegames/pitaya/v2/cluster" "github.com/topfreegames/pitaya/v2/cluster/mocks" "github.com/topfreegames/pitaya/v2/conn/message" + "github.com/topfreegames/pitaya/v2/constants" "github.com/topfreegames/pitaya/v2/protos" "github.com/topfreegames/pitaya/v2/route" ) @@ -108,3 +109,13 @@ func TestAddRoute(t *testing.T) { }) } } + +func TestRouteFailIfNullServiceDiscovery(t *testing.T) { + t.Parallel() + + router := New() + _, err := router.Route(context.Background(), protos.RPCType_Sys, serverType, route.NewRoute(serverType, "service", "method"), &message.Message{ + Data: []byte{0x01}, + }) + assert.Equal(t, constants.ErrServiceDiscoveryNotInitialized, err) +} diff --git a/service/remote_test.go b/service/remote_test.go index 24c9a46f..d6ab0fbd 100644 --- a/service/remote_test.go +++ b/service/remote_test.go @@ -266,18 +266,114 @@ func TestRemoteServiceRegisterFailsIfNoRemoteMethods(t *testing.T) { assert.Equal(t, errors.New("type NoHandlerRemoteComp has no exported methods of remote type"), err) } +func TestRemoteServiceRemoteCallWithDifferentServerArguments(t *testing.T) { + route := route.NewRoute("sv", "svc", "method") + table := []struct { + name string + serverArg *cluster.Server + routeServer *cluster.Server + expectedServer *cluster.Server + }{ + { + name: "should use server argument if provided", + serverArg: &cluster.Server{Type: "sv"}, + routeServer: &cluster.Server{Type: "sv2"}, + expectedServer: &cluster.Server{Type: "sv"}, + }, + { + name: "should use route's returned server if server argument is nil", + serverArg: nil, + routeServer: &cluster.Server{Type: "sv"}, + expectedServer: &cluster.Server{Type: "sv"}, + }, + } + + for _, row := range table { + t.Run(row.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockSession := sessionmocks.NewMockSession(ctrl) + mockRPCClient := clustermocks.NewMockRPCClient(ctrl) + sessionPool := sessionmocks.NewMockSessionPool(ctrl) + mockServiceDiscovery := clustermocks.NewMockServiceDiscovery(ctrl) + router := router.New() + router.SetServiceDiscovery(mockServiceDiscovery) + mockServiceDiscovery.EXPECT().GetServersByType(gomock.Any()).Return(map[string]*cluster.Server{row.routeServer.Type: row.routeServer}, nil).AnyTimes() + + msg := &message.Message{} + ctx := context.Background() + mockRPCClient.EXPECT().Call(ctx, protos.RPCType_Sys, gomock.Any(), mockSession, msg, row.expectedServer).Return(nil, nil).AnyTimes() + + svc := NewRemoteService(mockRPCClient, nil, nil, nil, nil, router, nil, nil, sessionPool, nil, pipeline.NewHandlerHooks(), nil) + assert.NotNil(t, svc) + + _, err := svc.remoteCall(ctx, row.serverArg, protos.RPCType_Sys, route, mockSession, msg) + assert.NoError(t, err) + }) + } +} + func TestRemoteServiceRemoteCall(t *testing.T) { - rt := route.NewRoute("sv", "svc", "method") - sv := &cluster.Server{} tables := []struct { - name string - server *cluster.Server - res *protos.Response - err error + name string + route route.Route + serverArg *cluster.Server + routeErr error + callRes *protos.Response + callErr error + expectedRes *protos.Response + expectedErr error }{ - {"no_target_route_error", nil, nil, e.NewError(constants.ErrServiceDiscoveryNotInitialized, e.ErrInternalCode)}, - {"error", sv, nil, errors.New("ble")}, - {"success", sv, &protos.Response{Data: []byte("ok")}, nil}, + { + name: "should return internal error for routing generic error", + route: *route.NewRoute("sv", "svc", "method"), + serverArg: nil, + routeErr: assert.AnError, + callRes: nil, + callErr: nil, + expectedRes: nil, + expectedErr: e.NewError(assert.AnError, e.ErrInternalCode), + }, + { + name: "should propagate error for routing pitaya error", + route: *route.NewRoute("sv", "svc", "method"), + serverArg: nil, + routeErr: e.NewError(assert.AnError, "CUSTOM-123"), + callRes: nil, + callErr: nil, + expectedRes: nil, + expectedErr: e.NewError(assert.AnError, "CUSTOM-123"), + }, + { + name: "should propagate error for routing wrapped pitaya error", + route: *route.NewRoute("sv", "svc", "method"), + serverArg: nil, + routeErr: fmt.Errorf("wrapper error: %w", e.NewError(assert.AnError, "CUSTOM-123")), + callRes: nil, + callErr: nil, + expectedRes: nil, + expectedErr: e.NewError(assert.AnError, "CUSTOM-123"), + }, + { + name: "should return error for rpc call error", + route: *route.NewRoute("sv", "svc", "method"), + serverArg: &cluster.Server{Type: "sv"}, + routeErr: nil, + callRes: nil, + callErr: assert.AnError, + expectedRes: nil, + expectedErr: assert.AnError, + }, + { + name: "should succeed", + route: *route.NewRoute("sv", "svc", "method"), + serverArg: &cluster.Server{Type: "sv"}, + routeErr: nil, + callRes: &protos.Response{Data: []byte("ok")}, + callErr: nil, + expectedRes: &protos.Response{Data: []byte("ok")}, + expectedErr: nil, + }, } for _, table := range tables { @@ -287,18 +383,24 @@ func TestRemoteServiceRemoteCall(t *testing.T) { mockSession := sessionmocks.NewMockSession(ctrl) mockRPCClient := clustermocks.NewMockRPCClient(ctrl) sessionPool := sessionmocks.NewMockSessionPool(ctrl) + mockServiceDiscovery := clustermocks.NewMockServiceDiscovery(ctrl) router := router.New() + router.SetServiceDiscovery(mockServiceDiscovery) + mockServiceDiscovery.EXPECT().GetServersByType(table.route.SvType).Return(map[string]*cluster.Server{"sv": {Type: "sv"}}, nil).AnyTimes() + + router.AddRoute(table.route.SvType, func(ctx context.Context, route *route.Route, payload []byte, servers map[string]*cluster.Server) (*cluster.Server, error) { + return &cluster.Server{}, table.routeErr + }) svc := NewRemoteService(mockRPCClient, nil, nil, nil, nil, router, nil, nil, sessionPool, nil, pipeline.NewHandlerHooks(), nil) assert.NotNil(t, svc) - msg := &message.Message{} + mockRPCClient.EXPECT().Call(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(table.callRes, table.callErr).AnyTimes() + ctx := context.Background() - if table.server != nil { - mockRPCClient.EXPECT().Call(ctx, protos.RPCType_Sys, rt, mockSession, msg, sv).Return(table.res, table.err) - } - res, err := svc.remoteCall(ctx, table.server, protos.RPCType_Sys, rt, mockSession, msg) - assert.Equal(t, table.err, err) - assert.Equal(t, table.res, res) + msg := &message.Message{} + res, err := svc.remoteCall(ctx, table.serverArg, protos.RPCType_Sys, &table.route, mockSession, msg) + assert.Equal(t, table.expectedErr, err) + assert.Equal(t, table.expectedRes, res) }) } } diff --git a/util/util.go b/util/util.go index 5fa24fb4..be8bdba5 100644 --- a/util/util.go +++ b/util/util.go @@ -39,8 +39,6 @@ import ( "github.com/topfreegames/pitaya/v2/logger/interfaces" "github.com/topfreegames/pitaya/v2/protos" "github.com/topfreegames/pitaya/v2/serialize" - "github.com/topfreegames/pitaya/v2/serialize/json" - "github.com/topfreegames/pitaya/v2/serialize/protobuf" "github.com/topfreegames/pitaya/v2/tracing" opentracing "github.com/opentracing/opentracing-go" @@ -125,16 +123,9 @@ func FileExists(filename string) bool { // GetErrorFromPayload gets the error from payload func GetErrorFromPayload(serializer serialize.Serializer, payload []byte) error { - err := &e.Error{Code: e.ErrUnknownCode} - switch serializer.(type) { - case *json.Serializer: - _ = serializer.Unmarshal(payload, err) - case *protobuf.Serializer: - pErr := &protos.Error{Code: e.ErrUnknownCode} - _ = serializer.Unmarshal(payload, pErr) - err = &e.Error{Code: pErr.Code, Message: pErr.Msg, Metadata: pErr.Metadata} - } - return err + pErr := &protos.Error{Code: e.ErrUnknownCode} + _ = serializer.Unmarshal(payload, pErr) + return &e.Error{Code: pErr.Code, Message: pErr.Msg, Metadata: pErr.Metadata} } // GetErrorPayload creates and serializes an error payload