Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request caching + add ratelimit-over-408 + endpoint rewrite support #13

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
all:
CGO_ENABLED=0 go build -v
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
nirn-proxy
57 changes: 57 additions & 0 deletions lib/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package lib

import (
"net/http"
"time"
)

type CacheEntry struct {
Data []byte
CreatedAt time.Time
ExpiresIn *time.Duration
Headers http.Header
}

func (c *CacheEntry) Expired() bool {
if c.ExpiresIn == nil {
return false
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
}
return time.Since(c.CreatedAt) > *c.ExpiresIn
}

type Cache struct {
entries map[string]*CacheEntry
}

func NewCache() *Cache {
return &Cache{
entries: make(map[string]*CacheEntry),
}
}

func (c *Cache) Get(key string) *CacheEntry {
entry, ok := c.entries[key]

if !ok {
return nil
}

if entry.Expired() {
c.Delete(key)
return nil
}

return entry
}

func (c *Cache) Set(key string, entry *CacheEntry) {
c.entries[key] = entry
}

func (c *Cache) Delete(key string) {
delete(c.entries, key)
}

func (c *Cache) Clear() {
c.entries = make(map[string]*CacheEntry)
}
171 changes: 160 additions & 11 deletions lib/discord.go
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package lib

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"github.com/sirupsen/logrus"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/sirupsen/logrus"
)

var client *http.Client
Expand All @@ -22,8 +24,58 @@ var contextTimeout time.Duration

var globalOverrideMap = make(map[string]uint)

var endpointCache = make(map[string]*Cache)

var disableRestLimitDetection = false

var cacheEndpoints = map[string]time.Duration{
"/api/users/@me": 10 * time.Minute,
"/api/v9/users/@me": 10 * time.Minute,
"/api/v10/users/@me": 10 * time.Minute,
"/api/gateway": 60 * time.Minute,
"/api/v9/gateway": 60 * time.Minute,
"/api/v10/gateway": 60 * time.Minute,
"/api/gateway/bot": 30 * time.Minute,
"/api/v9/gateway/bot": 30 * time.Minute,
"/api/v10/gateway/bot": 30 * time.Minute,
"/api/v9/applications/@me": 5 * time.Minute,
"/api/v10/applications/@me": 5 * time.Minute,
}

var wsProxy string
var ratelimitOver408 bool

func init() {
if len(os.Args) > 1 {
for _, arg := range os.Args[1:] {
argSplit := strings.SplitN(arg, "=", 2)

if len(argSplit) < 2 {
argSplit = append(argSplit, "")
}

switch argSplit[0] {
case "ws-proxy":
wsProxy = argSplit[1]
case "port":
os.Setenv("PORT", argSplit[1])
case "ratelimit-over-408":
ratelimitOver408 = true
default:
logrus.Fatal("Unknown argument: ", argSplit[0])
}
}
}

if os.Getenv("WS_PROXY") != "" {
wsProxy = os.Getenv("WS_PROXY")
}
cheesycod marked this conversation as resolved.
Show resolved Hide resolved

if os.Getenv("RATELIMIT_OVER_408") != "" {
ratelimitOver408 = os.Getenv("RATELIMIT_OVER_408") == "true"
}
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
}

type BotGatewayResponse struct {
SessionStartLimit map[string]int `json:"session_start_limit"`
}
Expand Down Expand Up @@ -161,7 +213,7 @@ func GetBotGlobalLimit(token string, user *BotUserResponse) (uint, error) {
return 0, errors.New("500 on gateway/bot")
}

body, _ := ioutil.ReadAll(bot.Body)
body, _ := io.ReadAll(bot.Body)

var s BotGatewayResponse

Expand Down Expand Up @@ -200,7 +252,7 @@ func GetBotUser(token string) (*BotUserResponse, error) {
return nil, errors.New("500 on users/@me")
}

body, _ := ioutil.ReadAll(bot.Body)
body, _ := io.ReadAll(bot.Body)

var s BotUserResponse

Expand All @@ -213,6 +265,48 @@ func GetBotUser(token string) (*BotUserResponse, error) {
}

