From ac0f6c9ea86f9c8ce66346e4bfb1a3109f79f2ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20Irmak?= Date: Tue, 19 Dec 2023 15:04:03 +0300 Subject: [PATCH] Allow using an API key when querying feeder/gateway (#1579) --- clients/feeder/feeder.go | 25 +++++++++++++++++++------ clients/gateway/gateway.go | 20 +++++++++++++++++--- cmd/juno/juno.go | 4 ++++ cmd/juno/juno_test.go | 3 ++- node/node.go | 7 +++++-- 5 files changed, 47 insertions(+), 12 deletions(-) diff --git a/clients/feeder/feeder.go b/clients/feeder/feeder.go index 8bfb297745..af3c519608 100644 --- a/clients/feeder/feeder.go +++ b/clients/feeder/feeder.go @@ -17,6 +17,8 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/starknet" "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type Backoff func(wait time.Duration) time.Duration @@ -30,6 +32,7 @@ type Client struct { minWait time.Duration log utils.SimpleLogger userAgent string + apiKey string listener EventListener } @@ -73,6 +76,11 @@ func (c *Client) WithTimeout(t time.Duration) *Client { return c } +func (c *Client) WithAPIKey(key string) *Client { + c.apiKey = key + return c +} + func ExponentialBackoff(wait time.Duration) time.Duration { return wait * 2 } @@ -83,11 +91,12 @@ func NopBackoff(d time.Duration) time.Duration { // NewTestClient returns a client and a function to close a test server. func NewTestClient(t *testing.T, network utils.Network) *Client { - srv := newTestServer(network) + srv := newTestServer(t, network) t.Cleanup(srv.Close) ua := "Juno/v0.0.1-test Starknet Implementation" + apiKey := "API_KEY" - c := NewClient(srv.URL).WithBackoff(NopBackoff).WithMaxRetries(0).WithUserAgent(ua) + c := NewClient(srv.URL).WithBackoff(NopBackoff).WithMaxRetries(0).WithUserAgent(ua).WithAPIKey(apiKey) c.client = &http.Client{ Transport: &http.Transport{ // On macOS tests often fail with the following error: @@ -106,7 +115,7 @@ func NewTestClient(t *testing.T, network utils.Network) *Client { return c } -func newTestServer(network utils.Network) *httptest.Server { +func newTestServer(t *testing.T, network utils.Network) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { queryMap, err := url.ParseQuery(r.URL.RawQuery) if err != nil { @@ -114,10 +123,11 @@ func newTestServer(network utils.Network) *httptest.Server { return } + assert.Equal(t, []string{"API_KEY"}, r.Header["X-Throttling-Bypass"]) + assert.Equal(t, []string{"Juno/v0.0.1-test Starknet Implementation"}, r.Header["User-Agent"]) + wd, err := os.Getwd() - if err != nil { - panic(err) - } + require.NoError(t, err) base := wd[:strings.LastIndex(wd, "juno")+4] queryArg := "" @@ -231,6 +241,9 @@ func (c *Client) get(ctx context.Context, queryURL string) (io.ReadCloser, error if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } + if c.apiKey != "" { + req.Header.Set("X-Throttling-Bypass", c.apiKey) + } reqTimer := time.Now() res, err = c.client.Do(req) diff --git a/clients/gateway/gateway.go b/clients/gateway/gateway.go index 5b0475fea5..425d5face6 100644 --- a/clients/gateway/gateway.go +++ b/clients/gateway/gateway.go @@ -15,6 +15,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/utils" + "github.com/stretchr/testify/assert" ) var ( @@ -40,6 +41,7 @@ type Client struct { timeout time.Duration log utils.SimpleLogger userAgent string + apiKey string } func (c *Client) WithUserAgent(ua string) *Client { @@ -47,18 +49,27 @@ func (c *Client) WithUserAgent(ua string) *Client { return c } +func (c *Client) WithAPIKey(key string) *Client { + c.apiKey = key + return c +} + // NewTestClient returns a client and a function to close a test server. func NewTestClient(t *testing.T) *Client { - srv := newTestServer() + srv := newTestServer(t) ua := "Juno/v0.0.1-test Starknet Implementation" + apiKey := "API_KEY" t.Cleanup(srv.Close) - return NewClient(srv.URL, utils.NewNopZapLogger()).WithUserAgent(ua) + return NewClient(srv.URL, utils.NewNopZapLogger()).WithUserAgent(ua).WithAPIKey(apiKey) } -func newTestServer() *httptest.Server { +func newTestServer(t *testing.T) *httptest.Server { // As this is a test sever we are mimic response for one good and one bad request. return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, []string{"API_KEY"}, r.Header["X-Throttling-Bypass"]) + assert.Equal(t, []string{"Juno/v0.0.1-test Starknet Implementation"}, r.Header["User-Agent"]) + b, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -142,6 +153,9 @@ func (c *Client) doPost(ctx context.Context, url string, data any) (*http.Respon if c.userAgent != "" { req.Header.Set("User-Agent", c.userAgent) } + if c.apiKey != "" { + req.Header.Set("X-Throttling-Bypass", c.apiKey) + } return c.client.Do(req) } diff --git a/cmd/juno/juno.go b/cmd/juno/juno.go index 551c06d96a..f498c27ec1 100644 --- a/cmd/juno/juno.go +++ b/cmd/juno/juno.go @@ -63,6 +63,7 @@ const ( remoteDBF = "remote-db" rpcMaxBlockScanF = "rpc-max-block-scan" dbCacheSizeF = "db-cache-size" + gwAPIKeyF = "gw-api-key" //nolint: gosec defaultConfig = "" defaulHost = "localhost" @@ -85,6 +86,7 @@ const ( defaultRemoteDB = "" defaultRPCMaxBlockScan = math.MaxUint defaultCacheSizeMb = 8 + defaultGwAPIKey = "" configFlagUsage = "The yaml configuration file." logLevelFlagUsage = "Options: debug, info, warn, error." @@ -117,6 +119,7 @@ const ( remoteDBUsage = "gRPC URL of a remote Juno node" rpcMaxBlockScanUsage = "Maximum number of blocks scanned in single starknet_getEvents call" dbCacheSizeUsage = "Determines the amount of memory (in megabytes) allocated for caching data in the database." + gwAPIKeyUsage = "API key for gateway endpoints to avoid throttling" //nolint: gosec ) var Version string @@ -245,6 +248,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr junoCmd.Flags().String(remoteDBF, defaultRemoteDB, remoteDBUsage) junoCmd.Flags().Uint(rpcMaxBlockScanF, defaultRPCMaxBlockScan, rpcMaxBlockScanUsage) junoCmd.Flags().Uint(dbCacheSizeF, defaultCacheSizeMb, dbCacheSizeUsage) + junoCmd.Flags().String(gwAPIKeyF, defaultGwAPIKey, gwAPIKeyUsage) return junoCmd } diff --git a/cmd/juno/juno_test.go b/cmd/juno/juno_test.go index 0386c84adf..d788dfc094 100644 --- a/cmd/juno/juno_test.go +++ b/cmd/juno/juno_test.go @@ -462,7 +462,7 @@ network: goerli }, "some setting set in both env variables and config file": { cfgFileContents: `db-path: /home/file/.juno`, - env: []string{"JUNO_DB_PATH", "/home/env/.juno"}, + env: []string{"JUNO_DB_PATH", "/home/env/.juno", "JUNO_GW_API_KEY", "apikey"}, expectedConfig: &node.Config{ LogLevel: defaultLogLevel, HTTP: defaultHTTP, @@ -488,6 +488,7 @@ network: goerli MaxVMQueue: 2 * defaultMaxVMs, RPCMaxBlockScan: defaultRPCMaxBlockScan, DBCacheSize: defaultMaxCacheSize, + GatewayAPIKey: "apikey", }, }, } diff --git a/node/node.go b/node/node.go index 29deb387d5..d6d3369f19 100644 --- a/node/node.go +++ b/node/node.go @@ -77,6 +77,8 @@ type Config struct { RPCMaxBlockScan uint `mapstructure:"rpc-max-block-scan"` DBCacheSize uint `mapstructure:"db-cache-size"` + + GatewayAPIKey string `mapstructure:"gw-api-key"` } type Node struct { @@ -119,10 +121,11 @@ func New(cfg *Config, version string) (*Node, error) { //nolint:gocyclo,funlen chain := blockchain.New(database, cfg.Network, log) feederClientTimeout := 5 * time.Second - client := feeder.NewClient(cfg.Network.FeederURL()).WithUserAgent(ua).WithLogger(log).WithTimeout(feederClientTimeout) + client := feeder.NewClient(cfg.Network.FeederURL()).WithUserAgent(ua).WithLogger(log). + WithTimeout(feederClientTimeout).WithAPIKey(cfg.GatewayAPIKey) synchronizer := sync.New(chain, adaptfeeder.New(client), log, cfg.PendingPollInterval, dbIsRemote) services = append(services, synchronizer) - gatewayClient := gateway.NewClient(cfg.Network.GatewayURL(), log).WithUserAgent(ua) + gatewayClient := gateway.NewClient(cfg.Network.GatewayURL(), log).WithUserAgent(ua).WithAPIKey(cfg.GatewayAPIKey) throttledVM := NewThrottledVM(vm.New(log), cfg.MaxVMs, int32(cfg.MaxVMQueue)) rpcHandler := rpc.New(chain, synchronizer, cfg.Network, gatewayClient, client, throttledVM, version, log)