diff --git a/CHANGELOG.md b/CHANGELOG.md index ef1ec46a4b..c96941788b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ All notable changes to `src-cli` are documented in this file. ## Unreleased +### Added + +- Support HTTP(S), SOCKS5, and UNIX Domain Socket proxies via SRC_PROXY environment variable. [#1120](https://github.com/sourcegraph/src-cli/pull/1120) + ## 5.8.1 ### Fixed diff --git a/cmd/src/login.go b/cmd/src/login.go index 5fec84b593..6915cada7f 100644 --- a/cmd/src/login.go +++ b/cmd/src/login.go @@ -81,7 +81,7 @@ func loginCmd(ctx context.Context, cfg *config, client api.Client, endpointArg s if cfg.ConfigFilePath != "" { fmt.Fprintln(out) - fmt.Fprintf(out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT and SRC_ACCESS_TOKEN instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", cfg.ConfigFilePath) + fmt.Fprintf(out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", cfg.ConfigFilePath) } noToken := cfg.AccessToken == "" diff --git a/cmd/src/login_test.go b/cmd/src/login_test.go index 4423d335d8..45403e293b 100644 --- a/cmd/src/login_test.go +++ b/cmd/src/login_test.go @@ -49,7 +49,7 @@ func TestLogin(t *testing.T) { if err != cmderrors.ExitCode1 { t.Fatal(err) } - wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT and SRC_ACCESS_TOKEN instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token at https://example.com/user/settings/tokens, then set the following environment variables:\n\n SRC_ENDPOINT=https://example.com\n SRC_ACCESS_TOKEN=(the access token you just created)\n\n To verify that it's working, run this command again." + wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token at https://example.com/user/settings/tokens, then set the following environment variables:\n\n SRC_ENDPOINT=https://example.com\n SRC_ACCESS_TOKEN=(the access token you just created)\n\n To verify that it's working, run this command again." if out != wantOut { t.Errorf("got output %q, want %q", out, wantOut) } diff --git a/cmd/src/main.go b/cmd/src/main.go index c646824e38..f2dfe270e4 100644 --- a/cmd/src/main.go +++ b/cmd/src/main.go @@ -5,8 +5,9 @@ import ( "flag" "io" "log" + "net" + "net/url" "os" - "os/user" "path/filepath" "strings" @@ -25,6 +26,20 @@ Usage: Environment variables SRC_ACCESS_TOKEN Sourcegraph access token SRC_ENDPOINT endpoint to use, if unset will default to "https://sourcegraph.com" + SRC_PROXY A proxy to use for proxying requests to the Sourcegraph endpoint. + Supports HTTP(S), SOCKS5/5h, and UNIX Domain Socket proxies. + If a UNIX Domain Socket, the path can be either an absolute path, + or can start with ~/ or %USERPROFILE%\ for a path in the user's home directory. + Examples: + - https://localhost:3080 + - https://:localhost:8080 + - socks5h://localhost:1080 + - socks5://:@localhost:1080 + - unix://~/src-proxy.sock + - unix://%USERPROFILE%\src-proxy.sock + - ~/src-proxy.sock + - %USERPROFILE%\src-proxy.sock + - C:\some\path\src-proxy.sock The options are: @@ -83,8 +98,10 @@ type config struct { Endpoint string `json:"endpoint"` AccessToken string `json:"accessToken"` AdditionalHeaders map[string]string `json:"additionalHeaders"` - - ConfigFilePath string + Proxy string `json:"proxy"` + ProxyURL *url.URL + ProxyPath string + ConfigFilePath string } // apiClient returns an api.Client built from the configuration. @@ -95,32 +112,25 @@ func (c *config) apiClient(flags *api.Flags, out io.Writer) api.Client { AdditionalHeaders: c.AdditionalHeaders, Flags: flags, Out: out, + ProxyURL: c.ProxyURL, + ProxyPath: c.ProxyPath, }) } -var testHomeDir string // used by tests to mock the user's $HOME - // readConfig reads the config file from the given path. func readConfig() (*config, error) { - cfgPath := *configPath + cfgFile := *configPath userSpecified := *configPath != "" - var homeDir string - if testHomeDir != "" { - homeDir = testHomeDir - } else { - u, err := user.Current() - if err != nil { - return nil, err - } - homeDir = u.HomeDir + if !userSpecified { + cfgFile = "~/src-config.json" } - if !userSpecified { - cfgPath = filepath.Join(homeDir, "src-config.json") - } else if strings.HasPrefix(cfgPath, "~/") { - cfgPath = filepath.Join(homeDir, cfgPath[2:]) + cfgPath, err := expandHomeDir(cfgFile) + if err != nil { + return nil, err } + data, err := os.ReadFile(os.ExpandEnv(cfgPath)) if err != nil && (!os.IsNotExist(err) || userSpecified) { return nil, err @@ -135,10 +145,12 @@ func readConfig() (*config, error) { envToken := os.Getenv("SRC_ACCESS_TOKEN") envEndpoint := os.Getenv("SRC_ENDPOINT") + envProxy := os.Getenv("SRC_PROXY") if userSpecified { - // If a config file is present, either zero or both environment variables must be present. + // If a config file is present, either zero or both required environment variables must be present. // We don't want to partially apply environment variables. + // Note that SRC_PROXY is optional so we don't test for it. if envToken == "" && envEndpoint != "" { return nil, errConfigMerge } @@ -157,6 +169,60 @@ func readConfig() (*config, error) { if cfg.Endpoint == "" { cfg.Endpoint = "https://sourcegraph.com" } + if envProxy != "" { + cfg.Proxy = envProxy + } + + if cfg.Proxy != "" { + + parseEndpoint := func(endpoint string) (scheme string, address string) { + parts := strings.SplitN(endpoint, "://", 2) + if len(parts) == 2 { + return parts[0], parts[1] + } + return "", endpoint + } + + urlSchemes := []string{"http", "https", "socks", "socks5", "socks5h"} + + isURLScheme := func(scheme string) bool { + for _, s := range urlSchemes { + if scheme == s { + return true + } + } + return false + } + + scheme, address := parseEndpoint(cfg.Proxy) + + if isURLScheme(scheme) { + endpoint := cfg.Proxy + // assume socks means socks5, because that's all we support + if scheme == "socks" { + endpoint = "socks5://" + address + } + cfg.ProxyURL, err = url.Parse(endpoint) + if err != nil { + return nil, err + } + } else if scheme == "" || scheme == "unix" { + path, err := expandHomeDir(address) + if err != nil { + return nil, err + } + isValidUDS, err := isValidUnixSocket(path) + if err != nil { + return nil, errors.Newf("Invalid proxy configuration: %w", err) + } + if !isValidUDS { + return nil, errors.Newf("invalid proxy socket: %s", path) + } + cfg.ProxyPath = path + } else { + return nil, errors.Newf("invalid proxy endpoint: %s", cfg.Proxy) + } + } cfg.AdditionalHeaders = parseAdditionalHeaders() // Ensure that we're not clashing additonal headers @@ -178,3 +244,65 @@ func readConfig() (*config, error) { func cleanEndpoint(urlStr string) string { return strings.TrimSuffix(urlStr, "/") } + +// isValidUnixSocket checks if the given path is a valid Unix socket. +// +// Parameters: +// - path: A string representing the file path to check. +// +// Returns: +// - bool: true if the path is a valid Unix socket, false otherwise. +// - error: nil if the check was successful, or an error if an unexpected issue occurred. +// +// The function attempts to establish a connection to the Unix socket at the given path. +// If the connection succeeds, it's considered a valid Unix socket. +// If the file doesn't exist, it returns false without an error. +// For any other errors, it returns false and the encountered error. +func isValidUnixSocket(path string) (bool, error) { + conn, err := net.Dial("unix", path) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, errors.Newf("Not a UNIX Domain Socket: %v: %w", path, err) + } + defer conn.Close() + + return true, nil +} + +var testHomeDir string // used by tests to mock the user's $HOME + +// expandHomeDir expands to the user's home directory a tilde (~) or %USERPROFILE% at the beginning of a file path. +// +// Parameters: +// - filePath: A string representing the file path that may start with "~/" or "%USERPROFILE%\". +// +// Returns: +// - string: The expanded file path with the home directory resolved. +// - error: An error if the user's home directory cannot be determined. +// +// The function handles both Unix-style paths starting with "~/" and Windows-style paths starting with "%USERPROFILE%\". +// It uses the testHomeDir variable for testing purposes if set, otherwise it uses os.UserHomeDir() to get the user's home directory. +// If the input path doesn't start with either prefix, it returns the original path unchanged. +func expandHomeDir(filePath string) (string, error) { + if strings.HasPrefix(filePath, "~/") || strings.HasPrefix(filePath, "%USERPROFILE%\\") { + var homeDir string + if testHomeDir != "" { + homeDir = testHomeDir + } else { + hd, err := os.UserHomeDir() + if err != nil { + return "", err + } + homeDir = hd + } + + if strings.HasPrefix(filePath, "~/") { + return filepath.Join(homeDir, filePath[2:]), nil + } + return filepath.Join(homeDir, filePath[14:]), nil + } + + return filePath, nil +} diff --git a/cmd/src/main_test.go b/cmd/src/main_test.go index 6b25654e7f..c37c36792a 100644 --- a/cmd/src/main_test.go +++ b/cmd/src/main_test.go @@ -2,14 +2,31 @@ package main import ( "encoding/json" + "net/url" "os" "path/filepath" "testing" "github.com/google/go-cmp/cmp" + + "github.com/sourcegraph/src-cli/internal/api" ) func TestReadConfig(t *testing.T) { + // UNIX Domain Sockets have a max path length: 104 on BSD/macOS, 108 on Linux. + // Including a prefix and suffix was causing the path to be too long + // with t.TempDir() (os.TempDir() is a shorter path) so we don't use them. + socketPath, err := api.CreateTempFile(t.TempDir(), "", "") + if err != nil { + t.Fatal(err) + } + socketServer, err := api.StartUnixSocketServer(socketPath) + if err != nil { + t.Fatal(err) + } + defer socketServer.Stop() + defer os.Remove(socketPath) + tests := []struct { name string fileContents *config @@ -17,6 +34,7 @@ func TestReadConfig(t *testing.T) { envFooHeader string envHeaders string envEndpoint string + envProxy string flagEndpoint string want *config wantErr string @@ -33,11 +51,18 @@ func TestReadConfig(t *testing.T) { fileContents: &config{ Endpoint: "https://example.com/", AccessToken: "deadbeef", + Proxy: "https://proxy.com:8080", }, want: &config{ Endpoint: "https://example.com", AccessToken: "deadbeef", AdditionalHeaders: map[string]string{}, + Proxy: "https://proxy.com:8080", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "https", + Host: "proxy.com:8080", + }, }, }, { @@ -61,16 +86,44 @@ func TestReadConfig(t *testing.T) { wantErr: errConfigMerge.Error(), }, { - name: "config file, both override", + name: "config file, proxy override only (allow)", fileContents: &config{ Endpoint: "https://example.com/", AccessToken: "deadbeef", + Proxy: "https://proxy.com:8080", + }, + envProxy: "socks5://other.proxy.com:9999", + want: &config{ + Endpoint: "https://example.com", + AccessToken: "deadbeef", + Proxy: "socks5://other.proxy.com:9999", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "socks5", + Host: "other.proxy.com:9999", + }, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "config file, all override", + fileContents: &config{ + Endpoint: "https://example.com/", + AccessToken: "deadbeef", + Proxy: "https://proxy.com:8080", }, envToken: "abc", envEndpoint: "https://override.com", + envProxy: "socks5://other.proxy.com:9999", want: &config{ - Endpoint: "https://override.com", - AccessToken: "abc", + Endpoint: "https://override.com", + AccessToken: "abc", + Proxy: "socks5://other.proxy.com:9999", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "socks5", + Host: "other.proxy.com:9999", + }, AdditionalHeaders: map[string]string{}, }, }, @@ -93,12 +146,84 @@ func TestReadConfig(t *testing.T) { }, }, { - name: "no config file, both variables", + name: "no config file, proxy from environment", + envProxy: "https://proxy.com:8080", + want: &config{ + Endpoint: "https://sourcegraph.com", + AccessToken: "", + Proxy: "https://proxy.com:8080", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "https", + Host: "proxy.com:8080", + }, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "no config file, all variables", envEndpoint: "https://example.com", envToken: "abc", + envProxy: "https://proxy.com:8080", want: &config{ - Endpoint: "https://example.com", - AccessToken: "abc", + Endpoint: "https://example.com", + AccessToken: "abc", + Proxy: "https://proxy.com:8080", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "https", + Host: "proxy.com:8080", + }, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "UNIX Domain Socket proxy using scheme and absolute path", + envProxy: "unix://" + socketPath, + want: &config{ + Endpoint: "https://sourcegraph.com", + Proxy: "unix://" + socketPath, + ProxyPath: socketPath, + ProxyURL: nil, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "UNIX Domain Socket proxy with absolute path", + envProxy: socketPath, + want: &config{ + Endpoint: "https://sourcegraph.com", + Proxy: socketPath, + ProxyPath: socketPath, + ProxyURL: nil, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "socks --> socks5", + envProxy: "socks://localhost:1080", + want: &config{ + Endpoint: "https://sourcegraph.com", + Proxy: "socks://localhost:1080", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "socks5", + Host: "localhost:1080", + }, + AdditionalHeaders: map[string]string{}, + }, + }, + { + name: "socks5h", + envProxy: "socks5h://localhost:1080", + want: &config{ + Endpoint: "https://sourcegraph.com", + Proxy: "socks5h://localhost:1080", + ProxyPath: "", + ProxyURL: &url.URL{ + Scheme: "socks5h", + Host: "localhost:1080", + }, AdditionalHeaders: map[string]string{}, }, }, @@ -171,6 +296,7 @@ func TestReadConfig(t *testing.T) { } setEnv("SRC_ACCESS_TOKEN", test.envToken) setEnv("SRC_ENDPOINT", test.envEndpoint) + setEnv("SRC_PROXY", test.envProxy) tmpDir := t.TempDir() testHomeDir = tmpDir diff --git a/dev/test-proxies.sh b/dev/test-proxies.sh new file mode 100755 index 0000000000..0241f4e0c6 --- /dev/null +++ b/dev/test-proxies.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash + +# A test script to help manually test permutations of the proxy settings. +# Before running this script,you'll need five terminal windows: +# One for `socat`, two for `mitmproxy` (non-auth and auth), +# and two for a socks proxy (none-auth and auth). +# If you're on macOS, run `brew install socat mitmproxy`. +# To avoid the need to use insecure TLS settings, import mitmproxy's CA certificate into your keychain. +# On macOS: `sudo security add-trusted-cert -d -p ssl -p basic -k /Library/Keychains/System.keychain ~/.mitmproxy/mitmproxy-ca-cert.pem` +# For a socks proxy, I use the Docker image serjs/go-socks5-proxy. +# Note that `mitmproxy` is not a tunneling proxy - it uses a self-signed certificate +# to authenticate clients, which is good and bad. Good because we can inspect requests coming through +# if we want, and bad because we need to use insecure TLS settings to use it. +# It also works with `socat`, where `tinyproxy` and even `devproxy.sgdev.org` don't. +# Terminal 1: +# `socat -d -d unix-listen:${HOME}/socat-proxy.sock,fork tcp:localhost:8080` +# Terminal 2: +# `mitmproxy -v -p 8080` +# Terminal 3: +# `mitmproxy -v -p 8081 --proxyauth user:pass` +# Terminal 4: +# `docker run --rm -p 1080:1080 serjs/go-socks5-proxy` +# Terminal 5: +# `docker run --rm -p 1081:1080 -e PROXY_USER=user -e PROXY_PASSWORD=pass serjs/go-socks5-proxy` +# when you kick off this script, keep all of the windows in view so you can see that the output +# shows successful connections being made. + +SRC_PATH=${SRC_PATH:-~/go/bin/src} + +export SRC_ENDPOINT=${SRC_ENDPOINT:-https://sourcegraph.com} +export SRC_ACCESS_TOKEN=${SRC_ACCESS_TOKEN} + +socket=~/socat-proxy.sock + +# UNIX Domain Socket test +# You should see connection output in both the `socat` and `mitmproxy` terminals. +echo "UNIX Domain Socket test" +SRC_PROXY=${socket} \ +${SRC_PATH} login + +# HTTP test +# You should see connection output in the `mitmproxy` terminal. +echo "HTTP proxy test" +SRC_PROXY=http://localhost:8080 \ +${SRC_PATH} login + +# HTTPS with auth test +# You should see connection output in the `mitmproxy` with auth terminal. +echo "HTTP proxy with auth test" +SRC_PROXY=http://user:pass@localhost:8081 \ +${SRC_PATH} login + +# HTTPS test +# You should see connection output in the `mitmproxy` terminal. +echo "HTTPS proxy test" +SRC_PROXY=https://localhost:8080 \ +${SRC_PATH} login + +# HTTPS with auth test +# You should see connection output in the `mitmproxy` with auth terminal. +echo "HTTPS proxy with auth test" +SRC_PROXY=https://user:pass@localhost:8081 \ +${SRC_PATH} login + +# SOCKS test +# You should see connection output in the socks terminal. +echo "SOCKS proxy test" +SRC_PROXY=socks5://localhost:1080 \ +${SRC_PATH} login + +# SOCKS with auth test +# You should see connection output in the socks terminal. +echo "SOCKS proxy with auth test" +SRC_PROXY=socks5://user:pass@localhost:1081 \ +${SRC_PATH} login + +# HTTPS using insecure TLS code path +# You should see connection output in the `mitmproxy` terminal. +echo "HTTPS proxy insecure TLS path test" +SRC_PROXY=https://localhost:8080 \ +${SRC_PATH} login --insecure-skip-verify=true + + +# test a search using a proxy +echo "Search test" +SRC_PROXY=https://localhost:8080 \ +${SRC_PATH} search -json 'repo:github.com/sourcegraph/src-cli foobar' + + +# test with the system proxy set to something else +echo "Ignoring system proxy test" +https_proxy=http://localhost:12345 \ +SRC_PROXY=https://localhost:8080 \ +${SRC_PATH} login diff --git a/internal/api/BUILD.bazel b/internal/api/BUILD.bazel index a292172e57..4441a7b3ba 100644 --- a/internal/api/BUILD.bazel +++ b/internal/api/BUILD.bazel @@ -8,6 +8,8 @@ go_library( "flags.go", "gzip.go", "nullable.go", + "proxy.go", + "test_unix_socket_server.go", ], importpath = "github.com/sourcegraph/src-cli/internal/api", visibility = ["//:__subpackages__"], diff --git a/internal/api/api.go b/internal/api/api.go index 25e87f814e..63cef92b50 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "runtime" "strings" @@ -81,6 +82,9 @@ type ClientOpts struct { // Out is the writer that will be used when outputting diagnostics, such as // curl commands when -get-curl is enabled. Out io.Writer + + ProxyURL *url.URL + ProxyPath string } // NewClient creates a new API client. @@ -95,9 +99,26 @@ func NewClient(opts ClientOpts) Client { } httpClient := http.DefaultClient + + transport := http.DefaultTransport.(*http.Transport).Clone() + customTransport := false + if flags.insecureSkipVerify != nil && *flags.insecureSkipVerify { + customTransport = true + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + + if applyProxy(transport, opts.ProxyURL, opts.ProxyPath) { + customTransport = true + } + + if customTransport { httpClient = &http.Client{ - Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Transport: transport, } } @@ -112,7 +133,6 @@ func NewClient(opts ClientOpts) Client { httpClient: httpClient, } } - func (c *client) NewQuery(query string) Request { return c.NewRequest(query, nil) } diff --git a/internal/api/proxy.go b/internal/api/proxy.go new file mode 100644 index 0000000000..5e5fbf30ef --- /dev/null +++ b/internal/api/proxy.go @@ -0,0 +1,138 @@ +package api + +import ( + "bufio" + "context" + "crypto/tls" + "encoding/base64" + "fmt" + "net" + "net/http" + "net/url" +) + +func applyProxy(transport *http.Transport, proxyURL *url.URL, proxyPath string) (applied bool) { + if proxyURL == nil && proxyPath == "" { + return false + } + + handshakeTLS := func(ctx context.Context, conn net.Conn, addr string) (net.Conn, error) { + // Extract the hostname (without the port) for TLS SNI + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + tlsConn := tls.Client(conn, &tls.Config{ + ServerName: host, + // Pull InsecureSkipVerify from the target host transport + // so that insecure-skip-verify flag settings are honored for the proxy server + InsecureSkipVerify: transport.TLSClientConfig.InsecureSkipVerify, + }) + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + return tlsConn, nil + } + + proxyApplied := false + + if proxyPath != "" { + dial := func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "unix", proxyPath) + } + dialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := dial(ctx, network, addr) + if err != nil { + return nil, err + } + return handshakeTLS(ctx, conn, addr) + } + transport.DialContext = dial + transport.DialTLSContext = dialTLS + // clear out any system proxy settings + transport.Proxy = nil + proxyApplied = true + } else if proxyURL != nil { + if proxyURL.Scheme == "socks5" || + proxyURL.Scheme == "socks5h" { + // SOCKS proxies work out of the box - no need to manually dial + transport.Proxy = http.ProxyURL(proxyURL) + proxyApplied = true + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + dial := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Dial the proxy + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", proxyURL.Host) + if err != nil { + return nil, err + } + + // this is the whole point of manually dialing the HTTP(S) proxy: + // being able to force HTTP/1. + // When relying on Transport.Proxy, the protocol is always HTTP/2, + // but many proxy servers don't support HTTP/2. + // We don't want to disable HTTP/2 in general because we want to use it when + // connecting to the Sourcegraph API, using HTTP/1 for the proxy connection only. + protocol := "HTTP/1.1" + + // CONNECT is the HTTP method used to set up a tunneling connection with a proxy + method := "CONNECT" + + // Manually writing out the HTTP commands because it's not complicated, + // and http.Request has some janky behavior: + // - ignores the Proto field and hard-codes the protocol to HTTP/1.1 + // - ignores the Host Header (Header.Set("Host", host)) and uses URL.Host instead. + // - When the Host field is set, overrides the URL field + connectReq := fmt.Sprintf("%s %s %s\r\n", method, addr, protocol) + + // A Host header is required per RFC 2616, section 14.23 + connectReq += fmt.Sprintf("Host: %s\r\n", addr) + + // use authentication if proxy credentials are present + if proxyURL.User != nil { + password, _ := proxyURL.User.Password() + auth := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.Username() + ":" + password)) + connectReq += fmt.Sprintf("Proxy-Authorization: Basic %s\r\n", auth) + } + + // finish up with an extra carriage return + newline, as per RFC 7230, section 3 + connectReq += "\r\n" + + // Send the CONNECT request to the proxy to establish the tunnel + if _, err := conn.Write([]byte(connectReq)); err != nil { + conn.Close() + return nil, err + } + + // Read and check the response from the proxy + resp, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + conn.Close() + return nil, err + } + if resp.StatusCode != http.StatusOK { + conn.Close() + return nil, fmt.Errorf("failed to connect to proxy %v: %v", proxyURL, resp.Status) + } + resp.Body.Close() + return conn, nil + } + dialTLS := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Dial the underlying connection through the proxy + conn, err := dial(ctx, network, addr) + if err != nil { + return nil, err + } + return handshakeTLS(ctx, conn, addr) + } + transport.DialContext = dial + transport.DialTLSContext = dialTLS + // clear out any system proxy settings + transport.Proxy = nil + proxyApplied = true + } + } + + return proxyApplied +} diff --git a/internal/api/test_unix_socket_server.go b/internal/api/test_unix_socket_server.go new file mode 100644 index 0000000000..c637aed0b9 --- /dev/null +++ b/internal/api/test_unix_socket_server.go @@ -0,0 +1,109 @@ +package api + +import ( + "fmt" + "net" + "os" + "sync" +) + +type Server struct { + Listener net.Listener + StopChan chan struct{} + Wg sync.WaitGroup +} + +func CreateTempFile(dir, prefix, suffix string) (string, error) { + // Get the default temporary directory for the OS + if dir == "" { + dir = os.TempDir() + } + + // Create a temporary file with the specified prefix and suffix + tempFile, err := os.CreateTemp(dir, prefix+"*"+suffix) + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + + // Close the file to release the file descriptor + tempFile.Close() + + // Return the name of the created file + return tempFile.Name(), nil +} + +func createUnixSocket(socketPath string) (net.Listener, error) { + // Clean up the socket file if it already exists + _ = os.Remove(socketPath) + + // Create a Unix domain socket + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, fmt.Errorf("failed to create Unix domain socket: %w", err) + } + + return listener, nil +} + +// StartServer starts the server in a goroutine and returns a stop function +func StartUnixSocketServer(socketPath string) (*Server, error) { + listener, err := createUnixSocket(socketPath) + if err != nil { + return nil, err + } + + server := &Server{ + Listener: listener, + StopChan: make(chan struct{}), + } + + server.Wg.Add(1) + go func() { + defer server.Wg.Done() + for { + conn, err := server.Listener.Accept() + if err != nil { + select { + case <-server.StopChan: + return + default: + fmt.Println("Error accepting connection:", err) + continue + } + } + + // Handle each connection in a separate goroutine + server.Wg.Add(1) + go func(conn net.Conn) { + defer server.Wg.Done() + defer conn.Close() + + // Handle incoming data from the connection + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + fmt.Println("Error reading from connection:", err) + return + } + + fmt.Printf("Received: %s\n", string(buf[:n])) + + // Example: Write a response back + _, err = conn.Write([]byte("Hello from server\n")) + if err != nil { + fmt.Println("Error writing to connection:", err) + return + } + }(conn) + } + }() + + return server, nil +} + +// StopServer stops the server and waits for all goroutines to finish +func (s *Server) Stop() { + close(s.StopChan) + s.Listener.Close() + s.Wg.Wait() +}