From 0b52c35f01861da486f175b84c17ea3bfa4e1ebe Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 24 Nov 2023 12:25:37 +0100 Subject: [PATCH 1/3] Refactor HTTP client --- apierr/errors.go | 2 +- client/client.go | 537 +++------------- client/client_test.go | 691 +++------------------ config/auth_azure_cli.go | 2 + config/auth_azure_client_secret.go | 2 + config/auth_azure_msi.go | 2 + config/azure.go | 20 +- httpclient/api_client.go | 386 ++++++++++++ httpclient/api_client_test.go | 589 ++++++++++++++++++ {client => httpclient}/body_logger.go | 12 +- {client => httpclient}/body_logger_test.go | 2 +- httpclient/errors.go | 53 ++ httpclient/request.go | 182 ++++++ httpclient/response.go | 83 +++ retries/retries.go | 16 +- 15 files changed, 1482 insertions(+), 1097 deletions(-) create mode 100644 httpclient/api_client.go create mode 100644 httpclient/api_client_test.go rename {client => httpclient}/body_logger.go (94%) rename {client => httpclient}/body_logger_test.go (98%) create mode 100644 httpclient/errors.go create mode 100644 httpclient/request.go create mode 100644 httpclient/response.go diff --git a/apierr/errors.go b/apierr/errors.go index e520c2118..c21526363 100644 --- a/apierr/errors.go +++ b/apierr/errors.go @@ -154,7 +154,7 @@ func GenericIOError(ue *url.Error) *APIError { } // GetAPIError inspects HTTP errors from the Databricks API for known transient errors. -func GetAPIError(ctx context.Context, resp *http.Response, body io.ReadCloser) *APIError { +func GetAPIError(ctx context.Context, resp *http.Response, body io.ReadCloser) error { if resp.StatusCode == 429 { return TooManyRequests() } diff --git a/client/client.go b/client/client.go index 10417b6de..55cb53bc3 100644 --- a/client/client.go +++ b/client/client.go @@ -1,29 +1,17 @@ package client import ( - "bytes" "context" - "crypto/tls" - "encoding/json" "errors" "fmt" - "io" - "net" "net/http" "net/url" - "reflect" - "runtime" - "sort" - "strings" "time" "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/logger" - "github.com/databricks/databricks-sdk-go/retries" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/useragent" - "github.com/google/go-querystring/query" - "golang.org/x/time/rate" ) func New(cfg *config.Config) (*DatabricksClient, error) { @@ -31,142 +19,80 @@ func New(cfg *config.Config) (*DatabricksClient, error) { if err != nil { return nil, err } + return newWithTransport(cfg, cfg.HTTPTransport), nil +} + +func newWithTransport(cfg *config.Config, transport http.RoundTripper) *DatabricksClient { retryTimeout := time.Duration(orDefault(cfg.RetryTimeoutSeconds, 300)) * time.Second httpTimeout := time.Duration(orDefault(cfg.HTTPTimeoutSeconds, 60)) * time.Second - rateLimiter := rate.NewLimiter(rate.Limit(orDefault(cfg.RateLimitPerSecond, 15)), 1) - debugTruncateBytes := orDefault(cfg.DebugTruncateBytes, 96) - httpTransport := cfg.HTTPTransport - if httpTransport == nil { - httpTransport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, - IdleConnTimeout: 180 * time.Second, - TLSHandshakeTimeout: 30 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: cfg.InsecureSkipVerify, - }, - } - } return &DatabricksClient{ - Config: cfg, - debugHeaders: cfg.DebugHeaders, - debugTruncateBytes: debugTruncateBytes, - retryTimeout: retryTimeout, - rateLimiter: rateLimiter, - httpClient: &http.Client{ - Timeout: httpTimeout, - Transport: httpTransport, - }, - }, nil -} - -type httpClient interface { - Do(req *http.Request) (*http.Response, error) - CloseIdleConnections() -} - -// Represents a request body. -// -// If the provided request data is an io.Reader, DebugBytes is set to -// "". Otherwise, DebugBytes is set to the marshaled JSON -// representation of the request data, and ReadCloser is set to a new -// io.ReadCloser that reads from DebugBytes. -// -// Request bodies are never closed by the client, hence only accepting -// io.Reader. -type requestBody struct { - Reader io.Reader - DebugBytes []byte -} - -func newRequestBody(data any) (requestBody, error) { - switch v := data.(type) { - case io.Reader: - return requestBody{ - Reader: v, - DebugBytes: []byte(""), - }, nil - case string: - return requestBody{ - Reader: strings.NewReader(v), - DebugBytes: []byte(v), - }, nil - case []byte: - return requestBody{ - Reader: bytes.NewReader(v), - DebugBytes: v, - }, nil - default: - bs, err := json.Marshal(data) - if err != nil { - return requestBody{}, fmt.Errorf("request marshal failure: %w", err) - } - return requestBody{ - Reader: bytes.NewReader(bs), - DebugBytes: bs, - }, nil - } -} - -// Reset a request body to its initial state. -// -// This is used to retry requests with a body that has already been read. -// If the request body is not resettable (i.e. not nil and of type other than -// strings.Reader or bytes.Reader), this will return an error. -func (r requestBody) reset() error { - if r.Reader == nil { - return nil - } - if v, ok := r.Reader.(io.Seeker); ok { - _, err := v.Seek(0, io.SeekStart) - return err - } else { - return fmt.Errorf("cannot reset reader of type %T", r.Reader) - } -} - -// Represents a response body. -// -// Responses must always be closed. For non-streaming responses, they are closed -// during deserialization in the client (see unmarshall()). For streaming -// responses, they are returned to the caller, who is responsible for closing -// them. -type responseBody struct { - ReadCloser io.ReadCloser - DebugBytes []byte -} - -func newResponseBody(data any) (responseBody, error) { - switch v := data.(type) { - case io.ReadCloser: - return responseBody{ - ReadCloser: v, - DebugBytes: []byte(""), - }, nil - case []byte: - return responseBody{ - ReadCloser: io.NopCloser(bytes.NewReader(v)), - DebugBytes: v, - }, nil - default: - return responseBody{}, errors.New("newResponseBody can only be called with io.ReadCloser or []byte") + Config: cfg, + client: httpclient.NewApiClient(httpclient.ClientConfig{ + RetryTimeout: retryTimeout, + HTTPTimeout: httpTimeout, + RateLimitPerSecond: orDefault(cfg.RateLimitPerSecond, 15), + DebugHeaders: cfg.DebugHeaders, + DebugTruncateBytes: cfg.DebugTruncateBytes, + InsecureSkipVerify: cfg.InsecureSkipVerify, + Transport: transport, + Visitors: []httpclient.RequestVisitor{ + cfg.Authenticate, + func(r *http.Request) error { + if r.URL == nil { + return fmt.Errorf("no URL found in request") + } + url, err := url.Parse(cfg.Host) + if err != nil { + return err + } + r.URL.Host = url.Host + r.URL.Scheme = url.Scheme + return nil + }, + func(r *http.Request) error { + ctx := useragent.InContext(r.Context(), "auth", cfg.AuthType) + *r = *r.WithContext(ctx) // replace request + return nil + }, + func(r *http.Request) error { + // Detect if we are running in a CI/CD environment + provider := useragent.CiCdProvider() + if provider == "" { + return nil + } + // Add the detected CI/CD provider to the user agent + ctx := useragent.InContext(r.Context(), "cicd", provider) + *r = *r.WithContext(ctx) // replace request + return nil + }, + }, + TransientErrors: []string{ + "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", + "does not have any associated worker environments", + "There is no worker environment with id", + "Unknown worker environment", + "ClusterNotReadyException", + "connection reset by peer", + "TLS handshake timeout", + "connection refused", + "Unexpected error", + "i/o timeout", + }, + ErrorMapper: apierr.GetAPIError, + ErrorRetriable: func(ctx context.Context, err error) bool { + var apiErr *apierr.APIError + if errors.As(err, &apiErr) { + return apiErr.IsRetriable(ctx) + } + return false + }, + }), } } type DatabricksClient struct { - Config *config.Config - rateLimiter *rate.Limiter - retryTimeout time.Duration - httpClient httpClient - debugHeaders bool - debugTruncateBytes int + Config *config.Config + client *httpclient.ApiClient } // ConfiguredAccountID returns Databricks Account ID if it's provided in config, @@ -179,320 +105,14 @@ func (c *DatabricksClient) ConfiguredAccountID() string { func (c *DatabricksClient) Do(ctx context.Context, method, path string, headers map[string]string, request, response any, visitors ...func(*http.Request) error) error { - body, err := c.perform(ctx, method, path, headers, request, visitors...) - if err != nil { - return err - } - return c.unmarshal(body, response) -} - -func (c *DatabricksClient) unmarshal(body *responseBody, response any) error { - if response == nil { - return nil - } - // If the destination is bytes.Buffer, write the body over there - if raw, ok := response.(*io.ReadCloser); ok { - *raw = body.ReadCloser - return nil - } - defer body.ReadCloser.Close() - bs, err := io.ReadAll(body.ReadCloser) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - if len(bs) == 0 { - return nil - } - // If the destination is a byte slice or buffer, pass the body verbatim. - if raw, ok := response.(*[]byte); ok { - *raw = bs - return nil - } - if raw, ok := response.(*bytes.Buffer); ok { - _, err := raw.Write(bs) - return err - } - return json.Unmarshal(bs, &response) -} - -func (c *DatabricksClient) addHostToRequestUrl(r *http.Request) error { - if r.URL == nil { - return fmt.Errorf("no URL found in request") - } - url, err := url.Parse(c.Config.Host) - if err != nil { - return err - } - r.URL.Host = url.Host - r.URL.Scheme = url.Scheme - return nil -} - -func (c *DatabricksClient) fromResponse(r *http.Response) (responseBody, error) { - if r == nil { - return responseBody{}, fmt.Errorf("nil response") - } - if r.Request == nil { - return responseBody{}, fmt.Errorf("nil request") - } - streamResponse := r.Request.Header.Get("Accept") != "application/json" && r.Header.Get("Content-Type") != "application/json" - if streamResponse { - return newResponseBody(r.Body) - } - defer r.Body.Close() - bs, err := io.ReadAll(r.Body) - if err != nil { - return responseBody{}, fmt.Errorf("response body: %w", err) - } - return newResponseBody(bs) -} - -func (c *DatabricksClient) redactedDump(prefix string, body []byte) (res string) { - return bodyLogger{ - debugTruncateBytes: c.debugTruncateBytes, - }.redactedDump(prefix, body) -} - -// Common error-handling logic for all responses that may need to be retried. -// -// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed -// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt. -// -// Always returns nil for the first parameter as there is no meaningful response body to return in the error case. -// -// If it is certain that an error should not be retried, use failRequest() instead. -func (c *DatabricksClient) handleError(ctx context.Context, err *apierr.APIError, body requestBody) (*responseBody, *retries.Err) { - if !err.IsRetriable(ctx) { - return c.failRequest(ctx, "non-retriable error", err) - } - if resetErr := body.reset(); resetErr != nil { - return nil, retries.Halt(resetErr) - } - return nil, retries.Continue(err) -} - -// Fails the request with a retries.Err to halt future retries. -func (c *DatabricksClient) failRequest(ctx context.Context, msg string, err error) (*responseBody, *retries.Err) { - logger.Debugf(ctx, "%s: %s", msg, err) - return nil, retries.Halt(err) -} - -func (c *DatabricksClient) attempt( - ctx context.Context, - method string, - requestURL string, - headers map[string]string, - requestBody requestBody, - visitors ...func(*http.Request) error, -) func() (*responseBody, *retries.Err) { - return func() (*responseBody, *retries.Err) { - err := c.rateLimiter.Wait(ctx) - if err != nil { - return c.failRequest(ctx, "failed in rate limiter", err) - } - request, err := http.NewRequestWithContext(ctx, method, requestURL, requestBody.Reader) - if err != nil { - return c.failRequest(ctx, "failed creating new request", err) - } - for k, v := range headers { - request.Header.Set(k, v) - } - for _, requestVisitor := range visitors { - err = requestVisitor(request) - if err != nil { - return c.failRequest(ctx, "failed during request visitor", err) - } - } - // request.Context() holds context potentially enhanced by visitors - request.Header.Set("User-Agent", useragent.FromContext(request.Context())) - - // attempt the actual request - response, err := c.httpClient.Do(request) - - // After this point, the request body has (probably) been consumed. handleError() must be called to reset it if - // possible. - if ue, ok := err.(*url.Error); ok { - return c.handleError(ctx, apierr.GenericIOError(ue), requestBody) - } - - // By this point, the request body has certainly been consumed. - responseBody, responseBodyErr := c.fromResponse(response) - if responseBodyErr != nil { - return c.failRequest(ctx, "failed while reading response", apierr.ReadError(response.StatusCode, responseBodyErr)) - } - - apiErr := apierr.GetAPIError(ctx, response, responseBody.ReadCloser) - defer c.recordRequestLog(ctx, request, response, apiErr, requestBody.DebugBytes, responseBody.DebugBytes) - - if apiErr == nil { - return &responseBody, nil - } - - // proactively release the connections in HTTP connection pool - c.httpClient.CloseIdleConnections() - return c.handleError(ctx, apiErr, requestBody) - } -} - -func (c *DatabricksClient) recordRequestLog( - ctx context.Context, - request *http.Request, - response *http.Response, - err error, - requestBody, responseBody []byte, -) { - // Don't compute expensive debug message if debug logging is not enabled. - if !logger.Get(ctx).Enabled(ctx, logger.LevelDebug) { - return - } - sb := strings.Builder{} - sb.WriteString(fmt.Sprintf("%s %s", request.Method, - escapeNewLines(request.URL.Path))) - if request.URL.RawQuery != "" { - sb.WriteString("?") - q, _ := url.QueryUnescape(request.URL.RawQuery) - sb.WriteString(q) - } - sb.WriteString("\n") - if c.debugHeaders { - if c.Config.Host != "" { - sb.WriteString("> * Host: ") - sb.WriteString(escapeNewLines(c.Config.Host)) - sb.WriteString("\n") - } - for k, v := range request.Header { - trunc := onlyNBytes(strings.Join(v, ""), c.debugTruncateBytes) - sb.WriteString(fmt.Sprintf("> * %s: %s\n", k, escapeNewLines(trunc))) - } - } - if len(requestBody) > 0 { - sb.WriteString(c.redactedDump("> ", requestBody)) - sb.WriteString("\n") - } - sb.WriteString("< ") - if response != nil { - sb.WriteString(fmt.Sprintf("%s %s", response.Proto, response.Status)) - // Only display error on this line if the response body is empty. - // Otherwise the response body will include details about the error. - if len(responseBody) == 0 && err != nil { - sb.WriteString(fmt.Sprintf(" (Error: %s)", err)) - } - } else { - sb.WriteString(fmt.Sprintf("Error: %s", err)) - } - sb.WriteString("\n") - if len(responseBody) > 0 { - sb.WriteString(c.redactedDump("< ", responseBody)) + opts := []httpclient.DoOption{} + for _, v := range visitors { + opts = append(opts, httpclient.WithVisitor(v)) } - logger.Debugf(ctx, sb.String()) // lgtm [go/log-injection] lgtm [go/clear-text-logging] -} - -func (c *DatabricksClient) addAuthHeaderToUserAgent(r *http.Request) error { - ctx := useragent.InContext(r.Context(), "auth", c.Config.AuthType) - *r = *r.WithContext(ctx) // replace request - return nil -} - -func (c *DatabricksClient) addCiCdProviderToUserAgent(r *http.Request) error { - // Detect if we are running in a CI/CD environment - provider := useragent.CiCdProvider() - if provider == "" { - return nil - } - - // Add the detected CI/CD provider to the user agent - ctx := useragent.InContext(r.Context(), "cicd", provider) - *r = *r.WithContext(ctx) // replace request - return nil -} - -func (c *DatabricksClient) perform( - ctx context.Context, - method, - requestURL string, - headers map[string]string, - data interface{}, - visitors ...func(*http.Request) error, -) (*responseBody, error) { - // replace double slash in the request URL with a single slash - requestURL = strings.Replace(requestURL, "//", "/", -1) - requestBody, err := makeRequestBody(method, &requestURL, data) - if err != nil { - return nil, fmt.Errorf("request marshal: %w", err) - } - visitors = append([]func(*http.Request) error{ - c.Config.Authenticate, - c.addHostToRequestUrl, - c.addAuthHeaderToUserAgent, - c.addCiCdProviderToUserAgent, - }, visitors...) - resp, err := retries.Poll(ctx, c.retryTimeout, - c.attempt(ctx, method, requestURL, headers, requestBody, visitors...)) - if err != nil { - // Don't re-wrap, as upper layers may depend on handling apierr.APIError. - return nil, err - } - return resp, nil -} - -func makeQueryString(data interface{}) (string, error) { - inputVal := reflect.ValueOf(data) - inputType := reflect.TypeOf(data) - if inputType.Kind() == reflect.Map { - s := []string{} - keys := inputVal.MapKeys() - // sort map keys by their string repr, so that tests can be deterministic - sort.Slice(keys, func(i, j int) bool { - return keys[i].String() < keys[j].String() - }) - for _, k := range keys { - v := inputVal.MapIndex(k) - if v.IsZero() { - continue - } - s = append(s, fmt.Sprintf("%s=%s", - strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1), - strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1))) - } - return "?" + strings.Join(s, "&"), nil - } - if inputType.Kind() == reflect.Struct { - params, err := query.Values(data) - if err != nil { - return "", fmt.Errorf("cannot create query string: %w", err) - } - // Query parameters may be nested, but the keys generated by - // query.Values use the "[" and "]" characters to represent nesting. - // Replace all instances of "[" with "." and "]" with empty string - // to make the query string compatible with the proto API. - // See the following for more information: - // https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule - protoCompatibleParams := make(url.Values) - for k, vs := range params { - newK := strings.Replace(k, "[", ".", -1) - newK = strings.Replace(newK, "]", "", -1) - for _, v := range vs { - protoCompatibleParams.Add(newK, v) - } - } - return "?" + protoCompatibleParams.Encode(), nil - } - return "", fmt.Errorf("unsupported query string data: %#v", data) -} - -func makeRequestBody(method string, requestURL *string, data interface{}) (requestBody, error) { - if data == nil { - return requestBody{}, nil - } - if method == "GET" || method == "DELETE" { - qs, err := makeQueryString(data) - if err != nil { - return requestBody{}, err - } - *requestURL += qs - return newRequestBody([]byte{}) - } - return newRequestBody(data) + opts = append(opts, httpclient.WithHeaders(headers)) + opts = append(opts, httpclient.WithData(request)) + opts = append(opts, httpclient.WithUnmarshal(response)) + return c.client.Do(ctx, method, path, opts...) } func orDefault(configured, _default int) int { @@ -501,10 +121,3 @@ func orDefault(configured, _default int) int { } return configured } - -// CWE-117 prevention -func escapeNewLines(in string) string { - in = strings.Replace(in, "\n", "", -1) - in = strings.Replace(in, "\r", "", -1) - return in -} diff --git a/client/client_test.go b/client/client_test.go index 348e32317..843f01163 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -7,103 +7,68 @@ import ( "io" "net/http" "net/url" - "os" "strings" "testing" - "time" "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/config" - "github.com/databricks/databricks-sdk-go/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/time/rate" ) -// errReader(true) will also fail on Close -type errReader bool - -func (errReader) Read(p []byte) (n int, err error) { - return 0, fmt.Errorf("test error") -} - -func (i errReader) Close() error { - if i { - return fmt.Errorf("test error") - } - return nil -} - type hc func(r *http.Request) (*http.Response, error) -func (cb hc) Do(r *http.Request) (*http.Response, error) { +func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { return cb(r) } -func (cb hc) CloseIdleConnections() {} - func TestNew(t *testing.T) { - c, err := New(&config.Config{ + _, err := New(&config.Config{ ConfigFile: "/dev/null", }) assert.NoError(t, err) - - assert.Equal(t, 96, c.debugTruncateBytes) - assert.Equal(t, 5*time.Minute, c.retryTimeout) } func TestSimpleRequestFailsURLError(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - r.Header.Add("Authenticated", "yes") - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - assert.Equal(t, "GET", r.Method) - assert.Equal(t, "/a/b", r.URL.Path) - assert.Equal(t, "c=d", r.URL.RawQuery) - assert.Equal(t, "f", r.Header.Get("e")) - auth := r.Header.Get("Authenticated") - assert.Equal(t, "yes", auth) - return &http.Response{ - Request: r, - }, &url.Error{ - Op: "GET", - URL: "/a/b", - Err: fmt.Errorf("nope"), - } - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + cfg := config.NewMockConfig(func(r *http.Request) error { + r.Header.Add("Authenticated", "yes") + return nil + }) + cfg.RetryTimeoutSeconds = 1 + c := newWithTransport(cfg, hc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "/a/b", r.URL.Path) + assert.Equal(t, "c=d", r.URL.RawQuery) + assert.Equal(t, "f", r.Header.Get("e")) + auth := r.Header.Get("Authenticated") + assert.Equal(t, "yes", auth) + return nil, fmt.Errorf("nope") + })) err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ "c": "d", }, nil) - assert.EqualError(t, err, "GET \"/a/b\": nope") + assert.EqualError(t, err, `Get "/a/b?c=d": nope`) } func TestSimpleRequestFailsAPIError(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - r.Header.Add("Authenticated", "yes") - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - assert.Equal(t, "GET", r.Method) - assert.Equal(t, "/a/b", r.URL.Path) - assert.Equal(t, "c=d", r.URL.RawQuery) - assert.Equal(t, "f", r.Header.Get("e")) - auth := r.Header.Get("Authenticated") - assert.Equal(t, "yes", auth) - return &http.Response{ - StatusCode: 400, - Request: r, - Body: io.NopCloser(strings.NewReader(`{"error_code": "INVALID_PARAMETER_VALUE", "message": "nope"}`)), - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + c := *newWithTransport(config.NewMockConfig(func(r *http.Request) error { + r.Header.Add("Authenticated", "yes") + return nil + }), hc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "GET", r.Method) + assert.Equal(t, "/a/b", r.URL.Path) + assert.Equal(t, "c=d", r.URL.RawQuery) + assert.Equal(t, "f", r.Header.Get("e")) + auth := r.Header.Get("Authenticated") + assert.Equal(t, "yes", auth) + return &http.Response{ + StatusCode: 400, + Request: r, + Body: io.NopCloser(strings.NewReader(`{"error_code": "INVALID_PARAMETER_VALUE", "message": "nope"}`)), + }, nil + })) err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ @@ -116,15 +81,13 @@ func TestETag(t *testing.T) { reason := "some_reason" domain := "a_domain" eTag := "sample_etag" - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 400, - Request: r, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{ + c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { + return nil + }), hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Request: r, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{ "error_code": "RESOURCE_CONFLICT", "message": "test_public_workspace_setting", "stack_trace": "java.io.PrintWriter@329e4ed3", @@ -147,10 +110,8 @@ func TestETag(t *testing.T) { } ] }`, "type.googleapis.com/google.rpc.ErrorInfo", reason, domain, eTag))), - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + }, nil + })) err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ @@ -170,19 +131,15 @@ func TestSimpleRequestSucceeds(t *testing.T) { type Dummy struct { Foo int `json:"foo"` } - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { + return nil + }), hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + })) var resp Dummy err := c.Do(context.Background(), "POST", "/c", nil, Dummy{1}, &resp) assert.NoError(t, err) @@ -194,28 +151,23 @@ func TestSimpleRequestRetried(t *testing.T) { Foo int `json:"foo"` } var retried [1]bool - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - if !retried[0] { - retried[0] = true - return nil, &url.Error{ - Op: "open", - URL: "/a/b", - Err: fmt.Errorf("connection refused"), - } + c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { + return nil + }), hc(func(r *http.Request) (*http.Response, error) { + if !retried[0] { + retried[0] = true + return nil, &url.Error{ + Op: "open", + URL: "/a/b", + Err: fmt.Errorf("connection refused"), } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), - Request: r, - }, nil - }), - retryTimeout: 1 * time.Minute, - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + })) var resp Dummy err := c.Do(context.Background(), "PATCH", "/a", nil, Dummy{1}, &resp) assert.NoError(t, err) @@ -223,514 +175,25 @@ func TestSimpleRequestRetried(t *testing.T) { assert.True(t, retried[0], "request was not retried") } -func TestHaltAttemptForLimit(t *testing.T) { - ctx := context.Background() - c := &DatabricksClient{ - rateLimiter: &rate.Limiter{}, - } - req, err := newRequestBody([]byte{}) - assert.NoError(t, err) - _, rerr := c.attempt(ctx, "GET", "foo", nil, req)() - assert.NotNil(t, rerr) - assert.Equal(t, true, rerr.Halt) - assert.EqualError(t, rerr.Err, "rate: Wait(n=1) exceeds limiter's burst 0") -} - -func TestHaltAttemptForNewRequest(t *testing.T) { - ctx := context.Background() - c := &DatabricksClient{ - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - req, err := newRequestBody([]byte{}) - assert.NoError(t, err) - _, rerr := c.attempt(ctx, "🥱", "/", nil, req)() - assert.NotNil(t, rerr) - assert.Equal(t, true, rerr.Halt) - assert.EqualError(t, rerr.Err, `net/http: invalid method "🥱"`) -} - -func TestHaltAttemptForVisitor(t *testing.T) { - ctx := context.Background() - c := &DatabricksClient{ - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - req, err := newRequestBody([]byte{}) - assert.NoError(t, err) - _, rerr := c.attempt(ctx, "GET", "/", nil, req, - func(r *http.Request) error { - return fmt.Errorf("🥱") - })() - assert.NotNil(t, rerr) - assert.Equal(t, true, rerr.Halt) - assert.EqualError(t, rerr.Err, "🥱") -} - -func TestMakeRequestBody(t *testing.T) { - type x struct { - Scope string `json:"scope" url:"scope"` - } - requestURL := "/a/b/c" - body, err := makeRequestBody("GET", &requestURL, x{"test"}) - assert.NoError(t, err) - bodyBytes, err := io.ReadAll(body.Reader) - assert.NoError(t, err) - assert.Equal(t, "/a/b/c?scope=test", requestURL) - assert.Equal(t, 0, len(bodyBytes)) - - requestURL = "/a/b/c" - body, err = makeRequestBody("POST", &requestURL, x{"test"}) - assert.NoError(t, err) - bodyBytes, err = io.ReadAll(body.Reader) - assert.NoError(t, err) - assert.Equal(t, "/a/b/c", requestURL) - x1 := `{"scope":"test"}` - assert.Equal(t, []byte(x1), bodyBytes) -} - -func TestMakeRequestBodyFromReader(t *testing.T) { - requestURL := "/a/b/c" - body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc")) - assert.NoError(t, err) - bodyBytes, err := io.ReadAll(body.Reader) - assert.NoError(t, err) - assert.Equal(t, []byte("abc"), bodyBytes) -} - -func TestMakeRequestBodyReaderError(t *testing.T) { - requestURL := "/a/b/c" - _, err := makeRequestBody("POST", &requestURL, errReader(false)) - // The request body is only read once the request is sent, so no error - // should be returned until then. - assert.NoError(t, err, "request body reader error should be ignored") -} - -func TestMakeRequestBodyJsonError(t *testing.T) { - requestURL := "/a/b/c" - type x struct { - Foo chan string `json:"foo"` - } - _, err := makeRequestBody("POST", &requestURL, x{make(chan string)}) - assert.EqualError(t, err, "request marshal failure: json: unsupported type: chan string") -} - -type failingUrlEncode string - -func (fue failingUrlEncode) EncodeValues(key string, v *url.Values) error { - return fmt.Errorf(string(fue)) -} - -func TestMakeRequestBodyQueryFailingEncode(t *testing.T) { - requestURL := "/a/b/c" - type x struct { - Foo failingUrlEncode `url:"foo"` - } - _, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}) - assert.EqualError(t, err, "cannot create query string: always failing") -} - -func TestMakeRequestBodyQueryUnsupported(t *testing.T) { - requestURL := "/a/b/c" - _, err := makeRequestBody("GET", &requestURL, true) - assert.EqualError(t, err, "unsupported query string data: true") -} - -func TestFailPerformChannel(t *testing.T) { - ctx := context.Background() - c := &DatabricksClient{ - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - _, err := c.perform(ctx, "GET", "/", nil, true) - assert.EqualError(t, err, "request marshal: unsupported query string data: true") -} - func TestSimpleRequestAPIError(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 400, - Body: io.NopCloser(strings.NewReader(`{ + c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { + return nil + }), hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 404, + Body: io.NopCloser(strings.NewReader(`{ "error_code": "NOT_FOUND", "message": "Something was not found" }`)), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } + Request: r, + }, nil + })) err := c.Do(context.Background(), "PATCH", "/a", nil, map[string]any{}, nil) var aerr *apierr.APIError if assert.ErrorAs(t, err, &aerr) { assert.Equal(t, "NOT_FOUND", aerr.ErrorCode) } -} - -func TestSimpleRequestErrReaderBody(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: errReader(false), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "PATCH", "/a", headers, map[string]any{}, nil) - assert.EqualError(t, err, "response body: test error") -} - -func TestSimpleRequestErrReaderBodyStreamResponse(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: errReader(false), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "PATCH", "/a", headers, map[string]any{}, nil) - assert.NoError(t, err, "streaming response bodies are not read") -} - -func TestSimpleRequestErrReaderCloseBody(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: errReader(true), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "PATCH", "/a", headers, map[string]any{}, nil) - assert.EqualError(t, err, "response body: test error") -} - -func TestSimpleRequestErrReaderCloseBody_StreamResponse(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: errReader(true), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "PATCH", "/a", headers, map[string]any{}, nil) - assert.NoError(t, err, "response body should not be closed for streaming responses") -} - -func TestSimpleRequestRawResponse(t *testing.T) { - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("Hello, world!")), - Request: r, - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - var raw []byte - err := c.Do(context.Background(), "GET", "/a", nil, nil, &raw) - assert.NoError(t, err) - assert.Equal(t, "Hello, world!", string(raw)) -} - -type BufferLogger struct { - strings.Builder -} - -func (l *BufferLogger) Enabled(_ context.Context, level logger.Level) bool { - return true -} - -func (l *BufferLogger) Tracef(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[TRACE] "+format, v...)) -} - -func (l *BufferLogger) Debugf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[DEBUG] "+format, v...)) -} - -func (l *BufferLogger) Infof(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[INFO] "+format, v...)) -} - -func (l *BufferLogger) Warnf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[WARN] "+format, v...)) -} - -func (l *BufferLogger) Errorf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[ERROR] "+format, v...)) -} - -func TestSimpleResponseRedaction(t *testing.T) { - cfg := config.NewMockConfig(func(r *http.Request) error { - r.Header.Add("X-For-Logging", "yes") - return nil - }) - cfg.Host = "http://localhost:12345" - - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() - - c := &DatabricksClient{ - Config: cfg, - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Proto: "HTTP/3.4", - Status: "200 Fine", - Body: io.NopCloser(strings.NewReader(`{ - "string_value": "__SENSITIVE01__", - "inner": { - "token_value": "__SENSITIVE02__", - "content": "__SENSITIVE03__" - }, - "list": [ - { - "token_value": "__SENSITIVE04__" - } - ], - "longer": "12345678901234567890qwerty" - }`)), - Request: r, - }, nil - }), - debugTruncateBytes: 16, - debugHeaders: true, - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - err := c.Do(context.Background(), "GET", "/a", nil, map[string]any{ - "b": 0, - "a": 3, - "c": 23, - }, nil) - assert.NoError(t, err) - // not testing for exact logged lines, as header order is not deterministic - assert.NotContains(t, bufLogger.String(), "__SENSITIVE01__") - assert.NotContains(t, bufLogger.String(), "__SENSITIVE02__") - assert.NotContains(t, bufLogger.String(), "__SENSITIVE03__") - assert.NotContains(t, bufLogger.String(), "__SENSITIVE04__") - assert.NotContains(t, bufLogger.String(), "12345678901234567890qwerty") -} - -func TestInlineArrayDebugging(t *testing.T) { - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() - - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`[ - {"foo": "bar"} - ]`)), - Request: r, - }, nil - }), - debugTruncateBytes: 2048, - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "GET", "/a", headers, map[string]any{ - "b": 0, - "a": 3, - "c": 23, - }, nil) - assert.NoError(t, err) - - assert.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 -< -< [ -< { -< "foo": "bar" -< } -< ]`, bufLogger.String()) -} - -func TestInlineArrayDebugging_StreamResponse(t *testing.T) { - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() - - c := &DatabricksClient{ - Config: config.NewMockConfig(func(r *http.Request) error { - return nil - }), - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`lots of bytes`)), - Request: r, - }, nil - }), - debugTruncateBytes: 2048, - rateLimiter: rate.NewLimiter(rate.Inf, 1), - } - headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "GET", "/a", headers, map[string]any{ - "b": 0, - "a": 3, - "c": 23, - }, nil) - assert.NoError(t, err) - - assert.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 -< -< [non-JSON document of 15 bytes]. `, bufLogger.String()) -} - -func TestStreamRequestFromFileWithReset(t *testing.T) { - // make a temporary file with some content - f, err := os.CreateTemp("", "databricks-client-test") - assert.NoError(t, err) - defer os.Remove(f.Name()) - _, err = f.WriteString("hello world") - assert.NoError(t, err) - assert.NoError(t, f.Close()) - - // Make a reader that reads this file - r, err := os.Open(f.Name()) - assert.NoError(t, err) - defer r.Close() - - succeed := false - handler := func(req *http.Request) (*http.Response, error) { - bytes, err := io.ReadAll(req.Body) - assert.NoError(t, err) - assert.Equal(t, "hello world", string(bytes)) - if succeed { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("succeeded")), - Request: req, - }, nil - } - succeed = true - return &http.Response{ - StatusCode: 429, - Body: io.NopCloser(strings.NewReader("failed")), - Request: req, - }, nil - } - - client := &DatabricksClient{ - httpClient: hc(handler), - rateLimiter: rate.NewLimiter(rate.Limit(1), 1), - Config: config.NewMockConfig(func(r *http.Request) error { return nil }), - retryTimeout: time.Hour, - } - - respBytes := bytes.Buffer{} - err = client.Do(context.Background(), "POST", "/a", nil, r, &respBytes) - assert.NoError(t, err) - assert.Equal(t, "succeeded", respBytes.String()) - assert.True(t, succeed) -} - -type customReader struct{} - -func (c customReader) Read(p []byte) (n int, err error) { - return 0, nil -} - -func TestCannotRetryArbitraryReader(t *testing.T) { - client := &DatabricksClient{ - httpClient: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 429, - Request: r, - Body: io.NopCloser(strings.NewReader("")), - }, nil - }), - rateLimiter: rate.NewLimiter(rate.Limit(1), 1), - Config: config.NewMockConfig(func(r *http.Request) error { return nil }), - retryTimeout: time.Hour, - } - err := client.Do(context.Background(), "POST", "/a", nil, customReader{}, nil) - assert.ErrorContains(t, err, "cannot reset reader of type client.customReader") -} - -func TestRetryGetRequest(t *testing.T) { - // This test was added in response to https://github.com/databricks/terraform-provider-databricks/issues/2675. - succeed := false - handler := func(req *http.Request) (*http.Response, error) { - assert.Nil(t, req.Body) - - if succeed { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("succeeded")), - Request: req, - }, nil - } - - succeed = true - return &http.Response{ - StatusCode: 429, - Body: io.NopCloser(strings.NewReader("failed")), - Request: req, - }, nil - } - - client := &DatabricksClient{ - httpClient: hc(handler), - rateLimiter: rate.NewLimiter(rate.Limit(1), 1), - Config: config.NewMockConfig(func(r *http.Request) error { return nil }), - retryTimeout: time.Hour, - } - - respBytes := bytes.Buffer{} - err := client.Do(context.Background(), "GET", "/a", nil, nil, &respBytes) - assert.NoError(t, err) - assert.Equal(t, "succeeded", respBytes.String()) - assert.True(t, succeed) -} - -func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { - return cb(r) + assert.ErrorIs(t, err, apierr.ErrNotFound) } func TestHttpTransport(t *testing.T) { diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index c2a3fc356..ff8762748 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -11,6 +11,7 @@ import ( "golang.org/x/oauth2" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" ) @@ -76,6 +77,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(* } return nil, err } + ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) err = cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index eb5e4c3f8..bc2c55592 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -9,6 +9,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/clientcredentials" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" ) @@ -46,6 +47,7 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config if err != nil { return nil, err } + ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) err = cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index 02a0ba229..9f2218ec2 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -8,6 +8,7 @@ import ( "net/http" "time" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) @@ -34,6 +35,7 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(* if err != nil { return nil, err } + ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) if !cfg.IsAccountClient() { err = cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { diff --git a/config/azure.go b/config/azure.go index 9a5ab044a..0b61eefa9 100644 --- a/config/azure.go +++ b/config/azure.go @@ -2,11 +2,10 @@ package config import ( "context" - "encoding/json" "fmt" - "io" "strings" + "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "golang.org/x/oauth2" ) @@ -81,23 +80,18 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol } // azure resource ID can also be used in lieu of host by some of the clients, like Terraform management := ahr.tokenSourceFor(ctx, c, env, env.ResourceManagerEndpoint) - resourceManager := oauth2.NewClient(ctx, management) - resp, err := resourceManager.Get(env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01") - if err != nil { - return fmt.Errorf("cannot resolve workspace: %w", err) - } - raw, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("cannot read: %w", err) - } var workspaceMetadata struct { Properties struct { WorkspaceURL string `json:"workspaceUrl"` } `json:"properties"` } - err = json.Unmarshal(raw, &workspaceMetadata) + requestURL := env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01" + err = httpclient.DefaultClient.Do(ctx, "GET", requestURL, + httpclient.WithUnmarshal(&workspaceMetadata), + httpclient.WithTokenSource(management), + ) if err != nil { - return fmt.Errorf("cannot unmarshal: %w", err) + return fmt.Errorf("resolve workspace: %w", err) } c.Host = fmt.Sprintf("https://%s", workspaceMetadata.Properties.WorkspaceURL) logger.Debugf(ctx, "Discovered workspace url: %s", c.Host) diff --git a/httpclient/api_client.go b/httpclient/api_client.go new file mode 100644 index 000000000..d25447205 --- /dev/null +++ b/httpclient/api_client.go @@ -0,0 +1,386 @@ +package httpclient // has to be a separate package than client, otherwise a circular dependency + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "runtime" + "strings" + "time" + + "github.com/databricks/databricks-sdk-go/logger" + "github.com/databricks/databricks-sdk-go/retries" + "github.com/databricks/databricks-sdk-go/useragent" + "golang.org/x/oauth2" + "golang.org/x/time/rate" +) + +type RequestVisitor func(*http.Request) error + +type ClientConfig struct { + Visitors []RequestVisitor + + RetryTimeout time.Duration + HTTPTimeout time.Duration + InsecureSkipVerify bool + DebugHeaders bool + DebugTruncateBytes int + RateLimitPerSecond int + + ErrorMapper func(ctx context.Context, resp *http.Response, body io.ReadCloser) error + ErrorRetriable func(ctx context.Context, err error) bool + TransientErrors []string + + Transport http.RoundTripper +} + +func (cfg ClientConfig) httpTransport() http.RoundTripper { + if cfg.Transport != nil { + return cfg.Transport + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, + IdleConnTimeout: 180 * time.Second, + TLSHandshakeTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: cfg.InsecureSkipVerify, + }, + } +} + +var DefaultClient = NewApiClient(ClientConfig{ + ErrorRetriable: DefaultErrorRetriable, + ErrorMapper: DefaultErrorMapper, + HTTPTimeout: 30 * time.Second, + RetryTimeout: 5 * time.Minute, + RateLimitPerSecond: 30, +}) + +func NewApiClient(cfg ClientConfig) *ApiClient { + cfg.HTTPTimeout = time.Duration(orDefault(int(cfg.HTTPTimeout), int(30*time.Second))) + cfg.DebugTruncateBytes = orDefault(cfg.DebugTruncateBytes, 96) + cfg.RetryTimeout = time.Duration(orDefault(int(cfg.RetryTimeout), int(5*time.Minute))) + cfg.HTTPTimeout = time.Duration(orDefault(int(cfg.HTTPTimeout), int(30*time.Second))) + rateLimiter := rate.NewLimiter(rate.Limit(orDefault(cfg.RateLimitPerSecond, 15)), 1) + if cfg.ErrorMapper == nil { + // default generic error mapper + cfg.ErrorMapper = DefaultErrorMapper + } + if cfg.ErrorRetriable == nil { + // by default, we just retry on HTTP 429/504 + cfg.ErrorRetriable = DefaultErrorRetriable + } + return &ApiClient{ + config: cfg, + rateLimiter: rateLimiter, + httpClient: &http.Client{ + Timeout: cfg.HTTPTimeout, + Transport: cfg.httpTransport(), + }, + } +} + +type ApiClient struct { + config ClientConfig + rateLimiter *rate.Limiter + httpClient *http.Client +} + +type DoOption struct { + in RequestVisitor + out func(body *responseBody) error + body any +} + +// Do sends an HTTP request against path. +func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOption) error { + visitors := c.config.Visitors[:] + for _, o := range opts { + if o.in == nil { + continue + } + // merge client-wide and request-specific visitors + visitors = append(visitors, o.in) + } + var requestBody any + for _, o := range opts { + if o.body == nil { + continue + } + requestBody = o.body + } + responseBody, err := c.perform(ctx, method, path, requestBody, visitors...) + if err != nil { + return err + } + for _, o := range opts { + if o.out == nil { + continue + } + err = o.out(responseBody) + if err != nil { + return err + } + } + return nil +} + +func (c *ApiClient) fromResponse(r *http.Response) (responseBody, error) { + if r == nil { + return responseBody{}, fmt.Errorf("nil response") + } + if r.Request == nil { + return responseBody{}, fmt.Errorf("nil request") + } + streamResponse := r.Request.Header.Get("Accept") != "application/json" && r.Header.Get("Content-Type") != "application/json" + if streamResponse { + return newResponseBody(r.Body, r.Header) + } + defer r.Body.Close() + bs, err := io.ReadAll(r.Body) + if err != nil { + return responseBody{}, fmt.Errorf("response body: %w", err) + } + return newResponseBody(bs, r.Header) +} + +func (c *ApiClient) redactedDump(prefix string, body []byte) (res string) { + return bodyLogger{ + debugTruncateBytes: c.config.DebugTruncateBytes, + }.redactedDump(prefix, body) +} + +func (c *ApiClient) isRetriable(ctx context.Context, err error) bool { + if c.config.ErrorRetriable(ctx, err) { + return true + } + _, isIO := err.(*url.Error) + if isIO { + // all IO errors are retriable + logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) + return true + } + message := err.Error() + // Handle transient errors for retries + for _, substring := range c.config.TransientErrors { + if strings.Contains(message, substring) { + logger.Debugf(ctx, "Attempting retry because of %#v", substring) + return true + } + } + // some API's recommend retries on HTTP 500, but we'll add that later + return false +} + +// Common error-handling logic for all responses that may need to be retried. +// +// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed +// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt. +// +// Always returns nil for the first parameter as there is no meaningful response body to return in the error case. +// +// If it is certain that an error should not be retried, use failRequest() instead. +func (c *ApiClient) handleError(ctx context.Context, err error, body requestBody) (*responseBody, *retries.Err) { + if !c.isRetriable(ctx, err) { + return c.failRequest(ctx, "non-retriable error", err) + } + if resetErr := body.reset(); resetErr != nil { + return nil, retries.Halt(resetErr) + } + return nil, retries.Continue(err) +} + +// Fails the request with a retries.Err to halt future retries. +func (c *ApiClient) failRequest(ctx context.Context, msg string, err error) (*responseBody, *retries.Err) { + logger.Debugf(ctx, "%s: %s", msg, err) + return nil, retries.Halt(err) +} + +func (c *ApiClient) attempt( + ctx context.Context, + method string, + requestURL string, + requestBody requestBody, + visitors ...RequestVisitor, +) func() (*responseBody, *retries.Err) { + return func() (*responseBody, *retries.Err) { + err := c.rateLimiter.Wait(ctx) + if err != nil { + return c.failRequest(ctx, "failed in rate limiter", err) + } + request, err := http.NewRequestWithContext(ctx, method, requestURL, requestBody.Reader) + if err != nil { + return c.failRequest(ctx, "failed creating new request", err) + } + for _, requestVisitor := range visitors { + err = requestVisitor(request) + if err != nil { + return c.failRequest(ctx, "failed during request visitor", err) + } + } + // request.Context() holds context potentially enhanced by visitors + request.Header.Set("User-Agent", useragent.FromContext(request.Context())) + + // attempt the actual request + response, err := c.httpClient.Do(request) + + // After this point, the request body has (probably) been consumed. handleError() must be called to reset it if + // possible. + if _, ok := err.(*url.Error); ok { + return c.handleError(ctx, err, requestBody) + } + + // By this point, the request body has certainly been consumed. + responseBody, responseBodyErr := c.fromResponse(response) + if responseBodyErr != nil { + return c.failRequest(ctx, "failed while reading response", responseBodyErr) + } + + mappedError := c.config.ErrorMapper(ctx, response, responseBody.ReadCloser) + defer c.recordRequestLog(ctx, request, response, mappedError, requestBody.DebugBytes, responseBody.DebugBytes) + + if mappedError == nil { + return &responseBody, nil + } + + // proactively release the connections in HTTP connection pool + c.httpClient.CloseIdleConnections() + return c.handleError(ctx, mappedError, requestBody) + } +} + +func (c *ApiClient) recordRequestLog( + ctx context.Context, + request *http.Request, + response *http.Response, + err error, + requestBody, responseBody []byte, +) { + // Don't compute expensive debug message if debug logging is not enabled. + if !logger.Get(ctx).Enabled(ctx, logger.LevelDebug) { + return + } + sb := strings.Builder{} + sb.WriteString(fmt.Sprintf("%s %s", request.Method, + escapeNewLines(request.URL.Path))) + if request.URL.RawQuery != "" { + sb.WriteString("?") + q, _ := url.QueryUnescape(request.URL.RawQuery) + sb.WriteString(q) + } + sb.WriteString("\n") + if c.config.DebugHeaders { + sb.WriteString("> * Host: ") + sb.WriteString(escapeNewLines(request.Host)) + sb.WriteString("\n") + for k, v := range request.Header { + trunc := onlyNBytes(strings.Join(v, ""), c.config.DebugTruncateBytes) + sb.WriteString(fmt.Sprintf("> * %s: %s\n", k, escapeNewLines(trunc))) + } + } + if len(requestBody) > 0 { + sb.WriteString(c.redactedDump("> ", requestBody)) + sb.WriteString("\n") + } + sb.WriteString("< ") + if response != nil { + sb.WriteString(fmt.Sprintf("%s %s", response.Proto, response.Status)) + // Only display error on this line if the response body is empty. + // Otherwise the response body will include details about the error. + if len(responseBody) == 0 && err != nil { + sb.WriteString(fmt.Sprintf(" (Error: %s)", err)) + } + } else { + sb.WriteString(fmt.Sprintf("Error: %s", err)) + } + sb.WriteString("\n") + if len(responseBody) > 0 { + sb.WriteString(c.redactedDump("< ", responseBody)) + } + logger.Debugf(ctx, sb.String()) // lgtm [go/log-injection] lgtm [go/clear-text-logging] +} + +// RoundTrip implements http.RoundTripper to integrate with golang.org/x/oauth2 +func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) { + ctx := request.Context() + requestURL := request.URL.String() + resp, err := retries.Poll(ctx, c.config.RetryTimeout, + c.attempt(ctx, request.Method, requestURL, requestBody{ + Reader: request.Body, + // DO NOT DECODE BODY, because it may contain sensitive payload, + // like Azure Service Principal in a multipart/form-data body. + DebugBytes: []byte(""), + })) + if err != nil { + return nil, err + } + // here we assume only successful responses, as HTTP 4XX and 5XX are mapped + // to Go's error implementations. + return &http.Response{ + Status: "OK", + StatusCode: 200, + Request: request, + Header: resp.Header, + Body: resp.ReadCloser, + }, nil +} + +// InContextForOAuth2 returns a context with a custom *http.Client to be used +// for only for token acquisition through golang.org/x/oauth2 package +func (c *ApiClient) InContextForOAuth2(ctx context.Context) context.Context { + return context.WithValue(ctx, oauth2.HTTPClient, &http.Client{ + Timeout: c.config.HTTPTimeout, + Transport: c, + }) +} + +func (c *ApiClient) perform( + ctx context.Context, + method, + requestURL string, + data interface{}, + visitors ...RequestVisitor, +) (*responseBody, error) { + requestBody, err := makeRequestBody(method, &requestURL, data) + if err != nil { + return nil, fmt.Errorf("request marshal: %w", err) + } + resp, err := retries.Poll(ctx, c.config.RetryTimeout, + c.attempt(ctx, method, requestURL, requestBody, visitors...)) + var timedOut *retries.ErrTimedOut + if errors.As(err, &timedOut) { + // TODO: check if we want to unwrap this error here + return nil, timedOut.Unwrap() + } else if err != nil { + // Don't re-wrap, as upper layers may depend on handling apierr.APIError. + return nil, err + } + return resp, nil +} + +func orDefault(configured, _default int) int { + if configured == 0 { + return _default + } + return configured +} + +// CWE-117 prevention +func escapeNewLines(in string) string { + in = strings.Replace(in, "\n", "", -1) + in = strings.Replace(in, "\r", "", -1) + return in +} diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go new file mode 100644 index 000000000..1b943b985 --- /dev/null +++ b/httpclient/api_client_test.go @@ -0,0 +1,589 @@ +package httpclient + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/databricks/databricks-sdk-go/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" +) + +// errReader(true) will also fail on Close +type errReader bool + +func (errReader) Read(p []byte) (n int, err error) { + return 0, fmt.Errorf("test error") +} + +func (i errReader) Close() error { + if i { + return fmt.Errorf("test error") + } + return nil +} + +type hc func(r *http.Request) (*http.Response, error) + +func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { + return cb(r) +} + +func TestNew(t *testing.T) { + c := NewApiClient(ClientConfig{}) + + require.Equal(t, 96, c.config.DebugTruncateBytes) + require.Equal(t, 5*time.Minute, c.config.RetryTimeout) +} + +func TestSimpleRequestFailsURLError(t *testing.T) { + c := NewApiClient(ClientConfig{ + RetryTimeout: 1 * time.Millisecond, + Transport: hc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, "GET", r.Method) + require.Equal(t, "/a/b", r.URL.Path) + require.Equal(t, "c=d", r.URL.RawQuery) + require.Equal(t, "f", r.Header.Get("e")) + return nil, fmt.Errorf("nope") + }), + }) + err := c.Do(context.Background(), "GET", "/a/b", WithHeaders(map[string]string{ + "e": "f", + }), WithData(map[string]string{ + "c": "d", + })) + require.EqualError(t, err, `Get "/a/b?c=d": nope`) +} + +func TestSimpleRequestFailsAPIError(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, "GET", r.Method) + require.Equal(t, "/a/b", r.URL.Path) + require.Equal(t, "c=d", r.URL.RawQuery) + require.Equal(t, "f", r.Header.Get("e")) + return &http.Response{ + StatusCode: 400, + Request: r, + Body: io.NopCloser(strings.NewReader(`nope`)), + }, nil + }), + }) + err := c.Do(context.Background(), "GET", "/a/b", WithHeaders(map[string]string{ + "e": "f", + }), WithData(map[string]string{ + "c": "d", + })) + require.EqualError(t, err, "http 400: nope") +} + +func TestSimpleRequestSucceeds(t *testing.T) { + type Dummy struct { + Foo int `json:"foo"` + } + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + }) + var resp Dummy + err := c.Do(context.Background(), "POST", "/c", WithData(Dummy{1}), WithUnmarshal(&resp)) + require.NoError(t, err) + require.Equal(t, 2, resp.Foo) +} + +func TestSimpleRequestRetried(t *testing.T) { + type Dummy struct { + Foo int `json:"foo"` + } + var retried [1]bool + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + if !retried[0] { + retried[0] = true + return nil, &url.Error{ + Op: "open", + URL: "/a/b", + Err: fmt.Errorf("connection refused"), + } + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + }) + var resp Dummy + err := c.Do(context.Background(), "PATCH", "/a", WithData(Dummy{1}), WithUnmarshal(&resp)) + require.NoError(t, err) + require.Equal(t, 2, resp.Foo) + require.True(t, retried[0], "request was not retried") +} + +func TestHaltAttemptForLimit(t *testing.T) { + ctx := context.Background() + c := &ApiClient{ + config: ClientConfig{}, + rateLimiter: &rate.Limiter{}, + } + req, err := newRequestBody([]byte{}) + require.NoError(t, err) + _, rerr := c.attempt(ctx, "GET", "foo", req)() + require.NotNil(t, rerr) + require.Equal(t, true, rerr.Halt) + require.EqualError(t, rerr.Err, "rate: Wait(n=1) exceeds limiter's burst 0") +} + +func TestHaltAttemptForNewRequest(t *testing.T) { + ctx := context.Background() + c := NewApiClient(ClientConfig{}) + req, err := newRequestBody([]byte{}) + require.NoError(t, err) + _, rerr := c.attempt(ctx, "🥱", "/", req)() + require.NotNil(t, rerr) + require.Equal(t, true, rerr.Halt) + require.EqualError(t, rerr.Err, `net/http: invalid method "🥱"`) +} + +func TestHaltAttemptForVisitor(t *testing.T) { + ctx := context.Background() + c := NewApiClient(ClientConfig{}) + req, err := newRequestBody([]byte{}) + require.NoError(t, err) + _, rerr := c.attempt(ctx, "GET", "/", req, + func(r *http.Request) error { + return fmt.Errorf("🥱") + })() + require.NotNil(t, rerr) + require.Equal(t, true, rerr.Halt) + require.EqualError(t, rerr.Err, "🥱") +} + +func TestMakeRequestBody(t *testing.T) { + type x struct { + Scope string `json:"scope" url:"scope"` + } + requestURL := "/a/b/c" + body, err := makeRequestBody("GET", &requestURL, x{"test"}) + require.NoError(t, err) + bodyBytes, err := io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, "/a/b/c?scope=test", requestURL) + require.Equal(t, 0, len(bodyBytes)) + + requestURL = "/a/b/c" + body, err = makeRequestBody("POST", &requestURL, x{"test"}) + require.NoError(t, err) + bodyBytes, err = io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, "/a/b/c", requestURL) + x1 := `{"scope":"test"}` + require.Equal(t, []byte(x1), bodyBytes) +} + +func TestMakeRequestBodyFromReader(t *testing.T) { + requestURL := "/a/b/c" + body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc")) + require.NoError(t, err) + bodyBytes, err := io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, []byte("abc"), bodyBytes) +} + +func TestMakeRequestBodyReaderError(t *testing.T) { + requestURL := "/a/b/c" + _, err := makeRequestBody("POST", &requestURL, errReader(false)) + // The request body is only read once the request is sent, so no error + // should be returned until then. + require.NoError(t, err, "request body reader error should be ignored") +} + +func TestMakeRequestBodyJsonError(t *testing.T) { + requestURL := "/a/b/c" + type x struct { + Foo chan string `json:"foo"` + } + _, err := makeRequestBody("POST", &requestURL, x{make(chan string)}) + require.EqualError(t, err, "request marshal failure: json: unsupported type: chan string") +} + +type failingUrlEncode string + +func (fue failingUrlEncode) EncodeValues(key string, v *url.Values) error { + return fmt.Errorf(string(fue)) +} + +func TestMakeRequestBodyQueryFailingEncode(t *testing.T) { + requestURL := "/a/b/c" + type x struct { + Foo failingUrlEncode `url:"foo"` + } + _, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}) + require.EqualError(t, err, "cannot create query string: always failing") +} + +func TestMakeRequestBodyQueryUnsupported(t *testing.T) { + requestURL := "/a/b/c" + _, err := makeRequestBody("GET", &requestURL, true) + require.EqualError(t, err, "unsupported query string data: true") +} + +func TestFailPerformChannel(t *testing.T) { + ctx := context.Background() + c := &ApiClient{ + rateLimiter: rate.NewLimiter(rate.Inf, 1), + } + _, err := c.perform(ctx, "GET", "/", true) + require.EqualError(t, err, "request marshal: unsupported query string data: true") +} + +func TestSimpleRequestAPIError(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader(`{ + "error_code": "NOT_FOUND", + "message": "Something was not found" + }`)), + Request: r, + }, nil + }), + }) + err := c.Do(context.Background(), "PATCH", "/a", WithData(map[string]any{})) + var httpErr *HttpError + if assert.ErrorAs(t, err, &httpErr) { + require.Equal(t, 400, httpErr.StatusCode) + } +} + +func TestSimpleRequestErrReaderBody(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errReader(false), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/json"} + err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + require.EqualError(t, err, "response body: test error") +} + +func TestSimpleRequestErrReaderBodyStreamResponse(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errReader(false), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/octet-stream"} + err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + require.NoError(t, err, "streaming response bodies are not read") +} + +func TestSimpleRequestErrReaderCloseBody(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errReader(true), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/json"} + err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + require.EqualError(t, err, "response body: test error") +} + +func TestSimpleRequestErrReaderCloseBody_StreamResponse(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errReader(true), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/octet-stream"} + err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + require.NoError(t, err, "response body should not be closed for streaming responses") +} + +func TestSimpleRequestRawResponse(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("Hello, world!")), + Request: r, + }, nil + }), + }) + var raw []byte + err := c.Do(context.Background(), "GET", "/a", WithUnmarshal(&raw)) + require.NoError(t, err) + require.Equal(t, "Hello, world!", string(raw)) +} + +type BufferLogger struct { + strings.Builder +} + +func (l *BufferLogger) Enabled(_ context.Context, level logger.Level) bool { + return true +} + +func (l *BufferLogger) Tracef(_ context.Context, format string, v ...interface{}) { + l.WriteString(fmt.Sprintf("[TRACE] "+format, v...)) +} + +func (l *BufferLogger) Debugf(_ context.Context, format string, v ...interface{}) { + l.WriteString(fmt.Sprintf("[DEBUG] "+format, v...)) +} + +func (l *BufferLogger) Infof(_ context.Context, format string, v ...interface{}) { + l.WriteString(fmt.Sprintf("[INFO] "+format, v...)) +} + +func (l *BufferLogger) Warnf(_ context.Context, format string, v ...interface{}) { + l.WriteString(fmt.Sprintf("[WARN] "+format, v...)) +} + +func (l *BufferLogger) Errorf(_ context.Context, format string, v ...interface{}) { + l.WriteString(fmt.Sprintf("[ERROR] "+format, v...)) +} + +func TestSimpleResponseRedaction(t *testing.T) { + prevLogger := logger.DefaultLogger + bufLogger := &BufferLogger{} + logger.DefaultLogger = bufLogger + defer func() { + logger.DefaultLogger = prevLogger + }() + + c := NewApiClient(ClientConfig{ + DebugTruncateBytes: 16, + DebugHeaders: true, + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Proto: "HTTP/3.4", + Status: "200 Fine", + Body: io.NopCloser(strings.NewReader(`{ + "string_value": "__SENSITIVE01__", + "inner": { + "token_value": "__SENSITIVE02__", + "content": "__SENSITIVE03__" + }, + "list": [ + { + "token_value": "__SENSITIVE04__" + } + ], + "longer": "12345678901234567890qwerty" + }`)), + Request: r, + }, nil + }), + }) + err := c.Do(context.Background(), "GET", "/a", WithData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) + require.NoError(t, err) + // not testing for exact logged lines, as header order is not deterministic + require.NotContains(t, bufLogger.String(), "__SENSITIVE01__") + require.NotContains(t, bufLogger.String(), "__SENSITIVE02__") + require.NotContains(t, bufLogger.String(), "__SENSITIVE03__") + require.NotContains(t, bufLogger.String(), "__SENSITIVE04__") + require.NotContains(t, bufLogger.String(), "12345678901234567890qwerty") +} + +func TestInlineArrayDebugging(t *testing.T) { + prevLogger := logger.DefaultLogger + bufLogger := &BufferLogger{} + logger.DefaultLogger = bufLogger + defer func() { + logger.DefaultLogger = prevLogger + }() + + c := NewApiClient(ClientConfig{ + DebugTruncateBytes: 2048, + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`[ + {"foo": "bar"} + ]`)), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/json"} + err := c.Do(context.Background(), "GET", "/a", WithHeaders(headers), WithData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) + require.NoError(t, err) + + require.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 +< +< [ +< { +< "foo": "bar" +< } +< ]`, bufLogger.String()) +} + +func TestInlineArrayDebugging_StreamResponse(t *testing.T) { + prevLogger := logger.DefaultLogger + bufLogger := &BufferLogger{} + logger.DefaultLogger = bufLogger + defer func() { + logger.DefaultLogger = prevLogger + }() + + c := NewApiClient(ClientConfig{ + DebugTruncateBytes: 2048, + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`lots of bytes`)), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/octet-stream"} + err := c.Do(context.Background(), "GET", "/a", WithHeaders(headers), WithData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) + require.NoError(t, err) + + require.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 +< +< [non-JSON document of 15 bytes]. `, bufLogger.String()) +} + +func TestStreamRequestFromFileWithReset(t *testing.T) { + // make a temporary file with some content + f, err := os.CreateTemp("", "databricks-client-test") + require.NoError(t, err) + defer os.Remove(f.Name()) + _, err = f.WriteString("hello world") + require.NoError(t, err) + require.NoError(t, f.Close()) + + // Make a reader that reads this file + r, err := os.Open(f.Name()) + require.NoError(t, err) + defer r.Close() + + succeed := false + handler := func(req *http.Request) (*http.Response, error) { + bytes, err := io.ReadAll(req.Body) + require.NoError(t, err) + require.Equal(t, "hello world", string(bytes)) + if succeed { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("succeeded")), + Request: req, + }, nil + } + succeed = true + return &http.Response{ + StatusCode: 429, + Body: io.NopCloser(strings.NewReader("failed")), + Request: req, + }, nil + } + + client := NewApiClient(ClientConfig{ + Transport: hc(handler), + }) + + respBytes := bytes.Buffer{} + err = client.Do(context.Background(), "POST", "/a", WithData(r), WithUnmarshal(&respBytes)) + require.NoError(t, err) + require.Equal(t, "succeeded", respBytes.String()) + require.True(t, succeed) +} + +type customReader struct{} + +func (c customReader) Read(p []byte) (n int, err error) { + return 0, nil +} + +func TestCannotRetryArbitraryReader(t *testing.T) { + client := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 429, + Request: r, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }), + }) + err := client.Do(context.Background(), "POST", "/a", WithData(customReader{})) + require.ErrorContains(t, err, "cannot reset reader of type httpclient.customReader") +} + +func TestRetryGetRequest(t *testing.T) { + // This test was added in response to https://github.com/databricks/terraform-provider-databricks/issues/2675. + succeed := false + handler := func(req *http.Request) (*http.Response, error) { + require.Nil(t, req.Body) + + if succeed { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("succeeded")), + Request: req, + }, nil + } + + succeed = true + return &http.Response{ + StatusCode: 429, + Body: io.NopCloser(strings.NewReader("failed")), + Request: req, + }, nil + } + + client := NewApiClient(ClientConfig{ + Transport: hc(handler), + }) + + respBytes := bytes.Buffer{} + err := client.Do(context.Background(), "GET", "/a", WithUnmarshal(&respBytes)) + require.NoError(t, err) + require.Equal(t, "succeeded", respBytes.String()) + require.True(t, succeed) +} diff --git a/client/body_logger.go b/httpclient/body_logger.go similarity index 94% rename from client/body_logger.go rename to httpclient/body_logger.go index de858c614..1a99474ee 100644 --- a/client/body_logger.go +++ b/httpclient/body_logger.go @@ -1,4 +1,4 @@ -package client +package httpclient import ( "bytes" @@ -13,9 +13,13 @@ type bodyLogger struct { } var redactKeys = map[string]bool{ - "string_value": true, - "token_value": true, - "content": true, + "string_value": true, + "token_value": true, + "content": true, + "access_token": true, + "refresh_token": true, + "token": true, + "password": true, } func (b bodyLogger) mask(m map[string]any) { diff --git a/client/body_logger_test.go b/httpclient/body_logger_test.go similarity index 98% rename from client/body_logger_test.go rename to httpclient/body_logger_test.go index 80f00afa4..f08395ef5 100644 --- a/client/body_logger_test.go +++ b/httpclient/body_logger_test.go @@ -1,4 +1,4 @@ -package client +package httpclient import ( "encoding/json" diff --git a/httpclient/errors.go b/httpclient/errors.go new file mode 100644 index 000000000..d966c4f22 --- /dev/null +++ b/httpclient/errors.go @@ -0,0 +1,53 @@ +package httpclient + +import ( + "context" + "fmt" + "io" + "net/http" +) + +type HttpError struct { + *http.Response + Message string + err error +} + +func (r *HttpError) Unwrap() error { + return r.err +} + +func (r *HttpError) Error() string { + return fmt.Sprintf("http %d: %s", r.StatusCode, r.Message) +} + +func DefaultErrorMapper(ctx context.Context, resp *http.Response, body io.ReadCloser) error { + if resp.StatusCode < 400 { + return nil + } + raw, err := io.ReadAll(body) + if err != nil { + return &HttpError{ + Response: resp, + Message: "failed to read response", + err: err, + } + } + return &HttpError{ + Response: resp, + Message: string(raw), + } +} + +func DefaultErrorRetriable(ctx context.Context, err error) bool { + switch some := err.(type) { + case *HttpError: + if some.StatusCode == 429 { + return true + } + if some.StatusCode == 504 { + return true + } + } + return false +} diff --git a/httpclient/request.go b/httpclient/request.go new file mode 100644 index 000000000..0684bd490 --- /dev/null +++ b/httpclient/request.go @@ -0,0 +1,182 @@ +package httpclient + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "reflect" + "sort" + "strings" + + "github.com/google/go-querystring/query" + "golang.org/x/oauth2" +) + +// Represents a request body. +// +// If the provided request data is an io.Reader, DebugBytes is set to +// "". Otherwise, DebugBytes is set to the marshaled JSON +// representation of the request data, and ReadCloser is set to a new +// io.ReadCloser that reads from DebugBytes. +// +// Request bodies are never closed by the client, hence only accepting +// io.Reader. +type requestBody struct { + Reader io.Reader + DebugBytes []byte +} + +func newRequestBody(data any) (requestBody, error) { + switch v := data.(type) { + case io.Reader: + return requestBody{ + Reader: v, + DebugBytes: []byte(""), + }, nil + case string: + return requestBody{ + Reader: strings.NewReader(v), + DebugBytes: []byte(v), + }, nil + case []byte: + return requestBody{ + Reader: bytes.NewReader(v), + DebugBytes: v, + }, nil + default: + bs, err := json.Marshal(data) + if err != nil { + return requestBody{}, fmt.Errorf("request marshal failure: %w", err) + } + return requestBody{ + Reader: bytes.NewReader(bs), + DebugBytes: bs, + }, nil + } +} + +// Reset a request body to its initial state. +// +// This is used to retry requests with a body that has already been read. +// If the request body is not resettable (i.e. not nil and of type other than +// strings.Reader or bytes.Reader), this will return an error. +func (r requestBody) reset() error { + if r.Reader == nil { + return nil + } + if v, ok := r.Reader.(io.Seeker); ok { + _, err := v.Seek(0, io.SeekStart) + return err + } else { + return fmt.Errorf("cannot reset reader of type %T", r.Reader) + } +} + +func WithHeader(k, v string) DoOption { + return DoOption{ + in: func(r *http.Request) error { + r.Header.Set(k, v) + return nil + }, + } +} + +func WithHeaders(headers map[string]string) DoOption { + return DoOption{ + in: func(r *http.Request) error { + for k, v := range headers { + r.Header.Set(k, v) + } + return nil + }, + } +} + +func WithTokenSource(ts oauth2.TokenSource) DoOption { + return DoOption{ + in: func(r *http.Request) error { + token, err := ts.Token() + if err != nil { + return fmt.Errorf("token: %w", err) + } + auth := fmt.Sprintf("%s %s", token.TokenType, token.AccessToken) + r.Header.Set("Authorization", auth) + return nil + }, + } +} + +func WithVisitor(visitor func(r *http.Request) error) DoOption { + return DoOption{ + in: visitor, + } +} + +func WithData(body any) DoOption { + return DoOption{ + body: body, + } +} + +func makeQueryString(data interface{}) (string, error) { + inputVal := reflect.ValueOf(data) + inputType := reflect.TypeOf(data) + if inputType.Kind() == reflect.Map { + s := []string{} + keys := inputVal.MapKeys() + // sort map keys by their string repr, so that tests can be deterministic + sort.Slice(keys, func(i, j int) bool { + return keys[i].String() < keys[j].String() + }) + for _, k := range keys { + v := inputVal.MapIndex(k) + if v.IsZero() { + continue + } + s = append(s, fmt.Sprintf("%s=%s", + strings.Replace(url.QueryEscape(fmt.Sprintf("%v", k.Interface())), "+", "%20", -1), + strings.Replace(url.QueryEscape(fmt.Sprintf("%v", v.Interface())), "+", "%20", -1))) + } + return "?" + strings.Join(s, "&"), nil + } + if inputType.Kind() == reflect.Struct { + params, err := query.Values(data) + if err != nil { + return "", fmt.Errorf("cannot create query string: %w", err) + } + // Query parameters may be nested, but the keys generated by + // query.Values use the "[" and "]" characters to represent nesting. + // Replace all instances of "[" with "." and "]" with empty string + // to make the query string compatible with the proto API. + // See the following for more information: + // https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule + protoCompatibleParams := make(url.Values) + for k, vs := range params { + newK := strings.Replace(k, "[", ".", -1) + newK = strings.Replace(newK, "]", "", -1) + for _, v := range vs { + protoCompatibleParams.Add(newK, v) + } + } + return "?" + protoCompatibleParams.Encode(), nil + } + return "", fmt.Errorf("unsupported query string data: %#v", data) +} + +func makeRequestBody(method string, requestURL *string, data interface{}) (requestBody, error) { + if data == nil { + return requestBody{}, nil + } + if method == "GET" || method == "DELETE" { + qs, err := makeQueryString(data) + if err != nil { + return requestBody{}, err + } + *requestURL += qs + return newRequestBody([]byte{}) + } + return newRequestBody(data) +} diff --git a/httpclient/response.go b/httpclient/response.go new file mode 100644 index 000000000..54b625b47 --- /dev/null +++ b/httpclient/response.go @@ -0,0 +1,83 @@ +package httpclient + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" +) + +// Represents a response body. +// +// Responses must always be closed. For non-streaming responses, they are closed +// during deserialization in the client (see unmarshall()). For streaming +// responses, they are returned to the caller, who is responsible for closing +// them. +type responseBody struct { + ReadCloser io.ReadCloser + DebugBytes []byte + Header http.Header +} + +func newResponseBody(data any, header http.Header) (responseBody, error) { + switch v := data.(type) { + case io.ReadCloser: + return responseBody{ + ReadCloser: v, + DebugBytes: []byte(""), + Header: header, + }, nil + case []byte: + return responseBody{ + ReadCloser: io.NopCloser(bytes.NewReader(v)), + DebugBytes: v, + Header: header, + }, nil + default: + return responseBody{}, errors.New("newResponseBody can only be called with io.ReadCloser or []byte") + } +} + +func WithCaptureHeader(key string, value *string) DoOption { + return DoOption{ + out: func(body *responseBody) error { + *value = body.Header.Get(key) + return nil + }, + } +} + +func WithUnmarshal(response any) DoOption { + return DoOption{ + out: func(body *responseBody) error { + if response == nil { + return nil + } + // If the destination is bytes.Buffer, write the body over there + if raw, ok := response.(*io.ReadCloser); ok { + *raw = body.ReadCloser + return nil + } + defer body.ReadCloser.Close() + bs, err := io.ReadAll(body.ReadCloser) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if len(bs) == 0 { + return nil + } + // If the destination is a byte slice or buffer, pass the body verbatim. + if raw, ok := response.(*[]byte); ok { + *raw = bs + return nil + } + if raw, ok := response.(*bytes.Buffer); ok { + _, err := raw.Write(bs) + return err + } + return json.Unmarshal(bs, &response) + }, + } +} diff --git a/retries/retries.go b/retries/retries.go index 77d0ae08f..73f624ddf 100644 --- a/retries/retries.go +++ b/retries/retries.go @@ -76,6 +76,18 @@ func Backoff(attempt int) time.Duration { return wait } +type ErrTimedOut struct { + err error +} + +func (et *ErrTimedOut) Error() string { + return fmt.Sprintf("timed out: %s", et.err) +} + +func (et *ErrTimedOut) Unwrap() error { + return et.err +} + func Wait(pctx context.Context, timeout time.Duration, fn WaitFn) error { ctx, cancel := context.WithTimeout(pctx, timeout) defer cancel() @@ -100,7 +112,7 @@ func Wait(pctx context.Context, timeout time.Duration, fn WaitFn) error { // stop when either this or parent context times out case <-ctx.Done(): timer.Stop() - return fmt.Errorf("timed out: %w", lastErr) + return &ErrTimedOut{lastErr} case <-timer.C: } } @@ -130,7 +142,7 @@ func Poll[T any](pctx context.Context, timeout time.Duration, fn func() (*T, *Er // stop when either this or parent context times out case <-ctx.Done(): timer.Stop() - return nil, fmt.Errorf("timed out: %w", lastErr) + return nil, &ErrTimedOut{lastErr} case <-timer.C: } } From fa7c2fb789992458e6d10435b3770abfc613cd32 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 24 Nov 2023 13:01:07 +0100 Subject: [PATCH 2/3] refactor tests --- client/client.go | 8 +- client/client_test.go | 264 ++++++++++++++++++++++-------------------- 2 files changed, 138 insertions(+), 134 deletions(-) diff --git a/client/client.go b/client/client.go index 55cb53bc3..e60c1bdc9 100644 --- a/client/client.go +++ b/client/client.go @@ -19,10 +19,6 @@ func New(cfg *config.Config) (*DatabricksClient, error) { if err != nil { return nil, err } - return newWithTransport(cfg, cfg.HTTPTransport), nil -} - -func newWithTransport(cfg *config.Config, transport http.RoundTripper) *DatabricksClient { retryTimeout := time.Duration(orDefault(cfg.RetryTimeoutSeconds, 300)) * time.Second httpTimeout := time.Duration(orDefault(cfg.HTTPTimeoutSeconds, 60)) * time.Second return &DatabricksClient{ @@ -34,7 +30,7 @@ func newWithTransport(cfg *config.Config, transport http.RoundTripper) *Databric DebugHeaders: cfg.DebugHeaders, DebugTruncateBytes: cfg.DebugTruncateBytes, InsecureSkipVerify: cfg.InsecureSkipVerify, - Transport: transport, + Transport: cfg.HTTPTransport, Visitors: []httpclient.RequestVisitor{ cfg.Authenticate, func(r *http.Request) error { @@ -87,7 +83,7 @@ func newWithTransport(cfg *config.Config, transport http.RoundTripper) *Databric return false }, }), - } + }, nil } type DatabricksClient struct { diff --git a/client/client_test.go b/client/client_test.go index 843f01163..3660dec43 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -12,7 +12,6 @@ import ( "github.com/databricks/databricks-sdk-go/apierr" "github.com/databricks/databricks-sdk-go/config" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -22,107 +21,106 @@ func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { return cb(r) } -func TestNew(t *testing.T) { - _, err := New(&config.Config{ - ConfigFile: "/dev/null", - }) - assert.NoError(t, err) -} - func TestSimpleRequestFailsURLError(t *testing.T) { - cfg := config.NewMockConfig(func(r *http.Request) error { - r.Header.Add("Authenticated", "yes") - return nil + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + RetryTimeoutSeconds: 1, + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + require.Equal(t, "GET", r.Method) + require.Equal(t, "/a/b", r.URL.Path) + require.Equal(t, "c=d", r.URL.RawQuery) + require.Equal(t, "f", r.Header.Get("e")) + auth := r.Header.Get("Authorization") + require.Equal(t, "Bearer token", auth) + return nil, fmt.Errorf("nope") + }), }) - cfg.RetryTimeoutSeconds = 1 - c := newWithTransport(cfg, hc(func(r *http.Request) (*http.Response, error) { - assert.Equal(t, "GET", r.Method) - assert.Equal(t, "/a/b", r.URL.Path) - assert.Equal(t, "c=d", r.URL.RawQuery) - assert.Equal(t, "f", r.Header.Get("e")) - auth := r.Header.Get("Authenticated") - assert.Equal(t, "yes", auth) - return nil, fmt.Errorf("nope") - })) - err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ + require.NoError(t, err) + err = c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ "c": "d", }, nil) - assert.EqualError(t, err, `Get "/a/b?c=d": nope`) + require.EqualError(t, err, `Get "https://some/a/b?c=d": nope`) } func TestSimpleRequestFailsAPIError(t *testing.T) { - c := *newWithTransport(config.NewMockConfig(func(r *http.Request) error { - r.Header.Add("Authenticated", "yes") - return nil - }), hc(func(r *http.Request) (*http.Response, error) { - assert.Equal(t, "GET", r.Method) - assert.Equal(t, "/a/b", r.URL.Path) - assert.Equal(t, "c=d", r.URL.RawQuery) - assert.Equal(t, "f", r.Header.Get("e")) - auth := r.Header.Get("Authenticated") - assert.Equal(t, "yes", auth) - return &http.Response{ - StatusCode: 400, - Request: r, - Body: io.NopCloser(strings.NewReader(`{"error_code": "INVALID_PARAMETER_VALUE", "message": "nope"}`)), - }, nil - })) - err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + RetryTimeoutSeconds: 1, + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Request: r, + Body: io.NopCloser(strings.NewReader(`{"error_code": "INVALID_PARAMETER_VALUE", "message": "nope"}`)), + }, nil + }), + }) + require.NoError(t, err) + err = c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ "c": "d", }, nil) - assert.EqualError(t, err, "nope") + require.EqualError(t, err, "nope") + require.ErrorIs(t, err, apierr.ErrInvalidParameterValue) } func TestETag(t *testing.T) { reason := "some_reason" domain := "a_domain" eTag := "sample_etag" - c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { - return nil - }), hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 400, - Request: r, - Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{ - "error_code": "RESOURCE_CONFLICT", - "message": "test_public_workspace_setting", - "stack_trace": "java.io.PrintWriter@329e4ed3", - "details": [ - { - "@type": "%s", - "reason": "%s", - "domain": "%s", - "metadata": { - "etag": "%s" - } - }, - { - "@type": "anotherType", - "reason": "", - "domain": "", - "metadata": { - "etag": "anotherTag" - } - } - ] - }`, "type.googleapis.com/google.rpc.ErrorInfo", reason, domain, eTag))), - }, nil - })) - err := c.Do(context.Background(), "GET", "/a/b", map[string]string{ + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + RetryTimeoutSeconds: 1, + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Request: r, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{ + "error_code": "RESOURCE_CONFLICT", + "message": "test_public_workspace_setting", + "stack_trace": "java.io.PrintWriter@329e4ed3", + "details": [ + { + "@type": "%s", + "reason": "%s", + "domain": "%s", + "metadata": { + "etag": "%s" + } + }, + { + "@type": "anotherType", + "reason": "", + "domain": "", + "metadata": { + "etag": "anotherTag" + } + } + ] + }`, "type.googleapis.com/google.rpc.ErrorInfo", reason, domain, eTag))), + }, nil + }), + }) + require.NoError(t, err) + err = c.Do(context.Background(), "GET", "/a/b", map[string]string{ "e": "f", }, map[string]string{ "c": "d", }, nil) details := apierr.GetErrorInfo(err) - assert.Equal(t, 1, len(details)) + require.Equal(t, 1, len(details)) errorDetails := details[0] - assert.Equal(t, reason, errorDetails.Reason) - assert.Equal(t, domain, errorDetails.Domain) - assert.Equal(t, map[string]string{ + require.Equal(t, reason, errorDetails.Reason) + require.Equal(t, domain, errorDetails.Domain) + require.Equal(t, map[string]string{ "etag": eTag, }, errorDetails.Metadata) } @@ -131,19 +129,23 @@ func TestSimpleRequestSucceeds(t *testing.T) { type Dummy struct { Foo int `json:"foo"` } - c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { - return nil - }), hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), - Request: r, - }, nil - })) + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + }) + require.NoError(t, err) var resp Dummy - err := c.Do(context.Background(), "POST", "/c", nil, Dummy{1}, &resp) - assert.NoError(t, err) - assert.Equal(t, 2, resp.Foo) + err = c.Do(context.Background(), "POST", "/c", nil, Dummy{1}, &resp) + require.NoError(t, err) + require.Equal(t, 2, resp.Foo) } func TestSimpleRequestRetried(t *testing.T) { @@ -151,49 +153,56 @@ func TestSimpleRequestRetried(t *testing.T) { Foo int `json:"foo"` } var retried [1]bool - c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { - return nil - }), hc(func(r *http.Request) (*http.Response, error) { - if !retried[0] { - retried[0] = true - return nil, &url.Error{ - Op: "open", - URL: "/a/b", - Err: fmt.Errorf("connection refused"), + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + if !retried[0] { + retried[0] = true + return nil, &url.Error{ + Op: "open", + URL: "/a/b", + Err: fmt.Errorf("connection refused"), + } } - } - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), - Request: r, - }, nil - })) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"foo": 2}`)), + Request: r, + }, nil + }), + }) + require.NoError(t, err) var resp Dummy - err := c.Do(context.Background(), "PATCH", "/a", nil, Dummy{1}, &resp) - assert.NoError(t, err) - assert.Equal(t, 2, resp.Foo) - assert.True(t, retried[0], "request was not retried") + err = c.Do(context.Background(), "PATCH", "/a", nil, Dummy{1}, &resp) + require.NoError(t, err) + require.Equal(t, 2, resp.Foo) + require.True(t, retried[0], "request was not retried") } func TestSimpleRequestAPIError(t *testing.T) { - c := newWithTransport(config.NewMockConfig(func(r *http.Request) error { - return nil - }), hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 404, - Body: io.NopCloser(strings.NewReader(`{ - "error_code": "NOT_FOUND", - "message": "Something was not found" - }`)), - Request: r, - }, nil - })) - err := c.Do(context.Background(), "PATCH", "/a", nil, map[string]any{}, nil) + c, err := New(&config.Config{ + Host: "some", + Token: "token", + ConfigFile: "/dev/null", + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 404, + Body: io.NopCloser(strings.NewReader(`{ + "error_code": "NOT_FOUND", + "message": "Something was not found" + }`)), + Request: r, + }, nil + }), + }) + require.NoError(t, err) + err = c.Do(context.Background(), "PATCH", "/a", nil, map[string]any{}, nil) var aerr *apierr.APIError - if assert.ErrorAs(t, err, &aerr) { - assert.Equal(t, "NOT_FOUND", aerr.ErrorCode) - } - assert.ErrorIs(t, err, apierr.ErrNotFound) + require.ErrorAs(t, err, &aerr) + require.Equal(t, "NOT_FOUND", aerr.ErrorCode) + require.ErrorIs(t, err, apierr.ErrNotFound) } func TestHttpTransport(t *testing.T) { @@ -208,6 +217,5 @@ func TestHttpTransport(t *testing.T) { err = client.Do(context.Background(), "GET", "/a", nil, nil, bytes.Buffer{}) require.NoError(t, err) - - assert.True(t, calledMock) + require.True(t, calledMock) } From c3188f42c1913f1f1fc30d8bac82fc9944418809 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 24 Nov 2023 13:49:04 +0100 Subject: [PATCH 3/3] cleanup --- client/client.go | 8 +- config/azure.go | 2 +- httpclient/api_client.go | 14 ++- httpclient/api_client_test.go | 231 ++++++++++++---------------------- httpclient/errors_test.go | 47 +++++++ httpclient/request.go | 63 +++++----- httpclient/request_test.go | 112 +++++++++++++++++ httpclient/response.go | 10 +- httpclient/response_test.go | 49 ++++++++ 9 files changed, 340 insertions(+), 196 deletions(-) create mode 100644 httpclient/errors_test.go create mode 100644 httpclient/request_test.go create mode 100644 httpclient/response_test.go diff --git a/client/client.go b/client/client.go index e60c1bdc9..aa86b4106 100644 --- a/client/client.go +++ b/client/client.go @@ -103,11 +103,11 @@ func (c *DatabricksClient) Do(ctx context.Context, method, path string, visitors ...func(*http.Request) error) error { opts := []httpclient.DoOption{} for _, v := range visitors { - opts = append(opts, httpclient.WithVisitor(v)) + opts = append(opts, httpclient.WithRequestVisitor(v)) } - opts = append(opts, httpclient.WithHeaders(headers)) - opts = append(opts, httpclient.WithData(request)) - opts = append(opts, httpclient.WithUnmarshal(response)) + opts = append(opts, httpclient.WithRequestHeaders(headers)) + opts = append(opts, httpclient.WithRequestData(request)) + opts = append(opts, httpclient.WithResponseUnmarshal(response)) return c.client.Do(ctx, method, path, opts...) } diff --git a/config/azure.go b/config/azure.go index 0b61eefa9..aa93f8140 100644 --- a/config/azure.go +++ b/config/azure.go @@ -87,7 +87,7 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol } requestURL := env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01" err = httpclient.DefaultClient.Do(ctx, "GET", requestURL, - httpclient.WithUnmarshal(&workspaceMetadata), + httpclient.WithResponseUnmarshal(&workspaceMetadata), httpclient.WithTokenSource(management), ) if err != nil { diff --git a/httpclient/api_client.go b/httpclient/api_client.go index d25447205..468c1ce49 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -1,4 +1,4 @@ -package httpclient // has to be a separate package than client, otherwise a circular dependency +package httpclient import ( "context" @@ -145,16 +145,20 @@ func (c *ApiClient) fromResponse(r *http.Response) (responseBody, error) { if r.Request == nil { return responseBody{}, fmt.Errorf("nil request") } + // SDK only supports using JSON for non-streaming requests/responses, as that + // is the only supported serde in the SDK. If you need to use any other content + // type, the SDK will just hand you an io.ReadCloser and you will be responsible + // for consuming the request body yourself. streamResponse := r.Request.Header.Get("Accept") != "application/json" && r.Header.Get("Content-Type") != "application/json" if streamResponse { - return newResponseBody(r.Body, r.Header) + return newResponseBody(r.Body, r.Header, r.StatusCode, r.Status) } defer r.Body.Close() bs, err := io.ReadAll(r.Body) if err != nil { return responseBody{}, fmt.Errorf("response body: %w", err) } - return newResponseBody(bs, r.Header) + return newResponseBody(bs, r.Header, r.StatusCode, r.Status) } func (c *ApiClient) redactedDump(prefix string, body []byte) (res string) { @@ -330,9 +334,9 @@ func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) { // here we assume only successful responses, as HTTP 4XX and 5XX are mapped // to Go's error implementations. return &http.Response{ - Status: "OK", - StatusCode: 200, Request: request, + Status: resp.Status, + StatusCode: resp.StatusCode, Header: resp.Header, Body: resp.ReadCloser, }, nil diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index 1b943b985..83819b585 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -13,8 +13,8 @@ import ( "time" "github.com/databricks/databricks-sdk-go/logger" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "golang.org/x/time/rate" ) @@ -56,11 +56,12 @@ func TestSimpleRequestFailsURLError(t *testing.T) { return nil, fmt.Errorf("nope") }), }) - err := c.Do(context.Background(), "GET", "/a/b", WithHeaders(map[string]string{ - "e": "f", - }), WithData(map[string]string{ - "c": "d", - })) + err := c.Do(context.Background(), "GET", "/a/b", + WithRequestHeaders(map[string]string{ + "e": "f", + }), WithRequestData(map[string]string{ + "c": "d", + })) require.EqualError(t, err, `Get "/a/b?c=d": nope`) } @@ -78,11 +79,12 @@ func TestSimpleRequestFailsAPIError(t *testing.T) { }, nil }), }) - err := c.Do(context.Background(), "GET", "/a/b", WithHeaders(map[string]string{ - "e": "f", - }), WithData(map[string]string{ - "c": "d", - })) + err := c.Do(context.Background(), "GET", "/a/b", + WithRequestHeaders(map[string]string{ + "e": "f", + }), WithRequestData(map[string]string{ + "c": "d", + })) require.EqualError(t, err, "http 400: nope") } @@ -100,7 +102,9 @@ func TestSimpleRequestSucceeds(t *testing.T) { }), }) var resp Dummy - err := c.Do(context.Background(), "POST", "/c", WithData(Dummy{1}), WithUnmarshal(&resp)) + err := c.Do(context.Background(), "POST", "/c", + WithRequestData(Dummy{1}), + WithResponseUnmarshal(&resp)) require.NoError(t, err) require.Equal(t, 2, resp.Foo) } @@ -128,7 +132,9 @@ func TestSimpleRequestRetried(t *testing.T) { }), }) var resp Dummy - err := c.Do(context.Background(), "PATCH", "/a", WithData(Dummy{1}), WithUnmarshal(&resp)) + err := c.Do(context.Background(), "PATCH", "/a", + WithRequestData(Dummy{1}), + WithResponseUnmarshal(&resp)) require.NoError(t, err) require.Equal(t, 2, resp.Foo) require.True(t, retried[0], "request was not retried") @@ -173,75 +179,6 @@ func TestHaltAttemptForVisitor(t *testing.T) { require.EqualError(t, rerr.Err, "🥱") } -func TestMakeRequestBody(t *testing.T) { - type x struct { - Scope string `json:"scope" url:"scope"` - } - requestURL := "/a/b/c" - body, err := makeRequestBody("GET", &requestURL, x{"test"}) - require.NoError(t, err) - bodyBytes, err := io.ReadAll(body.Reader) - require.NoError(t, err) - require.Equal(t, "/a/b/c?scope=test", requestURL) - require.Equal(t, 0, len(bodyBytes)) - - requestURL = "/a/b/c" - body, err = makeRequestBody("POST", &requestURL, x{"test"}) - require.NoError(t, err) - bodyBytes, err = io.ReadAll(body.Reader) - require.NoError(t, err) - require.Equal(t, "/a/b/c", requestURL) - x1 := `{"scope":"test"}` - require.Equal(t, []byte(x1), bodyBytes) -} - -func TestMakeRequestBodyFromReader(t *testing.T) { - requestURL := "/a/b/c" - body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc")) - require.NoError(t, err) - bodyBytes, err := io.ReadAll(body.Reader) - require.NoError(t, err) - require.Equal(t, []byte("abc"), bodyBytes) -} - -func TestMakeRequestBodyReaderError(t *testing.T) { - requestURL := "/a/b/c" - _, err := makeRequestBody("POST", &requestURL, errReader(false)) - // The request body is only read once the request is sent, so no error - // should be returned until then. - require.NoError(t, err, "request body reader error should be ignored") -} - -func TestMakeRequestBodyJsonError(t *testing.T) { - requestURL := "/a/b/c" - type x struct { - Foo chan string `json:"foo"` - } - _, err := makeRequestBody("POST", &requestURL, x{make(chan string)}) - require.EqualError(t, err, "request marshal failure: json: unsupported type: chan string") -} - -type failingUrlEncode string - -func (fue failingUrlEncode) EncodeValues(key string, v *url.Values) error { - return fmt.Errorf(string(fue)) -} - -func TestMakeRequestBodyQueryFailingEncode(t *testing.T) { - requestURL := "/a/b/c" - type x struct { - Foo failingUrlEncode `url:"foo"` - } - _, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}) - require.EqualError(t, err, "cannot create query string: always failing") -} - -func TestMakeRequestBodyQueryUnsupported(t *testing.T) { - requestURL := "/a/b/c" - _, err := makeRequestBody("GET", &requestURL, true) - require.EqualError(t, err, "unsupported query string data: true") -} - func TestFailPerformChannel(t *testing.T) { ctx := context.Background() c := &ApiClient{ @@ -251,41 +188,6 @@ func TestFailPerformChannel(t *testing.T) { require.EqualError(t, err, "request marshal: unsupported query string data: true") } -func TestSimpleRequestAPIError(t *testing.T) { - c := NewApiClient(ClientConfig{ - Transport: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 400, - Body: io.NopCloser(strings.NewReader(`{ - "error_code": "NOT_FOUND", - "message": "Something was not found" - }`)), - Request: r, - }, nil - }), - }) - err := c.Do(context.Background(), "PATCH", "/a", WithData(map[string]any{})) - var httpErr *HttpError - if assert.ErrorAs(t, err, &httpErr) { - require.Equal(t, 400, httpErr.StatusCode) - } -} - -func TestSimpleRequestErrReaderBody(t *testing.T) { - c := NewApiClient(ClientConfig{ - Transport: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: errReader(false), - Request: r, - }, nil - }), - }) - headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) - require.EqualError(t, err, "response body: test error") -} - func TestSimpleRequestErrReaderBodyStreamResponse(t *testing.T) { c := NewApiClient(ClientConfig{ Transport: hc(func(r *http.Request) (*http.Response, error) { @@ -297,7 +199,9 @@ func TestSimpleRequestErrReaderBodyStreamResponse(t *testing.T) { }), }) headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + err := c.Do(context.Background(), "PATCH", "/a", + WithRequestHeaders(headers), + WithRequestData(map[string]any{})) require.NoError(t, err, "streaming response bodies are not read") } @@ -312,7 +216,9 @@ func TestSimpleRequestErrReaderCloseBody(t *testing.T) { }), }) headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + err := c.Do(context.Background(), "PATCH", "/a", + WithRequestHeaders(headers), + WithRequestData(map[string]any{})) require.EqualError(t, err, "response body: test error") } @@ -327,26 +233,12 @@ func TestSimpleRequestErrReaderCloseBody_StreamResponse(t *testing.T) { }), }) headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "PATCH", "/a", WithHeaders(headers), WithData(map[string]any{})) + err := c.Do(context.Background(), "PATCH", "/a", + WithRequestHeaders(headers), + WithRequestData(map[string]any{})) require.NoError(t, err, "response body should not be closed for streaming responses") } -func TestSimpleRequestRawResponse(t *testing.T) { - c := NewApiClient(ClientConfig{ - Transport: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("Hello, world!")), - Request: r, - }, nil - }), - }) - var raw []byte - err := c.Do(context.Background(), "GET", "/a", WithUnmarshal(&raw)) - require.NoError(t, err) - require.Equal(t, "Hello, world!", string(raw)) -} - type BufferLogger struct { strings.Builder } @@ -408,11 +300,12 @@ func TestSimpleResponseRedaction(t *testing.T) { }, nil }), }) - err := c.Do(context.Background(), "GET", "/a", WithData(map[string]any{ - "b": 0, - "a": 3, - "c": 23, - })) + err := c.Do(context.Background(), "GET", "/a", + WithRequestData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) require.NoError(t, err) // not testing for exact logged lines, as header order is not deterministic require.NotContains(t, bufLogger.String(), "__SENSITIVE01__") @@ -443,11 +336,13 @@ func TestInlineArrayDebugging(t *testing.T) { }), }) headers := map[string]string{"Accept": "application/json"} - err := c.Do(context.Background(), "GET", "/a", WithHeaders(headers), WithData(map[string]any{ - "b": 0, - "a": 3, - "c": 23, - })) + err := c.Do(context.Background(), "GET", "/a", + WithRequestHeaders(headers), + WithRequestData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) require.NoError(t, err) require.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 @@ -478,11 +373,13 @@ func TestInlineArrayDebugging_StreamResponse(t *testing.T) { }), }) headers := map[string]string{"Accept": "application/octet-stream"} - err := c.Do(context.Background(), "GET", "/a", WithHeaders(headers), WithData(map[string]any{ - "b": 0, - "a": 3, - "c": 23, - })) + err := c.Do(context.Background(), "GET", "/a", + WithRequestHeaders(headers), + WithRequestData(map[string]any{ + "b": 0, + "a": 3, + "c": 23, + })) require.NoError(t, err) require.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 @@ -529,7 +426,9 @@ func TestStreamRequestFromFileWithReset(t *testing.T) { }) respBytes := bytes.Buffer{} - err = client.Do(context.Background(), "POST", "/a", WithData(r), WithUnmarshal(&respBytes)) + err = client.Do(context.Background(), "POST", "/a", + WithRequestData(r), + WithResponseUnmarshal(&respBytes)) require.NoError(t, err) require.Equal(t, "succeeded", respBytes.String()) require.True(t, succeed) @@ -551,7 +450,8 @@ func TestCannotRetryArbitraryReader(t *testing.T) { }, nil }), }) - err := client.Do(context.Background(), "POST", "/a", WithData(customReader{})) + err := client.Do(context.Background(), "POST", "/a", + WithRequestData(customReader{})) require.ErrorContains(t, err, "cannot reset reader of type httpclient.customReader") } @@ -582,8 +482,31 @@ func TestRetryGetRequest(t *testing.T) { }) respBytes := bytes.Buffer{} - err := client.Do(context.Background(), "GET", "/a", WithUnmarshal(&respBytes)) + err := client.Do(context.Background(), "GET", "/a", + WithResponseUnmarshal(&respBytes)) require.NoError(t, err) require.Equal(t, "succeeded", respBytes.String()) require.True(t, succeed) } + +func TestOAuth2Integration(t *testing.T) { + inner := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 204, + Request: r, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }), + }) + + ctx := context.Background() + ctx = inner.InContextForOAuth2(ctx) + + outer := oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: "abc", + })) + res, err := outer.Get("abc") + require.NoError(t, err) + require.Equal(t, 204, res.StatusCode) +} diff --git a/httpclient/errors_test.go b/httpclient/errors_test.go new file mode 100644 index 000000000..a9227ced5 --- /dev/null +++ b/httpclient/errors_test.go @@ -0,0 +1,47 @@ +package httpclient + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSimpleRequestAPIError(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader(`{ + "error_code": "NOT_FOUND", + "message": "Something was not found" + }`)), + Request: r, + }, nil + }), + }) + err := c.Do(context.Background(), "PATCH", "/a", WithRequestData(map[string]any{})) + var httpErr *HttpError + if assert.ErrorAs(t, err, &httpErr) { + require.Equal(t, 400, httpErr.StatusCode) + } +} + +func TestSimpleRequestErrReaderBody(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: errReader(false), + Request: r, + }, nil + }), + }) + headers := map[string]string{"Accept": "application/json"} + err := c.Do(context.Background(), "PATCH", "/a", WithRequestHeaders(headers), WithRequestData(map[string]any{})) + require.EqualError(t, err, "response body: test error") +} diff --git a/httpclient/request.go b/httpclient/request.go index 0684bd490..88b4eb945 100644 --- a/httpclient/request.go +++ b/httpclient/request.go @@ -75,47 +75,52 @@ func (r requestBody) reset() error { } } -func WithHeader(k, v string) DoOption { - return DoOption{ - in: func(r *http.Request) error { - r.Header.Set(k, v) - return nil - }, - } +// WithRequestHeader adds a request visitor, that sets a header on a request +func WithRequestHeader(k, v string) DoOption { + return WithRequestVisitor(func(r *http.Request) error { + r.Header.Set(k, v) + return nil + }) } -func WithHeaders(headers map[string]string) DoOption { - return DoOption{ - in: func(r *http.Request) error { - for k, v := range headers { - r.Header.Set(k, v) - } - return nil - }, - } +// WithRequestHeaders adds a request visitor, that set all headers from a map +func WithRequestHeaders(headers map[string]string) DoOption { + return WithRequestVisitor(func(r *http.Request) error { + for k, v := range headers { + r.Header.Set(k, v) + } + return nil + }) } +// WithTokenSource uses the specified golang.org/x/oauth2 token source on a request func WithTokenSource(ts oauth2.TokenSource) DoOption { - return DoOption{ - in: func(r *http.Request) error { - token, err := ts.Token() - if err != nil { - return fmt.Errorf("token: %w", err) - } - auth := fmt.Sprintf("%s %s", token.TokenType, token.AccessToken) - r.Header.Set("Authorization", auth) - return nil - }, - } + return WithRequestVisitor(func(r *http.Request) error { + token, err := ts.Token() + if err != nil { + return fmt.Errorf("token: %w", err) + } + auth := fmt.Sprintf("%s %s", token.TokenType, token.AccessToken) + r.Header.Set("Authorization", auth) + return nil + }) } -func WithVisitor(visitor func(r *http.Request) error) DoOption { +// WithRequestVisitor applies given function on a request +func WithRequestVisitor(visitor func(r *http.Request) error) DoOption { return DoOption{ in: visitor, } } -func WithData(body any) DoOption { +// WithRequestData takes either a struct instance, map, string, bytes, or io.Reader +// and sends it either as query string for GET and DELETE calls, or as request body +// for POST, PUT, and PATCH calls. +// +// Experimental: this method may eventually be split into more granular options. +func WithRequestData(body any) DoOption { + // refactor this, so that we split JSON/query string serialization and make + // separate request visitors internally. return DoOption{ body: body, } diff --git a/httpclient/request_test.go b/httpclient/request_test.go new file mode 100644 index 000000000..c65d56dbe --- /dev/null +++ b/httpclient/request_test.go @@ -0,0 +1,112 @@ +package httpclient + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestMakeRequestBody(t *testing.T) { + type x struct { + Scope string `json:"scope" url:"scope"` + } + requestURL := "/a/b/c" + body, err := makeRequestBody("GET", &requestURL, x{"test"}) + require.NoError(t, err) + bodyBytes, err := io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, "/a/b/c?scope=test", requestURL) + require.Equal(t, 0, len(bodyBytes)) + + requestURL = "/a/b/c" + body, err = makeRequestBody("POST", &requestURL, x{"test"}) + require.NoError(t, err) + bodyBytes, err = io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, "/a/b/c", requestURL) + x1 := `{"scope":"test"}` + require.Equal(t, []byte(x1), bodyBytes) +} + +func TestMakeRequestBodyFromReader(t *testing.T) { + requestURL := "/a/b/c" + body, err := makeRequestBody("PUT", &requestURL, strings.NewReader("abc")) + require.NoError(t, err) + bodyBytes, err := io.ReadAll(body.Reader) + require.NoError(t, err) + require.Equal(t, []byte("abc"), bodyBytes) +} + +func TestMakeRequestBodyReaderError(t *testing.T) { + requestURL := "/a/b/c" + _, err := makeRequestBody("POST", &requestURL, errReader(false)) + // The request body is only read once the request is sent, so no error + // should be returned until then. + require.NoError(t, err, "request body reader error should be ignored") +} + +func TestMakeRequestBodyJsonError(t *testing.T) { + requestURL := "/a/b/c" + type x struct { + Foo chan string `json:"foo"` + } + _, err := makeRequestBody("POST", &requestURL, x{make(chan string)}) + require.EqualError(t, err, "request marshal failure: json: unsupported type: chan string") +} + +type failingUrlEncode string + +func (fue failingUrlEncode) EncodeValues(key string, v *url.Values) error { + return fmt.Errorf(string(fue)) +} + +func TestMakeRequestBodyQueryFailingEncode(t *testing.T) { + requestURL := "/a/b/c" + type x struct { + Foo failingUrlEncode `url:"foo"` + } + _, err := makeRequestBody("GET", &requestURL, x{failingUrlEncode("always failing")}) + require.EqualError(t, err, "cannot create query string: always failing") +} + +func TestMakeRequestBodyQueryUnsupported(t *testing.T) { + requestURL := "/a/b/c" + _, err := makeRequestBody("GET", &requestURL, true) + require.EqualError(t, err, "unsupported query string data: true") +} + +func TestWithTokenSource(t *testing.T) { + client := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + foo := r.Header.Get("Foo") + require.Equal(t, "bar", foo) + token := r.Header.Get("Authorization") + reader := strings.NewReader(token) + return &http.Response{ + StatusCode: 204, + Request: r, + Body: io.NopCloser(reader), + }, nil + }), + }) + + var buf bytes.Buffer + ctx := context.Background() + err := client.Do(ctx, "GET", "abc", + WithResponseUnmarshal(&buf), + WithRequestHeader("Foo", "bar"), + WithTokenSource(oauth2.StaticTokenSource(&oauth2.Token{ + TokenType: "awesome", + AccessToken: "token", + }))) + require.NoError(t, err) + require.Equal(t, "awesome token", buf.String()) +} diff --git a/httpclient/response.go b/httpclient/response.go index 54b625b47..c4321614c 100644 --- a/httpclient/response.go +++ b/httpclient/response.go @@ -19,15 +19,19 @@ type responseBody struct { ReadCloser io.ReadCloser DebugBytes []byte Header http.Header + Status string + StatusCode int } -func newResponseBody(data any, header http.Header) (responseBody, error) { +func newResponseBody(data any, header http.Header, statusCode int, status string) (responseBody, error) { switch v := data.(type) { case io.ReadCloser: return responseBody{ ReadCloser: v, DebugBytes: []byte(""), Header: header, + StatusCode: statusCode, + Status: status, }, nil case []byte: return responseBody{ @@ -40,7 +44,7 @@ func newResponseBody(data any, header http.Header) (responseBody, error) { } } -func WithCaptureHeader(key string, value *string) DoOption { +func WithResponseHeader(key string, value *string) DoOption { return DoOption{ out: func(body *responseBody) error { *value = body.Header.Get(key) @@ -49,7 +53,7 @@ func WithCaptureHeader(key string, value *string) DoOption { } } -func WithUnmarshal(response any) DoOption { +func WithResponseUnmarshal(response any) DoOption { return DoOption{ out: func(body *responseBody) error { if response == nil { diff --git a/httpclient/response_test.go b/httpclient/response_test.go new file mode 100644 index 000000000..e4aab5798 --- /dev/null +++ b/httpclient/response_test.go @@ -0,0 +1,49 @@ +package httpclient + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSimpleRequestRawResponse(t *testing.T) { + c := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("Hello, world!")), + Request: r, + }, nil + }), + }) + var raw []byte + err := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&raw)) + require.NoError(t, err) + require.Equal(t, "Hello, world!", string(raw)) +} + +func TestWithResponseHeader(t *testing.T) { + client := NewApiClient(ClientConfig{ + Transport: hc(func(r *http.Request) (*http.Response, error) { + return &http.Response{ + Request: r, + StatusCode: 204, + Header: http.Header{ + "Foo": []string{"some"}, + }, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }), + }) + + var out string + ctx := context.Background() + err := client.Do(ctx, "GET", "abc", + WithResponseHeader("Foo", &out)) + require.NoError(t, err) + require.Equal(t, "some", out) +}