diff --git a/argument_test.go b/argument_test.go index ab8b77e..081b536 100644 --- a/argument_test.go +++ b/argument_test.go @@ -2,10 +2,12 @@ package pgxmock import ( "context" + "errors" "testing" "time" pgx "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" ) type AnyTime struct{} @@ -81,6 +83,41 @@ func TestByteSliceArgument(t *testing.T) { } } +type failQryRW struct { + pgx.QueryRewriter +} + +func (fqrw failQryRW) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return "", nil, errors.New("cannot rewrite query " + sql) +} + +func TestExpectQueryRewriterFail(t *testing.T) { + + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + + mock.ExpectExec(`INSERT INTO users\(username\) VALUES \(\@user\)`). + WithArgs(failQryRW{}) + _, err = mock.Exec(context.Background(), "INSERT INTO users(username) VALUES (@user)", "baz") + assert.Error(t, err) +} + +func TestQueryRewriterFail(t *testing.T) { + + t.Parallel() + mock, err := NewConn() + if err != nil { + t.Errorf("an error '%s' was not expected when opening a stub database connection", err) + } + mock.ExpectExec(`INSERT INTO .+`).WithArgs("foo") + _, err = mock.Exec(context.Background(), "INSERT INTO users(username) VALUES (@user)", failQryRW{}) + assert.Error(t, err) + +} + func TestByteSliceNamedArgument(t *testing.T) { t.Parallel() mock, err := NewConn()