From 597f77b10c15b9556aefa6d59db12a495ff213cc Mon Sep 17 00:00:00 2001 From: Yen-Ming Lee Date: Sun, 3 Nov 2024 11:38:55 -0800 Subject: [PATCH] Add reset (#91) * feat: add Reset to leaky bucket and token bucket --- leakybucket.go | 48 ++++++++++++++++++++++++++++++++++++++++ leakybucket_test.go | 30 +++++++++++++++++++++++++ tokenbucket.go | 54 +++++++++++++++++++++++++++++++++++++++++++++ tokenbucket_test.go | 26 ++++++++++++++++++++++ 4 files changed, 158 insertions(+) diff --git a/leakybucket.go b/leakybucket.go index 124b3c6..2b9ff8c 100644 --- a/leakybucket.go +++ b/leakybucket.go @@ -38,6 +38,8 @@ type LeakyBucketStateBackend interface { State(ctx context.Context) (LeakyBucketState, error) // SetState sets (persists) the current state of the LeakyBucket. SetState(ctx context.Context, state LeakyBucketState) error + // Reset resets (persists) the current state of the LeakyBucket. + Reset(ctx context.Context) error } // LeakyBucket implements the https://en.wikipedia.org/wiki/Leaky_bucket#As_a_queue algorithm. @@ -108,6 +110,11 @@ func (t *LeakyBucket) Limit(ctx context.Context) (time.Duration, error) { return time.Duration(wait), nil } +// Reset resets the bucket. +func (t *LeakyBucket) Reset(ctx context.Context) error { + return t.backend.Reset(ctx) +} + // LeakyBucketInMemory is an in-memory implementation of LeakyBucketStateBackend. type LeakyBucketInMemory struct { state LeakyBucketState @@ -129,6 +136,14 @@ func (l *LeakyBucketInMemory) SetState(ctx context.Context, state LeakyBucketSta return ctx.Err() } +// Reset resets the current state of the bucket. +func (l *LeakyBucketInMemory) Reset(ctx context.Context) error { + state := LeakyBucketState{ + Last: 0, + } + return l.SetState(ctx, state) +} + const ( etcdKeyLBLease = "lease" etcdKeyLBLast = "last" @@ -264,6 +279,14 @@ func (l *LeakyBucketEtcd) SetState(ctx context.Context, state LeakyBucketState) return l.save(ctx, state) } +// Reset resets the state of the bucket in etcd. +func (l *LeakyBucketEtcd) Reset(ctx context.Context) error { + state := LeakyBucketState{ + Last: 0, + } + return l.SetState(ctx, state) +} + const ( redisKeyLBLast = "last" redisKeyLBVersion = "version" @@ -399,6 +422,14 @@ func (t *LeakyBucketRedis) SetState(ctx context.Context, state LeakyBucketState) return errors.Wrap(err, "failed to save keys to redis") } +// Reset resets the state in Redis. +func (t *LeakyBucketRedis) Reset(ctx context.Context) error { + state := LeakyBucketState{ + Last: 0, + } + return t.SetState(ctx, state) +} + // LeakyBucketMemcached is a Memcached implementation of a LeakyBucketStateBackend. type LeakyBucketMemcached struct { cli *memcache.Client @@ -489,6 +520,15 @@ func (t *LeakyBucketMemcached) SetState(ctx context.Context, state LeakyBucketSt return errors.Wrap(err, "failed to save keys to memcached") } +// Reset resets the state in Memcached. +func (t *LeakyBucketMemcached) Reset(ctx context.Context) error { + state := LeakyBucketState{ + Last: 0, + } + t.casId = 0 + return t.SetState(ctx, state) +} + // LeakyBucketDynamoDB is a DyanamoDB implementation of a LeakyBucketStateBackend. type LeakyBucketDynamoDB struct { client *dynamodb.Client @@ -560,6 +600,14 @@ func (t *LeakyBucketDynamoDB) SetState(ctx context.Context, state LeakyBucketSta return err } +// Reset resets the state in DynamoDB. +func (t *LeakyBucketDynamoDB) Reset(ctx context.Context) error { + state := LeakyBucketState{ + Last: 0, + } + return t.SetState(ctx, state) +} + const ( dynamodbBucketRaceConditionExpression = "Version <= :version" dynamoDBBucketLastKey = "Last" diff --git a/leakybucket_test.go b/leakybucket_test.go index 5d1e972..96410fc 100644 --- a/leakybucket_test.go +++ b/leakybucket_test.go @@ -139,6 +139,36 @@ func (s *LimitersTestSuite) TestLeakyBucketOverflow() { } } +func (s *LimitersTestSuite) TestLeakyBucketReset() { + rate := time.Second + capacity := int64(2) + clock := newFakeClock() + for name, bucket := range s.leakyBuckets(capacity, rate, clock) { + s.Run(name, func() { + clock.reset() + // The first call has no wait since there were no calls before. + wait, err := bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), wait) + // The second call increments the queue size by 1. + wait, err = bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(rate, wait) + // The third call overflows the bucket capacity. + wait, err = bucket.Limit(context.TODO()) + s.Require().Equal(l.ErrLimitExhausted, err) + s.Equal(rate*2, wait) + // Reset the bucket + err = bucket.Reset(context.TODO()) + s.Require().NoError(err) + // Retry the last call. This time it should succeed. + wait, err = bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), wait) + }) + } +} + func TestLeakyBucket_ZeroCapacity_ReturnsError(t *testing.T) { capacity := int64(0) rate := time.Hour diff --git a/tokenbucket.go b/tokenbucket.go index 16fc8ec..2ee2cbf 100644 --- a/tokenbucket.go +++ b/tokenbucket.go @@ -40,6 +40,8 @@ type TokenBucketStateBackend interface { State(ctx context.Context) (TokenBucketState, error) // SetState sets (persists) the current state of the TokenBucket. SetState(ctx context.Context, state TokenBucketState) error + // Reset resets (persists) the current state of the TokenBucket. + Reset(ctx context.Context) error } // TokenBucket implements the https://en.wikipedia.org/wiki/Token_bucket algorithm. @@ -122,6 +124,11 @@ func (t *TokenBucket) Limit(ctx context.Context) (time.Duration, error) { return t.Take(ctx, 1) } +// Reset resets the bucket. +func (t *TokenBucket) Reset(ctx context.Context) error { + return t.backend.Reset(ctx) +} + // TokenBucketInMemory is an in-memory implementation of TokenBucketStateBackend. // // The state is not shared nor persisted so it won't survive restarts or failures. @@ -149,6 +156,15 @@ func (t *TokenBucketInMemory) SetState(ctx context.Context, state TokenBucketSta return ctx.Err() } +// Reset resets the current bucket's state. +func (t *TokenBucketInMemory) Reset(ctx context.Context) error { + state := TokenBucketState{ + Last: 0, + Available: 0, + } + return t.SetState(ctx, state) +} + const ( etcdKeyTBLease = "lease" etcdKeyTBAvailable = "available" @@ -325,6 +341,15 @@ func (t *TokenBucketEtcd) SetState(ctx context.Context, state TokenBucketState) return t.save(ctx, state) } +// Reset resets the state of the bucket. +func (t *TokenBucketEtcd) Reset(ctx context.Context) error { + state := TokenBucketState{ + Last: 0, + Available: 0, + } + return t.SetState(ctx, state) +} + const ( redisKeyTBAvailable = "available" redisKeyTBLast = "last" @@ -487,6 +512,15 @@ func (t *TokenBucketRedis) SetState(ctx context.Context, state TokenBucketState) return errors.Wrap(err, "failed to save keys to redis") } +// Reset resets the state in Redis. +func (t *TokenBucketRedis) Reset(ctx context.Context) error { + state := TokenBucketState{ + Last: 0, + Available: 0, + } + return t.SetState(ctx, state) +} + // TokenBucketMemcached is a Memcached implementation of a TokenBucketStateBackend. // // Memcached is a distributed memory object caching system. @@ -579,6 +613,17 @@ func (t *TokenBucketMemcached) SetState(ctx context.Context, state TokenBucketSt return errors.Wrap(err, "failed to save keys to memcached") } +// Reset resets the state in Memcached. +func (t *TokenBucketMemcached) Reset(ctx context.Context) error { + state := TokenBucketState{ + Last: 0, + Available: 0, + } + // Override casId to 0 to Set instead of CompareAndSwap in SetState + t.casId = 0 + return t.SetState(ctx, state) +} + // TokenBucketDynamoDB is a DynamoDB implementation of a TokenBucketStateBackend. type TokenBucketDynamoDB struct { client *dynamodb.Client @@ -650,6 +695,15 @@ func (t *TokenBucketDynamoDB) SetState(ctx context.Context, state TokenBucketSta return err } +// Reset resets the state in DynamoDB. +func (t *TokenBucketDynamoDB) Reset(ctx context.Context) error { + state := TokenBucketState{ + Last: 0, + Available: 0, + } + return t.SetState(ctx, state) +} + const dynamoDBBucketAvailableKey = "Available" func (t *TokenBucketDynamoDB) getGetItemInput() *dynamodb.GetItemInput { diff --git a/tokenbucket_test.go b/tokenbucket_test.go index a286ea8..2f8cf36 100644 --- a/tokenbucket_test.go +++ b/tokenbucket_test.go @@ -175,6 +175,32 @@ func (s *LimitersTestSuite) TestTokenBucketOverflow() { } } +func (s *LimitersTestSuite) TestTokenBucketReset() { + clock := newFakeClock() + rate := time.Second + for name, bucket := range s.tokenBuckets(2, rate, clock) { + s.Run(name, func() { + clock.reset() + wait, err := bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), wait) + wait, err = bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), wait) + // The third call should fail. + wait, err = bucket.Limit(context.TODO()) + s.Require().Equal(l.ErrLimitExhausted, err) + s.Equal(rate, wait) + err = bucket.Reset(context.TODO()) + s.Require().NoError(err) + // Retry the last call. + wait, err = bucket.Limit(context.TODO()) + s.Require().NoError(err) + s.Equal(time.Duration(0), wait) + }) + } +} + func (s *LimitersTestSuite) TestTokenBucketRefill() { for name, backend := range s.tokenBucketBackends() { s.Run(name, func() {