diff --git a/httpclient/response.go b/httpclient/response.go index c1aa29e99..4ece40e84 100644 --- a/httpclient/response.go +++ b/httpclient/response.go @@ -25,6 +25,17 @@ func WithResponseHeader(key string, value *string) DoOption { } } +// WithResponseUnmarshal unmarshals the response body into response. The +// supported response types are the following: +// - *bytes.Buffer, +// - *io.ReadCloser, +// - *[]byte, +// - a pointer to a struct with a Contents io.ReadCloser field, +// - a pointer to a struct representing a JSON object. +// +// If response is a pointer to a io.ReadCloser or a struct with a io.ReadCloser +// field name "Contents", then the response io.ReadCloser is set to the value of +// the body's reader without actually reading it. func WithResponseUnmarshal(response any) DoOption { return DoOption{ in: func(r *http.Request) error { @@ -50,45 +61,45 @@ func WithResponseUnmarshal(response any) DoOption { if err != nil { return err } - // If the field contains a "Content" field of type bytes.Buffer, write the body over there and return. - if field, ok := findContentsField(response, body); ok { - // If so, set the value + + if field, ok := findContentsField(response); ok { field.Set(reflect.ValueOf(body.ReadCloser)) return nil } - - // If the destination is bytes.Buffer, write the body over there - if raw, ok := response.(*io.ReadCloser); ok { - *raw = body.ReadCloser + if reader, ok := response.(*io.ReadCloser); ok { + *reader = body.ReadCloser return nil } + if buffer, ok := response.(*bytes.Buffer); ok { + defer body.ReadCloser.Close() + _, err := buffer.ReadFrom(body.ReadCloser) + return err + } + + // At this point, fully read the content of the body and use it + // to populate the response object (whether it is a slice of bytes + // or a JSON object). defer body.ReadCloser.Close() - bs, err := io.ReadAll(body.ReadCloser) + bodyBytes, err := io.ReadAll(body.ReadCloser) if err != nil { return fmt.Errorf("failed to read response body: %w", err) } - if len(bs) == 0 { + if len(bodyBytes) == 0 { return nil } - // If the destination is a byte slice or buffer, pass the body verbatim. - if raw, ok := response.(*[]byte); ok { - *raw = bs + if bs, ok := response.(*[]byte); ok { + *bs = bodyBytes return nil } - if raw, ok := response.(*bytes.Buffer); ok { - _, err := raw.Write(bs) - return err - } - err = json.Unmarshal(bs, &response) - if err != nil { - return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bs) + if err = json.Unmarshal(bodyBytes, &response); err != nil { + return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bodyBytes) } return nil }, } } -func findContentsField(response any, body *common.ResponseWrapper) (*reflect.Value, bool) { +func findContentsField(response any) (*reflect.Value, bool) { value := reflect.ValueOf(response) value = reflect.Indirect(value) if value.Kind() != reflect.Struct { diff --git a/httpclient/response_test.go b/httpclient/response_test.go index e4aab5798..ef36d2f32 100644 --- a/httpclient/response_test.go +++ b/httpclient/response_test.go @@ -1,29 +1,97 @@ package httpclient import ( + "bytes" "context" "io" "net/http" "strings" "testing" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) -func TestSimpleRequestRawResponse(t *testing.T) { - c := NewApiClient(ClientConfig{ +func make200Response(body string) *http.Response { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func mockClient(resp *http.Response) *ApiClient { + return NewApiClient(ClientConfig{ Transport: hc(func(r *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: io.NopCloser(strings.NewReader("Hello, world!")), - Request: r, - }, nil + resp.Request = r + return resp, nil }), }) - var raw []byte - err := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&raw)) - require.NoError(t, err) - require.Equal(t, "Hello, world!", string(raw)) +} + +func TestWithResponseUnmarshal_structWithContent(t *testing.T) { + c := mockClient(make200Response("foo bar")) + type structWithContents = struct { + Contents io.ReadCloser + } + want := structWithContents{ + Contents: io.NopCloser(strings.NewReader("foo bar")), + } + + var got structWithContents + gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got)) + + assert.NoError(t, gotErr) + wantBytes, _ := io.ReadAll(want.Contents) + gotBytes, _ := io.ReadAll(got.Contents) + assert.Equal(t, wantBytes, gotBytes) +} + +func TestWithResponseUnmarshal_readCloser(t *testing.T) { + c := mockClient(make200Response("foo bar")) + want := io.NopCloser(strings.NewReader("foo bar")) + + var got io.ReadCloser + gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got)) + + assert.NoError(t, gotErr) + wantBytes, _ := io.ReadAll(want) + gotBytes, _ := io.ReadAll(got) + assert.Equal(t, wantBytes, gotBytes) +} + +func TestWithResponseUnmarshal_byteBuffer(t *testing.T) { + c := mockClient(make200Response("foo bar")) + want := bytes.NewBuffer([]byte("foo bar")) + + var got bytes.Buffer + gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got)) + + assert.NoError(t, gotErr) + assert.Equal(t, want.Bytes(), got.Bytes()) +} + +func TestWithResponseUnmarshal_bytes(t *testing.T) { + c := mockClient(make200Response("foo bar")) + want := []byte("foo bar") + + var got []byte + gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got)) + + assert.NoError(t, gotErr) + assert.Equal(t, want, got) +} + +func TestWithResponseUnmarshal_json(t *testing.T) { + c := mockClient(make200Response(`{"foo": "bar"}`)) + type jsonStruct struct { + Foo string `json:"foo"` + } + want := jsonStruct{Foo: "bar"} + + var got jsonStruct + gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got)) + + assert.NoError(t, gotErr) + assert.Equal(t, want, got) } func TestWithResponseHeader(t *testing.T) { @@ -40,10 +108,9 @@ func TestWithResponseHeader(t *testing.T) { }), }) - var out string - ctx := context.Background() - err := client.Do(ctx, "GET", "abc", - WithResponseHeader("Foo", &out)) - require.NoError(t, err) - require.Equal(t, "some", out) + var got string + gotErr := client.Do(context.Background(), "GET", "abc", WithResponseHeader("Foo", &got)) + + assert.NoError(t, gotErr) + assert.Equal(t, "some", got) }