Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Avoid loading the response bodies twice in memory when parsing bytes.Buffer #984

Merged
merged 6 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions httpclient/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ func WithResponseHeader(key string, value *string) DoOption {
}
}

// WithResponseUnmarshal unmarshals the response body into response. The
// supported response types are the following:
// - *bytes.Buffer,
// - *io.ReadCloser,
// - *[]byte,
// - a pointer to a struct with a Contents io.ReadCloser field,
// - a pointer to a struct representing a JSON object.
//
// If response is a pointer to a io.ReadCloser or a struct with a io.ReadCloser
// field name "Contents", then the response io.ReadCloser is set to the value of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key thing to note here is that currently, our autogeneration only includes a Contents io.ReadCloser field for APIs that return non-JSON responses. We don't attempt to parse these at all and simply return them to the caller.

// the body's reader without actually reading it.
func WithResponseUnmarshal(response any) DoOption {
return DoOption{
in: func(r *http.Request) error {
Expand All @@ -50,45 +61,45 @@ func WithResponseUnmarshal(response any) DoOption {
if err != nil {
return err
}
// If the field contains a "Content" field of type bytes.Buffer, write the body over there and return.
if field, ok := findContentsField(response, body); ok {
// If so, set the value

if field, ok := findContentsField(response); ok {
field.Set(reflect.ValueOf(body.ReadCloser))
return nil
}

// If the destination is bytes.Buffer, write the body over there
if raw, ok := response.(*io.ReadCloser); ok {
*raw = body.ReadCloser
if reader, ok := response.(*io.ReadCloser); ok {
*reader = body.ReadCloser
return nil
}
if buffer, ok := response.(*bytes.Buffer); ok {
defer body.ReadCloser.Close()
_, err := buffer.ReadFrom(body.ReadCloser)
return err
}

// At this point, fully read the content of the body and use it
// to populate the response object (whether it is a slice of bytes
// or a JSON object).
defer body.ReadCloser.Close()
bs, err := io.ReadAll(body.ReadCloser)
bodyBytes, err := io.ReadAll(body.ReadCloser)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if len(bs) == 0 {
if len(bodyBytes) == 0 {
return nil
}
// If the destination is a byte slice or buffer, pass the body verbatim.
if raw, ok := response.(*[]byte); ok {
*raw = bs
if bs, ok := response.(*[]byte); ok {
*bs = bodyBytes
return nil
}
if raw, ok := response.(*bytes.Buffer); ok {
_, err := raw.Write(bs)
return err
}
err = json.Unmarshal(bs, &response)
if err != nil {
return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bs)
if err = json.Unmarshal(bodyBytes, &response); err != nil {
return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bodyBytes)
}
return nil
},
}
}

func findContentsField(response any, body *common.ResponseWrapper) (*reflect.Value, bool) {
func findContentsField(response any) (*reflect.Value, bool) {
value := reflect.ValueOf(response)
value = reflect.Indirect(value)
if value.Kind() != reflect.Struct {
Expand Down
113 changes: 102 additions & 11 deletions httpclient/response_test.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,120 @@
package httpclient

import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
)

func TestSimpleRequestRawResponse(t *testing.T) {
c := NewApiClient(ClientConfig{
func make200Response(body string) *http.Response {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(body)),
}
}

func mockClient(resp *http.Response) *ApiClient {
return NewApiClient(ClientConfig{
Transport: hc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("Hello, world!")),
Request: r,
}, nil
resp.Request = r
return resp, nil
}),
})
var raw []byte
err := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&raw))
require.NoError(t, err)
require.Equal(t, "Hello, world!", string(raw))
}

func TestWithResponseUnmarshal_structWithContent(t *testing.T) {
type structWithContents = struct {
Contents io.ReadCloser
}
want := structWithContents{
Contents: io.NopCloser(strings.NewReader("foo bar")),
}

var got structWithContents
c := mockClient(make200Response("foo bar"))
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

if gotErr != nil {
t.Errorf("WithResponseUnmarshal(): want no error, got: %s", gotErr)
}
renaudhartert-db marked this conversation as resolved.
Show resolved Hide resolved

wantBytes, _ := io.ReadAll(want.Contents)
gotBytes, _ := io.ReadAll(got.Contents)
if diff := cmp.Diff(wantBytes, gotBytes); diff != "" {
t.Errorf("WithResponseUnmarshal(): want != got: (-want +got):\n%s", diff)
}
}

func TestWithResponseUnmarshal_readCloser(t *testing.T) {
want := io.NopCloser(strings.NewReader("foo bar"))

var got io.ReadCloser
c := mockClient(make200Response("foo bar"))
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

if gotErr != nil {
t.Errorf("WithResponseUnmarshal(): want no error, got: %s", gotErr)
}

wantBytes, _ := io.ReadAll(want)
gotBytes, _ := io.ReadAll(got)
if diff := cmp.Diff(wantBytes, gotBytes); diff != "" {
t.Errorf("WithResponseUnmarshal(): want != got: (-want +got):\n%s", diff)
}
}

func TestWithResponseUnmarshal_byteBuffer(t *testing.T) {
want := bytes.NewBuffer([]byte("foo bar"))

var got bytes.Buffer
c := mockClient(make200Response("foo bar"))
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

if gotErr != nil {
t.Errorf("WithResponseUnmarshal(): want no error, got: %s", gotErr)
}
if diff := cmp.Diff(want.Bytes(), got.Bytes()); diff != "" {
t.Errorf("WithResponseUnmarshal(): want != got: (-want +got):\n%s", diff)
}
}

func TestWithResponseUnmarshal_bytes(t *testing.T) {
want := []byte("foo bar")

var got []byte
c := mockClient(make200Response("foo bar"))
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

if gotErr != nil {
t.Errorf("WithResponseUnmarshal(): want no error, got: %s", gotErr)
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("WithResponseUnmarshal(): want != got: (-want +got):\n%s", diff)
}
}

func TestWithResponseUnmarshal_json(t *testing.T) {
type jsonStruct struct {
Foo string `json:"foo"`
}
want := jsonStruct{Foo: "bar"}

var got jsonStruct
c := mockClient(make200Response(`{"foo": "bar"}`))
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

if gotErr != nil {
t.Errorf("WithResponseUnmarshal(): want no error, got: %s", gotErr)
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("WithResponseUnmarshal(): want != got: (-want +got):\n%s", diff)
}
}

func TestWithResponseHeader(t *testing.T) {
Expand Down
Loading