Skip to content

Commit

Permalink
Merge pull request #2 from 2manymws/exclude-contries
Browse files Browse the repository at this point in the history
Apply rate limiting to countries other than the target country
  • Loading branch information
Kazuhiko Yamashita authored Dec 20, 2023
2 parents 5f3a38b + 28be9f8 commit 26df451
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 47 deletions.
62 changes: 49 additions & 13 deletions country_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package rlutils
// limit from ip with maxMindDB

import (
"fmt"
"net"
"net/http"
"strings"
Expand All @@ -13,8 +14,9 @@ import (
)

type CountryLimiter struct {
db *maxminddb.Reader
countries []string
db *maxminddb.Reader
countries map[string]struct{}
skipCountries map[string]struct{}
BaseLimiter
}

Expand All @@ -23,6 +25,7 @@ type CountryLimiter struct {
func NewCountryLimiter(
dbPath string,
countries []string,
skipCountries []string,
reqLimit int,
windowLen time.Duration,
targetExtensions []string,
Expand All @@ -32,9 +35,23 @@ func NewCountryLimiter(
if err != nil {
return nil, err
}
cm := map[string]struct{}{}
scm := map[string]struct{}{}

for _, c := range countries {
cm[c] = struct{}{}
}

for _, c := range skipCountries {
if c == "*" {
return nil, fmt.Errorf("invalid skip country: %s", c)
}
scm[c] = struct{}{}
}
return &CountryLimiter{
db: db,
countries: countries,
db: db,
countries: cm,
skipCountries: scm,
BaseLimiter: NewBaseLimiter(
reqLimit,
windowLen,
Expand All @@ -58,16 +75,35 @@ func (l *CountryLimiter) Rule(r *http.Request) (*rl.Rule, error) {
if err != nil {
return nil, err
}
for _, c := range l.countries {
if country == c {
return &rl.Rule{
Key: remoteAddr,
ReqLimit: l.reqLimit,
WindowLen: l.windowLen,
}, nil
}

limit := &rl.Rule{
Key: remoteAddr,
ReqLimit: l.reqLimit,
WindowLen: l.windowLen,
}
noLimit := &rl.Rule{ReqLimit: -1}

if country == "" {
return noLimit, nil
}

if _, ok := l.skipCountries[country]; ok {
return noLimit, nil
}

if _, ok := l.countries["*"]; ok {
return limit, nil

}

if _, ok := l.countries[country]; ok {
return &rl.Rule{
Key: remoteAddr,
ReqLimit: l.reqLimit,
WindowLen: l.windowLen,
}, nil
}
return &rl.Rule{ReqLimit: -1}, nil
return noLimit, nil
}

func (l *CountryLimiter) country(remoteAddr string) (string, error) {
Expand Down
52 changes: 38 additions & 14 deletions country_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,73 @@ func testHTTPRequest(remoteAddr string) *http.Request {
func TestCountryLimiter(t *testing.T) {
abspath, _ := filepath.Abs("./testdata/GeoIP2-Country-Test.mmdb")
reqLimit := 10
cl, err := NewCountryLimiter(abspath, []string{"US"}, reqLimit, 1*time.Hour, nil, nil)
if err != nil {
t.Fatal(err)
}

// Define your test cases
testCases := []struct {
name string
request *http.Request
countries []string
skipCountries []string
expectedCountry string
expectedError bool
allowed bool
shouldLimit bool
}{
{
name: "Valid IP from United States With Port",
request: testHTTPRequest("50.114.0.1:1234"),
expectedCountry: "US",
countries: []string{"US"},
allowed: true,
shouldLimit: true,
expectedError: false,
},
{
name: "Valid IP from United States",
request: testHTTPRequest("50.114.0.1"),
expectedCountry: "US",
allowed: true,
shouldLimit: true,
expectedError: false,
},

{
name: "Invalid IP format",
request: testHTTPRequest("invalid-ip"),
expectedCountry: "",
allowed: false,
shouldLimit: false,
expectedError: true,
},
{
name: "Valid IP from United States With WildCard",
request: testHTTPRequest("50.114.0.1"),
expectedCountry: "US",
countries: []string{"*"},
shouldLimit: true,
expectedError: false,
},
{
name: "Valid IP from United States With Skip country",
request: testHTTPRequest("50.114.0.1"),
expectedCountry: "US",
countries: []string{"*"},
skipCountries: []string{"US"},
shouldLimit: false,
expectedError: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Run the country function to get the ISO country code

cl, err := NewCountryLimiter(
abspath,
[]string{"US"},
tc.skipCountries,
reqLimit,
1*time.Hour,
nil,
nil,
)
if err != nil {
t.Fatal(err)
}
remoteAddr := strings.Split(tc.request.RemoteAddr, ":")[0]
country, err := cl.country(remoteAddr)
if tc.expectedError {
Expand All @@ -82,11 +106,11 @@ func TestCountryLimiter(t *testing.T) {
t.Errorf("Rule method returned an unexpected error: %v", ruleErr)
}
} else {
if tc.allowed && (rule == nil || rule.ReqLimit != reqLimit) {
t.Errorf("Expected allowed rule with limit %d, but got %+v", reqLimit, rule)
if tc.shouldLimit && (rule == nil || rule.ReqLimit != reqLimit) {
t.Errorf("Expected shouldLimit rule with limit %d, but got %+v", reqLimit, rule)
}
if !tc.allowed && (rule == nil || rule.ReqLimit != -1) {
t.Errorf("Expected disallowed rule with no limiting, but got %+v", rule)
if !tc.shouldLimit && (rule == nil || rule.ReqLimit != -1) {
t.Errorf("Expected disshouldLimit rule with no limiting, but got %+v", rule)
}
}

Expand Down
51 changes: 31 additions & 20 deletions coverage.out
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
mode: count
github.com/2manymws/rlutils/base_limiter.go:26.15,35.2 2 15
github.com/2manymws/rlutils/base_limiter.go:26.15,35.2 2 19
github.com/2manymws/rlutils/base_limiter.go:37.70,39.2 1 0
github.com/2manymws/rlutils/base_limiter.go:41.37,43.2 1 0
github.com/2manymws/rlutils/base_limiter.go:45.61,47.2 1 18
github.com/2manymws/rlutils/base_limiter.go:49.64,50.34 1 24
github.com/2manymws/rlutils/base_limiter.go:50.34,52.3 1 19
github.com/2manymws/rlutils/base_limiter.go:45.61,47.2 1 20
github.com/2manymws/rlutils/base_limiter.go:49.64,50.34 1 26
github.com/2manymws/rlutils/base_limiter.go:50.34,52.3 1 21
github.com/2manymws/rlutils/base_limiter.go:53.2,54.41 2 5
github.com/2manymws/rlutils/base_limiter.go:54.41,55.40 1 8
github.com/2manymws/rlutils/base_limiter.go:55.40,57.4 1 3
github.com/2manymws/rlutils/base_limiter.go:59.2,59.14 1 2
github.com/2manymws/rlutils/country_limiter.go:30.28,32.16 2 1
github.com/2manymws/rlutils/country_limiter.go:32.16,34.3 1 0
github.com/2manymws/rlutils/country_limiter.go:35.2,44.8 1 1
github.com/2manymws/rlutils/country_limiter.go:47.40,49.2 1 0
github.com/2manymws/rlutils/country_limiter.go:51.66,52.27 1 3
github.com/2manymws/rlutils/country_limiter.go:52.27,54.3 1 0
github.com/2manymws/rlutils/country_limiter.go:56.2,58.16 3 3
github.com/2manymws/rlutils/country_limiter.go:58.16,60.3 1 1
github.com/2manymws/rlutils/country_limiter.go:61.2,61.32 1 2
github.com/2manymws/rlutils/country_limiter.go:61.32,62.19 1 2
github.com/2manymws/rlutils/country_limiter.go:62.19,68.4 1 2
github.com/2manymws/rlutils/country_limiter.go:70.2,70.36 1 0
github.com/2manymws/rlutils/country_limiter.go:73.69,81.16 3 6
github.com/2manymws/rlutils/country_limiter.go:81.16,83.3 1 2
github.com/2manymws/rlutils/country_limiter.go:85.2,85.36 1 4
github.com/2manymws/rlutils/country_limiter.go:88.73,90.2 1 0
github.com/2manymws/rlutils/country_limiter.go:33.28,35.16 2 5
github.com/2manymws/rlutils/country_limiter.go:35.16,37.3 1 0
github.com/2manymws/rlutils/country_limiter.go:38.2,41.30 3 5
github.com/2manymws/rlutils/country_limiter.go:41.30,43.3 1 5
github.com/2manymws/rlutils/country_limiter.go:45.2,45.34 1 5
github.com/2manymws/rlutils/country_limiter.go:45.34,46.15 1 1
github.com/2manymws/rlutils/country_limiter.go:46.15,48.4 1 0
github.com/2manymws/rlutils/country_limiter.go:49.3,49.22 1 1
github.com/2manymws/rlutils/country_limiter.go:51.2,61.8 1 5
github.com/2manymws/rlutils/country_limiter.go:64.40,66.2 1 0
github.com/2manymws/rlutils/country_limiter.go:68.66,69.27 1 5
github.com/2manymws/rlutils/country_limiter.go:69.27,71.3 1 0
github.com/2manymws/rlutils/country_limiter.go:73.2,75.16 3 5
github.com/2manymws/rlutils/country_limiter.go:75.16,77.3 1 1
github.com/2manymws/rlutils/country_limiter.go:79.2,86.19 3 4
github.com/2manymws/rlutils/country_limiter.go:86.19,88.3 1 0
github.com/2manymws/rlutils/country_limiter.go:90.2,90.43 1 4
github.com/2manymws/rlutils/country_limiter.go:90.43,92.3 1 1
github.com/2manymws/rlutils/country_limiter.go:94.2,94.35 1 3
github.com/2manymws/rlutils/country_limiter.go:94.35,97.3 1 0
github.com/2manymws/rlutils/country_limiter.go:99.2,99.39 1 3
github.com/2manymws/rlutils/country_limiter.go:99.39,105.3 1 3
github.com/2manymws/rlutils/country_limiter.go:106.2,106.21 1 0
github.com/2manymws/rlutils/country_limiter.go:109.69,117.16 3 10
github.com/2manymws/rlutils/country_limiter.go:117.16,119.3 1 2
github.com/2manymws/rlutils/country_limiter.go:121.2,121.36 1 8
github.com/2manymws/rlutils/country_limiter.go:124.73,126.2 1 0
github.com/2manymws/rlutils/get_parameter_limiter.go:23.24,33.2 1 3
github.com/2manymws/rlutils/get_parameter_limiter.go:35.45,37.2 1 0
github.com/2manymws/rlutils/get_parameter_limiter.go:39.71,40.27 1 3
Expand Down

0 comments on commit 26df451

Please sign in to comment.