Skip to content

Commit

Permalink
grpc/appsec: fix rpc message blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
Julio-Guerra committed Jun 3, 2024
1 parent fa01bc0 commit 89ecea1
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 93 deletions.
90 changes: 47 additions & 43 deletions contrib/google.golang.org/grpc/appsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ import (
// UnaryHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler {
trace.SetAppSecEnabledTags(span)
return func(ctx context.Context, req interface{}) (interface{}, error) {
var err error
var blocked bool
return func(ctx context.Context, req interface{}) (res interface{}, rpcErr error) {
var blockedErr error
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
args := types.HandlerOperationArgs{
Expand All @@ -41,46 +40,50 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
code, e := a.GRPCWrapper()
blocked = a.Blocking()
err = status.Error(codes.Code(code), e.Error())
if a.Blocking() {
code, err := a.GRPCWrapper()
blockedErr = status.Error(codes.Code(code), err.Error())
}
})
})
defer func() {
events := op.Finish(types.HandlerOperationRes{})
if blocked {
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
if blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
rpcErr = blockedErr
}
grpctrace.SetRequestMetadataTags(span, md)
trace.SetTags(span, op.Tags())
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
}()

if err != nil {
return nil, err
// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
if blockedErr != nil {
return nil, blockedErr
}

defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: req}, op).Finish(types.ReceiveOperationRes{})
if err != nil {
return nil, err
// As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs)
grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})
// Check if a blocking condition was detected so far with the receive operation events
if blockedErr != nil {
return nil, blockedErr
}

rv, err := handler(ctx, req)
if e, ok := err.(*types.MonitoringError); ok {
err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
}
return rv, err
// Call the original handler - let the deferred function above handle the blocking condition and return error
return handler(ctx, req)
}
}

// StreamHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler {
trace.SetAppSecEnabledTags(span)
return func(srv interface{}, stream grpc.ServerStream) error {
appsecStream := &appsecServerStream{
ServerStream: stream,
return func(srv interface{}, stream grpc.ServerStream) (rpcErr error) {
// Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface)
appsecStream := &appsecServerStream{
ServerStream: stream,
// note: the blockedErr field is captured by the RPC handler's OnData closure below
}

ctx := stream.Context()
Expand All @@ -99,37 +102,37 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp
if a.Blocking() {
code, e := a.GRPCWrapper()
appsecStream.blockedErr = status.Error(codes.Code(code), e.Error())

}
})
})

// Finish constructing the appsec stream wrapper and replace the original one
appsecStream.handlerOperation = op
appsecStream.ctx = ctx
stream = appsecStream

defer func() {
events := op.Finish(types.HandlerOperationRes{})

if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}

if appsecStream.blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
// Change the RPC return error with appsec's
rpcErr = appsecStream.blockedErr
}

trace.SetTags(span, op.Tags())
if len(events) > 0 {
grpctrace.SetSecurityEventsTags(span, events)
}
}()

// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
if appsecStream.blockedErr != nil {
return appsecStream.blockedErr
}

// Call the original handler
err := handler(srv, stream)
//if e, ok := err.(*types.MonitoringError); ok {
// err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
//}
return err
// Call the original handler - let the deferred function above handle the blocking condition and return error
return handler(srv, appsecStream)
}
}

Expand All @@ -144,18 +147,19 @@ type appsecServerStream struct {

// RecvMsg implements grpc.ServerStream interface method to monitor its
// execution with AppSec.
func (ss appsecServerStream) RecvMsg(m interface{}) (err error) {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: m}, ss.handlerOperation)
defer op.Finish(types.ReceiveOperationRes{})

if ss.blockedErr != nil {
return ss.blockedErr
}

func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation)
defer func() {
op.Finish(types.ReceiveOperationRes{Message: m})
if ss.blockedErr != nil {
// Change the function call return error with appsec's
err = ss.blockedErr
}
}()
return ss.ServerStream.RecvMsg(m)
}

