From a017b679f8de6ca24f2e8c3a90db39095a7fc171 Mon Sep 17 00:00:00 2001 From: Martin Kobetic Date: Fri, 8 Sep 2023 10:48:43 -0400 Subject: [PATCH] IP based rate limiting (#289) * IP based rate limiting * deploy to dev * lint * metric tag fix * SIGUSR debug level toggle * fix debug level toggling * more debug info * Fix rebase issues * Log at info level when rate limited * Force new deploy * Reset deployments to main --------- Co-authored-by: Steven Normore Co-authored-by: Nicholas Molnar <65710+neekolas@users.noreply.github.com> --- cmd/xmtpd/main.go | 24 +++- dev/aws-shell | 2 +- pkg/api/config.go | 36 +++++- pkg/api/interceptor.go | 110 ++++++++++++---- pkg/api/server.go | 9 +- pkg/api/server_test.go | 94 +++++++++++--- pkg/api/setup_test.go | 4 +- pkg/api/telemetry.go | 8 +- pkg/authz/wallet_allow_lister.go | 23 +++- pkg/e2e/test_messagev1.go | 6 +- pkg/logging/debug.go | 4 +- pkg/logging/logging.go | 3 + pkg/metrics/api-limits.go | 49 +++++++ pkg/metrics/metrics.go | 2 + pkg/ratelimiter/buckets.go | 83 ++++++++++++ pkg/ratelimiter/rate_limiter.go | 184 +++++++++++++++++---------- pkg/ratelimiter/rate_limiter_test.go | 157 +++++++++++++++-------- 17 files changed, 611 insertions(+), 187 deletions(-) create mode 100644 pkg/metrics/api-limits.go create mode 100644 pkg/ratelimiter/buckets.go diff --git a/cmd/xmtpd/main.go b/cmd/xmtpd/main.go index a0b5b346..7ade57aa 100644 --- a/cmd/xmtpd/main.go +++ b/cmd/xmtpd/main.go @@ -51,7 +51,7 @@ func main() { return } - log, err := buildLogger(options) + log, logCfg, err := buildLogger(options) if err != nil { fatal("Could not build logger: %s", err) } @@ -150,6 +150,20 @@ func main() { doneC <- true }) + // Toggle debug level on SIGUSR1 + sigToggleC := make(chan os.Signal, 1) + signal.Notify(sigToggleC, syscall.SIGUSR1) + go func() { + for range sigToggleC { + log.Info("toggling debug level") + newLevel := zapcore.DebugLevel + if logCfg.Level.Enabled(zapcore.DebugLevel) { + newLevel = zapcore.InfoLevel + } + logCfg.Level.SetLevel(newLevel) + } + }() + sigC := make(chan os.Signal, 1) signal.Notify(sigC, syscall.SIGHUP, @@ -193,12 +207,12 @@ func initWakuLogging(options server.Options) (func(), error) { return cleanup, nil } -func buildLogger(options server.Options) (*zap.Logger, error) { +func buildLogger(options server.Options) (*zap.Logger, *zap.Config, error) { atom := zap.NewAtomicLevel() level := zapcore.InfoLevel err := level.Set(options.LogLevel) if err != nil { - return nil, err + return nil, nil, err } atom.SetLevel(level) @@ -219,10 +233,10 @@ func buildLogger(options server.Options) (*zap.Logger, error) { } log, err := cfg.Build() if err != nil { - return nil, err + return nil, nil, err } log = log.Named("xmtpd") - return log, nil + return log, &cfg, nil } diff --git a/dev/aws-shell b/dev/aws-shell index 6c35973c..6850141e 100755 --- a/dev/aws-shell +++ b/dev/aws-shell @@ -8,7 +8,7 @@ region="${REGION:-us-east-2}" cluster="${ENV:-dev}" -task="${TASK:-node-0}" +task="${TASK:-group1-node-0}" container="${CONTAINER:-$task}" task=$(aws --region "$region" \ --query 'taskArns[0]' \ diff --git a/pkg/api/config.go b/pkg/api/config.go index d20ac9d0..68c6e528 100644 --- a/pkg/api/config.go +++ b/pkg/api/config.go @@ -32,11 +32,39 @@ type Config struct { Log *zap.Logger } -// Options bundle command line options associated with the authn package. +// AuthnOptions bundle command line options associated with the authn package. type AuthnOptions struct { - Enable bool `long:"enable" description:"require client authentication via wallet tokens"` - EnableV3 bool `long:"enable-v3" description:"require client authentication for V3"` - Ratelimits bool `long:"ratelimits" description:"apply rate limits per wallet"` + /* + Enable is the master switch for the authentication module. + If it is false then the other options in this group are ignored. + + The module enforces authentication for requests that require it (currently Publish only). + Authenticated requests will be permitted according to the rules of the request type, + (i.e. you can't publish into other wallets' contact and private topics). + */ + Enable bool `long:"enable" description:"require client authentication via wallet tokens"` + EnableV3 bool `long:"enable-v3" description:"require client authentication for V3"` + /* + Ratelimits enables request rate limiting. + + Requests are bucketed by client IP address and request type (there is one bucket for all requests without IPs). + Each bucket is allocated a number of tokens that are refilled at a fixed rate per minute + up to a given maximum number of tokens. + Requests cost 1 token by default, except Publish requests cost the number of Envelopes carried + and BatchQuery requests cost the number of queries carried. + The limits depend on request type, e.g. Publish requests get lower limits than other types of request. + If Allowlists is also true then requests with Bearer tokens from wallets explicitly Allowed get priority, + i.e. a predefined multiple the configured limit. + Priority wallets get separate IP buckets from regular wallets. + */ + Ratelimits bool `long:"ratelimits" description:"apply rate limits per client IP address"` + /* + Allowlists enables wallet allow lists. + + All requests that require authentication (currently Publish only) will be rejected + for wallets that are set as Denied in the allow list. + Wallets that are explicitly Allowed will get priority rate limits if Ratelimits is true. + */ AllowLists bool `long:"allowlists" description:"apply higher limits for allow listed wallets (requires authz and ratelimits)"` PrivilegedAddresses []string `long:"privileged-address" description:"allow this address to publish into other user's topics"` } diff --git a/pkg/api/interceptor.go b/pkg/api/interceptor.go index 170a7b1e..81d7c2d3 100644 --- a/pkg/api/interceptor.go +++ b/pkg/api/interceptor.go @@ -13,6 +13,7 @@ import ( messagev1 "github.com/xmtp/proto/v3/go/message_api/v1" "github.com/xmtp/xmtp-node-go/pkg/logging" + "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" "github.com/xmtp/xmtp-node-go/pkg/types" ) @@ -45,9 +46,22 @@ func (wa *WalletAuthorizer) Unary() grpc.UnaryServerInterceptor { info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { + if !wa.Ratelimits && !wa.requiresAuthorization(req) { + return handler(ctx, req) + } + wallet, authErr := wa.getWallet(ctx) + + if wa.Ratelimits { + if err := wa.applyLimits(ctx, info.FullMethod, req, wallet); err != nil { + return nil, err + } + } if wa.requiresAuthorization(req) { - if err := wa.authorize(ctx, req); err != nil { + if authErr != nil { + return nil, status.Error(codes.Unauthenticated, authErr.Error()) + } + if err := wa.authorize(ctx, req, wallet); err != nil { return nil, err } } @@ -62,14 +76,20 @@ func (wa *WalletAuthorizer) Stream() grpc.StreamServerInterceptor { info *grpc.StreamServerInfo, handler grpc.StreamHandler, ) error { - // TODO(mk): Add metrics + if wa.Ratelimits { + ctx := stream.Context() + wallet, _ := wa.getWallet(ctx) + if err := wa.applyLimits(ctx, info.FullMethod, nil, wallet); err != nil { + return err + } + } return handler(srv, stream) } } func (wa *WalletAuthorizer) isProtocolVersion3(request *messagev1.PublishRequest) bool { envelopes := request.Envelopes - if envelopes == nil || len(envelopes) == 0 { + if len(envelopes) == 0 { return false } // If any of the envelopes are not for a v3 topic, then we treat the request as non-v3 @@ -86,34 +106,37 @@ func (wa *WalletAuthorizer) requiresAuthorization(req interface{}) bool { return isPublish && (!wa.isProtocolVersion3(publishRequest) || wa.AuthnConfig.EnableV3) } -func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}) error { +func (wa *WalletAuthorizer) getWallet(ctx context.Context) (types.WalletAddr, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return status.Errorf(codes.Unauthenticated, "metadata is not provided") + return "", status.Errorf(codes.Unauthenticated, "metadata is not provided") } values := md.Get(authorizationMetadataKey) if len(values) == 0 { - return status.Errorf(codes.Unauthenticated, "authorization token is not provided") + return "", status.Errorf(codes.Unauthenticated, "authorization token is not provided") } words := strings.SplitN(values[0], " ", 2) if len(words) != 2 { - return status.Errorf(codes.Unauthenticated, "invalid authorization header") + return "", status.Errorf(codes.Unauthenticated, "invalid authorization header") } if scheme := strings.TrimSpace(words[0]); scheme != "Bearer" { - return status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme) + return "", status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme) } token, err := decodeAuthToken(strings.TrimSpace(words[1])) if err != nil { - return status.Errorf(codes.Unauthenticated, "extracting token: %s", err) + return "", status.Errorf(codes.Unauthenticated, "extracting token: %s", err) } wallet, err := validateToken(ctx, wa.Log, token, time.Now()) if err != nil { - return status.Errorf(codes.Unauthenticated, "validating token: %s", err) + return "", status.Errorf(codes.Unauthenticated, "validating token: %s", err) } + return wallet, nil +} +func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}, wallet types.WalletAddr) error { if pub, isPublish := req.(*messagev1.PublishRequest); isPublish { for _, env := range pub.Envelopes { if !wa.privilegedAddresses[wallet] && !allowedToPublish(env.ContentTopic, wallet) { @@ -121,31 +144,59 @@ func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}) erro } } } - - return wa.authorizeWallet(ctx, wallet) -} - -func (wa *WalletAuthorizer) authorizeWallet(ctx context.Context, wallet types.WalletAddr) error { - // * for limit exhaustion return status.Errorf(codes.ResourceExhausted, ...) - // * for other authorization failure return status.Errorf(codes.PermissionDenied, ...) - - var allowListed bool if wa.AllowLists { if wa.AllowLister.IsDenyListed(wallet.String()) { wa.Log.Debug("wallet deny listed", logging.WalletAddress(wallet.String())) return status.Errorf(codes.PermissionDenied, ErrDenyListed.Error()) } - allowListed = wa.AllowLister.IsAllowListed(wallet.String()) } + return nil +} - if !wa.Ratelimits { - return nil +func (wa *WalletAuthorizer) applyLimits(ctx context.Context, fullMethod string, req interface{}, wallet types.WalletAddr) error { + // * for limit exhaustion return status.Errorf(codes.ResourceExhausted, ...) + // * for other authorization failure return status.Errorf(codes.PermissionDenied, ...) + _, method := splitMethodName(fullMethod) + + ip := clientIPFromContext(ctx) + if len(ip) == 0 { + // requests without an IP address are bucketed together as "ip_unknown" + ip = "ip_unknown" + } + + // with no wallet apply regular limits + var isPriority bool + if len(wallet) > 0 && wa.AllowLists { + isPriority = wa.AllowLister.IsAllowListed(wallet.String()) + } + cost := 1 + limitType := ratelimiter.DEFAULT + switch req := req.(type) { + case *messagev1.PublishRequest: + cost = len(req.Envelopes) + limitType = ratelimiter.PUBLISH + case *messagev1.BatchQueryRequest: + cost = len(req.Requests) } - err := wa.Limiter.Spend(wallet.String(), allowListed) + // need to separate the IP buckets between priority and regular wallets + var bucket string + if isPriority { + bucket = "P" + ip + string(limitType) + } else { + bucket = "R" + ip + string(limitType) + } + err := wa.Limiter.Spend(limitType, bucket, uint16(cost), isPriority) if err == nil { return nil } - wa.Log.Debug("wallet rate limited", logging.WalletAddress(wallet.String())) + + wa.Log.Info("rate limited", + logging.String("client_ip", ip), + logging.WalletAddress(wallet.String()), + logging.Bool("priority", isPriority), + logging.String("method", method), + logging.String("limit", string(limitType)), + logging.Int("cost", cost)) return status.Errorf(codes.ResourceExhausted, err.Error()) } @@ -196,3 +247,14 @@ func allowedToPublish(topic string, wallet types.WalletAddr) bool { return true } + +func clientIPFromContext(ctx context.Context) string { + md, _ := metadata.FromIncomingContext(ctx) + vals := md.Get("x-forwarded-for") + if len(vals) == 0 { + return "" + } + // There are potentially multiple comma separated IPs bundled in that first value + ips := strings.Split(vals[0], ",") + return strings.TrimSpace(ips[0]) +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 937db7ad..f6b7948e 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" "sync" + "time" middleware "github.com/grpc-ecosystem/go-grpc-middleware" prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -87,9 +88,15 @@ func (s *Server) startGRPC() error { stream = append(stream, telemetryInterceptor.Stream()) if s.Config.Authn.Enable { + limiter := ratelimiter.NewTokenBucketRateLimiter(s.ctx, s.Log) + // Expire buckets after 1 hour of inactivity, + // sweep for expired buckets every 10 minutes. + // Note: entry expiration should be at least some multiple of + // maximum (limit max / limit rate) minutes. + go limiter.Janitor(10*time.Minute, 1*time.Hour) s.authorizer = NewWalletAuthorizer(&AuthnConfig{ AuthnOptions: s.Config.Authn, - Limiter: ratelimiter.NewTokenBucketRateLimiter(s.Log), + Limiter: limiter, AllowLister: s.Config.AllowLister, Log: s.Log.Named("authn"), }) diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 4af685fe..9ac580a4 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -14,6 +14,7 @@ import ( messageV1 "github.com/xmtp/proto/v3/go/message_api/v1" messagev1api "github.com/xmtp/xmtp-node-go/pkg/api/message/v1" messageclient "github.com/xmtp/xmtp-node-go/pkg/api/message/v1/client" + "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" test "github.com/xmtp/xmtp-node-go/pkg/testing" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -755,27 +756,88 @@ func Test_Publish_DenyListed(t *testing.T) { require.NoError(t, err) publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{}) - expectWalletDenied(t, err) + requireErrorEqual(t, err, codes.PermissionDenied, "wallet is deny listed") require.Nil(t, publishRes) }) } -func expectWalletDenied(t *testing.T, err error) { +func Test_Ratelimits_Regular(t *testing.T) { + ctx := withAuth(t, context.Background()) + testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, server *Server) { + server.authorizer.Ratelimits = true + limiter, ok := server.authorizer.Limiter.(*ratelimiter.TokenBucketRateLimiter) + require.True(t, ok) + limiter.Limits[ratelimiter.PUBLISH] = &ratelimiter.Limit{MaxTokens: 1, RatePerMinute: 0} + envs := makeEnvelopes(2) + _, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[0:1]}) + require.NoError(t, err) + _, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[1:2]}) + require.Error(t, err) + errMsg := "1 exceeds rate limit R" + if _, ok := status.FromError(err); ok { + // GRPC + errMsg += "ip_unknownPUB" + } else { + // HTTP + errMsg += "127.0.0.1PUB" + } + requireErrorEqual(t, err, codes.ResourceExhausted, errMsg) + // check that Query is not affected by publish quota + _, err = client.Query(ctx, &messageV1.QueryRequest{ContentTopics: []string{"topic"}}) + require.NoError(t, err) + }) +} + +func Test_Ratelimits_Priority(t *testing.T) { + token, data, err := generateV2AuthToken(time.Now()) + require.NoError(t, err) + et, err := EncodeAuthToken(token) + require.NoError(t, err) + ctx := metadata.AppendToOutgoingContext(context.Background(), authorizationMetadataKey, "Bearer "+et) + + testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, server *Server) { + err := server.AllowLister.Allow(ctx, data.WalletAddr) + require.NoError(t, err) + server.authorizer.Ratelimits = true + limiter, ok := server.authorizer.Limiter.(*ratelimiter.TokenBucketRateLimiter) + require.True(t, ok) + limiter.Limits[ratelimiter.PUBLISH] = &ratelimiter.Limit{MaxTokens: 1, RatePerMinute: 0} + limiter.PriorityMultiplier = 2 + envs := makeEnvelopes(3) + _, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[0:2]}) + require.NoError(t, err) + _, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[2:3]}) + require.Error(t, err) + errMsg := "1 exceeds rate limit P" + if _, ok := status.FromError(err); ok { + // GRPC + errMsg += "ip_unknownPUB" + } else { + // HTTP + errMsg += "127.0.0.1PUB" + } + requireErrorEqual(t, err, codes.ResourceExhausted, errMsg) + // check that query is not affected by publish quota + _, err = client.Query(ctx, &messageV1.QueryRequest{ContentTopics: []string{"topic"}}) + require.NoError(t, err) + }) +} + +func requireErrorEqual(t *testing.T, err error, code codes.Code, msg string, details ...interface{}) { + require.Error(t, err) grpcErr, ok := status.FromError(err) - if ok { - require.Equal(t, codes.PermissionDenied, grpcErr.Code()) - require.Equal(t, "wallet is deny listed", grpcErr.Message()) - } else { + if ok { // GRPC + require.Equal(t, code, grpcErr.Code()) + require.Equal(t, msg, grpcErr.Message()) + require.ElementsMatch(t, details, grpcErr.Details()) + } else { // HTTP parts := strings.SplitN(err.Error(), ": ", 2) - reason, msgJSON := parts[0], parts[1] - require.Equal(t, "403 Forbidden", reason) - var msg map[string]interface{} - err := json.Unmarshal([]byte(msgJSON), &msg) - require.NoError(t, err) - require.Equal(t, map[string]interface{}{ - "code": float64(codes.PermissionDenied), - "message": "wallet is deny listed", - "details": []interface{}{}, - }, msg) + _, errJSON := parts[0], parts[1] + var httpErr map[string]interface{} + err := json.Unmarshal([]byte(errJSON), &httpErr) + require.NoError(t, err) + require.Equal(t, float64(code), httpErr["code"]) + require.Contains(t, msg, httpErr["message"]) + require.ElementsMatch(t, details, httpErr["details"]) } } diff --git a/pkg/api/setup_test.go b/pkg/api/setup_test.go index ea6ab8cb..5dd87da9 100644 --- a/pkg/api/setup_test.go +++ b/pkg/api/setup_test.go @@ -100,7 +100,7 @@ func testGRPCAndHTTP(t *testing.T, ctx context.Context, f func(*testing.T, messa t.Run("grpc", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() server, cleanup := newTestServer(t) defer cleanup() @@ -127,7 +127,7 @@ func testGRPCAndHTTP(t *testing.T, ctx context.Context, f func(*testing.T, messa func testGRPC(t *testing.T, ctx context.Context, f func(*testing.T, messageclient.Client, *Server)) { t.Parallel() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() server, cleanup := newTestServer(t) defer cleanup() diff --git a/pkg/api/telemetry.go b/pkg/api/telemetry.go index ec5ea13c..59aeab1d 100644 --- a/pkg/api/telemetry.go +++ b/pkg/api/telemetry.go @@ -10,7 +10,6 @@ import ( "go.uber.org/zap/zapcore" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) @@ -61,11 +60,8 @@ func (ti *TelemetryInterceptor) record(ctx context.Context, fullMethod string, e ri.ZapFields()..., ) - md, _ := metadata.FromIncomingContext(ctx) - if ips := md.Get("x-forwarded-for"); len(ips) > 0 { - // There are potentially multiple comma separated IPs bundled in that first value - ips := strings.Split(ips[0], ",") - fields = append(fields, zap.String("client_ip", strings.TrimSpace(ips[0]))) + if ip := clientIPFromContext(ctx); len(ip) > 0 { + fields = append(fields, zap.String("client_ip", ip)) } logFn := ti.log.Debug diff --git a/pkg/authz/wallet_allow_lister.go b/pkg/authz/wallet_allow_lister.go index dbe696fd..4476ff18 100644 --- a/pkg/authz/wallet_allow_lister.go +++ b/pkg/authz/wallet_allow_lister.go @@ -43,6 +43,7 @@ type WalletAllowLister interface { IsDenyListed(walletAddress string) bool GetPermissions(walletAddress string) Permission Deny(ctx context.Context, WalletAddress string) error + Allow(ctx context.Context, WalletAddress string) error } // DatabaseWalletAllowLister implements database backed allow list. @@ -87,9 +88,18 @@ func (d *DatabaseWalletAllowLister) IsDenyListed(walletAddress string) bool { // Add an address to the deny list. func (d *DatabaseWalletAllowLister) Deny(ctx context.Context, walletAddress string) error { + return d.Apply(ctx, walletAddress, Denied) +} + +// Add an address to the allow list. +func (d *DatabaseWalletAllowLister) Allow(ctx context.Context, walletAddress string) error { + return d.Apply(ctx, walletAddress, Allowed) +} + +func (d *DatabaseWalletAllowLister) Apply(ctx context.Context, walletAddress string, permission Permission) error { wallet := WalletAddress{ WalletAddress: walletAddress, - Permission: "deny", + Permission: unmapPermission(permission), } _, err := d.db.NewInsert().Model(&wallet).Exec(ctx) if err != nil { @@ -167,6 +177,17 @@ func mapPermission(permission string) Permission { } } +func unmapPermission(permission Permission) string { + switch permission { + case Allowed: + return "allow" + case Denied: + return "deny" + default: + return "unspecified" + } +} + func (d *DatabaseWalletAllowLister) listenForChanges() { ticker := time.NewTicker(d.refreshInterval) diff --git a/pkg/e2e/test_messagev1.go b/pkg/e2e/test_messagev1.go index d2aad1d0..307b2622 100644 --- a/pkg/e2e/test_messagev1.go +++ b/pkg/e2e/test_messagev1.go @@ -36,7 +36,7 @@ func (s *Suite) testMessageV1PublishSubscribeQuery(log *zap.Logger) error { defer cancel() ctx, err := s.withAuth(ctx) if err != nil { - return err + return errors.Wrap(err, "adding auth token") } // Subscribe across nodes. @@ -70,7 +70,7 @@ syncLoop: }, }) if err != nil { - return errors.Wrap(err, "publishing") + return errors.Wrap(err, "publishing sync envelope") } syncEnvs = append(syncEnvs, syncEnv) @@ -87,7 +87,7 @@ syncLoop: prevSyncEnvs[string(syncEnv.Message)] = true continue syncLoop } - return err + return errors.Wrap(err, "reading sync envelope") } if prevSyncEnvs[string(env.Message)] { s.log.Info("skipping previous sync envelope", zap.String("value", string(env.Message))) diff --git a/pkg/logging/debug.go b/pkg/logging/debug.go index 6052697c..3bfff342 100644 --- a/pkg/logging/debug.go +++ b/pkg/logging/debug.go @@ -27,13 +27,13 @@ func IfDebug(field zap.Field) zap.Field { } // ToggleDebugLevel toggles the log level between DEBUG and INFO level. -func ToggleDebugLevel() { +func ToggleDebugLevel() error { levelWaku := "DEBUG" levelLibP2P := libp2p.LevelDebug if IsDebugLevel() { levelWaku = "INFO" levelLibP2P = libp2p.LevelInfo } - _ = waku.SetLogLevel(levelWaku) libp2p.SetAllLoggers(levelLibP2P) + return waku.SetLogLevel(levelWaku) } diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go index 12673d9c..c1a1cf05 100644 --- a/pkg/logging/logging.go +++ b/pkg/logging/logging.go @@ -27,6 +27,9 @@ var ( ENode = logging.ENode TCPAddr = logging.TCPAddr UDPAddr = logging.UDPAddr + String = zap.String + Bool = zap.Bool + Int = zap.Int ) // WalletAddress creates a field for a wallet address. diff --git a/pkg/metrics/api-limits.go b/pkg/metrics/api-limits.go new file mode 100644 index 00000000..37866a16 --- /dev/null +++ b/pkg/metrics/api-limits.go @@ -0,0 +1,49 @@ +package metrics + +import ( + "context" + + "github.com/xmtp/xmtp-node-go/pkg/logging" + "go.opencensus.io/stats" + "go.opencensus.io/stats/view" + "go.opencensus.io/tag" + "go.uber.org/zap" +) + +var bucketsNameKey = newTagKey("name") + +var ratelimiterBucketsGaugeMeasure = stats.Int64("ratelimiter_buckets", "size of ratelimiter buckets map", stats.UnitDimensionless) +var ratelimiterBucketsGaugeView = &view.View{ + Name: "xmtp_ratelimiter_buckets", + Measure: ratelimiterBucketsGaugeMeasure, + Description: "Size of rate-limiter buckets maps", + Aggregation: view.LastValue(), + TagKeys: []tag.Key{bucketsNameKey}, +} + +func EmitRatelimiterBucketsSize(ctx context.Context, name string, size int) { + err := recordWithTags(ctx, []tag.Mutator{tag.Insert(bucketsNameKey, name)}, ratelimiterBucketsGaugeMeasure.M(int64(size))) + if err != nil { + logging.From(ctx).Warn("recording metric", + zap.String("metric", ratelimiterBucketsGaugeMeasure.Name()), + zap.Error(err)) + } +} + +var ratelimiterBucketsDeletedCounterMeasure = stats.Int64("xmtp_ratelimiter_entries_deleted", "Count of deleted entries from ratelimiter buckets map", stats.UnitDimensionless) +var ratelimiterBucketsDeletedCounterView = &view.View{ + Name: "xmtp_ratelimiter_entries_deleted", + Measure: ratelimiterBucketsDeletedCounterMeasure, + Description: "Count of deleted entries from rate-limiter buckets maps", + Aggregation: view.Count(), + TagKeys: []tag.Key{bucketsNameKey}, +} + +func EmitRatelimiterDeletedEntries(ctx context.Context, name string, count int) { + err := recordWithTags(ctx, []tag.Mutator{tag.Insert(bucketsNameKey, name)}, ratelimiterBucketsDeletedCounterMeasure.M(int64(count))) + if err != nil { + logging.From(ctx).Warn("recording metric", + zap.String("metric", ratelimiterBucketsDeletedCounterMeasure.Name()), + zap.Error(err)) + } +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 7f9dac5d..2a38d4a7 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -59,6 +59,8 @@ func RegisterViews(logger *zap.Logger) { publishedEnvelopeCounterView, queryDurationView, queryResultView, + ratelimiterBucketsGaugeView, + ratelimiterBucketsDeletedCounterView, ); err != nil { logger.Fatal("registering metrics views", zap.Error(err)) } diff --git a/pkg/ratelimiter/buckets.go b/pkg/ratelimiter/buckets.go new file mode 100644 index 00000000..96ea21fa --- /dev/null +++ b/pkg/ratelimiter/buckets.go @@ -0,0 +1,83 @@ +package ratelimiter + +import ( + "sync" + "time" + + "go.uber.org/zap" +) + +type Buckets struct { + name string + log *zap.Logger + mutex sync.RWMutex + buckets map[string]*Entry +} + +func NewBuckets(log *zap.Logger, name string) *Buckets { + return &Buckets{ + name: name, + log: log.Named(name), + buckets: make(map[string]*Entry), + mutex: sync.RWMutex{}, + } +} + +func (b *Buckets) getAndRefill(bucket string, limit *Limit, multiplier uint16, createIfMissing bool) *Entry { + // The locking strategy is adapted from the following blog post: https://misfra.me/optimizing-concurrent-map-access-in-go/ + b.mutex.RLock() + currentVal, exists := b.buckets[bucket] + b.mutex.RUnlock() + if !exists { + if !createIfMissing { + return nil + } + b.mutex.Lock() + currentVal, exists = b.buckets[bucket] + if !exists { + currentVal = &Entry{ + tokens: uint16(limit.MaxTokens * multiplier), + lastSeen: time.Now(), + mutex: sync.Mutex{}, + } + b.buckets[bucket] = currentVal + b.mutex.Unlock() + + return currentVal + } + b.mutex.Unlock() + } + + limit.Refill(currentVal, multiplier) + return currentVal +} + +func (b *Buckets) deleteExpired(expiresAfter time.Duration) (deleted int) { + // Use RLock to iterate over the map + // to allow concurrent reads + b.mutex.RLock() + var expired []string + for bucket, entry := range b.buckets { + if time.Since(entry.lastSeen) > expiresAfter { + expired = append(expired, bucket) + } + } + b.mutex.RUnlock() + if len(expired) == 0 { + return deleted + } + b.log.Info("found expired buckets", zap.Int("count", len(expired))) + // Use Lock for individual deletes to avoid prolonged + // lockout for readers. + for _, bucket := range expired { + b.mutex.Lock() + // check lastSeen again in case it was updated in the meantime. + if entry, exists := b.buckets[bucket]; exists && time.Since(entry.lastSeen) > expiresAfter { + delete(b.buckets, bucket) + deleted++ + } + b.mutex.Unlock() + } + b.log.Info("deleted expired buckets", zap.Int("count", deleted)) + return deleted +} diff --git a/pkg/ratelimiter/rate_limiter.go b/pkg/ratelimiter/rate_limiter.go index 2ae137a4..00c865d0 100644 --- a/pkg/ratelimiter/rate_limiter.go +++ b/pkg/ratelimiter/rate_limiter.go @@ -1,24 +1,32 @@ package ratelimiter import ( - "errors" + "context" + "fmt" "sync" "time" "github.com/xmtp/xmtp-node-go/pkg/logging" + "github.com/xmtp/xmtp-node-go/pkg/metrics" "go.uber.org/zap" ) +type LimitType string + const ( - ALLOW_LISTED_RATE_PER_MINUTE = uint16(10) - ALLOW_LISTED_MAX_TOKENS = uint16(10000) - REGULAR_RATE_PER_MINUTE = uint16(1) - REGULAR_MAX_TOKENS = uint16(100) - MAX_UINT_16 = 65535 + PRIORITY_MULTIPLIER = uint16(5) + DEFAULT_RATE_PER_MINUTE = uint16(2000) + DEFAULT_MAX_TOKENS = uint16(10000) + PUBLISH_RATE_PER_MINUTE = uint16(200) + PUBLISH_MAX_TOKENS = uint16(1000) + MAX_UINT_16 = 65535 + + DEFAULT LimitType = "DEF" + PUBLISH LimitType = "PUB" ) type RateLimiter interface { - Spend(walletAddress string, isAllowListed bool) error + Spend(limitType LimitType, bucket string, cost uint16, isPriority bool) error } // Entry represents a single wallet entry in the rate limiter @@ -31,89 +39,131 @@ type Entry struct { mutex sync.Mutex } +// Limit controls token refilling for bucket entries +type Limit struct { + // Maximum number of tokens that can be accumulated + MaxTokens uint16 + // Number of tokens to refill per minute + RatePerMinute uint16 +} + +func (l Limit) Refill(entry *Entry, multiplier uint16) { + now := time.Now() + ratePerMinute := l.RatePerMinute * multiplier + maxTokens := l.MaxTokens * multiplier + entry.mutex.Lock() + defer entry.mutex.Unlock() + minutesSinceLastSeen := now.Sub(entry.lastSeen).Minutes() + if minutesSinceLastSeen > 0 { + // Only update the lastSeen if it has been >= 1 minute + // This allows for continuously sending nodes to still get credits + entry.lastSeen = now + // Convert to ints so that we can check if above MAX_UINT_16 + additionalTokens := int(ratePerMinute) * int(minutesSinceLastSeen) + // Avoid overflows of UINT16 when new balance is above limit + if additionalTokens+int(entry.tokens) > MAX_UINT_16 { + additionalTokens = MAX_UINT_16 - int(entry.tokens) + } + entry.tokens = minUint16(entry.tokens+uint16(additionalTokens), maxTokens) + } +} + // TokenBucketRateLimiter implements the RateLimiter interface type TokenBucketRateLimiter struct { - log *zap.Logger - wallets map[string]*Entry - mutex sync.RWMutex + log *zap.Logger + ctx context.Context + mutex sync.RWMutex + newBuckets *Buckets // buckets that can be added to + oldBuckets *Buckets // buckets to be swept for expired entries + PriorityMultiplier uint16 + Limits map[LimitType]*Limit } -func NewTokenBucketRateLimiter(log *zap.Logger) *TokenBucketRateLimiter { +func NewTokenBucketRateLimiter(ctx context.Context, log *zap.Logger) *TokenBucketRateLimiter { tb := new(TokenBucketRateLimiter) tb.log = log.Named("ratelimiter") - tb.wallets = make(map[string]*Entry) - tb.mutex = sync.RWMutex{} - + tb.ctx = ctx + // TODO: need to periodically clear out expired items to avoid unlimited growth of the map. + tb.newBuckets = NewBuckets(log, "buckets1") + tb.oldBuckets = NewBuckets(log, "buckets2") + tb.PriorityMultiplier = PRIORITY_MULTIPLIER + tb.Limits = map[LimitType]*Limit{ + DEFAULT: {DEFAULT_MAX_TOKENS, DEFAULT_RATE_PER_MINUTE}, + PUBLISH: {PUBLISH_MAX_TOKENS, PUBLISH_RATE_PER_MINUTE}, + } return tb } -func getRates(isAllowListed bool) (ratePerMinute uint16, maxTokens uint16) { - if isAllowListed { - ratePerMinute = ALLOW_LISTED_RATE_PER_MINUTE - maxTokens = ALLOW_LISTED_MAX_TOKENS - } else { - ratePerMinute = REGULAR_RATE_PER_MINUTE - maxTokens = REGULAR_MAX_TOKENS +func (rl *TokenBucketRateLimiter) getLimit(limitType LimitType) *Limit { + if l := rl.Limits[limitType]; l != nil { + return l } - return + return rl.Limits["default"] } // Will return the entry, with items filled based on the time since last access -func (rl *TokenBucketRateLimiter) fillAndReturnEntry(walletAddress string, isAllowListed bool) *Entry { - ratePerMinute, maxTokens := getRates(isAllowListed) - // The locking strategy is adapted from the following blog post: https://misfra.me/optimizing-concurrent-map-access-in-go/ - rl.mutex.RLock() - currentVal, exists := rl.wallets[walletAddress] - rl.mutex.RUnlock() - if !exists { - rl.mutex.Lock() - currentVal = &Entry{ - tokens: uint16(maxTokens), - lastSeen: time.Now(), - mutex: sync.Mutex{}, - } - rl.wallets[walletAddress] = currentVal - rl.mutex.Unlock() - - return currentVal +func (rl *TokenBucketRateLimiter) fillAndReturnEntry(limitType LimitType, bucket string, isPriority bool) *Entry { + limit := rl.getLimit(limitType) + multiplier := uint16(1) + if isPriority { + multiplier = rl.PriorityMultiplier } - - currentVal.mutex.Lock() - defer currentVal.mutex.Unlock() - now := time.Now() - minutesSinceLastSeen := now.Sub(currentVal.lastSeen).Minutes() - if minutesSinceLastSeen > 0 { - // Only update the lastSeen if it has been >= 1 minute - // This allows for continuously sending nodes to still get credits - currentVal.lastSeen = now - // Convert to ints so that we can check if above MAX_UINT_16 - additionalTokens := int(ratePerMinute) * int(minutesSinceLastSeen) - // Avoid overflows of UINT16 when new balance is above limit - if additionalTokens+int(currentVal.tokens) > MAX_UINT_16 { - additionalTokens = MAX_UINT_16 - int(currentVal.tokens) - } - currentVal.tokens = minUint16(currentVal.tokens+uint16(additionalTokens), maxTokens) + rl.mutex.RLock() + if entry := rl.oldBuckets.getAndRefill(bucket, limit, multiplier, false); entry != nil { + rl.mutex.RUnlock() + return entry } - - return currentVal + entry := rl.newBuckets.getAndRefill(bucket, limit, multiplier, true) + rl.mutex.RUnlock() + return entry } -// The Spend function takes a WalletAddress and a boolean asserting whether to apply the AllowListed rate limits or the regular rate limits -func (rl *TokenBucketRateLimiter) Spend(walletAddress string, isAllowListed bool) error { - entry := rl.fillAndReturnEntry(walletAddress, isAllowListed) +// The Spend function takes a bucket and a boolean asserting whether to apply the PRIORITY or the REGULAR rate limits. +func (rl *TokenBucketRateLimiter) Spend(limitType LimitType, bucket string, cost uint16, isPriority bool) error { + entry := rl.fillAndReturnEntry(limitType, bucket, isPriority) entry.mutex.Lock() defer entry.mutex.Unlock() - log := rl.log.With(logging.WalletAddress(walletAddress)) - if entry.tokens == 0 { - log.Info("Rate limit exceeded") - return errors.New("rate_limit_exceeded") + log := rl.log.With( + logging.String("bucket", bucket), + logging.String("limitType", string(limitType)), + logging.Bool("isPriority", isPriority), + logging.Int("cost", int(cost))) + if entry.tokens < cost { + // Normally error strings should be fixed, but this error gets passed down to clients, + // so we want to include more information for debugging purposes. + // grpc Status has details in theory, but it seems messy to use, we may want to reconsider. + return fmt.Errorf("%d exceeds rate limit %s", cost, bucket) } - log.Debug("Spend allowed. Wallet is under threshold", zap.Int("tokens_remaining", int(entry.tokens))) + entry.tokens = entry.tokens - cost + log.Debug("Spend allowed. bucket is under threshold", zap.Int("tokens_remaining", int(entry.tokens))) + return nil +} - entry.tokens = entry.tokens - 1 +func (rl *TokenBucketRateLimiter) Janitor(sweepInterval, expiresAfter time.Duration) { + ticker := time.NewTicker(sweepInterval) + defer ticker.Stop() + for { + select { + case <-rl.ctx.Done(): + return + case <-ticker.C: + rl.sweepAndSwap(expiresAfter) + } + } +} - return nil +func (rl *TokenBucketRateLimiter) sweepAndSwap(expiresAfter time.Duration) (deletedEntries int) { + // Only the janitor writes to oldBuckets (the swap below), so we shouldn't need to rlock it here. + deletedEntries = rl.oldBuckets.deleteExpired(expiresAfter) + metrics.EmitRatelimiterDeletedEntries(rl.ctx, rl.oldBuckets.name, deletedEntries) + metrics.EmitRatelimiterBucketsSize(rl.ctx, rl.oldBuckets.name, len(rl.oldBuckets.buckets)) + rl.mutex.Lock() + rl.newBuckets, rl.oldBuckets = rl.oldBuckets, rl.newBuckets + rl.mutex.Unlock() + rl.newBuckets.log.Info("became new buckets") + rl.oldBuckets.log.Info("became old buckets") + return deletedEntries } func minUint16(x, y uint16) uint16 { diff --git a/pkg/ratelimiter/rate_limiter_test.go b/pkg/ratelimiter/rate_limiter_test.go index 4af073b2..c9046184 100644 --- a/pkg/ratelimiter/rate_limiter_test.go +++ b/pkg/ratelimiter/rate_limiter_test.go @@ -1,6 +1,7 @@ package ratelimiter import ( + "context" "sync" "testing" "time" @@ -13,103 +14,149 @@ const walletAddress = "0x1234" func TestSpend(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - rl.wallets[walletAddress] = &Entry{ - lastSeen: time.Now(), - tokens: uint16(1), - mutex: sync.Mutex{}, - } + rl := NewTokenBucketRateLimiter(context.Background(), logger) + rl.newBuckets.getAndRefill(walletAddress, &Limit{1, 0}, 1, true) - err1 := rl.Spend(walletAddress, false) + err1 := rl.Spend(DEFAULT, walletAddress, 1, false) require.NoError(t, err1) - err2 := rl.Spend(walletAddress, false) + err2 := rl.Spend(DEFAULT, walletAddress, 1, false) require.Error(t, err2) - if err2.Error() != "rate_limit_exceeded" { - t.Error("Incorrect error") - } + require.Equal(t, "1 exceeds rate limit 0x1234", err2.Error()) } // Ensure that new entries are created for previously unseen wallets func TestSpendInitialize(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - entry := rl.fillAndReturnEntry(walletAddress, false) - require.Equal(t, entry.tokens, REGULAR_MAX_TOKENS) + rl := NewTokenBucketRateLimiter(context.Background(), logger) + entry := rl.fillAndReturnEntry(DEFAULT, walletAddress, false) + require.Equal(t, entry.tokens, DEFAULT_MAX_TOKENS) } // Set the clock back 1 minute and ensure that 1 item has been added to the bucket func TestSpendWithTime(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - rl.wallets[walletAddress] = &Entry{ - // Set the last seen to 1 minute ago - lastSeen: time.Now().Add(-1 * time.Minute), - tokens: uint16(0), - mutex: sync.Mutex{}, - } - err1 := rl.Spend(walletAddress, false) + rl := NewTokenBucketRateLimiter(context.Background(), logger) + rl.Limits[DEFAULT] = &Limit{100, 1} + entry := rl.newBuckets.getAndRefill(walletAddress, &Limit{0, 0}, 1, true) + // Set the last seen to 1 minute ago + entry.lastSeen = time.Now().Add(-1 * time.Minute) + err1 := rl.Spend(DEFAULT, walletAddress, 1, false) require.NoError(t, err1) - err2 := rl.Spend(walletAddress, false) + err2 := rl.Spend(DEFAULT, walletAddress, 1, false) require.Error(t, err2) } // Ensure that the token balance cannot go above the max bucket size func TestSpendMaxBucket(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - rl.wallets[walletAddress] = &Entry{ - // Set last seen to 500 minutes ago - lastSeen: time.Now().Add(-500 * time.Minute), - tokens: uint16(0), - mutex: sync.Mutex{}, - } - entry := rl.fillAndReturnEntry(walletAddress, false) - require.Equal(t, entry.tokens, REGULAR_MAX_TOKENS) + rl := NewTokenBucketRateLimiter(context.Background(), logger) + entry := rl.newBuckets.getAndRefill(walletAddress, &Limit{0, 0}, 1, true) + // Set last seen to 500 minutes ago + entry.lastSeen = time.Now().Add(-500 * time.Minute) + entry = rl.fillAndReturnEntry(DEFAULT, walletAddress, false) + require.Equal(t, entry.tokens, DEFAULT_MAX_TOKENS) } // Ensure that the allow list is being correctly applied func TestSpendAllowListed(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - rl.wallets[walletAddress] = &Entry{ - // Set last seen to 500 minutes ago - lastSeen: time.Now().Add(-500 * time.Minute), - tokens: uint16(0), - mutex: sync.Mutex{}, - } - entry := rl.fillAndReturnEntry(walletAddress, true) - require.Equal(t, entry.tokens, uint16(500*ALLOW_LISTED_RATE_PER_MINUTE)) + rl := NewTokenBucketRateLimiter(context.Background(), logger) + entry := rl.newBuckets.getAndRefill(walletAddress, &Limit{0, 0}, 1, true) + // Set last seen to 5 minutes ago + entry.lastSeen = time.Now().Add(-5 * time.Minute) + entry = rl.fillAndReturnEntry(DEFAULT, walletAddress, true) + require.Equal(t, entry.tokens, uint16(5*DEFAULT_RATE_PER_MINUTE*PRIORITY_MULTIPLIER)) } func TestMaxUint16(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) - rl.wallets[walletAddress] = &Entry{ - // Set last seen to 1 million minutes ago - lastSeen: time.Now().Add(-1000000 * time.Minute), - tokens: uint16(0), - mutex: sync.Mutex{}, - } - - entry := rl.fillAndReturnEntry(walletAddress, true) - require.Equal(t, entry.tokens, uint16(ALLOW_LISTED_MAX_TOKENS)) + rl := NewTokenBucketRateLimiter(context.Background(), logger) + entry := rl.newBuckets.getAndRefill(walletAddress, &Limit{0, 0}, 1, true) + // Set last seen to 1 million minutes ago + entry.lastSeen = time.Now().Add(-1000000 * time.Minute) + entry = rl.fillAndReturnEntry(DEFAULT, walletAddress, true) + require.Equal(t, entry.tokens, DEFAULT_MAX_TOKENS*PRIORITY_MULTIPLIER) } // Ensures that the map can be accessed concurrently func TestSpendConcurrent(t *testing.T) { logger, _ := zap.NewDevelopment() - rl := NewTokenBucketRateLimiter(logger) + rl := NewTokenBucketRateLimiter(context.Background(), logger) wg := sync.WaitGroup{} - for i := 0; i < 100; i++ { + for i := 0; i < int(PUBLISH_MAX_TOKENS); i++ { wg.Add(1) go func() { defer wg.Done() - _ = rl.Spend(walletAddress, false) + _ = rl.Spend(PUBLISH, walletAddress, 1, false) }() } wg.Wait() - entry := rl.fillAndReturnEntry(walletAddress, false) + entry := rl.fillAndReturnEntry(PUBLISH, walletAddress, false) require.Equal(t, entry.tokens, uint16(0)) } + +func TestBucketExpiration(t *testing.T) { + // Set things up so that entries are expired after two sweep intervals + expiresAfter := 100 * time.Millisecond + sweepInterval := 60 * time.Millisecond + + logger, _ := zap.NewDevelopment() + rl := NewTokenBucketRateLimiter(context.Background(), logger) + rl.Limits[DEFAULT] = &Limit{2, 0} // 2 tokens, no refill + + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 add + require.NoError(t, rl.Spend(DEFAULT, "ip2", 1, false)) // bucket1 add + + time.Sleep(sweepInterval) + require.Equal(t, 0, rl.sweepAndSwap(expiresAfter)) // sweep bucket2 and swap + + require.NoError(t, rl.Spend(DEFAULT, "ip2", 1, false)) // bucket1 refresh + require.NoError(t, rl.Spend(DEFAULT, "ip3", 1, false)) // bucket2 add + + time.Sleep(sweepInterval) + require.Equal(t, 1, rl.sweepAndSwap(expiresAfter)) // sweep bucket1 and swap, delete ip1 + + // ip2 has been refreshed every 60ms so it should still be out of tokens + require.Error(t, rl.Spend(DEFAULT, "ip2", 1, false)) // bucket1 refresh + // ip1 entry should have expired by now, so we should have 2 tokens again + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 add + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 refresh + require.Error(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 refresh + + time.Sleep(sweepInterval) + require.Equal(t, 1, rl.sweepAndSwap(expiresAfter)) // sweep bucket2 and swap, delete ip3 + + // ip2 should still be out of tokens + require.Error(t, rl.Spend(DEFAULT, "ip2", 1, false)) // bucket1 refresh + // ip3 should have expired now and we should have 2 tokens again + require.NoError(t, rl.Spend(DEFAULT, "ip3", 1, false)) // bucket2 add + require.NoError(t, rl.Spend(DEFAULT, "ip3", 1, false)) // bucket2 refresh + require.Error(t, rl.Spend(DEFAULT, "ip3", 1, false)) // bucket2 refresh +} + +func TestBucketExpirationIntegrity(t *testing.T) { + expiresAfter := 10 * time.Millisecond + logger, _ := zap.NewDevelopment() + rl := NewTokenBucketRateLimiter(context.Background(), logger) + rl.Limits[DEFAULT] = &Limit{2, 0} // 2 tokens, no refill + + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 add + + require.Equal(t, 0, rl.sweepAndSwap(expiresAfter)) // sweep bucket2 and swap + + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 refresh + require.Error(t, rl.Spend(DEFAULT, "ip1", 1, false)) // should be out of tokens now + + require.Equal(t, 0, rl.sweepAndSwap(expiresAfter)) // sweep bucket1 and swap + + require.Error(t, rl.Spend(DEFAULT, "ip1", 1, false)) // should still be out of tokens + + require.Equal(t, 0, rl.sweepAndSwap(expiresAfter)) // sweep bucket2 and swap + + time.Sleep(2 * expiresAfter) // wait until ip1 expires + require.Equal(t, 1, rl.sweepAndSwap(expiresAfter)) // sweep bucket1 and swap, delete ip1 + + require.NoError(t, rl.Spend(DEFAULT, "ip1", 1, false)) // bucket1 add +}