From d492640dbd9e72b9a50259799aba76ce85b16313 Mon Sep 17 00:00:00 2001 From: Iaroslav Ciupin Date: Fri, 11 Aug 2023 22:52:30 +0300 Subject: [PATCH] Handle grpc error explicitly (#602) Signed-off-by: Iaroslav Ciupin --- pkg/rpc/adminservice/util/transformers.go | 38 ++++++++++++------- .../adminservice/util/transformers_test.go | 36 +++++++++++++----- 2 files changed, 51 insertions(+), 23 deletions(-) diff --git a/pkg/rpc/adminservice/util/transformers.go b/pkg/rpc/adminservice/util/transformers.go index b28cf13ec..aa45e1429 100644 --- a/pkg/rpc/adminservice/util/transformers.go +++ b/pkg/rpc/adminservice/util/transformers.go @@ -1,25 +1,35 @@ package util import ( + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/flyteorg/flyteadmin/pkg/common" "github.com/flyteorg/flyteadmin/pkg/errors" - - "google.golang.org/grpc/codes" ) -// Transforms errors to grpc-compatible error types and optionally truncates it if necessary. +// TransformAndRecordError transforms errors to grpc-compatible error types and optionally truncates it if necessary. func TransformAndRecordError(err error, metrics *RequestMetrics) error { - var errorMessage = err.Error() - concatenateErrMessage := false - if len(errorMessage) > common.MaxResponseStatusBytes { - errorMessage = err.Error()[:common.MaxResponseStatusBytes] - concatenateErrMessage = true + errMsg := err.Error() + shouldTruncate := false + if len(errMsg) > common.MaxResponseStatusBytes { + errMsg = errMsg[:common.MaxResponseStatusBytes] + shouldTruncate = true } - if flyteAdminError, ok := err.(errors.FlyteAdminError); !ok { - err = errors.NewFlyteAdminError(codes.Internal, errorMessage) - } else if concatenateErrMessage { - err = errors.NewFlyteAdminError(flyteAdminError.Code(), errorMessage) + + adminErr, isAdminErr := err.(errors.FlyteAdminError) + grpcStatus, isStatus := status.FromError(err) + switch { + case isAdminErr: + if shouldTruncate { + adminErr = errors.NewFlyteAdminError(adminErr.Code(), errMsg) + } + case isStatus: + adminErr = errors.NewFlyteAdminError(grpcStatus.Code(), errMsg) + default: + adminErr = errors.NewFlyteAdminError(codes.Internal, errMsg) } - metrics.Record(err.(errors.FlyteAdminError).Code()) - return err + + metrics.Record(adminErr.Code()) + return adminErr } diff --git a/pkg/rpc/adminservice/util/transformers_test.go b/pkg/rpc/adminservice/util/transformers_test.go index 862806094..52398dc76 100644 --- a/pkg/rpc/adminservice/util/transformers_test.go +++ b/pkg/rpc/adminservice/util/transformers_test.go @@ -3,22 +3,25 @@ package util import ( "context" "errors" + "strings" "testing" - "github.com/flyteorg/flyteadmin/pkg/common" - - adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" mockScope "github.com/flyteorg/flytestdlib/promutils" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/flyteorg/flyteadmin/pkg/common" + adminErrors "github.com/flyteorg/flyteadmin/pkg/errors" ) var testRequestMetrics = NewRequestMetrics(mockScope.NewTestScope(), "foo") func TestTransformError_FlyteAdminError(t *testing.T) { invalidArgError := adminErrors.NewFlyteAdminError(codes.InvalidArgument, "invalid arg") + transformedError := TransformAndRecordError(invalidArgError, &testRequestMetrics) + transormerStatus, ok := status.FromError(transformedError) assert.True(t, ok) assert.Equal(t, codes.InvalidArgument, transormerStatus.Code()) @@ -26,28 +29,43 @@ func TestTransformError_FlyteAdminError(t *testing.T) { func TestTransformError_FlyteAdminErrorWithDetails(t *testing.T) { terminalStateError := adminErrors.NewAlreadyInTerminalStateError(context.Background(), "terminal state", "curPhase") + transformedError := TransformAndRecordError(terminalStateError, &testRequestMetrics) + transormerStatus, ok := status.FromError(transformedError) assert.True(t, ok) assert.Equal(t, codes.FailedPrecondition, transormerStatus.Code()) assert.Equal(t, 1, len(transormerStatus.Details())) } +func TestTransformError_GRPCError(t *testing.T) { + err := status.Error(codes.InvalidArgument, strings.Repeat("X", common.MaxResponseStatusBytes+1)) + + transformedError := TransformAndRecordError(err, &testRequestMetrics) + + transormerStatus, ok := status.FromError(transformedError) + assert.True(t, ok) + assert.Equal(t, codes.InvalidArgument, transormerStatus.Code()) + assert.Len(t, transormerStatus.Message(), common.MaxResponseStatusBytes) +} + func TestTransformError_BasicError(t *testing.T) { err := errors.New("some error") + transformedError := TransformAndRecordError(err, &testRequestMetrics) + transormerStatus, ok := status.FromError(transformedError) assert.True(t, ok) assert.Equal(t, codes.Internal, transormerStatus.Code()) } func TestTruncateErrorMessage(t *testing.T) { - errorMessage := make([]byte, common.MaxResponseStatusBytes+1) - for i := 0; i <= common.MaxResponseStatusBytes; i++ { - errorMessage[i] = byte('a') - } + err := adminErrors.NewFlyteAdminError(codes.InvalidArgument, strings.Repeat("X", common.MaxResponseStatusBytes+1)) - err := adminErrors.NewFlyteAdminError(codes.InvalidArgument, string(errorMessage)) transformedError := TransformAndRecordError(err, &testRequestMetrics) - assert.Len(t, transformedError.Error(), common.MaxResponseStatusBytes) + + transormerStatus, ok := status.FromError(transformedError) + assert.True(t, ok) + assert.Equal(t, codes.InvalidArgument, transormerStatus.Code()) + assert.Len(t, transormerStatus.Message(), common.MaxResponseStatusBytes) }