diff --git a/client.go b/client.go index 7938919..3afb6b7 100644 --- a/client.go +++ b/client.go @@ -39,7 +39,7 @@ import ( "sync" "time" - "github.com/hashicorp/go-cleanhttp" + cleanhttp "github.com/hashicorp/go-cleanhttp" ) var ( @@ -276,8 +276,10 @@ type Logger interface { Printf(string, ...interface{}) } -// LeveledFormatLogger interface allows to use logger libraries that use formatted -// leveled methods for loging +// LeveledFormatLogger is an interface that can be implemented by any logger or a +// logger wrapper to provide leveled logging in a log.Printf style formatting. +// The methods accept a log.Printf format string and a variadic number of elements +// to substitute in the format. type LeveledFormatLogger interface { Infof(string, ...interface{}) Debugf(string, ...interface{}) @@ -285,12 +287,16 @@ type LeveledFormatLogger interface { Errorf(string, ...interface{}) } -// LeveledLogger interface implements the basic methods that a logger library needs +// LeveledLogger is an interface that can be implemented by any logger or a +// logger wrapper to provide leveled logging. The methods accept a message +// string and a variadic number of key-value pairs. For log.Printf style +// formatting where message string contains a format specifier, use Logger +// or LeveledFormatLogger interfaces. type LeveledLogger interface { - Error(string, ...interface{}) - Info(string, ...interface{}) - Debug(string, ...interface{}) - Warn(string, ...interface{}) + Error(msg string, keysAndValues ...interface{}) + Info(msg string, keysAndValues ...interface{}) + Debug(msg string, keysAndValues ...interface{}) + Warn(msg string, keysAndValues ...interface{}) } // hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions @@ -451,6 +457,48 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } +// ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it +// propagates errors back instead of returning nil. This allows you to inspect +// why it decided to retry or not. +func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + if err != nil { + if v, ok := err.(*url.Error); ok { + // Don't retry if the error was due to too many redirects. + if redirectsErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to an invalid protocol scheme. + if schemeErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to TLS cert verification failure. + if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + return false, v + } + } + + // The error is likely recoverable so retry. + return true, nil + } + + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid response codes as well, like 0 and 999. + if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) { + return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) + } + + return false, nil +} + // DefaultBackoff provides a default callback for Client.Backoff which // will perform exponential backoff based on the attempt number and limited // by the provided minimum and maximum durations. @@ -522,17 +570,21 @@ func (c *Client) Do(req *Request) (*http.Response, error) { switch v := logger.(type) { case LeveledFormatLogger: v.Debugf("%s %s: performing request", req.Method, req.URL) - case Logger: - v.Printf("[DEBUG] %s %s", req.Method, req.URL) case LeveledLogger: v.Debug("performing request", "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[DEBUG] %s %s", req.Method, req.URL) } } var resp *http.Response - var err error + var attempt int + var shouldRetry bool + var doErr, checkErr error for i := 0; ; i++ { + attempt++ + var code int // HTTP response code // Always rewind the request body when non-nil. @@ -553,32 +605,32 @@ func (c *Client) Do(req *Request) (*http.Response, error) { switch v := logger.(type) { case LeveledFormatLogger: c.RequestLogHook(hookFormatLogger{v}, req.Request, i) - case Logger: - c.RequestLogHook(v, req.Request, i) case LeveledLogger: c.RequestLogHook(hookLogger{v}, req.Request, i) + case Logger: + c.RequestLogHook(v, req.Request, i) default: c.RequestLogHook(nil, req.Request, i) } } // Attempt the request - resp, err = c.HTTPClient.Do(req.Request) + resp, doErr = c.HTTPClient.Do(req.Request) if resp != nil { code = resp.StatusCode } // Check if we should continue with retries. - checkOK, checkErr := c.CheckRetry(req.Context(), resp, err) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) - if err != nil { + if doErr != nil { switch v := logger.(type) { case LeveledFormatLogger: v.Errorf("%s %s: request failed: %s", req.Method, req.URL, err) - case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) case LeveledLogger: - v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) + v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) } } else { // Call this here to maintain the behavior of logging all requests, @@ -588,23 +640,18 @@ func (c *Client) Do(req *Request) (*http.Response, error) { switch v := logger.(type) { case LeveledFormatLogger: c.ResponseLogHook(hookFormatLogger{v}, resp) - case Logger: - c.ResponseLogHook(v, resp) case LeveledLogger: c.ResponseLogHook(hookLogger{v}, resp) + case Logger: + c.ResponseLogHook(v, resp) default: c.ResponseLogHook(nil, resp) } } } - // Now decide if we should continue. - if !checkOK { - if checkErr != nil { - err = checkErr - } - c.HTTPClient.CloseIdleConnections() - return resp, err + if !shouldRetry { + break } // We do this before drainBody because there's no need for the I/O if @@ -615,7 +662,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // We're going to retry, consume any response to reuse the connection. - if err == nil && resp != nil { + if doErr == nil { c.drainBody(resp.Body) } @@ -628,10 +675,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { switch v := logger.(type) { case LeveledFormatLogger: v.Debugf("%s: retrying request in %s (%d left)", desc, wait, remain) - case Logger: - v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) case LeveledLogger: v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) + case Logger: + v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) } } select { @@ -642,19 +689,37 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + // this is the closest we have to success criteria + if doErr == nil && checkErr == nil && !shouldRetry { + return resp, nil + } + + defer c.HTTPClient.CloseIdleConnections() + + err := doErr + if checkErr != nil { + err = checkErr + } + if c.ErrorHandler != nil { - c.HTTPClient.CloseIdleConnections() - return c.ErrorHandler(resp, err, c.RetryMax+1) + return c.ErrorHandler(resp, err, attempt) } // By default, we close the response body and return an error without // returning the response if resp != nil { - resp.Body.Close() + c.drainBody(resp.Body) + } + + // this means CheckRetry thought the request was a failure, but didn't + // communicate why + if err == nil { + return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", + req.Method, req.URL, attempt) } - c.HTTPClient.CloseIdleConnections() - return nil, fmt.Errorf("%s %s giving up after %d attempts", - req.Method, req.URL, c.RetryMax+1) + + return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", + req.Method, req.URL, attempt, err) } // Try to read the response body so we can reuse this connection. @@ -666,10 +731,10 @@ func (c *Client) drainBody(body io.ReadCloser) { switch v := c.logger().(type) { case LeveledFormatLogger: v.Errorf("error reading response body: %s", err) - case Logger: - v.Printf("[ERR] error reading response body: %v", err) case LeveledLogger: v.Error("error reading response body", "error", err) + case Logger: + v.Printf("[ERR] error reading response body: %v", err) } } } diff --git a/client_test.go b/client_test.go index dcacb01..27442e0 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "io/ioutil" "net" @@ -255,22 +256,44 @@ func TestClient_Do_fails(t *testing.T) { })) defer ts.Close() - // Create the client. Use short retry windows so we fail faster. - client := NewClient() - client.RetryWaitMin = 10 * time.Millisecond - client.RetryWaitMax = 10 * time.Millisecond - client.RetryMax = 2 - - // Create the request - req, err := NewRequest("POST", ts.URL, nil) - if err != nil { - t.Fatalf("err: %v", err) + tests := []struct { + name string + cr CheckRetry + err string + }{ + { + name: "default_retry_policy", + cr: DefaultRetryPolicy, + err: "giving up after 3 attempt(s)", + }, + { + name: "error_propagated_retry_policy", + cr: ErrorPropagatedRetryPolicy, + err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", + }, } - // Send the request. - _, err = client.Do(req) - if err == nil || !strings.Contains(err.Error(), "giving up") { - t.Fatalf("expected giving up error, got: %#v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.CheckRetry = tt.cr + client.RetryMax = 2 + + // Create the request + req, err := NewRequest("POST", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Send the request. + _, err = client.Do(req) + if err == nil || !strings.HasSuffix(err.Error(), tt.err) { + t.Fatalf("expected giving up error, got: %#v", err) + } + }) } } @@ -462,8 +485,10 @@ func TestClient_RequestWithContext(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != context.Canceled { - t.Fatalf("Expected context.Canceled err, got: %v", err) + e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) + + if err.Error() != e { + t.Fatalf("Expected err to contain %s, got: %v", e, err) } } @@ -493,10 +518,9 @@ func TestClient_CheckRetry(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != retryErr { + if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { t.Fatalf("Expected retryError, got:%v", err) } - } func TestClient_DefaultRetryPolicy_TLS(t *testing.T) {