diff --git a/queryer.go b/queryer.go index 3d3587a..1f6ca06 100755 --- a/queryer.go +++ b/queryer.go @@ -93,16 +93,12 @@ type QueryerFunc func(*QueryInput) (interface{}, error) // Query invokes the provided function and writes the response to the receiver func (q QueryerFunc) Query(ctx context.Context, input *QueryInput, receiver interface{}) error { // invoke the handler - response, err := q(input) - if err != nil { - return err + response, responseErr := q(input) + if response != nil { + // assume the mock is writing the same kind as the receiver + reflect.ValueOf(receiver).Elem().Set(reflect.ValueOf(response)) } - - // assume the mock is writing the same kind as the receiver - reflect.ValueOf(receiver).Elem().Set(reflect.ValueOf(response)) - - // no errors - return nil + return responseErr // support partial success: always return the queryer error after setting the return data } type NetworkQueryer struct { diff --git a/queryerMultiOp.go b/queryerMultiOp.go index 4384b10..bef9cc4 100755 --- a/queryerMultiOp.go +++ b/queryerMultiOp.go @@ -73,11 +73,6 @@ func (q *MultiOpQueryer) Query(ctx context.Context, input *QueryInput, receiver } // format the result as needed - err = q.queryer.ExtractErrors(unmarshaled) - if err != nil { - return err - } - // assign the result under the data key to the receiver decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ TagName: "json", @@ -86,9 +81,11 @@ func (q *MultiOpQueryer) Query(ctx context.Context, input *QueryInput, receiver if err != nil { return err } + if err := decoder.Decode(unmarshaled["data"]); err != nil { + return err + } - // the only way for things to go wrong now happen while decoding - return decoder.Decode(unmarshaled["data"]) + return q.queryer.ExtractErrors(unmarshaled) } func (q *MultiOpQueryer) loadQuery(ctx context.Context, keys dataloader.Keys) []*dataloader.Result { diff --git a/queryerMultiOp_test.go b/queryerMultiOp_test.go index 188b169..8e33e47 100755 --- a/queryerMultiOp_test.go +++ b/queryerMultiOp_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httptest" "sync" "testing" "time" @@ -101,3 +102,31 @@ func TestMultiOpQueryer_batchesRequests(t *testing.T) { assert.Equal(t, map[string]interface{}{"nCalled": "2:2"}, result2) assert.Equal(t, map[string]interface{}{"nCalled": "2:2"}, result3) } + +func TestMultiOpQueryer_partial_success(t *testing.T) { + t.Parallel() + queryer := NewMultiOpQueryer("someURL", 1*time.Millisecond, 10).WithHTTPClient(&http.Client{ + Transport: roundTripFunc(func(*http.Request) *http.Response { + w := httptest.NewRecorder() + fmt.Fprint(w, ` + [ + { + "data": { + "foo": "bar" + }, + "errors": [ + {"message": "baz"} + ] + } + ] + `) + return w.Result() + }), + }) + var result any + err := queryer.Query(context.Background(), &QueryInput{Query: "query { hello }"}, &result) + assert.Equal(t, map[string]any{ + "foo": "bar", + }, result) + assert.EqualError(t, err, "baz") +} diff --git a/queryerNetwork.go b/queryerNetwork.go index 692f930..e2f5719 100644 --- a/queryerNetwork.go +++ b/queryerNetwork.go @@ -3,8 +3,9 @@ package graphql import ( "context" "encoding/json" - "github.com/mitchellh/mapstructure" "net/http" + + "github.com/mitchellh/mapstructure" ) // SingleRequestQueryer sends the query to a url and returns the response @@ -80,11 +81,6 @@ func (q *SingleRequestQueryer) Query(ctx context.Context, input *QueryInput, rec return err } - // otherwise we have to copy the response onto the receiver - if err = q.queryer.ExtractErrors(result); err != nil { - return err - } - // assign the result under the data key to the receiver decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ TagName: "json", @@ -93,7 +89,10 @@ func (q *SingleRequestQueryer) Query(ctx context.Context, input *QueryInput, rec if err != nil { return err } + if err = decoder.Decode(result["data"]); err != nil { + return err + } - // the only way for things to go wrong now happen while decoding - return decoder.Decode(result["data"]) + // finally extract errors, if any, and return them + return q.queryer.ExtractErrors(result) // TODO add unit tests! } diff --git a/queryer_test.go b/queryer_test.go index 7b574ce..5f8e1f1 100755 --- a/queryer_test.go +++ b/queryer_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/http/httptest" "testing" "time" @@ -58,6 +59,22 @@ func TestQueryerFunc_failure(t *testing.T) { assert.Equal(t, expected, err) } +func TestQueryerFunc_partial_success(t *testing.T) { + t.Parallel() + someData := map[string]interface{}{"foo": "bar"} + someError := errors.New("baz") + + queryer := QueryerFunc(func(*QueryInput) (interface{}, error) { + return someData, someError + }) + + result := map[string]interface{}{} + + err := queryer.Query(context.Background(), &QueryInput{}, &result) + assert.ErrorIs(t, err, someError) + assert.Equal(t, someData, result) +} + func TestHTTPQueryerBasicCases(t *testing.T) { // this test run a suite of tests for every queryer in the table queryerTable := []struct { @@ -115,7 +132,6 @@ func TestHTTPQueryerBasicCases(t *testing.T) { // serialize the json we want to send back marshaled, err := json.Marshal(result) - // if something went wrong if err != nil { return &http.Response{ @@ -250,7 +266,6 @@ func TestHTTPQueryerBasicCases(t *testing.T) { assert.Nil(t, err) } }) - } }) @@ -353,9 +368,10 @@ func TestQueryerWithMiddlewares(t *testing.T) { for _, row := range queryerTable { t.Run(row.name, func(t *testing.T) { t.Run("Middleware Failures", func(t *testing.T) { + someErr := errors.New("This One") queryer := row.queryer.WithMiddlewares([]NetworkMiddleware{ func(r *http.Request) error { - return errors.New("This One") + return someErr }, }) @@ -366,13 +382,7 @@ func TestQueryerWithMiddlewares(t *testing.T) { // fire the query err := queryer.Query(context.Background(), input, &map[string]interface{}{}) - if err == nil { - t.Error("Did not enounter an error when we should have") - return - } - if err.Error() != "This One" { - t.Errorf("Did not encountered expected error message: Expected 'This One', found %v", err.Error()) - } + assert.ErrorIs(t, err, someErr) }) t.Run("Middlware success", func(t *testing.T) { @@ -434,3 +444,29 @@ func TestQueryerWithMiddlewares(t *testing.T) { }) } } + +func TestNetworkQueryer_partial_success(t *testing.T) { + t.Parallel() + queryer := NewSingleRequestQueryer("someURL").WithHTTPClient(&http.Client{ + Transport: roundTripFunc(func(*http.Request) *http.Response { + w := httptest.NewRecorder() + fmt.Fprint(w, ` + { + "data": { + "foo": "bar" + }, + "errors": [ + {"message": "baz"} + ] + } + `) + return w.Result() + }), + }) + var result any + err := queryer.Query(context.Background(), &QueryInput{Query: "query { hello }"}, &result) + assert.Equal(t, map[string]any{ + "foo": "bar", + }, result) + assert.EqualError(t, err, "baz") +}