From 25b773618241086223d971f019d1690f800e773b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 May 2024 11:43:16 +0200 Subject: [PATCH 1/3] Add traceparent --- httpclient/api_client.go | 4 + httpclient/api_client_test.go | 21 +++++ httpclient/traceparent/traceparent.go | 104 +++++++++++++++++++++ httpclient/traceparent/traceparent_test.go | 53 +++++++++++ 4 files changed, 182 insertions(+) create mode 100644 httpclient/traceparent/traceparent.go create mode 100644 httpclient/traceparent/traceparent_test.go diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 903c79696..a7fd861e0 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -13,6 +13,7 @@ import ( "time" "github.com/databricks/databricks-sdk-go/common" + "github.com/databricks/databricks-sdk-go/httpclient/traceparent" "github.com/databricks/databricks-sdk-go/logger" "github.com/databricks/databricks-sdk-go/logger/httplog" "github.com/databricks/databricks-sdk-go/retries" @@ -236,6 +237,9 @@ func (c *ApiClient) attempt( return c.failRequest(ctx, "failed during request visitor", err) } } + // Set traceparent for distributed tracing. + traceparent.AddTraceparent(request) + // request.Context() holds context potentially enhanced by visitors request.Header.Set("User-Agent", useragent.FromContext(request.Context())) if request.Header.Get("Content-Type") == "" && requestBody.ContentType != "" { diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index 429056b3a..af21ce90c 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -714,3 +714,24 @@ func TestDefaultAuthVisitor(t *testing.T) { err := c.Do(context.Background(), "GET", "/a/b", WithRequestData(map[string]any{}), authOption) require.NoError(t, err) } + +func TestTraceparentHeader(t *testing.T) { + seenTraceparents := []string{} + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + tp := r.Header.Get("Traceparent") + assert.NotEmpty(t, tp) + assert.NotContains(t, seenTraceparents, tp) + seenTraceparents = append(seenTraceparents, tp) + return &http.Response{ + StatusCode: 200, + Request: r, + }, nil + }), + }) + + for i := 0; i < 10; i++ { + err := c.Do(context.Background(), "GET", "/a/b") + assert.NoError(t, err) + } +} diff --git a/httpclient/traceparent/traceparent.go b/httpclient/traceparent/traceparent.go new file mode 100644 index 000000000..b7cfe61db --- /dev/null +++ b/httpclient/traceparent/traceparent.go @@ -0,0 +1,104 @@ +package traceparent + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "strings" +) + +type traceparentFlag byte + +const ( + traceparentFlagSampled traceparentFlag = 1 << iota +) + +const TRACEPARENT_HEADER = "traceparent" + +type Traceparent struct { + version byte + traceId [16]byte + parentId [8]byte + flags traceparentFlag +} + +func (t *Traceparent) String() string { + b := strings.Builder{} + b.WriteString(hex.EncodeToString([]byte{t.version})) + b.WriteRune('-') + b.WriteString(hex.EncodeToString(t.traceId[:])) + b.WriteRune('-') + b.WriteString(hex.EncodeToString(t.parentId[:])) + b.WriteRune('-') + b.WriteString(hex.EncodeToString([]byte{byte(t.flags)})) + return b.String() +} + +func (t *Traceparent) Equals(other *Traceparent) bool { + return t.version == other.version && + t.flags == other.flags && + string(t.traceId[:]) == string(other.traceId[:]) && + string(t.parentId[:]) == string(other.parentId[:]) +} + +func NewTraceparent() *Traceparent { + traceId := [16]byte{} + parentId := [8]byte{} + rand.Read(traceId[:]) + rand.Read(parentId[:]) + return &Traceparent{ + version: 0, + traceId: traceId, + parentId: parentId, + flags: traceparentFlagSampled, + } +} + +func FromString(s string) (*Traceparent, error) { + parts := strings.Split(s, "-") + if len(parts) != 4 { + return nil, fmt.Errorf("invalid traceparent string: %s", s) + } + t := &Traceparent{ + traceId: [16]byte{}, + parentId: [8]byte{}, + } + version, err := hex.DecodeString(parts[0]) + if err != nil { + return nil, err + } + if len(version) != 1 { + return nil, fmt.Errorf("invalid version: %s, expected 1 byte", parts[0]) + } + t.version = version[0] + n, err := hex.Decode(t.traceId[:], []byte(parts[1])) + if err != nil { + return nil, err + } + if n != 16 { + return nil, fmt.Errorf("invalid traceId: %s, expected 16 bytes", parts[1]) + } + n, err = hex.Decode(t.parentId[:], []byte(parts[2])) + if err != nil { + return nil, err + } + if n != 8 { + return nil, fmt.Errorf("invalid parentId: %s, expected 8 bytes", parts[2]) + } + flags, err := hex.DecodeString(parts[3]) + if err != nil { + return nil, err + } + if len(flags) != 1 { + return nil, fmt.Errorf("invalid flags: %s, expected 1 byte", parts[3]) + } + t.flags = traceparentFlag(flags[0]) + return t, nil +} + +func AddTraceparent(r *http.Request) { + if r.Header.Get(TRACEPARENT_HEADER) == "" { + r.Header.Set(TRACEPARENT_HEADER, NewTraceparent().String()) + } +} diff --git a/httpclient/traceparent/traceparent_test.go b/httpclient/traceparent/traceparent_test.go new file mode 100644 index 000000000..9eb5cafa7 --- /dev/null +++ b/httpclient/traceparent/traceparent_test.go @@ -0,0 +1,53 @@ +package traceparent + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + tp := NewTraceparent() + assert.Equal(t, byte(0), tp.version) + assert.Equal(t, byte(1), byte(tp.flags)) +} + +func TestEqual(t *testing.T) { + tp1 := NewTraceparent() + tp2 := &Traceparent{ + version: tp1.version, + traceId: tp1.traceId, + parentId: tp1.parentId, + flags: tp1.flags, + } + assert.True(t, tp1.Equals(tp2)) +} + +func TestTwoNewTraceparentsAreNotEqual(t *testing.T) { + tp1 := NewTraceparent() + tp2 := NewTraceparent() + assert.False(t, tp1.Equals(tp2)) +} + +var testTraceId = [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} +var testParentId = [8]byte{0, 1, 2, 3, 4, 5, 6, 7} + +func TestString(t *testing.T) { + tp := &Traceparent{ + version: 0, + traceId: testTraceId, + parentId: testParentId, + flags: 1, + } + res := tp.String() + assert.Equal(t, "00-000102030405060708090a0b0c0d0e0f-0001020304050607-01", res) +} + +func TestFromString(t *testing.T) { + tp, err := FromString("00-000102030405060708090a0b0c0d0e0f-0001020304050607-01") + assert.NoError(t, err) + assert.Equal(t, byte(0), tp.version) + assert.Equal(t, testTraceId, tp.traceId) + assert.Equal(t, testParentId, tp.parentId) + assert.Equal(t, byte(1), byte(tp.flags)) +} From c9030682d555caf70983983e2185714c5348d101 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 May 2024 11:58:16 +0200 Subject: [PATCH 2/3] fix tests --- client/client_test.go | 2 ++ httpclient/fixtures/fixture.go | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 987d8ab9b..4a86130e6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -435,6 +435,8 @@ func testNonJSONResponseIncludedInError(t *testing.T, statusCode int, status, er Token: "token", ConfigFile: "/dev/null", HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + // Clear traceparent header which is nondeterministic. + r.Header.Del("traceparent") return &http.Response{ Proto: "HTTP/2.0", Status: status, diff --git a/httpclient/fixtures/fixture.go b/httpclient/fixtures/fixture.go index 3d5fafe65..32b00fcbd 100644 --- a/httpclient/fixtures/fixture.go +++ b/httpclient/fixtures/fixture.go @@ -6,10 +6,13 @@ import ( "fmt" "io" "net/http" + "net/textproto" "net/url" "os" "reflect" "strings" + + "github.com/databricks/databricks-sdk-go/httpclient/traceparent" ) // HTTPFixture defines request structure for test @@ -38,12 +41,17 @@ func (f HTTPFixture) AssertHeaders(req *http.Request) error { return nil } actualHeaders := map[string]string{} + // remove user agent & traceparent from comparison, as it'll make fixtures too difficult + // to maintain in the long term + toSkip := map[string]struct{}{ + textproto.CanonicalMIMEHeaderKey("User-Agent"): {}, + textproto.CanonicalMIMEHeaderKey(traceparent.TRACEPARENT_HEADER): {}, + } for k := range req.Header { - actualHeaders[k] = req.Header.Get(k) + if _, skip := toSkip[k]; !skip { + actualHeaders[k] = req.Header.Get(k) + } } - // remove user agent from comparison, as it'll make fixtures too difficult - // to maintain in the long term - delete(actualHeaders, "User-Agent") if !reflect.DeepEqual(f.ExpectedHeaders, actualHeaders) { expectedJSON, _ := json.MarshalIndent(f.ExpectedHeaders, "", " ") actualJSON, _ := json.MarshalIndent(actualHeaders, "", " ") From dd502d7ce5c47e1a9ff95469f34124e1ce5f5037 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Tue, 7 May 2024 17:23:35 +0200 Subject: [PATCH 3/3] test & comments --- config/config.go | 2 +- httpclient/api_client.go | 1 + httpclient/api_client_test.go | 20 ++++++++++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/config/config.go b/config/config.go index d5bfaeb9b..5da39cbbe 100644 --- a/config/config.go +++ b/config/config.go @@ -99,7 +99,7 @@ type Config struct { // Use at your own risk or for unit testing purposes. InsecureSkipVerify bool `name:"skip_verify" auth:"-"` - // Number of seconds for HTTP timeout. Default is 300 (5 minutes). + // Number of seconds for HTTP timeout. Default is 60 (1 minute). HTTPTimeoutSeconds int `name:"http_timeout_seconds" auth:"-"` // Truncate JSON fields in JSON above this limit. Default is 96. diff --git a/httpclient/api_client.go b/httpclient/api_client.go index a7fd861e0..ad4c4c63b 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -238,6 +238,7 @@ func (c *ApiClient) attempt( } } // Set traceparent for distributed tracing. + // This must be done after all visitors have run, as they may modify the request. traceparent.AddTraceparent(request) // request.Context() holds context potentially enhanced by visitors diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index af21ce90c..ba0ab9d8e 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -735,3 +735,23 @@ func TestTraceparentHeader(t *testing.T) { assert.NoError(t, err) } } + +func TestTraceparentHeaderDoesNotOverrideUserHeader(t *testing.T) { + userTraceparent := "00-thetraceid-theparentid-00" + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + tp := r.Header.Get("Traceparent") + assert.NotEmpty(t, tp) + assert.Equal(t, userTraceparent, tp) + return &http.Response{ + StatusCode: 200, + Request: r, + }, nil + }), + }) + + err := c.Do(context.Background(), "GET", "/a/b", WithRequestHeaders(map[string]string{ + "Traceparent": userTraceparent, + })) + assert.NoError(t, err) +}