diff --git a/base_limiter.go b/base_limiter.go index f4c2dda..1a5b950 100644 --- a/base_limiter.go +++ b/base_limiter.go @@ -1,6 +1,7 @@ package rlutils import ( + "fmt" "net/http" "path/filepath" "strings" @@ -58,3 +59,19 @@ func (l *BaseLimiter) isTargetExtensions(r *http.Request) bool { } return false } +func validateKey(key string) error { + for _, k := range []string{"remote_addr", "host"} { + if k == key { + return nil + } + } + return fmt.Errorf("invalid key: %s", key) +} + +func fillKey(r *http.Request, key string) string { + if key == "remote_addr" { + remoteAddr := strings.Split(r.RemoteAddr, ":")[0] + return remoteAddr + } + return r.Host +} diff --git a/coverage.out b/coverage.out index 072c076..b0b120b 100644 --- a/coverage.out +++ b/coverage.out @@ -1,14 +1,21 @@ mode: count -github.com/2manymws/rlutils/base_limiter.go:26.15,35.2 2 20 -github.com/2manymws/rlutils/base_limiter.go:37.70,39.2 1 0 -github.com/2manymws/rlutils/base_limiter.go:41.37,43.2 1 0 -github.com/2manymws/rlutils/base_limiter.go:45.61,47.2 1 21 -github.com/2manymws/rlutils/base_limiter.go:49.64,50.34 1 27 -github.com/2manymws/rlutils/base_limiter.go:50.34,52.3 1 22 -github.com/2manymws/rlutils/base_limiter.go:53.2,54.41 2 5 -github.com/2manymws/rlutils/base_limiter.go:54.41,55.40 1 8 -github.com/2manymws/rlutils/base_limiter.go:55.40,57.4 1 3 -github.com/2manymws/rlutils/base_limiter.go:59.2,59.14 1 2 +github.com/2manymws/rlutils/base_limiter.go:27.15,36.2 2 22 +github.com/2manymws/rlutils/base_limiter.go:38.70,40.2 1 0 +github.com/2manymws/rlutils/base_limiter.go:42.37,44.2 1 0 +github.com/2manymws/rlutils/base_limiter.go:46.61,48.2 1 23 +github.com/2manymws/rlutils/base_limiter.go:50.64,51.34 1 29 +github.com/2manymws/rlutils/base_limiter.go:51.34,53.3 1 24 +github.com/2manymws/rlutils/base_limiter.go:54.2,55.41 2 5 +github.com/2manymws/rlutils/base_limiter.go:55.41,56.40 1 8 +github.com/2manymws/rlutils/base_limiter.go:56.40,58.4 1 3 +github.com/2manymws/rlutils/base_limiter.go:60.2,60.14 1 2 +github.com/2manymws/rlutils/base_limiter.go:62.36,63.52 1 12 +github.com/2manymws/rlutils/base_limiter.go:63.52,64.15 1 22 +github.com/2manymws/rlutils/base_limiter.go:64.15,66.4 1 12 +github.com/2manymws/rlutils/base_limiter.go:68.2,68.43 1 0 +github.com/2manymws/rlutils/base_limiter.go:71.50,72.26 1 7 +github.com/2manymws/rlutils/base_limiter.go:72.26,75.3 2 2 +github.com/2manymws/rlutils/base_limiter.go:76.2,76.15 1 5 github.com/2manymws/rlutils/country_limiter.go:36.28,38.16 2 6 github.com/2manymws/rlutils/country_limiter.go:38.16,40.3 1 0 github.com/2manymws/rlutils/country_limiter.go:41.2,44.30 3 6 @@ -39,15 +46,17 @@ github.com/2manymws/rlutils/country_limiter.go:118.69,126.16 3 10 github.com/2manymws/rlutils/country_limiter.go:126.16,128.3 1 2 github.com/2manymws/rlutils/country_limiter.go:130.2,130.36 1 8 github.com/2manymws/rlutils/country_limiter.go:133.73,135.2 1 0 -github.com/2manymws/rlutils/get_parameter_limiter.go:23.24,33.2 1 3 -github.com/2manymws/rlutils/get_parameter_limiter.go:35.45,37.2 1 0 -github.com/2manymws/rlutils/get_parameter_limiter.go:39.71,40.27 1 3 -github.com/2manymws/rlutils/get_parameter_limiter.go:40.27,42.3 1 0 -github.com/2manymws/rlutils/get_parameter_limiter.go:43.2,43.36 1 3 -github.com/2manymws/rlutils/get_parameter_limiter.go:43.36,44.32 1 3 -github.com/2manymws/rlutils/get_parameter_limiter.go:44.32,50.4 1 2 -github.com/2manymws/rlutils/get_parameter_limiter.go:53.2,53.36 1 1 -github.com/2manymws/rlutils/get_parameter_limiter.go:56.78,58.2 1 0 +github.com/2manymws/rlutils/get_parameter_limiter.go:25.33,27.16 2 4 +github.com/2manymws/rlutils/get_parameter_limiter.go:27.16,29.3 1 0 +github.com/2manymws/rlutils/get_parameter_limiter.go:30.2,39.8 1 4 +github.com/2manymws/rlutils/get_parameter_limiter.go:42.45,44.2 1 0 +github.com/2manymws/rlutils/get_parameter_limiter.go:46.71,47.27 1 4 +github.com/2manymws/rlutils/get_parameter_limiter.go:47.27,49.3 1 0 +github.com/2manymws/rlutils/get_parameter_limiter.go:50.2,50.36 1 4 +github.com/2manymws/rlutils/get_parameter_limiter.go:50.36,51.32 1 4 +github.com/2manymws/rlutils/get_parameter_limiter.go:51.32,57.4 1 3 +github.com/2manymws/rlutils/get_parameter_limiter.go:60.2,60.36 1 1 +github.com/2manymws/rlutils/get_parameter_limiter.go:63.78,65.2 1 0 github.com/2manymws/rlutils/host_limiter.go:21.16,30.2 1 1 github.com/2manymws/rlutils/host_limiter.go:32.37,34.2 1 0 github.com/2manymws/rlutils/host_limiter.go:36.63,37.27 1 1 @@ -60,25 +69,27 @@ github.com/2manymws/rlutils/ip_limiter.go:37.61,38.27 1 2 github.com/2manymws/rlutils/ip_limiter.go:38.27,40.3 1 0 github.com/2manymws/rlutils/ip_limiter.go:41.2,46.8 2 2 github.com/2manymws/rlutils/ip_limiter.go:49.68,51.2 1 0 -github.com/2manymws/rlutils/request_path_limiter.go:34.23,49.2 1 7 -github.com/2manymws/rlutils/request_path_limiter.go:51.44,53.2 1 0 -github.com/2manymws/rlutils/request_path_limiter.go:55.70,56.27 1 7 -github.com/2manymws/rlutils/request_path_limiter.go:56.27,58.3 1 0 -github.com/2manymws/rlutils/request_path_limiter.go:59.2,66.4 1 7 -github.com/2manymws/rlutils/request_path_limiter.go:66.4,67.23 1 18 -github.com/2manymws/rlutils/request_path_limiter.go:67.23,68.33 1 18 -github.com/2manymws/rlutils/request_path_limiter.go:68.33,69.31 1 18 -github.com/2manymws/rlutils/request_path_limiter.go:69.31,80.8 2 9 -github.com/2manymws/rlutils/request_path_limiter.go:80.8,81.31 1 21 -github.com/2manymws/rlutils/request_path_limiter.go:81.31,82.42 1 6 -github.com/2manymws/rlutils/request_path_limiter.go:82.42,83.40 1 6 -github.com/2manymws/rlutils/request_path_limiter.go:83.40,85.15 2 6 -github.com/2manymws/rlutils/request_path_limiter.go:88.8,88.19 1 6 -github.com/2manymws/rlutils/request_path_limiter.go:88.19,89.14 1 6 -github.com/2manymws/rlutils/request_path_limiter.go:93.6,93.18 1 9 -github.com/2manymws/rlutils/request_path_limiter.go:93.18,99.7 1 3 -github.com/2manymws/rlutils/request_path_limiter.go:105.2,105.36 1 4 -github.com/2manymws/rlutils/request_path_limiter.go:108.77,110.2 1 0 +github.com/2manymws/rlutils/request_path_limiter.go:36.32,38.16 2 8 +github.com/2manymws/rlutils/request_path_limiter.go:38.16,40.3 1 0 +github.com/2manymws/rlutils/request_path_limiter.go:42.2,56.8 1 8 +github.com/2manymws/rlutils/request_path_limiter.go:59.44,61.2 1 0 +github.com/2manymws/rlutils/request_path_limiter.go:63.70,64.27 1 8 +github.com/2manymws/rlutils/request_path_limiter.go:64.27,66.3 1 0 +github.com/2manymws/rlutils/request_path_limiter.go:67.2,74.4 1 8 +github.com/2manymws/rlutils/request_path_limiter.go:74.4,75.23 1 21 +github.com/2manymws/rlutils/request_path_limiter.go:75.23,76.33 1 21 +github.com/2manymws/rlutils/request_path_limiter.go:76.33,77.31 1 21 +github.com/2manymws/rlutils/request_path_limiter.go:77.31,88.8 2 10 +github.com/2manymws/rlutils/request_path_limiter.go:88.8,89.31 1 24 +github.com/2manymws/rlutils/request_path_limiter.go:89.31,90.42 1 6 +github.com/2manymws/rlutils/request_path_limiter.go:90.42,91.40 1 6 +github.com/2manymws/rlutils/request_path_limiter.go:91.40,93.15 2 6 +github.com/2manymws/rlutils/request_path_limiter.go:96.8,96.19 1 6 +github.com/2manymws/rlutils/request_path_limiter.go:96.19,97.14 1 6 +github.com/2manymws/rlutils/request_path_limiter.go:101.6,101.18 1 10 +github.com/2manymws/rlutils/request_path_limiter.go:101.18,107.7 1 4 +github.com/2manymws/rlutils/request_path_limiter.go:113.2,113.36 1 4 +github.com/2manymws/rlutils/request_path_limiter.go:116.77,118.2 1 0 github.com/2manymws/rlutils/user_agent_limiter.go:24.21,34.2 1 1 github.com/2manymws/rlutils/user_agent_limiter.go:36.42,38.2 1 0 github.com/2manymws/rlutils/user_agent_limiter.go:40.68,41.27 1 2 diff --git a/get_parameter_limiter.go b/get_parameter_limiter.go index 8090210..b2f4004 100644 --- a/get_parameter_limiter.go +++ b/get_parameter_limiter.go @@ -9,6 +9,7 @@ import ( type GetParameterLimiter struct { getParameters map[string]string + key string BaseLimiter } @@ -18,18 +19,24 @@ func NewGetParameterLimiter( getParameters map[string]string, reqLimit int, windowLen time.Duration, + key string, targetExtensions []string, onRequestLimit func(*rl.Context, string) http.HandlerFunc, -) *GetParameterLimiter { +) (*GetParameterLimiter, error) { + err := validateKey(key) + if err != nil { + return nil, err + } return &GetParameterLimiter{ getParameters: getParameters, + key: key, BaseLimiter: NewBaseLimiter( reqLimit, windowLen, targetExtensions, onRequestLimit, ), - } + }, nil } func (l *GetParameterLimiter) Name() string { @@ -43,7 +50,7 @@ func (l *GetParameterLimiter) Rule(r *http.Request) (*rl.Rule, error) { for k, v := range l.getParameters { if r.URL.Query().Get(k) == v { return &rl.Rule{ - Key: r.Host + "/" + k + "=" + v, + Key: fillKey(r, l.key) + "/" + k + "=" + v, ReqLimit: l.reqLimit, WindowLen: l.windowLen, }, nil diff --git a/get_parameter_limiter_test.go b/get_parameter_limiter_test.go index ef7e3c7..b50a561 100644 --- a/get_parameter_limiter_test.go +++ b/get_parameter_limiter_test.go @@ -17,6 +17,7 @@ func TestGetParameterLimiter(t *testing.T) { name string getParameters map[string]string queryString string + key string expectedToBeLimited bool expectedKey string }{ @@ -26,6 +27,7 @@ func TestGetParameterLimiter(t *testing.T) { "token": "123456", }, queryString: "?token=123456", + key: "host", expectedToBeLimited: true, expectedKey: "example.com/token=123456", }, @@ -35,6 +37,7 @@ func TestGetParameterLimiter(t *testing.T) { "token": "123456", }, queryString: "?token=abcdef", + key: "host", expectedToBeLimited: false, }, { @@ -44,9 +47,20 @@ func TestGetParameterLimiter(t *testing.T) { "sessionId": "ABCDEF", }, queryString: "?token=123456&sessionId=XYZ", + key: "host", expectedToBeLimited: true, expectedKey: "example.com/token=123456", }, + { + name: "Request with matching get parameter should be limited with remote ip", + getParameters: map[string]string{ + "token": "123456", + }, + queryString: "?token=123456", + key: "remote_addr", + expectedToBeLimited: true, + expectedKey: "127.0.0.1/token=123456", + }, } for _, tc := range cases { @@ -58,10 +72,11 @@ func TestGetParameterLimiter(t *testing.T) { reqLimit := 5 windowLen := time.Minute - limiter := NewGetParameterLimiter( + limiter, _ := NewGetParameterLimiter( tc.getParameters, reqLimit, windowLen, + tc.key, nil, nil, ) @@ -70,6 +85,7 @@ func TestGetParameterLimiter(t *testing.T) { limiter.Counter = mockCounter req := httptest.NewRequest(http.MethodGet, "http://example.com"+tc.queryString, nil) + req.RemoteAddr = "127.0.0.1:12345" rule, err := limiter.Rule(req) assert.NoError(t, err) diff --git a/request_path_limiter.go b/request_path_limiter.go index 36741c9..a2884a1 100644 --- a/request_path_limiter.go +++ b/request_path_limiter.go @@ -15,6 +15,7 @@ type RequestPathLimiter struct { ignorePathContains []string ignorePathPrefixes []string ignorePathSuffixes []string + key string BaseLimiter } @@ -29,9 +30,15 @@ func NewRequestPathLimiter( ignorePathSuffixes []string, reqLimit int, windowLen time.Duration, + key string, targetExtensions []string, onRequestLimit func(*rl.Context, string) http.HandlerFunc, -) *RequestPathLimiter { +) (*RequestPathLimiter, error) { + err := validateKey(key) + if err != nil { + return nil, err + } + return &RequestPathLimiter{ requestPathContains: requestPathContains, requestPathPrefixes: requestPathPrefixes, @@ -39,13 +46,14 @@ func NewRequestPathLimiter( ignorePathContains: ignorePathContains, ignorePathPrefixes: ignorePathPrefixes, ignorePathSuffixes: ignorePathSuffixes, + key: key, BaseLimiter: NewBaseLimiter( reqLimit, windowLen, targetExtensions, onRequestLimit, ), - } + }, nil } func (l *RequestPathLimiter) Name() string { @@ -92,7 +100,7 @@ func (l *RequestPathLimiter) Rule(r *http.Request) (*rl.Rule, error) { } if !ignored { return &rl.Rule{ - Key: r.Host + path, + Key: fillKey(r, l.key) + path, ReqLimit: l.reqLimit, WindowLen: l.windowLen, }, nil diff --git a/request_path_limiter_test.go b/request_path_limiter_test.go index efe4779..82bf89e 100644 --- a/request_path_limiter_test.go +++ b/request_path_limiter_test.go @@ -20,6 +20,7 @@ func TestRequestPathLimiter(t *testing.T) { ignorePrefixes []string ignoreSuffixes []string path string + key string expectedToBeLimited bool expectedKey string }{ @@ -29,6 +30,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/api/"}, suffixes: []string{"/details"}, path: "/accounts/user/profile", + key: "host", expectedToBeLimited: true, expectedKey: "example.com/user", }, @@ -38,6 +40,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/api/"}, suffixes: []string{"/details"}, path: "/api/users/1", + key: "host", expectedToBeLimited: true, expectedKey: "example.com/api/", }, @@ -47,6 +50,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/api/"}, suffixes: []string{"/details"}, path: "/users/1/details", + key: "host", expectedToBeLimited: true, expectedKey: "example.com/details", }, @@ -56,6 +60,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/api/"}, suffixes: []string{"/details"}, path: "/about", + key: "host", expectedToBeLimited: false, }, { @@ -64,6 +69,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/abcd"}, suffixes: []string{"/abcde"}, path: "/abcdefg", + key: "host", ignoreContains: []string{"ab"}, expectedToBeLimited: false, expectedKey: "", @@ -74,6 +80,7 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/abcd"}, suffixes: []string{"/abcde"}, path: "/abcdefg", + key: "host", ignorePrefixes: []string{"/a"}, expectedToBeLimited: false, expectedKey: "", @@ -84,10 +91,21 @@ func TestRequestPathLimiter(t *testing.T) { prefixes: []string{"/abcd"}, suffixes: []string{"/abcde"}, path: "/abcdefg", + key: "host", ignoreSuffixes: []string{"g"}, expectedToBeLimited: false, expectedKey: "", }, + { + name: "Path contains limited segment with remote ip", + contains: []string{"/user"}, + prefixes: []string{"/api/"}, + suffixes: []string{"/details"}, + path: "/accounts/user/profile", + key: "remote_addr", + expectedToBeLimited: true, + expectedKey: "127.0.0.1/user", + }, } for _, tc := range cases { @@ -99,7 +117,7 @@ func TestRequestPathLimiter(t *testing.T) { reqLimit := 5 windowLen := time.Minute - limiter := NewRequestPathLimiter( + limiter, _ := NewRequestPathLimiter( tc.contains, tc.prefixes, tc.suffixes, @@ -108,6 +126,7 @@ func TestRequestPathLimiter(t *testing.T) { tc.ignoreSuffixes, reqLimit, windowLen, + tc.key, nil, nil, ) @@ -116,6 +135,7 @@ func TestRequestPathLimiter(t *testing.T) { limiter.Counter = mockCounter req := httptest.NewRequest(http.MethodGet, tc.path, nil) + req.RemoteAddr = "127.0.0.1" rule, err := limiter.Rule(req) assert.NoError(t, err)