diff --git a/expectations_test.go b/expectations_test.go index 0f6703c..3b48920 100644 --- a/expectations_test.go +++ b/expectations_test.go @@ -10,6 +10,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" ) @@ -250,21 +251,37 @@ func TestMissingWithArgs(t *testing.T) { } } +type user struct { + ID int64 + name string + email pgtype.Text +} + +func (u *user) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, _ []any) (newSQL string, newArgs []any, err error) { + switch sql { + case "INSERT": + return `INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`, []any{u.name, u.email}, nil + case "UPDATE": + return `UPDATE users SET username = $1, email = $2 WHERE id = $1`, []any{u.ID, u.name, u.email}, nil + case "DELETE": + return `DELETE FROM users WHERE id = $1`, []any{u.ID}, nil + } + return +} + func TestWithRewrittenSQL(t *testing.T) { t.Parallel() mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) a := assert.New(t) a.NoError(err) - mock.ExpectQuery(`INSERT INTO users(username) VALUES (@user)`). - WithArgs(pgx.NamedArgs{"user": "John"}). - WithRewrittenSQL(`INSERT INTO users(username) VALUES ($1)`). + u := user{name: "John", email: pgtype.Text{String: "john@example.com", Valid: true}} + mock.ExpectQuery(`INSERT`). + WithArgs(&u). + WithRewrittenSQL(`INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`). WillReturnRows() - _, err = mock.Query(context.Background(), - "INSERT INTO users(username) VALUES (@user)", - pgx.NamedArgs{"user": "John"}, - ) + _, err = mock.Query(context.Background(), "INSERT", &u) a.NoError(err) a.NoError(mock.ExpectationsWereMet()) @@ -280,3 +297,28 @@ func TestWithRewrittenSQL(t *testing.T) { a.Error(err) a.Error(mock.ExpectationsWereMet()) } + +func TestQueryRewriter(t *testing.T) { + t.Parallel() + mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) + a := assert.New(t) + a.NoError(err) + + update := `UPDATE "user" SET email = @email, password = @password, updated_utc = @updated_utc WHERE id = @id` + + mock.ExpectExec(update).WithArgs(pgx.NamedArgs{ + "id": "mockUser.ID", + "email": "mockUser.Email", + "password": "mockUser.Password", + "updated_utc": AnyArg(), + }).WillReturnError(errPanic) + + _, err = mock.Exec(context.Background(), update, pgx.NamedArgs{ + "id": "mockUser.ID", + "email": "mockUser.Email", + "password": "mockUser.Password", + "updated_utc": time.Now().UTC(), + }) + a.Error(err) + a.NoError(mock.ExpectationsWereMet()) +} diff --git a/pgxmock.go b/pgxmock.go index db70c08..5e0257f 100644 --- a/pgxmock.go +++ b/pgxmock.go @@ -203,7 +203,6 @@ func (c *pgxmock) ExpectationsWereMet() error { func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery { e := &ExpectedQuery{} e.expectSQL = expectedSQL - e.expectRewrittenSQL = expectedSQL c.expectations = append(c.expectations, e) return e } @@ -235,7 +234,6 @@ func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin { func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec { e := &ExpectedExec{} e.expectSQL = expectedSQL - e.expectRewrittenSQL = expectedSQL c.expectations = append(c.expectations, e) return e } @@ -371,7 +369,7 @@ func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, } func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.StatementDescription, error) { - ex, err := findExpectationFunc[*ExpectedPrepare](c, "Exec()", func(prepareExp *ExpectedPrepare) error { + ex, err := findExpectationFunc[*ExpectedPrepare](c, "Prepare()", func(prepareExp *ExpectedPrepare) error { if err := c.queryMatcher.Match(prepareExp.expectSQL, query); err != nil { return err } @@ -434,7 +432,7 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p } if rewrittenSQL, err := queryExp.argsMatches(sql, args); err != nil { return err - } else if rewrittenSQL != "" { + } else if rewrittenSQL != "" && queryExp.expectRewrittenSQL != "" { if err := c.queryMatcher.Match(queryExp.expectRewrittenSQL, rewrittenSQL); err != nil { return err } @@ -474,9 +472,8 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) ( } if rewrittenSQL, err := execExp.argsMatches(query, args); err != nil { return err - } else if rewrittenSQL != "" { + } else if rewrittenSQL != "" && execExp.expectRewrittenSQL != "" { if err := c.queryMatcher.Match(execExp.expectRewrittenSQL, rewrittenSQL); err != nil { - //pgx support QueryRewriter for arguments, now we can check if the query was actually rewriten return err } }