From 750f7dc107304c7d0e1f0a9b3f76e89ba00c8d93 Mon Sep 17 00:00:00 2001 From: Dustin Deus Date: Fri, 3 Jan 2025 13:41:57 +0100 Subject: [PATCH] feat: allow to conditionally block mutation via expressions (#1480) Co-authored-by: Ludwig --- router-tests/authentication_test.go | 24 +- router-tests/block_operations_test.go | 537 ++++++++++++++++++ router-tests/events/kafka_events_test.go | 12 +- router-tests/events/nats_events_test.go | 12 +- router-tests/go.mod | 1 + router-tests/go.sum | 2 + router-tests/integration_test.go | 47 -- router-tests/persisted_operations_test.go | 20 - router-tests/testenv/testenv.go | 12 +- router-tests/utils.go | 23 + router-tests/websocket_test.go | 25 - router/core/context.go | 11 +- router/core/graph_server.go | 21 +- router/core/graphql_prehandler.go | 43 +- router/core/operation_blocker.go | 130 ++++- router/core/websocket.go | 21 +- router/go.mod | 1 + router/go.sum | 2 + router/internal/expr/expr.go | 181 ++++++ router/pkg/authentication/authentication.go | 16 +- .../http_header_authenticator.go | 6 + router/pkg/config/config.go | 11 +- router/pkg/config/config.schema.json | 178 +++--- .../pkg/config/testdata/config_defaults.json | 15 +- router/pkg/config/testdata/config_full.json | 15 +- router/pkg/watcher/watcher_test.go | 11 +- 26 files changed, 1112 insertions(+), 265 deletions(-) create mode 100644 router-tests/block_operations_test.go create mode 100644 router/internal/expr/expr.go diff --git a/router-tests/authentication_test.go b/router-tests/authentication_test.go index 0d183cb610..fde9856ac1 100644 --- a/router-tests/authentication_test.go +++ b/router-tests/authentication_test.go @@ -18,7 +18,6 @@ import ( ) const ( - jwksName = "my-jwks-server" employeesQuery = `{"query":"{ employees { id } }"}` employeesQueryRequiringClaims = `{"query":"{ employees { id startDate } }"}` employeesExpectedData = `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}` @@ -26,21 +25,6 @@ const ( xAuthenticatedByHeader = "X-Authenticated-By" ) -func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) { - authServer, err := jwks.NewServer(t) - require.NoError(t, err) - t.Cleanup(authServer.Close) - tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) - authOptions := authentication.HttpHeaderAuthenticatorOptions{ - Name: jwksName, - URL: authServer.JWKSURL(), - TokenDecoder: tokenDecoder, - } - authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) - require.NoError(t, err) - return []authentication.Authenticator{authenticator}, authServer -} - func TestAuthentication(t *testing.T) { t.Parallel() @@ -750,7 +734,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { t.Cleanup(authServer2.Close) tokenDecoder1, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer1.JWKSURL(), time.Second*5) - authenticator1HeaderValuePrefixes := []string{"Bearer"} + authenticator1HeaderValuePrefixes := []string{"Provider1"} authenticator1, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{ Name: "1", HeaderValuePrefixes: authenticator1HeaderValuePrefixes, @@ -760,7 +744,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { require.NoError(t, err) tokenDecoder2, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer2.JWKSURL(), time.Second*5) - authenticator2HeaderValuePrefixes := []string{"", "Bearer", "Token"} + authenticator2HeaderValuePrefixes := []string{"", "Provider2"} authenticator2, err := authentication.NewHttpHeaderAuthenticator(authentication.HttpHeaderAuthenticatorOptions{ Name: "2", HeaderValuePrefixes: authenticator2HeaderValuePrefixes, @@ -771,7 +755,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { authenticators := []authentication.Authenticator{authenticator1, authenticator2} accessController := core.NewAccessController(authenticators, false) - t.Run("authenticate with first provider", func(t *testing.T) { + t.Run("authenticate with first provider due to matching prefix", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ @@ -800,7 +784,7 @@ func TestAuthenticationMultipleProviders(t *testing.T) { }) }) - t.Run("authenticate with second provider", func(t *testing.T) { + t.Run("authenticate with second provider due to matching prefix", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ diff --git a/router-tests/block_operations_test.go b/router-tests/block_operations_test.go new file mode 100644 index 0000000000..3dffafa80f --- /dev/null +++ b/router-tests/block_operations_test.go @@ -0,0 +1,537 @@ +package integration + +import ( + "bytes" + "encoding/json" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "io" + "net/http" + "strings" + "testing" +) + +func TestBlockOperations(t *testing.T) { + t.Parallel() + + t.Run("block mutations", func(t *testing.T) { + t.Parallel() + + t.Run("should allow all operations", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) + }) + }) + + t.Run("should block all operations", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockMutations = config.BlockOperationConfiguration{ + Enabled: true, + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body) + }) + }) + + t.Run("should block operations by header match expression", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockMutations = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "request.header.Get('graphql-client-name') == 'my-client'", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + // Positive test + + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Header: map[string][]string{ + "graphql-client-name": {"my-client-different"}, + }, + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) + + // Negative test + + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Header: map[string][]string{ + "graphql-client-name": {"my-client"}, + }, + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body) + }) + }) + + t.Run("should block operations by query match expression", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockMutations = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "request.url.query.foo == 'bar'", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + // Negative test + + data, err := json.Marshal(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + + require.NoError(t, err) + req, err := http.NewRequestWithContext(xEnv.Context, http.MethodPost, xEnv.GraphQLRequestURL(), bytes.NewReader(data)) + require.NoError(t, err) + + res, err := xEnv.MakeGraphQLRequestRaw(req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) + + // Positive test + + req, err = http.NewRequestWithContext(xEnv.Context, http.MethodPost, xEnv.GraphQLRequestURL()+"?foo=bar", bytes.NewReader(data)) + require.NoError(t, err) + + res, err = xEnv.MakeGraphQLRequestRaw(req) + require.NoError(t, err) + + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body) + }) + }) + + t.Run("should block operation by scope expression condition", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := configureAuth(t) + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockMutations = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "'read:miscellaneous' in request.auth.scopes && request.auth.isAuthenticated", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + token, err := authServer.Token(map[string]any{ + "scope": "write:fact read:miscellaneous read:all", + }) + require.NoError(t, err) + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err := xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(` + {"query":"mutation { addFact(fact: { title: \"title\", description: \"description\", factType: MISCELLANEOUS }) { ... on MiscellaneousFact { title description } } }"} + `)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + data, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, string(data)) + + // Negative test + + token, err = authServer.Token(map[string]any{ + "scope": "write:fact read:all", + }) + require.NoError(t, err) + header = http.Header{ + "Authorization": []string{"Bearer " + token}, + } + res, err = xEnv.MakeRequest(http.MethodPost, "/graphql", header, strings.NewReader(` + {"query":"mutation { addFact(fact: { title: \"title\", description: \"description\", factType: DIRECTIVE }) { description } }"} + `)) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + data, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, `{"data":{"addFact":{"description":"description"}}}`, string(data)) + }) + }) + }) + + t.Run("block subscriptions", func(t *testing.T) { + + t.Run("should block all subscriptions", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + var msg testenv.WebSocketMessage + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "error", msg.Type) + require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload)) + }) + }) + + t.Run("should block subscriptions by header match expression", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "request.header.Get('graphql-client-name') == 'my-client'", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + type currentTimePayload struct { + Data struct { + CurrentTime struct { + UnixTime float64 `json:"unixTime"` + Timestamp string `json:"timestamp"` + } `json:"currentTime"` + } `json:"data"` + } + + // Positive test + + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + var msg testenv.WebSocketMessage + var payload currentTimePayload + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + + err = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, err) + + require.NotEmpty(t, payload.Data.CurrentTime.UnixTime) + require.NotEmpty(t, payload.Data.CurrentTime.Timestamp) + + _ = conn.Close() + + // Negative test + + header := make(http.Header) + header.Add("graphql-client-name", "my-client") + conn = xEnv.InitGraphQLWebSocketConnection(header, nil, nil) + err = conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + msg = testenv.WebSocketMessage{} + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "error", msg.Type) + require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload)) + }) + }) + + t.Run("should block subscriptions by scope match expression", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := configureAuth(t) + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + core.WithAuthorizationConfig(&config.AuthorizationConfiguration{ + RejectOperationIfUnauthorized: false, + }), + }, + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "'read:block' in request.auth.scopes && request.auth.isAuthenticated", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + type currentTimePayload struct { + Data struct { + CurrentTime struct { + UnixTime float64 `json:"unixTime"` + Timestamp string `json:"timestamp"` + } `json:"currentTime"` + } `json:"data"` + } + + // Positive test + + token, err := authServer.Token(map[string]any{ + "scope": "read:all", + }) + require.NoError(t, err) + + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err = conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + var msg testenv.WebSocketMessage + var payload currentTimePayload + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + + err = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, err) + + require.NotEmpty(t, payload.Data.CurrentTime.UnixTime) + require.NotEmpty(t, payload.Data.CurrentTime.Timestamp) + + // Negative test + + token, err = authServer.Token(map[string]any{ + "scope": "read:block", + }) + require.NoError(t, err) + + header = http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + conn = xEnv.InitGraphQLWebSocketConnection(header, nil, nil) + err = conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + msg = testenv.WebSocketMessage{} + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "error", msg.Type) + require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload)) + + _ = conn.Close() + }) + }) + + t.Run("should block subscriptions by scope match expression and from initial payload enabled", func(t *testing.T) { + t.Parallel() + + authenticators, authServer := configureAuth(t) + + testenv.Run(t, &testenv.Config{ + ModifyWebsocketConfiguration: func(cfg *config.WebSocketConfiguration) { + cfg.Authentication.FromInitialPayload.Enabled = true + cfg.Enabled = true + }, + RouterOptions: []core.Option{ + core.WithAccessController(core.NewAccessController(authenticators, false)), + core.WithAuthorizationConfig(&config.AuthorizationConfiguration{ + RejectOperationIfUnauthorized: false, + }), + }, + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "'read:block' in request.auth.scopes && request.auth.isAuthenticated", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + type currentTimePayload struct { + Data struct { + CurrentTime struct { + UnixTime float64 `json:"unixTime"` + Timestamp string `json:"timestamp"` + } `json:"currentTime"` + } `json:"data"` + } + + // Positive test + + token, err := authServer.Token(map[string]any{ + "scope": "read:all", + }) + require.NoError(t, err) + + header := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err = conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + var msg testenv.WebSocketMessage + var payload currentTimePayload + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "next", msg.Type) + + err = json.Unmarshal(msg.Payload, &payload) + require.NoError(t, err) + + require.NotEmpty(t, payload.Data.CurrentTime.UnixTime) + require.NotEmpty(t, payload.Data.CurrentTime.Timestamp) + + // Negative test + + token, err = authServer.Token(map[string]any{ + "scope": "read:block", + }) + require.NoError(t, err) + + header = http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + conn = xEnv.InitGraphQLWebSocketConnection(header, nil, nil) + err = conn.WriteJSON(&testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + msg = testenv.WebSocketMessage{} + + err = conn.ReadJSON(&msg) + require.NoError(t, err) + require.Equal(t, "1", msg.ID) + require.Equal(t, "error", msg.Type) + require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload)) + + _ = conn.Close() + }) + }) + + }) + + t.Run("block non-persisted operations", func(t *testing.T) { + t.Parallel() + + t.Run("should allow operations", func(t *testing.T) { + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockNonPersistedOperations = config.BlockOperationConfiguration{ + Enabled: true, + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + // Negative test + + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, res.Response.Header.Get("Content-Type"), "application/json") + require.Equal(t, `{"errors":[{"message":"non-persisted operation is blocked"}]}`, res.Body) + + // Positive test + + header := make(http.Header) + header.Add("graphql-client-name", "my-client") + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + OperationName: []byte(`"Employees"`), + Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), + Header: header, + }) + require.NoError(t, err) + require.Equal(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body) + }) + }) + + t.Run("should block operation by header match expression", func(t *testing.T) { + t.Parallel() + testenv.Run(t, &testenv.Config{ + ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { + securityConfiguration.BlockNonPersistedOperations = config.BlockOperationConfiguration{ + Enabled: true, + Condition: "request.header.Get('graphql-client-name') == 'my-client'", + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + // Negative test + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Header: map[string][]string{ + "graphql-client-name": {"my-client-different"}, + }, + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) + + // Positive test + res = xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Header: map[string][]string{ + "graphql-client-name": {"my-client"}, + }, + Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, + }) + require.Equal(t, http.StatusOK, res.Response.StatusCode) + require.Equal(t, res.Response.Header.Get("Content-Type"), "application/json") + require.Equal(t, `{"errors":[{"message":"non-persisted operation is blocked"}]}`, res.Body) + }) + }) + }) +} diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index 340a376b2c..6d49659b77 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -495,7 +495,7 @@ func TestKafkaEvents(t *testing.T) { }) }) - t.Run("subscribe sync with block", func(t *testing.T) { + t.Run("Should block subscribe sync operation", func(t *testing.T) { t.Parallel() subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -503,7 +503,9 @@ func TestKafkaEvents(t *testing.T) { testenv.Run(t, &testenv.Config{ EnableKafka: true, ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockSubscriptions = true + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + } }, }, func(t *testing.T, xEnv *testenv.Environment) { client := http.Client{ @@ -645,7 +647,7 @@ func TestKafkaEvents(t *testing.T) { }) }) - t.Run("subscribe sync sse with block", func(t *testing.T) { + t.Run("should block subscribe sync sse operation", func(t *testing.T) { t.Parallel() subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -653,7 +655,9 @@ func TestKafkaEvents(t *testing.T) { testenv.Run(t, &testenv.Config{ EnableKafka: true, ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockSubscriptions = true + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + } }, }, func(t *testing.T, xEnv *testenv.Environment) { client := http.Client{ diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index 5d06c688bc..3d89a9d980 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -458,13 +458,15 @@ func TestNatsEvents(t *testing.T) { }) }) - t.Run("subscribe sync multipart with block", func(t *testing.T) { + t.Run("should block subscribe sync multipart operation", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ EnableNats: true, ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockSubscriptions = true + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + } }, }, func(t *testing.T, xEnv *testenv.Environment) { queries := [][]byte{ @@ -630,13 +632,15 @@ func TestNatsEvents(t *testing.T) { }) }) - t.Run("subscribe sync sse with block", func(t *testing.T) { + t.Run("should block subscribe sync sse operation", func(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ EnableNats: true, ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockSubscriptions = true + securityConfiguration.BlockSubscriptions = config.BlockOperationConfiguration{ + Enabled: true, + } }, }, func(t *testing.T, xEnv *testenv.Environment) { subscribePayloadOne := []byte(`{"query":"subscription { employeeUpdated(employeeID: 3) { id details { forename surname } }}"}`) diff --git a/router-tests/go.mod b/router-tests/go.mod index c01ffee6d9..8f5caed3e9 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -65,6 +65,7 @@ require ( github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/expr-lang/expr v1.16.9 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index bbaa324742..387f55253f 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -91,6 +91,8 @@ github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 h1:Oy0F4A github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1:YvSRo5mw33fLEx1+DlK6L2VV43tJt5Eyel9n9XBcR+0= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= +github.com/expr-lang/expr v1.16.9 h1:WUAzmR0JNI9JCiF0/ewwHB1gmcGw5wW7nWt8gc6PpCI= +github.com/expr-lang/expr v1.16.9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= diff --git a/router-tests/integration_test.go b/router-tests/integration_test.go index 49a86c1af6..23b0a15ae8 100644 --- a/router-tests/integration_test.go +++ b/router-tests/integration_test.go @@ -925,53 +925,6 @@ func TestConcurrentQueriesWithDelay(t *testing.T) { }) } -func TestBlockMutations(t *testing.T) { - t.Parallel() - t.Run("allow", func(t *testing.T) { - t.Parallel() - testenv.Run(t, &testenv.Config{}, func(t *testing.T, xEnv *testenv.Environment) { - res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ - Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, - }) - require.Equal(t, http.StatusOK, res.Response.StatusCode) - require.Equal(t, `{"data":{"updateEmployeeTag":{"id":1,"tag":"test"}}}`, res.Body) - }) - }) - t.Run("block", func(t *testing.T) { - t.Parallel() - testenv.Run(t, &testenv.Config{ - ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockMutations = true - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ - Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, - }) - require.Equal(t, http.StatusOK, res.Response.StatusCode) - require.Equal(t, `{"errors":[{"message":"operation type 'mutation' is blocked"}]}`, res.Body) - }) - }) -} - -func TestBlockNonPersistedOperations(t *testing.T) { - t.Parallel() - t.Run("block", func(t *testing.T) { - t.Parallel() - testenv.Run(t, &testenv.Config{ - ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockNonPersistedOperations = true - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ - Query: `mutation { updateEmployeeTag(id: 1, tag: "test") { id tag } }`, - }) - require.Equal(t, http.StatusOK, res.Response.StatusCode) - require.Equal(t, res.Response.Header.Get("Content-Type"), "application/json") - require.Equal(t, `{"errors":[{"message":"non-persisted operation is blocked"}]}`, res.Body) - }) - }) -} - func TestRequestBodySizeLimit(t *testing.T) { t.Parallel() testenv.Run(t, &testenv.Config{ diff --git a/router-tests/persisted_operations_test.go b/router-tests/persisted_operations_test.go index 7230a907e0..b3d9f14661 100644 --- a/router-tests/persisted_operations_test.go +++ b/router-tests/persisted_operations_test.go @@ -41,26 +41,6 @@ func TestPersistedOperation(t *testing.T) { }) } -func TestPersistedOperationWithBlock(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockNonPersistedOperations = true - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - header := make(http.Header) - header.Add("graphql-client-name", "my-client") - res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ - OperationName: []byte(`"Employees"`), - Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "dc67510fb4289672bea757e862d6b00e83db5d3cbbcfb15260601b6f29bb2b8f"}}`), - Header: header, - }) - require.NoError(t, err) - require.Equal(t, `{"data":{"employees":[{"id":1},{"id":2},{"id":3},{"id":4},{"id":5},{"id":7},{"id":8},{"id":10},{"id":11},{"id":12}]}}`, res.Body) - }) -} - func TestPersistedOperationPOExtensionNotTransmittedToSubgraph(t *testing.T) { t.Parallel() diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index a8216054d2..f5eb21db2c 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -582,7 +582,7 @@ func createTestEnv(t testing.TB, cfg *Config) (*Environment, error) { go func() { if err := rr.Start(ctx); err != nil { - t.Fatal("Could not start router", zap.Error(err)) + require.Failf(t, "Could not start router", "error: %s", err) } }() @@ -1180,7 +1180,7 @@ func (e *Environment) MakeGraphQLRequestWithContext(ctx context.Context, request req.Header = request.Header } req.Header.Set("Accept-Encoding", "identity") - return e.makeGraphQLRequest(req) + return e.MakeGraphQLRequestRaw(req) } func (e *Environment) MakeGraphQLRequestWithHeaders(request GraphQLRequest, headers map[string]string) (*TestResponse, error) { @@ -1193,11 +1193,10 @@ func (e *Environment) MakeGraphQLRequestWithHeaders(request GraphQLRequest, head if request.Header != nil { req.Header = request.Header } - req.Header.Set("Accept-Encoding", "identity") for k, v := range headers { req.Header.Set(k, v) } - return e.makeGraphQLRequest(req) + return e.MakeGraphQLRequestRaw(req) } func (e *Environment) MakeGraphQLRequestOverGET(request GraphQLRequest) (*TestResponse, error) { @@ -1206,7 +1205,7 @@ func (e *Environment) MakeGraphQLRequestOverGET(request GraphQLRequest) (*TestRe return nil, err } - return e.makeGraphQLRequest(req) + return e.MakeGraphQLRequestRaw(req) } func (e *Environment) newGraphQLRequestOverGET(baseURL string, request GraphQLRequest) (*http.Request, error) { @@ -1237,7 +1236,8 @@ func (e *Environment) newGraphQLRequestOverGET(baseURL string, request GraphQLRe return req, nil } -func (e *Environment) makeGraphQLRequest(request *http.Request) (*TestResponse, error) { +func (e *Environment) MakeGraphQLRequestRaw(request *http.Request) (*TestResponse, error) { + request.Header.Set("Accept-Encoding", "identity") resp, err := e.RouterClient.Do(request) if err != nil { return nil, err diff --git a/router-tests/utils.go b/router-tests/utils.go index 0e806c5998..874c7052fc 100644 --- a/router-tests/utils.go +++ b/router-tests/utils.go @@ -2,9 +2,17 @@ package integration import ( "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/jwks" + "github.com/wundergraph/cosmo/router/pkg/authentication" "go.opentelemetry.io/otel/sdk/trace" tracetest2 "go.opentelemetry.io/otel/sdk/trace/tracetest" + "go.uber.org/zap" "testing" + "time" +) + +const ( + jwksName = "my-jwks-server" ) func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, name string) trace.ReadOnlySpan { @@ -22,3 +30,18 @@ func RequireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, na require.NotNil(t, testSpan) return testSpan } + +func configureAuth(t *testing.T) ([]authentication.Authenticator, *jwks.Server) { + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + t.Cleanup(authServer.Close) + tokenDecoder, _ := authentication.NewJwksTokenDecoder(zap.NewNop(), authServer.JWKSURL(), time.Second*5) + authOptions := authentication.HttpHeaderAuthenticatorOptions{ + Name: jwksName, + URL: authServer.JWKSURL(), + TokenDecoder: tokenDecoder, + } + authenticator, err := authentication.NewHttpHeaderAuthenticator(authOptions) + require.NoError(t, err) + return []authentication.Authenticator{authenticator}, authServer +} diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index a1bbf676ac..d5fd5c1ca0 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -1190,31 +1190,6 @@ func TestWebSockets(t *testing.T) { require.Equal(t, `[{"message":"Unable to subscribe"}]`, string(msg.Payload)) }) }) - t.Run("subscription blocked", func(t *testing.T) { - t.Parallel() - - testenv.Run(t, &testenv.Config{ - ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) { - securityConfiguration.BlockSubscriptions = true - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - - conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) - err := conn.WriteJSON(&testenv.WebSocketMessage{ - ID: "1", - Type: "subscribe", - Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), - }) - require.NoError(t, err) - - var msg testenv.WebSocketMessage - err = conn.ReadJSON(&msg) - require.NoError(t, err) - require.Equal(t, "1", msg.ID) - require.Equal(t, "error", msg.Type) - require.Equal(t, `[{"message":"operation type 'subscription' is blocked"}]`, string(msg.Payload)) - }) - }) t.Run("multiple subscriptions one connection", func(t *testing.T) { t.Parallel() diff --git a/router/core/context.go b/router/core/context.go index eb37acb8d7..04e967c23f 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -2,6 +2,7 @@ package core import ( "context" + "github.com/wundergraph/cosmo/router/internal/expr" "net/http" "net/url" "strings" @@ -245,6 +246,8 @@ type requestContext struct { graphQLErrorCodes []string // telemetry are the base telemetry information of the request telemetry *requestTelemetryAttributes + // expressionContext is the context that will be provided to a compiled expression in order to retrieve data via dynamic expressions + expressionContext expr.Context } func (c *requestContext) Operation() OperationContext { @@ -598,6 +601,11 @@ type requestContextOptions struct { } func buildRequestContext(opts requestContextOptions) *requestContext { + + rootCtx := expr.Context{ + Request: expr.LoadRequest(opts.r), + } + return &requestContext{ logger: opts.requestLogger, keys: map[string]any{}, @@ -609,6 +617,7 @@ func buildRequestContext(opts requestContextOptions) *requestContext { metricsEnabled: opts.metricsEnabled, traceEnabled: opts.traceEnabled, }, - subgraphResolver: subgraphResolverFromContext(opts.r.Context()), + expressionContext: rootCtx, + subgraphResolver: subgraphResolverFromContext(opts.r.Context()), } } diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 7591c78abf..1a345c778d 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -927,12 +927,25 @@ func (s *graphServer) buildGraphMux(ctx context.Context, graphqlHandler := NewGraphQLHandler(handlerOpts) executor.Resolver.SetAsyncErrorWriter(graphqlHandler) - operationBlocker := NewOperationBlocker(&OperationBlockerOptions{ - BlockMutations: s.securityConfiguration.BlockMutations, - BlockSubscriptions: s.securityConfiguration.BlockSubscriptions, - BlockNonPersisted: s.securityConfiguration.BlockNonPersistedOperations, + operationBlocker, err := NewOperationBlocker(&OperationBlockerOptions{ + BlockMutations: BlockMutationOptions{ + Enabled: s.securityConfiguration.BlockMutations.Enabled, + Condition: s.securityConfiguration.BlockMutations.Condition, + }, + BlockSubscriptions: BlockSubscriptionOptions{ + Enabled: s.securityConfiguration.BlockSubscriptions.Enabled, + Condition: s.securityConfiguration.BlockSubscriptions.Condition, + }, + BlockNonPersisted: BlockNonPersistedOptions{ + Enabled: s.securityConfiguration.BlockNonPersistedOperations.Enabled, + Condition: s.securityConfiguration.BlockNonPersistedOperations.Condition, + }, }) + if err != nil { + return nil, fmt.Errorf("failed to create operation blocker: %w", err) + } + graphqlPreHandler := NewPreHandler(&PreHandlerOptions{ Logger: s.logger, Executor: executor, diff --git a/router/core/graphql_prehandler.go b/router/core/graphql_prehandler.go index 40544d6b3f..0af93f71a7 100644 --- a/router/core/graphql_prehandler.go +++ b/router/core/graphql_prehandler.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ecdsa" "fmt" + "github.com/wundergraph/cosmo/router/internal/expr" "net/http" "strconv" "strings" @@ -265,7 +266,7 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { readMultiPartSpan.End() - // Cleanup all files. Needs to be called in the pre_handler function to ensure that the + // Cleanup all files. Needs to be called in the pre_handler function to ensure that // defer is called after the response is written defer func() { if err := multipartParser.RemoveAll(); err != nil { @@ -298,24 +299,6 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { variablesParser := h.variableParsePool.Get() defer h.variableParsePool.Put(variablesParser) - err = h.handleOperation(r, variablesParser, &httpOperation{ - requestContext: requestContext, - requestLogger: requestLogger, - routerSpan: routerSpan, - operationMetrics: metrics, - traceTimings: traceTimings, - files: files, - body: body, - }) - if err != nil { - requestContext.error = err - // Mark the root span of the router as failed, so we can easily identify failed requests - rtrace.AttachErrToSpan(routerSpan, err) - - writeOperationError(r, w, requestLogger, err) - return - } - // If we have authenticators, we try to authenticate the request if h.accessController != nil { _, authenticateSpan := h.tracer.Start(r.Context(), "Authenticate", @@ -344,6 +327,26 @@ func (h *PreHandler) Handler(next http.Handler) http.Handler { authenticateSpan.End() r = validatedReq + + requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context()) + } + + err = h.handleOperation(r, variablesParser, &httpOperation{ + requestContext: requestContext, + requestLogger: requestLogger, + routerSpan: routerSpan, + operationMetrics: metrics, + traceTimings: traceTimings, + files: files, + body: body, + }) + if err != nil { + requestContext.error = err + // Mark the root span of the router as failed, so we can easily identify failed requests + rtrace.AttachErrToSpan(routerSpan, err) + + writeOperationError(r, w, requestLogger, err) + return } art.SetRequestTracingStats(r.Context(), traceOptions, traceTimings) @@ -516,7 +519,7 @@ func (h *PreHandler) handleOperation(req *http.Request, variablesParser *astjson // Set the operation name and type to the operation metrics and the router span as early as possible httpOperation.routerSpan.SetAttributes(attributesAfterParse...) - if err := h.operationBlocker.OperationIsBlocked(operationKit.parsedOperation); err != nil { + if err := h.operationBlocker.OperationIsBlocked(requestContext.logger, requestContext.expressionContext, operationKit.parsedOperation); err != nil { return &httpGraphqlError{ message: err.Error(), statusCode: http.StatusOK, diff --git a/router/core/operation_blocker.go b/router/core/operation_blocker.go index 6871a54b0e..98613b446f 100644 --- a/router/core/operation_blocker.go +++ b/router/core/operation_blocker.go @@ -2,6 +2,10 @@ package core import ( "errors" + "fmt" + "github.com/expr-lang/expr/vm" + "github.com/wundergraph/cosmo/router/internal/expr" + "go.uber.org/zap" ) var ( @@ -11,39 +15,135 @@ var ( ) type OperationBlocker struct { - blockMutations bool - blockSubscriptions bool - blockNonPersisted bool + blockMutations BlockMutationOptions + blockSubscriptions BlockSubscriptionOptions + blockNonPersisted BlockNonPersistedOptions + + mutationExpr *vm.Program + subscriptionExpr *vm.Program + nonPersistedExpr *vm.Program +} + +type BlockMutationOptions struct { + Enabled bool + Condition string +} + +type BlockSubscriptionOptions struct { + Enabled bool + Condition string +} + +type BlockNonPersistedOptions struct { + Enabled bool + Condition string } type OperationBlockerOptions struct { - BlockMutations bool - BlockSubscriptions bool - BlockNonPersisted bool + BlockMutations BlockMutationOptions + BlockSubscriptions BlockSubscriptionOptions + BlockNonPersisted BlockNonPersistedOptions } -func NewOperationBlocker(opts *OperationBlockerOptions) *OperationBlocker { - return &OperationBlocker{ +func NewOperationBlocker(opts *OperationBlockerOptions) (*OperationBlocker, error) { + ob := &OperationBlocker{ blockMutations: opts.BlockMutations, blockSubscriptions: opts.BlockSubscriptions, blockNonPersisted: opts.BlockNonPersisted, } + + if err := ob.compileExpressions(); err != nil { + return nil, err + } + + return ob, nil } -func (o *OperationBlocker) OperationIsBlocked(operation *ParsedOperation) error { +func (o *OperationBlocker) compileExpressions() error { + if o.blockMutations.Enabled && o.blockMutations.Condition != "" { - if !operation.IsPersistedOperation && o.blockNonPersisted { - return ErrNonPersistedOperationBlocked + v, err := expr.CompileBoolExpression(o.blockMutations.Condition) + if err != nil { + return fmt.Errorf("failed to compile mutation expression: %w", err) + } + o.mutationExpr = v + } + + if o.blockSubscriptions.Enabled && o.blockSubscriptions.Condition != "" { + v, err := expr.CompileBoolExpression(o.blockSubscriptions.Condition) + if err != nil { + return fmt.Errorf("failed to compile subscription expression: %w", err) + } + o.subscriptionExpr = v + } + + if o.blockNonPersisted.Enabled && o.blockNonPersisted.Condition != "" { + v, err := expr.CompileBoolExpression(o.blockNonPersisted.Condition) + if err != nil { + return fmt.Errorf("failed to compile non-persisted expression: %w", err) + } + o.nonPersistedExpr = v + } + + return nil +} + +func (o *OperationBlocker) OperationIsBlocked(requestLogger *zap.Logger, exprContext expr.Context, operation *ParsedOperation) error { + + if !operation.IsPersistedOperation && o.blockNonPersisted.Enabled { + + // Block all non-persisted operations when no expression is provided + if o.nonPersistedExpr == nil { + return ErrNonPersistedOperationBlocked + } + + ok, err := expr.ResolveBoolExpression(o.nonPersistedExpr, exprContext) + if err != nil { + requestLogger.Error("failed to resolve non-persisted block expression", zap.Error(err)) + return ErrNonPersistedOperationBlocked + } + + if ok { + return ErrNonPersistedOperationBlocked + } } switch operation.Type { case "mutation": - if o.blockMutations { - return ErrMutationOperationBlocked + if o.blockMutations.Enabled { + + // Block all mutations when no expression is provided + if o.mutationExpr == nil { + return ErrMutationOperationBlocked + } + + ok, err := expr.ResolveBoolExpression(o.mutationExpr, exprContext) + if err != nil { + requestLogger.Error("failed to resolve mutation block expression", zap.Error(err)) + return ErrMutationOperationBlocked + } + + if ok { + return ErrMutationOperationBlocked + } } case "subscription": - if o.blockSubscriptions { - return ErrSubscriptionOperationBlocked + if o.blockSubscriptions.Enabled { + + // Block all subscriptions when no expression is provided + if o.subscriptionExpr == nil { + return ErrSubscriptionOperationBlocked + } + + ok, err := expr.ResolveBoolExpression(o.subscriptionExpr, exprContext) + if err != nil { + requestLogger.Error("failed to resolve subscription block expression", zap.Error(err)) + return ErrSubscriptionOperationBlocked + } + + if ok { + return ErrSubscriptionOperationBlocked + } } } return nil diff --git a/router/core/websocket.go b/router/core/websocket.go index 1e67fba7ed..6b4a82079e 100644 --- a/router/core/websocket.go +++ b/router/core/websocket.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/wundergraph/cosmo/router/internal/expr" "net" "net/http" "regexp" @@ -224,6 +225,7 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R ) requestID := middleware.GetReqID(r.Context()) + requestContext := getRequestContext(r.Context()) requestLogger := h.logger.With(logging.WithRequestID(requestID), logging.WithTraceID(rtrace.GetTraceID(r.Context()))) clientInfo := NewClientInfoFromRequest(r, h.clientHeader) @@ -240,6 +242,8 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R return } r = validatedReq + + requestContext.expressionContext.Request.Auth = expr.LoadAuth(r.Context()) } upgrader := ws.HTTPUpgrader{ @@ -369,6 +373,8 @@ func (h *WebsocketHandler) handleUpgradeRequest(w http.ResponseWriter, r *http.R } handler.request.Header.Set(fromInitialPayloadConfig.ExportToken.HeaderKey, jwtToken) } + + requestContext.expressionContext.Request.Auth = expr.LoadAuth(handler.request.Context()) } // Only when epoll/kqueue is available. On Windows, epoll is not available @@ -747,7 +753,7 @@ func (h *WebSocketConnectionHandler) writeErrorMessage(operationID string, err e return h.protocol.WriteGraphQLErrors(operationID, payload, nil) } -func (h *WebSocketConnectionHandler) parseAndPlan(payload []byte) (*ParsedOperation, *operationContext, error) { +func (h *WebSocketConnectionHandler) parseAndPlan(registration *SubscriptionRegistration) (*ParsedOperation, *operationContext, error) { operationKit, err := h.operationProcessor.NewKit() if err != nil { @@ -759,7 +765,7 @@ func (h *WebSocketConnectionHandler) parseAndPlan(payload []byte) (*ParsedOperat clientInfo: h.plannerOptions.ClientInfo, } - if err := operationKit.UnmarshalOperationFromBody(payload); err != nil { + if err := operationKit.UnmarshalOperationFromBody(registration.msg.Payload); err != nil { return nil, nil, err } @@ -792,7 +798,12 @@ func (h *WebSocketConnectionHandler) parseAndPlan(payload []byte) (*ParsedOperat opContext.name = operationKit.parsedOperation.Request.OperationName opContext.opType = operationKit.parsedOperation.Type - if blocked := h.operationBlocker.OperationIsBlocked(operationKit.parsedOperation); blocked != nil { + reqCtx := getRequestContext(registration.clientRequest.Context()) + if reqCtx == nil { + return nil, nil, fmt.Errorf("request context not found") + } + + if blocked := h.operationBlocker.OperationIsBlocked(h.logger, reqCtx.expressionContext, operationKit.parsedOperation); blocked != nil { return nil, nil, blocked } @@ -848,7 +859,7 @@ func (h *WebSocketConnectionHandler) executeSubscription(registration *Subscript rw := newWebsocketResponseWriter(registration.msg.ID, h.protocol, h.graphqlHandler.subgraphErrorPropagation.Enabled, h.logger, h.stats) - _, operationCtx, err := h.parseAndPlan(registration.msg.Payload) + _, operationCtx, err := h.parseAndPlan(registration) if err != nil { wErr := h.writeErrorMessage(registration.msg.ID, err) if wErr != nil { @@ -919,7 +930,7 @@ func (h *WebSocketConnectionHandler) executeSubscription(registration *Subscript } resolveCtx = h.graphqlHandler.configureRateLimiting(resolveCtx) - // Put in a closure to evaluate err after the defer + // Put in a closure to evaluate err after defer defer func() { // StatusCode has no meaning here. We set it to 0 but set the error. h.metrics.ExportSchemaUsageInfo(operationCtx, 0, err != nil, false) diff --git a/router/go.mod b/router/go.mod index dc9a1a5185..c8271b6cff 100644 --- a/router/go.mod +++ b/router/go.mod @@ -64,6 +64,7 @@ require ( github.com/KimMachineGun/automemlimit v0.6.1 github.com/bep/debounce v1.2.1 github.com/caarlos0/env/v11 v11.1.0 + github.com/expr-lang/expr v1.16.9 github.com/fsnotify/fsnotify v1.7.0 github.com/klauspost/compress v1.17.9 github.com/minio/minio-go/v7 v7.0.74 diff --git a/router/go.sum b/router/go.sum index ccc2866a06..f187fb498d 100644 --- a/router/go.sum +++ b/router/go.sum @@ -54,6 +54,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/expr-lang/expr v1.16.9 h1:WUAzmR0JNI9JCiF0/ewwHB1gmcGw5wW7nWt8gc6PpCI= +github.com/expr-lang/expr v1.16.9/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= diff --git a/router/internal/expr/expr.go b/router/internal/expr/expr.go new file mode 100644 index 0000000000..a2d1348167 --- /dev/null +++ b/router/internal/expr/expr.go @@ -0,0 +1,181 @@ +package expr + +import ( + "context" + "errors" + "fmt" + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/file" + "github.com/expr-lang/expr/vm" + "github.com/wundergraph/cosmo/router/pkg/authentication" + "net/http" + "net/url" + "reflect" +) + +/** +* Naming conventions: +* - Fields are named using camelCase +* - Methods are named using PascalCase (Required to be exported) +* - Methods should be exported through the interface to make the contract clear +* +* Principles: +* The Expr package is used to evaluate expressions in the context of the request or router. +* The user should never be able to mutate the context or any other application state. +* +* Recommendations: +* If possible function calls should be avoided in the expressions as they are much more expensive. +* See https://github.com/expr-lang/expr/issues/734 + */ + +// Context is the context for expressions parser when evaluating dynamic expressions +type Context struct { + Request Request `expr:"request"` +} + +// Request is the context for the request object in expressions. Be aware, that only value receiver methods +// are exported in the expr environment. This is because the expressions are evaluated in a read-only context. +type Request struct { + Auth RequestAuth `expr:"auth"` + URL RequestURL `expr:"url"` + Header RequestHeaders `expr:"header"` +} + +// RequestHeaders is the interface available for the headers object in expressions. +type RequestHeaders interface { + // Get returns the value of the header with the given key. If the header is not present, an empty string is returned. + // The key is case-insensitive and transformed to the canonical format. + Get(key string) string +} + +// RequestURL is the context for the URL object in expressions +// it is limited in scope to the URL object and its components. For convenience, the query parameters are parsed. +type RequestURL struct { + Method string `expr:"method"` + // Scheme is the scheme of the URL + Scheme string `expr:"scheme"` + // Host is the host of the URL + Host string `expr:"host"` + // Path is the path of the URL + Path string `expr:"path"` + // Query is the parsed query parameters + Query map[string]string `expr:"query"` +} + +// LoadRequest loads the request object into the context. +func LoadRequest(req *http.Request) Request { + r := Request{ + Header: req.Header, + } + + m, _ := url.ParseQuery(req.URL.RawQuery) + qv := make(map[string]string, len(m)) + + for k := range m { + qv[k] = m.Get(k) + } + + r.URL = RequestURL{ + Method: req.Method, + Scheme: req.URL.Scheme, + Host: req.URL.Host, + Path: req.URL.Path, + Query: qv, + } + + return r +} + +type RequestAuth struct { + IsAuthenticated bool `expr:"isAuthenticated"` + Type string `expr:"type"` + Claims map[string]any `expr:"claims"` + Scopes []string `expr:"scopes"` +} + +// LoadAuth loads the authentication context into the request object. +// Must only be called when the authentication was successful. +func LoadAuth(ctx context.Context) RequestAuth { + authCtx := authentication.FromContext(ctx) + if authCtx == nil { + return RequestAuth{} + } + + return RequestAuth{ + Type: authCtx.Authenticator(), + IsAuthenticated: true, + Claims: authCtx.Claims(), + Scopes: authCtx.Scopes(), + } +} + +// CompileBoolExpression compiles an expression and returns the program. It is used for expressions that return bool. +// The exprContext is used to provide the context for the expression evaluation. Not safe for concurrent use. +func CompileBoolExpression(s string) (*vm.Program, error) { + v, err := expr.Compile(s, expr.Env(Context{}), expr.AsBool()) + if err != nil { + return nil, handleExpressionError(err) + } + + return v, nil +} + +// CompileStringExpression compiles an expression and returns the program. It is used for expressions that return strings +// The exprContext is used to provide the context for the expression evaluation. Not safe for concurrent use. +func CompileStringExpression(s string) (*vm.Program, error) { + v, err := expr.Compile(s, expr.Env(Context{}), expr.AsKind(reflect.String)) + if err != nil { + return nil, handleExpressionError(err) + } + + return v, nil +} + +// ResolveStringExpression evaluates the expression and returns the result as a string. The exprContext is used to +// provide the context for the expression evaluation. Not safe for concurrent use. +func ResolveStringExpression(vm *vm.Program, ctx Context) (string, error) { + r, err := expr.Run(vm, ctx) + if err != nil { + return "", handleExpressionError(err) + } + + switch v := r.(type) { + case string: + return v, nil + default: + return "", fmt.Errorf("expected string, got %T", r) + } +} + +// ResolveBoolExpression evaluates the expression and returns the result as a bool. The exprContext is used to +// provide the context for the expression evaluation. Not safe for concurrent use. +func ResolveBoolExpression(vm *vm.Program, ctx Context) (bool, error) { + if vm == nil { + return false, nil + } + + r, err := expr.Run(vm, ctx) + if err != nil { + return false, handleExpressionError(err) + } + + switch v := r.(type) { + case bool: + return v, nil + default: + return false, fmt.Errorf("failed to run expression: expected bool, got %T", r) + } +} + +func handleExpressionError(err error) error { + if err == nil { + return nil + } + + var fileError *file.Error + if errors.As(err, &fileError) { + return fmt.Errorf("line %d, column %d: %s", fileError.Line, fileError.Column, fileError.Message) + } + + return err +} diff --git a/router/pkg/authentication/authentication.go b/router/pkg/authentication/authentication.go index 956cedea58..9c19cade57 100644 --- a/router/pkg/authentication/authentication.go +++ b/router/pkg/authentication/authentication.go @@ -93,12 +93,18 @@ func Authenticate(ctx context.Context, authenticators []Authenticator, p Provide joinedErrors = errors.Join(joinedErrors, err) continue } - if claims != nil { - return &authentication{ - authenticator: auth.Name(), - claims: claims, - }, nil + + // Claims is nil when no authentication information matched the authenticator. + // In that case, we continue to the next authenticator. + if claims == nil { + continue } + + // If authentication succeeds, we return the authentication for the first provider. + return &authentication{ + authenticator: auth.Name(), + claims: claims, + }, nil } // If no authentication failed error will be nil here, // even if to claims were found. diff --git a/router/pkg/authentication/http_header_authenticator.go b/router/pkg/authentication/http_header_authenticator.go index 4fd5bde794..7ee493ea8e 100644 --- a/router/pkg/authentication/http_header_authenticator.go +++ b/router/pkg/authentication/http_header_authenticator.go @@ -32,6 +32,7 @@ func (a *httpHeaderAuthenticator) Name() string { func (a *httpHeaderAuthenticator) Authenticate(ctx context.Context, p Provider) (Claims, error) { headers := p.AuthenticationHeaders() var errs error + for _, header := range a.headerNames { authorization := headers.Get(header) for _, prefix := range a.headerValuePrefixes { @@ -42,6 +43,11 @@ func (a *httpHeaderAuthenticator) Authenticate(ctx context.Context, p Provider) errs = errors.Join(errs, fmt.Errorf("could not validate token: %w", err)) continue } + // If claims is nil, we should return an empty Claims map to signal that the + // authentication was successful, but no claims were found. + if claims == nil { + claims = make(Claims) + } return claims, nil } } diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 44288fb2e9..a2318cc01d 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -334,10 +334,15 @@ type EngineExecutionConfiguration struct { EnableSubgraphFetchOperationName bool `envDefault:"false" env:"ENGINE_ENABLE_SUBGRAPH_FETCH_OPERATION_NAME" yaml:"enable_subgraph_fetch_operation_name"` } +type BlockOperationConfiguration struct { + Enabled bool `yaml:"enabled" envDefault:"false" env:"ENABLED"` + Condition string `yaml:"condition" env:"CONDITION"` +} + type SecurityConfiguration struct { - BlockMutations bool `yaml:"block_mutations" envDefault:"false" env:"SECURITY_BLOCK_MUTATIONS"` - BlockSubscriptions bool `yaml:"block_subscriptions" envDefault:"false" env:"SECURITY_BLOCK_SUBSCRIPTIONS"` - BlockNonPersistedOperations bool `yaml:"block_non_persisted_operations" envDefault:"false" env:"SECURITY_BLOCK_NON_PERSISTED_OPERATIONS"` + BlockMutations BlockOperationConfiguration `yaml:"block_mutations" envPrefix:"SECURITY_BLOCK_MUTATIONS_"` + BlockSubscriptions BlockOperationConfiguration `yaml:"block_subscriptions" envPrefix:"SECURITY_BLOCK_SUBSCRIPTIONS_"` + BlockNonPersistedOperations BlockOperationConfiguration `yaml:"block_non_persisted_operations" envPrefix:"SECURITY_BLOCK_NON_PERSISTED_OPERATIONS_"` ComplexityCalculationCache *ComplexityCalculationCache `yaml:"complexity_calculation_cache"` ComplexityLimits *ComplexityLimits `yaml:"complexity_limits"` DepthLimit *QueryDepthConfiguration `yaml:"depth_limit"` diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index cc7cd0fab7..2919806a37 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -824,7 +824,12 @@ "context_field": { "type": "string", "description": "The field name of the context from which to extract the value. The value is only extracted when a context is available otherwise the default value is used.", - "enum": ["operation_service_names", "graphql_error_codes", "graphql_error_service_names", "operation_sha256"] + "enum": [ + "operation_service_names", + "graphql_error_codes", + "graphql_error_service_names", + "operation_sha256" + ] } } } @@ -1668,54 +1673,54 @@ }, "cache_warmup": { "type": "object", - "description": "Cache Warmup pre-warms all caches (e.g. normalization, validation, planning) before accepting traffic.", - "additionalProperties": false, - "properties": { - "enabled": { - "type": "boolean", - "description": "Enable the cache warmup.", - "default": false - }, - "source": { - "type": "string", - "description": "The source of the cache warmup items can be filesystem, cdn (Cosmo), or s3.", - "enum": ["filesystem"] - }, - "workers": { - "type": "integer", - "description": "The number of workers for the cache warmup to run in parallel. Higher numbers decrease the time to warm up the cache but increase the load on the system.", - "default": 8 - }, - "items_per_second": { - "type": "integer", - "description": "The number of cache warmup items to process per second. Higher numbers decrease the time to warm up the cache but increase the load on the system.", - "default": 50 - }, - "timeout": { - "type": "string", - "description": "The timeout for warming up the cache. This can be used to limit the amount of time cache warming will block deploying a new config. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", - "default": "30s", - "duration": { - "minimum": "1s" - } - }, - "if": { - "properties": { - "source": { - "const": "filesystem" - } + "description": "Cache Warmup pre-warms all caches (e.g. normalization, validation, planning) before accepting traffic.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Enable the cache warmup.", + "default": false + }, + "source": { + "type": "string", + "description": "The source of the cache warmup items can be filesystem, cdn (Cosmo), or s3.", + "enum": ["filesystem"] + }, + "workers": { + "type": "integer", + "description": "The number of workers for the cache warmup to run in parallel. Higher numbers decrease the time to warm up the cache but increase the load on the system.", + "default": 8 + }, + "items_per_second": { + "type": "integer", + "description": "The number of cache warmup items to process per second. Higher numbers decrease the time to warm up the cache but increase the load on the system.", + "default": 50 + }, + "timeout": { + "type": "string", + "description": "The timeout for warming up the cache. This can be used to limit the amount of time cache warming will block deploying a new config. The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'.", + "default": "30s", + "duration": { + "minimum": "1s" + } + }, + "if": { + "properties": { + "source": { + "const": "filesystem" } - }, - "then": { - "properties": { - "path": { - "type": "string", - "description": "The path to the directory containing the cache warmup items.", - "format": "file-path" - } + } + }, + "then": { + "properties": { + "path": { + "type": "string", + "description": "The path to the directory containing the cache warmup items.", + "format": "file-path" } } } + } }, "router_config_path": { "type": "string", @@ -1787,35 +1792,66 @@ "additionalProperties": false, "properties": { "block_mutations": { - "type": "boolean", - "default": false, - "description": "Block mutation Operations. If the value is true, the mutations are blocked." + "type": "object", + "description": "The configuration for blocking mutations.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Block mutation Operations. If the value is true, all operations are blocked. You can also specify a condition that is evaluated to determine if the mutation should be blocked." + }, + "condition": { + "type": "string", + "description": "The expression to evaluate if the mutation should be blocked. The expression is specified as a string and needs to evaluate to a boolean. Please see https://expr-lang.org/ for more information." + } + } }, "block_subscriptions": { - "type": "boolean", - "description": "Block subscription Operations. If the value is true, the subscriptions are blocked." + "type": "object", + "description": "The configuration for blocking subscriptions.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Block subscription Operations. If the value is true, all operations are blocked. You can also specify a condition that is evaluated to determine if the subscription should be blocked." + }, + "condition": { + "type": "string", + "description": "The expression to evaluate if the subscription should be blocked. The expression is specified as a string and needs to evaluate to a boolean. Please see https://expr-lang.org/ for more information." + } + } }, "block_non_persisted_operations": { - "type": "boolean", - "default": false, - "description": "Block non-persisted Operations. If the value is true, the non-persisted operations are blocked." + "type": "object", + "description": "The configuration for blocking non-persisted operations.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "description": "Block non-persisted operations. If the value is true, all Operations are blocked. You can also specify a condition that is evaluated to determine if the non-persisted operation should be blocked." + }, + "condition": { + "type": "string", + "description": "The expression to evaluate if the non-persisted operation should be blocked. The expression is specified as a string and needs to evaluate to a boolean. Please see https://expr-lang.org/ for more information." + } + } }, "complexity_calculation_cache": { "type": "object", - "description": "The configuration for the complexity calculation cache. The complexity calculation cache is used to cache the complexity calculation for the queries.", - "additionalProperties": false, - "properties": { - "enabled": { - "type": "boolean", - "default": true, - "description": "Enable the complexity calculation cache. If the value is true, the complexity calculation cache is enabled." - }, - "size": { - "type": "integer", - "default": 1024, - "description": "The size of the cache for the complexity calculation." - } + "description": "The configuration for the complexity calculation cache. The complexity calculation cache is used to cache the complexity calculation for the queries.", + "additionalProperties": false, + "properties": { + "enabled": { + "type": "boolean", + "default": true, + "description": "Enable the complexity calculation cache. If the value is true, the complexity calculation cache is enabled." + }, + "size": { + "type": "integer", + "default": 1024, + "description": "The size of the cache for the complexity calculation." } + } }, "complexity_limits": { "type": "object", @@ -2113,9 +2149,9 @@ "description": "The size of the validation cache." }, "enable_subgraph_fetch_operation_name": { - "type": "boolean", - "default": false, - "description": "Enable appending the operation name to subgraph fetches. This will ensure that the operation name will be included in the corresponding subgraph requests using the following format: $operationName__$subgraphName__$sequenceID." + "type": "boolean", + "default": false, + "description": "Enable appending the operation name to subgraph fetches. This will ensure that the operation name will be included in the corresponding subgraph requests using the following format: $operationName__$subgraphName__$sequenceID." } } }, @@ -2372,9 +2408,7 @@ "context_field": { "type": "string", "description": "The field name of the context from which to extract the value. The value is only extracted when a context is available otherwise the default value is used.", - "enum": [ - "operation_name" - ] + "enum": ["operation_name"] } } } diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 4668bd6163..e1b863c623 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -218,9 +218,18 @@ "Subgraphs": null }, "SecurityConfiguration": { - "BlockMutations": false, - "BlockSubscriptions": false, - "BlockNonPersistedOperations": false, + "BlockMutations": { + "Enabled": false, + "Condition": "" + }, + "BlockSubscriptions": { + "Enabled": false, + "Condition": "" + }, + "BlockNonPersistedOperations": { + "Enabled": false, + "Condition": "" + }, "ComplexityCalculationCache": null, "ComplexityLimits": null, "DepthLimit": null diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index aad027e306..9aeca91f63 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -435,9 +435,18 @@ } }, "SecurityConfiguration": { - "BlockMutations": false, - "BlockSubscriptions": false, - "BlockNonPersistedOperations": false, + "BlockMutations": { + "Enabled": false, + "Condition": "" + }, + "BlockSubscriptions": { + "Enabled": false, + "Condition": "" + }, + "BlockNonPersistedOperations": { + "Enabled": false, + "Condition": "" + }, "ComplexityCalculationCache": { "Enabled": true, "CacheSize": 1024 diff --git a/router/pkg/watcher/watcher_test.go b/router/pkg/watcher/watcher_test.go index 0716b9c0a5..219449c57d 100644 --- a/router/pkg/watcher/watcher_test.go +++ b/router/pkg/watcher/watcher_test.go @@ -3,7 +3,6 @@ package watcher_test import ( "context" "errors" - "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/watcher" @@ -91,13 +90,9 @@ func TestCreate(t *testing.T) { events, err := getEvent(eventCh) require.NoError(t, err) - // For debugging - if len(events) > 1 { - fmt.Printf("event-1: op: %v, path: %v\n", events[0].Op, events[0].Path) - fmt.Printf("events-2: op: %v, path: %v\n", events[1].Op, events[1].Path) - } - - assert.Len(t, events, 1) + // In rare circumstances (Depends on the OS and the watcher implementation) + // we might get multiple events for the same file (create, update) + assert.True(t, len(events) == 1 || len(events) == 2, "expected 1 or 2 events, got %d", len(events)) assert.Equal(t, events[0].Path, tempFile) assert.Equal(t, events[0].Op, watcher.OpCreate) return true