diff --git a/stats/opentelemetry/internal/grpc_trace_bin_propagator.go b/stats/opentelemetry/internal/grpc_trace_bin_propagator.go new file mode 100644 index 000000000000..901ee444f600 --- /dev/null +++ b/stats/opentelemetry/internal/grpc_trace_bin_propagator.go @@ -0,0 +1,132 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// TODO: Move out of internal as part of open telemetry API +package internal + +import ( + "context" + "encoding/base64" + + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + otelinternaltracing "google.golang.org/grpc/stats/opentelemetry/internal/tracing" +) + +// GRPCTraceBinPropagator is TextMapPropagator to propagate cross-cutting +// concerns as both text and binary key-value pairs within a carrier that +// travels in-band across process boundaries. +type GRPCTraceBinPropagator struct{} + +// Inject set cross-cutting concerns from the Context into the carrier. +// +// If carrier is carrier.CustomMapCarrier then SetBinary (fast path) is used, +// otherwise Set (slow path) with encoding is used. +func (p GRPCTraceBinPropagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) { + span := trace.SpanFromContext(ctx) + if !span.SpanContext().IsValid() { + return + } + + binaryData := Binary(span.SpanContext()) + if binaryData == nil { + return + } + + if customCarrier, ok := carrier.(otelinternaltracing.CustomCarrier); ok { + customCarrier.SetBinary(binaryData) // fast path: set the binary data without encoding + return + } else { + carrier.Set(otelinternaltracing.GRPCTraceBinHeaderKey, base64.StdEncoding.EncodeToString(binaryData)) // slow path: set the binary data with encoding + } +} + +// Extract reads cross-cutting concerns from the carrier into a Context. +// +// If carrier is carrier.CustomCarrier then GetBinary (fast path) is used, +// otherwise Get (slow path) with decoding is used. +func (p GRPCTraceBinPropagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context { + var binaryData []byte + + if customCarrier, ok := carrier.(otelinternaltracing.CustomCarrier); ok { + binaryData, _ = customCarrier.GetBinary() + } else { + binaryData, _ = base64.StdEncoding.DecodeString(carrier.Get(otelinternaltracing.GRPCTraceBinHeaderKey)) + } + if binaryData == nil { + return ctx + } + + spanContext, ok := FromBinary([]byte(binaryData)) + if !ok { + return ctx + } + + return trace.ContextWithRemoteSpanContext(ctx, spanContext) +} + +// Fields returns the keys whose values are set with Inject. +// +// GRPCTraceBinPropagator will only have `grpc-trace-bin` field. +func (p GRPCTraceBinPropagator) Fields() []string { + return []string{otelinternaltracing.GRPCTraceBinHeaderKey} +} + +// Binary returns the binary format representation of a SpanContext. +// +// If sc is the zero value, Binary returns nil. +func Binary(sc trace.SpanContext) []byte { + if sc.Equal(trace.SpanContext{}) { + return nil + } + var b [29]byte + traceID := trace.TraceID(sc.TraceID()) + copy(b[2:18], traceID[:]) + b[18] = 1 + spanID := trace.SpanID(sc.SpanID()) + copy(b[19:27], spanID[:]) + b[27] = 2 + b[28] = uint8(trace.TraceFlags(sc.TraceFlags())) + return b[:] +} + +// FromBinary returns the SpanContext represented by b. +// +// If b has an unsupported version ID or contains no TraceID, FromBinary +// returns with ok==false. +func FromBinary(b []byte) (sc trace.SpanContext, ok bool) { + if len(b) == 0 || b[0] != 0 { + return trace.SpanContext{}, false + } + b = b[1:] + + if len(b) >= 17 && b[0] == 0 { + sc = sc.WithTraceID(trace.TraceID(b[1:17])) + b = b[17:] + } else { + return trace.SpanContext{}, false + } + if len(b) >= 9 && b[0] == 1 { + sc = sc.WithSpanID(trace.SpanID(b[1:9])) + b = b[9:] + } + if len(b) >= 2 && b[0] == 2 { + sc = sc.WithTraceFlags(trace.TraceFlags(b[1])) + } + return sc, true +} diff --git a/stats/opentelemetry/internal/grpc_trace_bin_propagator_test.go b/stats/opentelemetry/internal/grpc_trace_bin_propagator_test.go new file mode 100644 index 000000000000..d7777bec32ee --- /dev/null +++ b/stats/opentelemetry/internal/grpc_trace_bin_propagator_test.go @@ -0,0 +1,116 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// TODO: Move out of internal as part of open telemetry API +package internal + +import ( + "context" + "encoding/base64" + "testing" + + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" + "google.golang.org/grpc/stats/opentelemetry/internal/tracing" + otelinternaltracing "google.golang.org/grpc/stats/opentelemetry/internal/tracing" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestInject(t *testing.T) { + propagator := GRPCTraceBinPropagator{} + spanContext := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + SpanID: [8]byte{17, 18, 19, 20, 21, 22, 23, 24}, + TraceFlags: trace.FlagsSampled, + }) + traceCtx, traceCancel := context.WithCancel(context.Background()) + traceCtx = trace.ContextWithSpanContext(traceCtx, spanContext) + + t.Run("Fast path with CustomCarrier", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + carrier := tracing.NewCustomCarrier(metadata.NewOutgoingContext(ctx, metadata.MD{})) + propagator.Inject(traceCtx, carrier) + + got := stats.OutgoingTrace(*carrier.Ctx) + want := Binary(spanContext) + if string(got) != string(want) { + t.Fatalf("got = %v, want %v", got, want) + } + cancel() + }) + + t.Run("Slow path with TextMapCarrier", func(t *testing.T) { + carrier := propagation.MapCarrier{} + propagator.Inject(traceCtx, carrier) + + got := carrier.Get(otelinternaltracing.GRPCTraceBinHeaderKey) + want := base64.StdEncoding.EncodeToString(Binary(spanContext)) + if got != want { + t.Fatalf("got = %v, want %v", got, want) + } + }) + + traceCancel() +} + +func (s) TestExtract(t *testing.T) { + propagator := GRPCTraceBinPropagator{} + spanContext := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + SpanID: [8]byte{17, 18, 19, 20, 21, 22, 23, 24}, + TraceFlags: trace.FlagsSampled, + Remote: true, + }) + binaryData := Binary(spanContext) + + t.Run("Fast path with CustomCarrier", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + carrier := tracing.NewCustomCarrier(stats.SetIncomingTrace(ctx, binaryData)) + traceCtx := propagator.Extract(ctx, carrier) + got := trace.SpanContextFromContext(traceCtx) + + if !got.Equal(spanContext) { + t.Fatalf("got = %v, want %v", got, spanContext) + } + cancel() + }) + + t.Run("Slow path with TextMapCarrier", func(t *testing.T) { + carrier := propagation.MapCarrier{ + otelinternaltracing.GRPCTraceBinHeaderKey: base64.StdEncoding.EncodeToString(binaryData), + } + ctx, cancel := context.WithCancel(context.Background()) + traceCtx := propagator.Extract(ctx, carrier) + got := trace.SpanContextFromContext(traceCtx) + + if !got.Equal(spanContext) { + t.Fatalf("got = %v, want %v", got, spanContext) + } + cancel() + }) +} diff --git a/stats/opentelemetry/internal/tracing/custom_map_carrier.go b/stats/opentelemetry/internal/tracing/custom_map_carrier.go new file mode 100644 index 000000000000..1343455f88af --- /dev/null +++ b/stats/opentelemetry/internal/tracing/custom_map_carrier.go @@ -0,0 +1,101 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package tracing implements the OpenTelemetry carrier for context propagation +// in gRPC tracing. +package tracing + +import ( + "context" + "errors" + + "go.opentelemetry.io/otel/propagation" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" +) + +const GRPCTraceBinHeaderKey = "grpc-trace-bin" + +// CustomMapCarrier is a TextMapCarrier that uses gRPC context to store and +// retrieve any propagated key-value pairs in text format along with binary +// format for `grpc-trace-bin` header +type CustomCarrier struct { + propagation.TextMapCarrier + + Ctx *context.Context +} + +// NewCustomCarrier creates a new CustomMapCarrier with +// the given context. +func NewCustomCarrier(ctx context.Context) CustomCarrier { + return CustomCarrier{ + Ctx: &ctx, + } +} + +// Get returns the string value associated with the passed key from the gRPC +// context. +func (c CustomCarrier) Get(key string) string { + md, ok := metadata.FromIncomingContext(*c.Ctx) + if !ok { + return "" + } + values := md.Get(key) + if len(values) == 0 { + return "" + } + return values[0] +} + +// Set stores the key-value pair in string format in the gRPC context. +// If the key already exists, its value will be overwritten. +func (c CustomCarrier) Set(key, value string) { + md, ok := metadata.FromOutgoingContext(*c.Ctx) + if !ok { + md = metadata.MD{} + } + md.Set(key, value) + *c.Ctx = metadata.NewOutgoingContext(*c.Ctx, md) +} + +// GetBinary returns the binary value from the gRPC context in the incoming RPC, +// associated with the header `grpc-trace-bin`. +func (c CustomCarrier) GetBinary() ([]byte, error) { + values := stats.Trace(*c.Ctx) + if len(values) == 0 { + return nil, errors.New("`grpc-trace-bin` header not found") + } + + return values, nil +} + +// SetBinary sets the binary value to the gRPC context, which will be sent in +// the outgoing RPC with the header grpc-trace-bin. +func (c CustomCarrier) SetBinary(value []byte) { + *c.Ctx = stats.SetTrace(*c.Ctx, value) +} + +// Keys lists the keys stored in the gRPC context for the outgoing RPC. +func (c CustomCarrier) Keys() []string { + md, _ := metadata.FromOutgoingContext(*c.Ctx) + keys := make([]string, 0, len(md)) + for k := range md { + keys = append(keys, k) + } + return keys +} diff --git a/stats/opentelemetry/internal/tracing/custom_map_carrier_test.go b/stats/opentelemetry/internal/tracing/custom_map_carrier_test.go new file mode 100644 index 000000000000..932051328703 --- /dev/null +++ b/stats/opentelemetry/internal/tracing/custom_map_carrier_test.go @@ -0,0 +1,202 @@ +/* + * + * Copyright 2024 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package tracing + +import ( + "context" + "reflect" + "testing" + + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +// sameElements checks if two string slices have the same elements, +// ignoring order. +func sameElements(a, b []string) bool { + if len(a) != len(b) { + return false + } + + countA := make(map[string]int) + countB := make(map[string]int) + for _, s := range a { + countA[s]++ + } + for _, s := range b { + countB[s]++ + } + + for k, v := range countA { + if countB[k] != v { + return false + } + } + return true +} + +func (s) TestGet(t *testing.T) { + tests := []struct { + name string + md metadata.MD + key string + want string + }{ + { + name: "existing key", + md: metadata.Pairs("key1", "value1"), + key: "key1", + want: "value1", + }, + { + name: "non-existing key", + md: metadata.Pairs("key1", "value1"), + key: "key2", + want: "", + }, + { + name: "empty key", + md: metadata.MD{}, + key: "key1", + want: "", + }, + } + + for _, tt := range tests { + ctx, cancel := context.WithCancel(context.Background()) + t.Run(tt.name, func(t *testing.T) { + c := NewCustomCarrier(metadata.NewIncomingContext(ctx, tt.md)) + got := c.Get(tt.key) + if got != tt.want { + t.Fatalf("got %s, want %s", got, tt.want) + } + cancel() + }) + } +} + +func (s) TestSet(t *testing.T) { + tests := []struct { + name string + initialMD metadata.MD // Metadata to initialize the context with + setKey string // Key to set using c.Set() + setValue string // Value to set using c.Set() + wantKeys []string // Expected keys returned by c.Keys() + }{ + { + name: "set new key", + initialMD: metadata.MD{}, + setKey: "key1", + setValue: "value1", + wantKeys: []string{"key1"}, + }, + { + name: "override existing key", + initialMD: metadata.MD{"key1": []string{"oldvalue"}}, + setKey: "key1", + setValue: "newvalue", + wantKeys: []string{"key1"}, + }, + { + name: "set key with existing unrelated key", + initialMD: metadata.MD{"key2": []string{"value2"}}, + setKey: "key1", + setValue: "value1", + wantKeys: []string{"key2", "key1"}, // Order matters here! + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := NewCustomCarrier(metadata.NewOutgoingContext(ctx, tt.initialMD)) + + c.Set(tt.setKey, tt.setValue) + + gotKeys := c.Keys() + if !sameElements(gotKeys, tt.wantKeys) { + t.Fatalf("got keys %v, want %v", gotKeys, tt.wantKeys) + } + gotMD, _ := metadata.FromOutgoingContext(*c.Ctx) + if gotMD.Get(tt.setKey)[0] != tt.setValue { + t.Fatalf("got value %s, want %s, for key %s", gotMD.Get(tt.setKey)[0], tt.setValue, tt.setKey) + } + cancel() + }) + } +} + +func (s) TestGetBinary(t *testing.T) { + t.Run("get grpc-trace-bin header", func(t *testing.T) { + want := []byte{0x01, 0x02, 0x03} + ctx, cancel := context.WithCancel(context.Background()) + c := NewCustomCarrier(stats.SetIncomingTrace(ctx, want)) + got, err := c.GetBinary() + if err != nil { + t.Fatalf("got error %v, want nil", err) + } + if string(got) != string(want) { + t.Fatalf("got %s, want %s", got, want) + } + cancel() + }) + + t.Run("get non grpc-trace-bin header", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := NewCustomCarrier(metadata.NewIncomingContext(ctx, metadata.Pairs("non-trace-bin", "\x01\x02\x03"))) + _, err := c.GetBinary() + if err == nil { + t.Fatalf("got nil error, want error") + } + cancel() + }) +} + +func (s) TestSetBinary(t *testing.T) { + t.Run("set grpc-trace-bin header", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + want := []byte{0x01, 0x02, 0x03} + c := NewCustomCarrier(stats.SetIncomingTrace(ctx, want)) + c.SetBinary(want) + got := stats.OutgoingTrace(*c.Ctx) + if !reflect.DeepEqual(got, want) { + t.Fatalf("got %v, want %v", got, want) + } + cancel() + }) + + t.Run("set non grpc-trace-bin header", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := NewCustomCarrier(metadata.NewOutgoingContext(ctx, metadata.MD{"non-trace-bin": []string{"value"}})) + got := stats.OutgoingTrace(*c.Ctx) + if got != nil { + t.Fatalf("got %v, want nil", got) + } + cancel() + }) +}