Skip to content

Commit

Permalink
web: add user level rate limit method (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
trim21 authored Aug 7, 2022
1 parent dc5aebc commit aa4031e
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 40 deletions.
78 changes: 69 additions & 9 deletions internal/mocks/RateLimiter.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion internal/pkg/test/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func MockIndexRepo(repo domain.IndexRepo) fx.Option {
func MockRateLimiter(repo rate.Manager) fx.Option {
if repo == nil {
mocker := &mocks.RateLimiter{}
mocker.EXPECT().Allowed(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd
mocker.EXPECT().Login(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd
mocker.EXPECT().Reset(mock.Anything, mock.Anything).Return(nil)

repo = mocker
Expand Down
2 changes: 1 addition & 1 deletion internal/web/handler/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (h Handler) PrivateLogin(c *fiber.Ctx) error {
}

a := h.GetHTTPAccessor(c)
allowed, remain, err := h.rateLimit.Allowed(c.Context(), a.IP.String())
allowed, remain, err := h.rateLimit.Login(c.Context(), a.IP.String())
if err != nil {
return h.InternalError(c, err, "failed to apply rate limit", a.Log())
}
Expand Down
51 changes: 51 additions & 0 deletions internal/web/rate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>

package web

import (
"net/http"

"github.com/gofiber/fiber/v2"

"github.com/bangumi/server/internal/pkg/errgo"
"github.com/bangumi/server/internal/web/accessor"
"github.com/bangumi/server/internal/web/rate"
"github.com/bangumi/server/internal/web/rate/action"
"github.com/bangumi/server/internal/web/res"
)

type baseHandler interface {
GetHTTPAccessor(c *fiber.Ctx) *accessor.Accessor
}

// rateMiddleware require Handler.NeedLogin before this middleware.
func rateMiddleware(r rate.Manager, h baseHandler, action action.Action, limit rate.Limit) fiber.Handler {
return func(c *fiber.Ctx) error {
a := h.GetHTTPAccessor(c)
if !a.Login {
return res.Unauthorized("login required")
}

allowed, _, err := r.AllowAction(c.Context(), a.ID, action, limit)
if err != nil {
return errgo.Wrap(err, "rate.Manager.AllowAction")
}
if !allowed {
return c.SendStatus(http.StatusTooManyRequests)
}

return c.Next()
}
}
22 changes: 22 additions & 0 deletions internal/web/rate/action/action.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>

package action

type Action uint8

const (
Unknown Action = 0
Login Action = 1
)
2 changes: 1 addition & 1 deletion internal/web/rate/allow.lua
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ local diff = now - allow_at
local remaining = diff / emission_interval

if remaining < 0 then
redis.call('SET', ban_key, "1", "EX", ban_expire) -- ban key
redis.call('SET', ban_key, 1, "EX", ban_expire) -- ban key

local reset_after = tat - now
local retry_after = diff * -1
Expand Down
51 changes: 33 additions & 18 deletions internal/web/rate/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import (

"github.com/go-redis/redis/v8"

"github.com/bangumi/server/internal/model"
"github.com/bangumi/server/internal/pkg/errgo"
"github.com/bangumi/server/internal/pkg/gtime"
"github.com/bangumi/server/internal/web/rate/action"
)

const defaultAllowPerHour = 5
Expand All @@ -35,8 +37,15 @@ var allowLua string
var allowScript = redis.NewScript(allowLua) //nolint:gochecknoglobals

type Manager interface {
// Allowed 检查是否允许登录。
Allowed(ctx context.Context, ip string) (allowed bool, remain int, err error)
// Login 检查是登录限流。
Login(ctx context.Context, ip string) (allowed bool, remain int, err error)

AllowAction(
ctx context.Context,
u model.UserID,
action action.Action,
limit Limit,
) (allowed bool, remain int, err error)
// Reset 登录成功时应该重置计数。
Reset(ctx context.Context, ip string) error
}
Expand All @@ -54,28 +63,33 @@ type manager struct {
r *redis.Client
}

func (m manager) Allowed(ctx context.Context, ip string) (bool, int, error) {
var banKey = RedisBanKeyPrefix + ip
result, err := m.r.Exists(ctx, banKey, "1").Result()
func (m manager) AllowAction(
ctx context.Context,
u model.UserID,
action action.Action,
limit Limit,
) (bool, int, error) {
rateKey := fmt.Sprintf("chii:rate:%d:%d", action, u)
banKey := fmt.Sprintf("chii:rate:ban:%d:%d", action, u)

res, err := m.allow(ctx, rateKey, banKey, limit)
if err != nil {
return false, 0, errgo.Wrap(err, "redis.Exists")
return false, 0, errgo.Wrap(err, "Limiter.Allow")
}

if result == 1 {
return false, 0, nil
}
return res.Allowed > 0, res.Remaining, nil
}

func (m manager) Login(ctx context.Context, ip string) (bool, int, error) {
var rateKey = RedisRateKeyPrefix + ip
var banKey = RedisBanKeyPrefix + ip

res, err := m.allow(ctx, RedisRateKeyPrefix+ip, PerHour(defaultAllowPerHour))
res, err := m.allow(ctx, rateKey, banKey, PerHour(defaultAllowPerHour))
if err != nil {
return false, 0, errgo.Wrap(err, "Limiter.Allow")
}

if res.Allowed <= 0 {
err := m.r.Set(ctx, banKey, "1", gtime.OneWeek).Err()
return false, 0, errgo.Wrap(err, "redis.Set")
}

return true, res.Remaining, nil
return res.Allowed > 0, res.Remaining, nil
}

func (m manager) Reset(ctx context.Context, ip string) error {
Expand All @@ -87,11 +101,12 @@ func (m manager) Reset(ctx context.Context, ip string) error {
// AllowN reports whether n events may happen at time now.
func (m manager) allow(
ctx context.Context,
ip string,
rateKey string,
banKey string,
limit Limit,
) (Result, error) {
now := time.Now()
var keys = []string{RedisRateKeyPrefix + ip, RedisBanKeyPrefix + ip}
var keys = []string{rateKey, banKey}
var values = []any{
limit.Burst, limit.Rate, limit.Period.Seconds(), now.Unix(), now.Nanosecond() / 1000, gtime.OneWeekSec,
}
Expand Down
42 changes: 32 additions & 10 deletions internal/web/rate/new_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,44 @@ import (
"context"
"testing"

"github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"

"github.com/bangumi/server/internal/model"
"github.com/bangumi/server/internal/pkg/test"
"github.com/bangumi/server/internal/web/rate"
"github.com/bangumi/server/internal/web/rate/action"
)

func TestManager_Allowed(t *testing.T) { //nolint:paralleltest
func flushDB(t *testing.T, db *redis.Client) {
t.Helper()
test.RunAndCleanup(t, func() { require.NoError(t, db.FlushDB(context.Background()).Err()) })
}

//nolint:paralleltest
func TestRateLimitManager_action(t *testing.T) {
test.RequireEnv(t, "redis")
db := test.GetRedis(t)
flushDB(t, db)

const uid model.UserID = 6
r := rate.New(db)

allowed, remain, err := r.AllowAction(context.TODO(), uid, action.Unknown, rate.PerHour(10))
require.NoError(t, err)
require.True(t, allowed)
require.EqualValues(t, 9, remain)
}

//nolint:paralleltest
func TestRateLimitManager_Allowed(t *testing.T) {
test.RequireEnv(t, "redis")

db := test.GetRedis(t)
flushDB(t, db)

const ip = "0.0.0.-0"

require.NoError(t, db.FlushDB(context.Background()).Err())
t.Cleanup(func() { db.FlushDB(context.Background()) })

a, err := db.Exists(context.TODO(), rate.RedisRateKeyPrefix+ip).Result()
require.NoError(t, err)
require.Equal(t, int64(0), a)
Expand All @@ -44,32 +66,32 @@ func TestManager_Allowed(t *testing.T) { //nolint:paralleltest

rateLimiter := rate.New(db)

allowed, remain, err := rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err := rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 4, remain)

allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 3, remain)

allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 2, remain)

allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 1, remain)

allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 0, remain)

allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.False(t, allowed)
require.Equal(t, 0, remain)
Expand Down
Loading

0 comments on commit aa4031e

Please sign in to comment.