diff --git a/README.md b/README.md index 98d635e..ef36fcd 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,57 @@ You can add, update, and remove backend services using the following REST endpoi - PUT /services/{name} - Updates an existing backend service - DELETE /services/{name} - Removes a backend service +## Adding authentication to backend services +Frontman currently supports two methods of authentication: JWT tokens and Basic Auth. Authentication can be configured for each backend service separately using the `auth` configuration +option: + +- Basic Auth with Username and Password In config: +```yaml + # .. backend config + auth: + type: "basic" + basic: + username: "test" + password: "test" +``` + +- Basic Auth with username and password environment variable: +```yaml + # .. backend config + auth: + type: "basic" + basic: + usernameEnvVariable: "API_USERNAME" + passwordEnvVariable: "API_PASSWORD" +``` + +- Basic Auth with username and password stored in credentials file: +```yaml + # .. backend config + auth: + type: "basic" + basic: + credentialsFile: "credentials.yaml" +``` + +credentials.yaml: +```yaml +username: "filetest" +password: "filetest" +``` + +- JWT Auth: +```yaml + # .. backend config + auth: + type: "jwt" + userDataContextKey: "user" # Header for storing user claims + jwt: + audience: + issuer: + keysUrl: +``` + ## Frontman Plugins Frontman allows you to create custom plugins that can be used to extend its functionality. Plugins are implemented using the FrontmanPlugin interface, which consists of three methods: @@ -310,4 +361,4 @@ Once you have updated the configuration file, restart Frontman to load the new p If you'd like to contribute to Frontman, please fork the repository and submit a pull request. We welcome bug reports, feature requests, and code contributions. ## License -Frontman is released under the GNU General Public License. See LICENSE for details. \ No newline at end of file +Frontman is released under the GNU General Public License. See LICENSE for details. diff --git a/auth/auth.go b/auth/auth.go index fea8cd8..822598a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -3,16 +3,19 @@ package auth import ( "errors" "github.com/Frontman-Labs/frontman/config" + "net/http" ) type TokenValidator interface { - ValidateToken(tokenString string) (map[string]interface{}, error) + ValidateToken(request *http.Request) (map[string]interface{}, error) } func GetTokenValidator(conf config.AuthConfig) (TokenValidator, error) { switch conf.AuthType { case "jwt": - return NewJWTValidator(conf.JWT.Audience, conf.JWT.Issuer, conf.JWT.KeysUrl), nil + return NewJWTValidator(conf.JWT.Audience, conf.JWT.Issuer, conf.JWT.KeysUrl) + case "basic": + return NewBasicAuthValidator(conf.BasicAuthConfig) default: return nil, errors.New("Unrecognized auth type specified") } diff --git a/auth/basic.go b/auth/basic.go new file mode 100644 index 0000000..dc9bb90 --- /dev/null +++ b/auth/basic.go @@ -0,0 +1,70 @@ +package auth + +import ( + "errors" + "io/ioutil" + "log" + "net/http" + "os" + + "github.com/Frontman-Labs/frontman/config" + "gopkg.in/yaml.v3" +) + +type BasicAuthValidator struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +func getCredentialsFromConfig(conf *config.BasicAuthConfig) (string, string) { + var username, password string + if conf.Username != "" { + username = conf.Username + } else { + username = os.Getenv(conf.UsernameEnv) + } + + if conf.Password != "" { + password = conf.Password + } else { + password = os.Getenv(conf.PasswordEnv) + } + + return username, password +} + +func NewBasicAuthValidator(conf *config.BasicAuthConfig) (*BasicAuthValidator, error) { + if conf.CredentialsFile != "" { + // Read credentials file to build validator + yamlData, err := ioutil.ReadFile(conf.CredentialsFile) + if err != nil { + log.Printf("Failed to read credentials file: %s", err) + return nil, err + } + validator := &BasicAuthValidator{} + err = yaml.Unmarshal(yamlData, validator) + if err != nil { + log.Printf("Failed to unmarshal credentials data: %s", err) + return nil, err + } + return validator, nil + } + username, password := getCredentialsFromConfig(conf) + return &BasicAuthValidator{ + Username: username, + Password: password, + }, nil +} + +func (v BasicAuthValidator) ValidateToken(request *http.Request) (map[string]interface{}, error) { + username, password, ok := request.BasicAuth() + if !ok { + return nil, errors.New("Error parsing authentication token") + } + + if username != v.Username || password != v.Password { + return nil, errors.New("Invalid credentials") + } + + return nil, nil +} diff --git a/auth/basic_auth_validator_test.go b/auth/basic_auth_validator_test.go new file mode 100644 index 0000000..42e6a77 --- /dev/null +++ b/auth/basic_auth_validator_test.go @@ -0,0 +1,102 @@ +package auth + +import ( + "github.com/Frontman-Labs/frontman/config" + "net/http" + "os" + "testing" +) + +func TestNewBasicAuthValidatorFromHardcodedCredentials(t *testing.T) { + conf := &config.BasicAuthConfig{ + Username: "username", + Password: "password", + } + + validator, err := NewBasicAuthValidator(conf) + if err != nil { + t.Errorf("Failed to create basic validator: %s\n", err) + } + if validator.Username != "username" { + t.Errorf("NewBasicAuthValidator failed to parse username from username config variable\n") + } + if validator.Password != "password" { + t.Errorf("NewBasicAuthValidator failed to parse password from password config variable\n") + } +} + +func TestNewBasicAuthValidatorFromEnvVariables(t *testing.T) { + conf := &config.BasicAuthConfig{ + UsernameEnv: "FRONTMAN_TEST_BACKEND_USERNAME", + PasswordEnv: "FRONTMAN_TEST_BACKEND_PASSWORD", + } + + os.Setenv("FRONTMAN_TEST_BACKEND_USERNAME", "username_from_env") + os.Setenv("FRONTMAN_TEST_BACKEND_PASSWORD", "password_from_env") + + validator, err := NewBasicAuthValidator(conf) + if err != nil { + t.Errorf("Failed to create basic validator: %s\n", err) + } + if validator.Username != "username_from_env" { + t.Errorf("NewBasicAuthValidator failed to parse username from username environment variable\n") + } + if validator.Password != "password_from_env" { + t.Errorf("NewBasicAuthValidator failed to parse password from password environment variable\n") + } +} + +func TestBasicAuthValidCredentials(t *testing.T) { + validator := &BasicAuthValidator{ + Username: "test", + Password: "test", + } + + req := &http.Request{ + Header: make(http.Header), + } + req.SetBasicAuth("test", "test") + _, err := validator.ValidateToken(req) + if err != nil { + t.Errorf("Failed to validate correct basic auth: %s\n", err) + } +} + +func TestBasicAuthInvalidCredentials(t *testing.T) { + validator := &BasicAuthValidator{ + Username: "test", + Password: "test", + } + + req := &http.Request{ + Header: make(http.Header), + } + req.SetBasicAuth("blah", "blah") + _, err := validator.ValidateToken(req) + if err == nil { + t.Errorf("Failed to validate correctly identify invalid basic auth credentials\n") + } + + if err.Error() != "Invalid credentials" { + t.Errorf("Invalid error message returned when parsing invalid credentials: %s\n", err) + } +} + +func TestBasicAuthMissingCredentials(t *testing.T) { + validator := &BasicAuthValidator{ + Username: "test", + Password: "test", + } + + req := &http.Request{ + Header: make(http.Header), + } + _, err := validator.ValidateToken(req) + if err == nil { + t.Errorf("Failed to validate correctly identify missing basic auth credentials\n") + } + + if err.Error() != "Error parsing authentication token" { + t.Errorf("Invalid error message returned when parsing invalid credentials: %s\n", err) + } +} diff --git a/auth/jwt.go b/auth/jwt.go index b78e150..f259775 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -8,6 +8,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" + "net/http" ) type JWTValidator struct { @@ -16,20 +17,21 @@ type JWTValidator struct { JWKS jwk.Set } -func NewJWTValidator(issuer string, audience string, jwkUrl string) *JWTValidator { +func NewJWTValidator(issuer string, audience string, jwkUrl string) (*JWTValidator, error) { jwks, err := jwk.Fetch(context.Background(), jwkUrl) if err != nil { log.Printf("Error loading jwks from %s: %s", jwkUrl, err.Error()) - return nil + return nil, err } return &JWTValidator{ issuer: issuer, audience: audience, JWKS: jwks, - } + }, nil } -func (v JWTValidator) ValidateToken(tokenString string) (map[string]interface{}, error) { +func (v JWTValidator) ValidateToken(request *http.Request) (map[string]interface{}, error) { + tokenString := request.Header.Get("Authorization") splitToken := strings.Fields(tokenString) // Remove leading "Bearer " token := splitToken[len(splitToken)-1] diff --git a/jwt_validator_test.go b/auth/jwt_validator_test.go similarity index 79% rename from jwt_validator_test.go rename to auth/jwt_validator_test.go index d68bcb7..831a4ca 100644 --- a/jwt_validator_test.go +++ b/auth/jwt_validator_test.go @@ -1,19 +1,19 @@ -package frontman +package auth import ( "crypto/rand" "crypto/rsa" - "testing" - "time" - - "github.com/Frontman-Labs/frontman/auth" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" + "net/http" + "testing" + "time" ) -func TestValidateToken(t *testing.T) { - validator := auth.JWTValidator{ +// TestGetServicesHandler tests the getServicesHandler function +func TestValidateJWTToken(t *testing.T) { + validator := JWTValidator{ JWKS: jwk.NewSet(), } privKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -41,7 +41,11 @@ func TestValidateToken(t *testing.T) { if err != nil { t.Errorf("failed to generate signed serialized: %s\n", err) } - result, err := validator.ValidateToken(string(signed)) + headers := make(http.Header) + headers.Add("Authorization", string(signed)) + result, err := validator.ValidateToken(&http.Request{ + Header: headers, + }) if err != nil { t.Errorf("Failed to validate signed token: %s", err) } @@ -50,8 +54,8 @@ func TestValidateToken(t *testing.T) { } } -func TestValidateTokenInvalidSignature(t *testing.T) { - validator := auth.JWTValidator{ +func TestValidateJWTTokenInvalidSignature(t *testing.T) { + validator := JWTValidator{ JWKS: jwk.NewSet(), } privKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -83,14 +87,18 @@ func TestValidateTokenInvalidSignature(t *testing.T) { if err != nil { t.Errorf("failed to generate signed serialized: %s\n", err) } - _, err = validator.ValidateToken(string(signed)) + headers := make(http.Header) + headers.Add("Authorization", string(signed)) + _, err = validator.ValidateToken(&http.Request{ + Header: headers, + }) if err == nil { t.Errorf("Failed to detect invalid key") } } -func TestValidateExpiredToken(t *testing.T) { - validator := auth.JWTValidator{ +func TestValidateJWTExpiredToken(t *testing.T) { + validator := JWTValidator{ JWKS: jwk.NewSet(), } privKey, err := rsa.GenerateKey(rand.Reader, 2048) @@ -124,9 +132,13 @@ func TestValidateExpiredToken(t *testing.T) { if err != nil { t.Errorf("failed to generate signed serialized: %s\n", err) } - _, err = validator.ValidateToken(string(signed)) + headers := make(http.Header) + headers.Add("Authorization", string(signed)) + _, err = validator.ValidateToken(&http.Request{ + Header: headers, + }) if err == nil { - t.Errorf("Failed to detect invalid key") + t.Errorf("Failed to detect invalid key: %s", err) } if err.Error() != "\"exp\" not satisfied" { diff --git a/config/config.go b/config/config.go index 13d0c89..b1cd4b6 100644 --- a/config/config.go +++ b/config/config.go @@ -32,11 +32,20 @@ type JWTConfig struct { KeysUrl string `json:"keysUrl" yaml:"keysUrl"` } +type BasicAuthConfig struct { + Username string `json:"username" yaml:"username"` + Password string `json:"password" yaml:"password"` + UsernameEnv string `json:"usernameEnvVariable" yaml:"usernameEnvVariable"` + PasswordEnv string `json:"passwordEnvVariable" yaml:"passwordEnvVariable"` + CredentialsFile string `json:"credentialsFile" yaml:"credentialsFile"` +} + // Auth config type AuthConfig struct { - AuthType string `json:"type" yaml:"type"` - UserDataHeader string `json:"userDataHeader" yaml:"userDataHeader"` - JWT *JWTConfig `json:"jwt" yaml:"jwt"` + AuthType string `json:"type" yaml:"type"` + UserDataHeader string `json:"userDataHeader" yaml:"userDataHeader"` + JWT *JWTConfig `json:"jwt" yaml:"jwt"` + BasicAuthConfig *BasicAuthConfig `json:"basic" yaml:"basic"` } // APIConfig holds the API server configuration diff --git a/handlers.go b/handlers.go index d5c8751..f35437b 100644 --- a/handlers.go +++ b/handlers.go @@ -265,7 +265,6 @@ func gatewayHandler(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, // Get the upstream target URL for this request upstreamTarget := backendService.GetLoadBalancer().ChooseTarget(backendService.UpstreamTargets) - var urlPath string if backendService.StripPath { urlPath = strings.TrimPrefix(r.URL.Path, backendService.Path) @@ -289,17 +288,20 @@ func gatewayHandler(bs service.ServiceRegistry, plugs []plugins.FrontmanPlugin, if backendService.AuthConfig != nil { tokenValidator := backendService.GetTokenValidator() // Backend service has auth config specified - claims, err := tokenValidator.ValidateToken(headers.Get("Authorization")) + claims, err := tokenValidator.ValidateToken(r) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } - data, err := json.Marshal(claims) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + + if claims != nil { + data, err := json.Marshal(claims) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + headers.Add(backendService.GetUserDataHeader(), string(data)) } - headers.Add(backendService.GetUserDataHeader(), string(data)) } // Remove the X-Forwarded-For header to prevent spoofing diff --git a/service/service.go b/service/service.go index 8942176..919e5b2 100644 --- a/service/service.go +++ b/service/service.go @@ -140,6 +140,8 @@ func (bs *BackendService) setLoadBalancer() { bs.loadBalancer = loadbalancer.NewWRoundRobinLoadBalancer(bs.LoadBalancerPolicy.Options.Weights) case loadbalancer.LeastConnection: bs.loadBalancer = loadbalancer.NewLeastConnLoadBalancer(bs.UpstreamTargets) + default: + bs.loadBalancer = loadbalancer.NewRoundRobinLoadBalancer() } }