Skip to content

Commit

Permalink
add expectations for rewritten query
Browse files Browse the repository at this point in the history
  • Loading branch information
pashagolub committed Oct 15, 2023
1 parent 81da36d commit 03412d6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
3 changes: 2 additions & 1 deletion argument_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ func TestByteSliceNamedArgument(t *testing.T) {
}

username := []byte("user")
mock.ExpectExec("INSERT INTO users").
mock.ExpectExec(`INSERT INTO users\(username\) VALUES \(\@user\)`).
WithArgs(pgx.NamedArgs{"user": username}).
WithRewrittenSQL("INSERT INTO users(username) VALUES ($1)").
WillReturnResult(NewResult("INSERT", 1))

_, err = mock.Exec(context.Background(),
Expand Down
36 changes: 26 additions & 10 deletions expectations.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ func (e *commonExpectation) String() string {

// queryBasedExpectation is a base class that adds a query matching logic
type queryBasedExpectation struct {
expectSQL string
args []interface{}
expectSQL string
expectRewrittenSQL string
args []interface{}
}

func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) error {
func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) (rewrittenSQL string, err error) {
eargs := e.args
// check for any query rewriters
// according to current pgx docs, a QueryRewriter is only supported as the first
Expand All @@ -143,9 +144,9 @@ func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) erro
if qrw, ok := args[0].(pgx.QueryRewriter); ok {
// note: pgx.Conn is not currently used by the query rewriter, but is part
// of the method signature, so just create an empty pointer for now.
_, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, args)
rewrittenSQL, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, args)
if err != nil {
return fmt.Errorf("error rewriting query: %w", err)
return rewrittenSQL, fmt.Errorf("error rewriting query: %w", err)
}
args = newArgs
}
Expand All @@ -156,30 +157,31 @@ func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) erro
// of the method signature, so just create an empty pointer for now.
_, newArgs, err := qrw.RewriteQuery(context.Background(), new(pgx.Conn), sql, eargs)
if err != nil {
return fmt.Errorf("error rewriting query expectation: %w", err)
return "", fmt.Errorf("error rewriting query expectation: %w", err)
}
e.expectRewrittenSQL = rewrittenSQL
eargs = newArgs
}
}
}

if len(args) != len(eargs) {
return fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args))
return rewrittenSQL, fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args))
}
for k, v := range args {
// custom argument matcher
if matcher, ok := eargs[k].(Argument); ok {
if !matcher.Match(v) {
return fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
return rewrittenSQL, fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k])
}
continue
}

if darg := eargs[k]; !reflect.DeepEqual(darg, v) {
return fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
return rewrittenSQL, fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v)
}
}
return nil
return
}

// ExpectedClose is used to manage pgx.Close expectation
Expand Down Expand Up @@ -236,6 +238,13 @@ func (e *ExpectedExec) WithArgs(args ...interface{}) *ExpectedExec {
return e
}

// WithRewrittenSQL will match given expected expression to a rewritten SQL statement by
// an pgx.QueryRewriter argument
func (e *ExpectedExec) WithRewrittenSQL(sql string) *ExpectedExec {
e.expectRewrittenSQL = sql
return e
}

// String returns string representation
func (e *ExpectedExec) String() string {
msg := "ExpectedExec => expecting call to Exec():\n"
Expand Down Expand Up @@ -352,6 +361,13 @@ func (e *ExpectedQuery) WithArgs(args ...interface{}) *ExpectedQuery {
return e
}

// WithRewrittenSQL will match given expected expression to a rewritten SQL statement by
// an pgx.QueryRewriter argument
func (e *ExpectedQuery) WithRewrittenSQL(sql string) *ExpectedQuery {
e.expectRewrittenSQL = sql
return e
}

// RowsWillBeClosed expects this query rows to be closed.
func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery {
e.rowsMustBeClosed = true
Expand Down
19 changes: 17 additions & 2 deletions pgxmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ func (c *pgxmock) ExpectationsWereMet() error {
func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery {
e := &ExpectedQuery{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -234,6 +235,7 @@ func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin {
func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec {
e := &ExpectedExec{}
e.expectSQL = expectedSQL
e.expectRewrittenSQL = expectedSQL
c.expectations = append(c.expectations, e)
return e
}
Expand Down Expand Up @@ -430,7 +432,15 @@ func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (p
if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil {
return err
}
if err := queryExp.argsMatches(sql, args); err != nil {
if rewrittenSQL, err := queryExp.argsMatches(sql, args); err != nil {
return err
} else if rewrittenSQL > "" {
if err := c.queryMatcher.Match(queryExp.expectRewrittenSQL, rewrittenSQL); err != nil {
//pgx support QueryRewriter for arguments, now we can check if the query was actually rewriten
return err
}
}
if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil {
return err
}
if queryExp.err == nil && queryExp.rows == nil {
Expand Down Expand Up @@ -466,8 +476,13 @@ func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (
if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil {
return err
}
if err := execExp.argsMatches(query, args); err != nil {
if rewrittenSQL, err := execExp.argsMatches(query, args); err != nil {
return err
} else if rewrittenSQL > "" {
if err := c.queryMatcher.Match(execExp.expectRewrittenSQL, rewrittenSQL); err != nil {
//pgx support QueryRewriter for arguments, now we can check if the query was actually rewriten
return err
}
}
if execExp.result.String() == "" && execExp.err == nil {
return fmt.Errorf("Exec must return a result or raise an error: %s", execExp)
Expand Down

0 comments on commit 03412d6

Please sign in to comment.