Skip to content

Commit

Permalink
Merge pull request #111 from amityahav/issue-103
Browse files Browse the repository at this point in the history
  • Loading branch information
amityahav authored Apr 12, 2023
2 parents 32477f6 + 0c6c74f commit 54985ba
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 230 deletions.
10 changes: 5 additions & 5 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,23 @@ func validateService(service *service.BackendService) error {
return nil
}

func validateLoadBalancerPolicy(s *service.BackendService) error {
switch s.LoadBalancerPolicy.Type {
func validateLoadBalancerPolicy(bs *service.BackendService) error {
switch bs.LoadBalancerPolicy.Type {
case loadbalancer.Random:
case loadbalancer.LeastConnection:
case loadbalancer.RoundRobin:
case loadbalancer.WeightedRoundRobin, loadbalancer.WeightedLeastConnection:
if len(s.LoadBalancerPolicy.Options.Weights) != len(s.UpstreamTargets) {
if len(bs.LoadBalancerPolicy.Options.Weights) != len(bs.UpstreamTargets) {
return fmt.Errorf("mismatched lengths of weights and targets")
}

for _, w := range s.LoadBalancerPolicy.Options.Weights {
for _, w := range bs.LoadBalancerPolicy.Options.Weights {
if w <= 0 {
return fmt.Errorf("weights must be greater than zero")
}
}
default:
return fmt.Errorf("unknown load-balancer policy: %s", s.LoadBalancerPolicy.Type)
return fmt.Errorf("unknown load-balancer policy: %s", bs.LoadBalancerPolicy.Type)
}

return nil
Expand Down
10 changes: 2 additions & 8 deletions frontman.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"

"github.com/Frontman-Labs/frontman/api"
"github.com/julienschmidt/httprouter"
"net/http"

"github.com/Frontman-Labs/frontman/config"
"github.com/Frontman-Labs/frontman/gateway"
Expand Down Expand Up @@ -52,11 +50,7 @@ func NewFrontman(conf *config.Config, log log.Logger) (*Frontman, error) {
}

// Create new APIGateway instance
clients := make(map[string]*http.Client)
lock := sync.Mutex{}
apiGateway := gateway.NewAPIGateway(serviceRegistry, plug, conf, clients, log, &lock)

go gateway.RefreshConnections(serviceRegistry, clients, &lock)
apiGateway := gateway.NewAPIGateway(serviceRegistry, plug, conf, log)

// Create the Frontman instance
return &Frontman{
Expand Down
134 changes: 17 additions & 117 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,29 @@ package gateway

import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"

"github.com/Frontman-Labs/frontman/config"
"github.com/Frontman-Labs/frontman/log"
"github.com/Frontman-Labs/frontman/plugins"
"github.com/Frontman-Labs/frontman/service"
"io"
"net/http"
"net/url"
"strings"
)

type APIGateway struct {
reg service.ServiceRegistry
plugs []plugins.FrontmanPlugin
conf *config.Config
clients map[string]*http.Client
clientLock *sync.Mutex
log log.Logger
reg service.ServiceRegistry
plugs []plugins.FrontmanPlugin
conf *config.Config
log log.Logger
}

func NewAPIGateway(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, clients map[string]*http.Client, logger log.Logger, lock *sync.Mutex) *APIGateway {
func NewAPIGateway(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, conf *config.Config, logger log.Logger) *APIGateway {
return &APIGateway{
reg: bs,
plugs: plugs,
conf: conf,
clients: clients,
clientLock: lock,
log: logger,
reg: bs,
plugs: plugs,
conf: conf,
log: logger,
}
}

Expand All @@ -57,15 +49,12 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Get the upstream target URL for this request
upstreamTarget := backendService.GetLoadBalancer().ChooseTarget(backendService.UpstreamTargets)

var urlPath string

urlPath := req.URL.Path
if backendService.StripPath {
urlPath = strings.TrimPrefix(req.URL.Path, backendService.Path)
} else {
urlPath = req.URL.Path
}

// Use the compiledRegex field in the backendService struct to apply the rewrite
// Use the compiledRewriteMatch field in the backendService struct to apply the rewrite
if backendService.GetCompiledRewriteMatch() != nil {
urlPath = backendService.GetCompiledRewriteMatch().ReplaceAllString(urlPath, backendService.RewriteReplace)
}
Expand All @@ -82,8 +71,8 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) {
targetURL.RawQuery = req.URL.RawQuery
}

// Get or create a new client for this backend service
client, err := getClientForBackendService(*backendService, backendService.Name, g.clients, g.clientLock)
// Get client for backend service
client := backendService.GetHttpClient()

// Copy the headers from the original request
headers := make(http.Header)
Expand Down Expand Up @@ -156,95 +145,6 @@ func (g *APIGateway) ServeHTTP(w http.ResponseWriter, req *http.Request) {

}


func RefreshConnections(bs service.ServiceRegistry, clients map[string]*http.Client, clientLock *sync.Mutex) {

ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
services := bs.GetServices()

// Remove clients that are no longer needed
clientLock.Lock()
for k := range clients {
found := false
for _, s := range services {
for _, t := range s.UpstreamTargets {
key := fmt.Sprintf("%s_%s", s.Name, t)
if key == k {
found = true
break
}
}
if found {
break
}
}
if !found {
delete(clients, k)
}
}
clientLock.Unlock()

// Add or update clients for each service
for _, s := range services {
for _, t := range s.UpstreamTargets {
clientLock.Lock()
key := fmt.Sprintf("%s_%s", s.Name, t)
_, ok := clients[key]
if !ok {
transport := &http.Transport{
MaxIdleConns: s.MaxIdleConns,
IdleConnTimeout: s.MaxIdleTime * time.Second,
TLSHandshakeTimeout: s.Timeout * time.Second,
}
client := &http.Client{
Transport: transport,
}
clients[key] = client
} else {
clients[key].Transport.(*http.Transport).MaxIdleConns = s.MaxIdleConns
clients[key].Transport.(*http.Transport).IdleConnTimeout = s.MaxIdleTime * time.Second
clients[key].Transport.(*http.Transport).TLSHandshakeTimeout = s.Timeout * time.Second
}
clientLock.Unlock()
}
}

}
}
}

func getClientForBackendService(bs service.BackendService, target string, clients map[string]*http.Client, clientLock *sync.Mutex) (*http.Client, error) {
clientLock.Lock()
defer clientLock.Unlock()

// Check if the client for this target already exists
if client, ok := clients[target]; ok {
return client, nil
}

// Create a new transport with the specified settings
transport := &http.Transport{
MaxIdleConns: bs.MaxIdleConns,
IdleConnTimeout: bs.MaxIdleTime * time.Second,
TLSHandshakeTimeout: bs.Timeout * time.Second,
}

// Create a new HTTP client with the transport
client := &http.Client{
Transport: transport,
}

// Add the client to the map of clients
key := fmt.Sprintf("%s_%s", bs.Name, target)
clients[key] = client

return client, nil
}

func copyHeaders(dst, src http.Header) {
for k, v := range src {
dst[k] = v
Expand Down
92 changes: 3 additions & 89 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package gateway

import (
"context"
"fmt"
"github.com/Frontman-Labs/frontman/loadbalancer"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -255,8 +253,6 @@ func TestGatewayHandler(t *testing.T) {

bs.Init()

clients := make(map[string]*http.Client)

mockClient := &mockHTTPClient{
mockResponse: &http.Response{
StatusCode: tc.expectedStatusCode,
Expand All @@ -265,7 +261,7 @@ func TestGatewayHandler(t *testing.T) {
mockErr: nil,
}

clients[bs.Name] = &http.Client{Transport: mockClient}
bs.GetHttpClient().Transport = mockClient

reg, _ := service.NewServiceRegistry(context.Background(), "memory", nil)
reg.AddService(bs)
Expand All @@ -288,7 +284,7 @@ func TestGatewayHandler(t *testing.T) {
if err != nil {
t.Errorf("could not create logger due to: %s", err)
}
handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{})
handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, logger)
handler.ServeHTTP(w, req)

// Check the response status code
Expand Down Expand Up @@ -397,78 +393,6 @@ func TestFindBackendService(t *testing.T) {
}
}

func TestGetClientForBackendService(t *testing.T) {
testCases := []struct {
name string
backendService service.BackendService
target string
existingClients map[string]*http.Client
expectedTransport *http.Transport
}{
{
name: "New client for target",
backendService: service.BackendService{
MaxIdleConns: 100,
MaxIdleTime: 10,
Timeout: 5,
UpstreamTargets: []string{"httpbin.org"},
},
target: "httpbin.org",
existingClients: map[string]*http.Client{},
expectedTransport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
},
},
{
name: "Existing client for target",
backendService: service.BackendService{
MaxIdleConns: 50,
MaxIdleTime: 5,
Timeout: 2,
UpstreamTargets: []string{"httpbin.org"},
},
target: "httpbin.org",
existingClients: map[string]*http.Client{
"httpbin.org": &http.Client{
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
},
},
},
expectedTransport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clients := tc.existingClients
clientLock := &sync.Mutex{}

client, err := getClientForBackendService(tc.backendService, tc.target, clients, clientLock)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Errorf("Unexpected transport type: %T", client.Transport)
}

if !reflect.DeepEqual(transport, tc.expectedTransport) {
t.Errorf("Unexpected transport: got %v, want %v", transport, tc.expectedTransport)
}
})
}
}

func BenchmarkGatewayHandler(b *testing.B) {
bs := &service.BackendService{
Name: "test",
Expand Down Expand Up @@ -501,21 +425,11 @@ func BenchmarkGatewayHandler(b *testing.B) {
},
}

clients := make(map[string]*http.Client)
key := fmt.Sprintf("%s_%s", bs.Name, bs.UpstreamTargets[0])
clients[key] = &http.Client{Transport: &mockHTTPClient{
mockResponse: &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
},
mockErr: nil,
}}

logger, err := log.NewZapLogger("info")
if err != nil {
b.Errorf("could not create logger due to: %s", err)
}
handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, clients, logger, &sync.Mutex{})
handler := NewAPIGateway(reg, []plugins.FrontmanPlugin{plugin}, &config.Config{}, logger)

for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
Expand Down
Loading

0 comments on commit 54985ba

Please sign in to comment.