From 171d0c90ea3fe19417437c478bf0aedc589ea919 Mon Sep 17 00:00:00 2001 From: Tim Heckman Date: Mon, 15 Jun 2020 11:20:49 -0700 Subject: [PATCH 1/2] Propogate Do() HTTP error up when retries are exceeded (#70) Add propagation of the HTTP request/response error up the call stack when the number of retries are exceeded. Because the ErrorHandler does not receive a copy of the request, it's impossible to use the ErrorHandler to generate an enhanced version of the default Retries Exceeded error message. This change takes the non-API breaking approach to solving this, by including the error string in the returned error (if an error has occurred). Because the error string returned from `Do()` uses `%w`, this change requires Go 1.13+. Fixes #69 --- client.go | 99 ++++++++++++++++++++++++++++++++++++++++---------- client_test.go | 60 +++++++++++++++++++++--------- 2 files changed, 121 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index cda0210..9ecab4f 100644 --- a/client.go +++ b/client.go @@ -432,6 +432,48 @@ func DefaultRetryPolicy(ctx context.Context, resp *http.Response, err error) (bo return false, nil } +// ErrorPropagatedRetryPolicy is the same as DefaultRetryPolicy, except it +// propagates errors back instead of returning nil. This allows you to inspect +// why it decided to retry or not. +func ErrorPropagatedRetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) { + // do not retry on context.Canceled or context.DeadlineExceeded + if ctx.Err() != nil { + return false, ctx.Err() + } + + if err != nil { + if v, ok := err.(*url.Error); ok { + // Don't retry if the error was due to too many redirects. + if redirectsErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to an invalid protocol scheme. + if schemeErrorRe.MatchString(v.Error()) { + return false, v + } + + // Don't retry if the error was due to TLS cert verification failure. + if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + return false, v + } + } + + // The error is likely recoverable so retry. + return true, nil + } + + // Check the response code. We retry on 500-range responses to allow + // the server time to recover, as 500's are typically not permanent + // errors and may relate to outages on the server side. This will catch + // invalid response codes as well, like 0 and 999. + if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != 501) { + return true, fmt.Errorf("unexpected HTTP status %s", resp.Status) + } + + return false, nil +} + // DefaultBackoff provides a default callback for Client.Backoff which // will perform exponential backoff based on the attempt number and limited // by the provided minimum and maximum durations. @@ -509,9 +551,13 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } var resp *http.Response - var err error + var attempt int + var shouldRetry bool + var doErr, checkErr error for i := 0; ; i++ { + attempt++ + var code int // HTTP response code // Always rewind the request body when non-nil. @@ -540,20 +586,20 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // Attempt the request - resp, err = c.HTTPClient.Do(req.Request) + resp, doErr = c.HTTPClient.Do(req.Request) if resp != nil { code = resp.StatusCode } // Check if we should continue with retries. - checkOK, checkErr := c.CheckRetry(req.Context(), resp, err) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) - if err != nil { + if doErr != nil { switch v := logger.(type) { case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, err) + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) case LeveledLogger: - v.Error("request failed", "error", err, "method", req.Method, "url", req.URL) + v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) } } else { // Call this here to maintain the behavior of logging all requests, @@ -571,13 +617,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } - // Now decide if we should continue. - if !checkOK { - if checkErr != nil { - err = checkErr - } - c.HTTPClient.CloseIdleConnections() - return resp, err + if !shouldRetry { + break } // We do this before drainBody because there's no need for the I/O if @@ -588,7 +629,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } // We're going to retry, consume any response to reuse the connection. - if err == nil && resp != nil { + if doErr == nil { c.drainBody(resp.Body) } @@ -613,19 +654,37 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } } + // this is the closest we have to success criteria + if doErr == nil && checkErr == nil && !shouldRetry { + return resp, nil + } + + defer c.HTTPClient.CloseIdleConnections() + + err := doErr + if checkErr != nil { + err = checkErr + } + if c.ErrorHandler != nil { - c.HTTPClient.CloseIdleConnections() - return c.ErrorHandler(resp, err, c.RetryMax+1) + return c.ErrorHandler(resp, err, attempt) } // By default, we close the response body and return an error without // returning the response if resp != nil { - resp.Body.Close() + c.drainBody(resp.Body) } - c.HTTPClient.CloseIdleConnections() - return nil, fmt.Errorf("%s %s giving up after %d attempts", - req.Method, req.URL, c.RetryMax+1) + + // this means CheckRetry thought the request was a failure, but didn't + // communicate why + if err == nil { + return nil, fmt.Errorf("%s %s giving up after %d attempt(s)", + req.Method, req.URL, attempt) + } + + return nil, fmt.Errorf("%s %s giving up after %d attempt(s): %w", + req.Method, req.URL, attempt, err) } // Try to read the response body so we can reuse this connection. diff --git a/client_test.go b/client_test.go index dcacb01..27442e0 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "io/ioutil" "net" @@ -255,22 +256,44 @@ func TestClient_Do_fails(t *testing.T) { })) defer ts.Close() - // Create the client. Use short retry windows so we fail faster. - client := NewClient() - client.RetryWaitMin = 10 * time.Millisecond - client.RetryWaitMax = 10 * time.Millisecond - client.RetryMax = 2 - - // Create the request - req, err := NewRequest("POST", ts.URL, nil) - if err != nil { - t.Fatalf("err: %v", err) + tests := []struct { + name string + cr CheckRetry + err string + }{ + { + name: "default_retry_policy", + cr: DefaultRetryPolicy, + err: "giving up after 3 attempt(s)", + }, + { + name: "error_propagated_retry_policy", + cr: ErrorPropagatedRetryPolicy, + err: "giving up after 3 attempt(s): unexpected HTTP status 500 Internal Server Error", + }, } - // Send the request. - _, err = client.Do(req) - if err == nil || !strings.Contains(err.Error(), "giving up") { - t.Fatalf("expected giving up error, got: %#v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.CheckRetry = tt.cr + client.RetryMax = 2 + + // Create the request + req, err := NewRequest("POST", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Send the request. + _, err = client.Do(req) + if err == nil || !strings.HasSuffix(err.Error(), tt.err) { + t.Fatalf("expected giving up error, got: %#v", err) + } + }) } } @@ -462,8 +485,10 @@ func TestClient_RequestWithContext(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != context.Canceled { - t.Fatalf("Expected context.Canceled err, got: %v", err) + e := fmt.Sprintf("GET %s giving up after 1 attempt(s): %s", ts.URL, context.Canceled.Error()) + + if err.Error() != e { + t.Fatalf("Expected err to contain %s, got: %v", e, err) } } @@ -493,10 +518,9 @@ func TestClient_CheckRetry(t *testing.T) { t.Fatalf("CheckRetry called %d times, expected 1", called) } - if err != retryErr { + if err.Error() != fmt.Sprintf("GET %s giving up after 2 attempt(s): retryError", ts.URL) { t.Fatalf("Expected retryError, got:%v", err) } - } func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { From cf855b1d14d561f66108aa47dc4b92493d8fcb37 Mon Sep 17 00:00:00 2001 From: Prashanth Pai Date: Mon, 15 Jun 2020 23:54:55 +0530 Subject: [PATCH 2/2] LeveledLogger: Add doc and prioritize it over Logger (#97) * Document LeveledLogger better to make it clear to developers that the second variadic argument accepts key-value pairs like hci-log or zap logger and not Printf style logging. * When the passed logger implements both Logger and LeveledLogger interfaces, prioritize LeveledLogger over Logger. Switch statements evaluate cases from top to bottom. Co-authored-by: Jeff Mitchell --- client.go | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index 9ecab4f..458e578 100644 --- a/client.go +++ b/client.go @@ -39,7 +39,7 @@ import ( "sync" "time" - "github.com/hashicorp/go-cleanhttp" + cleanhttp "github.com/hashicorp/go-cleanhttp" ) var ( @@ -276,12 +276,16 @@ type Logger interface { Printf(string, ...interface{}) } -// LeveledLogger interface implements the basic methods that a logger library needs +// LeveledLogger is an interface that can be implemented by any logger or a +// logger wrapper to provide leveled logging. The methods accept a message +// string and a variadic number of key-value pairs. For log.Printf style +// formatting where message string contains a format specifier, use Logger +// interface. type LeveledLogger interface { - Error(string, ...interface{}) - Info(string, ...interface{}) - Debug(string, ...interface{}) - Warn(string, ...interface{}) + Error(msg string, keysAndValues ...interface{}) + Info(msg string, keysAndValues ...interface{}) + Debug(msg string, keysAndValues ...interface{}) + Warn(msg string, keysAndValues ...interface{}) } // hookLogger adapts an LeveledLogger to Logger for use by the existing hook functions @@ -543,10 +547,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if logger != nil { switch v := logger.(type) { - case Logger: - v.Printf("[DEBUG] %s %s", req.Method, req.URL) case LeveledLogger: v.Debug("performing request", "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[DEBUG] %s %s", req.Method, req.URL) } } @@ -576,10 +580,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if c.RequestLogHook != nil { switch v := logger.(type) { - case Logger: - c.RequestLogHook(v, req.Request, i) case LeveledLogger: c.RequestLogHook(hookLogger{v}, req.Request, i) + case Logger: + c.RequestLogHook(v, req.Request, i) default: c.RequestLogHook(nil, req.Request, i) } @@ -596,10 +600,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if doErr != nil { switch v := logger.(type) { - case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) case LeveledLogger: v.Error("request failed", "error", doErr, "method", req.Method, "url", req.URL) + case Logger: + v.Printf("[ERR] %s %s request failed: %v", req.Method, req.URL, doErr) } } else { // Call this here to maintain the behavior of logging all requests, @@ -607,10 +611,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if c.ResponseLogHook != nil { // Call the response logger function if provided. switch v := logger.(type) { - case Logger: - c.ResponseLogHook(v, resp) case LeveledLogger: c.ResponseLogHook(hookLogger{v}, resp) + case Logger: + c.ResponseLogHook(v, resp) default: c.ResponseLogHook(nil, resp) } @@ -640,10 +644,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { } if logger != nil { switch v := logger.(type) { - case Logger: - v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) case LeveledLogger: v.Debug("retrying request", "request", desc, "timeout", wait, "remaining", remain) + case Logger: + v.Printf("[DEBUG] %s: retrying in %s (%d left)", desc, wait, remain) } } select { @@ -694,10 +698,10 @@ func (c *Client) drainBody(body io.ReadCloser) { if err != nil { if c.logger() != nil { switch v := c.logger().(type) { - case Logger: - v.Printf("[ERR] error reading response body: %v", err) case LeveledLogger: v.Error("error reading response body", "error", err) + case Logger: + v.Printf("[ERR] error reading response body: %v", err) } } }