Skip to content

Commit

Permalink
Merge pull request #41 from Snawoot/no_sni
Browse files Browse the repository at this point in the history
Major network stack rework
  • Loading branch information
Snawoot authored Mar 23, 2021
2 parents eb2ae46 + af12cee commit f2c7f73
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 291 deletions.
272 changes: 61 additions & 211 deletions handler.go
Original file line number Diff line number Diff line change
@@ -1,152 +1,48 @@
package main

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"strings"
"time"
)

const BAD_REQ_MSG = "Bad Request\n"

type AuthProvider func() string

type ProxyHandler struct {
auth AuthProvider
upstreamAddr string
tlsName string
logger *CondLogger
dialer *net.Dialer
dialer ContextDialer
httptransport http.RoundTripper
resolver *Resolver
}

func NewProxyHandler(upstream *Endpoint, auth AuthProvider, resolver *Resolver, logger *CondLogger) *ProxyHandler {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
netaddr := net.JoinHostPort(upstream.Host, fmt.Sprintf("%d", upstream.Port))
func NewProxyHandler(dialer ContextDialer, resolver *Resolver, logger *CondLogger) *ProxyHandler {
dialer = NewRetryDialer(dialer, resolver, logger)
httptransport := &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
Proxy: http.ProxyURL(upstream.URL()),
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", netaddr)
},
DialContext: dialer.DialContext,
}
return &ProxyHandler{
auth: auth,
upstreamAddr: netaddr,
tlsName: upstream.TLSName,
logger: logger,
dialer: dialer,
httptransport: httptransport,
resolver: resolver,
}
}

func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.logger.Info("Request: %v %v %v", req.RemoteAddr, req.Method, req.URL)
if strings.ToUpper(req.Method) == "CONNECT" {
req.Header.Set("Proxy-Authorization", s.auth())
rawreq, err := httputil.DumpRequest(req, false)
if err != nil {
s.logger.Error("Can't dump request: %v", err)
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
return
}

conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write upstream: %v", err)
http.Error(wr, "Can't write upstream", http.StatusBadGateway)
return
}
bufrd := bufio.NewReader(conn)
proxyResp, err := http.ReadResponse(bufrd, req)
responseBytes := make([]byte, 0)
if err != nil {
s.logger.Error("Can't read response from upstream: %v", err)
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
return
}

if proxyResp.StatusCode == http.StatusForbidden &&
proxyResp.Header.Get("X-Hola-Error") == "Forbidden Host" {
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&rewrite workaround.",
req.URL.String())

conn, err = s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial upstream: %v", err)
http.Error(wr, "Can't dial upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

err = rewriteConnectReq(req, s.resolver)
if err != nil {
s.logger.Error("Can't rewrite request: %v", err)
http.Error(wr, "Can't rewrite request", http.StatusInternalServerError)
return
}
rawreq, err = httputil.DumpRequest(req, false)
if err != nil {
s.logger.Error("Can't dump request: %v", err)
http.Error(wr, "Can't dump request", http.StatusInternalServerError)
return
}
_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write tls upstream: %v", err)
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
return
}
} else {
responseBytes, err = httputil.DumpResponse(proxyResp, false)
if err != nil {
s.logger.Error("Can't dump response: %v", err)
http.Error(wr, "Can't dump response", http.StatusInternalServerError)
return
}
buffered := bufrd.Buffered()
if buffered > 0 {
trailer := make([]byte, buffered)
bufrd.Read(trailer)
responseBytes = append(responseBytes, trailer...)
}
}
bufrd = nil
func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
ctx := req.Context()
conn, err := s.dialer.DialContext(ctx, "tcp", req.RequestURI)
if err != nil {
s.logger.Error("Can't satisfy CONNECT request: %v", err)
http.Error(wr, "Can't satisfy CONNECT request", http.StatusBadGateway)
return
}

if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
// Upgrade client connection
localconn, _, err := hijack(wr)
if err != nil {
Expand All @@ -156,102 +52,56 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
}
defer localconn.Close()

if len(responseBytes) > 0 {
_, err = localconn.Write(responseBytes)
if err != nil {
return
}
}
// Inform client connection is built
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)

proxy(req.Context(), localconn, conn)
} else if req.ProtoMajor == 2 {
wr.Header()["Date"] = nil
wr.WriteHeader(http.StatusOK)
flush(wr)
proxyh2(req.Context(), req.Body, wr, conn)
} else {
delHopHeaders(req.Header)
orig_req := req.Clone(req.Context())
req.RequestURI = ""
req.Header.Set("Proxy-Authorization", s.auth())
resp, err := s.httptransport.RoundTrip(req)
if err != nil {
s.logger.Error("HTTP fetch error: %v", err)
http.Error(wr, "Server Error", http.StatusInternalServerError)
return
}
if resp.StatusCode == http.StatusForbidden &&
resp.Header.Get("X-Hola-Error") == "Forbidden Host" {
s.logger.Info("Request %s denied by upstream. Rescuing it with resolve&tunnel workaround.",
req.URL.String())
resp.Body.Close()

// Prepare tunnel request
proxyReq, err := makeConnReq(orig_req.RequestURI, s.resolver)
if err != nil {
s.logger.Error("Can't rewrite request: %v", err)
http.Error(wr, "Can't rewrite request", http.StatusBadGateway)
return
}
proxyReq.Header.Set("Proxy-Authorization", s.auth())
rawreq, _ := httputil.DumpRequest(proxyReq, false)

// Prepare upstream conn
conn, err := s.dialer.DialContext(req.Context(), "tcp", s.upstreamAddr)
if err != nil {
s.logger.Error("Can't dial tls upstream: %v", err)
http.Error(wr, "Can't dial tls upstream", http.StatusBadGateway)
return
}
defer conn.Close()

if s.tlsName != "" {
conn = tls.Client(conn, &tls.Config{
ServerName: s.tlsName,
})
defer conn.Close()
}

// Send proxy request
_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write tls upstream: %v", err)
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
return
}
s.logger.Error("Unsupported protocol version: %s", req.Proto)
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
return
}
}

