diff --git a/contrib/google.golang.org/grpc/appsec.go b/contrib/google.golang.org/grpc/appsec.go
index ffea569bad..0b03cdc6fe 100644
--- a/contrib/google.golang.org/grpc/appsec.go
+++ b/contrib/google.golang.org/grpc/appsec.go
@@ -29,9 +29,8 @@ import (
// UnaryHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc.UnaryHandler) grpc.UnaryHandler {
trace.SetAppSecEnabledTags(span)
- return func(ctx context.Context, req interface{}) (interface{}, error) {
- var err error
- var blocked bool
+ return func(ctx context.Context, req interface{}) (res interface{}, rpcErr error) {
+ var blockedErr error
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
args := types.HandlerOperationArgs{
@@ -41,46 +40,50 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
- code, e := a.GRPCWrapper()
- blocked = a.Blocking()
- err = status.Error(codes.Code(code), e.Error())
+ if a.Blocking() {
+ code, err := a.GRPCWrapper()
+ blockedErr = status.Error(codes.Code(code), err.Error())
+ }
})
})
defer func() {
events := op.Finish(types.HandlerOperationRes{})
- if blocked {
+ if len(events) > 0 {
+ grpctrace.SetSecurityEventsTags(span, events)
+ }
+ if blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
+ rpcErr = blockedErr
}
grpctrace.SetRequestMetadataTags(span, md)
trace.SetTags(span, op.Tags())
- if len(events) > 0 {
- grpctrace.SetSecurityEventsTags(span, events)
- }
}()
- if err != nil {
- return nil, err
+ // Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
+ if blockedErr != nil {
+ return nil, blockedErr
}
- defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: req}, op).Finish(types.ReceiveOperationRes{})
- if err != nil {
- return nil, err
+ // As of our gRPC abstract operation definition, we must fake a receive operation for unary RPCs (the same model fits both unary and streaming RPCs)
+ grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})
+ // Check if a blocking condition was detected so far with the receive operation events
+ if blockedErr != nil {
+ return nil, blockedErr
}
- rv, err := handler(ctx, req)
- if e, ok := err.(*types.MonitoringError); ok {
- err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
- }
- return rv, err
+ // Call the original handler - let the deferred function above handle the blocking condition and return error
+ return handler(ctx, req)
}
}
// StreamHandler wrapper to use when AppSec is enabled to monitor its execution.
func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler {
trace.SetAppSecEnabledTags(span)
- return func(srv interface{}, stream grpc.ServerStream) error {
- appsecStream := &appsecServerStream{
- ServerStream: stream,
+ return func(srv interface{}, stream grpc.ServerStream) (rpcErr error) {
+ // Create a ServerStream wrapper with appsec RPC handler operation and the Go context (to implement the ServerStream interface)
+ appsecStream := &appsecServerStream{
+ ServerStream: stream,
+ // note: the blockedErr field is captured by the RPC handler's OnData closure below
}
ctx := stream.Context()
@@ -99,24 +102,28 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp
if a.Blocking() {
code, e := a.GRPCWrapper()
appsecStream.blockedErr = status.Error(codes.Code(code), e.Error())
-
}
})
})
+ // Finish constructing the appsec stream wrapper and replace the original one
appsecStream.handlerOperation = op
appsecStream.ctx = ctx
- stream = appsecStream
defer func() {
events := op.Finish(types.HandlerOperationRes{})
+
+ if len(events) > 0 {
+ grpctrace.SetSecurityEventsTags(span, events)
+ }
+
if appsecStream.blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
+ // Change the RPC return error with appsec's
+ rpcErr = appsecStream.blockedErr
}
+
trace.SetTags(span, op.Tags())
- if len(events) > 0 {
- grpctrace.SetSecurityEventsTags(span, events)
- }
}()
// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
@@ -124,12 +131,8 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp
return appsecStream.blockedErr
}
- // Call the original handler
- err := handler(srv, stream)
- //if e, ok := err.(*types.MonitoringError); ok {
- // err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
- //}
- return err
+ // Call the original handler - let the deferred function above handle the blocking condition and return error
+ return handler(srv, appsecStream)
}
}
@@ -144,18 +147,19 @@ type appsecServerStream struct {
// RecvMsg implements grpc.ServerStream interface method to monitor its
// execution with AppSec.
-func (ss appsecServerStream) RecvMsg(m interface{}) (err error) {
- op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: m}, ss.handlerOperation)
- defer op.Finish(types.ReceiveOperationRes{})
-
- if ss.blockedErr != nil {
- return ss.blockedErr
- }
-
+func (ss *appsecServerStream) RecvMsg(m interface{}) (err error) {
+ op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation)
+ defer func() {
+ op.Finish(types.ReceiveOperationRes{Message: m})
+ if ss.blockedErr != nil {
+ // Change the function call return error with appsec's
+ err = ss.blockedErr
+ }
+ }()
return ss.ServerStream.RecvMsg(m)
}
-func (ss appsecServerStream) Context() context.Context {
+func (ss *appsecServerStream) Context() context.Context {
return ss.ctx
}
diff --git a/contrib/google.golang.org/grpc/appsec_test.go b/contrib/google.golang.org/grpc/appsec_test.go
index d6b8280719..911efa7c5d 100644
--- a/contrib/google.golang.org/grpc/appsec_test.go
+++ b/contrib/google.golang.org/grpc/appsec_test.go
@@ -10,7 +10,6 @@ import (
"encoding/json"
"fmt"
"net"
- "strings"
"testing"
pappsec "gopkg.in/DataDog/dd-trace-go.v1/appsec"
@@ -169,16 +168,17 @@ func TestBlocking(t *testing.T) {
},
{
name: "message blocking",
- md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
+ md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "legit-user-1"),
message: "$globals",
expectedMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
expectedNotMatchedRules: []string{"blk-001-002"}, // no user blocking
},
{
- name: "user blocking",
- md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
- message: "",
- expectedMatchedRules: []string{"crs-941-110", "blk-001-002"}, // monitoring event + user blocking
+ name: "user blocking",
+ md: metadata.Pairs("m1", "v1", "x-client-ip", "1.2.3.5", "user-id", "blocked-user-1"),
+ message: "",
+ expectedMatchedRules: []string{"blk-001-002"}, // user blocking alone as it comes first in our test handler
+ expectedNotMatchedRules: []string{"crs-933-130-block"}, // message blocking alone as it comes before user blocking
},
} {
t.Run(tc.name, func(t *testing.T) {
@@ -190,10 +190,10 @@ func TestBlocking(t *testing.T) {
do(client)
finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
+ require.True(t, len(finished) >= 1) // streaming RPCs will have two spans, unary RPCs will have one
// The request should have the security events
- events, _ := finished[0].Tag("_dd.appsec.json").(string)
+ events, _ := finished[len(finished)-1 /* root span */].Tag("_dd.appsec.json").(string)
require.NotEmpty(t, events)
for _, rule := range tc.expectedMatchedRules {
require.Contains(t, events, rule)
@@ -236,28 +236,6 @@ func TestBlocking(t *testing.T) {
})
}
})
-
- t.Run("stream-block", func(t *testing.T) {
- client, mt, cleanup := setup()
- defer cleanup()
-
- ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("dd-canary", "dd-test-scanner-log", "x-client-ip", "1.2.3.4"))
- stream, err := client.StreamPing(ctx)
- require.NoError(t, err)
- reply, err := stream.Recv()
- err = stream.CloseSend()
- require.NoError(t, err)
-
- require.Equal(t, codes.Aborted, status.Code(err))
- require.Nil(t, reply)
-
- finished := mt.FinishedSpans()
- require.Len(t, finished, 1)
- // The request should have the attack attempts
- event, _ := finished[0].Tag("_dd.appsec.json").(string)
- require.NotNil(t, event)
- require.True(t, strings.Contains(event, "blk-001-001"))
- })
}
func TestPasslist(t *testing.T) {
@@ -406,17 +384,20 @@ func (s *appsecFixtureServer) StreamPing(stream Fixture_StreamPingServer) (err e
ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
- if err := pappsec.SetUser(ctx, ids[0]); err != nil {
- return err
+ if len(ids) > 0 {
+ if err := pappsec.SetUser(ctx, ids[0]); err != nil {
+ return err
+ }
}
return s.s.StreamPing(stream)
}
func (s *appsecFixtureServer) Ping(ctx context.Context, in *FixtureRequest) (*FixtureReply, error) {
md, _ := metadata.FromIncomingContext(ctx)
ids := md.Get("user-id")
- if err := pappsec.SetUser(ctx, ids[0]); err != nil {
- return nil, err
+ if len(ids) > 0 {
+ if err := pappsec.SetUser(ctx, ids[0]); err != nil {
+ return nil, err
+ }
}
-
return s.s.Ping(ctx, in)
}
diff --git a/internal/appsec/emitter/grpcsec/types/types.go b/internal/appsec/emitter/grpcsec/types/types.go
index f41a832b3a..6e94ef9495 100644
--- a/internal/appsec/emitter/grpcsec/types/types.go
+++ b/internal/appsec/emitter/grpcsec/types/types.go
@@ -63,15 +63,14 @@ type (
// ReceiveOperationArgs is the gRPC handler receive operation arguments
// Empty as of today.
- ReceiveOperationArgs struct{
- // Message received by the gRPC handler.
- // Corresponds to the address `grpc.server.request.message`.
- Message interface{}
- }
+ ReceiveOperationArgs struct{}
// ReceiveOperationRes is the gRPC handler receive operation results which
// contains the message the gRPC handler received.
ReceiveOperationRes struct {
+ // Message received by the gRPC handler.
+ // Corresponds to the address `grpc.server.request.message`.
+ Message interface{}
}
// MonitoringError is used to vehicle a gRPC error that also embeds a request status code
diff --git a/internal/appsec/listener/grpcsec/grpc.go b/internal/appsec/listener/grpcsec/grpc.go
index 0429b9adb4..da8d14c0e9 100644
--- a/internal/appsec/listener/grpcsec/grpc.go
+++ b/internal/appsec/listener/grpcsec/grpc.go
@@ -127,14 +127,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
addEvents(wafResult.Events)
}
if wafResult.HasActions() {
- for aType, params := range wafResult.Actions {
- for _, action := range shared.ActionsFromEntry(aType, params) {
- if grpcAction, ok := action.(*sharedsec.GRPCAction); ok {
- code, err := grpcAction.GRPCWrapper()
- dyngo.EmitData(op, types.NewMonitoringError(err.Error(), code))
- }
- }
- }
+ shared.ProcessActions(op, wafResult.Actions, &types.MonitoringError{})
log.Debug("appsec: WAF detected an authenticated user attack: %s", args.UserID)
}
})
@@ -163,7 +156,7 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
}
// When the gRPC handler receives a message
- dyngo.On(op, func(_ types.ReceiveOperation, res types.ReceiveOperationArgs) {
+ dyngo.OnFinish(op, func(_ types.ReceiveOperation, res types.ReceiveOperationRes) {
// Run the WAF on the rule addresses available and listened to by the sec rules
var values waf.RunAddressData
// Add the gRPC message to the values if the WAF rules are using it.
@@ -184,7 +177,6 @@ func (l *wafEventListener) onEvent(op *types.HandlerOperation, handlerArgs types
wafResult := shared.RunWAF(wafCtx, values)
if wafResult.HasEvents() {
log.Debug("appsec: attack detected by the grpc waf")
-
addEvents(wafResult.Events)
}
if wafResult.HasActions() {