From 27d08a67df5b0d35544c663f9d47577f85ffc6e4 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Wed, 8 May 2024 15:44:15 +0200 Subject: [PATCH] Add traceparent header to enable distributed tracing. (#914) ## Changes Currently, it is difficult to correlate requests made by a client with a request made by a backend service, especially if the same request is made multiple times. The REST API supports the [Trace Context standard](https://www.w3.org/TR/trace-context/) which defines a standard way of propagating tracing information through HTTP headers. This PR implements a traceparent generator to construct new traceparent headers in accordance with this standard. These traceparents are attached to the headers of each individual request. The resulting header is visible when debug logs are enabled and DATABRICKS_DEBUG_HEADERS is set. ## Tests Added a unit test to ensure that new traceparents are set for every request. It's hard to test that a debug log actually contains this header without rewriting a lot of the matching logic for debug logs, but I did need to remove traceparent from those unit tests, suggesting that debug logging does include traceparent (like any other header). Manually tested that the traceparent set by the client is visible to the server by checking our internal access logs. - [ ] `make test` passing - [ ] `make fmt` applied - [ ] relevant integration tests applied --- client/client_test.go | 2 + config/config.go | 2 +- httpclient/api_client.go | 5 + httpclient/api_client_test.go | 41 ++++++++ httpclient/fixtures/fixture.go | 16 +++- httpclient/traceparent/traceparent.go | 104 +++++++++++++++++++++ httpclient/traceparent/traceparent_test.go | 53 +++++++++++ 7 files changed, 218 insertions(+), 5 deletions(-) create mode 100644 httpclient/traceparent/traceparent.go create mode 100644 httpclient/traceparent/traceparent_test.go diff --git a/client/client_test.go b/client/client_test.go index 987d8ab9b..4a86130e6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -435,6 +435,8 @@ func testNonJSONResponseIncludedInError(t *testing.T, statusCode int, status, er Token: "token", ConfigFile: "/dev/null", HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + // Clear traceparent header which is nondeterministic. + r.Header.Del("traceparent") return &http.Response{ Proto: "HTTP/2.0", Status: status, diff --git a/config/config.go b/config/config.go index d5bfaeb9b..5da39cbbe 100644 --- a/config/config.go +++ b/config/config.go @@ -99,7 +99,7 @@ type Config struct { // Use at your own risk or for unit testing purposes. InsecureSkipVerify bool `name:"skip_verify" auth:"-"` - // Number of seconds for HTTP timeout. Default is 300 (5 minutes). + // Number of seconds for HTTP timeout. Default is 60 (1 minute). HTTPTimeoutSeconds int `name:"http_timeout_seconds" auth:"-"` // Truncate JSON fields in JSON above this limit. Default is 96. diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 903c79696..ad4c4c63b 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/databricks/databricks-sdk-go/common" + "github.com/databricks/databricks-sdk-go/httpclient/traceparent" "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/logger/httplog" "github.com/databricks/databricks-sdk-go/retries" @@ -236,6 +237,10 @@ func (c *ApiClient) attempt( return c.failRequest(ctx, "failed during request visitor", err) } } + // Set traceparent for distributed tracing. + // This must be done after all visitors have run, as they may modify the request. + traceparent.AddTraceparent(request) + // request.Context() holds context potentially enhanced by visitors request.Header.Set("User-Agent", useragent.FromContext(request.Context())) if request.Header.Get("Content-Type") == "" && requestBody.ContentType != "" { diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index 429056b3a..ba0ab9d8e 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -714,3 +714,44 @@ func TestDefaultAuthVisitor(t *testing.T) { err := c.Do(context.Background(), "GET", "/a/b", WithRequestData(map[string]any{}), authOption) require.NoError(t, err) } + +func TestTraceparentHeader(t *testing.T) { + seenTraceparents := []string{} + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + tp := r.Header.Get("Traceparent") + assert.NotEmpty(t, tp) + assert.NotContains(t, seenTraceparents, tp) + seenTraceparents = append(seenTraceparents, tp) + return &http.Response{ + StatusCode: 200, + Request: r, + }, nil + }), + }) + + for i := 0; i < 10; i++ { + err := c.Do(context.Background(), "GET", "/a/b") + assert.NoError(t, err) + } +} + +func TestTraceparentHeaderDoesNotOverrideUserHeader(t *testing.T) { + userTraceparent := "00-thetraceid-theparentid-00" + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + tp := r.Header.Get("Traceparent") + assert.NotEmpty(t, tp) + assert.Equal(t, userTraceparent, tp) + return &http.Response{ + StatusCode: 200, + Request: r, + }, nil + }), + }) + + err := c.Do(context.Background(), "GET", "/a/b", WithRequestHeaders(map[string]string{ + "Traceparent": userTraceparent, + })) + assert.NoError(t, err) +} diff --git a/httpclient/fixtures/fixture.go b/httpclient/fixtures/fixture.go index 3d5fafe65..32b00fcbd 100644 --- a/httpclient/fixtures/fixture.go +++ b/httpclient/fixtures/fixture.go @@ -6,10 +6,13 @@ import ( "fmt" "io" "net/http" + "net/textproto" "net/url" "os" "reflect" "strings" + + "github.com/databricks/databricks-sdk-go/httpclient/traceparent" ) // HTTPFixture defines request structure for test @@ -38,12 +41,17 @@ func (f HTTPFixture) AssertHeaders(req *http.Request) error { return nil } actualHeaders := map[string]string{} + // remove user agent & traceparent from comparison, as it'll make fixtures too difficult + // to maintain in the long term + toSkip := map[string]struct{}{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {}, + textproto.CanonicalMIMEHeaderKey(traceparent.TRACEPARENT_HEADER): {}, + } for k := range req.Header { - actualHeaders[k] = req.Header.Get(k) + if _, skip := toSkip[k]; !skip { + actualHeaders[k] = req.Header.Get(k) + } } - // remove user agent from comparison, as it'll make fixtures too difficult - // to maintain in the long term - delete(actualHeaders, "User-Agent") if !reflect.DeepEqual(f.ExpectedHeaders, actualHeaders) { expectedJSON, _ := json.MarshalIndent(f.ExpectedHeaders, "", " ") actualJSON, _ := json.MarshalIndent(actualHeaders, "", " ") diff --git a/httpclient/traceparent/traceparent.go b/httpclient/traceparent/traceparent.go new file mode 100644 index 000000000..b7cfe61db --- /dev/null +++ b/httpclient/traceparent/traceparent.go @@ -0,0 +1,104 @@ +package traceparent + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "strings" +) + +type traceparentFlag byte + +const ( + traceparentFlagSampled traceparentFlag = 1 << iota +) + +const TRACEPARENT_HEADER = "traceparent" + +type Traceparent struct { + version byte + traceId [16]byte + parentId [8]byte + flags traceparentFlag +} + +func (t *Traceparent) String() string { + b := strings.Builder{} + b.WriteString(hex.EncodeToString([]byte{t.version})) + b.WriteRune('-') + b.WriteString(hex.EncodeToString(t.traceId[:])) + b.WriteRune('-') + b.WriteString(hex.EncodeToString(t.parentId[:])) + b.WriteRune('-') + b.WriteString(hex.EncodeToString([]byte{byte(t.flags)})) + return b.String() +} + +func (t *Traceparent) Equals(other *Traceparent) bool { + return t.version == other.version && + t.flags == other.flags && + string(t.traceId[:]) == string(other.traceId[:]) && + string(t.parentId[:]) == string(other.parentId[:]) +} + +func NewTraceparent() *Traceparent { + traceId := [16]byte{} + parentId := [8]byte{} + rand.Read(traceId[:]) + rand.Read(parentId[:]) + return &Traceparent{ + version: 0, + traceId: traceId, + parentId: parentId, + flags: traceparentFlagSampled, + } +} + +func FromString(s string) (*Traceparent, error) { + parts := strings.Split(s, "-") + if len(parts) != 4 { + return nil, fmt.Errorf("invalid traceparent string: %s", s) + } + t := &Traceparent{ + traceId: [16]byte{}, + parentId: [8]byte{}, + } + version, err := hex.DecodeString(parts[0]) + if err != nil { + return nil, err + } + if len(version) != 1 { + return nil, fmt.Errorf("invalid version: %s, expected 1 byte", parts[0]) + } + t.version = version[0] + n, err := hex.Decode(t.traceId[:], []byte(parts[1])) + if err != nil { + return nil, err + } + if n != 16 { + return nil, fmt.Errorf("invalid traceId: %s, expected 16 bytes", parts[1]) + } + n, err = hex.Decode(t.parentId[:], []byte(parts[2])) + if err != nil { + return nil, err + } + if n != 8 { + return nil, fmt.Errorf("invalid parentId: %s, expected 8 bytes", parts[2]) + } + flags, err := hex.DecodeString(parts[3]) + if err != nil { + return nil, err + } + if len(flags) != 1 { + return nil, fmt.Errorf("invalid flags: %s, expected 1 byte", parts[3]) + } + t.flags = traceparentFlag(flags[0]) + return t, nil +} + +func AddTraceparent(r *http.Request) { + if r.Header.Get(TRACEPARENT_HEADER) == "" { + r.Header.Set(TRACEPARENT_HEADER, NewTraceparent().String()) + } +} diff --git a/httpclient/traceparent/traceparent_test.go b/httpclient/traceparent/traceparent_test.go new file mode 100644 index 000000000..9eb5cafa7 --- /dev/null +++ b/httpclient/traceparent/traceparent_test.go @@ -0,0 +1,53 @@ +package traceparent + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + tp := NewTraceparent() + assert.Equal(t, byte(0), tp.version) + assert.Equal(t, byte(1), byte(tp.flags)) +} + +func TestEqual(t *testing.T) { + tp1 := NewTraceparent() + tp2 := &Traceparent{ + version: tp1.version, + traceId: tp1.traceId, + parentId: tp1.parentId, + flags: tp1.flags, + } + assert.True(t, tp1.Equals(tp2)) +} + +func TestTwoNewTraceparentsAreNotEqual(t *testing.T) { + tp1 := NewTraceparent() + tp2 := NewTraceparent() + assert.False(t, tp1.Equals(tp2)) +} + +var testTraceId = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} +var testParentId = [8]byte{0, 1, 2, 3, 4, 5, 6, 7} + +func TestString(t *testing.T) { + tp := &Traceparent{ + version: 0, + traceId: testTraceId, + parentId: testParentId, + flags: 1, + } + res := tp.String() + assert.Equal(t, "00-000102030405060708090a0b0c0d0e0f-0001020304050607-01", res) +} + +func TestFromString(t *testing.T) { + tp, err := FromString("00-000102030405060708090a0b0c0d0e0f-0001020304050607-01") + assert.NoError(t, err) + assert.Equal(t, byte(0), tp.version) + assert.Equal(t, testTraceId, tp.traceId) + assert.Equal(t, testParentId, tp.parentId) + assert.Equal(t, byte(1), byte(tp.flags)) +}