diff --git a/internal/auth/authenticator.go b/internal/auth/authenticator.go index 169709d1..f449d893 100644 --- a/internal/auth/authenticator.go +++ b/internal/auth/authenticator.go @@ -12,16 +12,16 @@ import ( "github.com/buzzfeed/sso/internal/auth/providers" "github.com/buzzfeed/sso/internal/pkg/aead" log "github.com/buzzfeed/sso/internal/pkg/logging" - "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/pkg/templates" + "github.com/buzzfeed/sso/internal/pkg/validators" "github.com/datadog/datadog-go/statsd" ) // Authenticator stores all the information associated with proxying the request. type Authenticator struct { - Validators []options.Validator + Validators []validators.Validator EmailDomains []string ProxyRootDomains []string Host string @@ -226,7 +226,7 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request) } } - errors := options.RunValidators(p.Validators, session) + errors := validators.RunValidators(p.Validators, session) if len(errors) == len(p.Validators) { logger.WithUser(session.Email).Info( fmt.Sprintf("permission denied: unauthorized: %q", errors)) @@ -582,7 +582,7 @@ func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Reque // - for p.Validator see validator.go#newValidatorImpl for more info // - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info - errors := options.RunValidators(p.Validators, session) + errors := validators.RunValidators(p.Validators, session) if len(errors) == len(p.Validators) { tags := append(tags, "error:invalid_email") p.StatsdClient.Incr("application_error", tags, 1.0) diff --git a/internal/auth/authenticator_test.go b/internal/auth/authenticator_test.go index 998cb998..3375bb83 100644 --- a/internal/auth/authenticator_test.go +++ b/internal/auth/authenticator_test.go @@ -17,10 +17,10 @@ import ( "github.com/buzzfeed/sso/internal/auth/providers" "github.com/buzzfeed/sso/internal/pkg/aead" - "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/pkg/templates" "github.com/buzzfeed/sso/internal/pkg/testutil" + "github.com/buzzfeed/sso/internal/pkg/validators" ) func init() { @@ -418,7 +418,7 @@ func TestSignIn(t *testing.T) { t.Run(tc.name, func(t *testing.T) { config := testConfiguration(t) auth, err := NewAuthenticator(config, - SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}), + SetValidators([]validators.Validator{validators.NewMockValidator(tc.validEmail, nil)}), setMockSessionStore(tc.mockSessionStore), setMockTempl(), setMockRedirectURL(), @@ -565,7 +565,7 @@ func TestSignOutPage(t *testing.T) { provider.RevokeError = tc.RevokeError p, _ := NewAuthenticator(config, - SetValidators([]options.Validator{options.NewMockValidator(true)}), + SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}), setMockSessionStore(tc.mockSessionStore), setMockTempl(), setTestProvider(provider), @@ -942,7 +942,7 @@ func TestGetProfile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { config := testConfiguration(t) p, _ := NewAuthenticator(config, - SetValidators([]options.Validator{options.NewMockValidator(true)}), + SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}), ) u, _ := url.Parse("http://example.com") testProvider := providers.NewTestProvider(u) @@ -1044,7 +1044,7 @@ func TestRedeemCode(t *testing.T) { config := testConfiguration(t) proxy, _ := NewAuthenticator(config, - SetValidators([]options.Validator{options.NewMockValidator(true)}), + SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}), ) testURL, err := url.Parse("example.com") @@ -1433,7 +1433,7 @@ func TestOAuthCallback(t *testing.T) { t.Run(tc.name, func(t *testing.T) { config := testConfiguration(t) proxy, _ := NewAuthenticator(config, - SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}), + SetValidators([]validators.Validator{validators.NewMockValidator(tc.validEmail, nil)}), setMockCSRFStore(tc.csrfResp), setMockSessionStore(tc.sessionStore), ) @@ -1554,7 +1554,7 @@ func TestOAuthStart(t *testing.T) { provider := providers.NewTestProvider(nil) proxy, _ := NewAuthenticator(config, setTestProvider(provider), - SetValidators([]options.Validator{options.NewMockValidator(true)}), + SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}), setMockRedirectURL(), setMockCSRFStore(&sessions.MockCSRFStore{}), ) diff --git a/internal/auth/mux.go b/internal/auth/mux.go index 5fe0c639..64585f67 100644 --- a/internal/auth/mux.go +++ b/internal/auth/mux.go @@ -6,7 +6,7 @@ import ( "github.com/buzzfeed/sso/internal/pkg/hostmux" log "github.com/buzzfeed/sso/internal/pkg/logging" - "github.com/buzzfeed/sso/internal/pkg/options" + "github.com/buzzfeed/sso/internal/pkg/validators" "github.com/datadog/datadog-go/statsd" ) @@ -18,11 +18,11 @@ type AuthenticatorMux struct { func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*AuthenticatorMux, error) { logger := log.NewLogEntry() - validators := []options.Validator{} + v := []validators.Validator{} if len(config.AuthorizeConfig.EmailConfig.Addresses) != 0 { - validators = append(validators, options.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses)) + v = append(v, validators.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses)) } else { - validators = append(validators, options.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains)) + v = append(v, validators.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains)) } authenticators := []*Authenticator{} @@ -37,7 +37,7 @@ func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*Au idpSlug := idp.Data().ProviderSlug authenticator, err := NewAuthenticator(config, - SetValidators(validators), + SetValidators(v), SetProvider(idp), SetCookieStore(config.SessionConfig, idpSlug), SetStatsdClient(statsdClient), diff --git a/internal/auth/options.go b/internal/auth/options.go index a7d5f0d9..59e20712 100644 --- a/internal/auth/options.go +++ b/internal/auth/options.go @@ -9,8 +9,8 @@ import ( "github.com/buzzfeed/sso/internal/auth/providers" "github.com/buzzfeed/sso/internal/pkg/aead" "github.com/buzzfeed/sso/internal/pkg/groups" - "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/buzzfeed/sso/internal/pkg/validators" "github.com/datadog/datadog-go/statsd" ) @@ -97,7 +97,7 @@ func SetRedirectURL(serverConfig ServerConfig, slug string) func(*Authenticator) } // SetValidator sets the email validator -func SetValidators(validators []options.Validator) func(*Authenticator) error { +func SetValidators(validators []validators.Validator) func(*Authenticator) error { return func(a *Authenticator) error { a.Validators = validators return nil diff --git a/internal/pkg/logging/logging.go b/internal/pkg/logging/logging.go index 003d5820..cafc4cdd 100644 --- a/internal/pkg/logging/logging.go +++ b/internal/pkg/logging/logging.go @@ -144,6 +144,11 @@ func (l *LogEntry) WithEndpoint(endpoint string) *LogEntry { return l.withField("endpoint", endpoint) } +// WithAuthorizedUpstream appends an `authorized_upstream` tag to a LogEntry. +func (l *LogEntry) WithAuthorizedUpstream(upstream string) *LogEntry { + return l.withField("authorized_upstream", upstream) +} + // WithError appends an `error` tag to a LogEntry. Useful for annotating non-Error log // entries (e.g. Fatal messages) with an `error` object. func (l *LogEntry) WithError(err error) *LogEntry { diff --git a/internal/pkg/sessions/session_state.go b/internal/pkg/sessions/session_state.go index f6044d7e..be254a16 100644 --- a/internal/pkg/sessions/session_state.go +++ b/internal/pkg/sessions/session_state.go @@ -25,9 +25,10 @@ type SessionState struct { ValidDeadline time.Time `json:"valid_deadline"` GracePeriodStart time.Time `json:"grace_period_start"` - Email string `json:"email"` - User string `json:"user"` - Groups []string `json:"groups"` + Email string `json:"email"` + User string `json:"user"` + Groups []string `json:"groups"` + AuthorizedUpstream string `json:"authorized_upstream"` } // LifetimePeriodExpired returns true if the lifetime has expired @@ -45,6 +46,14 @@ func (s *SessionState) ValidationPeriodExpired() bool { return isExpired(s.ValidDeadline) } +// IsWithinGracePeriod returns true if the session is still within the grace period +func (s *SessionState) IsWithinGracePeriod(gracePeriodTTL time.Duration) bool { + if s.GracePeriodStart.IsZero() { + s.GracePeriodStart = time.Now() + } + return s.GracePeriodStart.Add(gracePeriodTTL).After(time.Now()) +} + func isExpired(t time.Time) bool { if t.Before(time.Now()) { return true diff --git a/internal/pkg/sessions/session_state_test.go b/internal/pkg/sessions/session_state_test.go index 9db8d607..adcdfe58 100644 --- a/internal/pkg/sessions/session_state_test.go +++ b/internal/pkg/sessions/session_state_test.go @@ -55,20 +55,29 @@ func TestSessionStateExpirations(t *testing.T) { LifetimeDeadline: time.Now().Add(-1 * time.Hour), RefreshDeadline: time.Now().Add(-1 * time.Hour), ValidDeadline: time.Now().Add(-1 * time.Minute), + GracePeriodStart: time.Now().Add(-2 * time.Minute), Email: "user@domain.com", User: "user", } if !session.LifetimePeriodExpired() { - t.Errorf("expcted lifetime period to be expired") + t.Errorf("expected lifetime period to be expired") } if !session.RefreshPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") + t.Errorf("expected lifetime period to be expired") } if !session.ValidationPeriodExpired() { - t.Errorf("expcted lifetime period to be expired") + t.Errorf("expected lifetime period to be expired") + } + + if session.IsWithinGracePeriod(1 * time.Minute) { + t.Errorf("expected session to be outside of grace period") + } + + if !session.IsWithinGracePeriod(3 * time.Minute) { + t.Errorf("expected session to be inside of grace period") } } diff --git a/internal/pkg/options/email_address_validator.go b/internal/pkg/validators/email_address_validator.go similarity index 98% rename from internal/pkg/options/email_address_validator.go rename to internal/pkg/validators/email_address_validator.go index 02c6195e..050f5025 100644 --- a/internal/pkg/options/email_address_validator.go +++ b/internal/pkg/validators/email_address_validator.go @@ -1,4 +1,4 @@ -package options +package validators import ( "errors" diff --git a/internal/pkg/options/email_address_validator_test.go b/internal/pkg/validators/email_address_validator_test.go similarity index 99% rename from internal/pkg/options/email_address_validator_test.go rename to internal/pkg/validators/email_address_validator_test.go index aef92f75..d4a8af24 100644 --- a/internal/pkg/options/email_address_validator_test.go +++ b/internal/pkg/validators/email_address_validator_test.go @@ -1,4 +1,4 @@ -package options +package validators import ( "testing" diff --git a/internal/pkg/options/email_domain_validator.go b/internal/pkg/validators/email_domain_validator.go similarity index 99% rename from internal/pkg/options/email_domain_validator.go rename to internal/pkg/validators/email_domain_validator.go index 4afdf372..c5201059 100644 --- a/internal/pkg/options/email_domain_validator.go +++ b/internal/pkg/validators/email_domain_validator.go @@ -1,4 +1,4 @@ -package options +package validators import ( "errors" diff --git a/internal/pkg/options/email_domain_validator_test.go b/internal/pkg/validators/email_domain_validator_test.go similarity index 99% rename from internal/pkg/options/email_domain_validator_test.go rename to internal/pkg/validators/email_domain_validator_test.go index e2eb407f..31a6d1d4 100644 --- a/internal/pkg/options/email_domain_validator_test.go +++ b/internal/pkg/validators/email_domain_validator_test.go @@ -1,4 +1,4 @@ -package options +package validators import ( "testing" diff --git a/internal/pkg/options/email_group_validator.go b/internal/pkg/validators/email_group_validator.go similarity index 84% rename from internal/pkg/options/email_group_validator.go rename to internal/pkg/validators/email_group_validator.go index 14d1c6dd..b67be1b6 100644 --- a/internal/pkg/options/email_group_validator.go +++ b/internal/pkg/validators/email_group_validator.go @@ -1,7 +1,9 @@ -package options +package validators import ( "errors" + "fmt" + "strings" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/proxy/providers" @@ -10,7 +12,7 @@ import ( var ( _ Validator = EmailGroupValidator{} - // These error message should be formatted in such a way that is appropriate + // These error messages should be formatted in such a way that is appropriate // for display to the end user. ErrGroupMembership = errors.New("Invalid Group Membership") ) @@ -42,7 +44,7 @@ func (v EmailGroupValidator) Validate(session *sessions.SessionState) error { func (v EmailGroupValidator) validate(session *sessions.SessionState) error { matchedGroups, valid, err := v.Provider.ValidateGroup(session.Email, v.AllowedGroups, session.AccessToken) if err != nil { - return ErrValidationError + return err } if valid { @@ -50,5 +52,5 @@ func (v EmailGroupValidator) validate(session *sessions.SessionState) error { return nil } - return ErrGroupMembership + return fmt.Errorf("%v - Allowed Groups: %s", ErrGroupMembership, strings.Join(v.AllowedGroups, ", ")) } diff --git a/internal/pkg/options/mock_validator.go b/internal/pkg/validators/mock_validator.go similarity index 56% rename from internal/pkg/options/mock_validator.go rename to internal/pkg/validators/mock_validator.go index 87739be0..41bebf11 100644 --- a/internal/pkg/options/mock_validator.go +++ b/internal/pkg/validators/mock_validator.go @@ -1,4 +1,4 @@ -package options +package validators import ( "errors" @@ -12,18 +12,26 @@ var ( type MockValidator struct { Result bool + Err error } -func NewMockValidator(result bool) MockValidator { +func NewMockValidator(result bool, err error) MockValidator { return MockValidator{ Result: result, + Err: err, } } func (v MockValidator) Validate(session *sessions.SessionState) error { + // if we pass in a specific error, return it + if v.Err != nil { + return v.Err + } + // if result is true, return nil if v.Result { return nil } + // otherwise, return generic mock validator error return errors.New("MockValidator error") } diff --git a/internal/pkg/options/validators.go b/internal/pkg/validators/validators.go similarity index 91% rename from internal/pkg/options/validators.go rename to internal/pkg/validators/validators.go index 0b14e48f..87e9c1de 100644 --- a/internal/pkg/options/validators.go +++ b/internal/pkg/validators/validators.go @@ -1,4 +1,4 @@ -package options +package validators import ( "errors" @@ -10,7 +10,6 @@ var ( // These error message should be formatted in such a way that is appropriate // for display to the end user. ErrInvalidEmailAddress = errors.New("Invalid Email Address In Session State") - ErrValidationError = errors.New("Error during validation") ) type Validator interface { diff --git a/internal/proxy/oauthproxy.go b/internal/proxy/oauthproxy.go index 0944c95e..d0c2515f 100644 --- a/internal/proxy/oauthproxy.go +++ b/internal/proxy/oauthproxy.go @@ -13,8 +13,8 @@ import ( "github.com/buzzfeed/sso/internal/pkg/aead" log "github.com/buzzfeed/sso/internal/pkg/logging" - "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/pkg/sessions" + "github.com/buzzfeed/sso/internal/pkg/validators" "github.com/buzzfeed/sso/internal/proxy/providers" "github.com/datadog/datadog-go/statsd" @@ -39,9 +39,10 @@ var SignatureHeaders = []string{ // Errors var ( - ErrLifetimeExpired = errors.New("user lifetime expired") - ErrUserNotAuthorized = errors.New("user not authorized") - ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider") + ErrLifetimeExpired = errors.New("user lifetime expired") + ErrUserNotAuthorized = errors.New("user not authorized") + ErrWrongIdentityProvider = errors.New("user authenticated with wrong identity provider") + ErrUnauthorizedUpstreamRequested = errors.New("user session authorized with different upstream") ) type ErrOAuthProxyMisconfigured struct { @@ -57,7 +58,7 @@ const statusInvalidHost = 421 // OAuthProxy stores all the information associated with proxying the request. type OAuthProxy struct { cookieSecure bool - Validators []options.Validator + Validators []validators.Validator redirectURL *url.URL // the url to receive requests at templates *template.Template @@ -150,7 +151,7 @@ func SetProxyHandler(handler http.Handler) func(*OAuthProxy) error { } // SetValidator sets the email validator as a functional option -func SetValidators(validators []options.Validator) func(*OAuthProxy) error { +func SetValidators(validators []validators.Validator) func(*OAuthProxy) error { return func(op *OAuthProxy) error { op.Validators = validators return nil @@ -176,7 +177,7 @@ func NewOAuthProxy(opts *Options, optFuncs ...func(*OAuthProxy) error) (*OAuthPr p := &OAuthProxy{ cookieSecure: opts.CookieSecure, StatsdClient: opts.StatsdClient, - Validators: []options.Validator{}, + Validators: []validators.Validator{}, redirectURL: &url.URL{Path: "/oauth2/callback"}, templates: getTemplates(), @@ -359,7 +360,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i p.templates.ExecuteTemplate(rw, "error.html", t) } -// IsWhitelistedRequest cheks that proxy host exists and checks the SkipAuthRegex +// IsWhitelistedRequest checks that proxy host exists and checks the SkipAuthRegex func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) bool { if p.skipAuthPreflight && req.Method == "OPTIONS" { return true @@ -375,6 +376,26 @@ func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) bool { return false } +// runValidatorsWithGracePeriod runs all validators and upon finding errors, checks to see if the +// auth provider is explicity denying authentication or if it's merely unavailable. If it's unavailable, +// we check whether the session is within the grace period or not to determine the specific error we return. +func (p *OAuthProxy) runValidatorsWithGracePeriod(session *sessions.SessionState) (err error) { + logger := log.NewLogEntry() + errors := validators.RunValidators(p.Validators, session) + if len(errors) == len(p.Validators) { + for _, err := range errors { + // Check to see if the auth provider is explicity denying authentication, or if it is merely unavailable. + if err == providers.ErrAuthProviderUnavailable && session.IsWithinGracePeriod(p.provider.Data().GracePeriodTTL) { + return err + } + } + logger.WithUser(session.Email).Error(errors, + "no longer authorized after validation period") + return ErrUserNotAuthorized + } + return nil +} + func (p *OAuthProxy) isXHR(req *http.Request) bool { return req.Header.Get("X-Requested-With") == "XMLHttpRequest" } @@ -569,26 +590,28 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { // // set cookie, or deny - errors := options.RunValidators(p.Validators, session) + errors := validators.RunValidators(p.Validators, session) if len(errors) == len(p.Validators) { tags = append(tags, "error:validation_failed") p.StatsdClient.Incr("application_error", tags, 1.0) logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info( fmt.Sprintf("permission denied: unauthorized: %q", errors)) - formattedErrors := make([]string, 0, len(errors)) - for _, err := range errors { - formattedErrors = append(formattedErrors, err.Error()) - } - errorMsg := fmt.Sprintf("We ran into some issues while validating your account: \"%s\"", - strings.Join(formattedErrors, ", ")) - p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", errorMsg) + p.ErrorPage(rw, req, http.StatusForbidden, "Permission Denied", + fmt.Sprintf("We ran into some issues while validating your account: %q", errors)) return } logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).WithInGroups(session.Groups).Info( fmt.Sprintf("oauth callback: user validated ")) + // We add the request host into the session to allow us to validate that each request has + // been authorized for the upstream it's requesting. + // e.g. if a request is authenticated while trying to reach 'foo' upstream, it should not + // automatically be seen as authorized with 'bar' upstream. Each upstream may set different + // validators, so the request should be reauthenticated. + session.AuthorizedUpstream = req.Host + // We store the session in a cookie and redirect the user back to the application err = p.sessionStore.SaveSession(rw, req, session) if err != nil { @@ -653,6 +676,13 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { // the user has a stale sesssion. p.OAuthStart(rw, req, tags) return + case ErrUnauthorizedUpstreamRequested: + // The users session has been authorised for use with a different upstream than the one + // that is being requested, so we trigger the start of the oauth flow. + // This exists primarily to implement some form of grace period while this additional session + // check is being introduced. + p.OAuthStart(rw, req, tags) + return case sessions.ErrInvalidSession: // The user session is invalid and we can't decode it. // This can happen for a variety of reasons but the most common non-malicious @@ -693,8 +723,6 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er remoteAddr := getRemoteAddr(req) tags := []string{"action:authenticate"} - allowedGroups := p.upstreamConfig.AllowedGroups - // Clear the session cookie if anything goes wrong. defer func() { if err != nil { @@ -705,7 +733,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er session, err := p.sessionStore.LoadSession(req) if err != nil { // We loaded a cookie but it wasn't valid, clear it, and reject the request - logger.Error(err, "error authenticating user") + logger.Error(err, "invalid session loaded") return err } @@ -718,6 +746,15 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return ErrWrongIdentityProvider } + // check that the user has been authorized against the requested upstream + // this is primarily to combat against a user authorizing with one upstream and attempting to use + // the session cookie for a different upstream. + if req.Host != session.AuthorizedUpstream { + logger.WithProxyHost(req.Host).WithAuthorizedUpstream(session.AuthorizedUpstream).WithUser(session.Email).Warn( + "session authorized against different upstream; restarting authentication") + return ErrUnauthorizedUpstreamRequested + } + // Lifetime period is the entire duration in which the session is valid. // This should be set to something like 14 to 30 days. if session.LifetimePeriodExpired() { @@ -728,14 +765,17 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er } else if session.RefreshPeriodExpired() { // Refresh period is the period in which the access token is valid. This is ultimately // controlled by the upstream provider and tends to be around 1 hour. - ok, err := p.provider.RefreshSession(session, allowedGroups) + // If it has expired we: + // - attempt to refresh the session + // - run email domain, email address, and email group validations against the session (if defined). + + ok, err := p.provider.RefreshSessionToken(session) // We failed to refresh the session successfully // clear the cookie and reject the request if err != nil { logger.WithUser(session.Email).Error(err, "refreshing session failed") return err } - if !ok { // User is not authorized after refresh // clear the cookie and reject the request @@ -744,6 +784,18 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return ErrUserNotAuthorized } + err = p.runValidatorsWithGracePeriod(session) + if err != nil { + switch err { + case providers.ErrAuthProviderUnavailable: + tags = append(tags, "action:refresh_session", "error:validation_failed") + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + session.RefreshDeadline = sessions.ExtendDeadline(p.provider.Data().SessionValidTTL) + default: + return ErrUserNotAuthorized + } + } + err = p.sessionStore.SaveSession(rw, req, session) if err != nil { // We refreshed the session successfully, but failed to save it. @@ -757,9 +809,13 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er } else if session.ValidationPeriodExpired() { // Validation period has expired, this is the shortest interval we use to // check for valid requests. This should be set to something like a minute. - // This calls up the provider chain to validate this user is still active - // and hasn't been de-authorized. - ok := p.provider.ValidateSessionState(session, allowedGroups) + // In this case we: + // - call up the provider chain to validate this user is still active and hasn't been de-authorized. + // - run any defined email domain, email address, and email group validators against the session + + //TODO: change this to match the RefreshSessionToken method + // (https://github.com/buzzfeed/sso/pull/275#discussion_r366448883) + ok := p.provider.ValidateSessionToken(session) if !ok { // This user is now no longer authorized, or we failed to // validate the user. @@ -769,6 +825,18 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er return ErrUserNotAuthorized } + err = p.runValidatorsWithGracePeriod(session) + if err != nil { + switch err { + case providers.ErrAuthProviderUnavailable: + tags = append(tags, "action:validate_session", "error:validation_failed") + p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) + session.ValidDeadline = sessions.ExtendDeadline(p.provider.Data().SessionValidTTL) + default: + return ErrUserNotAuthorized + } + } + err = p.sessionStore.SaveSession(rw, req, session) if err != nil { // We validated the session successfully, but failed to save it. @@ -781,25 +849,6 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) (er } } - // We revalidate group membership whenever the session is refreshed or revalidated - // just above in the call to ValidateSessionState and RefreshSession. - // To reduce strain on upstream identity providers we only revalidate email domains and - // addresses on each request here. - for _, v := range p.Validators { - _, EmailGroupValidator := v.(options.EmailGroupValidator) - - if !EmailGroupValidator { - err := v.Validate(session) - if err != nil { - tags = append(tags, "error:validation_failed") - p.StatsdClient.Incr("application_error", tags, 1.0) - logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info( - fmt.Sprintf("permission denied: unauthorized: %q", err)) - return ErrUserNotAuthorized - } - } - } - logger.WithRemoteAddress(remoteAddr).WithUser(session.Email).Info( fmt.Sprintf("authentication: user validated")) diff --git a/internal/proxy/oauthproxy_test.go b/internal/proxy/oauthproxy_test.go index d21aced8..56c66163 100644 --- a/internal/proxy/oauthproxy_test.go +++ b/internal/proxy/oauthproxy_test.go @@ -20,9 +20,9 @@ import ( "github.com/mccutchen/go-httpbin/httpbin" "github.com/buzzfeed/sso/internal/pkg/aead" - "github.com/buzzfeed/sso/internal/pkg/options" "github.com/buzzfeed/sso/internal/pkg/sessions" "github.com/buzzfeed/sso/internal/pkg/testutil" + "github.com/buzzfeed/sso/internal/pkg/validators" "github.com/buzzfeed/sso/internal/proxy/providers" ) @@ -73,9 +73,10 @@ func testSession() *sessions.SessionState { theFuture := time.Now().AddDate(100, 100, 100) return &sessions.SessionState{ - Email: "michael.bland@gsa.gov", - AccessToken: "my_access_token", - Groups: []string{"foo", "bar"}, + Email: "michael.bland@gsa.gov", + AccessToken: "my_access_token", + Groups: []string{"foo", "bar"}, + AuthorizedUpstream: "localhost", RefreshDeadline: theFuture, LifetimeDeadline: theFuture, @@ -124,7 +125,7 @@ func testNewOAuthProxy(t *testing.T, optFuncs ...func(*OAuthProxy) error) (*OAut } standardOptFuncs := []func(*OAuthProxy) error{ - SetValidators([]options.Validator{options.NewMockValidator(true)}), + SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}), SetProvider(provider), setSessionStore(&sessions.MockSessionStore{Session: testSession()}), SetUpstreamConfig(upstreamConfig), @@ -203,27 +204,28 @@ func TestFavicon(t *testing.T) { func TestAuthOnlyEndpoint(t *testing.T) { testCases := []struct { - name string - validEmail bool - sessionStore *sessions.MockSessionStore - expectedBody string - expectedCode int + name string + validatorResult bool + validatorErr error + sessionStore *sessions.MockSessionStore + expectedBody string + expectedCode int }{ { name: "accepted", sessionStore: &sessions.MockSessionStore{ Session: testSession(), }, - validEmail: true, - expectedBody: "", - expectedCode: http.StatusAccepted, + validatorResult: true, + expectedBody: "", + expectedCode: http.StatusAccepted, }, { - name: "unauthorized on no cookie set", - expectedBody: "unauthorized request\n", - sessionStore: &sessions.MockSessionStore{}, - validEmail: true, - expectedCode: http.StatusUnauthorized, + name: "unauthorized on no cookie set", + expectedBody: "unauthorized request\n", + sessionStore: &sessions.MockSessionStore{}, + validatorResult: true, + expectedCode: http.StatusUnauthorized, }, { name: "unauthorized on expiration", @@ -232,30 +234,167 @@ func TestAuthOnlyEndpoint(t *testing.T) { LifetimeDeadline: time.Now().Add(-1 * time.Hour), }, }, - validEmail: true, + // it should error before running the validator + validatorResult: true, + expectedBody: "unauthorized request\n", + expectedCode: http.StatusUnauthorized, + }, + { + name: "authorized: refresh period expired, validations pass", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + AuthorizedUpstream: "localhost", + }, + }, + validatorResult: true, + expectedBody: "", + expectedCode: http.StatusAccepted, + }, + { + name: "authorized: refresh period expired, idp unavailable, within grace period", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(1 * time.Hour), + AuthorizedUpstream: "localhost", + }, + }, + //provider unavailable instead of hard valiation fail + validatorErr: providers.ErrAuthProviderUnavailable, + expectedBody: "", + expectedCode: http.StatusAccepted, + }, + { + name: "unauthorized: refresh period expired, idp unavailable, outside of grace period", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(-1 * time.Hour), + }, + }, + //provider unavailable instead of hard valiation fail + validatorErr: providers.ErrAuthProviderUnavailable, expectedBody: "unauthorized request\n", expectedCode: http.StatusUnauthorized, }, { - name: "unauthorized on email validation failure", + name: "unauthorized: refresh period expired, idp available, outside of grace period", sessionStore: &sessions.MockSessionStore{ - Session: testSession(), + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(-1 * time.Hour), + }, + }, + validatorResult: false, + expectedBody: "unauthorized request\n", + expectedCode: http.StatusUnauthorized, + }, + { + name: "unauthorized: refresh period expired, idp available, validator hard fail", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + }, + }, + validatorResult: false, + expectedBody: "unauthorized request\n", + expectedCode: http.StatusUnauthorized, + }, + { + name: "authorized: validation period expired, validations pass", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + ValidDeadline: time.Now().Add(-1 * time.Hour), + AuthorizedUpstream: "localhost", + }, + }, + validatorResult: true, + expectedBody: "", + expectedCode: http.StatusAccepted, + }, + { + name: "authorized: validation period expired, idp unavailable, within grace period", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + ValidDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(1 * time.Hour), + AuthorizedUpstream: "localhost", + }, + }, + //provider unavailable instead of hard valiation fail + validatorErr: providers.ErrAuthProviderUnavailable, + expectedBody: "", + expectedCode: http.StatusAccepted, + }, + { + name: "unauthorized: validatoion period expired, idp unavailable, outside of grace period", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + ValidDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(-1 * time.Hour), + }, }, + //provider unavailable instead of hard valiation fail + validatorErr: providers.ErrAuthProviderUnavailable, expectedBody: "unauthorized request\n", expectedCode: http.StatusUnauthorized, }, + { + name: "unauthorized: validation period expired, idp available, outside of grace period", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + RefreshDeadline: time.Now().Add(-1 * time.Hour), + GracePeriodStart: time.Now().Add(-1 * time.Hour), + }, + }, + validatorResult: false, + expectedBody: "unauthorized request\n", + expectedCode: http.StatusUnauthorized, + }, + { + name: "unauthorized: validation period expired, idp available, validator hard fail", + sessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + //refresh period expired, but within grace period + LifetimeDeadline: time.Now().Add(1 * time.Hour), + ValidDeadline: time.Now().Add(-1 * time.Hour), + }, + }, + validatorResult: false, + expectedBody: "unauthorized request\n", + expectedCode: http.StatusUnauthorized, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { providerURL, _ := url.Parse("http://localhost/") tp := providers.NewTestProvider(providerURL, "") - tp.RefreshSessionFunc = func(*sessions.SessionState, []string) (bool, error) { return true, nil } - tp.ValidateSessionFunc = func(*sessions.SessionState, []string) bool { return true } + tp.RefreshSessionTokenFunc = func(*sessions.SessionState) (bool, error) { return true, nil } + tp.ValidateSessionTokenFunc = func(*sessions.SessionState) bool { return true } proxy, close := testNewOAuthProxy(t, setSessionStore(tc.sessionStore), - SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}), + SetValidators([]validators.Validator{validators.NewMockValidator(tc.validatorResult, tc.validatorErr)}), SetProvider(tp), ) defer close() @@ -412,19 +551,20 @@ func TestAuthenticate(t *testing.T) { testCases := []struct { Name string - SessionStore *sessions.MockSessionStore - ExpectedErr error - CookieExpectation int // One of: {NewCookie, ClearCookie, KeepCookie} - RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error) - ValidateSessionFunc func(*sessions.SessionState, []string) bool + SessionStore *sessions.MockSessionStore + ExpectedErr error + CookieExpectation int // One of: {NewCookie, ClearCookie, KeepCookie} + RefreshSessionTokenFunc func(*sessions.SessionState) (bool, error) + ValidateSessionTokenFunc func(*sessions.SessionState) bool }{ { Name: "redirect if deadlines are blank", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", }, }, ExpectedErr: ErrLifetimeExpired, @@ -443,11 +583,12 @@ func TestAuthenticate(t *testing.T) { Name: "authenticate successfully", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectedErr: nil, @@ -457,11 +598,12 @@ func TestAuthenticate(t *testing.T) { Name: "lifetime expired, do not authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectedErr: ErrLifetimeExpired, @@ -471,115 +613,137 @@ func TestAuthenticate(t *testing.T) { Name: "refresh expired, refresh fails, do not authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - ExpectedErr: ErrRefreshFailed, - CookieExpectation: ClearCookie, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, ErrRefreshFailed }, + ExpectedErr: ErrRefreshFailed, + CookieExpectation: ClearCookie, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return false, ErrRefreshFailed }, }, { Name: "refresh expired, user not OK, do not authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - ExpectedErr: ErrUserNotAuthorized, - CookieExpectation: ClearCookie, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, nil }, + ExpectedErr: ErrUserNotAuthorized, + CookieExpectation: ClearCookie, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return false, nil }, }, { Name: "refresh expired, user OK, authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - ExpectedErr: nil, - CookieExpectation: NewCookie, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, + ExpectedErr: nil, + CookieExpectation: NewCookie, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return true, nil }, }, { Name: "refresh expired, refresh and user OK, error saving session", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, SaveError: SaveCookieFailed, }, - ExpectedErr: SaveCookieFailed, - CookieExpectation: ClearCookie, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, + ExpectedErr: SaveCookieFailed, + CookieExpectation: ClearCookie, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return true, nil }, }, { Name: "validation expired, user not OK, do not authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), }, }, - ExpectedErr: ErrUserNotAuthorized, - CookieExpectation: ClearCookie, - ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return false }, + ExpectedErr: ErrUserNotAuthorized, + CookieExpectation: ClearCookie, + ValidateSessionTokenFunc: func(s *sessions.SessionState) bool { return false }, }, { Name: "validation expired, user OK, authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), }, }, - ExpectedErr: nil, - CookieExpectation: NewCookie, - ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true }, + ExpectedErr: nil, + CookieExpectation: NewCookie, + ValidateSessionTokenFunc: func(s *sessions.SessionState) bool { return true }, }, { Name: "wrong identity provider, user OK, do not authenticate", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - ProviderSlug: "example", - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + ProviderSlug: "example", + Email: "email1@example.com", + AuthorizedUpstream: "localhost", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectedErr: ErrWrongIdentityProvider, CookieExpectation: ClearCookie, }, + { + Name: "authorized against different upstream, user OK, do not authenticate", + SessionStore: &sessions.MockSessionStore{ + Session: &sessions.SessionState{ + AuthorizedUpstream: "foo", + Email: "email1@example.com", + AccessToken: "my_access_token", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + }, + }, + ExpectedErr: ErrUnauthorizedUpstreamRequested, + CookieExpectation: ClearCookie, + }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { providerURL, _ := url.Parse("http://localhost/") tp := providers.NewTestProvider(providerURL, "") - tp.RefreshSessionFunc = tc.RefreshSessionFunc - tp.ValidateSessionFunc = tc.ValidateSessionFunc + tp.RefreshSessionTokenFunc = tc.RefreshSessionTokenFunc + tp.ValidateSessionTokenFunc = tc.ValidateSessionTokenFunc proxy, close := testNewOAuthProxy(t, SetProvider(tp), @@ -619,9 +783,9 @@ func TestAuthenticationUXFlows(t *testing.T) { testCases := []struct { Name string - SessionStore *sessions.MockSessionStore - RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error) - ValidateSessionFunc func(*sessions.SessionState, []string) bool + SessionStore *sessions.MockSessionStore + RefreshSessionTokenFunc func(*sessions.SessionState) (bool, error) + ValidateSessionTokenFunc func(*sessions.SessionState) bool ExpectStatusCode int }{ @@ -629,8 +793,9 @@ func TestAuthenticationUXFlows(t *testing.T) { Name: "missing deadlines, redirect to sign-in", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", }, }, ExpectStatusCode: http.StatusFound, @@ -647,11 +812,12 @@ func TestAuthenticationUXFlows(t *testing.T) { Name: "authenticate successfully, expect ok", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectStatusCode: http.StatusOK, @@ -660,11 +826,12 @@ func TestAuthenticationUXFlows(t *testing.T) { Name: "lifetime expired, redirect to sign-in", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectStatusCode: http.StatusFound, @@ -673,97 +840,104 @@ func TestAuthenticationUXFlows(t *testing.T) { Name: "refresh expired, refresh fails, show error", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, ErrRefreshFailed }, - ExpectStatusCode: http.StatusInternalServerError, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return false, ErrRefreshFailed }, + ExpectStatusCode: http.StatusInternalServerError, }, { Name: "refresh expired, user not OK, deny", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return false, nil }, - ExpectStatusCode: http.StatusForbidden, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return false, nil }, + ExpectStatusCode: http.StatusForbidden, }, { Name: "refresh expired, user OK, expect ok", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, - ExpectStatusCode: http.StatusOK, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return true, nil }, + ExpectStatusCode: http.StatusOK, }, { Name: "refresh expired, refresh and user OK, error saving session, show error", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, SaveError: SaveCookieFailed, }, - RefreshSessionFunc: func(s *sessions.SessionState, g []string) (bool, error) { return true, nil }, - ExpectStatusCode: http.StatusInternalServerError, + RefreshSessionTokenFunc: func(s *sessions.SessionState) (bool, error) { return true, nil }, + ExpectStatusCode: http.StatusInternalServerError, }, { Name: "validation expired, user not OK, deny", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), }, }, - ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return false }, - ExpectStatusCode: http.StatusForbidden, + ValidateSessionTokenFunc: func(s *sessions.SessionState) bool { return false }, + ExpectStatusCode: http.StatusForbidden, }, { Name: "validation expired, user OK, expect ok", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(-1) * time.Minute), }, }, - ValidateSessionFunc: func(s *sessions.SessionState, g []string) bool { return true }, - ExpectStatusCode: http.StatusOK, + ValidateSessionTokenFunc: func(s *sessions.SessionState) bool { return true }, + ExpectStatusCode: http.StatusOK, }, { Name: "wrong identity provider, redirect to sign-in", SessionStore: &sessions.MockSessionStore{ Session: &sessions.SessionState{ - ProviderSlug: "example", - Email: "email1@example.com", - AccessToken: "my_access_token", - LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), + ProviderSlug: "example", + Email: "email1@example.com", + AccessToken: "my_access_token", + AuthorizedUpstream: "localhost", + LifetimeDeadline: time.Now().Add(time.Duration(24) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Minute), }, }, ExpectStatusCode: http.StatusFound, @@ -773,8 +947,8 @@ func TestAuthenticationUXFlows(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { providerURL, _ := url.Parse("http://localhost/") tp := providers.NewTestProvider(providerURL, "") - tp.RefreshSessionFunc = tc.RefreshSessionFunc - tp.ValidateSessionFunc = tc.ValidateSessionFunc + tp.RefreshSessionTokenFunc = tc.RefreshSessionTokenFunc + tp.ValidateSessionTokenFunc = tc.ValidateSessionTokenFunc proxy, close := testNewOAuthProxy(t, SetProvider(tp), @@ -809,7 +983,8 @@ func TestProxyXHRErrorHandling(t *testing.T) { { Name: "expired session should redirect on normal request (GET)", Session: &sessions.SessionState{ - LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + AuthorizedUpstream: "localhost", }, Method: "GET", ExpectedCode: http.StatusFound, @@ -817,7 +992,8 @@ func TestProxyXHRErrorHandling(t *testing.T) { { Name: "expired session should proxy preflight request (OPTIONS)", Session: &sessions.SessionState{ - LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + AuthorizedUpstream: "localhost", }, Method: "OPTIONS", ExpectedCode: http.StatusFound, @@ -825,7 +1001,8 @@ func TestProxyXHRErrorHandling(t *testing.T) { { Name: "expired session should return error code when XMLHttpRequest", Session: &sessions.SessionState{ - LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + LifetimeDeadline: time.Now().Add(time.Duration(-24) * time.Hour), + AuthorizedUpstream: "localhost", }, Method: "GET", Header: map[string]string{ @@ -845,9 +1022,10 @@ func TestProxyXHRErrorHandling(t *testing.T) { { Name: "valid session should proxy as normal when XMLHttpRequest", Session: &sessions.SessionState{ - LifetimeDeadline: time.Now().Add(time.Duration(1) * time.Hour), - RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), - ValidDeadline: time.Now().Add(time.Duration(1) * time.Hour), + LifetimeDeadline: time.Now().Add(time.Duration(1) * time.Hour), + RefreshDeadline: time.Now().Add(time.Duration(1) * time.Hour), + ValidDeadline: time.Now().Add(time.Duration(1) * time.Hour), + AuthorizedUpstream: "localhost", }, Method: "GET", Header: map[string]string{ diff --git a/internal/proxy/providers/providers.go b/internal/proxy/providers/providers.go index 226dc16f..f068dbc0 100644 --- a/internal/proxy/providers/providers.go +++ b/internal/proxy/providers/providers.go @@ -13,10 +13,10 @@ type Provider interface { Redeem(string, string) (*sessions.SessionState, error) ValidateGroup(string, []string, string) ([]string, bool, error) UserGroups(string, []string, string) ([]string, error) - ValidateSessionState(*sessions.SessionState, []string) bool + ValidateSessionToken(*sessions.SessionState) bool + RefreshSessionToken(*sessions.SessionState) (bool, error) GetSignInURL(redirectURL *url.URL, finalRedirect string) *url.URL GetSignOutURL(redirectURL *url.URL) *url.URL - RefreshSession(*sessions.SessionState, []string) (bool, error) } // New returns a new sso Provider diff --git a/internal/proxy/providers/singleflight_middleware.go b/internal/proxy/providers/singleflight_middleware.go index 40e6fcf7..3c6ec2bd 100644 --- a/internal/proxy/providers/singleflight_middleware.go +++ b/internal/proxy/providers/singleflight_middleware.go @@ -94,10 +94,10 @@ func (p *SingleFlightProvider) UserGroups(email string, groups []string, accessT return groups, nil } -// ValidateSessionState calls the provider's ValidateSessionState function and returns the response -func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState, allowedGroups []string) bool { - response, err := p.do("ValidateSessionState", s.AccessToken, func() (interface{}, error) { - valid := p.provider.ValidateSessionState(s, allowedGroups) +// ValidateSessionToken calls the provider's ValidateSessionToken function and returns the response +func (p *SingleFlightProvider) ValidateSessionToken(s *sessions.SessionState) bool { + response, err := p.do("ValidateSessionToken", s.AccessToken, func() (interface{}, error) { + valid := p.provider.ValidateSessionToken(s) return valid, nil }) if err != nil { @@ -112,11 +112,11 @@ func (p *SingleFlightProvider) ValidateSessionState(s *sessions.SessionState, al return valid } -// RefreshSession takes in a SessionState and allowedGroups and +// RefreshSessionToken takes in a SessionState and // returns false if the session is not refreshed and true if it is. -func (p *SingleFlightProvider) RefreshSession(s *sessions.SessionState, allowedGroups []string) (bool, error) { - response, err := p.do("RefreshSession", s.RefreshToken, func() (interface{}, error) { - return p.provider.RefreshSession(s, allowedGroups) +func (p *SingleFlightProvider) RefreshSessionToken(s *sessions.SessionState) (bool, error) { + response, err := p.do("RefreshSessionToken", s.RefreshToken, func() (interface{}, error) { + return p.provider.RefreshSessionToken(s) }) if err != nil { return false, err diff --git a/internal/proxy/providers/sso.go b/internal/proxy/providers/sso.go index 1aca089c..bd17ba9a 100644 --- a/internal/proxy/providers/sso.go +++ b/internal/proxy/providers/sso.go @@ -98,17 +98,6 @@ func isProviderUnavailable(statusCode int) bool { return statusCode == http.StatusTooManyRequests || statusCode == http.StatusServiceUnavailable } -func extendDeadline(ttl time.Duration) time.Time { - return time.Now().Add(ttl).Truncate(time.Second) -} - -func (p *SSOProvider) withinGracePeriod(s *sessions.SessionState) bool { - if s.GracePeriodStart.IsZero() { - s.GracePeriodStart = time.Now() - } - return s.GracePeriodStart.Add(p.GracePeriodTTL).After(time.Now()) -} - // Redeem takes a redirectURL and code and redeems the SessionState func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState, error) { if code == "" { @@ -165,9 +154,9 @@ func (p *SSOProvider) Redeem(redirectURL, code string) (*sessions.SessionState, AccessToken: jsonResponse.AccessToken, RefreshToken: jsonResponse.RefreshToken, - RefreshDeadline: extendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second), - LifetimeDeadline: extendDeadline(p.SessionLifetimeTTL), - ValidDeadline: extendDeadline(p.SessionValidTTL), + RefreshDeadline: sessions.ExtendDeadline(time.Duration(jsonResponse.ExpiresIn) * time.Second), + LifetimeDeadline: sessions.ExtendDeadline(p.SessionLifetimeTTL), + ValidDeadline: sessions.ExtendDeadline(p.SessionValidTTL), Email: jsonResponse.Email, User: user, @@ -245,9 +234,9 @@ func (p *SSOProvider) UserGroups(email string, groups []string, accessToken stri return jsonResponse.Groups, nil } -// RefreshSession takes a SessionState and allowedGroups and refreshes the session access token, +// RefreshSessionToken takes a SessionState and refreshes the session access token, // returns `true` on success, and `false` on error -func (p *SSOProvider) RefreshSession(s *sessions.SessionState, allowedGroups []string) (bool, error) { +func (p *SSOProvider) RefreshSessionToken(s *sessions.SessionState) (bool, error) { logger := log.NewLogEntry() if s.RefreshToken == "" { @@ -259,35 +248,17 @@ func (p *SSOProvider) RefreshSession(s *sessions.SessionState, allowedGroups []s // When we detect that the auth provider is not explicitly denying // authentication, and is merely unavailable, we refresh and continue // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { + if err == ErrAuthProviderUnavailable && s.IsWithinGracePeriod(p.GracePeriodTTL) { tags := []string{"action:refresh_session", "error:redeem_token_failed"} p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.RefreshDeadline = extendDeadline(p.SessionValidTTL) + s.RefreshDeadline = sessions.ExtendDeadline(p.SessionValidTTL) return true, nil } return false, err } - inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups, newToken) - if err != nil { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we refresh and continue - // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { - tags := []string{"action:refresh_session", "error:user_groups_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.RefreshDeadline = extendDeadline(p.SessionValidTTL) - return true, nil - } - return false, err - } - if !validGroup { - return false, errors.New("Group membership revoked") - } - s.Groups = inGroups - s.AccessToken = newToken - s.RefreshDeadline = extendDeadline(duration) + s.RefreshDeadline = sessions.ExtendDeadline(duration) s.GracePeriodStart = time.Time{} logger.WithUser(s.Email).WithRefreshDeadline(s.RefreshDeadline).Info("refreshed session access token") return true, nil @@ -340,8 +311,8 @@ func (p *SSOProvider) redeemRefreshToken(refreshToken string) (token string, exp return } -// ValidateSessionState takes a sessionState and allowedGroups and validates the session state -func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGroups []string) bool { +// ValidateSessionToken takes a sessionState and validates the session token +func (p *SSOProvider) ValidateSessionToken(s *sessions.SessionState) bool { logger := log.NewLogEntry() // we validate the user's access token is valid @@ -349,7 +320,7 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou params.Add("client_id", p.ClientID) req, err := p.newRequest("GET", fmt.Sprintf("%s?%s", p.ValidateURL.String(), params.Encode()), nil) if err != nil { - logger.WithUser(s.Email).Error(err, "error validating session state") + logger.WithUser(s.Email).Error(err, "error validating session access token") return false } @@ -358,7 +329,7 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou resp, err := httpClient.Do(req) if err != nil { - logger.WithUser(s.Email).Error("error making request to validate access token") + logger.WithUser(s.Email).Error("error making request to validate session access token") return false } @@ -366,44 +337,21 @@ func (p *SSOProvider) ValidateSessionState(s *sessions.SessionState, allowedGrou // When we detect that the auth provider is not explicitly denying // authentication, and is merely unavailable, we validate and continue // as normal during the "grace period" - if isProviderUnavailable(resp.StatusCode) && p.withinGracePeriod(s) { + if isProviderUnavailable(resp.StatusCode) && s.IsWithinGracePeriod(p.GracePeriodTTL) { tags := []string{"action:validate_session", "error:validation_failed"} p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.ValidDeadline = extendDeadline(p.SessionValidTTL) + s.ValidDeadline = sessions.ExtendDeadline(p.SessionValidTTL) return true } logger.WithUser(s.Email).WithHTTPStatus(resp.StatusCode).Info( - "could not validate user access token") - return false - } - - // check the user is in the proper group(s) - inGroups, validGroup, err := p.ValidateGroup(s.Email, allowedGroups, s.AccessToken) - if err != nil { - // When we detect that the auth provider is not explicitly denying - // authentication, and is merely unavailable, we validate and continue - // as normal during the "grace period" - if err == ErrAuthProviderUnavailable && p.withinGracePeriod(s) { - tags := []string{"action:validate_session", "error:user_groups_failed"} - p.StatsdClient.Incr("provider_error_fallback", tags, 1.0) - s.ValidDeadline = extendDeadline(p.SessionValidTTL) - return true - } - logger.WithUser(s.Email).Error(err, "error fetching group memberships") - return false - } - - if !validGroup { - logger.WithUser(s.Email).WithAllowedGroups(allowedGroups).Info( - "user is no longer in valid groups") + "could not validate session access token") return false } - s.Groups = inGroups - s.ValidDeadline = extendDeadline(p.SessionValidTTL) + s.ValidDeadline = sessions.ExtendDeadline(p.SessionValidTTL) s.GracePeriodStart = time.Time{} - logger.WithUser(s.Email).WithSessionValid(s.ValidDeadline).Info("validated session") + logger.WithUser(s.Email).WithSessionValid(s.ValidDeadline).Info("validated session access token") return true } diff --git a/internal/proxy/providers/sso_test.go b/internal/proxy/providers/sso_test.go index 1b9a2244..f300529e 100644 --- a/internal/proxy/providers/sso_test.go +++ b/internal/proxy/providers/sso_test.go @@ -301,26 +301,13 @@ func TestSSOProviderRedeem(t *testing.T) { }) } } -func TestSSOProviderValidateSessionState(t *testing.T) { +func TestSSOProviderValidateSessionToken(t *testing.T) { testCases := []struct { Name string SessionState *sessions.SessionState ProviderResponse int - Groups []string - ProxyGroupIds []string ExpectedValid bool }{ - { - Name: "invalid when no group id set", - SessionState: &sessions.SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{}, - ProxyGroupIds: []string{}, - ExpectedValid: false, - }, { Name: "invalid when response is is not 200", SessionState: &sessions.SessionState{ @@ -328,30 +315,6 @@ func TestSSOProviderValidateSessionState(t *testing.T) { Email: "michael.bland@gsa.gov", }, ProviderResponse: http.StatusForbidden, - Groups: []string{}, - ProxyGroupIds: []string{}, - ExpectedValid: false, - }, - { - Name: "valid when the group id exists", - SessionState: &sessions.SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{"test1", "test2"}, - ProxyGroupIds: []string{"test1"}, - ExpectedValid: true, - }, - { - Name: "invalid when the group id isn't in user groups", - SessionState: &sessions.SessionState{ - AccessToken: "abc", - Email: "michael.bland@gsa.gov", - }, - ProviderResponse: http.StatusOK, - Groups: []string{}, - ProxyGroupIds: []string{"test1"}, ExpectedValid: false, }, { @@ -376,16 +339,6 @@ func TestSSOProviderValidateSessionState(t *testing.T) { p := newSSOProvider() p.GracePeriodTTL = time.Duration(3) * time.Hour - // setup group endpoint - body, err := json.Marshal(profileResponse{ - Email: tc.SessionState.Email, - Groups: tc.Groups, - }) - testutil.Equal(t, nil, err) - var profileServer *httptest.Server - p.ProfileURL, profileServer = newTestServer(http.StatusOK, body) - defer profileServer.Close() - validateServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { accessToken := r.Header.Get("X-Access-Token") if accessToken != tc.SessionState.AccessToken { @@ -398,7 +351,7 @@ func TestSSOProviderValidateSessionState(t *testing.T) { p.ValidateURL, _ = url.Parse(validateServer.URL) defer validateServer.Close() - valid := p.ValidateSessionState(tc.SessionState, tc.ProxyGroupIds) + valid := p.ValidateSessionToken(tc.SessionState) if valid != tc.ExpectedValid { t.Errorf("got unexpected result. want=%v got=%v", tc.ExpectedValid, valid) } @@ -406,12 +359,10 @@ func TestSSOProviderValidateSessionState(t *testing.T) { } } -func TestSSOProviderRefreshSession(t *testing.T) { +func TestSSOProviderRefreshSessionToken(t *testing.T) { testCases := []struct { Name string SessionState *sessions.SessionState - UserGroups []string - ProxyGroups []string RefreshResponse *refreshResponse ExpectedRefresh bool ExpectedError string @@ -456,50 +407,13 @@ func TestSSOProviderRefreshSession(t *testing.T) { ExpectedError: "got 400", }, { - Name: "no refresh if profile not responding", - SessionState: &sessions.SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - RefreshResponse: &refreshResponse{ - Code: http.StatusCreated, - ExpiresIn: 10, - AccessToken: "newToken1234", - }, - ProxyGroups: []string{"test1"}, - ExpectedRefresh: false, - ExpectedError: "got 500", - }, - { - Name: "no refresh if user no longer in group", + Name: "successful refresh if can redeem", SessionState: &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), RefreshToken: "refresh1234", }, - UserGroups: []string{"useless"}, - ProxyGroups: []string{"test1"}, - RefreshResponse: &refreshResponse{ - Code: http.StatusCreated, - ExpiresIn: 10, - AccessToken: "newToken1234", - }, - ExpectedRefresh: false, - ExpectedError: "Group membership revoked", - }, - { - Name: "successful refresh if can redeem and user in group", - SessionState: &sessions.SessionState{ - Email: "user@domain.com", - AccessToken: "token1234", - RefreshDeadline: time.Now().Add(time.Duration(-1) * time.Hour), - RefreshToken: "refresh1234", - }, - UserGroups: []string{"test1"}, - ProxyGroups: []string{"test1"}, RefreshResponse: &refreshResponse{ Code: http.StatusCreated, ExpiresIn: 10, @@ -536,11 +450,6 @@ func TestSSOProviderRefreshSession(t *testing.T) { p := newSSOProvider() p.GracePeriodTTL = time.Duration(3) * time.Hour - groups := []string{} - if tc.ProxyGroups != nil { - groups = tc.ProxyGroups - } - // set up redeem resource var refreshServer *httptest.Server body, err := json.Marshal(tc.RefreshResponse) @@ -548,22 +457,8 @@ func TestSSOProviderRefreshSession(t *testing.T) { p.RefreshURL, refreshServer = newTestServer(tc.RefreshResponse.Code, body) defer refreshServer.Close() - // set up groups resource - var groupsServer *httptest.Server - if tc.UserGroups != nil { - body, err := json.Marshal(profileResponse{ - Email: tc.SessionState.Email, - Groups: tc.UserGroups, - }) - testutil.Equal(t, nil, err) - p.ProfileURL, groupsServer = newTestServer(http.StatusOK, body) - } else { - p.ProfileURL, groupsServer = newTestServer(http.StatusInternalServerError, []byte{}) - } - defer groupsServer.Close() - // run the endpoint - actualRefresh, err := p.RefreshSession(tc.SessionState, groups) + actualRefresh, err := p.RefreshSessionToken(tc.SessionState) if tc.ExpectedRefresh != actualRefresh { t.Fatalf("got unexpected refresh behavior. want=%v got=%v", tc.ExpectedRefresh, actualRefresh) } diff --git a/internal/proxy/providers/test_provider.go b/internal/proxy/providers/test_provider.go index 35494786..e3f60569 100644 --- a/internal/proxy/providers/test_provider.go +++ b/internal/proxy/providers/test_provider.go @@ -8,11 +8,11 @@ import ( // TestProvider is a mock provider type TestProvider struct { - RefreshSessionFunc func(*sessions.SessionState, []string) (bool, error) - ValidateSessionFunc func(*sessions.SessionState, []string) bool - RedeemFunc func(string, string) (*sessions.SessionState, error) - UserGroupsFunc func(string, []string, string) ([]string, error) - ValidateGroupsFunc func(string, []string, string) ([]string, bool, error) + RefreshSessionTokenFunc func(*sessions.SessionState) (bool, error) + ValidateSessionTokenFunc func(*sessions.SessionState) bool + RedeemFunc func(string, string) (*sessions.SessionState, error) + UserGroupsFunc func(string, []string, string) ([]string, error) + ValidateGroupsFunc func(string, []string, string) ([]string, bool, error) *ProviderData } @@ -47,8 +47,8 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { } // ValidateSessionState mocks the ValidateSessionState function -func (tp *TestProvider) ValidateSessionState(s *sessions.SessionState, groups []string) bool { - return tp.ValidateSessionFunc(s, groups) +func (tp *TestProvider) ValidateSessionToken(s *sessions.SessionState) bool { + return tp.ValidateSessionTokenFunc(s) } // Redeem mocks the provider Redeem function @@ -56,9 +56,9 @@ func (tp *TestProvider) Redeem(redirectURL string, token string) (*sessions.Sess return tp.RedeemFunc(redirectURL, token) } -// RefreshSession mocks the RefreshSession function -func (tp *TestProvider) RefreshSession(s *sessions.SessionState, g []string) (bool, error) { - return tp.RefreshSessionFunc(s, g) +// RefreshSessionToken mocks the RefreshSessionToken function +func (tp *TestProvider) RefreshSessionToken(s *sessions.SessionState) (bool, error) { + return tp.RefreshSessionTokenFunc(s) } // UserGroups mocks the UserGroups function diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 5cc5c1c0..2c5f35b7 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -5,7 +5,7 @@ import ( "net/http" "github.com/buzzfeed/sso/internal/pkg/hostmux" - "github.com/buzzfeed/sso/internal/pkg/options" + "github.com/buzzfeed/sso/internal/pkg/validators" ) type SSOProxy struct { @@ -38,17 +38,17 @@ func New(opts *Options) (*SSOProxy, error) { return nil, err } - validators := []options.Validator{} + v := []validators.Validator{} if len(upstreamConfig.AllowedEmailAddresses) != 0 { - validators = append(validators, options.NewEmailAddressValidator(upstreamConfig.AllowedEmailAddresses)) + v = append(v, validators.NewEmailAddressValidator(upstreamConfig.AllowedEmailAddresses)) } if len(upstreamConfig.AllowedEmailDomains) != 0 { - validators = append(validators, options.NewEmailDomainValidator(upstreamConfig.AllowedEmailDomains)) + v = append(v, validators.NewEmailDomainValidator(upstreamConfig.AllowedEmailDomains)) } if len(upstreamConfig.AllowedGroups) != 0 { - validators = append(validators, options.NewEmailGroupValidator(provider, upstreamConfig.AllowedGroups)) + v = append(v, validators.NewEmailGroupValidator(provider, upstreamConfig.AllowedGroups)) } optFuncs = append(optFuncs, @@ -56,7 +56,7 @@ func New(opts *Options) (*SSOProxy, error) { SetCookieStore(opts), SetUpstreamConfig(upstreamConfig), SetProxyHandler(handler), - SetValidators(validators), + SetValidators(v), ) oauthproxy, err := NewOAuthProxy(opts, optFuncs...)