From 843ff2bd6ec8e683ea32355f96e4404a85d17a1a Mon Sep 17 00:00:00 2001 From: Aaron Parfitt Date: Thu, 23 Mar 2023 08:29:31 +0000 Subject: [PATCH] removing gorilla mux for gateway (#68) * removing gorilla mux for gateway * removing gorilla mux for gateway * Update README.md Co-authored-by: Dec <20417324+DeWarner@users.noreply.github.com> * the gateway is now detached from mux * remove unused code and .vscode launch.json --------- Co-authored-by: Dec <20417324+DeWarner@users.noreply.github.com> --- .vscode/launch.json | 39 ---- Makefile | 2 +- README.md | 1 + cmd/frontman/main.go | 4 +- frontman.go | 189 +++++++++++++++++++ frontman_test.go | 35 ++++ gateway.go | 191 -------------------- gateway/gateway.go | 144 +++++++++++++++ handlers_test.go => gateway/gateway_test.go | 19 +- handlers.go => gateway/routes.go | 138 +------------- gateway_test.go | 36 ---- ssl/tls.go | 14 ++ test.yaml | 8 +- 13 files changed, 404 insertions(+), 416 deletions(-) delete mode 100644 .vscode/launch.json create mode 100644 frontman.go create mode 100644 frontman_test.go delete mode 100644 gateway.go create mode 100644 gateway/gateway.go rename handlers_test.go => gateway/gateway_test.go (95%) rename handlers.go => gateway/routes.go (53%) delete mode 100644 gateway_test.go create mode 100644 ssl/tls.go diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index badf1a2..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "version": "0.2.0", - "configurations": [ - { - "name": "Debug Go tests", - "type": "go", - "request": "launch", - "mode": "test", - "program": "${workspaceFolder}", - "args": [], - "env": {}, - "showLog": true, - "debugAdapter": "legacy", - "port": 2345 - }, - { - "name": "Debug frontman", - "type": "go", - "request": "launch", - "mode": "debug", - "program": "${workspaceFolder}/cmd/frontman", - "args": [ - "-config", - "test.yaml" - ], - "env": {}, - "showLog": true, - "trace": true, - "dlvLoadConfig": { - "followPointers": true, - "maxVariableRecurse": 3, - "maxStringLen": 64, - "maxArrayValues": 64, - "maxStructFields": -1 - }, - "cwd": "${workspaceFolder}" - } - ] -} diff --git a/Makefile b/Makefile index 7079733..2540094 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ clean: # Test target (run all tests) test: - $(GO) test -v + $(GO) test -v ./... bench: go test -v -bench=. -benchmem diff --git a/README.md b/README.md index a28e04e..580cfee 100644 --- a/README.md +++ b/README.md @@ -238,6 +238,7 @@ Supported Load Balancer types: "options": { "weights": [1, 2, 3] } + } ``` - Least Connection diff --git a/cmd/frontman/main.go b/cmd/frontman/main.go index 493aa0d..5ba94ca 100644 --- a/cmd/frontman/main.go +++ b/cmd/frontman/main.go @@ -43,9 +43,9 @@ func main() { fmt.Println("failed to initialize logger") os.Exit(1) } - + // Create a new Gateway instance - gateway, err := frontman.NewGateway(config, logger) + gateway, err := frontman.NewFrontman(config, logger) if err != nil { logger.Fatalf("failed to create gateway: %v", err) } diff --git a/frontman.go b/frontman.go new file mode 100644 index 0000000..08d903a --- /dev/null +++ b/frontman.go @@ -0,0 +1,189 @@ +package frontman + +import ( + "context" + "crypto/tls" + "fmt" + "net/http" + "sync" + + "github.com/Frontman-Labs/frontman/config" + "github.com/Frontman-Labs/frontman/gateway" + "github.com/Frontman-Labs/frontman/log" + "github.com/Frontman-Labs/frontman/plugins" + "github.com/Frontman-Labs/frontman/service" + "github.com/Frontman-Labs/frontman/ssl" + "github.com/gorilla/mux" +) + +// Frontman contains the backend services and the router +type Frontman struct { + router *gateway.APIGateway + service *mux.Router + backendServices service.ServiceRegistry + conf *config.Config + log log.Logger +} + +func NewServicesRouter(backendServices service.ServiceRegistry) *mux.Router { + router := mux.NewRouter() + + router.HandleFunc("/api/services", getServicesHandler(backendServices)).Methods("GET") + router.HandleFunc("/api/services", addServiceHandler(backendServices)).Methods("POST") + router.HandleFunc("/api/services/{name}", removeServiceHandler(backendServices)).Methods("DELETE") + router.HandleFunc("/api/services/{name}", updateServiceHandler(backendServices)).Methods("PUT") + router.HandleFunc("/api/health", getHealthHandler(backendServices)).Methods("GET") + + return router +} + +// NewGateway creates a new Frontman instance with a Redis client connection factory +func NewFrontman(conf *config.Config, log log.Logger) (*Frontman, error) { + + // Retrieve the Redis client connection from the factory + ctx := context.Background() + + // Create a new BackendServices instance + backendServices, err := service.NewServiceRegistry(ctx, conf.GlobalConfig.ServiceType, conf) + if err != nil { + return nil, err + } + + servicesRouter := NewServicesRouter(backendServices) + + // Load plugins + var plug []plugins.FrontmanPlugin + + // Create new APIGateway instance + + if conf.PluginConfig.Enabled { + plug, err = plugins.LoadPlugins(conf.PluginConfig.Order) + if err != nil { + return nil, err + } + + } + + // Create new APIGateway instance + clients := make(map[string]*http.Client) + lock := sync.Mutex{} + apiGateway := gateway.NewAPIGateway(backendServices, plug, conf, clients, log, &lock) + go gateway.RefreshConnections(backendServices, clients, &lock) + + // Create the Frontman instance + return &Frontman{ + router: apiGateway, + service: servicesRouter, + backendServices: backendServices, + conf: conf, + log: log, + }, nil +} + +func (gw *Frontman) Start() error { + apiAddr := gw.conf.APIConfig.Addr + if apiAddr == "" { + apiAddr = "0.0.0.0:8080" + } + gatewayAddr := gw.conf.GatewayConfig.Addr + if gatewayAddr == "" { + gatewayAddr = "0.0.0.0:8000" + } + + var apiHandler http.Handler + var gatewayHandler http.Handler + + if gw.conf.APIConfig.SSL.Enabled { + apiHandler = gw.service + cert, err := ssl.LoadCert(gw.conf.APIConfig.SSL.Cert, gw.conf.APIConfig.SSL.Key) + if err != nil { + return err + } + apiServer := createServer(apiAddr, apiHandler, &cert) + gw.log.Infof("Started Frontman API with SSL on %s", apiAddr) + go func() { + if err := startServer(apiServer); err != nil { + gw.log.Fatal(err) + } + }() + } else { + apiHandler = gw.service + api := createServer(apiAddr, apiHandler, nil) + gw.log.Infof("Started Frontman API on %s", apiAddr) + go func() { + if err := startServer(api); err != nil { + gw.log.Fatal(err) + } + }() + } + + if gw.conf.GatewayConfig.SSL.Enabled { + gatewayHandler = gw.router + cert, err := ssl.LoadCert(gw.conf.GatewayConfig.SSL.Cert, gw.conf.GatewayConfig.SSL.Key) + if err != nil { + return err + } + + // Redirect HTTP traffic to HTTPS + httpAddr := "0.0.0.0:80" + httpRedirect := createRedirectServer(httpAddr, gatewayAddr) + gw.log.Infof("Started HTTP redirect server on %s", httpAddr) + go func() { + if err := startServer(httpRedirect); err != nil { + gw.log.Fatal(err) + } + }() + + gatewayServer := createServer(gatewayAddr, gatewayHandler, &cert) + gw.log.Infof("Started Frontman Frontman with SSL on %s", gatewayAddr) + if err := startServer(gatewayServer); err != nil { + return err + } + } else { + gatewayHandler = gw.router + gateway := createServer(gatewayAddr, gatewayHandler, nil) + gw.log.Infof("Started Frontman Frontman on %s", gatewayAddr) + if err := startServer(gateway); err != nil { + return err + } + } + + return nil +} + +func createRedirectServer(addr string, redirectAddr string) *http.Server { + redirect := func(w http.ResponseWriter, req *http.Request) { + httpsURL := "https://" + req.Host + req.URL.Path + http.Redirect(w, req, httpsURL, http.StatusMovedPermanently) + } + return &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(redirect), + } +} + +func createServer(addr string, handler http.Handler, cert *tls.Certificate) *http.Server { + server := &http.Server{ + Addr: addr, + Handler: handler, + } + if cert != nil { + server.TLSConfig = &tls.Config{ + Certificates: []tls.Certificate{*cert}, + } + } + return server +} + +func startServer(server *http.Server) error { + if server.TLSConfig != nil { + if err := server.ListenAndServeTLS("", ""); err != nil { + return fmt.Errorf("Failed to start server with TLS: %w", err) + } + } else { + if err := server.ListenAndServe(); err != nil { + return fmt.Errorf("Failed to start server without TLS: %w", err) + } + } + return nil +} diff --git a/frontman_test.go b/frontman_test.go new file mode 100644 index 0000000..8df765b --- /dev/null +++ b/frontman_test.go @@ -0,0 +1,35 @@ +package frontman + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestCreateRedirectServer(t *testing.T) { + redirectAddr := "0.0.0.0:8000" + redirectServer := createRedirectServer("0.0.0.0:80", redirectAddr) + + // Create a test request to the redirect server + req, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + t.Fatal(err) + } + + // Create a test response recorder + rr := httptest.NewRecorder() + + // Call the redirect server's handler function + redirectServer.Handler.ServeHTTP(rr, req) + + // Check that the response has a 301 status code + if status := rr.Code; status != http.StatusMovedPermanently { + t.Errorf("Unexpected status code: got %v, expected %v", status, http.StatusMovedPermanently) + } + + // Check that the response includes a "Location" header with the expected value + expectedURL := "https://example.com/foo" + if location := rr.Header().Get("Location"); location != expectedURL { + t.Errorf("Unexpected Location header value: got %v, expected %v", location, expectedURL) + } +} diff --git a/gateway.go b/gateway.go deleted file mode 100644 index c8c6f21..0000000 --- a/gateway.go +++ /dev/null @@ -1,191 +0,0 @@ -package frontman - -import ( - "context" - "crypto/tls" - "fmt" - "net/http" - "strings" - - "github.com/Frontman-Labs/frontman/config" - "github.com/Frontman-Labs/frontman/log" - "github.com/Frontman-Labs/frontman/plugins" - "github.com/Frontman-Labs/frontman/service" - "github.com/gorilla/mux" -) - -// Gateway contains the backend services and the router -type Gateway struct { - router *mux.Router - service *mux.Router - backendServices service.ServiceRegistry - conf *config.Config - log log.Logger -} - -func NewServicesRouter(backendServices service.ServiceRegistry) *mux.Router { - router := mux.NewRouter() - - router.HandleFunc("/api/services", getServicesHandler(backendServices)).Methods("GET") - router.HandleFunc("/api/services", addServiceHandler(backendServices)).Methods("POST") - router.HandleFunc("/api/services/{name}", removeServiceHandler(backendServices)).Methods("DELETE") - router.HandleFunc("/api/services/{name}", updateServiceHandler(backendServices)).Methods("PUT") - router.HandleFunc("/api/health", getHealthHandler(backendServices)).Methods("GET") - - return router -} - -// NewGateway creates a new Gateway instance with a Redis client connection factory -func NewGateway(conf *config.Config, log log.Logger) (*Gateway, error) { - - // Retrieve the Redis client connection from the factory - ctx := context.Background() - - // Create a new BackendServices instance - backendServices, err := service.NewServiceRegistry(ctx, conf.GlobalConfig.ServiceType, conf) - if err != nil { - return nil, err - } - - servicesRouter := NewServicesRouter(backendServices) - - // Create a new router instance - proxyRouter := mux.NewRouter() - - // Load plugins - var plug []plugins.FrontmanPlugin - - if conf.PluginConfig.Enabled { - plug, err = plugins.LoadPlugins(conf.PluginConfig.Order) - if err != nil { - return nil, err - } - - } - - proxyRouter.HandleFunc("/{proxyPath:.+}", gatewayHandler(backendServices, plug, conf, make(map[string]*http.Client))).Methods("GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS").MatcherFunc(func(r *http.Request, rm *mux.RouteMatch) bool { - vars := mux.Vars(r) - proxyPath := vars["proxyPath"] - for _, prefix := range []string{"/api/"} { - if strings.HasPrefix(proxyPath, prefix) { - return false - } - } - return true - }) - - // Create the Gateway instance - return &Gateway{ - router: proxyRouter, - service: servicesRouter, - backendServices: backendServices, - conf: conf, - log : log, - }, nil -} - -func (gw *Gateway) Start() error { - apiAddr := gw.conf.APIConfig.Addr - if apiAddr == "" { - apiAddr = "0.0.0.0:8080" - } - gatewayAddr := gw.conf.GatewayConfig.Addr - if gatewayAddr == "" { - gatewayAddr = "0.0.0.0:8000" - } - - var apiHandler http.Handler - var gatewayHandler http.Handler - - if gw.conf.APIConfig.SSL.Enabled { - apiHandler = gw.service - cert, err := loadCert(gw.conf.APIConfig.SSL.Cert, gw.conf.APIConfig.SSL.Key) - if err != nil { - return err - } - apiServer := createServer(apiAddr, apiHandler, &cert) - gw.log.Infof("Started Frontman API with SSL on %s", apiAddr) - go func(){if err := startServer(apiServer); err!=nil{ gw.log.Fatal(err)}}() - } else { - apiHandler = gw.service - api := createServer(apiAddr, apiHandler, nil) - gw.log.Infof("Started Frontman API on %s", apiAddr) - go func(){if err := startServer(api); err!=nil{ gw.log.Fatal(err)}}() - } - - if gw.conf.GatewayConfig.SSL.Enabled { - gatewayHandler = gw.router - cert, err := loadCert(gw.conf.GatewayConfig.SSL.Cert, gw.conf.GatewayConfig.SSL.Key) - if err != nil { - return err - } - - // Redirect HTTP traffic to HTTPS - httpAddr := "0.0.0.0:80" - httpRedirect := createRedirectServer(httpAddr, gatewayAddr) - gw.log.Infof("Started HTTP redirect server on %s", httpAddr) - go func(){if err := startServer(httpRedirect); err!=nil{ gw.log.Fatal(err)}}() - - - gatewayServer := createServer(gatewayAddr, gatewayHandler, &cert) - gw.log.Infof("Started Frontman Gateway with SSL on %s", gatewayAddr) - if err := startServer(gatewayServer); err!=nil { - return err - } - } else { - gatewayHandler = gw.router - gateway := createServer(gatewayAddr, gatewayHandler, nil) - gw.log.Infof("Started Frontman Gateway on %s", gatewayAddr) - if err := startServer(gateway); err!=nil { - return err - } - } - - return nil -} - -func createRedirectServer(addr string, redirectAddr string) *http.Server { - redirect := func(w http.ResponseWriter, req *http.Request) { - httpsURL := "https://" + req.Host + req.URL.Path - http.Redirect(w, req, httpsURL, http.StatusMovedPermanently) - } - return &http.Server{ - Addr: addr, - Handler: http.HandlerFunc(redirect), - } -} - - -func loadCert(certFile, keyFile string) (tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return tls.Certificate{}, fmt.Errorf("Failed to load certificate: %w", err) - } - return cert, nil -} - -func createServer(addr string, handler http.Handler, cert *tls.Certificate) *http.Server { - server := &http.Server{ - Addr: addr, - Handler: handler, - } - if cert != nil { - server.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{*cert}, - } - } - return server -} - -func startServer(server *http.Server) error { - if server.TLSConfig != nil { - if err := server.ListenAndServeTLS("", ""); err != nil { - return fmt.Errorf("Failed to start server with TLS: %w", err) - } - } else { - if err := server.ListenAndServe(); err != nil { - return fmt.Errorf("Failed to start server without TLS: %w", err) - } - } - return nil -} diff --git a/gateway/gateway.go b/gateway/gateway.go new file mode 100644 index 0000000..d498223 --- /dev/null +++ b/gateway/gateway.go @@ -0,0 +1,144 @@ +package gateway + +import ( + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/Frontman-Labs/frontman/config" + "github.com/Frontman-Labs/frontman/log" + "github.com/Frontman-Labs/frontman/plugins" + "github.com/Frontman-Labs/frontman/service" +) + +type APIGateway struct { + bs service.ServiceRegistry + plugs []plugins.FrontmanPlugin + conf *config.Config + clients map[string]*http.Client + clientLock *sync.Mutex + currentTargetIndex int + log log.Logger +} + +func NewAPIGateway(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, clients map[string]*http.Client, logger log.Logger, lock *sync.Mutex) *APIGateway { + return &APIGateway{ + bs: bs, + plugs: plugs, + conf: conf, + clients: clients, + clientLock: lock, + log: logger, + } +} + +func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) { + root := buildRoutes(g.bs.GetServices()) + for _, plugin := range g.plugs { + if err := plugin.PreRequest(req, g.bs, g.conf); err != nil { + g.log.Errorf("Plugin error: %v", err) + http.Error(w, err.Error(), err.StatusCode()) + return + } + } + + // Find the backend service that matches the request + backendService := findBackendService(root, req) + + // If the backend service was not found, return a 404 error + if backendService == nil { + http.NotFound(w, req) + return + } + + // Get the upstream target URL for this request + upstreamTarget := backendService.GetLoadBalancer().ChooseTarget(backendService.UpstreamTargets) + var urlPath string + if backendService.StripPath { + urlPath = strings.TrimPrefix(req.URL.Path, backendService.Path) + } else { + urlPath = backendService.Path + } + + // Create a new target URL with the service path and scheme + + targetURL, err := url.Parse(upstreamTarget + urlPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Get or create a new client for this backend service + client, err := getClientForBackendService(*backendService, backendService.Name, g.clients, g.clientLock) + headers := make(http.Header) + // Copy the headers from the original request + copyHeaders(headers, req.Header) + if backendService.AuthConfig != nil { + tokenValidator := backendService.GetTokenValidator() + // Backend service has auth config specified + claims, err := tokenValidator.ValidateToken(req) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + if claims != nil { + data, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + headers.Add(backendService.GetUserDataHeader(), string(data)) + } + + } + // Remove the X-Forwarded-For header to prevent spoofing + headers.Del("X-Forwarded-For") + + // Log a message indicating that the request is being sent to the target service + g.log.Infof("Sending request to %s: %s %s", upstreamTarget, req.Method, urlPath) + + // Send the request to the target service using the client with the specified transport + resp, err := client.Do(&http.Request{ + Method: req.Method, + URL: targetURL, + Proto: req.Proto, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Header: headers, + Body: req.Body, + ContentLength: req.ContentLength, + Host: targetURL.Host, + }) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + g.log.Infof("Error sending request: %v\n", err.Error()) + return + } + + backendService.GetLoadBalancer().Done(upstreamTarget) + + defer resp.Body.Close() + + for _, plugin := range g.plugs { + if err := plugin.PostResponse(resp, g.bs, g.conf); err != nil { + g.log.Infof("Plugin error: %v", err) + http.Error(w, err.Error(), err.StatusCode()) + return + } + } + + // Log a message indicating that the response has been received from the target service + g.log.Infof("Response received from %s: %d %s", upstreamTarget, resp.StatusCode, resp.Status) + + // Copy the response headers back to the client + copyHeaders(w.Header(), resp.Header) + + // Set the status code and body of the response + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) + +} diff --git a/handlers_test.go b/gateway/gateway_test.go similarity index 95% rename from handlers_test.go rename to gateway/gateway_test.go index 561a217..4865fc5 100644 --- a/handlers_test.go +++ b/gateway/gateway_test.go @@ -1,4 +1,4 @@ -package frontman +package gateway import ( "github.com/Frontman-Labs/frontman/loadbalancer" @@ -11,6 +11,7 @@ import ( "time" "github.com/Frontman-Labs/frontman/config" + "github.com/Frontman-Labs/frontman/log" "github.com/Frontman-Labs/frontman/plugins" "github.com/Frontman-Labs/frontman/service" ) @@ -217,8 +218,12 @@ func TestGatewayHandler(t *testing.T) { }, } - handler := gatewayHandler(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients) - handler(w, req) + logger, err := log.NewZapLogger("info") + if err != nil { + t.Errorf("could not create logger due to: %s", err) + } + handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{}) + handler.ServeHTTP(w, req) // Check the response status code if w.Code != tc.expectedStatusCode { @@ -425,9 +430,13 @@ func BenchmarkGatewayHandler(b *testing.B) { mockErr: nil, }} - handler := gatewayHandler(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients) + logger, err := log.NewZapLogger("info") + if err != nil { + b.Errorf("could not create logger due to: %s", err) + } + handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{}) for i := 0; i < b.N; i++ { - handler(w, req) + handler.ServeHTTP(w, req) } } diff --git a/handlers.go b/gateway/routes.go similarity index 53% rename from handlers.go rename to gateway/routes.go index f35437b..b287465 100644 --- a/handlers.go +++ b/gateway/routes.go @@ -1,18 +1,11 @@ -package frontman +package gateway import ( - "encoding/json" - - "io" - "log" "net/http" - "net/url" "strings" "sync" "time" - "github.com/Frontman-Labs/frontman/config" - "github.com/Frontman-Labs/frontman/plugins" "github.com/Frontman-Labs/frontman/service" ) @@ -45,7 +38,7 @@ func refreshClients(bs *service.BackendService, clients map[string]*http.Client, } } -func refreshConnections(bs service.ServiceRegistry, clients map[string]*http.Client, clientLock *sync.Mutex) { +func RefreshConnections(bs service.ServiceRegistry, clients map[string]*http.Client, clientLock *sync.Mutex) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -103,17 +96,6 @@ func refreshConnections(bs service.ServiceRegistry, clients map[string]*http.Cli } } -func getNextTargetIndex(backendService *service.BackendService, currentIndex int) int { - numTargets := len(backendService.UpstreamTargets) - if numTargets == 0 { - return -1 - } - if currentIndex >= numTargets-1 { - return 0 - } - return currentIndex + 1 -} - func getClientForBackendService(bs service.BackendService, target string, clients map[string]*http.Client, clientLock *sync.Mutex) (*http.Client, error) { clientLock.Lock() defer clientLock.Unlock() @@ -235,119 +217,3 @@ func findBackendService(root *Route, r *http.Request) *service.BackendService { return nil } - -func gatewayHandler(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, clients map[string]*http.Client) http.HandlerFunc { - // Create a map to store HTTP clients for each backend service - var clientLock sync.Mutex - - // Start a goroutine to refresh HTTP connections to each backend service - go refreshConnections(bs, clients, &clientLock) - - return func(w http.ResponseWriter, r *http.Request) { - - root := buildRoutes(bs.GetServices()) - for _, plugin := range plugs { - if err := plugin.PreRequest(r, bs, conf); err != nil { - log.Printf("Plugin error: %v", err) - http.Error(w, err.Error(), err.StatusCode()) - return - } - } - - // Find the backend service that matches the request - backendService := findBackendService(root, r) - - // If the backend service was not found, return a 404 error - if backendService == nil { - http.NotFound(w, r) - return - } - - // Get the upstream target URL for this request - upstreamTarget := backendService.GetLoadBalancer().ChooseTarget(backendService.UpstreamTargets) - var urlPath string - if backendService.StripPath { - urlPath = strings.TrimPrefix(r.URL.Path, backendService.Path) - } else { - urlPath = backendService.Path - } - - // Create a new target URL with the service path and scheme - - targetURL, err := url.Parse(upstreamTarget + urlPath) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - // Get or create a new client for this backend service - client, err := getClientForBackendService(*backendService, backendService.Name, clients, &clientLock) - headers := make(http.Header) - // Copy the headers from the original request - copyHeaders(headers, r.Header) - if backendService.AuthConfig != nil { - tokenValidator := backendService.GetTokenValidator() - // Backend service has auth config specified - claims, err := tokenValidator.ValidateToken(r) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - - if claims != nil { - data, err := json.Marshal(claims) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - headers.Add(backendService.GetUserDataHeader(), string(data)) - } - - } - // Remove the X-Forwarded-For header to prevent spoofing - headers.Del("X-Forwarded-For") - - // Log a message indicating that the request is being sent to the target service - log.Printf("Sending request to %s: %s %s", upstreamTarget, r.Method, urlPath) - - // Send the request to the target service using the client with the specified transport - resp, err := client.Do(&http.Request{ - Method: r.Method, - URL: targetURL, - Proto: r.Proto, - ProtoMajor: r.ProtoMajor, - ProtoMinor: r.ProtoMinor, - Header: headers, - Body: r.Body, - ContentLength: r.ContentLength, - Host: targetURL.Host, - }) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - log.Printf("Error sending request: %v\n", err.Error()) - return - } - - backendService.GetLoadBalancer().Done(upstreamTarget) - - defer resp.Body.Close() - - for _, plugin := range plugs { - if err := plugin.PostResponse(resp, bs, conf); err != nil { - log.Printf("Plugin error: %v", err) - http.Error(w, err.Error(), err.StatusCode()) - return - } - } - - // Log a message indicating that the response has been received from the target service - log.Printf("Response received from %s: %d %s", upstreamTarget, resp.StatusCode, resp.Status) - - // Copy the response headers back to the client - copyHeaders(w.Header(), resp.Header) - - // Set the status code and body of the response - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - } -} diff --git a/gateway_test.go b/gateway_test.go deleted file mode 100644 index 7ed0ef0..0000000 --- a/gateway_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package frontman - -import ( - "net/http" - "net/http/httptest" - "testing" -) - - -func TestCreateRedirectServer(t *testing.T) { - redirectAddr := "0.0.0.0:8000" - redirectServer := createRedirectServer("0.0.0.0:80", redirectAddr) - - // Create a test request to the redirect server - req, err := http.NewRequest("GET", "http://example.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - // Create a test response recorder - rr := httptest.NewRecorder() - - // Call the redirect server's handler function - redirectServer.Handler.ServeHTTP(rr, req) - - // Check that the response has a 301 status code - if status := rr.Code; status != http.StatusMovedPermanently { - t.Errorf("Unexpected status code: got %v, expected %v", status, http.StatusMovedPermanently) - } - - // Check that the response includes a "Location" header with the expected value - expectedURL := "https://example.com/foo" - if location := rr.Header().Get("Location"); location != expectedURL { - t.Errorf("Unexpected Location header value: got %v, expected %v", location, expectedURL) - } -} \ No newline at end of file diff --git a/ssl/tls.go b/ssl/tls.go new file mode 100644 index 0000000..a8bfc98 --- /dev/null +++ b/ssl/tls.go @@ -0,0 +1,14 @@ +package ssl + +import ( + "crypto/tls" + "fmt" +) + +func LoadCert(certFile, keyFile string) (tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return tls.Certificate{}, fmt.Errorf("Failed to load certificate: %w", err) + } + return cert, nil +} diff --git a/test.yaml b/test.yaml index 84ff610..e0e3010 100644 --- a/test.yaml +++ b/test.yaml @@ -2,13 +2,9 @@ global: service_type: "yaml" services_file: "services.yaml" api: - addr: "0.0.0.0:8080" - ssl: - enabled: false + addr: "0.0.0.0:8082" gateway: - addr: "0.0.0.0:8000" - ssl: - enabled: false + addr: "0.0.0.0:8002" logging: level: "debug" plugins: