Skip to content

Commit

Permalink
Improve efficiency by (1) passing in span directly, and (2) getting t…
Browse files Browse the repository at this point in the history
…race context bytes only once and reusing it for batch requests
  • Loading branch information
nhulston committed Oct 10, 2024
1 parent 71447c6 commit 9be0a75
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 104 deletions.
6 changes: 3 additions & 3 deletions contrib/aws/aws-sdk-go-v2/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ func (mw *traceMiddleware) startTraceMiddleware(stack *middleware.Stack) error {
// Inject trace context
switch serviceID {
case "SQS":
sqsTracer.EnrichOperation(spanctx, in, operation)
sqsTracer.EnrichOperation(span, in, operation)
case "SNS":
snsTracer.EnrichOperation(spanctx, in, operation)
snsTracer.EnrichOperation(span, in, operation)
case "EventBridge":
eventBridgeTracer.EnrichOperation(spanctx, in, operation)
eventBridgeTracer.EnrichOperation(span, in, operation)
}

// Handle initialize and continue through the middleware chain.
Expand Down
17 changes: 5 additions & 12 deletions contrib/aws/internal/eventbridge/eventbridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package eventbridge

import (
"context"
"encoding/json"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/eventbridge"
Expand All @@ -25,36 +24,30 @@ const (
maxSizeBytes = 256 * 1024 // 256 KB
)

func EnrichOperation(ctx context.Context, in middleware.InitializeInput, operation string) {
func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) {
switch operation {
case "PutEvents":
handlePutEvents(ctx, in)
handlePutEvents(span, in)
}
}

func handlePutEvents(ctx context.Context, in middleware.InitializeInput) {
func handlePutEvents(span tracer.Span, in middleware.InitializeInput) {
params, ok := in.Parameters.(*eventbridge.PutEventsInput)
if !ok {
log.Debug("Unable to read PutEvents params")
return
}

for i := range params.Entries {
injectTraceContext(ctx, &params.Entries[i])
injectTraceContext(span, &params.Entries[i])
}
}

func injectTraceContext(ctx context.Context, entryPtr *types.PutEventsRequestEntry) {
func injectTraceContext(span tracer.Span, entryPtr *types.PutEventsRequestEntry) {
if entryPtr == nil {
return
}

span, ok := tracer.SpanFromContext(ctx)
if !ok || span == nil {
log.Debug("Unable to find span from context")
return
}

carrier := tracer.TextMapCarrier{}
err := tracer.Inject(span.Context(), carrier)
if err != nil {
Expand Down
18 changes: 9 additions & 9 deletions contrib/aws/internal/eventbridge/eventbridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,24 @@ func TestEnrichOperation(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")
span := tracer.StartSpan("test-span")

input := middleware.InitializeInput{
Parameters: &eventbridge.PutEventsInput{
Entries: []types.PutEventsRequestEntry{
{
Detail: aws.String(`{"key": "value"}`),
Detail: aws.String(`{"@123": "value", "_foo": "bar"}`),
EventBusName: aws.String("test-bus"),
},
{
Detail: aws.String(`{"another": "data"}`),
Detail: aws.String(`{"@123": "data", "_foo": "bar"}`),
EventBusName: aws.String("test-bus-2"),
},
},
},
}

EnrichOperation(ctx, input, "PutEvents")
EnrichOperation(span, input, "PutEvents")

params, ok := input.Parameters.(*eventbridge.PutEventsInput)
require.True(t, ok)
Expand All @@ -55,6 +54,8 @@ func TestEnrichOperation(t *testing.T) {
err := json.Unmarshal([]byte(*entry.Detail), &detail)
require.NoError(t, err)

assert.Contains(t, detail, "@123") // make sure user data still exists
assert.Contains(t, detail, "_foo")
assert.Contains(t, detail, datadogKey)
ddData, ok := detail[datadogKey].(map[string]interface{})
require.True(t, ok)
Expand Down Expand Up @@ -109,7 +110,7 @@ func TestInjectTraceContext(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
injectTraceContext(ctx, &tt.entry)
injectTraceContext(span, &tt.entry)
tt.expected(t, &tt.entry)

var detail map[string]interface{}
Expand Down Expand Up @@ -147,8 +148,7 @@ func TestInjectTraceContextSizeLimit(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")
span := tracer.StartSpan("test-span")

tests := []struct {
name string
Expand Down Expand Up @@ -187,7 +187,7 @@ func TestInjectTraceContextSizeLimit(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
injectTraceContext(ctx, &tt.entry)
injectTraceContext(span, &tt.entry)
tt.expected(t, &tt.entry)
})
}
Expand Down
59 changes: 33 additions & 26 deletions contrib/aws/internal/sns/sns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package sns

import (
"context"
"encoding/json"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sns"
Expand All @@ -21,69 +20,77 @@ const (
maxMessageAttributes = 10
)

func EnrichOperation(ctx context.Context, in middleware.InitializeInput, operation string) {
func EnrichOperation(span tracer.Span, in middleware.InitializeInput, operation string) {
switch operation {
case "Publish":
handlePublish(ctx, in)
handlePublish(span, in)
case "PublishBatch":
handlePublishBatch(ctx, in)
handlePublishBatch(span, in)
}
}

func handlePublish(ctx context.Context, in middleware.InitializeInput) {
func handlePublish(span tracer.Span, in middleware.InitializeInput) {
params, ok := in.Parameters.(*sns.PublishInput)
if !ok {
log.Debug("Unable to read PublishInput params")
return
}

traceContext, err := getTraceContextBytes(span)
if err != nil {
log.Debug("Unable to get trace context: %s", err.Error())
return
}

if params.MessageAttributes == nil {
params.MessageAttributes = make(map[string]types.MessageAttributeValue)
}

injectTraceContext(ctx, params.MessageAttributes)
injectTraceContext(traceContext, params.MessageAttributes)
}

func handlePublishBatch(ctx context.Context, in middleware.InitializeInput) {
func handlePublishBatch(span tracer.Span, in middleware.InitializeInput) {
params, ok := in.Parameters.(*sns.PublishBatchInput)
if !ok {
log.Debug("Unable to read PublishBatch params")
return
}

traceContext, err := getTraceContextBytes(span)
if err != nil {
log.Debug("Unable to get trace context: %s", err.Error())
return
}

for i := range params.PublishBatchRequestEntries {
if params.PublishBatchRequestEntries[i].MessageAttributes == nil {
params.PublishBatchRequestEntries[i].MessageAttributes = make(map[string]types.MessageAttributeValue)
}
injectTraceContext(ctx, params.PublishBatchRequestEntries[i].MessageAttributes)
injectTraceContext(traceContext, params.PublishBatchRequestEntries[i].MessageAttributes)
}
}

func injectTraceContext(ctx context.Context, messageAttributes map[string]types.MessageAttributeValue) {
span, ok := tracer.SpanFromContext(ctx)
if !ok || span == nil {
log.Debug("Unable to find span from context")
return
}

// SNS only allows a maximum of 10 message attributes.
// https://docs.aws.amazon.com/sns/latest/dg/sns-message-attributes.html
// Only inject if there's room.
if len(messageAttributes) >= maxMessageAttributes {
log.Info("Cannot inject trace context: message already has maximum allowed attributes")
return
}

func getTraceContextBytes(span tracer.Span) ([]byte, error) {
carrier := tracer.TextMapCarrier{}
err := tracer.Inject(span.Context(), carrier)
if err != nil {
log.Debug("Unable to inject trace context: %s", err.Error())
return
return nil, err
}

jsonBytes, err := json.Marshal(carrier)
if err != nil {
log.Debug("Unable to marshal trace context: %s", err.Error())
return nil, err
}

return jsonBytes, nil
}

func injectTraceContext(jsonBytes []byte, messageAttributes map[string]types.MessageAttributeValue) {
// SNS only allows a maximum of 10 message attributes.
// https://docs.aws.amazon.com/sns/latest/dg/sns-message-attributes.html
// Only inject if there's room.
if len(messageAttributes) >= maxMessageAttributes {
log.Info("Cannot inject trace context: message already has maximum allowed attributes")
return
}

Expand Down
27 changes: 13 additions & 14 deletions contrib/aws/internal/sns/sns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestEnrichOperation(t *testing.T) {
name string
operation string
input middleware.InitializeInput
setup func(context.Context) context.Context
setup func(context.Context) tracer.Span
check func(*testing.T, middleware.InitializeInput)
}{
{
Expand All @@ -38,9 +38,9 @@ func TestEnrichOperation(t *testing.T) {
TopicArn: aws.String("arn:aws:sns:us-east-1:123456789012:test-topic"),
},
},
setup: func(ctx context.Context) context.Context {
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")
return ctx
setup: func(ctx context.Context) tracer.Span {
span, _ := tracer.StartSpanFromContext(ctx, "test-span")
return span
},
check: func(t *testing.T, in middleware.InitializeInput) {
params, ok := in.Parameters.(*sns.PublishInput)
Expand Down Expand Up @@ -72,9 +72,9 @@ func TestEnrichOperation(t *testing.T) {
},
},
},
setup: func(ctx context.Context) context.Context {
_, ctx = tracer.StartSpanFromContext(ctx, "test-span")
return ctx
setup: func(ctx context.Context) tracer.Span {
span, _ := tracer.StartSpanFromContext(ctx, "test-span")
return span
},
check: func(t *testing.T, in middleware.InitializeInput) {
params, ok := in.Parameters.(*sns.PublishBatchInput)
Expand All @@ -101,11 +101,9 @@ func TestEnrichOperation(t *testing.T) {
defer mt.Stop()

ctx := context.Background()
if tt.setup != nil {
ctx = tt.setup(ctx)
}
span := tt.setup(ctx)

EnrichOperation(ctx, tt.input, tt.operation)
EnrichOperation(span, tt.input, tt.operation)

if tt.check != nil {
tt.check(t, tt.input)
Expand Down Expand Up @@ -142,8 +140,7 @@ func TestInjectTraceContext(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

ctx := context.Background()
span, ctx := tracer.StartSpanFromContext(ctx, "test-span")
span := tracer.StartSpan("test-span")

messageAttributes := make(map[string]types.MessageAttributeValue)
for i := 0; i < tt.existingAttributes; i++ {
Expand All @@ -153,7 +150,9 @@ func TestInjectTraceContext(t *testing.T) {
}
}

injectTraceContext(ctx, messageAttributes)
traceContext, err := getTraceContextBytes(span)
assert.NoError(t, err)
injectTraceContext(traceContext, messageAttributes)

if tt.expectInjection {
assert.Contains(t, messageAttributes, datadogKey)
Expand Down
Loading

0 comments on commit 9be0a75

Please sign in to comment.