Skip to content

Commit

Permalink
Merge pull request #262 from buzzfeed/graceful-shutdown
Browse files Browse the repository at this point in the history
cmd: ensure http servers shut down gracefully
  • Loading branch information
mccutchen authored Nov 4, 2019
2 parents b76969f + 5e4a232 commit a29ba90
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 21 deletions.
11 changes: 7 additions & 4 deletions cmd/sso-auth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@ import (
"os"

"github.com/buzzfeed/sso/internal/auth"
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/httpserver"
"github.com/buzzfeed/sso/internal/pkg/logging"
)

func init() {
log.SetServiceName("sso-authenticator")
logging.SetServiceName("sso-authenticator")
}

func main() {
logger := log.NewLogEntry()
logger := logging.NewLogEntry()

config, err := auth.LoadConfig()
if err != nil {
Expand Down Expand Up @@ -53,5 +54,7 @@ func main() {
Handler: auth.NewLoggingHandler(os.Stdout, timeoutHandler, config.LoggingConfig.Enable, statsdClient),
}

logger.Fatal(s.ListenAndServe())
if err := httpserver.Run(s, config.ServerConfig.TimeoutConfig.Shutdown, logger); err != nil {
logger.WithError(err).Fatal("error running server")
}
}
11 changes: 7 additions & 4 deletions cmd/sso-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@ import (

"github.com/kelseyhightower/envconfig"

log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/httpserver"
"github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/proxy"
"github.com/buzzfeed/sso/internal/proxy/collector"
)

func init() {
log.SetServiceName("sso-proxy")
logging.SetServiceName("sso-proxy")
}

func main() {
logger := log.NewLogEntry()
logger := logging.NewLogEntry()

opts := proxy.NewOptions()
err := envconfig.Process("", opts)
Expand Down Expand Up @@ -58,5 +59,7 @@ func main() {
Handler: loggingHandler,
}

logger.Fatal(s.ListenAndServe())
if err := httpserver.Run(s, opts.ShutdownTimeout, logger); err != nil {
logger.WithError(err).Fatal("error running server")
}
}
15 changes: 8 additions & 7 deletions docs/sso_authenticator_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ CLIENT_PROXY_SECRET - string - Client secret matching the SSO Proxy client secre

### Server
```
SERVER_SCHEME - string - scheme the server will use, e.g. `https`
SERVER_HOST - string - host header that's required on incoming requests
SERVER_PORT - string - port the http server listens on
SERVER_TIMEOUT_REQUEST - time.Duration - overall request timeout
SERVER_TIMEOUT_WRITE - time.Duration - write request timeout
SERVER_TIMEOUT_READ - time.Duration - read request timeout
SERVER_SCHEME - string - scheme the server will use, e.g. `https`
SERVER_HOST - string - host header that's required on incoming requests
SERVER_PORT - string - port the http server listens on
SERVER_TIMEOUT_REQUEST - time.Duration - overall request timeout
SERVER_TIMEOUT_WRITE - time.Duration - write request timeout
SERVER_TIMEOUT_READ - time.Duration - read request timeout
SERVER_TIMEOUT_SHUTDOWN - time.Duration - time to allow in-flight requests to complete before server shutdown
```


Expand Down Expand Up @@ -69,7 +70,7 @@ PROVIDER_*_TYPE - string - determines the type of provider (supported o
PROVIDER_*_SLUG - string - unique provider 'slug' that is used to separate and create routes to individual providers.
PROVIDER_*_CLIENT_ID - string - OAuth Client ID
PROVIDER_*_CLIENT_SECRET - string - OAuth Client secret
PROVIDER_*_SCOPE - string - OAuth scopes the provider will use. Default standard set of scopes pre-set in individual provider
PROVIDER_*_SCOPE - string - OAuth scopes the provider will use. Default standard set of scopes pre-set in individual provider
files; which this configuration variable overrides.
```

Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OI
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4 h1:hU4mGcQI4DaAYW+IbTun+2qEZVFxK0ySjQLTbS0VQKc=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/handlers v1.4.0/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ=
github.com/gorilla/mux v1.7.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
Expand Down
15 changes: 9 additions & 6 deletions internal/auth/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import (
// SERVER_TIMEOUT_REQUEST
// SERVER_TIMEOUT_WRITE
// SERVER_TIMEOUT_READ
// SERVER_TIMEOUT_SHUTDOWN
//
// AUTHORIZE_PROXY_DOMAINS
// AUTHORIZE_EMAIL_DOMAINS
Expand Down Expand Up @@ -79,9 +80,10 @@ func DefaultAuthConfig() Configuration {
Port: 4180,
Scheme: "https",
TimeoutConfig: TimeoutConfig{
Write: 30 * time.Second,
Read: 30 * time.Second,
Request: 45 * time.Second,
Write: 30 * time.Second,
Read: 30 * time.Second,
Request: 45 * time.Second,
Shutdown: 46 * time.Second, // by default, shutdown timeout matches request timeout + a little headroom
},
},
SessionConfig: SessionConfig{
Expand Down Expand Up @@ -384,9 +386,10 @@ func (sc ServerConfig) Validate() error {
}

type TimeoutConfig struct {
Write time.Duration `mapstructure:"write"`
Read time.Duration `mapstructure:"read"`
Request time.Duration `mapstructure:"request"`
Write time.Duration `mapstructure:"write"`
Read time.Duration `mapstructure:"read"`
Request time.Duration `mapstructure:"request"`
Shutdown time.Duration `mapstructure:"shutdown"`
}

func (tc TimeoutConfig) Validate() error {
Expand Down
80 changes: 80 additions & 0 deletions internal/pkg/httpserver/httpserver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package httpserver

import (
"context"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/buzzfeed/sso/internal/pkg/logging"
)

// OS signals that will initiate graceful shutdown of the http server.
//
// NOTE: defined in a variable so that they may be overridden by tests.
var shutdownSignals = []os.Signal{
syscall.SIGINT,
syscall.SIGTERM,
}

// Run runs an http server and ensures that it is shut down gracefully within
// the given shutdown timeout, allowing all in-flight requests to complete.
//
// Returns an error if a) the server fails to listen on its port or b) the
// shutdown timeout elapses before all in-flight requests are finished.
func Run(srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error {
// Logic below copied from the stdlib http.Server ListenAndServe() method:
// https://github.com/golang/go/blob/release-branch.go1.13/src/net/http/server.go#L2805-L2826
addr := srv.Addr
if addr == "" {
addr = ":http"
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return runWithListener(ln, srv, shutdownTimeout, logger)
}

// runWithListener does the heavy lifting for Run() above, and is decoupled
// only for testing purposes
func runWithListener(ln net.Listener, srv *http.Server, shutdownTimeout time.Duration, logger *logging.LogEntry) error {
var (
// shutdownCh triggers graceful shutdown on SIGINT or SIGTERM
shutdownCh = make(chan os.Signal, 1)

// exitCh will be closed when it is safe to exit, after graceful shutdown
exitCh = make(chan struct{})

// shutdownErr allows any error from srv.Shutdown to propagate out up
// from the goroutine
shutdownErr error
)

signal.Notify(shutdownCh, shutdownSignals...)

go func() {
sig := <-shutdownCh
logger.Info("shutdown started by signal: ", sig)
signal.Stop(shutdownCh)

logger.Info("waiting for server to shut down in ", shutdownTimeout)
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

shutdownErr = srv.Shutdown(ctx)
close(exitCh)
}()

if serveErr := srv.Serve(ln); serveErr != nil && serveErr != http.ErrServerClosed {
return serveErr
}

<-exitCh
logger.Info("shutdown finished")

return shutdownErr
}
148 changes: 148 additions & 0 deletions internal/pkg/httpserver/httpserver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package httpserver

import (
"fmt"
"net"
"net/http"
"os"
"sync"
"syscall"
"testing"
"time"

"github.com/buzzfeed/sso/internal/pkg/logging"
)

func newLocalListener(t *testing.T) net.Listener {
t.Helper()

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to listen on a port: %v", err)
}
return l
}

func TestGracefulShutdown(t *testing.T) {
proc, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatal(err)
}

// override shutdown signals used by Run for testing purposes
shutdownSignals = []os.Signal{syscall.SIGUSR1}

logger := logging.NewLogEntry()

testCases := map[string]struct {
shutdownTimeout time.Duration
requestDelay time.Duration
expectShutdownErr bool
expectRequestErr bool
}{
"clean shutdown": {
shutdownTimeout: 1 * time.Second,
requestDelay: 250 * time.Millisecond,
expectShutdownErr: false,
expectRequestErr: false,
},
"timeout elapsed": {
shutdownTimeout: 50 * time.Millisecond,
requestDelay: 250 * time.Millisecond,
expectShutdownErr: true,

// In real usage, we would expect the request to be aborted when
// the server is shut down and its process exits before it can
// finish responding.
//
// But because we're running the server within the test process,
// which does not exit after shutdown, the goroutine handling the
// long-running request does not seem to get canceled and the
// request ends up completing successfully even after the server
// has shut down.
//
// Properly testing this would require something like re-running
// the test binary as a subprocess to which we can send SIGTERM,
// but doing that would add a lot more complexity (e.g. having it
// bind to a random available port and then running a separate
// subprocess to figure out the port to which it is bound, all in a
// cross-platform way).
//
// If we wanted to go that route, some examples of the general
// approach can be seen here:
//
// - http://cs-guy.com/blog/2015/01/test-main/#toc_3
// - https://talks.golang.org/2014/testing.slide#23
expectRequestErr: false,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
var (
ln = newLocalListener(t)
addr = ln.Addr().String()
url = fmt.Sprintf("http://%s", addr)
)

srv := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(tc.requestDelay)
}),
}

var (
wg sync.WaitGroup
shutdownErr error
requestErr error
)

// Run our server and wait for a stop signal
wg.Add(1)
go func() {
defer wg.Done()
shutdownErr = runWithListener(ln, srv, tc.shutdownTimeout, logger)
}()

// give the server time to start listening
<-time.After(50 * time.Millisecond)

// make a request
wg.Add(1)
go func() {
defer wg.Done()
_, requestErr = http.Get(url)
}()

// give the request some time to connect
<-time.After(1 * time.Millisecond)

// tell server to shut down gracefully
proc.Signal(syscall.SIGUSR1)

// wait for server to shut down and requests to complete
wg.Wait()

if tc.expectShutdownErr {
if shutdownErr == nil {
t.Fatalf("did not get expected shutdown error")
}
} else {
if shutdownErr != nil {
t.Fatalf("got unexpected shutdown error: %s", shutdownErr)
}
}

if tc.expectRequestErr && requestErr == nil {
if requestErr == nil {
t.Fatalf("did not get expected request error")
}
} else {
if requestErr != nil {
t.Fatalf("got unexpected request error: %s", requestErr)
}
}
})
}
}
3 changes: 3 additions & 0 deletions internal/proxy/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
// RequestLoging - boolean whether or not to log requests
// StatsdHost - host addr for statsd client to listen on
// StatsdPort - port for statsdclient to listen on
// ShutdownTimeout - maximum time to wait for in-flight HTTP requests to complete before shutdown
type Options struct {
Port int `envconfig:"PORT" default:"4180"`

Expand Down Expand Up @@ -97,6 +98,8 @@ type Options struct {

RequestSigningKey string `envconfig:"REQUEST_SIGNATURE_KEY"`

ShutdownTimeout time.Duration `envconfig:"SHUTDOWN_TIMEOUT" default:"30s"`

StatsdClient *statsd.Client

// This is an override for supplying template vars at test time
Expand Down

0 comments on commit a29ba90

Please sign in to comment.