diff --git a/internal/mocks/RateLimiter.go b/internal/mocks/RateLimiter.go index f6f497aed..7df1d9e84 100644 --- a/internal/mocks/RateLimiter.go +++ b/internal/mocks/RateLimiter.go @@ -5,7 +5,13 @@ package mocks import ( context "context" + action "github.com/bangumi/server/internal/web/rate/action" + mock "github.com/stretchr/testify/mock" + + model "github.com/bangumi/server/internal/model" + + rate "github.com/bangumi/server/internal/web/rate" ) // RateLimiter is an autogenerated mock type for the Manager type @@ -21,8 +27,62 @@ func (_m *RateLimiter) EXPECT() *RateLimiter_Expecter { return &RateLimiter_Expecter{mock: &_m.Mock} } -// Allowed provides a mock function with given fields: ctx, ip -func (_m *RateLimiter) Allowed(ctx context.Context, ip string) (bool, int, error) { +// AllowAction provides a mock function with given fields: ctx, u, _a2, limit +func (_m *RateLimiter) AllowAction(ctx context.Context, u model.UserID, _a2 action.Action, limit rate.Limit) (bool, int, error) { + ret := _m.Called(ctx, u, _a2, limit) + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, model.UserID, action.Action, rate.Limit) bool); ok { + r0 = rf(ctx, u, _a2, limit) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 int + if rf, ok := ret.Get(1).(func(context.Context, model.UserID, action.Action, rate.Limit) int); ok { + r1 = rf(ctx, u, _a2, limit) + } else { + r1 = ret.Get(1).(int) + } + + var r2 error + if rf, ok := ret.Get(2).(func(context.Context, model.UserID, action.Action, rate.Limit) error); ok { + r2 = rf(ctx, u, _a2, limit) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// RateLimiter_AllowAction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AllowAction' +type RateLimiter_AllowAction_Call struct { + *mock.Call +} + +// AllowAction is a helper method to define mock.On call +// - ctx context.Context +// - u model.UserID +// - _a2 action.Action +// - limit rate.Limit +func (_e *RateLimiter_Expecter) AllowAction(ctx interface{}, u interface{}, _a2 interface{}, limit interface{}) *RateLimiter_AllowAction_Call { + return &RateLimiter_AllowAction_Call{Call: _e.mock.On("AllowAction", ctx, u, _a2, limit)} +} + +func (_c *RateLimiter_AllowAction_Call) Run(run func(ctx context.Context, u model.UserID, _a2 action.Action, limit rate.Limit)) *RateLimiter_AllowAction_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(model.UserID), args[2].(action.Action), args[3].(rate.Limit)) + }) + return _c +} + +func (_c *RateLimiter_AllowAction_Call) Return(allowed bool, remain int, err error) *RateLimiter_AllowAction_Call { + _c.Call.Return(allowed, remain, err) + return _c +} + +// Login provides a mock function with given fields: ctx, ip +func (_m *RateLimiter) Login(ctx context.Context, ip string) (bool, int, error) { ret := _m.Called(ctx, ip) var r0 bool @@ -49,26 +109,26 @@ func (_m *RateLimiter) Allowed(ctx context.Context, ip string) (bool, int, error return r0, r1, r2 } -// RateLimiter_Allowed_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Allowed' -type RateLimiter_Allowed_Call struct { +// RateLimiter_Login_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Login' +type RateLimiter_Login_Call struct { *mock.Call } -// Allowed is a helper method to define mock.On call +// Login is a helper method to define mock.On call // - ctx context.Context // - ip string -func (_e *RateLimiter_Expecter) Allowed(ctx interface{}, ip interface{}) *RateLimiter_Allowed_Call { - return &RateLimiter_Allowed_Call{Call: _e.mock.On("Allowed", ctx, ip)} +func (_e *RateLimiter_Expecter) Login(ctx interface{}, ip interface{}) *RateLimiter_Login_Call { + return &RateLimiter_Login_Call{Call: _e.mock.On("Login", ctx, ip)} } -func (_c *RateLimiter_Allowed_Call) Run(run func(ctx context.Context, ip string)) *RateLimiter_Allowed_Call { +func (_c *RateLimiter_Login_Call) Run(run func(ctx context.Context, ip string)) *RateLimiter_Login_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *RateLimiter_Allowed_Call) Return(allowed bool, remain int, err error) *RateLimiter_Allowed_Call { +func (_c *RateLimiter_Login_Call) Return(allowed bool, remain int, err error) *RateLimiter_Login_Call { _c.Call.Return(allowed, remain, err) return _c } diff --git a/internal/pkg/test/web.go b/internal/pkg/test/web.go index 7521993c1..0e6dfa24e 100644 --- a/internal/pkg/test/web.go +++ b/internal/pkg/test/web.go @@ -152,7 +152,7 @@ func MockIndexRepo(repo domain.IndexRepo) fx.Option { func MockRateLimiter(repo rate.Manager) fx.Option { if repo == nil { mocker := &mocks.RateLimiter{} - mocker.EXPECT().Allowed(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd + mocker.EXPECT().Login(mock.Anything, mock.Anything).Return(true, 5, nil) //nolint:gomnd mocker.EXPECT().Reset(mock.Anything, mock.Anything).Return(nil) repo = mocker diff --git a/internal/web/handler/auth.go b/internal/web/handler/auth.go index b1e28b7d7..7e0250a0e 100644 --- a/internal/web/handler/auth.go +++ b/internal/web/handler/auth.go @@ -69,7 +69,7 @@ func (h Handler) PrivateLogin(c *fiber.Ctx) error { } a := h.GetHTTPAccessor(c) - allowed, remain, err := h.rateLimit.Allowed(c.Context(), a.IP.String()) + allowed, remain, err := h.rateLimit.Login(c.Context(), a.IP.String()) if err != nil { return h.InternalError(c, err, "failed to apply rate limit", a.Log()) } diff --git a/internal/web/rate.go b/internal/web/rate.go new file mode 100644 index 000000000..0bc102007 --- /dev/null +++ b/internal/web/rate.go @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see + +package web + +import ( + "net/http" + + "github.com/gofiber/fiber/v2" + + "github.com/bangumi/server/internal/pkg/errgo" + "github.com/bangumi/server/internal/web/accessor" + "github.com/bangumi/server/internal/web/rate" + "github.com/bangumi/server/internal/web/rate/action" + "github.com/bangumi/server/internal/web/res" +) + +type baseHandler interface { + GetHTTPAccessor(c *fiber.Ctx) *accessor.Accessor +} + +// rateMiddleware require Handler.NeedLogin before this middleware. +func rateMiddleware(r rate.Manager, h baseHandler, action action.Action, limit rate.Limit) fiber.Handler { + return func(c *fiber.Ctx) error { + a := h.GetHTTPAccessor(c) + if !a.Login { + return res.Unauthorized("login required") + } + + allowed, _, err := r.AllowAction(c.Context(), a.ID, action, limit) + if err != nil { + return errgo.Wrap(err, "rate.Manager.AllowAction") + } + if !allowed { + return c.SendStatus(http.StatusTooManyRequests) + } + + return c.Next() + } +} diff --git a/internal/web/rate/action/action.go b/internal/web/rate/action/action.go new file mode 100644 index 000000000..24814f2d2 --- /dev/null +++ b/internal/web/rate/action/action.go @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see + +package action + +type Action uint8 + +const ( + Unknown Action = 0 + Login Action = 1 +) diff --git a/internal/web/rate/allow.lua b/internal/web/rate/allow.lua index 0d64cbfa0..547a2e6f9 100644 --- a/internal/web/rate/allow.lua +++ b/internal/web/rate/allow.lua @@ -65,7 +65,7 @@ local diff = now - allow_at local remaining = diff / emission_interval if remaining < 0 then - redis.call('SET', ban_key, "1", "EX", ban_expire) -- ban key + redis.call('SET', ban_key, 1, "EX", ban_expire) -- ban key local reset_after = tat - now local retry_after = diff * -1 diff --git a/internal/web/rate/new.go b/internal/web/rate/new.go index 38b8f3268..ce804dd02 100644 --- a/internal/web/rate/new.go +++ b/internal/web/rate/new.go @@ -23,8 +23,10 @@ import ( "github.com/go-redis/redis/v8" + "github.com/bangumi/server/internal/model" "github.com/bangumi/server/internal/pkg/errgo" "github.com/bangumi/server/internal/pkg/gtime" + "github.com/bangumi/server/internal/web/rate/action" ) const defaultAllowPerHour = 5 @@ -35,8 +37,15 @@ var allowLua string var allowScript = redis.NewScript(allowLua) //nolint:gochecknoglobals type Manager interface { - // Allowed 检查是否允许登录。 - Allowed(ctx context.Context, ip string) (allowed bool, remain int, err error) + // Login 检查是登录限流。 + Login(ctx context.Context, ip string) (allowed bool, remain int, err error) + + AllowAction( + ctx context.Context, + u model.UserID, + action action.Action, + limit Limit, + ) (allowed bool, remain int, err error) // Reset 登录成功时应该重置计数。 Reset(ctx context.Context, ip string) error } @@ -54,28 +63,33 @@ type manager struct { r *redis.Client } -func (m manager) Allowed(ctx context.Context, ip string) (bool, int, error) { - var banKey = RedisBanKeyPrefix + ip - result, err := m.r.Exists(ctx, banKey, "1").Result() +func (m manager) AllowAction( + ctx context.Context, + u model.UserID, + action action.Action, + limit Limit, +) (bool, int, error) { + rateKey := fmt.Sprintf("chii:rate:%d:%d", action, u) + banKey := fmt.Sprintf("chii:rate:ban:%d:%d", action, u) + + res, err := m.allow(ctx, rateKey, banKey, limit) if err != nil { - return false, 0, errgo.Wrap(err, "redis.Exists") + return false, 0, errgo.Wrap(err, "Limiter.Allow") } - if result == 1 { - return false, 0, nil - } + return res.Allowed > 0, res.Remaining, nil +} + +func (m manager) Login(ctx context.Context, ip string) (bool, int, error) { + var rateKey = RedisRateKeyPrefix + ip + var banKey = RedisBanKeyPrefix + ip - res, err := m.allow(ctx, RedisRateKeyPrefix+ip, PerHour(defaultAllowPerHour)) + res, err := m.allow(ctx, rateKey, banKey, PerHour(defaultAllowPerHour)) if err != nil { return false, 0, errgo.Wrap(err, "Limiter.Allow") } - if res.Allowed <= 0 { - err := m.r.Set(ctx, banKey, "1", gtime.OneWeek).Err() - return false, 0, errgo.Wrap(err, "redis.Set") - } - - return true, res.Remaining, nil + return res.Allowed > 0, res.Remaining, nil } func (m manager) Reset(ctx context.Context, ip string) error { @@ -87,11 +101,12 @@ func (m manager) Reset(ctx context.Context, ip string) error { // AllowN reports whether n events may happen at time now. func (m manager) allow( ctx context.Context, - ip string, + rateKey string, + banKey string, limit Limit, ) (Result, error) { now := time.Now() - var keys = []string{RedisRateKeyPrefix + ip, RedisBanKeyPrefix + ip} + var keys = []string{rateKey, banKey} var values = []any{ limit.Burst, limit.Rate, limit.Period.Seconds(), now.Unix(), now.Nanosecond() / 1000, gtime.OneWeekSec, } diff --git a/internal/web/rate/new_test.go b/internal/web/rate/new_test.go index 3b3adccd7..d8567cfa3 100644 --- a/internal/web/rate/new_test.go +++ b/internal/web/rate/new_test.go @@ -18,22 +18,44 @@ import ( "context" "testing" + "github.com/go-redis/redis/v8" "github.com/stretchr/testify/require" + "github.com/bangumi/server/internal/model" "github.com/bangumi/server/internal/pkg/test" "github.com/bangumi/server/internal/web/rate" + "github.com/bangumi/server/internal/web/rate/action" ) -func TestManager_Allowed(t *testing.T) { //nolint:paralleltest +func flushDB(t *testing.T, db *redis.Client) { + t.Helper() + test.RunAndCleanup(t, func() { require.NoError(t, db.FlushDB(context.Background()).Err()) }) +} + +//nolint:paralleltest +func TestRateLimitManager_action(t *testing.T) { + test.RequireEnv(t, "redis") + db := test.GetRedis(t) + flushDB(t, db) + + const uid model.UserID = 6 + r := rate.New(db) + + allowed, remain, err := r.AllowAction(context.TODO(), uid, action.Unknown, rate.PerHour(10)) + require.NoError(t, err) + require.True(t, allowed) + require.EqualValues(t, 9, remain) +} + +//nolint:paralleltest +func TestRateLimitManager_Allowed(t *testing.T) { test.RequireEnv(t, "redis") db := test.GetRedis(t) + flushDB(t, db) const ip = "0.0.0.-0" - require.NoError(t, db.FlushDB(context.Background()).Err()) - t.Cleanup(func() { db.FlushDB(context.Background()) }) - a, err := db.Exists(context.TODO(), rate.RedisRateKeyPrefix+ip).Result() require.NoError(t, err) require.Equal(t, int64(0), a) @@ -44,32 +66,32 @@ func TestManager_Allowed(t *testing.T) { //nolint:paralleltest rateLimiter := rate.New(db) - allowed, remain, err := rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err := rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.True(t, allowed) require.Equal(t, 4, remain) - allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err = rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.True(t, allowed) require.Equal(t, 3, remain) - allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err = rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.True(t, allowed) require.Equal(t, 2, remain) - allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err = rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.True(t, allowed) require.Equal(t, 1, remain) - allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err = rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.True(t, allowed) require.Equal(t, 0, remain) - allowed, remain, err = rateLimiter.Allowed(context.TODO(), ip) + allowed, remain, err = rateLimiter.Login(context.TODO(), ip) require.NoError(t, err) require.False(t, allowed) require.Equal(t, 0, remain) diff --git a/internal/web/rate_internal_test.go b/internal/web/rate_internal_test.go new file mode 100644 index 000000000..449d573b2 --- /dev/null +++ b/internal/web/rate_internal_test.go @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: AGPL-3.0-only +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published +// by the Free Software Foundation, version 3. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see + +package web + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/bangumi/server/internal/domain" + "github.com/bangumi/server/internal/mocks" + "github.com/bangumi/server/internal/model" + "github.com/bangumi/server/internal/web/accessor" + "github.com/bangumi/server/internal/web/rate" + "github.com/bangumi/server/internal/web/rate/action" +) + +func Test_rateMiddleware(t *testing.T) { + t.Parallel() + app := fiber.New() + + r := mocks.NewRateLimiter(t) + r.EXPECT().AllowAction(mock.Anything, model.UserID(1), mock.Anything, mock.Anything). + Return(false, 1, nil) + + app.Use(rateMiddleware(r, mockBaseHandler{ + a: &accessor.Accessor{ + RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: true, + }, + }, action.Action(0), rate.PerHour(10))) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res, err := app.Test(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusTooManyRequests, res.StatusCode) +} + +func Test_rateMiddleware_allow(t *testing.T) { + t.Parallel() + app := fiber.New() + + r := mocks.NewRateLimiter(t) + r.EXPECT().AllowAction(mock.Anything, model.UserID(1), mock.Anything, mock.Anything). + Return(true, 1, nil) + + app.Use(rateMiddleware(r, mockBaseHandler{ + a: &accessor.Accessor{ + RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: true, + }, + }, action.Action(0), rate.PerHour(10))) + + app.Get("/", func(c *fiber.Ctx) error { + return c.SendString("") + }) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res, err := app.Test(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusOK, res.StatusCode) +} + +func Test_rateMiddleware_not_login(t *testing.T) { + t.Parallel() + app := fiber.New(fiber.Config{ErrorHandler: getDefaultErrorHandler()}) + + r := mocks.NewRateLimiter(t) + + app.Use(rateMiddleware(r, mockBaseHandler{ + a: &accessor.Accessor{ + RequestID: "fake-request-id", IP: net.IPv4(1, 1, 1, 1), Auth: domain.Auth{ID: 1}, Login: false, + }, + }, action.Action(0), rate.PerHour(10))) + + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + res, err := app.Test(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusUnauthorized, res.StatusCode) +} + +type mockBaseHandler struct { + a *accessor.Accessor +} + +func (h mockBaseHandler) GetHTTPAccessor(c *fiber.Ctx) *accessor.Accessor { + return h.a +}