func (ss appsecServerStream) Context() context.Context {
func (ss *appsecServerStream) Context() context.Context {
return ss.ctx
}

Expand Down
51 changes: 16 additions & 35 deletions contrib/google.golang.org/grpc/appsec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"encoding/json"
"fmt"
"net"
"strings"
"testing"

pappsec "gopkg.in/DataDog/dd-trace-go.v1/appsec"
Expand Down Expand Up @@ -169,16 +168,17 @@ func TestBlocking(t *testing.T) {
},
{
name: "message blocking",
md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "legit-user-1"),
message: "$globals",
expectedMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
expectedNotMatchedRules: []string{"blk-001-002"}, // no user blocking
},
{
name: "user blocking",
md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
message: "<script>alert('xss');</script>",
expectedMatchedRules: []string{"crs-941-110", "blk-001-002"}, // monitoring event + user blocking
name: "user blocking",
md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
message: "<script>alert('xss');</script>",
expectedMatchedRules: []string{"blk-001-002"}, // user blocking alone as it comes first in our test handler
expectedNotMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
},
} {
t.Run(tc.name, func(t *testing.T) {
Expand All @@ -190,10 +190,10 @@ func TestBlocking(t *testing.T) {
do(client)

finished := mt.FinishedSpans()
require.Len(t, finished, 1)
require.True(t, len(finished) >= 1) // streaming RPCs will have two spans, unary RPCs will have one

// The request should have the security events
events, _ := finished[0].Tag("_dd.appsec.json").(string)
events, _ := finished[len(finished)-1 /* root span */].Tag("_dd.appsec.json").(string)
require.NotEmpty(t, events)
for _, rule := range tc.expectedMatchedRules {
require.Contains(t, events, rule)
Expand Down Expand Up @@ -236,28 +236,6 @@ func TestBlocking(t *testing.T) {
})
}
})

t.Run("stream-block", func(t *testing.T) {
client, mt, cleanup := setup()
defer cleanup()

ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
stream, err := client.StreamPing(ctx)
require.NoError(t, err)
reply, err := stream.Recv()
err = stream.CloseSend()
require.NoError(t, err)

require.Equal(t, codes.Aborted, status.Code(err))
require.Nil(t, reply)

finished := mt.FinishedSpans()
require.Len(t, finished, 1)
// The request should have the attack attempts
event, _ := finished[0].Tag("_dd.appsec.json").(string)
require.NotNil(t, event)
require.True(t, strings.Contains(event, "blk-001-001"))
})
}

func TestPasslist(t *testing.T) {
Expand Down Expand Up @@ -406,17 +384,20 @@ func (s *appsecFixtureServer) StreamPing(stream Fixture_StreamPingServer) (err e
ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
if err := pappsec.SetUser(ctx, ids[0]); err != nil {
return err
if len(ids) > 0 {
if err := pappsec.SetUser(ctx, ids[0]); err != nil {
return err
}
}
return s.s.StreamPing(stream)
}
func (s *appsecFixtureServer) Ping(ctx context.Context, in *FixtureRequest) (*FixtureReply, error) {
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
if err := pappsec.SetUser(ctx, ids[0]); err != nil {
return nil, err
if len(ids) > 0 {
if err := pappsec.SetUser(ctx, ids[0]); err != nil {
return nil, err
}
}

return s.s.Ping(ctx, in)
}
9 changes: 4 additions & 5 deletions internal/appsec/emitter/grpcsec/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@ type (

// ReceiveOperationArgs is the gRPC handler receive operation arguments
// Empty as of today.
ReceiveOperationArgs struct{
// Message received by the gRPC handler.
// Corresponds to the address `grpc.server.request.message`.
Message interface{}
}
ReceiveOperationArgs struct{}

// ReceiveOperationRes is the gRPC handler receive operation results which
// contains the message the gRPC handler received.
ReceiveOperationRes struct {
// Message received by the gRPC handler.
// Corresponds to the address `grpc.server.request.message`.
Message interface{}
}

// MonitoringError is used to vehicle a gRPC error that also embeds a request status code
Expand Down
12 changes: 2 additions & 10 deletions internal/appsec/listener/grpcsec/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
addEvents(wafResult.Events)
}
if wafResult.HasActions() {
for aType, params := range wafResult.Actions {
for _, action := range shared.ActionsFromEntry(aType, params) {
if grpcAction, ok := action.(*sharedsec.GRPCAction); ok {
code, err := grpcAction.GRPCWrapper()
dyngo.EmitData(op, types.NewMonitoringError(err.Error(), code))
}
}
}
shared.ProcessActions(op, wafResult.Actions, &types.MonitoringError{})
log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID)
}
})
Expand Down Expand Up @@ -163,7 +156,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
}

// When the gRPC handler receives a message
dyngo.On(op, func(_ types.ReceiveOperation, res types.ReceiveOperationArgs) {
dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) {
// Run the WAF on the rule addresses available and listened to by the sec rules
var values waf.RunAddressData
// Add the gRPC message to the values if the WAF rules are using it.
Expand All @@ -184,7 +177,6 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
wafResult := shared.RunWAF(wafCtx, values)
if wafResult.HasEvents() {
log.Debug("appsec: attack detected by the grpc waf")

addEvents(wafResult.Events)
}
if wafResult.HasActions() {
Expand Down

0 comments on commit 89ecea1

Please sign in to comment.