Skip to content

Commit

Permalink
add pgx.Namedargs support to WithArgs arg matcher
Browse files Browse the repository at this point in the history
Example usage:

```go
mock.ExpectExec("INSERT INTO users").
    WithArgs(pgx.NamedArgs{"name": "john", "created": AnyArg()}).
    WillReturnResult(NewResult("INSERT", 1))
```

closes #164
  • Loading branch information
dropwhile committed Oct 12, 2023
1 parent 1af56ff commit fb5e292
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 4 deletions.
74 changes: 74 additions & 0 deletions argument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"
"time"

pgx "github.com/jackc/pgx/v5"
)

type AnyTime struct{}
Expand Down Expand Up @@ -35,6 +37,30 @@ func TestAnyTimeArgument(t *testing.T) {
}
}

func TestAnyTimeNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(name, created_at) VALUES (@name, @time)",
pgx.NamedArgs{"name": "john", "time": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestByteSliceArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -55,6 +81,31 @@ func TestByteSliceArgument(t *testing.T) {
}
}

func TestByteSliceNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

username := []byte("user")
mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"user": username}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
"INSERT INTO users(username) VALUES (@user)",
pgx.NamedArgs{"user": username},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
Expand All @@ -75,3 +126,26 @@ func TestAnyArgument(t *testing.T) {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}

func TestAnyNamedArgument(t *testing.T) {
t.Parallel()
mock, err := NewConn()
if err != nil {
t.Errorf("an error '%s' was not expected when opening a stub database connection", err)
}

mock.ExpectExec("INSERT INTO users").
WithArgs(pgx.NamedArgs{"name": "john", "created": AnyArg()}).
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)",
pgx.NamedArgs{"name": "john", "created": time.Now()},
)
if err != nil {
t.Errorf("error '%s' was not expected, while inserting a row", err)
}

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
45 changes: 45 additions & 0 deletions expectations.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package pgxmock

import (
"cmp"
"context"
"errors"
"fmt"
"reflect"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -138,6 +140,15 @@ func (e *queryBasedExpectation) argsMatches(args []interface{}) error {
if len(args) != len(e.args) {
return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args))
}
if len(e.args) == 1 && len(args) == 1 {
if nargs, ok := e.args[0].(pgx.NamedArgs); ok {
vargs, ok := args[0].(pgx.NamedArgs)
if !ok {
return fmt.Errorf("expected named arguments, but instead got %T - %+v", args[0], args[0])
}
return e.namedArgsMatches(nargs, vargs)
}
}
for k, v := range args {
// custom argument matcher
if matcher, ok := e.args[k].(Argument); ok {
Expand All @@ -154,6 +165,40 @@ func (e *queryBasedExpectation) argsMatches(args []interface{}) error {
return nil
}

// Keys returns the keys of the map m.
// The keys will be a sorted order.
func keys[M ~map[K]V, K cmp.Ordered, V any](m M) []K {
r := make([]K, 0, len(m))
for k := range m {
r = append(r, k)
}
slices.Sort(r)
return r
}

func (e *queryBasedExpectation) namedArgsMatches(nargs, vargs pgx.NamedArgs) error {
vargsKeys := keys(vargs)
nargsKeys := keys(nargs)
if !slices.Equal(nargsKeys, vargsKeys) {
return fmt.Errorf("named argument keys mismatch: expected %#v, got %#v", nargsKeys, vargsKeys)
}
// check values
for k := range nargs {
darg := nargs[k]
v := vargs[k]
if matcher, ok := darg.(Argument); ok {
if !matcher.Match(v) {
return fmt.Errorf("matcher %T could not match named argument %s value %T - %+v", matcher, k, v, v)
}
continue
}
if !reflect.DeepEqual(darg, v) {
return fmt.Errorf("named argument %s expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
}
}
return nil
}

// ExpectedClose is used to manage pgx.Close expectation
// returned by pgxmock.ExpectClose
type ExpectedClose struct {
Expand Down
8 changes: 4 additions & 4 deletions expectations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestMaybe(t *testing.T) {
a := assert.New(t)
mock.ExpectPing().Maybe()
mock.ExpectBegin().Maybe()
mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() //only if we're in Italy
mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() // only if we're in Italy
cmdtag := pgconn.NewCommandTag("SELECT 1")
mock.ExpectExec("select").WillReturnResult(cmdtag)
mock.ExpectCommit().Maybe()
Expand Down Expand Up @@ -71,10 +71,10 @@ func TestCallModifier(t *testing.T) {
f()
a.Error(mock.Ping(c), "should raise error for cancelled context")

a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() call is optional
a.NoError(mock.ExpectationsWereMet()) // should produce no error since Ping() call is optional

a.NoError(mock.Ping(ctx))
a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() was called actually
a.NoError(mock.ExpectationsWereMet()) // should produce no error since Ping() was called actually
}

func TestCopyFromBug(t *testing.T) {
Expand Down Expand Up @@ -197,7 +197,7 @@ func TestBuildQuery(t *testing.T) {
}

func TestQueryRowScan(t *testing.T) {
mock, _ := NewConn() //TODO New(ValueConverterOption(CustomConverter{}))
mock, _ := NewConn() // TODO New(ValueConverterOption(CustomConverter{}))
query := `
SELECT
name,
Expand Down

0 comments on commit fb5e292

Please sign in to comment.