diff --git a/gitignore b/gitignore new file mode 100644 index 0000000..4bddeaf --- /dev/null +++ b/gitignore @@ -0,0 +1 @@ +nirn-proxy diff --git a/lib/cache.go b/lib/cache.go new file mode 100644 index 0000000..378105c --- /dev/null +++ b/lib/cache.go @@ -0,0 +1,54 @@ +package lib + +import ( + "net/http" + "time" +) + +type CacheEntry struct { + Data []byte + CreatedAt time.Time + ExpiresIn time.Duration + Headers http.Header +} + +func (c *CacheEntry) Expired() bool { + return time.Since(c.CreatedAt) > c.ExpiresIn +} + +type Cache struct { + entries map[string]*CacheEntry +} + +func NewCache() *Cache { + return &Cache{ + entries: make(map[string]*CacheEntry), + } +} + +func (c *Cache) Get(key string) *CacheEntry { + entry, ok := c.entries[key] + + if !ok { + return nil + } + + if entry.Expired() { + c.Delete(key) + return nil + } + + return entry +} + +func (c *Cache) Set(key string, entry *CacheEntry) { + c.entries[key] = entry +} + +func (c *Cache) Delete(key string) { + delete(c.entries, key) +} + +func (c *Cache) Clear() { + c.entries = make(map[string]*CacheEntry) +} diff --git a/lib/discord.go b/lib/discord.go index cf99f86..74b0904 100644 --- a/lib/discord.go +++ b/lib/discord.go @@ -1,19 +1,22 @@ package lib import ( + "bytes" "context" "crypto/tls" "encoding/json" "errors" - "github.com/sirupsen/logrus" "io" - "io/ioutil" "math" "net" "net/http" + "os" + "path/filepath" "strconv" "strings" "time" + + "github.com/sirupsen/logrus" ) var client *http.Client @@ -22,8 +25,98 @@ var contextTimeout time.Duration var globalOverrideMap = make(map[string]uint) +var endpointCache = make(map[string]*Cache) + var disableRestLimitDetection = false +// List of endpoints to cache and their expiry times +var useEndpointCache bool +var cacheEndpoints = map[string]time.Duration{ + "/api/users/@me": 10 * time.Minute, + "/api/v*/users/@me": 10 * time.Minute, + "/api/gateway": 60 * time.Minute, + "/api/v*/gateway": 60 * time.Minute, + "/api/gateway/*": 30 * time.Minute, + "/api/v*/gateway/*": 30 * time.Minute, + "/api/v*/applications/@me": 5 * time.Minute, +} + +// In some cases, we may want to transparently rewrite endpoints +// +// For example, when using a gateway proxy, the proxy may provide its own /api/gateway/bot endpoint +// +// This allows transparently rewriting the endpoint to the proxy's +var endpointRewrite = map[string]string{} + +var wsProxy string +var ratelimitOver408 bool + +func init() { + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + argSplit := strings.SplitN(arg, "=", 2) + + if len(argSplit) < 2 { + argSplit = append(argSplit, "") + } + + switch argSplit[0] { + case "ws-proxy": + wsProxy = argSplit[1] + case "port": + os.Setenv("PORT", argSplit[1]) + case "ratelimit-over-408": + ratelimitOver408 = true + case "use-endpoint-cache": + useEndpointCache = true + case "cache-endpoints": + if argSplit[1] == "" { + continue + } + + if argSplit[1] == "false" { + cacheEndpoints = make(map[string]time.Duration) + } else { + var endpoints map[string]time.Duration + + err := json.Unmarshal([]byte(argSplit[1]), &endpoints) + + if err != nil { + logrus.Fatal("Failed to parse cache-endpoints: ", err) + } + + cacheEndpoints = endpoints + } + case "endpoint-rewrite": + for _, rewrite := range strings.Split(argSplit[1], ",") { + // split by '->' + rewriteSplit := strings.Split(rewrite, "@") + + if len(rewriteSplit) != 2 { + logrus.Fatal("Invalid endpoint rewrite: ", rewrite) + } + + endpointRewrite[rewriteSplit[0]] = rewriteSplit[1] + } + default: + logrus.Fatal("Unknown argument: ", argSplit[0]) + } + } + } + + if wsProxy == "" { + wsProxy = EnvGet("WS_PROXY", "") + } + + if !ratelimitOver408 { + ratelimitOver408 = EnvGetBool("RATELIMIT_OVER_408", false) + } + + if !useEndpointCache { + useEndpointCache = EnvGetBool("USE_ENDPOINT_CACHE", false) + } +} + type BotGatewayResponse struct { SessionStartLimit map[string]int `json:"session_start_limit"` } @@ -161,7 +254,7 @@ func GetBotGlobalLimit(token string, user *BotUserResponse) (uint, error) { return 0, errors.New("500 on gateway/bot") } - body, _ := ioutil.ReadAll(bot.Body) + body, _ := io.ReadAll(bot.Body) var s BotGatewayResponse @@ -200,7 +293,7 @@ func GetBotUser(token string) (*BotUserResponse, error) { return nil, errors.New("500 on users/@me") } - body, _ := ioutil.ReadAll(bot.Body) + body, _ := io.ReadAll(bot.Body) var s BotUserResponse @@ -213,7 +306,63 @@ func GetBotUser(token string) (*BotUserResponse, error) { } func doDiscordReq(ctx context.Context, path string, method string, body io.ReadCloser, header http.Header, query string) (*http.Response, error) { - discordReq, err := http.NewRequestWithContext(ctx, method, "https://discord.com"+path+"?"+query, body) + identifier := ctx.Value("identifier") + if identifier == nil { + identifier = "Internal" + } + + logger.Info(method, " ", path+"?"+query) + + identifierStr, ok := identifier.(string) + + if ok { + if useEndpointCache && identifier != "internal" { + cache, ok := endpointCache[identifierStr] + + if !ok { + endpointCache[identifierStr] = NewCache() + cache = endpointCache[identifierStr] + } + + // Check endpoint cache + cacheEntry := cache.Get(path) + + if cacheEntry != nil { + // Send cached response + logger.WithFields(logrus.Fields{ + "method": method, + "path": path, + "status": "200 (cached)", + }).Debug("Discord request") + + headers := cacheEntry.Headers.Clone() + headers.Set("X-Cached", "true") + + // Set rl headers so bot won't be perpetually stuck + headers.Set("X-RateLimit-Limit", "5") + headers.Set("X-RateLimit-Remaining", "5") + headers.Set("X-RateLimit-Bucket", "cache") + + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewBuffer(cacheEntry.Data)), + Header: headers, + }, nil + } + } + } + + // Check for a rewrite + var urlBase = "https://discord.com" + for rw := range endpointRewrite { + if ok, _ := filepath.Match(rw, path); ok { + urlBase = endpointRewrite[rw] + break + + } + } + + discordReq, err := http.NewRequestWithContext(ctx, method, urlBase+path+"?"+query, body) if err != nil { return nil, err } @@ -222,12 +371,6 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC startTime := time.Now() discordResp, err := client.Do(discordReq) - identifier := ctx.Value("identifier") - if identifier == nil { - // Queues always have an identifier, if there's none in the context, we called the method from outside a queue - identifier = "Internal" - } - if err == nil { route := GetMetricsPath(path) status := discordResp.Status @@ -242,6 +385,59 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC RequestHistogram.With(map[string]string{"route": route, "status": status, "method": method, "clientId": identifier.(string)}).Observe(elapsed) } + + if wsProxy != "" && discordResp.StatusCode == 200 { + var isGwProxyUrl bool + + if strings.HasSuffix(path, "/gateway") || strings.HasSuffix(path, "/gateway/bot") { + isGwProxyUrl = true + } + + if isGwProxyUrl { + var data map[string]any + + err := json.NewDecoder(discordResp.Body).Decode(&data) + + if err != nil { + return nil, err + } + + data["url"] = wsProxy + + bytes, err := json.Marshal(data) + + if err != nil { + return nil, err + } + + discordResp.Body = io.NopCloser(strings.NewReader(string(bytes))) + } + } + + if useEndpointCache { + var expiry *time.Duration + + for endpoint, exp := range cacheEndpoints { + if ok, _ := filepath.Match(endpoint, path); ok { + expiry = &exp + break + } + } + + if expiry != nil && discordResp.StatusCode == 200 { + body, _ := io.ReadAll(discordResp.Body) + endpointCache[identifierStr].Set(path, &CacheEntry{ + Data: body, + CreatedAt: time.Now(), + ExpiresIn: *expiry, + Headers: discordResp.Header, + }) + + // Put body back into response + discordResp.Body = io.NopCloser(bytes.NewBuffer(body)) + } + } + return discordResp, err } @@ -255,7 +451,30 @@ func ProcessRequest(ctx context.Context, item *QueueItem) (*http.Response, error if err != nil { if ctx.Err() == context.DeadlineExceeded { - res.WriteHeader(408) + if ratelimitOver408 { + res.WriteHeader(429) + res.Header().Add("Reset-After", "3") + + // Set rl headers so bot won't be perpetually stuck + if res.Header().Get("X-RateLimit-Limit") == "" { + res.Header().Set("X-RateLimit-Limit", "5") + } + if res.Header().Get("X-RateLimit-Remaining") == "" { + res.Header().Set("X-RateLimit-Remaining", "0") + } + + if res.Header().Get("X-RateLimit-Bucket") == "" { + res.Header().Set("X-RateLimit-Bucket", "proxyTimeout") + } + + // Default to 'shared' so the bot doesn't think its + // against them + if res.Header().Get("X-RateLimit-Scope") == "" { + res.Header().Set("X-RateLimit-Scope", "shared") + } + } else { + res.WriteHeader(408) + } } else { res.WriteHeader(500) } @@ -279,3 +498,4 @@ func ProcessRequest(ctx context.Context, item *QueueItem) (*http.Response, error return discordResp, nil } + diff --git a/lib/queue.go b/lib/queue.go index 203383a..df0be2c 100644 --- a/lib/queue.go +++ b/lib/queue.go @@ -3,15 +3,16 @@ package lib import ( "context" "errors" - "github.com/Clever/leakybucket" - "github.com/Clever/leakybucket/memory" - "github.com/sirupsen/logrus" "net/http" "strconv" "strings" "sync" "sync/atomic" "time" + + "github.com/Clever/leakybucket" + "github.com/Clever/leakybucket/memory" + "github.com/sirupsen/logrus" ) type QueueItem struct {