diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3e9a76d..ccf4f22 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.19 + go-version: 1.22 - name: Build run: go build -v ./... diff --git a/README.md b/README.md index b538e74..0a812f2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ The module provides an HTTP router that integrates with Go's `net/http` package. That means the application code uses `http.ResponseWriter` and `http.Request` from Go's standard library. It supports -- path variables - standard middleware functions - custom middleware functions @@ -11,11 +10,14 @@ It supports There are already a lot of mux implementations. But after a brief search I only found implementations that did not match my requirements oder that overfullfill them. The things I wanted were - path variables +- routes for certain methods - built-in middleware support - I wanted to stay as close to Go's `net/http` package as possible And last but certainly not least: **It's a lot of fun to implement such a thing yourself!** +Since Go 1.22 the standard http package supports path variables and the possibility to specify a handler for a method-path combination. This HTTP router was migrated to make use of the new standard features. So now it just provides a different (but in my opinion clearer) API and the possibility to define middleware functions. + ## Usage A simple server could look like this: @@ -33,11 +35,13 @@ func main() { httpRouter.Post("/books", createBookHandler) httpRouter.Get("/books/:bookId", getSingleBookHandler) + httpRouter.FinishSetup() + log.Fatal(http.ListenAndServe(":8080", httpRouter)) } ``` -The code creates two `GET` and one `POST` route to retrieve and create books. The first parameter is the path, that may contain path variables. Path variables start with a `:`. The second parameter is the handler function that handles the request. A handler function must be of the following type: `type HttpHandler func(http.ResponseWriter, *http.Request, router.Context)` -The first and second parameter are the `ResponseWriter` and the `Request` of Go's `http` package. The third parameter is a `map` containing the path variables. The key is the name the way it was used in the route's path. In this example the third route would contain a value for the key `bookId`. +The code creates two `GET` and one `POST` route to retrieve and create books. The first parameter is the path, that may contain path variables. Path variables start with a `:`. The second parameter is the handler function that handles the request. A handler function must be of the following type: `type HttpHandler func(http.ResponseWriter, *http.Request)` +The first and second parameter are the `ResponseWriter` and the `Request` of Go's `http` package. ## Middleware @@ -52,13 +56,13 @@ import ( ) func middleware1(handler router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { // ... } } func middleware2(handler router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { // ... } } @@ -71,46 +75,17 @@ func main() { httpRouter.Get("/test1", publicHandler) httpRouter.Post("/test2", protectedHanlder).Use(middleware2) - log.Fatal(http.ListenAndServe(":8080", httpRouter)) -} -``` - -There is a third way to add a middleware function. It is possible to define a middleware function for a certain path and HTTP method. - -```go -import ( - "net/http" - - "github.com/gossie/router" -) - -func middleware(handler router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - // ... - } -} - -func main() { - httpRouter := router.New() - - testRouter.UseRecursively(router.GET, "/tests", middleware) - - httpRouter.Get("/tests", testsHandler) - httpRouter.Get("/tests/:testId", singleTestHandler) - httpRouter.Get("/tests/:testId/assertions", assertionsHandler) - httpRouter.Get("/other", otherHandler) + httpRouter.FinishSetup() log.Fatal(http.ListenAndServe(":8080", httpRouter)) } ``` -The code makes sure that the middleware function is executed for `GET` request targeting `/tests`, `/tests/:testId` and `/tests/:testId/assertions`. It won't be executed when `/other` is called. - ### Standard middleware functions #### Basic auth -The module provides a standard middleware function for basic authentication. The line `testRouter.Use(router.BasicAuth(userChecker))` adds basic auth to the router. The `userChecker` is a function that checks if the authentication data is correct. +The module provides a standard middleware function for basic authentication. The line `testRouter.Use(router.BasicAuth(userChecker))` adds basic auth to the router. The `userChecker` is a function that checks if the authentication data is correct. If the user was authenticated, the username will be added to the `context` of the request under the key `router.UsernameKey`. ```go import ( @@ -132,6 +107,8 @@ func main() { httpRouter.Post("/books", createBookHandler) httpRouter.Get("/books/:bookId", getSingleBookHandler) + httpRouter.FinishSetup() + log.Fatal(http.ListenAndServe(":8080", httpRouter)) } ``` @@ -156,6 +133,8 @@ func main() { httpRouter.Post("/books", createBookHandler) httpRouter.Get("/books/:bookId", getSingleBookHandler) + httpRouter.FinishSetup() + log.Fatal(http.ListenAndServe(":8080", httpRouter)) } ``` @@ -170,7 +149,7 @@ import ( ) func logRequestTime(handler router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { start := time.Now() defer func() { log.Default().Println("request took", time.Since(start).Milliseconds(), "ms") @@ -189,6 +168,8 @@ func main() { httpRouter.Post("/books", createBookHandler) httpRouter.Get("/books/:bookId", getSingleBookHandler) + httpRouter.FinishSetup() + log.Fatal(http.ListenAndServe(":8080", httpRouter)) } ``` diff --git a/basicauth.go b/basicauth.go index b5cad2e..145fe45 100644 --- a/basicauth.go +++ b/basicauth.go @@ -1,9 +1,14 @@ package router import ( + "context" "net/http" ) +type usernamekey string + +const UsernameKey = usernamekey("username") + type UserData struct { username, password string } @@ -24,16 +29,16 @@ type UserChecker = func(*UserData) bool func BasicAuth(userChecker UserChecker) Middleware { return func(next HttpHandler) HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx Context) { - performBasicAuth(w, r, ctx, userChecker, next) + return func(w http.ResponseWriter, r *http.Request) { + performBasicAuth(w, r, userChecker, next) } } } -func performBasicAuth(w http.ResponseWriter, r *http.Request, ctx Context, userChecker UserChecker, next HttpHandler) { +func performBasicAuth(w http.ResponseWriter, r *http.Request, userChecker UserChecker, next HttpHandler) { if user, pass, ok := r.BasicAuth(); ok && userChecker(newUserData(user, pass)) { - ctx.username = user - next(w, r, ctx) + + next(w, r.WithContext(context.WithValue(r.Context(), UsernameKey, user))) return } http.Error(w, "", http.StatusUnauthorized) diff --git a/basicauth_test.go b/basicauth_test.go index 3975640..c3bc3cd 100644 --- a/basicauth_test.go +++ b/basicauth_test.go @@ -16,11 +16,13 @@ func TestBasicAuth_noAuthData(t *testing.T) { } testRouter := router.New() - testRouter.Get("/protected", func(_ http.ResponseWriter, _ *http.Request, _ router.Context) { + testRouter.Get("/protected", func(_ http.ResponseWriter, _ *http.Request) { assert.Fail(t, "handler must not be called") }) testRouter.Use(router.BasicAuth(userChecker)) + testRouter.FinishSetup() + w := &TestResponseWriter{} r := &http.Request{ Method: "GET", @@ -37,11 +39,13 @@ func TestBasicAuth_wrongAuthData(t *testing.T) { } testRouter := router.New() - testRouter.Get("/protected", func(_ http.ResponseWriter, _ *http.Request, _ router.Context) { + testRouter.Get("/protected", func(_ http.ResponseWriter, _ *http.Request) { assert.Fail(t, "handler must not be called") }) testRouter.Use(router.BasicAuth(userChecker)) + testRouter.FinishSetup() + userStr := base64.StdEncoding.EncodeToString([]byte("user2:wrong")) w := &TestResponseWriter{} @@ -61,12 +65,14 @@ func TestBasicAuth_correctAuthData(t *testing.T) { } testRouter := router.New() - testRouter.Get("/protected", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - assert.Equal(t, "user2", ctx.Username()) + testRouter.Get("/protected", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "user2", r.Context().Value(router.UsernameKey)) w.WriteHeader(200) }) testRouter.Use(router.BasicAuth(userChecker)) + testRouter.FinishSetup() + userStr := base64.StdEncoding.EncodeToString([]byte("user2:password2")) w := &TestResponseWriter{} diff --git a/cache.go b/cache.go index 17e7ab6..b8c292a 100644 --- a/cache.go +++ b/cache.go @@ -8,10 +8,10 @@ import ( func Cache(duration time.Duration) Middleware { return func(next HttpHandler) HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx Context) { + return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", fmt.Sprintf("public, maxage=%v, s-maxage=%v, immutable", duration.Seconds(), duration.Seconds())) w.Header().Set("Expires", time.Now().Add(duration).Local().Format("Mon, 02 Jan 2006 15:04:05 MST")) - next(w, r, ctx) + next(w, r) } } } diff --git a/cache_test.go b/cache_test.go index 69070c7..ee7f697 100644 --- a/cache_test.go +++ b/cache_test.go @@ -12,10 +12,12 @@ import ( func TestCache_noCache(t *testing.T) { testRouter := router.New() - testRouter.Get("/route", func(w http.ResponseWriter, _ *http.Request, _ router.Context) { + testRouter.Get("/route", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) + testRouter.FinishSetup() + w := &TestResponseWriter{} r := &http.Request{ Method: "GET", @@ -29,11 +31,13 @@ func TestCache_noCache(t *testing.T) { func TestCache_cache(t *testing.T) { testRouter := router.New() - testRouter.Get("/route", func(w http.ResponseWriter, _ *http.Request, _ router.Context) { + testRouter.Get("/route", func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) testRouter.Use(router.Cache(1 * time.Hour)) + testRouter.FinishSetup() + w := &TestResponseWriter{} r := &http.Request{ Method: "GET", diff --git a/context.go b/context.go deleted file mode 100644 index 09e5674..0000000 --- a/context.go +++ /dev/null @@ -1,29 +0,0 @@ -package router - -type pathParam struct { - name, value string -} - -type Context struct { - pathParameters []pathParam - username string -} - -func newContext(pathParameters []pathParam) Context { - return Context{ - pathParameters: pathParameters, - } -} - -func (ctx *Context) PathParameter(name string) string { - for _, p := range ctx.pathParameters { - if name == p.name { - return p.value - } - } - return "" -} - -func (ctx *Context) Username() string { - return ctx.username -} diff --git a/go.mod b/go.mod index 5ee10fd..d968fda 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/gossie/router -go 1.19 +go 1.22 require github.com/stretchr/testify v1.8.4 diff --git a/pathtree.go b/pathtree.go deleted file mode 100644 index c9e1dad..0000000 --- a/pathtree.go +++ /dev/null @@ -1,98 +0,0 @@ -package router - -import ( - "errors" - "log" -) - -const ( - NodeTypeRoot = iota - NodeTypeStatic - NodeTypeVar -) - -type node struct { - nodeType int - pathElement string - route *route - children []*node - middleware []Middleware -} - -func contains(nodes []*node, el string) (*node, bool) { - for _, n := range nodes { - if n.pathElement == el { - return n, true - } - } - return nil, false -} - -func (n *node) createOrGetStaticChild(el string) (*node, error) { - if n.children == nil { - n.children = make([]*node, 0, 1) - } - - foundVariable := false - for _, child := range n.children { - if child.nodeType == NodeTypeVar { - foundVariable = true - } - } - - if foundVariable { - return nil, errors.New("a static path element cannot be added, if there is already a path variable at that position") - } - - pathElement := el - if child, found := contains(n.children, pathElement); found && child.nodeType == NodeTypeStatic && child.pathElement == pathElement { - log.Default().Println("found static path element", pathElement) - return child, nil - } - - log.Default().Println("creating static path element", pathElement) - newNode := &node{nodeType: NodeTypeStatic, pathElement: pathElement} - n.children = append(n.children, newNode) - return newNode, nil -} - -func (n *node) createOrGetVarChild(el string) (*node, error) { - if n.children == nil { - n.children = make([]*node, 0, 1) - } - - if child, found := contains(n.children, el); found && child.nodeType == NodeTypeVar && child.pathElement == el { - log.Default().Println("found variable path element", el) - return child, nil - } - - if len(n.children) != 0 { - return nil, errors.New("a path variable cannot be added as a path element that is already present") - } - - log.Default().Println("creating variable path element", el) - newNode := &node{nodeType: NodeTypeVar, pathElement: el} - n.children = append(n.children, newNode) - return newNode, nil -} - -func (n *node) childNode(el string) *node { - if len(n.children) == 1 && n.children[0].nodeType == NodeTypeVar { - return n.children[0] - } - - if child, found := contains(n.children, el); found { - return child - } - - log.Default().Println("could not find node for path element", el) - return nil -} - -type pathTree struct { - root *node -} - -func newPathTree() *pathTree { - return &pathTree{&node{nodeType: NodeTypeRoot}} -} diff --git a/route.go b/route.go index 9bbfb7f..3d01a58 100644 --- a/route.go +++ b/route.go @@ -1,12 +1,18 @@ package router type route struct { + method string + path string handler HttpHandler middleware []Middleware } -func newRoute(handler HttpHandler) *route { - return &route{handler: handler} +func newRoute(method, path string, handler HttpHandler) *route { + return &route{ + method: method, + path: path, + handler: handler, + } } func (r *route) Use(middleware Middleware) *route { diff --git a/router.go b/router.go index 08adb87..477c0fe 100644 --- a/router.go +++ b/router.go @@ -1,10 +1,8 @@ package router import ( - "log" - "math" + "fmt" "net/http" - "strings" "sync" "time" ) @@ -19,39 +17,33 @@ const ( PATH_VARIABLE_PREFIX = ":" ) -type HttpHandler = func(http.ResponseWriter, *http.Request, Context) +type HttpHandler = func(w http.ResponseWriter, r *http.Request) type Middleware = func(HttpHandler) HttpHandler type HttpRouter struct { - mutex sync.RWMutex - routes map[string]*pathTree - middleware []Middleware - pathVariableCount uint + mutex sync.RWMutex + routes []*route + middleware []Middleware } func New() *HttpRouter { - return &HttpRouter{routes: make(map[string]*pathTree)} + http.DefaultServeMux = &http.ServeMux{} + return &HttpRouter{routes: make([]*route, 0)} } func (hr *HttpRouter) addRoute(path string, method string, handler HttpHandler) *route { hr.mutex.Lock() defer hr.mutex.Unlock() - hr.pathVariableCount = uint(math.Max(float64(hr.pathVariableCount), float64(strings.Count(path, PATH_VARIABLE_PREFIX)))) + newRoute := newRoute(method, path, handler) + hr.routes = append(hr.routes, newRoute) - rootHandler := func() { - hr.routes[method].root.route = newRoute(handler) - } - - currentNode := hr.getCreateOrGetNode(path, method, rootHandler) - - currentNode.route = newRoute(handler) - return currentNode.route + return newRoute } func (hr *HttpRouter) Handle(path string, handler http.Handler) { - hr.addRoute(path, GET, func(w http.ResponseWriter, r *http.Request, _ Context) { + hr.addRoute(path, GET, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "public, maxage=86400, s-maxage=86400, immutable") w.Header().Set("Expires", time.Now().Add(86400*time.Second).Local().Format("Mon, 02 Jan 2006 15:04:05 MST")) handler.ServeHTTP(w, r) @@ -85,106 +77,22 @@ func (hr *HttpRouter) Use(middleware Middleware) { hr.middleware = append(hr.middleware, middleware) } -func (hr *HttpRouter) UseRecursively(method, path string, middleware Middleware) { - hr.mutex.Lock() - defer hr.mutex.Unlock() - - rootHandler := func() { - panic("use the Use() method") - } - - currentNode := hr.getCreateOrGetNode(path, method, rootHandler) - - currentNode.middleware = append(currentNode.middleware, middleware) -} - -func (hr *HttpRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - hr.mutex.RLock() - defer hr.mutex.RUnlock() - - var pathVariables []pathParam +func (hr *HttpRouter) FinishSetup() { + for _, r := range hr.routes { + handler := r.handler - if tree, present := hr.routes[r.Method]; present { - currentNode := tree.root - if r.URL.Path == SEPARATOR { - currentNode.route.handler(w, r, newContext(nil)) - return + for i := len(r.middleware) - 1; i >= 0; i-- { + handler = r.middleware[i](handler) } - middlewareToExecute := appendMiddlewareIfNeeded(nil, hr.middleware) - - currentPath := r.URL.Path[1:] - index := strings.Index(currentPath, SEPARATOR) - for index > 0 || currentPath != "" { - var el string - if index < 0 { - el = currentPath - currentPath = "" - } else { - el = currentPath[0:index] - currentPath = currentPath[index+1:] - } - - if currentNode != nil { - middlewareToExecute = appendMiddlewareIfNeeded(middlewareToExecute, currentNode.middleware) - currentNode = currentNode.childNode(el) - if currentNode != nil && currentNode.nodeType == NodeTypeVar { - if pathVariables == nil { - pathVariables = make([]pathParam, 0, hr.pathVariableCount) - } - pathVariables = append(pathVariables, pathParam{name: currentNode.pathElement, value: el}) - } - } - - index = strings.Index(currentPath, SEPARATOR) + for i := len(hr.middleware) - 1; i >= 0; i-- { + handler = hr.middleware[i](handler) } - if currentNode == nil || currentNode.route == nil { - log.Default().Println("no", r.Method, "pattern matched", r.URL.Path, "-> returning 404") - http.NotFound(w, r) - return - } - - handlerToExceute := currentNode.route.handler - middlewareToExecute = appendMiddlewareIfNeeded(middlewareToExecute, currentNode.route.middleware) - for i := len(middlewareToExecute) - 1; i >= 0; i-- { - handlerToExceute = middlewareToExecute[i](handlerToExceute) - } - - handlerToExceute(w, r, newContext(pathVariables)) + http.Handle(fmt.Sprintf("%v %v", r.method, r.path), http.HandlerFunc(handler)) } } -func (hr *HttpRouter) getCreateOrGetNode(path string, method string, rootHandler func()) *node { - if _, present := hr.routes[method]; !present { - hr.routes[method] = newPathTree() - } - - if path == SEPARATOR { - rootHandler() - } - - currentNode := hr.routes[method].root - var err error - for _, el := range strings.Split(path, SEPARATOR) { - if el != "" { - if strings.HasPrefix(el, PATH_VARIABLE_PREFIX) { - currentNode, err = currentNode.createOrGetVarChild(el[1:]) - } else { - currentNode, err = currentNode.createOrGetStaticChild(el) - } - - if err != nil { - panic(err.Error()) - } - } - } - return currentNode -} - -func appendMiddlewareIfNeeded(current []Middleware, source []Middleware) []Middleware { - if len(source) > 0 { - return append(current, source...) - } - return current +func (hr *HttpRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + http.DefaultServeMux.ServeHTTP(w, r) } diff --git a/router_test.go b/router_test.go index 8d25a30..0b8f248 100644 --- a/router_test.go +++ b/router_test.go @@ -35,40 +35,42 @@ func TestRouting(t *testing.T) { testRouter := router.New() - testRouter.Post("/tests", func(w http.ResponseWriter, r *http.Request, _ router.Context) { + testRouter.Post("/tests", func(w http.ResponseWriter, r *http.Request) { testString = "post-was-called" }) - testRouter.Get("/tests/:testString", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - if testString != ctx.PathParameter("testString") { - t.Fatalf("%s != %s", testString, ctx.PathParameter("testString")) + testRouter.Get("/tests/{testString}", func(w http.ResponseWriter, r *http.Request) { + if testString != r.PathValue("testString") { + t.Fatalf("%s != %s", testString, r.PathValue("testString")) } }) - testRouter.Get("/tests/:testString/:detailId", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - if testString != ctx.PathParameter("testString") { - t.Fatalf("%s != %s", testString, ctx.PathParameter("testString")) + testRouter.Get("/tests/{testString}/{detailId}", func(w http.ResponseWriter, r *http.Request) { + if testString != r.PathValue("testString") { + t.Fatalf("%s != %s", testString, r.PathValue("testString")) } }) - testRouter.Put("/tests/:testString", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - if testString == ctx.PathParameter("testString") { + testRouter.Put("/tests/{testString}", func(w http.ResponseWriter, r *http.Request) { + if testString == r.PathValue("testString") { testString = "put-was-called" } }) - testRouter.Patch("/tests/:testString", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - if testString == ctx.PathParameter("testString") { + testRouter.Patch("/tests/{testString}", func(w http.ResponseWriter, r *http.Request) { + if testString == r.PathValue("testString") { testString = "patch-was-called" } }) - testRouter.Delete("/tests/:testString", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - if testString == ctx.PathParameter("testString") { + testRouter.Delete("/tests/{testString}", func(w http.ResponseWriter, r *http.Request) { + if testString == r.PathValue("testString") { testString = "" } }) + testRouter.FinishSetup() + w1 := &TestResponseWriter{} r1 := &http.Request{ Method: "POST", @@ -132,12 +134,13 @@ func TestRouting(t *testing.T) { } func TestHasRootRoute(t *testing.T) { - emptyHandler := func(w http.ResponseWriter, _ *http.Request, _ router.Context) { + emptyHandler := func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(200) } testRouter := router.New() testRouter.Get("/", emptyHandler) + testRouter.FinishSetup() w := &TestResponseWriter{} r := &http.Request{ @@ -151,10 +154,11 @@ func TestHasRootRoute(t *testing.T) { } func TestReturnsStatus404(t *testing.T) { - emptyHandler := func(_ http.ResponseWriter, _ *http.Request, _ router.Context) {} + emptyHandler := func(_ http.ResponseWriter, _ *http.Request) {} testRouter := router.New() testRouter.Get("/tests/:id", emptyHandler) + testRouter.FinishSetup() w := &TestResponseWriter{} r := &http.Request{ @@ -167,53 +171,20 @@ func TestReturnsStatus404(t *testing.T) { assert.Equal(t, 404, w.statusCode) } -func TestCreatesVariableAndStaticElementAtTheSamePosition(t *testing.T) { - emptyHandler := func(_ http.ResponseWriter, _ *http.Request, _ router.Context) {} - - testRouter := router.New() - - assert.Panics(t, func() { - testRouter.Get("/tests/:id", emptyHandler) - testRouter.Get("/tests/green", emptyHandler) - }) -} - -func TestCreatesStaticElementAndVariableAtTheSamePosition(t *testing.T) { - emptyHandler := func(_ http.ResponseWriter, _ *http.Request, _ router.Context) {} - - testRouter := router.New() - - assert.Panics(t, func() { - testRouter.Get("/tests/green", emptyHandler) - testRouter.Get("/tests/:id", emptyHandler) - }) -} - -func TestCreatesTwoVariablesAtTheSamePosition(t *testing.T) { - emptyHandler := func(_ http.ResponseWriter, _ *http.Request, _ router.Context) {} - - testRouter := router.New() - - assert.Panics(t, func() { - testRouter.Get("/tests/:testId", emptyHandler) - testRouter.Get("/tests/:id", emptyHandler) - }) -} - func TestMiddleware(t *testing.T) { executed := make([]string, 0) middleware1 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "middleware1") - in(w, r, ctx) + in(w, r) } } middleware2 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "middleware2") - in(w, r, ctx) + in(w, r) } } @@ -222,10 +193,12 @@ func TestMiddleware(t *testing.T) { testRouter.Use(middleware1) testRouter.Use(middleware2) - testRouter.Get("/test", func(w http.ResponseWriter, r *http.Request, _ router.Context) { + testRouter.Get("/test", func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "get") }) + testRouter.FinishSetup() + w := &TestResponseWriter{} r := &http.Request{ Method: "GET", @@ -243,16 +216,16 @@ func TestMiddlewareForSingleRoute(t *testing.T) { executed := make([]string, 0) middleware1 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "middleware1") - in(w, r, ctx) + in(w, r) } } middleware2 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { + return func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "middleware2") - in(w, r, ctx) + in(w, r) } } @@ -260,14 +233,16 @@ func TestMiddlewareForSingleRoute(t *testing.T) { testRouter.Use(middleware1) - testRouter.Get("/test1", func(w http.ResponseWriter, r *http.Request, _ router.Context) { + testRouter.Get("/test1", func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "test1") }) - testRouter.Get("/test2", func(w http.ResponseWriter, r *http.Request, _ router.Context) { + testRouter.Get("/test2", func(w http.ResponseWriter, r *http.Request) { executed = append(executed, "test2") }).Use(middleware2) + testRouter.FinishSetup() + w := &TestResponseWriter{} r1 := &http.Request{ Method: "GET", @@ -293,133 +268,3 @@ func TestMiddlewareForSingleRoute(t *testing.T) { assert.Equal(t, "middleware2", executed[3]) assert.Equal(t, "test2", executed[4]) } - -func TestMiddlewareForGroupOfRoutes(t *testing.T) { - executed := make([]string, 0) - - middleware1 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - executed = append(executed, "middleware1") - in(w, r, ctx) - } - } - - middleware2 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - executed = append(executed, "middleware2") - in(w, r, ctx) - } - } - - middleware3 := func(in router.HttpHandler) router.HttpHandler { - return func(w http.ResponseWriter, r *http.Request, ctx router.Context) { - executed = append(executed, "middleware3") - in(w, r, ctx) - } - } - - testRouter := router.New() - - testRouter.Use(middleware1) - testRouter.UseRecursively(router.GET, "/tests", middleware2) - - testRouter.Get("/tests/test1", func(w http.ResponseWriter, r *http.Request, _ router.Context) { - executed = append(executed, "GET test1") - }) - - testRouter.Get("/tests/test2", func(w http.ResponseWriter, r *http.Request, _ router.Context) { - executed = append(executed, "GET test2") - }).Use(middleware3) - - testRouter.Post("/tests/test2", func(w http.ResponseWriter, r *http.Request, _ router.Context) { - executed = append(executed, "POST test2") - }) - - testRouter.Get("/other", func(w http.ResponseWriter, r *http.Request, _ router.Context) { - executed = append(executed, "GET other") - }) - - w := &TestResponseWriter{} - r1 := &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/tests/test1"}, - } - testRouter.ServeHTTP(w, r1) - - assert.Equal(t, 3, len(executed)) - assert.Equal(t, "middleware1", executed[0]) - assert.Equal(t, "middleware2", executed[1]) - assert.Equal(t, "GET test1", executed[2]) - - r2 := &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/tests/test2"}, - } - testRouter.ServeHTTP(w, r2) - - assert.Equal(t, 7, len(executed)) - assert.Equal(t, "middleware1", executed[0]) - assert.Equal(t, "middleware2", executed[1]) - assert.Equal(t, "GET test1", executed[2]) - assert.Equal(t, "middleware1", executed[3]) - assert.Equal(t, "middleware2", executed[4]) - assert.Equal(t, "middleware3", executed[5]) - assert.Equal(t, "GET test2", executed[6]) - - r3 := &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/other"}, - } - testRouter.ServeHTTP(w, r3) - - assert.Equal(t, 9, len(executed)) - assert.Equal(t, "middleware1", executed[0]) - assert.Equal(t, "middleware2", executed[1]) - assert.Equal(t, "GET test1", executed[2]) - assert.Equal(t, "middleware1", executed[3]) - assert.Equal(t, "middleware2", executed[4]) - assert.Equal(t, "middleware3", executed[5]) - assert.Equal(t, "GET test2", executed[6]) - assert.Equal(t, "middleware1", executed[7]) - assert.Equal(t, "GET other", executed[8]) - - r4 := &http.Request{ - Method: "POST", - URL: &url.URL{Path: "/tests/test2"}, - } - testRouter.ServeHTTP(w, r4) - - assert.Equal(t, 11, len(executed)) - assert.Equal(t, "middleware1", executed[0]) - assert.Equal(t, "middleware2", executed[1]) - assert.Equal(t, "GET test1", executed[2]) - assert.Equal(t, "middleware1", executed[3]) - assert.Equal(t, "middleware2", executed[4]) - assert.Equal(t, "middleware3", executed[5]) - assert.Equal(t, "GET test2", executed[6]) - assert.Equal(t, "middleware1", executed[7]) - assert.Equal(t, "GET other", executed[8]) - assert.Equal(t, "middleware1", executed[9]) - assert.Equal(t, "POST test2", executed[10]) -} - -// func TestRouteCaseInsensitivity(t *testing.T) { -// executed := false - -// testRouter := router.New() - -// testRouter.Get("/TEST1/:id/test2", func(w http.ResponseWriter, r *http.Request, ctx router.Context) { -// assert.Equal(t, "aBc", ctx.PathParameter("id")) -// executed = true -// }) - -// w := &TestResponseWriter{} -// r := &http.Request{ -// Method: "GET", -// URL: &url.URL{Path: "/tEsT1/aBc/TeSt2"}, -// } - -// testRouter.ServeHTTP(w, r) - -// assert.True(t, executed) -// }