diff --git a/cmd/sso-auth/main.go b/cmd/sso-auth/main.go index 72a4a485..ee5df9d6 100644 --- a/cmd/sso-auth/main.go +++ b/cmd/sso-auth/main.go @@ -6,15 +6,16 @@ import ( "os" "github.com/buzzfeed/sso/internal/auth" - log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/buzzfeed/sso/internal/pkg/httpserver" + "github.com/buzzfeed/sso/internal/pkg/logging" ) func init() { - log.SetServiceName("sso-authenticator") + logging.SetServiceName("sso-authenticator") } func main() { - logger := log.NewLogEntry() + logger := logging.NewLogEntry() config, err := auth.LoadConfig() if err != nil { @@ -53,5 +54,7 @@ func main() { Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, config.LoggingConfig.Enable, statsdClient), } - logger.Fatal(s.ListenAndServe()) + if err := httpserver.Run(s, config.ServerConfig.TimeoutConfig.Shutdown, logger); err != nil { + logger.WithError(err).Fatal("error running server") + } } diff --git a/cmd/sso-proxy/main.go b/cmd/sso-proxy/main.go index c50e1361..1d924b23 100755 --- a/cmd/sso-proxy/main.go +++ b/cmd/sso-proxy/main.go @@ -8,17 +8,18 @@ import ( "github.com/kelseyhightower/envconfig" - log "github.com/buzzfeed/sso/internal/pkg/logging" + "github.com/buzzfeed/sso/internal/pkg/httpserver" + "github.com/buzzfeed/sso/internal/pkg/logging" "github.com/buzzfeed/sso/internal/proxy" "github.com/buzzfeed/sso/internal/proxy/collector" ) func init() { - log.SetServiceName("sso-proxy") + logging.SetServiceName("sso-proxy") } func main() { - logger := log.NewLogEntry() + logger := logging.NewLogEntry() opts := proxy.NewOptions() err := envconfig.Process("", opts) @@ -58,5 +59,7 @@ func main() { Handler: loggingHandler, } - logger.Fatal(s.ListenAndServe()) + if err := httpserver.Run(s, opts.ShutdownTimeout, logger); err != nil { + logger.WithError(err).Fatal("error running server") + } } diff --git a/docs/sso_authenticator_config.md b/docs/sso_authenticator_config.md index 081c0527..b87fb25f 100644 --- a/docs/sso_authenticator_config.md +++ b/docs/sso_authenticator_config.md @@ -31,12 +31,13 @@ CLIENT_PROXY_SECRET - string - Client secret matching the SSO Proxy client secre ### Server ``` -SERVER_SCHEME - string - scheme the server will use, e.g. `https` -SERVER_HOST - string - host header that's required on incoming requests -SERVER_PORT - string - port the http server listens on -SERVER_TIMEOUT_REQUEST - time.Duration - overall request timeout -SERVER_TIMEOUT_WRITE - time.Duration - write request timeout -SERVER_TIMEOUT_READ - time.Duration - read request timeout +SERVER_SCHEME - string - scheme the server will use, e.g. `https` +SERVER_HOST - string - host header that's required on incoming requests +SERVER_PORT - string - port the http server listens on +SERVER_TIMEOUT_REQUEST - time.Duration - overall request timeout +SERVER_TIMEOUT_WRITE - time.Duration - write request timeout +SERVER_TIMEOUT_READ - time.Duration - read request timeout +SERVER_TIMEOUT_SHUTDOWN - time.Duration - time to allow in-flight requests to complete before server shutdown ``` @@ -69,7 +70,7 @@ PROVIDER_*_TYPE - string - determines the type of provider (supported o PROVIDER_*_SLUG - string - unique provider 'slug' that is used to separate and create routes to individual providers. PROVIDER_*_CLIENT_ID - string - OAuth Client ID PROVIDER_*_CLIENT_SECRET - string - OAuth Client secret -PROVIDER_*_SCOPE - string - OAuth scopes the provider will use. Default standard set of scopes pre-set in individual provider +PROVIDER_*_SCOPE - string - OAuth scopes the provider will use. Default standard set of scopes pre-set in individual provider files; which this configuration variable overrides. ``` diff --git a/go.sum b/go.sum index fb4e90de..b097befb 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,7 @@ github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OI github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go/v2 v2.0.4 h1:hU4mGcQI4DaAYW+IbTun+2qEZVFxK0ySjQLTbS0VQKc= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/gorilla/handlers v1.4.0/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= diff --git a/internal/auth/configuration.go b/internal/auth/configuration.go index bd111917..bec6c198 100644 --- a/internal/auth/configuration.go +++ b/internal/auth/configuration.go @@ -52,6 +52,7 @@ import ( // SERVER_TIMEOUT_REQUEST // SERVER_TIMEOUT_WRITE // SERVER_TIMEOUT_READ +// SERVER_TIMEOUT_SHUTDOWN // // AUTHORIZE_PROXY_DOMAINS // AUTHORIZE_EMAIL_DOMAINS @@ -79,9 +80,10 @@ func DefaultAuthConfig() Configuration { Port: 4180, Scheme: "https", TimeoutConfig: TimeoutConfig{ - Write: 30 * time.Second, - Read: 30 * time.Second, - Request: 45 * time.Second, + Write: 30 * time.Second, + Read: 30 * time.Second, + Request: 45 * time.Second, + Shutdown: 46 * time.Second, // by default, shutdown timeout matches request timeout + a little headroom }, }, SessionConfig: SessionConfig{ @@ -384,9 +386,10 @@ func (sc ServerConfig) Validate() error { } type TimeoutConfig struct { - Write time.Duration `mapstructure:"write"` - Read time.Duration `mapstructure:"read"` - Request time.Duration `mapstructure:"request"` + Write time.Duration `mapstructure:"write"` + Read time.Duration `mapstructure:"read"` + Request time.Duration `mapstructure:"request"` + Shutdown time.Duration `mapstructure:"shutdown"` } func (tc TimeoutConfig) Validate() error { diff --git a/internal/pkg/httpserver/httpserver.go b/internal/pkg/httpserver/httpserver.go new file mode 100644 index 00000000..c5b4dec8 --- /dev/null +++ b/internal/pkg/httpserver/httpserver.go @@ -0,0 +1,80 @@ +package httpserver + +import ( + "context" + "net" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/buzzfeed/sso/internal/pkg/logging" +) + +// OS signals that will initiate graceful shutdown of the http server. +// +// NOTE: defined in a variable so that they may be overridden by tests. +var shutdownSignals = []os.Signal{ + syscall.SIGINT, + syscall.SIGTERM, +} + +// Run runs an http server and ensures that it is shut down gracefully within +// the given shutdown timeout, allowing all in-flight requests to complete. +// +// Returns an error if a) the server fails to listen on its port or b) the +// shutdown timeout elapses before all in-flight requests are finished. +func Run(srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error { + // Logic below copied from the stdlib http.Server ListenAndServe() method: + // https://github.com/golang/go/blob/release-branch.go1.13/src/net/http/server.go#L2805-L2826 + addr := srv.Addr + if addr == "" { + addr = ":http" + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return runWithListener(ln, srv, shutdownTimeout, logger) +} + +// runWithListener does the heavy lifting for Run() above, and is decoupled +// only for testing purposes +func runWithListener(ln net.Listener, srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error { + var ( + // shutdownCh triggers graceful shutdown on SIGINT or SIGTERM + shutdownCh = make(chan os.Signal, 1) + + // exitCh will be closed when it is safe to exit, after graceful shutdown + exitCh = make(chan struct{}) + + // shutdownErr allows any error from srv.Shutdown to propagate out up + // from the goroutine + shutdownErr error + ) + + signal.Notify(shutdownCh, shutdownSignals...) + + go func() { + sig := <-shutdownCh + logger.Info("shutdown started by signal: ", sig) + signal.Stop(shutdownCh) + + logger.Info("waiting for server to shut down in ", shutdownTimeout) + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + shutdownErr = srv.Shutdown(ctx) + close(exitCh) + }() + + if serveErr := srv.Serve(ln); serveErr != nil && serveErr != http.ErrServerClosed { + return serveErr + } + + <-exitCh + logger.Info("shutdown finished") + + return shutdownErr +} diff --git a/internal/pkg/httpserver/httpserver_test.go b/internal/pkg/httpserver/httpserver_test.go new file mode 100644 index 00000000..3ae69151 --- /dev/null +++ b/internal/pkg/httpserver/httpserver_test.go @@ -0,0 +1,148 @@ +package httpserver + +import ( + "fmt" + "net" + "net/http" + "os" + "sync" + "syscall" + "testing" + "time" + + "github.com/buzzfeed/sso/internal/pkg/logging" +) + +func newLocalListener(t *testing.T) net.Listener { + t.Helper() + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to listen on a port: %v", err) + } + return l +} + +func TestGracefulShutdown(t *testing.T) { + proc, err := os.FindProcess(os.Getpid()) + if err != nil { + t.Fatal(err) + } + + // override shutdown signals used by Run for testing purposes + shutdownSignals = []os.Signal{syscall.SIGUSR1} + + logger := logging.NewLogEntry() + + testCases := map[string]struct { + shutdownTimeout time.Duration + requestDelay time.Duration + expectShutdownErr bool + expectRequestErr bool + }{ + "clean shutdown": { + shutdownTimeout: 1 * time.Second, + requestDelay: 250 * time.Millisecond, + expectShutdownErr: false, + expectRequestErr: false, + }, + "timeout elapsed": { + shutdownTimeout: 50 * time.Millisecond, + requestDelay: 250 * time.Millisecond, + expectShutdownErr: true, + + // In real usage, we would expect the request to be aborted when + // the server is shut down and its process exits before it can + // finish responding. + // + // But because we're running the server within the test process, + // which does not exit after shutdown, the goroutine handling the + // long-running request does not seem to get canceled and the + // request ends up completing successfully even after the server + // has shut down. + // + // Properly testing this would require something like re-running + // the test binary as a subprocess to which we can send SIGTERM, + // but doing that would add a lot more complexity (e.g. having it + // bind to a random available port and then running a separate + // subprocess to figure out the port to which it is bound, all in a + // cross-platform way). + // + // If we wanted to go that route, some examples of the general + // approach can be seen here: + // + // - http://cs-guy.com/blog/2015/01/test-main/#toc_3 + // - https://talks.golang.org/2014/testing.slide#23 + expectRequestErr: false, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + var ( + ln = newLocalListener(t) + addr = ln.Addr().String() + url = fmt.Sprintf("http://%s", addr) + ) + + srv := &http.Server{ + Addr: addr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(tc.requestDelay) + }), + } + + var ( + wg sync.WaitGroup + shutdownErr error + requestErr error + ) + + // Run our server and wait for a stop signal + wg.Add(1) + go func() { + defer wg.Done() + shutdownErr = runWithListener(ln, srv, tc.shutdownTimeout, logger) + }() + + // give the server time to start listening + <-time.After(50 * time.Millisecond) + + // make a request + wg.Add(1) + go func() { + defer wg.Done() + _, requestErr = http.Get(url) + }() + + // give the request some time to connect + <-time.After(1 * time.Millisecond) + + // tell server to shut down gracefully + proc.Signal(syscall.SIGUSR1) + + // wait for server to shut down and requests to complete + wg.Wait() + + if tc.expectShutdownErr { + if shutdownErr == nil { + t.Fatalf("did not get expected shutdown error") + } + } else { + if shutdownErr != nil { + t.Fatalf("got unexpected shutdown error: %s", shutdownErr) + } + } + + if tc.expectRequestErr && requestErr == nil { + if requestErr == nil { + t.Fatalf("did not get expected request error") + } + } else { + if requestErr != nil { + t.Fatalf("got unexpected request error: %s", requestErr) + } + } + }) + } +} diff --git a/internal/proxy/options.go b/internal/proxy/options.go index 06da8533..dab86d1a 100644 --- a/internal/proxy/options.go +++ b/internal/proxy/options.go @@ -49,6 +49,7 @@ import ( // RequestLoging - boolean whether or not to log requests // StatsdHost - host addr for statsd client to listen on // StatsdPort - port for statsdclient to listen on +// ShutdownTimeout - maximum time to wait for in-flight HTTP requests to complete before shutdown type Options struct { Port int `envconfig:"PORT" default:"4180"` @@ -97,6 +98,8 @@ type Options struct { RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"` + ShutdownTimeout time.Duration `envconfig:"SHUTDOWN_TIMEOUT" default:"30s"` + StatsdClient *statsd.Client // This is an override for supplying template vars at test time