// Read proxy response
bufrd := bufio.NewReader(conn)
proxyResp, err := http.ReadResponse(bufrd, proxyReq)
if err != nil {
s.logger.Error("Can't read response from upstream: %v", err)
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
return
}
if proxyResp.StatusCode != http.StatusOK {
delHopHeaders(proxyResp.Header)
copyHeader(wr.Header(), proxyResp.Header)
wr.WriteHeader(proxyResp.StatusCode)
}
func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request) {
req.RequestURI = ""
if req.ProtoMajor == 2 {
req.URL.Scheme = "http" // We can't access :scheme pseudo-header, so assume http
req.URL.Host = req.Host
}
resp, err := s.httptransport.RoundTrip(req)
if err != nil {
s.logger.Error("HTTP fetch error: %v", err)
http.Error(wr, "Server Error", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
delHopHeaders(resp.Header)
copyHeader(wr.Header(), resp.Header)
wr.WriteHeader(resp.StatusCode)
flush(wr)
copyBody(wr, resp.Body)
}

// Send tunneled request
orig_req.RequestURI = ""
orig_req.Header.Set("Connection", "close")
rawreq, _ = httputil.DumpRequest(orig_req, false)
_, err = conn.Write(rawreq)
if err != nil {
s.logger.Error("Can't write tls upstream: %v", err)
http.Error(wr, "Can't write tls upstream", http.StatusBadGateway)
return
}
func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
s.logger.Info("Request: %v %v %v %v", req.RemoteAddr, req.Proto, req.Method, req.URL)

// Read tunneled response
resp, err = http.ReadResponse(bufrd, orig_req)
if err != nil {
s.logger.Error("Can't read response from upstream: %v", err)
http.Error(wr, "Can't read response from upstream", http.StatusBadGateway)
return
}
}
defer resp.Body.Close()
s.logger.Info("%v %v %v %v", req.RemoteAddr, req.Method, req.URL, resp.Status)
delHopHeaders(resp.Header)
copyHeader(wr.Header(), resp.Header)
wr.WriteHeader(resp.StatusCode)
io.Copy(wr, resp.Body)
isConnect := strings.ToUpper(req.Method) == "CONNECT"
if (req.URL.Host == "" || req.URL.Scheme == "" && !isConnect) && req.ProtoMajor < 2 ||
req.Host == "" && req.ProtoMajor == 2 {
http.Error(wr, BAD_REQ_MSG, http.StatusBadRequest)
return
}
delHopHeaders(req.Header)
if isConnect {
s.HandleTunnel(wr, req)
} else {
s.HandleRequest(wr, req)
}
}
21 changes: 12 additions & 9 deletions holaapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ func (a *FallbackAgent) ToProxy() *url.URL {
}
}

func (a *FallbackAgent) Hostname() string {
return a.Name + AGENT_SUFFIX
}

func (a *FallbackAgent) NetAddr() string {
return net.JoinHostPort(a.IP, fmt.Sprintf("%d", a.Port))
}

func do_req(ctx context.Context, client *http.Client, method, url string, query, data url.Values) ([]byte, error) {
var (
req *http.Request
Expand Down Expand Up @@ -328,19 +336,14 @@ func httpClientWithProxy(agent *FallbackAgent) *http.Client {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
dialer := &net.Dialer{
var dialer ContextDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
if agent == nil {
t.DialContext = dialer.DialContext
} else {
t.Proxy = http.ProxyURL(agent.ToProxy())
addr := net.JoinHostPort(agent.IP, fmt.Sprintf("%d", agent.Port))
t.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr)
}
if agent != nil {
dialer = NewProxyDialer(agent.NetAddr(), agent.Hostname(), nil, dialer)
}
t.DialContext = dialer.DialContext
return &http.Client{
Transport: t,
}
Expand Down
11 changes: 7 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@ import (
"flag"
"fmt"
"log"
"os"
// "os/signal"
// "syscall"
"net"
"net/http"
"os"
"time"
)

Expand Down Expand Up @@ -126,9 +125,13 @@ func run() int {
logWriter.Close()
return 5
}
var dialer ContextDialer = NewProxyDialer(endpoint.NetAddr(), endpoint.TLSName, auth, &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
})
mainLogger.Info("Endpoint: %s", endpoint.URL().String())
mainLogger.Info("Starting proxy server...")
handler := NewProxyHandler(endpoint, auth, resolver, proxyLogger)
handler := NewProxyHandler(dialer, resolver, proxyLogger)
mainLogger.Info("Init complete.")
err = http.ListenAndServe(args.bind_address, handler)
mainLogger.Critical("Server terminated with a reason: %v", err)
Expand Down
Loading

0 comments on commit f2c7f73

Please sign in to comment.