Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add WithNamedArgs arg matcher to make working with pgx.NamedArgs easier #165

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions argument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"
"time"

pgx "github.com/jackc/pgx/v5"
)

type AnyTime struct{}
Expand Down Expand Up @@ -35,6 +37,30 @@ func TestAnyTimeArgument(t *testing.T) {
}
}

func TestAnyTimeNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(name, created_at) VALUES (@name, @time)",
pgx.NamedArgs{"name": "john", "time": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestByteSliceArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -55,6 +81,31 @@ func TestByteSliceArgument(t *testing.T) {
}
}

func TestByteSliceNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

username := []byte("user")
mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"user": username}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": username},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -75,3 +126,26 @@ func TestAnyArgument(t *testing.T) {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs("john", AnyArg()).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)",
pgx.NamedArgs{"name": "john", "created": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
38 changes: 33 additions & 5 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,48 @@ type queryBasedExpectation struct {
args []interface{}
}

func (e *queryBasedExpectation) argsMatches(args []interface{}) error {
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) error {
eargs := e.args
// check for any query rewriters
// according to current pgx docs, a QueryRewriter is only supported as the first
// argument.
if len(args) == 1 {
if qrw, ok := args[0].(pgx.QueryRewriter); ok {
// note: pgx.Conn is not currently used by the query rewriter, but is part
// of the method signature, so just create an empty pointer for now.
_, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, args)
pashagolub marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("error rewriting query: %w", err)
}
args = newArgs
}
// also do rewriting on the expected args if a QueryRewriter is present
if len(eargs) == 1 {
if qrw, ok := eargs[0].(pgx.QueryRewriter); ok {
// note: pgx.Conn is not currently used by the query rewriter, but is part
// of the method signature, so just create an empty pointer for now.
_, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, eargs)
if err != nil {
return fmt.Errorf("error rewriting query expectation: %w", err)
}
eargs = newArgs
}
}
}

if len(args) != len(eargs) {
return fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args))
}
for k, v := range args {
// custom argument matcher
if matcher, ok := e.args[k].(Argument); ok {
if matcher, ok := eargs[k].(Argument); ok {
if !matcher.Match(v) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}

if darg := e.args[k]; !reflect.DeepEqual(darg, v) {
if darg := eargs[k]; !reflect.DeepEqual(darg, v) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
}
}
Expand Down
4 changes: 2 additions & 2 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p
if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil {
return err
}
if err := queryExp.argsMatches(args); err != nil {
if err := queryExp.argsMatches(sql, args); err != nil {
pashagolub marked this conversation as resolved.
Show resolved Hide resolved
return err
}
if queryExp.err == nil && queryExp.rows == nil {
Expand Down Expand Up @@ -466,7 +466,7 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (
if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil {
return err
}
if err := execExp.argsMatches(args); err != nil {
if err := execExp.argsMatches(query, args); err != nil {
return err
}
if execExp.result.String() == "" && execExp.err == nil {
Expand Down
Loading