Skip to content

Commit

Permalink
Allow custom error handling via HandleErrResponse (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
vearutop authored Feb 25, 2021
1 parent 88927c7 commit 6f71d14
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 4 deletions.
18 changes: 14 additions & 4 deletions nethttp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net/http"
"reflect"

"github.com/swaggest/openapi-go/openapi3"
"github.com/swaggest/rest"
"github.com/swaggest/usecase"
"github.com/swaggest/usecase/status"
Expand All @@ -22,6 +21,7 @@ func NewHandler(useCase usecase.Interactor, options ...func(h *Handler)) *Handle
h := &Handler{
options: options,
}
h.HandleErrResponse = h.handleErrResponseDefault
h.SetUseCase(useCase)

return h
Expand Down Expand Up @@ -50,8 +50,8 @@ func (h *Handler) SetUseCase(useCase usecase.Interactor) {
type Handler struct {
rest.HandlerTrait

// OperationAnnotations are called after operation setup and before adding operation to documentation.
OperationAnnotations []func(op *openapi3.Operation) error
// HandleErrResponse allows control of error response processing.
HandleErrResponse func(w http.ResponseWriter, r *http.Request, err error)

// requestDecoder maps data from http.Request into structured Go input value.
requestDecoder RequestDecoder
Expand Down Expand Up @@ -123,7 +123,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.responseEncoder.WriteSuccessfulResponse(w, r, output, h.HandlerTrait)
}

func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err error) {
func (h *Handler) handleErrResponseDefault(w http.ResponseWriter, r *http.Request, err error) {
var (
code int
er interface{}
Expand All @@ -138,6 +138,16 @@ func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err
h.responseEncoder.WriteErrResponse(w, r, code, er)
}

func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err error) {
if h.HandleErrResponse != nil {
h.HandleErrResponse(w, r, err)

return
}

h.handleErrResponseDefault(w, r, err)
}

func closeMultipartForm(r *http.Request) {
if err := r.MultipartForm.RemoveAll(); err != nil {
log.Println(err)
Expand Down
34 changes: 34 additions & 0 deletions nethttp/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,37 @@ func TestHandler_ServeHTTP_customMapping(t *testing.T) {
assert.Equal(t, http.StatusNoContent, rw.Code)
assert.Equal(t, "", rw.Body.String())
}

func TestOptionsMiddleware(t *testing.T) {
u := usecase.NewIOI(nil, nil, func(ctx context.Context, input, output interface{}) error {
return errors.New("failed")
})
h := nethttp.NewHandler(u, func(h *nethttp.Handler) {
h.MakeErrResp = func(ctx context.Context, err error) (int, interface{}) {
return http.StatusExpectationFailed, struct {
Foo string `json:"foo"`
}{Foo: err.Error()}
}
})
h.SetResponseEncoder(&response.Encoder{})

var loggedErr error

rw := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "/", nil)
require.NoError(t, err)

oh := nethttp.OptionsMiddleware(func(h *nethttp.Handler) {
handleErrResponse := h.HandleErrResponse
h.HandleErrResponse = func(w http.ResponseWriter, r *http.Request, err error) {
assert.Equal(t, req, r)
loggedErr = err
handleErrResponse(w, r, err)
}
})(h)

oh.ServeHTTP(rw, req)

assert.EqualError(t, loggedErr, "failed")
assert.Equal(t, `{"foo":"failed"}`+"\n", rw.Body.String())
}
18 changes: 18 additions & 0 deletions nethttp/options.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
package nethttp

import (
"net/http"
"reflect"

"github.com/swaggest/openapi-go/openapi3"
"github.com/swaggest/refl"
"github.com/swaggest/rest"
)

// OptionsMiddleware applies options to encountered nethttp.Handler.
func OptionsMiddleware(options ...func(h *Handler)) func(h http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
var rh *Handler

if HandlerAs(h, &rh) {
rh.options = append(rh.options, options...)

for _, option := range options {
option(rh)
}
}

return h
}
}

// AnnotateOperation allows customizations of prepared operations.
func AnnotateOperation(annotations ...func(operation *openapi3.Operation) error) func(h *Handler) {
return func(h *Handler) {
Expand Down
4 changes: 4 additions & 0 deletions trait.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"reflect"

"github.com/swaggest/openapi-go/openapi3"
"github.com/swaggest/refl"
"github.com/swaggest/usecase"
)
Expand Down Expand Up @@ -33,6 +34,9 @@ type HandlerTrait struct {

// RespValidator validates decoded response data.
RespValidator Validator

// OperationAnnotations are called after operation setup and before adding operation to documentation.
OperationAnnotations []func(op *openapi3.Operation) error
}

// RestHandler is a an accessor.
Expand Down

0 comments on commit 6f71d14

Please sign in to comment.