From 81da36d807cb83085aa685c25a6db07824a99c89 Mon Sep 17 00:00:00 2001 From: elij Date: Thu, 12 Oct 2023 13:14:14 -0700 Subject: [PATCH] add pgx.Namedargs support to WithArgs arg matcher Example usage: mock.ExpectExec("INSERT INTO users"). WithArgs(pgx.NamedArgs{"name": "john", "created": AnyArg()}). WillReturnResult(NewResult("INSERT", 1)) closes #164 --- argument_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++ expectations.go | 38 +++++++++++++++++++++---- pgxmock.go | 4 +-- 3 files changed, 109 insertions(+), 7 deletions(-) diff --git a/argument_test.go b/argument_test.go index e760d93..75b473e 100644 --- a/argument_test.go +++ b/argument_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + pgx "github.com/jackc/pgx/v5" ) type AnyTime struct{} @@ -35,6 +37,30 @@ func TestAnyTimeArgument(t *testing.T) { } } +func TestAnyTimeNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec("INSERT INTO users"). + WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), + "INSERT INTO users(name, created_at) VALUES (@name, @time)", + pgx.NamedArgs{"name": "john", "time": time.Now()}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + func TestByteSliceArgument(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -55,6 +81,31 @@ func TestByteSliceArgument(t *testing.T) { } } +func TestByteSliceNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + username := []byte("user") + mock.ExpectExec("INSERT INTO users"). + WithArgs(pgx.NamedArgs{"user": username}). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), + "INSERT INTO users(username) VALUES (@user)", + pgx.NamedArgs{"user": username}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + func TestAnyArgument(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -75,3 +126,26 @@ func TestAnyArgument(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestAnyNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec("INSERT INTO users"). + WithArgs("john", AnyArg()). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)", + pgx.NamedArgs{"name": "john", "created": time.Now()}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/expectations.go b/expectations.go index 884e3fc..691e85d 100644 --- a/expectations.go +++ b/expectations.go @@ -134,20 +134,48 @@ type queryBasedExpectation struct { args []interface{} } -func (e *queryBasedExpectation) argsMatches(args []interface{}) error { - if len(args) != len(e.args) { - return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) +func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) error { + eargs := e.args + // check for any query rewriters + // according to current pgx docs, a QueryRewriter is only supported as the first + // argument. + if len(args) == 1 { + if qrw, ok := args[0].(pgx.QueryRewriter); ok { + // note: pgx.Conn is not currently used by the query rewriter, but is part + // of the method signature, so just create an empty pointer for now. + _, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, args) + if err != nil { + return fmt.Errorf("error rewriting query: %w", err) + } + args = newArgs + } + // also do rewriting on the expected args if a QueryRewriter is present + if len(eargs) == 1 { + if qrw, ok := eargs[0].(pgx.QueryRewriter); ok { + // note: pgx.Conn is not currently used by the query rewriter, but is part + // of the method signature, so just create an empty pointer for now. + _, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, eargs) + if err != nil { + return fmt.Errorf("error rewriting query expectation: %w", err) + } + eargs = newArgs + } + } + } + + if len(args) != len(eargs) { + return fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args)) } for k, v := range args { // custom argument matcher - if matcher, ok := e.args[k].(Argument); ok { + if matcher, ok := eargs[k].(Argument); ok { if !matcher.Match(v) { return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) } continue } - if darg := e.args[k]; !reflect.DeepEqual(darg, v) { + if darg := eargs[k]; !reflect.DeepEqual(darg, v) { return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v) } } diff --git a/pgxmock.go b/pgxmock.go index 6051301..f2a46d9 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -430,7 +430,7 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil { return err } - if err := queryExp.argsMatches(args); err != nil { + if err := queryExp.argsMatches(sql, args); err != nil { return err } if queryExp.err == nil && queryExp.rows == nil { @@ -466,7 +466,7 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) ( if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil { return err } - if err := execExp.argsMatches(args); err != nil { + if err := execExp.argsMatches(query, args); err != nil { return err } if execExp.result.String() == "" && execExp.err == nil {