Skip to content

Commit

Permalink
IP based rate limiting (#289)
Browse files Browse the repository at this point in the history
* IP based rate limiting

* deploy to dev

* lint

* metric tag fix

* SIGUSR debug level toggle

* fix debug level toggling

* more debug info

* Fix rebase issues

* Log at info level when rate limited

* Force new deploy

* Reset deployments to main

---------

Co-authored-by: Steven Normore <[email protected]>
Co-authored-by: Nicholas Molnar <[email protected]>
  • Loading branch information
3 people authored Sep 8, 2023
1 parent dc4719c commit a017b67
Show file tree
Hide file tree
Showing 17 changed files with 611 additions and 187 deletions.
24 changes: 19 additions & 5 deletions cmd/xmtpd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func main() {
return
}

log, err := buildLogger(options)
log, logCfg, err := buildLogger(options)
if err != nil {
fatal("Could not build logger: %s", err)
}
Expand Down Expand Up @@ -150,6 +150,20 @@ func main() {
doneC <- true
})

// Toggle debug level on SIGUSR1
sigToggleC := make(chan os.Signal, 1)
signal.Notify(sigToggleC, syscall.SIGUSR1)
go func() {
for range sigToggleC {
log.Info("toggling debug level")
newLevel := zapcore.DebugLevel
if logCfg.Level.Enabled(zapcore.DebugLevel) {
newLevel = zapcore.InfoLevel
}
logCfg.Level.SetLevel(newLevel)
}
}()

