diff --git a/strategy/exception/default_exception_formatting_strategy.go b/strategy/exception/default_exception_formatting_strategy.go index e947af33..3927fce7 100644 --- a/strategy/exception/default_exception_formatting_strategy.go +++ b/strategy/exception/default_exception_formatting_strategy.go @@ -11,6 +11,7 @@ package exception import ( "bytes" "crypto/rand" + goerrors "errors" "fmt" "runtime" "strings" @@ -115,7 +116,8 @@ func (dEFS *DefaultFormattingStrategy) Panicf(formatString string, args ...inter // ExceptionFromError takes an error and returns value of Exception func (dEFS *DefaultFormattingStrategy) ExceptionFromError(err error) Exception { var isRemote bool - if reqErr, ok := err.(awserr.RequestFailure); ok { + var reqErr awserr.RequestFailure + if goerrors.As(err, &reqErr) { // A service error occurs if reqErr.RequestID() != "" { isRemote = true @@ -133,22 +135,25 @@ func (dEFS *DefaultFormattingStrategy) ExceptionFromError(err error) Exception { Remote: isRemote, } - if err, ok := err.(*XRayError); ok { - e.Type = err.Type + xRayErr := &XRayError{} + if goerrors.As(err, &xRayErr) { + e.Type = xRayErr.Type } var s []uintptr // This is our publicly supported interface for passing along stack traces - if err, ok := err.(StackTracer); ok { - s = err.StackTrace() + var st StackTracer + if goerrors.As(err, &st) { + s = st.StackTrace() } // We also accept github.com/pkg/errors style stack traces for ease of use - if err, ok := err.(interface { + var est interface { StackTrace() errors.StackTrace - }); ok { - for _, frame := range err.StackTrace() { + } + if goerrors.As(err, &est) { + for _, frame := range est.StackTrace() { s = append(s, uintptr(frame)) } } diff --git a/strategy/exception/exception_test.go b/strategy/exception/exception_test.go index b002a42d..24f50e13 100644 --- a/strategy/exception/exception_test.go +++ b/strategy/exception/exception_test.go @@ -12,6 +12,7 @@ import ( "errors" "testing" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/stretchr/testify/assert" ) @@ -105,6 +106,30 @@ func TestExceptionFromError(t *testing.T) { assert.Equal(t, "errors.errorString", err.Type) } +func TestExceptionFromErrorRequestFailure(t *testing.T) { + defaultStrategy := &DefaultFormattingStrategy{} + reqErr := awserr.NewRequestFailure(awserr.New("error code", "error message", errors.New("new error")), 400, "1234") + + err := defaultStrategy.ExceptionFromError(reqErr) + + assert.NotNil(t, err.ID) + assert.Contains(t, err.Message, "new error") + assert.Contains(t, err.Message, "1234") + assert.Equal(t, "awserr.requestError", err.Type) + assert.Equal(t, true, err.Remote) +} + +func TestExceptionFromErrorXRayError(t *testing.T) { + defaultStrategy := &DefaultFormattingStrategy{} + xRayErr := defaultStrategy.Error("new XRayError") + + err := defaultStrategy.ExceptionFromError(xRayErr) + + assert.NotNil(t, err.ID) + assert.Equal(t, "new XRayError", err.Message) + assert.Equal(t, "error", err.Type) +} + // Benchmarks func BenchmarkDefaultFormattingStrategy_Error(b *testing.B) { defs, _ := NewDefaultFormattingStrategy()