Skip to content

Commit

Permalink
fix(proxy): correct proxy application
Browse files Browse the repository at this point in the history
This commit primarily addresses issues with proxy functionality,
ensuring that proxies are correctly applied to requests.
  • Loading branch information
jaxron committed Oct 3, 2024
1 parent c27bafc commit 90d06f7
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ run:
concurrency: 0
allow-parallel-runners: true
allow-serial-runners: true
tests: true
tests: false
go: '1.23'

linters:
Expand Down
28 changes: 14 additions & 14 deletions middleware/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,26 @@ func (m *ProxyMiddleware) Process(ctx context.Context, httpClient *http.Client,

m.logger.WithFields(logger.String("proxy", proxy.Host)).Debug("Using Proxy")

// Clone the client to avoid modifying the original because the
// client is shared across requests and unsafe for concurrent use
clonedClient := &http.Client{
Transport: httpClient.Transport,
CheckRedirect: httpClient.CheckRedirect,
Jar: httpClient.Jar,
Timeout: httpClient.Timeout,
}
*httpClient = *clonedClient

// Apply the proxy to the request
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok {
return nil, ErrInvalidTransport
}
transportCopy := transport.Clone()
transportCopy.Proxy = http.ProxyURL(proxy)
transport = transport.Clone()
transport.Proxy = http.ProxyURL(proxy)
transport.OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *http.Request, connectRes *http.Response) error {
m.logger.WithFields(logger.String("proxy", proxyURL.Host)).Debug("Proxy connection established")
return nil
}

// Use the modified transport for this request
ctx = context.WithValue(ctx, http.DefaultTransport, transportCopy)
// Shallow copy the client to avoid modifying the original because
// it's shared across requests and is unsafe for concurrent use
httpClient = &http.Client{
Transport: transport,
CheckRedirect: httpClient.CheckRedirect,
Jar: httpClient.Jar,
Timeout: httpClient.Timeout,
}
}

return next(ctx, httpClient, req)
Expand Down
38 changes: 28 additions & 10 deletions middleware/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func TestProxyMiddleware(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

handler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
transport := ctx.Value(http.DefaultTransport).(*http.Transport)
transport, ok := httpClient.Transport.(*http.Transport)
require.True(t, ok)
assert.NotNil(t, transport.Proxy)
proxyURL, err := transport.Proxy(req)
require.NoError(t, err)
Expand Down Expand Up @@ -58,7 +59,8 @@ func TestProxyMiddleware(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

handler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
transport := ctx.Value(http.DefaultTransport).(*http.Transport)
transport, ok := httpClient.Transport.(*http.Transport)
require.True(t, ok)
assert.NotNil(t, transport.Proxy)
proxyURL, err := transport.Proxy(req)
require.NoError(t, err)
Expand All @@ -76,7 +78,8 @@ func TestProxyMiddleware(t *testing.T) {

// Next request should use the new proxy
newHandler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
transport := ctx.Value(http.DefaultTransport).(*http.Transport)
transport, ok := httpClient.Transport.(*http.Transport)
require.True(t, ok)
assert.NotNil(t, transport.Proxy)
proxyURL, err := transport.Proxy(req)
require.NoError(t, err)
Expand Down Expand Up @@ -110,13 +113,20 @@ func TestProxyMiddleware(t *testing.T) {

req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

originalClient := &http.Client{}
handler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
_, ok := ctx.Value(http.DefaultTransport).(*http.Transport)
assert.False(t, ok, "Expected no transport to be set when no proxies are configured")
// When no proxies are set, the client should remain unchanged
assert.Equal(t, originalClient, httpClient)

// If Transport is not nil, ensure it doesn't have a Proxy set
if transport, ok := httpClient.Transport.(*http.Transport); ok {
assert.Nil(t, transport.Proxy)
}

return &http.Response{StatusCode: http.StatusOK}, nil
}

resp, err := middleware.Process(context.Background(), &http.Client{}, req, handler)
resp, err := middleware.Process(context.Background(), originalClient, req, handler)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
})
Expand All @@ -137,7 +147,8 @@ func TestProxyMiddlewareDisable(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)

handler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
transport := ctx.Value(http.DefaultTransport).(*http.Transport)
transport, ok := httpClient.Transport.(*http.Transport)
require.True(t, ok)
assert.NotNil(t, transport.Proxy)
actualProxyURL, err := transport.Proxy(req)
require.NoError(t, err)
Expand All @@ -155,13 +166,20 @@ func TestProxyMiddlewareDisable(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://example.com", nil)
ctx := context.WithValue(context.Background(), proxy.KeySkipProxy, true)

originalClient := &http.Client{}
handler := func(ctx context.Context, httpClient *http.Client, req *http.Request) (*http.Response, error) {
_, ok := ctx.Value(http.DefaultTransport).(*http.Transport)
assert.False(t, ok, "Expected no transport to be set when proxy is disabled")
// When proxy is disabled, the client should remain unchanged
assert.Equal(t, originalClient, httpClient)

// If Transport is not nil, ensure it doesn't have a Proxy set
if transport, ok := httpClient.Transport.(*http.Transport); ok {
assert.Nil(t, transport.Proxy)
}

return &http.Response{StatusCode: http.StatusOK}, nil
}

_, err := middleware.Process(ctx, &http.Client{}, req, handler)
_, err := middleware.Process(ctx, originalClient, req, handler)
require.NoError(t, err)
})
}
11 changes: 0 additions & 11 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"reflect"

"github.com/jaxron/axonet/pkg/client/errors"
Expand Down Expand Up @@ -40,16 +39,6 @@ func NewClient(opts ...Option) *Client {
opt(client)
}

// Set up proxy connection logging
if transport, ok := client.defaultHTTPClient.Transport.(*http.Transport); ok {
transport.OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *http.Request, connectRes *http.Response) error {
client.Logger.WithFields(logger.String("proxy", proxyURL.Host)).Debug("Proxy connection established")
return nil
}
} else {
client.Logger.Debug("HTTP client transport is not of type *http.Transport, proxy connection logging not set up")
}

return client
}

Expand Down

0 comments on commit 90d06f7

Please sign in to comment.