diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go index ffea569bad..0b03cdc6fe 100644 --- a/contrib/google.golang.org/grpc/appsec.go +++ b/contrib/google.golang.org/grpc/appsec.go @@ -29,9 +29,8 @@ import ( // UnaryHandler wrapper to use when AppSec is enabled to monitor its execution. func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler { trace.SetAppSecEnabledTags(span) - return func(ctx context.Context, req interface{}) (interface{}, error) { - var err error - var blocked bool + return func(ctx context.Context, req interface{}) (res interface{}, rpcErr error) { + var blockedErr error md, _ := metadata.FromIncomingContext(ctx) clientIP := setClientIP(ctx, span, md) args := types.HandlerOperationArgs{ @@ -41,46 +40,50 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc } ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) { dyngo.OnData(op, func(a *sharedsec.GRPCAction) { - code, e := a.GRPCWrapper() - blocked = a.Blocking() - err = status.Error(codes.Code(code), e.Error()) + if a.Blocking() { + code, err := a.GRPCWrapper() + blockedErr = status.Error(codes.Code(code), err.Error()) + } }) }) defer func() { events := op.Finish(types.HandlerOperationRes{}) - if blocked { + if len(events) > 0 { + grpctrace.SetSecurityEventsTags(span, events) + } + if blockedErr != nil { op.SetTag(trace.BlockedRequestTag, true) + rpcErr = blockedErr } grpctrace.SetRequestMetadataTags(span, md) trace.SetTags(span, op.Tags()) - if len(events) > 0 { - grpctrace.SetSecurityEventsTags(span, events) - } }() - if err != nil { - return nil, err + // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.) + if blockedErr != nil { + return nil, blockedErr } - defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: req}, op).Finish(types.ReceiveOperationRes{}) - if err != nil { - return nil, err + // As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs) + grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req}) + // Check if a blocking condition was detected so far with the receive operation events + if blockedErr != nil { + return nil, blockedErr } - rv, err := handler(ctx, req) - if e, ok := err.(*types.MonitoringError); ok { - err = status.Error(codes.Code(e.GRPCStatus()), e.Error()) - } - return rv, err + // Call the original handler - let the deferred function above handle the blocking condition and return error + return handler(ctx, req) } } // StreamHandler wrapper to use when AppSec is enabled to monitor its execution. func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler { trace.SetAppSecEnabledTags(span) - return func(srv interface{}, stream grpc.ServerStream) error { - appsecStream := &appsecServerStream{ - ServerStream: stream, + return func(srv interface{}, stream grpc.ServerStream) (rpcErr error) { + // Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface) + appsecStream := &appsecServerStream{ + ServerStream: stream, + // note: the blockedErr field is captured by the RPC handler's OnData closure below } ctx := stream.Context() @@ -99,24 +102,28 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp if a.Blocking() { code, e := a.GRPCWrapper() appsecStream.blockedErr = status.Error(codes.Code(code), e.Error()) - } }) }) + // Finish constructing the appsec stream wrapper and replace the original one appsecStream.handlerOperation = op appsecStream.ctx = ctx - stream = appsecStream defer func() { events := op.Finish(types.HandlerOperationRes{}) + + if len(events) > 0 { + grpctrace.SetSecurityEventsTags(span, events) + } + if appsecStream.blockedErr != nil { op.SetTag(trace.BlockedRequestTag, true) + // Change the RPC return error with appsec's + rpcErr = appsecStream.blockedErr } + trace.SetTags(span, op.Tags()) - if len(events) > 0 { - grpctrace.SetSecurityEventsTags(span, events) - } }() // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.) @@ -124,12 +131,8 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp return appsecStream.blockedErr } - // Call the original handler - err := handler(srv, stream) - //if e, ok := err.(*types.MonitoringError); ok { - // err = status.Error(codes.Code(e.GRPCStatus()), e.Error()) - //} - return err + // Call the original handler - let the deferred function above handle the blocking condition and return error + return handler(srv, appsecStream) } } @@ -144,18 +147,19 @@ type appsecServerStream struct { // RecvMsg implements grpc.ServerStream interface method to monitor its // execution with AppSec. -func (ss appsecServerStream) RecvMsg(m interface{}) (err error) { - op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: m}, ss.handlerOperation) - defer op.Finish(types.ReceiveOperationRes{}) - - if ss.blockedErr != nil { - return ss.blockedErr - } - +func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) { + op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation) + defer func() { + op.Finish(types.ReceiveOperationRes{Message: m}) + if ss.blockedErr != nil { + // Change the function call return error with appsec's + err = ss.blockedErr + } + }() return ss.ServerStream.RecvMsg(m) } -func (ss appsecServerStream) Context() context.Context { +func (ss *appsecServerStream) Context() context.Context { return ss.ctx } diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go index d6b8280719..911efa7c5d 100644 --- a/contrib/google.golang.org/grpc/appsec_test.go +++ b/contrib/google.golang.org/grpc/appsec_test.go @@ -10,7 +10,6 @@ import ( "encoding/json" "fmt" "net" - "strings" "testing" pappsec "gopkg.in/DataDog/dd-trace-go.v1/appsec" @@ -169,16 +168,17 @@ func TestBlocking(t *testing.T) { }, { name: "message blocking", - md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"), + md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "legit-user-1"), message: "$globals", expectedMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking expectedNotMatchedRules: []string{"blk-001-002"}, // no user blocking }, { - name: "user blocking", - md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"), - message: "", - expectedMatchedRules: []string{"crs-941-110", "blk-001-002"}, // monitoring event + user blocking + name: "user blocking", + md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"), + message: "", + expectedMatchedRules: []string{"blk-001-002"}, // user blocking alone as it comes first in our test handler + expectedNotMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking }, } { t.Run(tc.name, func(t *testing.T) { @@ -190,10 +190,10 @@ func TestBlocking(t *testing.T) { do(client) finished := mt.FinishedSpans() - require.Len(t, finished, 1) + require.True(t, len(finished) >= 1) // streaming RPCs will have two spans, unary RPCs will have one // The request should have the security events - events, _ := finished[0].Tag("_dd.appsec.json").(string) + events, _ := finished[len(finished)-1 /* root span */].Tag("_dd.appsec.json").(string) require.NotEmpty(t, events) for _, rule := range tc.expectedMatchedRules { require.Contains(t, events, rule) @@ -236,28 +236,6 @@ func TestBlocking(t *testing.T) { }) } }) - - t.Run("stream-block", func(t *testing.T) { - client, mt, cleanup := setup() - defer cleanup() - - ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4")) - stream, err := client.StreamPing(ctx) - require.NoError(t, err) - reply, err := stream.Recv() - err = stream.CloseSend() - require.NoError(t, err) - - require.Equal(t, codes.Aborted, status.Code(err)) - require.Nil(t, reply) - - finished := mt.FinishedSpans() - require.Len(t, finished, 1) - // The request should have the attack attempts - event, _ := finished[0].Tag("_dd.appsec.json").(string) - require.NotNil(t, event) - require.True(t, strings.Contains(event, "blk-001-001")) - }) } func TestPasslist(t *testing.T) { @@ -406,17 +384,20 @@ func (s *appsecFixtureServer) StreamPing(stream Fixture_StreamPingServer) (err e ctx := stream.Context() md, _ := metadata.FromIncomingContext(ctx) ids := md.Get("user-id") - if err := pappsec.SetUser(ctx, ids[0]); err != nil { - return err + if len(ids) > 0 { + if err := pappsec.SetUser(ctx, ids[0]); err != nil { + return err + } } return s.s.StreamPing(stream) } func (s *appsecFixtureServer) Ping(ctx context.Context, in *FixtureRequest) (*FixtureReply, error) { md, _ := metadata.FromIncomingContext(ctx) ids := md.Get("user-id") - if err := pappsec.SetUser(ctx, ids[0]); err != nil { - return nil, err + if len(ids) > 0 { + if err := pappsec.SetUser(ctx, ids[0]); err != nil { + return nil, err + } } - return s.s.Ping(ctx, in) } diff --git a/internal/appsec/emitter/grpcsec/types/types.go b/internal/appsec/emitter/grpcsec/types/types.go index f41a832b3a..6e94ef9495 100644 --- a/internal/appsec/emitter/grpcsec/types/types.go +++ b/internal/appsec/emitter/grpcsec/types/types.go @@ -63,15 +63,14 @@ type ( // ReceiveOperationArgs is the gRPC handler receive operation arguments // Empty as of today. - ReceiveOperationArgs struct{ - // Message received by the gRPC handler. - // Corresponds to the address `grpc.server.request.message`. - Message interface{} - } + ReceiveOperationArgs struct{} // ReceiveOperationRes is the gRPC handler receive operation results which // contains the message the gRPC handler received. ReceiveOperationRes struct { + // Message received by the gRPC handler. + // Corresponds to the address `grpc.server.request.message`. + Message interface{} } // MonitoringError is used to vehicle a gRPC error that also embeds a request status code diff --git a/internal/appsec/listener/grpcsec/grpc.go b/internal/appsec/listener/grpcsec/grpc.go index 0429b9adb4..da8d14c0e9 100644 --- a/internal/appsec/listener/grpcsec/grpc.go +++ b/internal/appsec/listener/grpcsec/grpc.go @@ -127,14 +127,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types addEvents(wafResult.Events) } if wafResult.HasActions() { - for aType, params := range wafResult.Actions { - for _, action := range shared.ActionsFromEntry(aType, params) { - if grpcAction, ok := action.(*sharedsec.GRPCAction); ok { - code, err := grpcAction.GRPCWrapper() - dyngo.EmitData(op, types.NewMonitoringError(err.Error(), code)) - } - } - } + shared.ProcessActions(op, wafResult.Actions, &types.MonitoringError{}) log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID) } }) @@ -163,7 +156,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types } // When the gRPC handler receives a message - dyngo.On(op, func(_ types.ReceiveOperation, res types.ReceiveOperationArgs) { + dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) { // Run the WAF on the rule addresses available and listened to by the sec rules var values waf.RunAddressData // Add the gRPC message to the values if the WAF rules are using it. @@ -184,7 +177,6 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types wafResult := shared.RunWAF(wafCtx, values) if wafResult.HasEvents() { log.Debug("appsec: attack detected by the grpc waf") - addEvents(wafResult.Events) } if wafResult.HasActions() {