diff --git a/apierr/errors.go b/apierr/errors.go index d2dc9e1c3..38925a314 100644 --- a/apierr/errors.go +++ b/apierr/errors.go @@ -223,7 +223,12 @@ func parseUnknownError(resp *http.Response, requestBody, responseBody []byte, er } func MakeUnexpectedError(resp *http.Response, err error, requestBody, responseBody []byte) error { + var req *http.Request + if resp != nil { + req = resp.Request + } rts := httplog.RoundTripStringer{ + Request: req, Response: resp, Err: err, RequestBody: requestBody, diff --git a/httpclient/api_client.go b/httpclient/api_client.go index ad4c4c63b..ff3163fca 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -194,7 +194,7 @@ func (c *ApiClient) isRetriable(ctx context.Context, err error) bool { // If it is certain that an error should not be retried, use failRequest() instead. func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) { if !c.isRetriable(ctx, err) { - return c.failRequest(ctx, "non-retriable error", err) + return nil, retries.Halt(err) } if resetErr := body.Reset(); resetErr != nil { return nil, retries.Halt(resetErr) @@ -203,8 +203,8 @@ func (c *ApiClient) handleError(ctx context.Context, err error, body common.Requ } // Fails the request with a retries.Err to halt future retries. -func (c *ApiClient) failRequest(ctx context.Context, msg string, err error) (*common.ResponseWrapper, *retries.Err) { - logger.Debugf(ctx, "%s: %s", msg, err) +func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) { + err = fmt.Errorf("%s: %w", msg, err) return nil, retries.Halt(err) } @@ -218,7 +218,7 @@ func (c *ApiClient) attempt( return func() (*common.ResponseWrapper, *retries.Err) { err := c.rateLimiter.Wait(ctx) if err != nil { - return c.failRequest(ctx, "failed in rate limiter", err) + return c.failRequest("failed in rate limiter", err) } pctx := ctx @@ -229,12 +229,12 @@ func (c *ApiClient) attempt( ctx, ticker := newTimeoutContext(pctx, c.config.HTTPTimeout) request, err := http.NewRequestWithContext(ctx, method, requestURL, requestBody.Reader) if err != nil { - return c.failRequest(ctx, "failed creating new request", err) + return c.failRequest("failed creating new request", err) } for _, requestVisitor := range visitors { err = requestVisitor(request) if err != nil { - return c.failRequest(ctx, "failed during request visitor", err) + return c.failRequest("failed during request visitor", err) } } // Set traceparent for distributed tracing. @@ -263,27 +263,28 @@ func (c *ApiClient) attempt( if pctx.Err() == nil && uerr.Err == context.Canceled { uerr.Err = fmt.Errorf("request timed out after %s of inactivity", c.config.HTTPTimeout) } - return c.handleError(ctx, err, requestBody) } // If there is a response body, wrap it to extend the request timeout while it is being read. if response != nil && response.Body != nil { response.Body = newResponseBodyTicker(ticker, response.Body) } else { - // If there is no response body, the request has completed and there - // is no need to extend the timeout. Cancel the context to clean up - // the underlying goroutine. + // If there is no response body or an error is returned, the request + // has completed and there is no need to extend the timeout. Cancel + // the context to clean up the underlying goroutine. ticker.Cancel() } // By this point, the request body has certainly been consumed. - responseWrapper, err := common.NewResponseWrapper(response, requestBody) - if err != nil { - return c.failRequest(ctx, "failed while reading response", err) + var responseWrapper common.ResponseWrapper + if err == nil { + responseWrapper, err = common.NewResponseWrapper(response, requestBody) + } + if err == nil { + err = c.config.ErrorMapper(ctx, responseWrapper) } - err = c.config.ErrorMapper(ctx, responseWrapper) - defer c.recordRequestLog(ctx, request, response, err, requestBody.DebugBytes, responseWrapper.DebugBytes) + c.recordRequestLog(ctx, request, response, err, requestBody.DebugBytes, responseWrapper.DebugBytes) if err == nil { return &responseWrapper, nil @@ -307,6 +308,7 @@ func (c *ApiClient) recordRequestLog( return } message := httplog.RoundTripStringer{ + Request: request, Response: response, Err: err, RequestBody: requestBody, diff --git a/httpclient/api_client_test.go b/httpclient/api_client_test.go index ba0ab9d8e..2c25ba556 100644 --- a/httpclient/api_client_test.go +++ b/httpclient/api_client_test.go @@ -188,7 +188,7 @@ func TestHaltAttemptForLimit(t *testing.T) { _, rerr := c.attempt(ctx, "GET", "foo", req)() require.NotNil(t, rerr) require.Equal(t, true, rerr.Halt) - require.EqualError(t, rerr.Err, "rate: Wait(n=1) exceeds limiter's burst 0") + require.EqualError(t, rerr.Err, "failed in rate limiter: rate: Wait(n=1) exceeds limiter's burst 0") } func TestHaltAttemptForNewRequest(t *testing.T) { @@ -199,7 +199,7 @@ func TestHaltAttemptForNewRequest(t *testing.T) { _, rerr := c.attempt(ctx, "🥱", "/", req)() require.NotNil(t, rerr) require.Equal(t, true, rerr.Halt) - require.EqualError(t, rerr.Err, `net/http: invalid method "🥱"`) + require.EqualError(t, rerr.Err, `failed creating new request: net/http: invalid method "🥱"`) } func TestHaltAttemptForVisitor(t *testing.T) { @@ -213,7 +213,7 @@ func TestHaltAttemptForVisitor(t *testing.T) { })() require.NotNil(t, rerr) require.Equal(t, true, rerr.Halt) - require.EqualError(t, rerr.Err, "🥱") + require.EqualError(t, rerr.Err, "failed during request visitor: 🥱") } func TestFailPerformChannel(t *testing.T) { @@ -334,32 +334,37 @@ func (l *BufferLogger) Enabled(_ context.Context, level logger.Level) bool { } func (l *BufferLogger) Tracef(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[TRACE] "+format, v...)) + l.WriteString(fmt.Sprintf("[TRACE] "+format+"\n", v...)) } func (l *BufferLogger) Debugf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[DEBUG] "+format, v...)) + l.WriteString(fmt.Sprintf("[DEBUG] "+format+"\n", v...)) } func (l *BufferLogger) Infof(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[INFO] "+format, v...)) + l.WriteString(fmt.Sprintf("[INFO] "+format+"\n", v...)) } func (l *BufferLogger) Warnf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[WARN] "+format, v...)) + l.WriteString(fmt.Sprintf("[WARN] "+format+"\n", v...)) } func (l *BufferLogger) Errorf(_ context.Context, format string, v ...interface{}) { - l.WriteString(fmt.Sprintf("[ERROR] "+format, v...)) + l.WriteString(fmt.Sprintf("[ERROR] "+format+"\n", v...)) } -func TestSimpleResponseRedaction(t *testing.T) { +func configureBufferedLogger(t *testing.T) *BufferLogger { prevLogger := logger.DefaultLogger bufLogger := &BufferLogger{} logger.DefaultLogger = bufLogger - defer func() { + t.Cleanup(func() { logger.DefaultLogger = prevLogger - }() + }) + return bufLogger +} + +func TestSimpleResponseRedaction(t *testing.T) { + bufLogger := configureBufferedLogger(t) c := NewApiClient(ClientConfig{ DebugTruncateBytes: 16, @@ -402,12 +407,7 @@ func TestSimpleResponseRedaction(t *testing.T) { } func TestInlineArrayDebugging(t *testing.T) { - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() + bufLogger := configureBufferedLogger(t) c := NewApiClient(ClientConfig{ DebugTruncateBytes: 2048, @@ -437,16 +437,12 @@ func TestInlineArrayDebugging(t *testing.T) { < { < "foo": "bar" < } -< ]`, bufLogger.String()) +< ] +`, bufLogger.String()) } func TestInlineArrayDebugging_StreamResponse(t *testing.T) { - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() + bufLogger := configureBufferedLogger(t) c := NewApiClient(ClientConfig{ DebugTruncateBytes: 2048, @@ -470,16 +466,12 @@ func TestInlineArrayDebugging_StreamResponse(t *testing.T) { require.Equal(t, `[DEBUG] GET /a?a=3&b=0&c=23 < -< `, bufLogger.String()) +< +`, bufLogger.String()) } func TestLogQueryParametersWithPercent(t *testing.T) { - prevLogger := logger.DefaultLogger - bufLogger := &BufferLogger{} - logger.DefaultLogger = bufLogger - defer func() { - logger.DefaultLogger = prevLogger - }() + bufLogger := configureBufferedLogger(t) c := NewApiClient(ClientConfig{ DebugTruncateBytes: 2048, @@ -502,7 +494,27 @@ func TestLogQueryParametersWithPercent(t *testing.T) { < < { < "foo": "bar" -< }`, bufLogger.String()) +< } +`, bufLogger.String()) +} + +func TestLogCancelledRequest(t *testing.T) { + bufLogger := configureBufferedLogger(t) + + ctx, cancel := context.WithCancel(context.Background()) + c := NewApiClient(ClientConfig{ + DebugTruncateBytes: 2048, + Transport: hc(func(r *http.Request) (*http.Response, error) { + cancel() + return nil, ctx.Err() + }), + }) + err := c.Do(context.Background(), "GET", "/a") + assert.Error(t, err) + assert.Equal(t, `[DEBUG] GET /a +< Error: Get "/a": request timed out after 30s of inactivity +[DEBUG] non-retriable error: Get "/a": request timed out after 30s of inactivity +`, bufLogger.String()) } func TestStreamRequestFromFileWithReset(t *testing.T) { diff --git a/logger/httplog/round_trip_stringer.go b/logger/httplog/round_trip_stringer.go index d160f1af6..1cd23ba32 100644 --- a/logger/httplog/round_trip_stringer.go +++ b/logger/httplog/round_trip_stringer.go @@ -10,6 +10,7 @@ import ( ) type RoundTripStringer struct { + Request *http.Request Response *http.Response Err error RequestBody []byte @@ -45,22 +46,21 @@ func (r RoundTripStringer) writeHeaders(sb *strings.Builder, prefix string, head } func (r RoundTripStringer) String() string { - request := r.Response.Request sb := strings.Builder{} - sb.WriteString(fmt.Sprintf("%s %s", request.Method, - escapeNewLines(request.URL.Path))) - if request.URL.RawQuery != "" { + sb.WriteString(fmt.Sprintf("%s %s", r.Request.Method, + escapeNewLines(r.Request.URL.Path))) + if r.Request.URL.RawQuery != "" { sb.WriteString("?") - q, _ := url.QueryUnescape(request.URL.RawQuery) + q, _ := url.QueryUnescape(r.Request.URL.RawQuery) sb.WriteString(q) } sb.WriteString("\n") if r.DebugHeaders { sb.WriteString("> * Host: ") - sb.WriteString(escapeNewLines(request.Host)) + sb.WriteString(escapeNewLines(r.Request.Host)) sb.WriteString("\n") - if len(request.Header) > 0 { - r.writeHeaders(&sb, "> ", request.Header) + if len(r.Request.Header) > 0 { + r.writeHeaders(&sb, "> ", r.Request.Header) sb.WriteString("\n") } } @@ -69,15 +69,17 @@ func (r RoundTripStringer) String() string { sb.WriteString("\n") } sb.WriteString("< ") - if r.Response != nil { - sb.WriteString(fmt.Sprintf("%s %s", r.Response.Proto, r.Response.Status)) - // Only display error on this line if the response body is empty. - // Otherwise the response body will include details about the error. - if len(r.ResponseBody) == 0 && r.Err != nil { - sb.WriteString(fmt.Sprintf(" (Error: %s)", r.Err)) - } - } else { + if r.Response == nil { sb.WriteString(fmt.Sprintf("Error: %s", r.Err)) + return sb.String() + } + + sb.WriteString(fmt.Sprintf("%s %s", r.Response.Proto, r.Response.Status)) + // Only display error on this line if the response body is empty or the + // client failed to read the response body. + // Otherwise the response body will include details about the error. + if len(r.ResponseBody) == 0 && r.Err != nil { + sb.WriteString(fmt.Sprintf(" (Error: %s)", r.Err)) } if r.DebugHeaders && len(r.Response.Header) > 0 { sb.WriteString("\n") diff --git a/logger/httplog/round_trip_stringer_test.go b/logger/httplog/round_trip_stringer_test.go index 5d1731ec2..45faad3cd 100644 --- a/logger/httplog/round_trip_stringer_test.go +++ b/logger/httplog/round_trip_stringer_test.go @@ -1,6 +1,7 @@ package httplog import ( + "errors" "net/http" "net/url" "testing" @@ -10,14 +11,14 @@ import ( func TestNoHeadersNoBody(t *testing.T) { res := RoundTripStringer{ - Response: &http.Response{ - Request: &http.Request{ - Method: "GET", - URL: &url.URL{ - Path: "/", - RawQuery: "foo=bar&baz=qux", - }, + Request: &http.Request{ + Method: "GET", + URL: &url.URL{ + Path: "/", + RawQuery: "foo=bar&baz=qux", }, + }, + Response: &http.Response{ Status: "200 OK", Proto: "HTTP/1.1", }, @@ -28,15 +29,15 @@ func TestNoHeadersNoBody(t *testing.T) { func TestRequestAndResponseHaveHeadersAndBody(t *testing.T) { res := RoundTripStringer{ - Response: &http.Response{ - Request: &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/"}, - Header: http.Header{ - "Foo": []string{"bar"}, - "Bar": []string{"baz"}, - }, + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + Header: http.Header{ + "Foo": []string{"bar"}, + "Bar": []string{"baz"}, }, + }, + Response: &http.Response{ Status: "200 OK", Proto: "HTTP/1.1", Header: http.Header{ @@ -62,15 +63,15 @@ func TestRequestAndResponseHaveHeadersAndBody(t *testing.T) { func TestDoNotPrintHeadersWhenNotConfigured(t *testing.T) { res := RoundTripStringer{ - Response: &http.Response{ - Request: &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/"}, - Header: http.Header{ - "Foo": []string{"bar"}, - "Bar": []string{"baz"}, - }, + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + Header: http.Header{ + "Foo": []string{"bar"}, + "Bar": []string{"baz"}, }, + }, + Response: &http.Response{ Status: "200 OK", Proto: "HTTP/1.1", Header: http.Header{ @@ -91,17 +92,17 @@ func TestDoNotPrintHeadersWhenNotConfigured(t *testing.T) { func TestHideAuthorizationHeaderWhenConfigured(t *testing.T) { res := RoundTripStringer{ - Response: &http.Response{ - Request: &http.Request{ - Method: "GET", - URL: &url.URL{Path: "/"}, - Header: http.Header{ - "Foo": []string{"bar"}, - "Authorization": []string{"baz"}, - "X-Databricks-Azure-SP-Management-Token": []string{"open sesame"}, - "X-Databricks-GCP-SA-Access-Token": []string{"alohamora"}, - }, + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + Header: http.Header{ + "Foo": []string{"bar"}, + "Authorization": []string{"baz"}, + "X-Databricks-Azure-SP-Management-Token": []string{"open sesame"}, + "X-Databricks-GCP-SA-Access-Token": []string{"alohamora"}, }, + }, + Response: &http.Response{ Status: "200 OK", Proto: "HTTP/1.1", }, @@ -121,3 +122,41 @@ func TestHideAuthorizationHeaderWhenConfigured(t *testing.T) { < HTTP/1.1 200 OK < response-hello`, res) } + +func TestNilResponse(t *testing.T) { + res := RoundTripStringer{ + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + }, + Err: &url.Error{ + Op: "Get", + URL: "http://example.com", + Err: errors.New("request timed out after 1m0s of inactivity"), + }, + }.String() + assert.Equal(t, `GET / +< Error: Get "http://example.com": request timed out after 1m0s of inactivity`, res) +} + +func TestFailureToConsumeResponse(t *testing.T) { + res := RoundTripStringer{ + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/"}, + }, + Response: &http.Response{ + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{ + "Foo": []string{"bar"}, + }, + }, + Err: errors.New("failed to read response body"), + DebugHeaders: true, + }.String() + assert.Equal(t, `GET / +> * Host: +< HTTP/1.1 200 OK (Error: failed to read response body) +< * Foo: ... (3 more bytes)`, res) +} diff --git a/retries/retries.go b/retries/retries.go index 47572f388..8517e0e07 100644 --- a/retries/retries.go +++ b/retries/retries.go @@ -210,6 +210,7 @@ func (r Retrier[T]) Run(ctx context.Context, fn func(context.Context) (*T, error return entity, nil } if !r.config.ShouldRetry(err) { + logger.Debugf(ctx, "non-retriable error: %s", err) return nil, err } lastErr = err