From 8dadd5e82ae0f95dc5762f15e1e806b3abb2916d Mon Sep 17 00:00:00 2001 From: Ronny Haryanto Date: Sat, 27 Apr 2024 22:49:02 +1000 Subject: [PATCH] Support tracing Fixes 3 --- Makefile | 5 +- README.md | 60 +++++++++++++++++--- handler.go | 67 +++++++++++++++++++---- handler_test.go | 22 ++++++-- keys.go | 8 ++- trace/cloud_trace.go | 87 +++++++++++++++++++++++++++++ trace/cloud_trace_test.go | 21 +++++++ trace/context.go | 22 ++++++++ trace/doc.go | 3 + trace/errors.go | 8 +++ trace/middleware.go | 30 ++++++++++ trace/record.go | 38 +++++++++++++ trace/record_test.go | 101 ++++++++++++++++++++++++++++++++++ trace/trace.go | 7 +++ trace/w3c_trace.go | 112 ++++++++++++++++++++++++++++++++++++++ trace/w3c_trace_test.go | 106 ++++++++++++++++++++++++++++++++++++ 16 files changed, 670 insertions(+), 27 deletions(-) create mode 100644 trace/cloud_trace.go create mode 100644 trace/cloud_trace_test.go create mode 100644 trace/context.go create mode 100644 trace/doc.go create mode 100644 trace/errors.go create mode 100644 trace/middleware.go create mode 100644 trace/record.go create mode 100644 trace/record_test.go create mode 100644 trace/trace.go create mode 100644 trace/w3c_trace.go create mode 100644 trace/w3c_trace_test.go diff --git a/Makefile b/Makefile index 84cab97..7270ebf 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,10 @@ all: test test: - go test -v -shuffle=on -coverprofile=coverage.txt -count=1 ./... + go test -shuffle=on -coverprofile=coverage.txt -count=1 ./... + +bench: + go test -run=XXX -bench=. ./... vet: go vet ./... diff --git a/README.md b/README.md index 9272847..423bcef 100644 --- a/README.md +++ b/README.md @@ -6,35 +6,81 @@ A lightweight [`log/slog.JSONHandler`](https://pkg.go.dev/log/slog#JSONHandler) wrapper that adapts the fields to the [Google Cloud Logging structured log format](https://cloud.google.com/logging/docs/structured-logging#structured_logging_special_fields). -The handler merely reformats/renames the structured JSON log fields. It's -still `JSONHandler` under the hood. It does NOT send logs to Cloud Logging -directly (e.g. using the Cloud SDK). - The intended use case is Cloud Run, but it should work in similar environments where logs are emitted to stdout/stderr and automatically picked up by Cloud Logging (e.g. App Engine, Cloud Functions, GKE). +## Features + +- Lightweight. The handler merely reformats/renames the structured JSON log + fields. It's still [`log/slog.JSONHandler`](https://pkg.go.dev/log/slog#JSONHandler) + under the hood. It does NOT send logs to Cloud Logging directly (e.g. using + the Cloud SDK). + +- Tracing. A tracing middleware is provided to automatically extract tracing + information from `traceparent` or `X-Cloud-Trace-Context` HTTP request header, + and attaches it to the request context. The Handler automatically includes any + tracing information as log attributes. + +- Custom levels as supported by Google Cloud Logging, e.g. CRITICAL and NOTICE. + ## Usage ```go import ( "log/slog" + "cloud.google.com/go/compute/metadata" "github.com/ronny/clog" ) func main() { + projectID, err := metadata.ProjectID() + if err != nil { + panic(err) + } + logger := slog.New( clog.NewHandler(os.Stderr, clog.HandlerOptions{ Level: clog.LevelInfo, + GoogleProjectID: projectID, }), ) - logger.Warn("flux capacitor is too warm", "tempCelsius", 42) - logger.Log(ctx, clog.LevelCritical, "flux capacitor is on fire") + slog.SetDefault(logger) + + mux := http.NewServeMux() + mux.Handle("POST /warn", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // ⚠️ The log will have tracing attrs since we're using + // trace.Middleware below (assuming trace header is in the request): + // "logging.googleapis.com/trace" + // "logging.googleapis.com/spanId" + // "logging.googleapis.com/traceSampled" + slog.WarnContext(ctx, "flux capacitor is too warm", + "mycount", 42, + ) + })) + mux.Handle("POST /critical", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // ⚠️ Custom level CRITICAL + slog.Log(ctx, clog.LevelCritical, "flux capacitor is on fire") + })) + + port := os.Getenv("PORT") + if port == "" { + port = "8080" + } + log.Printf("listening on port %s", port) + + // ⚠️ `trace.Middleware` to make tracing information available in ctx in mux + // handlers. + if err := http.ListenAndServe(":"+port, trace.Middleware(mux)); err != nil { + log.Fatal(err) + } } ``` ## Credits and acknowledgements -Thank you to Remko Tronçon for doing all the hard work in +Thank you to Remko Tronçon for doing most of the hard work in https://github.com/remko/cloudrun-slog which is the basis for this library. diff --git a/handler.go b/handler.go index c22bd15..7351688 100644 --- a/handler.go +++ b/handler.go @@ -1,20 +1,67 @@ package clog import ( + "context" + "errors" + "fmt" "io" "log/slog" + + "github.com/ronny/clog/trace" ) -type HandlerOptions = slog.HandlerOptions +var ErrInvalidHandlerOptions = errors.New("invalid HandlerOptions") -// NewHandler returns a [log/slog.JSONHandler] pre-configured for Google Cloud -// Logging. -// -// Basically it replaces `opts.ReplaceAttr` with [ReplaceAttr]. -func NewHandler(w io.Writer, opts *HandlerOptions) *slog.JSONHandler { - if opts == nil { - opts = &HandlerOptions{} - } +type HandlerOptions struct { + AddSource bool + Level slog.Leveler + ReplaceAttr func(groups []string, a slog.Attr) slog.Attr + GoogleProjectID string +} + +var _ slog.Handler = (*Handler)(nil) + +// Handler is a [log/slog.JSONHandler] preconfigured for Google Cloud Logging. +type Handler struct { + opts HandlerOptions + handler slog.Handler +} + +func NewHandler(w io.Writer, opts HandlerOptions) (*Handler, error) { opts.ReplaceAttr = ReplaceAttr - return slog.NewJSONHandler(w, opts) + if opts.GoogleProjectID == "" { + return nil, fmt.Errorf("%w: missing GoogleProjectID", ErrInvalidHandlerOptions) + } + return &Handler{ + opts: opts, + handler: slog.NewJSONHandler(w, &slog.HandlerOptions{ + AddSource: opts.AddSource, + Level: opts.Level, + ReplaceAttr: opts.ReplaceAttr, + }), + }, nil +} + +// Handle implements [log/slog.Handler]. +func (h *Handler) Handle(ctx context.Context, record slog.Record) error { + return h.handler.Handle(ctx, + trace.NewRecord(ctx, record, h.opts.GoogleProjectID), + ) +} + +// Enabled implements [log/slog.Handler]. +func (h *Handler) Enabled(ctx context.Context, level slog.Level) bool { + return h.handler.Enabled(ctx, level) +} + +// WithAttrs implements [log/slog.Handler]. +func (h *Handler) WithAttrs(attrs []slog.Attr) slog.Handler { + h.handler = h.handler.WithAttrs(attrs) + return h +} + +// WithGroup implements [log/slog.Handler]. +func (h *Handler) WithGroup(name string) slog.Handler { + h.handler = h.handler.WithGroup(name) + return h } diff --git a/handler_test.go b/handler_test.go index df71142..14ed00d 100644 --- a/handler_test.go +++ b/handler_test.go @@ -3,7 +3,6 @@ package clog_test import ( "context" "encoding/json" - "fmt" "log/slog" "strings" "testing" @@ -17,10 +16,21 @@ func TestNewHandler(t *testing.T) { buf := strings.Builder{} - handler := clog.NewHandler(&buf, &clog.HandlerOptions{ - AddSource: true, - Level: clog.LevelInfo, + _, err := clog.NewHandler(&buf, clog.HandlerOptions{}) + if !assert.NotNil(t, err) { + return + } + assert.ErrorIs(t, err, clog.ErrInvalidHandlerOptions) + assert.ErrorContains(t, err, "missing GoogleProjectID") + + handler, err := clog.NewHandler(&buf, clog.HandlerOptions{ + AddSource: true, + Level: clog.LevelInfo, + GoogleProjectID: "my-project-id", }) + if !assert.Nil(t, err) { + return + } logger := slog.New(handler) @@ -37,7 +47,7 @@ func TestNewHandler(t *testing.T) { clog.SpanIDKey, "banana", ) - fmt.Printf("%q\n", buf.String()) + // fmt.Printf("%q\n", buf.String()) lines := strings.Split(buf.String(), "\n") // the log lines are \n terminated, so the last line will always be empty since we split on \n @@ -47,7 +57,7 @@ func TestNewHandler(t *testing.T) { assert.Equal(t, 2, len(lines)) var warnEntry Entry - err := json.Unmarshal([]byte(lines[0]), &warnEntry) + err = json.Unmarshal([]byte(lines[0]), &warnEntry) if err != nil { t.Fatal(err) } diff --git a/keys.go b/keys.go index a812072..d9258d8 100644 --- a/keys.go +++ b/keys.go @@ -1,5 +1,7 @@ package clog +import "github.com/ronny/clog/trace" + // Standard JSON log fields as per // https://cloud.google.com/logging/docs/structured-logging#structured_logging_special_fields const ( @@ -15,7 +17,7 @@ const ( LabelsKey = "logging.googleapis.com/labels" OperationKey = "logging.googleapis.com/operation" SourceLocationKey = "logging.googleapis.com/sourceLocation" - SpanIDKey = "logging.googleapis.com/spanId" - TraceKey = "logging.googleapis.com/trace" - TraceSampledKey = "logging.googleapis.com/traceSampled" + TraceKey = trace.TraceKey + SpanIDKey = trace.SpanIDKey + TraceSampledKey = trace.TraceSampledKey ) diff --git a/trace/cloud_trace.go b/trace/cloud_trace.go new file mode 100644 index 0000000..1b275b2 --- /dev/null +++ b/trace/cloud_trace.go @@ -0,0 +1,87 @@ +package trace + +import ( + "fmt" + "regexp" +) + +var _ Trace = (*CloudTraceContext)(nil) + +// TRACE_ID/SPAN_ID();o=OPTIONS +// Example: 70e0091f6f5d4643bb4eca9d81320c76/97123319527522;o=1 +var cloudTraceContextRegex = regexp.MustCompile(`^(?P[0-9a-fA-F]{32})/(?P[0-9]+);o=(?P.+)$`) + +// Deprecated: use [ParseW3CTraceParent] instead. +func ParseCloudTraceContext(s string) (*CloudTraceContext, error) { + match := cloudTraceContextRegex.FindStringSubmatch(s) + result := make(map[string]string) + for i, name := range cloudTraceContextRegex.SubexpNames() { + if i != 0 && name != "" { + result[name] = match[i] + } + } + t := &CloudTraceContext{ + TraceID: result["TraceID"], + SpanID: result["SpanID"], + Options: result["Options"], + } + err := t.Validate() + if err != nil { + return nil, err + } + return t, nil +} + +// Based on https://cloud.google.com/trace/docs/trace-context#http-requests. +// See also https://cloud.google.com/run/docs/trace for Cloud Run specific info. +// +// Deprecated: use [W3CTraceParent] instead. +type CloudTraceContext struct { + TraceID string + SpanID string + Options string +} + +func (t *CloudTraceContext) Validate() error { + if t == nil { + return fmt.Errorf("%w: nil CloudTraceContext", ErrInvalidTrace) + } + if len(t.TraceID) != 32 { + return fmt.Errorf("%w: invalid TraceID", ErrInvalidTrace) + } + if t.SpanID == "" { + return fmt.Errorf("%w: invalid SpanID", ErrInvalidTrace) + } + if t.Options == "" { + return fmt.Errorf("%w: invalid Options", ErrInvalidTrace) + } + return nil +} + +func (t *CloudTraceContext) GetTraceID() string { + if t == nil { + return "" + } + return t.TraceID +} + +func (t *CloudTraceContext) GetSpanID() string { + if t == nil { + return "" + } + return t.SpanID +} + +func (t *CloudTraceContext) GetOptions() string { + if t == nil { + return "" + } + return t.Options +} + +func (t *CloudTraceContext) Sampled() bool { + if t == nil { + return false + } + return t.Options == "1" +} diff --git a/trace/cloud_trace_test.go b/trace/cloud_trace_test.go new file mode 100644 index 0000000..46b6bd5 --- /dev/null +++ b/trace/cloud_trace_test.go @@ -0,0 +1,21 @@ +package trace_test + +import ( + "testing" + + "github.com/ronny/clog/trace" + "github.com/stretchr/testify/assert" +) + +func TestCloudTraceContext(t *testing.T) { + validExample := "70e0091f6f5d4643bb4eca9d81320c76/97123319527522;o=1" + + tr, err := trace.ParseCloudTraceContext(validExample) + if !assert.Nil(t, err) { + return + } + + assert.Equal(t, "70e0091f6f5d4643bb4eca9d81320c76", tr.GetTraceID()) + assert.Equal(t, "97123319527522", tr.GetSpanID()) + assert.Equal(t, true, tr.Sampled()) +} diff --git a/trace/context.go b/trace/context.go new file mode 100644 index 0000000..fe9341e --- /dev/null +++ b/trace/context.go @@ -0,0 +1,22 @@ +package trace + +import "context" + +type ctxKeyType int + +const ctxKey ctxKeyType = iota + +// NewContext returns a new context derived from ctx with trace attached. +func NewContext(ctx context.Context, t Trace) context.Context { + return context.WithValue(ctx, ctxKey, t) +} + +// FromContext extracts any [Trace] information from ctx and +// returns it if found, otherwise returns nil. +func FromContext(ctx context.Context) Trace { + t, ok := ctx.Value(ctxKey).(Trace) + if !ok { + return nil + } + return t +} diff --git a/trace/doc.go b/trace/doc.go new file mode 100644 index 0000000..fd30af0 --- /dev/null +++ b/trace/doc.go @@ -0,0 +1,3 @@ +// Package trace contains utilities to extract and use Cloud Trace context in +// logs. +package trace diff --git a/trace/errors.go b/trace/errors.go new file mode 100644 index 0000000..9f2c5b3 --- /dev/null +++ b/trace/errors.go @@ -0,0 +1,8 @@ +package trace + +import "errors" + +var ( + ErrUnparseable = errors.New("unparseable") + ErrInvalidTrace = errors.New("invalid trace") +) diff --git a/trace/middleware.go b/trace/middleware.go new file mode 100644 index 0000000..f98b20c --- /dev/null +++ b/trace/middleware.go @@ -0,0 +1,30 @@ +package trace + +import "net/http" + +func Middleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // https://cloud.google.com/trace/docs/trace-context#http-requests + var t Trace + + traceparent := r.Header.Get("traceparent") + if traceparent != "" { + // Intentionally ignoring err since there's nothing useful we can do it. + t, _ = ParseW3CTraceParent(traceparent) + } + + if t == nil { + cloudTraceContext := r.Header.Get("X-Cloud-Trace-Context") + if cloudTraceContext != "" { + // Intentionally ignoring err since there's nothing useful we can do it. + t, _ = ParseCloudTraceContext(cloudTraceContext) + } + } + + if t != nil { + r = r.WithContext(NewContext(r.Context(), t)) + } + + h.ServeHTTP(w, r) + }) +} diff --git a/trace/record.go b/trace/record.go new file mode 100644 index 0000000..67d1fce --- /dev/null +++ b/trace/record.go @@ -0,0 +1,38 @@ +package trace + +import ( + "context" + "log/slog" + "strings" +) + +const ( + TraceKey = "logging.googleapis.com/trace" + SpanIDKey = "logging.googleapis.com/spanId" + TraceSampledKey = "logging.googleapis.com/traceSampled" +) + +// NewRecord extracts trace information from ctx, then returns a clone of record +// with the trace attrs as per +// https://cloud.google.com/logging/docs/structured-logging#structured_logging_special_fields +// added to it. +func NewRecord(ctx context.Context, record slog.Record, projectID string) slog.Record { + t := FromContext(ctx) + if t == nil { + return record + } + + record = record.Clone() + + traceValue := strings.Builder{} + traceValue.WriteString("projects/") + traceValue.WriteString(projectID) + traceValue.WriteString("/traces/") + traceValue.WriteString(t.GetTraceID()) + + record.Add(TraceKey, slog.StringValue(traceValue.String())) + record.Add(SpanIDKey, slog.StringValue(t.GetSpanID())) + record.Add(TraceSampledKey, slog.BoolValue(t.Sampled())) + + return record +} diff --git a/trace/record_test.go b/trace/record_test.go new file mode 100644 index 0000000..39d6440 --- /dev/null +++ b/trace/record_test.go @@ -0,0 +1,101 @@ +package trace_test + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/ronny/clog/trace" + "github.com/stretchr/testify/assert" +) + +func TestNewRecord(t *testing.T) { + now := time.Now() + + record := slog.NewRecord( + now, + slog.LevelInfo, + "hello fellow kids", + 0, + ) + + tr, err := trace.ParseW3CTraceParent( + "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", + ) + if err != nil { + t.Fatal(err) + } + + ctx := trace.NewContext(context.Background(), tr) + + projectID := "your-project-id" + + returnedRecord := trace.NewRecord(ctx, record, projectID) + + hasTrace := false + hasSpan := false + hasSampled := false + + returnedRecord.Attrs(func(attr slog.Attr) bool { + switch attr.Key { + case trace.TraceKey: + hasTrace = true + assert.Equal(t, slog.KindString, attr.Value.Kind()) + assert.Equal(t, + "projects/your-project-id/traces/4bf92f3577b34da6a3ce929d0e0e4736", + attr.Value.String(), + ) + case trace.SpanIDKey: + hasSpan = true + assert.Equal(t, slog.KindString, attr.Value.Kind()) + assert.Equal(t, "00f067aa0ba902b7", attr.Value.String()) + case trace.TraceSampledKey: + hasSampled = true + assert.Equal(t, slog.KindBool, attr.Value.Kind()) + assert.Equal(t, true, attr.Value.Bool()) + } + return true + }) + + assert.True(t, hasTrace) + assert.True(t, hasSpan) + assert.True(t, hasSampled) +} + +var benchmarkResult slog.Record + +func BenchmarkNewRecord(b *testing.B) { + now := time.Now() + + record := slog.NewRecord( + now, + slog.LevelInfo, + "hello fellow kids", + 0, + ) + + tr, err := trace.ParseW3CTraceParent( + "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01", + ) + if err != nil { + b.Fatal(err) + } + + ctx := trace.NewContext(context.Background(), tr) + + projectID := "your-project-id" + + b.ResetTimer() + + var r slog.Record + for n := 0; n < b.N; n++ { + // Always store the result to avoid the compiler eliminating the function + // call. + r = trace.NewRecord(ctx, record, projectID) + } + + // Always store the result to a package level variable + // so the compiler cannot eliminate the Benchmark itself. + benchmarkResult = r +} diff --git a/trace/trace.go b/trace/trace.go new file mode 100644 index 0000000..86a721f --- /dev/null +++ b/trace/trace.go @@ -0,0 +1,7 @@ +package trace + +type Trace interface { + GetTraceID() string + GetSpanID() string + Sampled() bool +} diff --git a/trace/w3c_trace.go b/trace/w3c_trace.go new file mode 100644 index 0000000..30f7b5e --- /dev/null +++ b/trace/w3c_trace.go @@ -0,0 +1,112 @@ +package trace + +import ( + "encoding/hex" + "fmt" + "strings" +) + +var _ Trace = (*W3CTraceParent)(nil) + +func ParseW3CTraceParent(s string) (*W3CTraceParent, error) { + parts := strings.Split(s, "-") + if len(parts) != 4 { + return nil, fmt.Errorf("%w: incorrect number of parts %d, expected 4", ErrUnparseable, len(parts)) + } + + versionBytes, err := hex.DecodeString(parts[0]) + if err != nil { + return nil, fmt.Errorf("%w: Version: hex.DecodeString: %w", ErrUnparseable, err) + } + if len(versionBytes) != 1 { + return nil, fmt.Errorf("%w: Version: expected 1 byte, got %d", ErrUnparseable, len(versionBytes)) + } + + traceIDBytes, err := hex.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("%w: TraceID: hex.DecodeString: %w", ErrUnparseable, err) + } + if len(traceIDBytes) != 16 { + return nil, fmt.Errorf("%w: TraceID: expected 16 bytes, got %d", ErrUnparseable, len(traceIDBytes)) + } + + parentIDBytes, err := hex.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("%w: ParentID: hex.DecodeString: %w", ErrUnparseable, err) + } + if len(parentIDBytes) != 8 { + return nil, fmt.Errorf("%w: ParentID: expected 8 bytes, got %d", ErrUnparseable, len(parentIDBytes)) + } + + flagsBytes, err := hex.DecodeString(parts[3]) + if err != nil { + return nil, fmt.Errorf("%w: TraceFlags: hex.DecodeString: %w", ErrUnparseable, err) + } + if len(flagsBytes) != 1 { + return nil, fmt.Errorf("%w: TraceFlags: expected 1 byte, got %d", ErrUnparseable, len(flagsBytes)) + } + + return &W3CTraceParent{ + Version: versionBytes[0], + TraceID: parts[1], + ParentID: parts[2], + TraceFlags: flagsBytes[0], + }, nil +} + +// Based on https://www.w3.org/TR/trace-context/#traceparent-header +type W3CTraceParent struct { + Version byte + TraceID string + ParentID string + TraceFlags byte +} + +func (t *W3CTraceParent) GetVersion() byte { + if t == nil { + return 0 + } + return t.Version +} + +func (t *W3CTraceParent) GetTraceID() string { + if t == nil { + return "" + } + return t.TraceID +} + +func (t *W3CTraceParent) GetParentID() string { + if t == nil { + return "" + } + return t.ParentID +} + +func (t *W3CTraceParent) GetSpanID() string { + return t.GetParentID() +} + +func (t *W3CTraceParent) GetTraceFlags() byte { + if t == nil { + return 0 + } + return t.TraceFlags +} + +// https://www.w3.org/TR/trace-context/#sampled-flag +const W3CFlagSampled byte = 1 + +func (t *W3CTraceParent) Sampled() bool { + if t == nil { + return false + } + + // Need to mask because TraceFlags is a bit field, we need to check + // whether the least significant bit (the rightmost one) is on or not, + // regardless of the other bits. Another way to check is to test if the + // number is odd or not, but the mask way works with any bit position. + // + // https://www.w3.org/TR/trace-context/#trace-flags + return t.TraceFlags&W3CFlagSampled == W3CFlagSampled +} diff --git a/trace/w3c_trace_test.go b/trace/w3c_trace_test.go new file mode 100644 index 0000000..dbd0752 --- /dev/null +++ b/trace/w3c_trace_test.go @@ -0,0 +1,106 @@ +package trace_test + +import ( + "fmt" + "testing" + + "github.com/ronny/clog/trace" + "github.com/stretchr/testify/assert" +) + +func TestParseW3CTraceParent(t *testing.T) { + testCases := []struct { + desc string + input string + expected *trace.W3CTraceParent + expectedErr error + }{ + { + desc: "empty string -> ErrUnparseable", + input: "", + expectedErr: trace.ErrUnparseable, + }, + { + desc: "valid looking -> expected", + input: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-00", + expected: &trace.W3CTraceParent{ + Version: 0, + TraceID: "4bf92f3577b34da6a3ce929d0e0e4736", + ParentID: "00f067aa0ba902b7", + TraceFlags: 0, + }, + }, + { + desc: "valid looking but wrong number of bytes -> ErrUnparseable", + input: "00-abc-def-01", + expectedErr: trace.ErrUnparseable, + }, + } + for i, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf("%d: %s", i, tc.desc), func(t *testing.T) { + tp, err := trace.ParseW3CTraceParent(tc.input) + if tc.expectedErr != nil { + assert.NotNil(t, err) + assert.ErrorIs(t, err, tc.expectedErr) + return + } + assert.Nil(t, err) + assert.Equal(t, tc.expected.GetVersion(), tp.GetVersion()) + assert.Equal(t, tc.expected.GetTraceID(), tp.GetTraceID()) + assert.Equal(t, tc.expected.GetParentID(), tp.GetParentID()) + assert.Equal(t, tc.expected.GetSpanID(), tp.GetSpanID()) + assert.Equal(t, tc.expected.GetTraceFlags(), tp.GetTraceFlags()) + assert.Equal(t, tc.expected.Sampled(), tp.Sampled()) + }) + } +} + +func TestW3CTraceContext_Sampled(t *testing.T) { + testCases := []struct { + desc string + trace *trace.W3CTraceParent + expected bool + }{ + { + desc: "nil trace -> false", + trace: nil, + expected: false, + }, + { + desc: "0x00 -> false", + trace: &trace.W3CTraceParent{ + TraceFlags: 0, + }, + expected: false, + }, + { + desc: "0x01 -> true", + trace: &trace.W3CTraceParent{ + TraceFlags: 1, + }, + expected: true, + }, + { + desc: "0xf3 -> true", + trace: &trace.W3CTraceParent{ + TraceFlags: 0xf3, + }, + expected: true, + }, + { + desc: "0xa0 -> false", + trace: &trace.W3CTraceParent{ + TraceFlags: 0xa0, + }, + expected: false, + }, + } + + for i, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf("%d: %s", i, tc.desc), func(t *testing.T) { + assert.Equal(t, tc.expected, tc.trace.Sampled()) + }) + } +}