Skip to content

Commit

Permalink
Split localCounter into its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Jul 24, 2024
1 parent 6352918 commit 56b390b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 83 deletions.
83 changes: 0 additions & 83 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"net/http"
"sync"
"time"

"github.com/cespare/xxhash/v2"
)

type LimitCounter interface {
Expand Down Expand Up @@ -153,87 +151,6 @@ func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64
return true, rate, nil
}

type localCounter struct {
counters map[uint64]*count
windowLength time.Duration
lastEvict time.Time
mu sync.Mutex
}

var _ LimitCounter = &localCounter{}

type count struct {
value int
updatedAt time.Time
}

func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.windowLength = windowLength
}

func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
c.mu.Lock()
defer c.mu.Unlock()

c.evict()

hkey := LimitCounterKey(key, currentWindow)

v, ok := c.counters[hkey]
if !ok {
v = &count{}
c.counters[hkey] = v
}
v.value += amount
v.updatedAt = time.Now()

return nil
}

func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
c.mu.Lock()
defer c.mu.Unlock()

curr, ok := c.counters[LimitCounterKey(key, currentWindow)]
if !ok {
curr = &count{value: 0, updatedAt: time.Now()}
}
prev, ok := c.counters[LimitCounterKey(key, previousWindow)]
if !ok {
prev = &count{value: 0, updatedAt: time.Now()}
}

return curr.value, prev.value, nil
}

func (c *localCounter) evict() {
d := c.windowLength * 3

if time.Since(c.lastEvict) < d {
return
}
c.lastEvict = time.Now()

for k, v := range c.counters {
if time.Since(v.updatedAt) >= d {
delete(c.counters, k)
}
}
}

func LimitCounterKey(key string, window time.Time) uint64 {
h := xxhash.New()
h.WriteString(key)
h.WriteString(fmt.Sprintf("%d", window.Unix()))
return h.Sum64()
}

func setHeader(w http.ResponseWriter, key string, value string) {
if key != "" {
w.Header().Set(key, value)
Expand Down
90 changes: 90 additions & 0 deletions local_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package httprate

import (
"fmt"
"sync"
"time"

"github.com/cespare/xxhash/v2"
)

var _ LimitCounter = &localCounter{}

type localCounter struct {
counters map[uint64]*count
windowLength time.Duration
lastEvict time.Time
mu sync.Mutex
}

type count struct {
value int
updatedAt time.Time
}

func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.windowLength = windowLength
}

func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}

func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
c.mu.Lock()
defer c.mu.Unlock()

c.evict()

hkey := LimitCounterKey(key, currentWindow)

v, ok := c.counters[hkey]
if !ok {
v = &count{}
c.counters[hkey] = v
}
v.value += amount
v.updatedAt = time.Now()

return nil
}

func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time) (int, int, error) {
c.mu.Lock()
defer c.mu.Unlock()

curr, ok := c.counters[LimitCounterKey(key, currentWindow)]
if !ok {
curr = &count{value: 0, updatedAt: time.Now()}
}
prev, ok := c.counters[LimitCounterKey(key, previousWindow)]
if !ok {
prev = &count{value: 0, updatedAt: time.Now()}
}

return curr.value, prev.value, nil
}

func (c *localCounter) evict() {
d := c.windowLength * 3

if time.Since(c.lastEvict) < d {
return
}
c.lastEvict = time.Now()

for k, v := range c.counters {
if time.Since(v.updatedAt) >= d {
delete(c.counters, k)
}
}
}

func LimitCounterKey(key string, window time.Time) uint64 {
h := xxhash.New()
h.WriteString(key)
h.WriteString(fmt.Sprintf("%d", window.Unix()))
return h.Sum64()
}

0 comments on commit 56b390b

Please sign in to comment.