sigC := make(chan os.Signal, 1)
signal.Notify(sigC,
syscall.SIGHUP,
Expand Down Expand Up @@ -193,12 +207,12 @@ func initWakuLogging(options server.Options) (func(), error) {
return cleanup, nil
}

func buildLogger(options server.Options) (*zap.Logger, error) {
func buildLogger(options server.Options) (*zap.Logger, *zap.Config, error) {
atom := zap.NewAtomicLevel()
level := zapcore.InfoLevel
err := level.Set(options.LogLevel)
if err != nil {
return nil, err
return nil, nil, err
}
atom.SetLevel(level)

Expand All @@ -219,10 +233,10 @@ func buildLogger(options server.Options) (*zap.Logger, error) {
}
log, err := cfg.Build()
if err != nil {
return nil, err
return nil, nil, err
}

log = log.Named("xmtpd")

return log, nil
return log, &cfg, nil
}
2 changes: 1 addition & 1 deletion dev/aws-shell
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

region="${REGION:-us-east-2}"
cluster="${ENV:-dev}"
task="${TASK:-node-0}"
task="${TASK:-group1-node-0}"
container="${CONTAINER:-$task}"
task=$(aws --region "$region" \
--query 'taskArns[0]' \
Expand Down
36 changes: 32 additions & 4 deletions pkg/api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,39 @@ type Config struct {
Log *zap.Logger
}

// Options bundle command line options associated with the authn package.
// AuthnOptions bundle command line options associated with the authn package.
type AuthnOptions struct {
Enable bool `long:"enable" description:"require client authentication via wallet tokens"`
EnableV3 bool `long:"enable-v3" description:"require client authentication for V3"`
Ratelimits bool `long:"ratelimits" description:"apply rate limits per wallet"`
/*
Enable is the master switch for the authentication module.
If it is false then the other options in this group are ignored.
The module enforces authentication for requests that require it (currently Publish only).
Authenticated requests will be permitted according to the rules of the request type,
(i.e. you can't publish into other wallets' contact and private topics).
*/
Enable bool `long:"enable" description:"require client authentication via wallet tokens"`
EnableV3 bool `long:"enable-v3" description:"require client authentication for V3"`
/*
Ratelimits enables request rate limiting.
Requests are bucketed by client IP address and request type (there is one bucket for all requests without IPs).
Each bucket is allocated a number of tokens that are refilled at a fixed rate per minute
up to a given maximum number of tokens.
Requests cost 1 token by default, except Publish requests cost the number of Envelopes carried
and BatchQuery requests cost the number of queries carried.
The limits depend on request type, e.g. Publish requests get lower limits than other types of request.
If Allowlists is also true then requests with Bearer tokens from wallets explicitly Allowed get priority,
i.e. a predefined multiple the configured limit.
Priority wallets get separate IP buckets from regular wallets.
*/
Ratelimits bool `long:"ratelimits" description:"apply rate limits per client IP address"`
/*
Allowlists enables wallet allow lists.
All requests that require authentication (currently Publish only) will be rejected
for wallets that are set as Denied in the allow list.
Wallets that are explicitly Allowed will get priority rate limits if Ratelimits is true.
*/
AllowLists bool `long:"allowlists" description:"apply higher limits for allow listed wallets (requires authz and ratelimits)"`
PrivilegedAddresses []string `long:"privileged-address" description:"allow this address to publish into other user's topics"`
}
Expand Down
110 changes: 86 additions & 24 deletions pkg/api/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

messagev1 "github.com/xmtp/proto/v3/go/message_api/v1"
"github.com/xmtp/xmtp-node-go/pkg/logging"
"github.com/xmtp/xmtp-node-go/pkg/ratelimiter"
"github.com/xmtp/xmtp-node-go/pkg/types"
)

Expand Down Expand Up @@ -45,9 +46,22 @@ func (wa *WalletAuthorizer) Unary() grpc.UnaryServerInterceptor {
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if !wa.Ratelimits && !wa.requiresAuthorization(req) {
return handler(ctx, req)
}
wallet, authErr := wa.getWallet(ctx)

if wa.Ratelimits {
if err := wa.applyLimits(ctx, info.FullMethod, req, wallet); err != nil {
return nil, err
}
}

if wa.requiresAuthorization(req) {
if err := wa.authorize(ctx, req); err != nil {
if authErr != nil {
return nil, status.Error(codes.Unauthenticated, authErr.Error())
}
if err := wa.authorize(ctx, req, wallet); err != nil {
return nil, err
}
}
Expand All @@ -62,14 +76,20 @@ func (wa *WalletAuthorizer) Stream() grpc.StreamServerInterceptor {
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
// TODO(mk): Add metrics
if wa.Ratelimits {
ctx := stream.Context()
wallet, _ := wa.getWallet(ctx)
if err := wa.applyLimits(ctx, info.FullMethod, nil, wallet); err != nil {
return err
}
}
return handler(srv, stream)
}
}

func (wa *WalletAuthorizer) isProtocolVersion3(request *messagev1.PublishRequest) bool {
envelopes := request.Envelopes
if envelopes == nil || len(envelopes) == 0 {
if len(envelopes) == 0 {
return false
}
// If any of the envelopes are not for a v3 topic, then we treat the request as non-v3
Expand All @@ -86,66 +106,97 @@ func (wa *WalletAuthorizer) requiresAuthorization(req interface{}) bool {
return isPublish && (!wa.isProtocolVersion3(publishRequest) || wa.AuthnConfig.EnableV3)
}

func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}) error {
func (wa *WalletAuthorizer) getWallet(ctx context.Context) (types.WalletAddr, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Errorf(codes.Unauthenticated, "metadata is not provided")
return "", status.Errorf(codes.Unauthenticated, "metadata is not provided")
}

values := md.Get(authorizationMetadataKey)
if len(values) == 0 {
return status.Errorf(codes.Unauthenticated, "authorization token is not provided")
return "", status.Errorf(codes.Unauthenticated, "authorization token is not provided")
}

words := strings.SplitN(values[0], " ", 2)
if len(words) != 2 {
return status.Errorf(codes.Unauthenticated, "invalid authorization header")
return "", status.Errorf(codes.Unauthenticated, "invalid authorization header")
}
if scheme := strings.TrimSpace(words[0]); scheme != "Bearer" {
return status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme)
return "", status.Errorf(codes.Unauthenticated, "unrecognized authorization scheme %s", scheme)
}
token, err := decodeAuthToken(strings.TrimSpace(words[1]))
if err != nil {
return status.Errorf(codes.Unauthenticated, "extracting token: %s", err)
return "", status.Errorf(codes.Unauthenticated, "extracting token: %s", err)
}

wallet, err := validateToken(ctx, wa.Log, token, time.Now())
if err != nil {
return status.Errorf(codes.Unauthenticated, "validating token: %s", err)
return "", status.Errorf(codes.Unauthenticated, "validating token: %s", err)
}
return wallet, nil
}

func (wa *WalletAuthorizer) authorize(ctx context.Context, req interface{}, wallet types.WalletAddr) error {
if pub, isPublish := req.(*messagev1.PublishRequest); isPublish {
for _, env := range pub.Envelopes {
if !wa.privilegedAddresses[wallet] && !allowedToPublish(env.ContentTopic, wallet) {
return status.Errorf(codes.PermissionDenied, "publishing to restricted topic")
}
}
}

return wa.authorizeWallet(ctx, wallet)
}

func (wa *WalletAuthorizer) authorizeWallet(ctx context.Context, wallet types.WalletAddr) error {
// * for limit exhaustion return status.Errorf(codes.ResourceExhausted, ...)
// * for other authorization failure return status.Errorf(codes.PermissionDenied, ...)

var allowListed bool
if wa.AllowLists {
if wa.AllowLister.IsDenyListed(wallet.String()) {
wa.Log.Debug("wallet deny listed", logging.WalletAddress(wallet.String()))
return status.Errorf(codes.PermissionDenied, ErrDenyListed.Error())
}
allowListed = wa.AllowLister.IsAllowListed(wallet.String())
}
return nil
}

if !wa.Ratelimits {
return nil
func (wa *WalletAuthorizer) applyLimits(ctx context.Context, fullMethod string, req interface{}, wallet types.WalletAddr) error {
// * for limit exhaustion return status.Errorf(codes.ResourceExhausted, ...)
// * for other authorization failure return status.Errorf(codes.PermissionDenied, ...)
_, method := splitMethodName(fullMethod)

ip := clientIPFromContext(ctx)
if len(ip) == 0 {
// requests without an IP address are bucketed together as "ip_unknown"
ip = "ip_unknown"
}

// with no wallet apply regular limits
var isPriority bool
if len(wallet) > 0 && wa.AllowLists {
isPriority = wa.AllowLister.IsAllowListed(wallet.String())
}
cost := 1
limitType := ratelimiter.DEFAULT
switch req := req.(type) {
case *messagev1.PublishRequest:
cost = len(req.Envelopes)
limitType = ratelimiter.PUBLISH
case *messagev1.BatchQueryRequest:
cost = len(req.Requests)
}
err := wa.Limiter.Spend(wallet.String(), allowListed)
// need to separate the IP buckets between priority and regular wallets
var bucket string
if isPriority {
bucket = "P" + ip + string(limitType)
} else {
bucket = "R" + ip + string(limitType)
}
err := wa.Limiter.Spend(limitType, bucket, uint16(cost), isPriority)
if err == nil {
return nil
}
wa.Log.Debug("wallet rate limited", logging.WalletAddress(wallet.String()))

wa.Log.Info("rate limited",
logging.String("client_ip", ip),
logging.WalletAddress(wallet.String()),
logging.Bool("priority", isPriority),
logging.String("method", method),
logging.String("limit", string(limitType)),
logging.Int("cost", cost))
return status.Errorf(codes.ResourceExhausted, err.Error())
}

Expand Down Expand Up @@ -196,3 +247,14 @@ func allowedToPublish(topic string, wallet types.WalletAddr) bool {

return true
}

func clientIPFromContext(ctx context.Context) string {
md, _ := metadata.FromIncomingContext(ctx)
vals := md.Get("x-forwarded-for")
if len(vals) == 0 {
return ""
}
// There are potentially multiple comma separated IPs bundled in that first value
ips := strings.Split(vals[0], ",")
return strings.TrimSpace(ips[0])
}
9 changes: 8 additions & 1 deletion pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"strings"
"sync"
"time"

middleware "github.com/grpc-ecosystem/go-grpc-middleware"
prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
Expand Down Expand Up @@ -87,9 +88,15 @@ func (s *Server) startGRPC() error {
stream = append(stream, telemetryInterceptor.Stream())

if s.Config.Authn.Enable {
limiter := ratelimiter.NewTokenBucketRateLimiter(s.ctx, s.Log)
// Expire buckets after 1 hour of inactivity,
// sweep for expired buckets every 10 minutes.
// Note: entry expiration should be at least some multiple of
// maximum (limit max / limit rate) minutes.
go limiter.Janitor(10*time.Minute, 1*time.Hour)
s.authorizer = NewWalletAuthorizer(&AuthnConfig{
AuthnOptions: s.Config.Authn,
Limiter: ratelimiter.NewTokenBucketRateLimiter(s.Log),
Limiter: limiter,
AllowLister: s.Config.AllowLister,
Log: s.Log.Named("authn"),
})
Expand Down
Loading

0 comments on commit a017b67

Please sign in to comment.