diff --git a/internal/mocks/RateLimiter.go b/internal/mocks/RateLimiter.go
index f6f497aed..7df1d9e84 100644
--- a/internal/mocks/RateLimiter.go
+++ b/internal/mocks/RateLimiter.go
@@ -5,7 +5,13 @@ package mocks
import (
context "context"
+ action "github.com/bangumi/server/internal/web/rate/action"
+
mock "github.com/stretchr/testify/mock"
+
+ model "github.com/bangumi/server/internal/model"
+
+ rate "github.com/bangumi/server/internal/web/rate"
)
// RateLimiter is an autogenerated mock type for the Manager type
@@ -21,8 +27,62 @@ func (_m *RateLimiter) EXPECT() *RateLimiter_Expecter {
return &RateLimiter_Expecter{mock: &_m.Mock}
}
-// Allowed provides a mock function with given fields: ctx, ip
-func (_m *RateLimiter) Allowed(ctx context.Context, ip string) (bool, int, error) {
+// AllowAction provides a mock function with given fields: ctx, u, _a2, limit
+func (_m *RateLimiter) AllowAction(ctx context.Context, u model.UserID, _a2 action.Action, limit rate.Limit) (bool, int, error) {
+ ret := _m.Called(ctx, u, _a2, limit)
+
+ var r0 bool
+ if rf, ok := ret.Get(0).(func(context.Context, model.UserID, action.Action, rate.Limit) bool); ok {
+ r0 = rf(ctx, u, _a2, limit)
+ } else {
+ r0 = ret.Get(0).(bool)
+ }
+
+ var r1 int
+ if rf, ok := ret.Get(1).(func(context.Context, model.UserID, action.Action, rate.Limit) int); ok {
+ r1 = rf(ctx, u, _a2, limit)
+ } else {
+ r1 = ret.Get(1).(int)
+ }
+
+ var r2 error
+ if rf, ok := ret.Get(2).(func(context.Context, model.UserID, action.Action, rate.Limit) error); ok {
+ r2 = rf(ctx, u, _a2, limit)
+ } else {
+ r2 = ret.Error(2)
+ }
+
+ return r0, r1, r2
+}
+
+// RateLimiter_AllowAction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllowAction'
+type RateLimiter_AllowAction_Call struct {
+ *mock.Call
+}
+
+// AllowAction is a helper method to define mock.On call
+// - ctx context.Context
+// - u model.UserID
+// - _a2 action.Action
+// - limit rate.Limit
+func (_e *RateLimiter_Expecter) AllowAction(ctx interface{}, u interface{}, _a2 interface{}, limit interface{}) *RateLimiter_AllowAction_Call {
+ return &RateLimiter_AllowAction_Call{Call: _e.mock.On("AllowAction", ctx, u, _a2, limit)}
+}
+
+func (_c *RateLimiter_AllowAction_Call) Run(run func(ctx context.Context, u model.UserID, _a2 action.Action, limit rate.Limit)) *RateLimiter_AllowAction_Call {
+ _c.Call.Run(func(args mock.Arguments) {
+ run(args[0].(context.Context), args[1].(model.UserID), args[2].(action.Action), args[3].(rate.Limit))
+ })
+ return _c
+}
+
+func (_c *RateLimiter_AllowAction_Call) Return(allowed bool, remain int, err error) *RateLimiter_AllowAction_Call {
+ _c.Call.Return(allowed, remain, err)
+ return _c
+}
+
+// Login provides a mock function with given fields: ctx, ip
+func (_m *RateLimiter) Login(ctx context.Context, ip string) (bool, int, error) {
ret := _m.Called(ctx, ip)
var r0 bool
@@ -49,26 +109,26 @@ func (_m *RateLimiter) Allowed(ctx context.Context, ip string) (bool, int, error
return r0, r1, r2
}
-// RateLimiter_Allowed_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Allowed'
-type RateLimiter_Allowed_Call struct {
+// RateLimiter_Login_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Login'
+type RateLimiter_Login_Call struct {
*mock.Call
}
-// Allowed is a helper method to define mock.On call
+// Login is a helper method to define mock.On call
// - ctx context.Context
// - ip string
-func (_e *RateLimiter_Expecter) Allowed(ctx interface{}, ip interface{}) *RateLimiter_Allowed_Call {
- return &RateLimiter_Allowed_Call{Call: _e.mock.On("Allowed", ctx, ip)}
+func (_e *RateLimiter_Expecter) Login(ctx interface{}, ip interface{}) *RateLimiter_Login_Call {
+ return &RateLimiter_Login_Call{Call: _e.mock.On("Login", ctx, ip)}
}
-func (_c *RateLimiter_Allowed_Call) Run(run func(ctx context.Context, ip string)) *RateLimiter_Allowed_Call {
+func (_c *RateLimiter_Login_Call) Run(run func(ctx context.Context, ip string)) *RateLimiter_Login_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
-func (_c *RateLimiter_Allowed_Call) Return(allowed bool, remain int, err error) *RateLimiter_Allowed_Call {
+func (_c *RateLimiter_Login_Call) Return(allowed bool, remain int, err error) *RateLimiter_Login_Call {
_c.Call.Return(allowed, remain, err)
return _c
}
diff --git a/internal/pkg/test/web.go b/internal/pkg/test/web.go
index 7521993c1..0e6dfa24e 100644
--- a/internal/pkg/test/web.go
+++ b/internal/pkg/test/web.go
@@ -152,7 +152,7 @@ func MockIndexRepo(repo domain.IndexRepo) fx.Option {
func MockRateLimiter(repo rate.Manager) fx.Option {
if repo == nil {
mocker := &mocks.RateLimiter{}
- mocker.EXPECT().Allowed(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd
+ mocker.EXPECT().Login(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd
mocker.EXPECT().Reset(mock.Anything, mock.Anything).Return(nil)
repo = mocker
diff --git a/internal/web/handler/auth.go b/internal/web/handler/auth.go
index b1e28b7d7..7e0250a0e 100644
--- a/internal/web/handler/auth.go
+++ b/internal/web/handler/auth.go
@@ -69,7 +69,7 @@ func (h Handler) PrivateLogin(c *fiber.Ctx) error {
}
a := h.GetHTTPAccessor(c)
- allowed, remain, err := h.rateLimit.Allowed(c.Context(), a.IP.String())
+ allowed, remain, err := h.rateLimit.Login(c.Context(), a.IP.String())
if err != nil {
return h.InternalError(c, err, "failed to apply rate limit", a.Log())
}
diff --git a/internal/web/rate.go b/internal/web/rate.go
new file mode 100644
index 000000000..0bc102007
--- /dev/null
+++ b/internal/web/rate.go
@@ -0,0 +1,51 @@
+// SPDX-License-Identifier: AGPL-3.0-only
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, version 3.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+// See the GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see
+
+package web
+
+import (
+ "net/http"
+
+ "github.com/gofiber/fiber/v2"
+
+ "github.com/bangumi/server/internal/pkg/errgo"
+ "github.com/bangumi/server/internal/web/accessor"
+ "github.com/bangumi/server/internal/web/rate"
+ "github.com/bangumi/server/internal/web/rate/action"
+ "github.com/bangumi/server/internal/web/res"
+)
+
+type baseHandler interface {
+ GetHTTPAccessor(c *fiber.Ctx) *accessor.Accessor
+}
+
+// rateMiddleware require Handler.NeedLogin before this middleware.
+func rateMiddleware(r rate.Manager, h baseHandler, action action.Action, limit rate.Limit) fiber.Handler {
+ return func(c *fiber.Ctx) error {
+ a := h.GetHTTPAccessor(c)
+ if !a.Login {
+ return res.Unauthorized("login required")
+ }
+
+ allowed, _, err := r.AllowAction(c.Context(), a.ID, action, limit)
+ if err != nil {
+ return errgo.Wrap(err, "rate.Manager.AllowAction")
+ }
+ if !allowed {
+ return c.SendStatus(http.StatusTooManyRequests)
+ }
+
+ return c.Next()
+ }
+}
diff --git a/internal/web/rate/action/action.go b/internal/web/rate/action/action.go
new file mode 100644
index 000000000..24814f2d2
--- /dev/null
+++ b/internal/web/rate/action/action.go
@@ -0,0 +1,22 @@
+// SPDX-License-Identifier: AGPL-3.0-only
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, version 3.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+// See the GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see
+
+package action
+
+type Action uint8
+
+const (
+ Unknown Action = 0
+ Login Action = 1
+)
diff --git a/internal/web/rate/allow.lua b/internal/web/rate/allow.lua
index 0d64cbfa0..547a2e6f9 100644
--- a/internal/web/rate/allow.lua
+++ b/internal/web/rate/allow.lua
@@ -65,7 +65,7 @@ local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
- redis.call('SET', ban_key, "1", "EX", ban_expire) -- ban key
+ redis.call('SET', ban_key, 1, "EX", ban_expire) -- ban key
local reset_after = tat - now
local retry_after = diff * -1
diff --git a/internal/web/rate/new.go b/internal/web/rate/new.go
index 38b8f3268..ce804dd02 100644
--- a/internal/web/rate/new.go
+++ b/internal/web/rate/new.go
@@ -23,8 +23,10 @@ import (
"github.com/go-redis/redis/v8"
+ "github.com/bangumi/server/internal/model"
"github.com/bangumi/server/internal/pkg/errgo"
"github.com/bangumi/server/internal/pkg/gtime"
+ "github.com/bangumi/server/internal/web/rate/action"
)
const defaultAllowPerHour = 5
@@ -35,8 +37,15 @@ var allowLua string
var allowScript = redis.NewScript(allowLua) //nolint:gochecknoglobals
type Manager interface {
- // Allowed 检查是否允许登录。
- Allowed(ctx context.Context, ip string) (allowed bool, remain int, err error)
+ // Login 检查是登录限流。
+ Login(ctx context.Context, ip string) (allowed bool, remain int, err error)
+
+ AllowAction(
+ ctx context.Context,
+ u model.UserID,
+ action action.Action,
+ limit Limit,
+ ) (allowed bool, remain int, err error)
// Reset 登录成功时应该重置计数。
Reset(ctx context.Context, ip string) error
}
@@ -54,28 +63,33 @@ type manager struct {
r *redis.Client
}
-func (m manager) Allowed(ctx context.Context, ip string) (bool, int, error) {
- var banKey = RedisBanKeyPrefix + ip
- result, err := m.r.Exists(ctx, banKey, "1").Result()
+func (m manager) AllowAction(
+ ctx context.Context,
+ u model.UserID,
+ action action.Action,
+ limit Limit,
+) (bool, int, error) {
+ rateKey := fmt.Sprintf("chii:rate:%d:%d", action, u)
+ banKey := fmt.Sprintf("chii:rate:ban:%d:%d", action, u)
+
+ res, err := m.allow(ctx, rateKey, banKey, limit)
if err != nil {
- return false, 0, errgo.Wrap(err, "redis.Exists")
+ return false, 0, errgo.Wrap(err, "Limiter.Allow")
}
- if result == 1 {
- return false, 0, nil
- }
+ return res.Allowed > 0, res.Remaining, nil
+}
+
+func (m manager) Login(ctx context.Context, ip string) (bool, int, error) {
+ var rateKey = RedisRateKeyPrefix + ip
+ var banKey = RedisBanKeyPrefix + ip
- res, err := m.allow(ctx, RedisRateKeyPrefix+ip, PerHour(defaultAllowPerHour))
+ res, err := m.allow(ctx, rateKey, banKey, PerHour(defaultAllowPerHour))
if err != nil {
return false, 0, errgo.Wrap(err, "Limiter.Allow")
}
- if res.Allowed <= 0 {
- err := m.r.Set(ctx, banKey, "1", gtime.OneWeek).Err()
- return false, 0, errgo.Wrap(err, "redis.Set")
- }
-
- return true, res.Remaining, nil
+ return res.Allowed > 0, res.Remaining, nil
}
func (m manager) Reset(ctx context.Context, ip string) error {
@@ -87,11 +101,12 @@ func (m manager) Reset(ctx context.Context, ip string) error {
// AllowN reports whether n events may happen at time now.
func (m manager) allow(
ctx context.Context,
- ip string,
+ rateKey string,
+ banKey string,
limit Limit,
) (Result, error) {
now := time.Now()
- var keys = []string{RedisRateKeyPrefix + ip, RedisBanKeyPrefix + ip}
+ var keys = []string{rateKey, banKey}
var values = []any{
limit.Burst, limit.Rate, limit.Period.Seconds(), now.Unix(), now.Nanosecond() / 1000, gtime.OneWeekSec,
}
diff --git a/internal/web/rate/new_test.go b/internal/web/rate/new_test.go
index 3b3adccd7..d8567cfa3 100644
--- a/internal/web/rate/new_test.go
+++ b/internal/web/rate/new_test.go
@@ -18,22 +18,44 @@ import (
"context"
"testing"
+ "github.com/go-redis/redis/v8"
"github.com/stretchr/testify/require"
+ "github.com/bangumi/server/internal/model"
"github.com/bangumi/server/internal/pkg/test"
"github.com/bangumi/server/internal/web/rate"
+ "github.com/bangumi/server/internal/web/rate/action"
)
-func TestManager_Allowed(t *testing.T) { //nolint:paralleltest
+func flushDB(t *testing.T, db *redis.Client) {
+ t.Helper()
+ test.RunAndCleanup(t, func() { require.NoError(t, db.FlushDB(context.Background()).Err()) })
+}
+
+//nolint:paralleltest
+func TestRateLimitManager_action(t *testing.T) {
+ test.RequireEnv(t, "redis")
+ db := test.GetRedis(t)
+ flushDB(t, db)
+
+ const uid model.UserID = 6
+ r := rate.New(db)
+
+ allowed, remain, err := r.AllowAction(context.TODO(), uid, action.Unknown, rate.PerHour(10))
+ require.NoError(t, err)
+ require.True(t, allowed)
+ require.EqualValues(t, 9, remain)
+}
+
+//nolint:paralleltest
+func TestRateLimitManager_Allowed(t *testing.T) {
test.RequireEnv(t, "redis")
db := test.GetRedis(t)
+ flushDB(t, db)
const ip = "0.0.0.-0"
- require.NoError(t, db.FlushDB(context.Background()).Err())
- t.Cleanup(func() { db.FlushDB(context.Background()) })
-
a, err := db.Exists(context.TODO(), rate.RedisRateKeyPrefix+ip).Result()
require.NoError(t, err)
require.Equal(t, int64(0), a)
@@ -44,32 +66,32 @@ func TestManager_Allowed(t *testing.T) { //nolint:paralleltest
rateLimiter := rate.New(db)
- allowed, remain, err := rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err := rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 4, remain)
- allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 3, remain)
- allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 2, remain)
- allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 1, remain)
- allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.True(t, allowed)
require.Equal(t, 0, remain)
- allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip)
+ allowed, remain, err = rateLimiter.Login(context.TODO(), ip)
require.NoError(t, err)
require.False(t, allowed)
require.Equal(t, 0, remain)
diff --git a/internal/web/rate_internal_test.go b/internal/web/rate_internal_test.go
new file mode 100644
index 000000000..449d573b2
--- /dev/null
+++ b/internal/web/rate_internal_test.go
@@ -0,0 +1,109 @@
+// SPDX-License-Identifier: AGPL-3.0-only
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published
+// by the Free Software Foundation, version 3.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+// See the GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see
+
+package web
+
+import (
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gofiber/fiber/v2"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+
+ "github.com/bangumi/server/internal/domain"
+ "github.com/bangumi/server/internal/mocks"
+ "github.com/bangumi/server/internal/model"
+ "github.com/bangumi/server/internal/web/accessor"
+ "github.com/bangumi/server/internal/web/rate"
+ "github.com/bangumi/server/internal/web/rate/action"
+)
+
+func Test_rateMiddleware(t *testing.T) {
+ t.Parallel()
+ app := fiber.New()
+
+ r := mocks.NewRateLimiter(t)
+ r.EXPECT().AllowAction(mock.Anything, model.UserID(1), mock.Anything, mock.Anything).
+ Return(false, 1, nil)
+
+ app.Use(rateMiddleware(r, mockBaseHandler{
+ a: &accessor.Accessor{
+ RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: true,
+ },
+ }, action.Action(0), rate.PerHour(10)))
+
+ req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+
+ require.Equal(t, http.StatusTooManyRequests, res.StatusCode)
+}
+
+func Test_rateMiddleware_allow(t *testing.T) {
+ t.Parallel()
+ app := fiber.New()
+
+ r := mocks.NewRateLimiter(t)
+ r.EXPECT().AllowAction(mock.Anything, model.UserID(1), mock.Anything, mock.Anything).
+ Return(true, 1, nil)
+
+ app.Use(rateMiddleware(r, mockBaseHandler{
+ a: &accessor.Accessor{
+ RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: true,
+ },
+ }, action.Action(0), rate.PerHour(10)))
+
+ app.Get("/", func(c *fiber.Ctx) error {
+ return c.SendString("")
+ })
+
+ req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+
+ require.Equal(t, http.StatusOK, res.StatusCode)
+}
+
+func Test_rateMiddleware_not_login(t *testing.T) {
+ t.Parallel()
+ app := fiber.New(fiber.Config{ErrorHandler: getDefaultErrorHandler()})
+
+ r := mocks.NewRateLimiter(t)
+
+ app.Use(rateMiddleware(r, mockBaseHandler{
+ a: &accessor.Accessor{
+ RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: false,
+ },
+ }, action.Action(0), rate.PerHour(10)))
+
+ req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
+ res, err := app.Test(req)
+ require.NoError(t, err)
+ defer res.Body.Close()
+
+ require.Equal(t, http.StatusUnauthorized, res.StatusCode)
+}
+
+type mockBaseHandler struct {
+ a *accessor.Accessor
+}
+
+func (h mockBaseHandler) GetHTTPAccessor(c *fiber.Ctx) *accessor.Accessor {
+ return h.a
+}