diff --git a/country_limiter.go b/country_limiter.go index 0563250..fc65691 100644 --- a/country_limiter.go +++ b/country_limiter.go @@ -3,6 +3,7 @@ package rlutils // limit from ip with maxMindDB import ( + "fmt" "net" "net/http" "strings" @@ -13,8 +14,9 @@ import ( ) type CountryLimiter struct { - db *maxminddb.Reader - countries []string + db *maxminddb.Reader + countries map[string]struct{} + skipCountries map[string]struct{} BaseLimiter } @@ -23,6 +25,7 @@ type CountryLimiter struct { func NewCountryLimiter( dbPath string, countries []string, + skipCountries []string, reqLimit int, windowLen time.Duration, targetExtensions []string, @@ -32,9 +35,23 @@ func NewCountryLimiter( if err != nil { return nil, err } + cm := map[string]struct{}{} + scm := map[string]struct{}{} + + for _, c := range countries { + cm[c] = struct{}{} + } + + for _, c := range skipCountries { + if c == "*" { + return nil, fmt.Errorf("invalid skip country: %s", c) + } + scm[c] = struct{}{} + } return &CountryLimiter{ - db: db, - countries: countries, + db: db, + countries: cm, + skipCountries: scm, BaseLimiter: NewBaseLimiter( reqLimit, windowLen, @@ -58,16 +75,35 @@ func (l *CountryLimiter) Rule(r *http.Request) (*rl.Rule, error) { if err != nil { return nil, err } - for _, c := range l.countries { - if country == c { - return &rl.Rule{ - Key: remoteAddr, - ReqLimit: l.reqLimit, - WindowLen: l.windowLen, - }, nil - } + + limit := &rl.Rule{ + Key: remoteAddr, + ReqLimit: l.reqLimit, + WindowLen: l.windowLen, + } + noLimit := &rl.Rule{ReqLimit: -1} + + if country == "" { + return noLimit, nil + } + + if _, ok := l.skipCountries[country]; ok { + return noLimit, nil + } + + if _, ok := l.countries["*"]; ok { + return limit, nil + + } + + if _, ok := l.countries[country]; ok { + return &rl.Rule{ + Key: remoteAddr, + ReqLimit: l.reqLimit, + WindowLen: l.windowLen, + }, nil } - return &rl.Rule{ReqLimit: -1}, nil + return noLimit, nil } func (l *CountryLimiter) country(remoteAddr string) (string, error) { diff --git a/country_limiter_test.go b/country_limiter_test.go index 3d58935..82ba098 100644 --- a/country_limiter_test.go +++ b/country_limiter_test.go @@ -18,49 +18,73 @@ func testHTTPRequest(remoteAddr string) *http.Request { func TestCountryLimiter(t *testing.T) { abspath, _ := filepath.Abs("./testdata/GeoIP2-Country-Test.mmdb") reqLimit := 10 - cl, err := NewCountryLimiter(abspath, []string{"US"}, reqLimit, 1*time.Hour, nil, nil) - if err != nil { - t.Fatal(err) - } // Define your test cases testCases := []struct { name string request *http.Request countries []string + skipCountries []string expectedCountry string expectedError bool - allowed bool + shouldLimit bool }{ { name: "Valid IP from United States With Port", request: testHTTPRequest("50.114.0.1:1234"), expectedCountry: "US", countries: []string{"US"}, - allowed: true, + shouldLimit: true, expectedError: false, }, { name: "Valid IP from United States", request: testHTTPRequest("50.114.0.1"), expectedCountry: "US", - allowed: true, + shouldLimit: true, expectedError: false, }, - { name: "Invalid IP format", request: testHTTPRequest("invalid-ip"), expectedCountry: "", - allowed: false, + shouldLimit: false, expectedError: true, }, + { + name: "Valid IP from United States With WildCard", + request: testHTTPRequest("50.114.0.1"), + expectedCountry: "US", + countries: []string{"*"}, + shouldLimit: true, + expectedError: false, + }, + { + name: "Valid IP from United States With Skip country", + request: testHTTPRequest("50.114.0.1"), + expectedCountry: "US", + countries: []string{"*"}, + skipCountries: []string{"US"}, + shouldLimit: false, + expectedError: false, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Run the country function to get the ISO country code - + cl, err := NewCountryLimiter( + abspath, + []string{"US"}, + tc.skipCountries, + reqLimit, + 1*time.Hour, + nil, + nil, + ) + if err != nil { + t.Fatal(err) + } remoteAddr := strings.Split(tc.request.RemoteAddr, ":")[0] country, err := cl.country(remoteAddr) if tc.expectedError { @@ -82,11 +106,11 @@ func TestCountryLimiter(t *testing.T) { t.Errorf("Rule method returned an unexpected error: %v", ruleErr) } } else { - if tc.allowed && (rule == nil || rule.ReqLimit != reqLimit) { - t.Errorf("Expected allowed rule with limit %d, but got %+v", reqLimit, rule) + if tc.shouldLimit && (rule == nil || rule.ReqLimit != reqLimit) { + t.Errorf("Expected shouldLimit rule with limit %d, but got %+v", reqLimit, rule) } - if !tc.allowed && (rule == nil || rule.ReqLimit != -1) { - t.Errorf("Expected disallowed rule with no limiting, but got %+v", rule) + if !tc.shouldLimit && (rule == nil || rule.ReqLimit != -1) { + t.Errorf("Expected disshouldLimit rule with no limiting, but got %+v", rule) } } diff --git a/coverage.out b/coverage.out index 17c014e..be9af2a 100644 --- a/coverage.out +++ b/coverage.out @@ -1,30 +1,41 @@ mode: count -github.com/2manymws/rlutils/base_limiter.go:26.15,35.2 2 15 +github.com/2manymws/rlutils/base_limiter.go:26.15,35.2 2 19 github.com/2manymws/rlutils/base_limiter.go:37.70,39.2 1 0 github.com/2manymws/rlutils/base_limiter.go:41.37,43.2 1 0 -github.com/2manymws/rlutils/base_limiter.go:45.61,47.2 1 18 -github.com/2manymws/rlutils/base_limiter.go:49.64,50.34 1 24 -github.com/2manymws/rlutils/base_limiter.go:50.34,52.3 1 19 +github.com/2manymws/rlutils/base_limiter.go:45.61,47.2 1 20 +github.com/2manymws/rlutils/base_limiter.go:49.64,50.34 1 26 +github.com/2manymws/rlutils/base_limiter.go:50.34,52.3 1 21 github.com/2manymws/rlutils/base_limiter.go:53.2,54.41 2 5 github.com/2manymws/rlutils/base_limiter.go:54.41,55.40 1 8 github.com/2manymws/rlutils/base_limiter.go:55.40,57.4 1 3 github.com/2manymws/rlutils/base_limiter.go:59.2,59.14 1 2 -github.com/2manymws/rlutils/country_limiter.go:30.28,32.16 2 1 -github.com/2manymws/rlutils/country_limiter.go:32.16,34.3 1 0 -github.com/2manymws/rlutils/country_limiter.go:35.2,44.8 1 1 -github.com/2manymws/rlutils/country_limiter.go:47.40,49.2 1 0 -github.com/2manymws/rlutils/country_limiter.go:51.66,52.27 1 3 -github.com/2manymws/rlutils/country_limiter.go:52.27,54.3 1 0 -github.com/2manymws/rlutils/country_limiter.go:56.2,58.16 3 3 -github.com/2manymws/rlutils/country_limiter.go:58.16,60.3 1 1 -github.com/2manymws/rlutils/country_limiter.go:61.2,61.32 1 2 -github.com/2manymws/rlutils/country_limiter.go:61.32,62.19 1 2 -github.com/2manymws/rlutils/country_limiter.go:62.19,68.4 1 2 -github.com/2manymws/rlutils/country_limiter.go:70.2,70.36 1 0 -github.com/2manymws/rlutils/country_limiter.go:73.69,81.16 3 6 -github.com/2manymws/rlutils/country_limiter.go:81.16,83.3 1 2 -github.com/2manymws/rlutils/country_limiter.go:85.2,85.36 1 4 -github.com/2manymws/rlutils/country_limiter.go:88.73,90.2 1 0 +github.com/2manymws/rlutils/country_limiter.go:33.28,35.16 2 5 +github.com/2manymws/rlutils/country_limiter.go:35.16,37.3 1 0 +github.com/2manymws/rlutils/country_limiter.go:38.2,41.30 3 5 +github.com/2manymws/rlutils/country_limiter.go:41.30,43.3 1 5 +github.com/2manymws/rlutils/country_limiter.go:45.2,45.34 1 5 +github.com/2manymws/rlutils/country_limiter.go:45.34,46.15 1 1 +github.com/2manymws/rlutils/country_limiter.go:46.15,48.4 1 0 +github.com/2manymws/rlutils/country_limiter.go:49.3,49.22 1 1 +github.com/2manymws/rlutils/country_limiter.go:51.2,61.8 1 5 +github.com/2manymws/rlutils/country_limiter.go:64.40,66.2 1 0 +github.com/2manymws/rlutils/country_limiter.go:68.66,69.27 1 5 +github.com/2manymws/rlutils/country_limiter.go:69.27,71.3 1 0 +github.com/2manymws/rlutils/country_limiter.go:73.2,75.16 3 5 +github.com/2manymws/rlutils/country_limiter.go:75.16,77.3 1 1 +github.com/2manymws/rlutils/country_limiter.go:79.2,86.19 3 4 +github.com/2manymws/rlutils/country_limiter.go:86.19,88.3 1 0 +github.com/2manymws/rlutils/country_limiter.go:90.2,90.43 1 4 +github.com/2manymws/rlutils/country_limiter.go:90.43,92.3 1 1 +github.com/2manymws/rlutils/country_limiter.go:94.2,94.35 1 3 +github.com/2manymws/rlutils/country_limiter.go:94.35,97.3 1 0 +github.com/2manymws/rlutils/country_limiter.go:99.2,99.39 1 3 +github.com/2manymws/rlutils/country_limiter.go:99.39,105.3 1 3 +github.com/2manymws/rlutils/country_limiter.go:106.2,106.21 1 0 +github.com/2manymws/rlutils/country_limiter.go:109.69,117.16 3 10 +github.com/2manymws/rlutils/country_limiter.go:117.16,119.3 1 2 +github.com/2manymws/rlutils/country_limiter.go:121.2,121.36 1 8 +github.com/2manymws/rlutils/country_limiter.go:124.73,126.2 1 0 github.com/2manymws/rlutils/get_parameter_limiter.go:23.24,33.2 1 3 github.com/2manymws/rlutils/get_parameter_limiter.go:35.45,37.2 1 0 github.com/2manymws/rlutils/get_parameter_limiter.go:39.71,40.27 1 3