Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sso_proxy: reduce direct calls to ValidateGroup() and clean up logic #275

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/templates"
"github.com/buzzfeed/sso/internal/pkg/validators"

"github.com/datadog/datadog-go/statsd"
)

// Authenticator stores all the information associated with proxying the request.
type Authenticator struct {
Validators []options.Validator
Validators []validators.Validator
EmailDomains []string
ProxyRootDomains []string
Host string
Expand Down Expand Up @@ -226,7 +226,7 @@ func (p *Authenticator) authenticate(rw http.ResponseWriter, req *http.Request)
}
}

errors := options.RunValidators(p.Validators, session)
errors := validators.RunValidators(p.Validators, session)
if len(errors) == len(p.Validators) {
logger.WithUser(session.Email).Info(
fmt.Sprintf("permission denied: unauthorized: %q", errors))
Expand Down Expand Up @@ -582,7 +582,7 @@ func (p *Authenticator) getOAuthCallback(rw http.ResponseWriter, req *http.Reque
// - for p.Validator see validator.go#newValidatorImpl for more info
// - for p.provider.ValidateGroup see providers/google.go#ValidateGroup for more info

errors := options.RunValidators(p.Validators, session)
errors := validators.RunValidators(p.Validators, session)
if len(errors) == len(p.Validators) {
tags := append(tags, "error:invalid_email")
p.StatsdClient.Incr("application_error", tags, 1.0)
Expand Down
14 changes: 7 additions & 7 deletions internal/auth/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ import (

"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/templates"
"github.com/buzzfeed/sso/internal/pkg/testutil"
"github.com/buzzfeed/sso/internal/pkg/validators"
)

func init() {
Expand Down Expand Up @@ -418,7 +418,7 @@ func TestSignIn(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
auth, err := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
SetValidators([]validators.Validator{validators.NewMockValidator(tc.validEmail, nil)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setMockRedirectURL(),
Expand Down Expand Up @@ -565,7 +565,7 @@ func TestSignOutPage(t *testing.T) {
provider.RevokeError = tc.RevokeError

p, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}),
setMockSessionStore(tc.mockSessionStore),
setMockTempl(),
setTestProvider(provider),
Expand Down Expand Up @@ -942,7 +942,7 @@ func TestGetProfile(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
p, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}),
)
u, _ := url.Parse("http://example.com")
testProvider := providers.NewTestProvider(u)
Expand Down Expand Up @@ -1044,7 +1044,7 @@ func TestRedeemCode(t *testing.T) {
config := testConfiguration(t)

proxy, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}),
)

testURL, err := url.Parse("example.com")
Expand Down Expand Up @@ -1433,7 +1433,7 @@ func TestOAuthCallback(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
config := testConfiguration(t)
proxy, _ := NewAuthenticator(config,
SetValidators([]options.Validator{options.NewMockValidator(tc.validEmail)}),
SetValidators([]validators.Validator{validators.NewMockValidator(tc.validEmail, nil)}),
setMockCSRFStore(tc.csrfResp),
setMockSessionStore(tc.sessionStore),
)
Expand Down Expand Up @@ -1554,7 +1554,7 @@ func TestOAuthStart(t *testing.T) {
provider := providers.NewTestProvider(nil)
proxy, _ := NewAuthenticator(config,
setTestProvider(provider),
SetValidators([]options.Validator{options.NewMockValidator(true)}),
SetValidators([]validators.Validator{validators.NewMockValidator(true, nil)}),
setMockRedirectURL(),
setMockCSRFStore(&sessions.MockCSRFStore{}),
)
Expand Down
10 changes: 5 additions & 5 deletions internal/auth/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/buzzfeed/sso/internal/pkg/hostmux"
log "github.com/buzzfeed/sso/internal/pkg/logging"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/validators"

"github.com/datadog/datadog-go/statsd"
)
Expand All @@ -18,11 +18,11 @@ type AuthenticatorMux struct {

func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*AuthenticatorMux, error) {
logger := log.NewLogEntry()
validators := []options.Validator{}
v := []validators.Validator{}
if len(config.AuthorizeConfig.EmailConfig.Addresses) != 0 {
validators = append(validators, options.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses))
v = append(v, validators.NewEmailAddressValidator(config.AuthorizeConfig.EmailConfig.Addresses))
} else {
validators = append(validators, options.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains))
v = append(v, validators.NewEmailDomainValidator(config.AuthorizeConfig.EmailConfig.Domains))
}

authenticators := []*Authenticator{}
Expand All @@ -37,7 +37,7 @@ func NewAuthenticatorMux(config Configuration, statsdClient *statsd.Client) (*Au

idpSlug := idp.Data().ProviderSlug
authenticator, err := NewAuthenticator(config,
SetValidators(validators),
SetValidators(v),
SetProvider(idp),
SetCookieStore(config.SessionConfig, idpSlug),
SetStatsdClient(statsdClient),
Expand Down
4 changes: 2 additions & 2 deletions internal/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"github.com/buzzfeed/sso/internal/auth/providers"
"github.com/buzzfeed/sso/internal/pkg/aead"
"github.com/buzzfeed/sso/internal/pkg/groups"
"github.com/buzzfeed/sso/internal/pkg/options"
"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/pkg/validators"

"github.com/datadog/datadog-go/statsd"
)
Expand Down Expand Up @@ -97,7 +97,7 @@ func SetRedirectURL(serverConfig ServerConfig, slug string) func(*Authenticator)
}

// SetValidator sets the email validator
func SetValidators(validators []options.Validator) func(*Authenticator) error {
func SetValidators(validators []validators.Validator) func(*Authenticator) error {
return func(a *Authenticator) error {
a.Validators = validators
return nil
Expand Down
5 changes: 5 additions & 0 deletions internal/pkg/logging/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ func (l *LogEntry) WithEndpoint(endpoint string) *LogEntry {
return l.withField("endpoint", endpoint)
}

// WithAuthorizedUpstream appends an `authorized_upstream` tag to a LogEntry.
func (l *LogEntry) WithAuthorizedUpstream(upstream string) *LogEntry {
return l.withField("authorized_upstream", upstream)
}

// WithError appends an `error` tag to a LogEntry. Useful for annotating non-Error log
// entries (e.g. Fatal messages) with an `error` object.
func (l *LogEntry) WithError(err error) *LogEntry {
Expand Down
15 changes: 12 additions & 3 deletions internal/pkg/sessions/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ type SessionState struct {
ValidDeadline time.Time `json:"valid_deadline"`
GracePeriodStart time.Time `json:"grace_period_start"`

Email string `json:"email"`
User string `json:"user"`
Groups []string `json:"groups"`
Email string `json:"email"`
User string `json:"user"`
Groups []string `json:"groups"`
AuthorizedUpstream string `json:"authorized_upstream"`
}

// LifetimePeriodExpired returns true if the lifetime has expired
Expand All @@ -45,6 +46,14 @@ func (s *SessionState) ValidationPeriodExpired() bool {
return isExpired(s.ValidDeadline)
}

// IsWithinGracePeriod returns true if the session is still within the grace period
func (s *SessionState) IsWithinGracePeriod(gracePeriodTTL time.Duration) bool {
Jusshersmith marked this conversation as resolved.
Show resolved Hide resolved
if s.GracePeriodStart.IsZero() {
s.GracePeriodStart = time.Now()
}
return s.GracePeriodStart.Add(gracePeriodTTL).After(time.Now())
}

func isExpired(t time.Time) bool {
if t.Before(time.Now()) {
return true
Expand Down
15 changes: 12 additions & 3 deletions internal/pkg/sessions/session_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,29 @@ func TestSessionStateExpirations(t *testing.T) {
LifetimeDeadline: time.Now().Add(-1 * time.Hour),
RefreshDeadline: time.Now().Add(-1 * time.Hour),
ValidDeadline: time.Now().Add(-1 * time.Minute),
GracePeriodStart: time.Now().Add(-2 * time.Minute),

Email: "user@domain.com",
User: "user",
}

if !session.LifetimePeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if !session.RefreshPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if !session.ValidationPeriodExpired() {
t.Errorf("expcted lifetime period to be expired")
t.Errorf("expected lifetime period to be expired")
}

if session.IsWithinGracePeriod(1 * time.Minute) {
t.Errorf("expected session to be outside of grace period")
}

if !session.IsWithinGracePeriod(3 * time.Minute) {
t.Errorf("expected session to be inside of grace period")
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"errors"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"testing"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"errors"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"testing"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package options
package validators

import (
"errors"
"fmt"
"strings"

"github.com/buzzfeed/sso/internal/pkg/sessions"
"github.com/buzzfeed/sso/internal/proxy/providers"
Expand All @@ -10,7 +12,7 @@ import (
var (
_ Validator = EmailGroupValidator{}

// These error message should be formatted in such a way that is appropriate
// These error messages should be formatted in such a way that is appropriate
// for display to the end user.
ErrGroupMembership = errors.New("Invalid Group Membership")
)
Expand Down Expand Up @@ -42,13 +44,13 @@ func (v EmailGroupValidator) Validate(session *sessions.SessionState) error {
func (v EmailGroupValidator) validate(session *sessions.SessionState) error {
matchedGroups, valid, err := v.Provider.ValidateGroup(session.Email, v.AllowedGroups, session.AccessToken)
if err != nil {
return ErrValidationError
return err
}

if valid {
session.Groups = matchedGroups
return nil
}

return ErrGroupMembership
return fmt.Errorf("%v - Allowed Groups: %s", ErrGroupMembership, strings.Join(v.AllowedGroups, ", "))
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"errors"
Expand All @@ -12,18 +12,26 @@ var (

type MockValidator struct {
Result bool
Err error
}

func NewMockValidator(result bool) MockValidator {
func NewMockValidator(result bool, err error) MockValidator {
return MockValidator{
Result: result,
Err: err,
}
}

func (v MockValidator) Validate(session *sessions.SessionState) error {
// if we pass in a specific error, return it
if v.Err != nil {
return v.Err
}
// if result is true, return nil
if v.Result {
return nil
}

// otherwise, return generic mock validator error
return errors.New("MockValidator error")
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package validators

import (
"errors"
Expand All @@ -10,7 +10,6 @@ var (
// These error message should be formatted in such a way that is appropriate
// for display to the end user.
ErrInvalidEmailAddress = errors.New("Invalid Email Address In Session State")
ErrValidationError = errors.New("Error during validation")
)

type Validator interface {
Expand Down
Loading