diff --git a/rate.go b/rate.go index 058a822..0f7b059 100644 --- a/rate.go +++ b/rate.go @@ -73,14 +73,34 @@ func PerHour(rate int) Limit { // Limiter controls how frequently events are allowed to happen. type Limiter struct { - rdb rediser + rdb rediser + keyPrefix string } +type Option func(*Limiter) + // NewLimiter returns a new Limiter. -func NewLimiter(rdb rediser) *Limiter { - return &Limiter{ +func NewLimiter(rdb rediser, options ...Option) *Limiter { + limiter := &Limiter{ rdb: rdb, } + + for _, opt := range options { + opt(limiter) + } + + if limiter.keyPrefix == "" { + limiter.keyPrefix = redisPrefix + } + + return limiter +} + +// WithKeyPrefix is a functional option to set the redis key prefix. +func WithKeyPrefix(keyPrefix string) Option { + return func(l *Limiter) { + l.keyPrefix = keyPrefix + } } // Allow is a shortcut for AllowN(ctx, key, limit, 1). @@ -96,7 +116,7 @@ func (l Limiter) AllowN( n int, ) (*Result, error) { values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} - v, err := allowN.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() + v, err := allowN.Run(ctx, l.rdb, []string{l.keyPrefix + key}, values...).Result() if err != nil { return nil, err } @@ -132,7 +152,7 @@ func (l Limiter) AllowAtMost( n int, ) (*Result, error) { values := []interface{}{limit.Burst, limit.Rate, limit.Period.Seconds(), n} - v, err := allowAtMost.Run(ctx, l.rdb, []string{redisPrefix + key}, values...).Result() + v, err := allowAtMost.Run(ctx, l.rdb, []string{l.keyPrefix + key}, values...).Result() if err != nil { return nil, err } @@ -161,7 +181,7 @@ func (l Limiter) AllowAtMost( // Reset gets a key and reset all limitations and previous usages func (l *Limiter) Reset(ctx context.Context, key string) error { - return l.rdb.Del(ctx, redisPrefix+key).Err() + return l.rdb.Del(ctx, l.keyPrefix+key).Err() } func dur(f float64) time.Duration { diff --git a/rate_test.go b/rate_test.go index 6be2edf..e7c2c2e 100644 --- a/rate_test.go +++ b/rate_test.go @@ -21,10 +21,22 @@ func rateLimiter() *redis_rate.Limiter { return redis_rate.NewLimiter(ring) } +func TestAllow_WithKeyPrefix(t *testing.T) { + ring := redis.NewRing(&redis.RingOptions{ + Addrs: map[string]string{"server0": ":6379"}, + }) + if err := ring.FlushDB(context.TODO()).Err(); err != nil { + panic(err) + } + testAllow(t, redis_rate.NewLimiter(ring, redis_rate.WithKeyPrefix("redis_rate:"))) +} + func TestAllow(t *testing.T) { - ctx := context.Background() + testAllow(t, rateLimiter()) +} - l := rateLimiter() +func testAllow(t *testing.T, l *redis_rate.Limiter) { + ctx := context.Background() limit := redis_rate.PerSecond(10) require.Equal(t, limit.String(), "10 req/s (burst 10)")