Skip to content

Commit

Permalink
grpc/appsec: fix rpc message blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
Julio-Guerra committed Jun 3, 2024
1 parent b0aa1b8 commit fa01bc0
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 239 deletions.
66 changes: 42 additions & 24 deletions contrib/google.golang.org/grpc/appsec.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
code, e := a.GRPCWrapper(md)
code, e := a.GRPCWrapper()
blocked = a.Blocking()
err = status.Error(codes.Code(code), e.Error())
})
Expand All @@ -61,7 +61,12 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
if err != nil {
return nil, err
}
defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, op).Finish(types.ReceiveOperationRes{Message: req})

defer grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: req}, op).Finish(types.ReceiveOperationRes{})
if err != nil {
return nil, err
}

rv, err := handler(ctx, req)
if e, ok := err.(*types.MonitoringError); ok {
err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
Expand All @@ -74,33 +79,38 @@ func appsecUnaryHandlerMiddleware(method string, span ddtrace.Span, handler grpc
func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grpc.StreamHandler) grpc.StreamHandler {
trace.SetAppSecEnabledTags(span)
return func(srv interface{}, stream grpc.ServerStream) error {
var err error
var blocked bool
appsecStream := &appsecServerStream{
ServerStream: stream,
}

ctx := stream.Context()
md, _ := metadata.FromIncomingContext(ctx)
clientIP := setClientIP(ctx, span, md)
grpctrace.SetRequestMetadataTags(span, md)

// Create the handler operation and listen to blocking gRPC actions to detect a blocking condition
args := types.HandlerOperationArgs{
Method: method,
Metadata: md,
ClientIP: clientIP,
}
ctx, op := grpcsec.StartHandlerOperation(ctx, args, nil, func(op *types.HandlerOperation) {
dyngo.OnData(op, func(a *sharedsec.GRPCAction) {
code, e := a.GRPCWrapper(md)
blocked = a.Blocking()
err = status.Error(codes.Code(code), e.Error())
if a.Blocking() {
code, e := a.GRPCWrapper()
appsecStream.blockedErr = status.Error(codes.Code(code), e.Error())

}
})
})
stream = appsecServerStream{
ServerStream: stream,
handlerOperation: op,
ctx: ctx,
}

appsecStream.handlerOperation = op
appsecStream.ctx = ctx
stream = appsecStream

defer func() {
events := op.Finish(types.HandlerOperationRes{})
if blocked {
if appsecStream.blockedErr != nil {
op.SetTag(trace.BlockedRequestTag, true)
}
trace.SetTags(span, op.Tags())
Expand All @@ -109,14 +119,16 @@ func appsecStreamHandlerMiddleware(method string, span ddtrace.Span, handler grp
}
}()

if err != nil {
return err
// Check if a blocking condition was detected so far with the start operation event (ip blocking, metadata blocking, etc.)
if appsecStream.blockedErr != nil {
return appsecStream.blockedErr
}

err = handler(srv, stream)
if e, ok := err.(*types.MonitoringError); ok {
err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
}
// Call the original handler
err := handler(srv, stream)
//if e, ok := err.(*types.MonitoringError); ok {
// err = status.Error(codes.Code(e.GRPCStatus()), e.Error())
//}
return err
}
}
Expand All @@ -125,15 +137,21 @@ type appsecServerStream struct {
grpc.ServerStream
handlerOperation *types.HandlerOperation
ctx context.Context

// blockedErr is used to store the error to return when a blocking sec event is detected.
blockedErr error
}

// RecvMsg implements grpc.ServerStream interface method to monitor its
// execution with AppSec.
func (ss appsecServerStream) RecvMsg(m interface{}) error {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{}, ss.handlerOperation)
defer func() {
op.Finish(types.ReceiveOperationRes{Message: m})
}()
func (ss appsecServerStream) RecvMsg(m interface{}) (err error) {
op := grpcsec.StartReceiveOperation(types.ReceiveOperationArgs{Message: m}, ss.handlerOperation)
defer op.Finish(types.ReceiveOperationRes{})

if ss.blockedErr != nil {
return ss.blockedErr
}

return ss.ServerStream.RecvMsg(m)
}

Expand Down
Loading

0 comments on commit fa01bc0

Please sign in to comment.