diff --git a/cli.go b/cli.go index b3c1aa757..dea33ccd8 100644 --- a/cli.go +++ b/cli.go @@ -16,10 +16,12 @@ limitations under the License. package main import ( + "errors" "fmt" "os" "os/signal" "reflect" + "strings" "syscall" "time" @@ -220,6 +222,42 @@ func parseCLIOptions(cx *cli.Context, config *Config) (err error) { config.Resources = append(config.Resources, resource) } } + if cx.IsSet("upstream-url-paths") { + for _, x := range cx.StringSlice("upstream-url-paths") { + path, err := cliParseUpstreamURLPath(x) + if err != nil { + return fmt.Errorf("invalid upstream-url-paths %s, %s", x, err) + } + config.UpstreamPaths = append(config.UpstreamPaths, path) + } + } return nil } + +func cliParseUpstreamURLPath(resource string) (r UpstreamURLPath, err error) { + if resource == "" { + return r, errors.New("no value given") + } + for _, x := range strings.Split(resource, "|") { + kp := strings.Split(x, "=") + if len(kp) != 2 { + return r, errors.New("config pair, should be (uri|upstream-url)=value") + } + switch kp[0] { + case "uri": + r.URL = kp[1] + case "upstream-url": + r.Upstream = kp[1] + default: + return r, fmt.Errorf("invalid identifier '%s', should be uri or upstream-url", kp[0]) + } + } + if r.URL == "" { + return r, errors.New("uri config missing") + } + if r.Upstream == "" { + return r, errors.New("upstream-url config missing") + } + return r, err +} diff --git a/config.go b/config.go index 080af6ad3..6372c9d98 100644 --- a/config.go +++ b/config.go @@ -142,6 +142,15 @@ func (r *Config) isValid() error { if r.Upstream == "" { return errors.New("you have not specified an upstream endpoint to proxy to") } + + if len(r.UpstreamPaths) > 0 { + for _, p := range r.UpstreamPaths { + if _, err := url.Parse(p.Upstream); err != nil { + return fmt.Errorf("the upstream endpoint `%s` is invalid, %s", p, err) + } + } + } + if _, err := url.Parse(r.Upstream); err != nil { return fmt.Errorf("the upstream endpoint is invalid, %s", err) } diff --git a/doc.go b/doc.go index 97ea81d54..ce06853f0 100644 --- a/doc.go +++ b/doc.go @@ -136,6 +136,13 @@ var ( ErrDecryption = errors.New("failed to decrypt token") ) +type UpstreamURLPath struct { + // URL the url for the resource + URL string `json:"uri" yaml:"uri"` + // Upstream is the upstream endpoint i.e whom were proxying to + Upstream string `json:"upstream-url" yaml:"upstream-url"` +} + // Resource represents a url resource to protect type Resource struct { // URL the url for the resource @@ -184,6 +191,8 @@ type Config struct { Scopes []string `json:"scopes" yaml:"scopes" usage:"list of scopes requested when authenticating the user"` // Upstream is the upstream endpoint i.e whom were proxying to Upstream string `json:"upstream-url" yaml:"upstream-url" usage:"url for the upstream endpoint you wish to proxy" env:"UPSTREAM_URL"` + // Resources is a list of protected resources + UpstreamPaths []UpstreamURLPath `json:"upstream-url-paths" yaml:"upstream-url-paths" usage:"list of upstream url paths 'uri=/admin*|upstream-url=http://server1|uri=/data*|upstream-url=http://server2:8080'"` // UpstreamCA is the path to a CA certificate in PEM format to validate the upstream certificate UpstreamCA string `json:"upstream-ca" yaml:"upstream-ca" usage:"the path to a file container a CA certificate to validate the upstream tls endpoint"` // Resources is a list of protected resources diff --git a/forwarding.go b/forwarding.go index 928498575..473eff9ac 100644 --- a/forwarding.go +++ b/forwarding.go @@ -18,6 +18,7 @@ package main import ( "fmt" "net/http" + "net/url" "time" "github.com/coreos/go-oidc/jose" @@ -53,19 +54,25 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { req.Header.Set(k, v) } + r.upstream.ServeHTTP(w, req) + }) +} + +func (r *oauthProxy) forwardToUpstream(upstreamUrl *url.URL, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // @note: by default goproxy only provides a forwarding proxy, thus all requests have to be absolute and we must update the host headers - req.URL.Host = r.endpoint.Host - req.URL.Scheme = r.endpoint.Scheme + req.URL.Host = upstreamUrl.Host + req.URL.Scheme = upstreamUrl.Scheme if v := req.Header.Get("Host"); v != "" { req.Host = v req.Header.Del("Host") } else if !r.config.PreserveHost { - req.Host = r.endpoint.Host + req.Host = upstreamUrl.Host } if isUpgradedConnection(req) { r.log.Debug("upgrading the connnection", zap.String("client_ip", req.RemoteAddr)) - if err := tryUpdateConnection(req, w, r.endpoint); err != nil { + if err := tryUpdateConnection(req, w, upstreamUrl); err != nil { r.log.Error("failed to upgrade connection", zap.Error(err)) w.WriteHeader(http.StatusInternalServerError) return @@ -73,7 +80,7 @@ func (r *oauthProxy) proxyMiddleware(next http.Handler) http.Handler { return } - r.upstream.ServeHTTP(w, req) + next.ServeHTTP(w, req) }) } diff --git a/server.go b/server.go index b3c0c4874..0416f16a3 100644 --- a/server.go +++ b/server.go @@ -155,7 +155,7 @@ func createLogger(config *Config) (*zap.Logger, error) { // createReverseProxy creates a reverse proxy func (r *oauthProxy) createReverseProxy() error { r.log.Info("enabled reverse proxy mode, upstream url", zap.String("url", r.config.Upstream)) - if err := r.createUpstreamProxy(r.endpoint); err != nil { + if err := r.createDefaultUpstreamProxy(); err != nil { return err } engine := chi.NewRouter() @@ -293,14 +293,15 @@ func (r *oauthProxy) createForwardingProxy() error { if r.config.SkipUpstreamTLSVerify { r.log.Warn("tls verification switched off. In forward signing mode it's recommended you verify! (--skip-upstream-tls-verify=false)") } - if err := r.createUpstreamProxy(nil); err != nil { + + proxy, err := r.createUpstreamProxy(nil) + if err != nil { return err } //nolint:bodyclose forwardingHandler := r.forwardProxyHandler() // set the http handler - proxy := r.upstream.(*goproxy.ProxyHttpServer) r.router = proxy // setup the tls configuration @@ -553,8 +554,47 @@ func (r *oauthProxy) createHTTPListener(config listenerConfig) (net.Listener, er return listener, nil } +func (r *oauthProxy) createDefaultUpstreamProxy() error { + defaultUpstream, err := r.createUpstreamProxy(r.endpoint) + if err != nil { + return err + } + + if len(r.config.UpstreamPaths) > 0 { + engine := chi.NewRouter() + + for _, x := range r.config.UpstreamPaths { + path := x + fmt.Printf("%s => %s\n", path.URL, path.Upstream) + upstreamUrl, err := url.Parse(path.Upstream) + if err != nil { + return err + } + + proxy, err := r.createUpstreamProxy(upstreamUrl) + if err != nil { + return err + } + + engine.Mount(path.URL, r.forwardToUpstream(upstreamUrl, proxy)) + + //engine.Mount(path.URL, http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + // fmt.Printf("hit %s => %s\n", path.URL, u) + // proxy.ServeHTTP(writer, request) + //})) + } + + engine.NotFound(r.forwardToUpstream(r.endpoint, defaultUpstream).ServeHTTP) + + r.upstream = engine + } else { + r.upstream = r.forwardToUpstream(r.endpoint, defaultUpstream) + } + return nil +} + // createUpstreamProxy create a reverse http proxy from the upstream -func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { +func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) (*goproxy.ProxyHttpServer, error) { dialer := (&net.Dialer{ KeepAlive: r.config.UpstreamKeepaliveTimeout, Timeout: r.config.UpstreamTimeout, @@ -583,7 +623,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { cert, err := ioutil.ReadFile(r.config.TLSClientCertificate) if err != nil { r.log.Error("unable to read client certificate", zap.String("path", r.config.TLSClientCertificate), zap.Error(err)) - return err + return nil, err } pool := x509.NewCertPool() pool.AppendCertsFromPEM(cert) @@ -597,7 +637,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { r.log.Info("loading the upstream ca", zap.String("path", r.config.UpstreamCA)) ca, err := ioutil.ReadFile(r.config.UpstreamCA) if err != nil { - return err + return nil, err } pool := x509.NewCertPool() pool.AppendCertsFromPEM(ca) @@ -614,10 +654,9 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { proxy.KeepDestinationHeaders = true proxy.Logger = httplog.New(ioutil.Discard, "", 0) proxy.KeepDestinationHeaders = true - r.upstream = proxy // update the tls configuration of the reverse proxy - r.upstream.(*goproxy.ProxyHttpServer).Tr = &http.Transport{ + proxy.Tr = &http.Transport{ Dial: dialer, DisableKeepAlives: !r.config.UpstreamKeepalives, ExpectContinueTimeout: r.config.UpstreamExpectContinueTimeout, @@ -627,8 +666,7 @@ func (r *oauthProxy) createUpstreamProxy(upstream *url.URL) error { MaxIdleConns: r.config.MaxIdleConns, MaxIdleConnsPerHost: r.config.MaxIdleConnsPerHost, } - - return nil + return proxy, nil } // createTemplates loads the custom template diff --git a/server_upstream_paths_test.go b/server_upstream_paths_test.go new file mode 100644 index 000000000..2b36eda3b --- /dev/null +++ b/server_upstream_paths_test.go @@ -0,0 +1,99 @@ +/* +Copyright 2015 All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" +) + +// fakeUpstreamService acts as a fake upstream service, returns the headers and request +type counterService struct { + Name string + HitCounter int64 +} + +func (f *counterService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + fmt.Printf("counter %s\n", f.Name) + atomic.AddInt64(&f.HitCounter, 1) +} + +//TestWebSocket is used to validate that the proxy reverse proxy WebSocket connections. +func TestUpstreamPaths(t *testing.T) { + + // Setup an upstream service. + defaultSvc := &counterService{Name: "default"} + defaultSvcServer := httptest.NewServer(defaultSvc) + defer defaultSvcServer.Close() + + adminSvc := &counterService{Name: "admin"} + adminSvcServer := httptest.NewServer(adminSvc) + defer adminSvcServer.Close() + + dataSvc := &counterService{Name: "data"} + dataSvcServer := httptest.NewServer(dataSvc) + defer dataSvcServer.Close() + + counters := func() []int64 { + return []int64{ + atomic.AddInt64(&defaultSvc.HitCounter, 0), + atomic.AddInt64(&adminSvc.HitCounter, 0), + atomic.AddInt64(&dataSvc.HitCounter, 0), + } + } + + // Setup the proxy. + config := newFakeKeycloakConfig() + config.Upstream = defaultSvcServer.URL + config.UpstreamPaths = []UpstreamURLPath{ + { + URL: "/auth_all/white_listed/admin", + Upstream: adminSvcServer.URL, + }, + { + URL: "/auth_all/white_listed/data", + Upstream: dataSvcServer.URL, + }, + } + + auth := newFakeAuthServer() + if config == nil { + config = newFakeKeycloakConfig() + } + config.DiscoveryURL = auth.getLocation() + config.RevocationEndpoint = auth.getRevocationURL() + + proxy, err := newProxy(config) + require.NoError(t, err) + + proxyServer := httptest.NewServer(proxy.router) + defer proxyServer.Close() + + http.Get(proxyServer.URL + "/auth_all/white_listed/admin") + require.Equal(t, []int64{0, 1, 0}, counters()) + + http.Get(proxyServer.URL + "/auth_all/white_listed/other") + require.Equal(t, []int64{1, 1, 0}, counters()) + + http.Get(proxyServer.URL + "/auth_all/white_listed/data") + require.Equal(t, []int64{1, 1, 1}, counters()) + +}