func doDiscordReq(ctx context.Context, path string, method string, body io.ReadCloser, header http.Header, query string) (*http.Response, error) {
identifier := ctx.Value("identifier")
if identifier == nil {
// Queues always have an identifier, if there's none in the context, we called the method from outside a queue
identifier = "Internal"
}

logger.Info(method, path+"?"+query)

identifierStr, ok := identifier.(string)

if ok {
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
// Check endpoint cache
if endpointCache[identifierStr] != nil {
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
cacheEntry := endpointCache[identifierStr].Get(path)

if cacheEntry != nil {
// Send cached response
logger.WithFields(logrus.Fields{
"method": method,
"path": path,
"status": "200 (cached)",
}).Debug("Discord request")

headers := cacheEntry.Headers.Clone()
headers.Set("X-Cached", "true")

// Set rl headers so bot won't be perpetually stuck
headers.Set("X-RateLimit-Limit", "5")
headers.Set("X-RateLimit-Remaining", "5")
headers.Set("X-RateLimit-Bucket", "cache")

return &http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewBuffer(cacheEntry.Data)),
Header: headers,
}, nil
}
} else {
endpointCache[identifierStr] = NewCache()
}
}

discordReq, err := http.NewRequestWithContext(ctx, method, "https://discord.com"+path+"?"+query, body)
if err != nil {
return nil, err
Expand All @@ -222,12 +316,6 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC
startTime := time.Now()
discordResp, err := client.Do(discordReq)

identifier := ctx.Value("identifier")
if identifier == nil {
// Queues always have an identifier, if there's none in the context, we called the method from outside a queue
identifier = "Internal"
}

if err == nil {
route := GetMetricsPath(path)
status := discordResp.Status
Expand All @@ -242,6 +330,44 @@ func doDiscordReq(ctx context.Context, path string, method string, body io.ReadC

RequestHistogram.With(map[string]string{"route": route, "status": status, "method": method, "clientId": identifier.(string)}).Observe(elapsed)
}

if wsProxy != "" {
if path == "/api/gateway" || path == "/api/v9/gateway" || path == "/api/gateway/bot" || path == "/api/v10/gateway/bot" {
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
var data map[string]any

err := json.NewDecoder(discordResp.Body).Decode(&data)

if err != nil {
return nil, err
}

data["url"] = wsProxy

bytes, err := json.Marshal(data)

if err != nil {
return nil, err
}

discordResp.Body = io.NopCloser(strings.NewReader(string(bytes)))
}
}

if expiry, ok := cacheEndpoints[path]; ok {
if discordResp.StatusCode == 200 {
body, _ := io.ReadAll(discordResp.Body)
endpointCache[identifierStr].Set(path, &CacheEntry{
Data: body,
CreatedAt: time.Now(),
ExpiresIn: &expiry,
Headers: discordResp.Header,
})

// Put body back into response
discordResp.Body = io.NopCloser(bytes.NewBuffer(body))
}
}

return discordResp, err
}

Expand All @@ -255,7 +381,30 @@ func ProcessRequest(ctx context.Context, item *QueueItem) (*http.Response, error

if err != nil {
if ctx.Err() == context.DeadlineExceeded {
res.WriteHeader(408)
if ratelimitOver408 {
res.WriteHeader(429)
res.Header().Add("Reset-After", "3")

// Set rl headers so bot won't be perpetually stuck
if res.Header().Get("X-RateLimit-Limit") == "" {
res.Header().Set("X-RateLimit-Limit", "5")
}
if res.Header().Get("X-RateLimit-Remaining") == "" {
res.Header().Set("X-RateLimit-Remaining", "0")
}

if res.Header().Get("X-RateLimit-Bucket") == "" {
res.Header().Set("X-RateLimit-Bucket", "proxyTimeout")
}

// Default to 'shared' so the bot doesn't think its
// against them
if res.Header().Get("X-RateLimit-Scope") == "" {
res.Header().Set("X-RateLimit-Scope", "shared")
}
} else {
res.WriteHeader(408)
}
} else {
res.WriteHeader(500)
}
Expand Down
9 changes: 5 additions & 4 deletions lib/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package lib
import (
"context"
"errors"
"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/Clever/leakybucket"
"github.com/Clever/leakybucket/memory"
"github.com/sirupsen/logrus"
)

type QueueItem struct {
Expand Down Expand Up @@ -82,7 +83,7 @@ func NewRequestQueue(processor func(ctx context.Context, item *QueueItem) (*http
identifier := "NoAuth"
if user != nil {
queueType = Bot
identifier = user.Username + "#" + user.Discrim
identifier = user.Id
cheesycod marked this conversation as resolved.
Show resolved Hide resolved
}

if queueType == Bearer {
Expand Down