From fb5e2928cbdd3e091dd2ae9db5331bf72632bdb0 Mon Sep 17 00:00:00 2001 From: elij Date: Wed, 11 Oct 2023 19:47:32 -0700 Subject: [PATCH] add pgx.Namedargs support to WithArgs arg matcher Example usage: ```go mock.ExpectExec("INSERT INTO users"). WithArgs(pgx.NamedArgs{"name": "john", "created": AnyArg()}). WillReturnResult(NewResult("INSERT", 1)) ``` closes #164 --- argument_test.go | 74 ++++++++++++++++++++++++++++++++++++++++++++ expectations.go | 45 +++++++++++++++++++++++++++ expectations_test.go | 8 ++--- 3 files changed, 123 insertions(+), 4 deletions(-) diff --git a/argument_test.go b/argument_test.go index e760d93..a6e4c15 100644 --- a/argument_test.go +++ b/argument_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + pgx "github.com/jackc/pgx/v5" ) type AnyTime struct{} @@ -35,6 +37,30 @@ func TestAnyTimeArgument(t *testing.T) { } } +func TestAnyTimeNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec("INSERT INTO users"). + WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), + "INSERT INTO users(name, created_at) VALUES (@name, @time)", + pgx.NamedArgs{"name": "john", "time": time.Now()}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + func TestByteSliceArgument(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -55,6 +81,31 @@ func TestByteSliceArgument(t *testing.T) { } } +func TestByteSliceNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + username := []byte("user") + mock.ExpectExec("INSERT INTO users"). + WithArgs(pgx.NamedArgs{"user": username}). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), + "INSERT INTO users(username) VALUES (@user)", + pgx.NamedArgs{"user": username}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} + func TestAnyArgument(t *testing.T) { t.Parallel() mock, err := NewConn() @@ -75,3 +126,26 @@ func TestAnyArgument(t *testing.T) { t.Errorf("there were unfulfilled expectations: %s", err) } } + +func TestAnyNamedArgument(t *testing.T) { + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec("INSERT INTO users"). + WithArgs(pgx.NamedArgs{"name": "john", "created": AnyArg()}). + WillReturnResult(NewResult("INSERT", 1)) + + _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)", + pgx.NamedArgs{"name": "john", "created": time.Now()}, + ) + if err != nil { + t.Errorf("error '%s' was not expected, while inserting a row", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/expectations.go b/expectations.go index 884e3fc..dd14bde 100644 --- a/expectations.go +++ b/expectations.go @@ -1,10 +1,12 @@ package pgxmock import ( + "cmp" "context" "errors" "fmt" "reflect" + "slices" "strings" "sync" "time" @@ -138,6 +140,15 @@ func (e *queryBasedExpectation) argsMatches(args []interface{}) error { if len(args) != len(e.args) { return fmt.Errorf("expected %d, but got %d arguments", len(e.args), len(args)) } + if len(e.args) == 1 && len(args) == 1 { + if nargs, ok := e.args[0].(pgx.NamedArgs); ok { + vargs, ok := args[0].(pgx.NamedArgs) + if !ok { + return fmt.Errorf("expected named arguments, but instead got %T - %+v", args[0], args[0]) + } + return e.namedArgsMatches(nargs, vargs) + } + } for k, v := range args { // custom argument matcher if matcher, ok := e.args[k].(Argument); ok { @@ -154,6 +165,40 @@ func (e *queryBasedExpectation) argsMatches(args []interface{}) error { return nil } +// Keys returns the keys of the map m. +// The keys will be a sorted order. +func keys[M ~map[K]V, K cmp.Ordered, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + slices.Sort(r) + return r +} + +func (e *queryBasedExpectation) namedArgsMatches(nargs, vargs pgx.NamedArgs) error { + vargsKeys := keys(vargs) + nargsKeys := keys(nargs) + if !slices.Equal(nargsKeys, vargsKeys) { + return fmt.Errorf("named argument keys mismatch: expected %#v, got %#v", nargsKeys, vargsKeys) + } + // check values + for k := range nargs { + darg := nargs[k] + v := vargs[k] + if matcher, ok := darg.(Argument); ok { + if !matcher.Match(v) { + return fmt.Errorf("matcher %T could not match named argument %s value %T - %+v", matcher, k, v, v) + } + continue + } + if !reflect.DeepEqual(darg, v) { + return fmt.Errorf("named argument %s expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v) + } + } + return nil +} + // ExpectedClose is used to manage pgx.Close expectation // returned by pgxmock.ExpectClose type ExpectedClose struct { diff --git a/expectations_test.go b/expectations_test.go index e8e9eac..f7ae59e 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -34,7 +34,7 @@ func TestMaybe(t *testing.T) { a := assert.New(t) mock.ExpectPing().Maybe() mock.ExpectBegin().Maybe() - mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() //only if we're in Italy + mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() // only if we're in Italy cmdtag := pgconn.NewCommandTag("SELECT 1") mock.ExpectExec("select").WillReturnResult(cmdtag) mock.ExpectCommit().Maybe() @@ -71,10 +71,10 @@ func TestCallModifier(t *testing.T) { f() a.Error(mock.Ping(c), "should raise error for cancelled context") - a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() call is optional + a.NoError(mock.ExpectationsWereMet()) // should produce no error since Ping() call is optional a.NoError(mock.Ping(ctx)) - a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() was called actually + a.NoError(mock.ExpectationsWereMet()) // should produce no error since Ping() was called actually } func TestCopyFromBug(t *testing.T) { @@ -197,7 +197,7 @@ func TestBuildQuery(t *testing.T) { } func TestQueryRowScan(t *testing.T) { - mock, _ := NewConn() //TODO New(ValueConverterOption(CustomConverter{})) + mock, _ := NewConn() // TODO New(ValueConverterOption(CustomConverter{})) query := ` SELECT name,