diff --git a/api/api.go b/api/api.go index 24cb1de..8945c8b 100644 --- a/api/api.go +++ b/api/api.go @@ -186,23 +186,23 @@ func validateService(service *service.BackendService) error { return nil } -func validateLoadBalancerPolicy(s *service.BackendService) error { - switch s.LoadBalancerPolicy.Type { +func validateLoadBalancerPolicy(bs *service.BackendService) error { + switch bs.LoadBalancerPolicy.Type { case loadbalancer.Random: case loadbalancer.LeastConnection: case loadbalancer.RoundRobin: case loadbalancer.WeightedRoundRobin, loadbalancer.WeightedLeastConnection: - if len(s.LoadBalancerPolicy.Options.Weights) != len(s.UpstreamTargets) { + if len(bs.LoadBalancerPolicy.Options.Weights) != len(bs.UpstreamTargets) { return fmt.Errorf("mismatched lengths of weights and targets") } - for _, w := range s.LoadBalancerPolicy.Options.Weights { + for _, w := range bs.LoadBalancerPolicy.Options.Weights { if w <= 0 { return fmt.Errorf("weights must be greater than zero") } } default: - return fmt.Errorf("unknown load-balancer policy: %s", s.LoadBalancerPolicy.Type) + return fmt.Errorf("unknown load-balancer policy: %s", bs.LoadBalancerPolicy.Type) } return nil diff --git a/frontman.go b/frontman.go index 5f9cb12..9ef42eb 100644 --- a/frontman.go +++ b/frontman.go @@ -4,11 +4,9 @@ import ( "context" "crypto/tls" "fmt" - "net/http" - "sync" - "github.com/Frontman-Labs/frontman/api" "github.com/julienschmidt/httprouter" + "net/http" "github.com/Frontman-Labs/frontman/config" "github.com/Frontman-Labs/frontman/gateway" @@ -52,11 +50,7 @@ func NewFrontman(conf *config.Config, log log.Logger) (*Frontman, error) { } // Create new APIGateway instance - clients := make(map[string]*http.Client) - lock := sync.Mutex{} - apiGateway := gateway.NewAPIGateway(serviceRegistry, plug, conf, clients, log, &lock) - - go gateway.RefreshConnections(serviceRegistry, clients, &lock) + apiGateway := gateway.NewAPIGateway(serviceRegistry, plug, conf, log) // Create the Frontman instance return &Frontman{ diff --git a/gateway/gateway.go b/gateway/gateway.go index f73115a..69091f9 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -2,37 +2,29 @@ package gateway import ( "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - "time" - "github.com/Frontman-Labs/frontman/config" "github.com/Frontman-Labs/frontman/log" "github.com/Frontman-Labs/frontman/plugins" "github.com/Frontman-Labs/frontman/service" + "io" + "net/http" + "net/url" + "strings" ) type APIGateway struct { - reg service.ServiceRegistry - plugs []plugins.FrontmanPlugin - conf *config.Config - clients map[string]*http.Client - clientLock *sync.Mutex - log log.Logger + reg service.ServiceRegistry + plugs []plugins.FrontmanPlugin + conf *config.Config + log log.Logger } -func NewAPIGateway(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, clients map[string]*http.Client, logger log.Logger, lock *sync.Mutex) *APIGateway { +func NewAPIGateway(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, logger log.Logger) *APIGateway { return &APIGateway{ - reg: bs, - plugs: plugs, - conf: conf, - clients: clients, - clientLock: lock, - log: logger, + reg: bs, + plugs: plugs, + conf: conf, + log: logger, } } @@ -57,15 +49,12 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Get the upstream target URL for this request upstreamTarget := backendService.GetLoadBalancer().ChooseTarget(backendService.UpstreamTargets) - var urlPath string - + urlPath := req.URL.Path if backendService.StripPath { urlPath = strings.TrimPrefix(req.URL.Path, backendService.Path) - } else { - urlPath = req.URL.Path } - // Use the compiledRegex field in the backendService struct to apply the rewrite + // Use the compiledRewriteMatch field in the backendService struct to apply the rewrite if backendService.GetCompiledRewriteMatch() != nil { urlPath = backendService.GetCompiledRewriteMatch().ReplaceAllString(urlPath, backendService.RewriteReplace) } @@ -82,8 +71,8 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) { targetURL.RawQuery = req.URL.RawQuery } - // Get or create a new client for this backend service - client, err := getClientForBackendService(*backendService, backendService.Name, g.clients, g.clientLock) + // Get client for backend service + client := backendService.GetHttpClient() // Copy the headers from the original request headers := make(http.Header) @@ -156,95 +145,6 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) { } - -func RefreshConnections(bs service.ServiceRegistry, clients map[string]*http.Client, clientLock *sync.Mutex) { - - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - services := bs.GetServices() - - // Remove clients that are no longer needed - clientLock.Lock() - for k := range clients { - found := false - for _, s := range services { - for _, t := range s.UpstreamTargets { - key := fmt.Sprintf("%s_%s", s.Name, t) - if key == k { - found = true - break - } - } - if found { - break - } - } - if !found { - delete(clients, k) - } - } - clientLock.Unlock() - - // Add or update clients for each service - for _, s := range services { - for _, t := range s.UpstreamTargets { - clientLock.Lock() - key := fmt.Sprintf("%s_%s", s.Name, t) - _, ok := clients[key] - if !ok { - transport := &http.Transport{ - MaxIdleConns: s.MaxIdleConns, - IdleConnTimeout: s.MaxIdleTime * time.Second, - TLSHandshakeTimeout: s.Timeout * time.Second, - } - client := &http.Client{ - Transport: transport, - } - clients[key] = client - } else { - clients[key].Transport.(*http.Transport).MaxIdleConns = s.MaxIdleConns - clients[key].Transport.(*http.Transport).IdleConnTimeout = s.MaxIdleTime * time.Second - clients[key].Transport.(*http.Transport).TLSHandshakeTimeout = s.Timeout * time.Second - } - clientLock.Unlock() - } - } - - } - } -} - -func getClientForBackendService(bs service.BackendService, target string, clients map[string]*http.Client, clientLock *sync.Mutex) (*http.Client, error) { - clientLock.Lock() - defer clientLock.Unlock() - - // Check if the client for this target already exists - if client, ok := clients[target]; ok { - return client, nil - } - - // Create a new transport with the specified settings - transport := &http.Transport{ - MaxIdleConns: bs.MaxIdleConns, - IdleConnTimeout: bs.MaxIdleTime * time.Second, - TLSHandshakeTimeout: bs.Timeout * time.Second, - } - - // Create a new HTTP client with the transport - client := &http.Client{ - Transport: transport, - } - - // Add the client to the map of clients - key := fmt.Sprintf("%s_%s", bs.Name, target) - clients[key] = client - - return client, nil -} - func copyHeaders(dst, src http.Header) { for k, v := range src { dst[k] = v diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index d9255de..4f87914 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -2,13 +2,11 @@ package gateway import ( "context" - "fmt" "github.com/Frontman-Labs/frontman/loadbalancer" "net/http" "net/http/httptest" "net/url" "reflect" - "sync" "testing" "time" @@ -255,8 +253,6 @@ func TestGatewayHandler(t *testing.T) { bs.Init() - clients := make(map[string]*http.Client) - mockClient := &mockHTTPClient{ mockResponse: &http.Response{ StatusCode: tc.expectedStatusCode, @@ -265,7 +261,7 @@ func TestGatewayHandler(t *testing.T) { mockErr: nil, } - clients[bs.Name] = &http.Client{Transport: mockClient} + bs.GetHttpClient().Transport = mockClient reg, _ := service.NewServiceRegistry(context.Background(), "memory", nil) reg.AddService(bs) @@ -288,7 +284,7 @@ func TestGatewayHandler(t *testing.T) { if err != nil { t.Errorf("could not create logger due to: %s", err) } - handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{}) + handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, logger) handler.ServeHTTP(w, req) // Check the response status code @@ -397,78 +393,6 @@ func TestFindBackendService(t *testing.T) { } } -func TestGetClientForBackendService(t *testing.T) { - testCases := []struct { - name string - backendService service.BackendService - target string - existingClients map[string]*http.Client - expectedTransport *http.Transport - }{ - { - name: "New client for target", - backendService: service.BackendService{ - MaxIdleConns: 100, - MaxIdleTime: 10, - Timeout: 5, - UpstreamTargets: []string{"httpbin.org"}, - }, - target: "httpbin.org", - existingClients: map[string]*http.Client{}, - expectedTransport: &http.Transport{ - MaxIdleConns: 100, - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - }, - }, - { - name: "Existing client for target", - backendService: service.BackendService{ - MaxIdleConns: 50, - MaxIdleTime: 5, - Timeout: 2, - UpstreamTargets: []string{"httpbin.org"}, - }, - target: "httpbin.org", - existingClients: map[string]*http.Client{ - "httpbin.org": &http.Client{ - Transport: &http.Transport{ - MaxIdleConns: 100, - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - }, - }, - }, - expectedTransport: &http.Transport{ - MaxIdleConns: 100, - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - clients := tc.existingClients - clientLock := &sync.Mutex{} - - client, err := getClientForBackendService(tc.backendService, tc.target, clients, clientLock) - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - - transport, ok := client.Transport.(*http.Transport) - if !ok { - t.Errorf("Unexpected transport type: %T", client.Transport) - } - - if !reflect.DeepEqual(transport, tc.expectedTransport) { - t.Errorf("Unexpected transport: got %v, want %v", transport, tc.expectedTransport) - } - }) - } -} - func BenchmarkGatewayHandler(b *testing.B) { bs := &service.BackendService{ Name: "test", @@ -501,21 +425,11 @@ func BenchmarkGatewayHandler(b *testing.B) { }, } - clients := make(map[string]*http.Client) - key := fmt.Sprintf("%s_%s", bs.Name, bs.UpstreamTargets[0]) - clients[key] = &http.Client{Transport: &mockHTTPClient{ - mockResponse: &http.Response{ - StatusCode: http.StatusOK, - Header: make(http.Header), - }, - mockErr: nil, - }} - logger, err := log.NewZapLogger("info") if err != nil { b.Errorf("could not create logger due to: %s", err) } - handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{}) + handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, logger) for i := 0; i < b.N; i++ { handler.ServeHTTP(w, req) diff --git a/service/service.go b/service/service.go index f8ab2e7..536560b 100644 --- a/service/service.go +++ b/service/service.go @@ -3,8 +3,8 @@ package service import ( "log" "net/http" - "time" "regexp" + "time" "github.com/Frontman-Labs/frontman/auth" "github.com/Frontman-Labs/frontman/config" @@ -28,13 +28,13 @@ type BackendService struct { AuthConfig *config.AuthConfig `json:"auth,omitempty" yaml:"auth,omitempty"` LoadBalancerPolicy LoadBalancerPolicy `json:"loadBalancerPolicy,omitempty" yaml:"loadBalancerPolicy,omitempty"` RewriteMatch string `json:"rewriteMatch,omitempty" yaml:"rewriteMatch,omitempty"` - RewriteReplace string `json:"rewriteReplace,omitempty" yaml:"rewriteReplace,omitempty"` + RewriteReplace string `json:"rewriteReplace,omitempty" yaml:"rewriteReplace,omitempty"` - httpClient *http.Client - compiledRewriteMatch *regexp.Regexp - loadBalancer loadbalancer.LoadBalancer - provider oauth.OAuthProvider - tokenValidator *auth.TokenValidator + httpClient *http.Client + compiledRewriteMatch *regexp.Regexp + loadBalancer loadbalancer.LoadBalancer + provider oauth.OAuthProvider + tokenValidator *auth.TokenValidator } type LoadBalancerPolicy struct { @@ -66,6 +66,7 @@ func (bs *BackendService) setTokenValidator() { if bs.AuthConfig == nil { return } + validator, err := auth.GetTokenValidator(*bs.AuthConfig) if err != nil { log.Printf("Error adding auth to backend service: %s: %s", bs.Name, err.Error()) @@ -99,6 +100,9 @@ func (bs *BackendService) GetCompiledRewriteMatch() *regexp.Regexp { return bs.compiledRewriteMatch } +func (bs *BackendService) GetHttpClient() *http.Client { + return bs.httpClient +} func (bs *BackendService) setLoadBalancer() { switch bs.LoadBalancerPolicy.Type { @@ -120,24 +124,32 @@ func (bs *BackendService) setLoadBalancer() { // CompilePath compiles the rewrite match regular expression for the backend service and // stores it in the compiledRewriteMatch field. If there's an error while compiling, // the error is returned. -func (bs *BackendService) compilePath() { +func (bs *BackendService) compilePath() { if bs.RewriteMatch == "" || bs.RewriteReplace == "" { - return + return } compiled, err := regexp.Compile(bs.RewriteMatch) if err != nil { - return + return } bs.compiledRewriteMatch = compiled - return } +func (bs *BackendService) setHttpClient() { + transport := &http.Transport{ + MaxIdleConns: bs.MaxIdleConns, + IdleConnTimeout: bs.MaxIdleTime * time.Second, + TLSHandshakeTimeout: bs.Timeout * time.Second, + } + bs.httpClient = &http.Client{Transport: transport} +} func (bs *BackendService) Init() { bs.setTokenValidator() bs.setLoadBalancer() + bs.setHttpClient() bs.compilePath() }