diff --git a/client.go b/client.go index 3e4a71c..a2852c3 100644 --- a/client.go +++ b/client.go @@ -15,8 +15,10 @@ const ( // Client defines a HTTP client type Client struct { - Client *http.Client - Retryer retry.Retryer + Client *http.Client + Retryer retry.Retryer + Prehooks []Prehook + Posthooks []Posthook } // NewClient initialises a new `Client` @@ -28,6 +30,8 @@ func NewClient(opts ...Option) *Client { Retryer: retry.Retryer{ Attempts: defaultRetryAttempts, }, + Prehooks: []Prehook{}, + Posthooks: []Posthook{}, } for _, opt := range opts { @@ -41,20 +45,28 @@ func NewClient(opts ...Option) *Client { func (client *Client) Do(request *http.Request) (*http.Response, error) { var ( response *http.Response - ) - success, errs := client.Retryer.Do(func() error { - var err error + success, errs = client.Retryer.Do(func() error { + var err error + + for _, prehook := range client.Prehooks { + prehook(request) + } - response, err = client.Client.Do(request) + response, err = client.Client.Do(request) - // Retry only on 5xx status codes - if response != nil && response.StatusCode >= http.StatusInternalServerError { - return fmt.Errorf("retrying on %s", response.Status) - } + for _, posthook := range client.Posthooks { + posthook(response, err) + } - return err - }) + // Retry only on 5xx status codes + if response != nil && response.StatusCode >= http.StatusInternalServerError { + return fmt.Errorf("retrying on %s", response.Status) + } + + return err + }) + ) if !success { return response, fmt.Errorf("httpclient: request occurred with errors: %s", errs) diff --git a/client_test.go b/client_test.go index d330bc0..025e51a 100644 --- a/client_test.go +++ b/client_test.go @@ -88,3 +88,26 @@ func TestClient_DoWithInternalServerError(t *testing.T) { assert.EqualError(t, err, "httpclient: request occurred with errors: retrying on 500 Internal Server Error; retrying on 500 Internal Server Error") assert.Equal(t, http.StatusInternalServerError, response.StatusCode) } + +func TestClient_DoWithHook(t *testing.T) { + var ( + url = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).URL + client = httpclient.NewClient( + httpclient.OptionAttempts(2), + httpclient.OptionAddPrehook(func(req *http.Request) { + assert.Equal(t, url, req.URL.String()) + }), + httpclient.OptionAddPosthook(func(resp *http.Response, err error) { + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + }), + ) + request, _ = http.NewRequest(http.MethodPost, url, strings.NewReader(`{"foo": "bar"}`)) + response, err = client.Do(request) + ) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, response.StatusCode) +} diff --git a/go.sum b/go.sum index d16c539..d8c601e 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= diff --git a/option.go b/option.go index 902e1a4..4bb81b0 100644 --- a/option.go +++ b/option.go @@ -49,3 +49,17 @@ func OptionJitter(fn func(backoff time.Duration) time.Duration) Option { client.Retryer.Jitter = fn } } + +// OptionAddPrehook adds a prehook to `Client` +func OptionAddPrehook(prehook func(request *http.Request)) Option { + return func(client *Client) { + client.Prehooks = append(client.Prehooks, prehook) + } +} + +// OptionAddPosthook adds a posthook to `Client` +func OptionAddPosthook(posthook func(response *http.Response, err error)) Option { + return func(client *Client) { + client.Posthooks = append(client.Posthooks, posthook) + } +} diff --git a/posthook.go b/posthook.go new file mode 100644 index 0000000..0dfb671 --- /dev/null +++ b/posthook.go @@ -0,0 +1,6 @@ +package httpclient + +import "net/http" + +// Posthook defines a hook called after a HTTP request +type Posthook func(response *http.Response, err error) diff --git a/prehook.go b/prehook.go new file mode 100644 index 0000000..377280d --- /dev/null +++ b/prehook.go @@ -0,0 +1,6 @@ +package httpclient + +import "net/http" + +// Prehook defines a hook called before a HTTP request +type Prehook func(request *http.Request)