Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes bugs around expression precedence and LIKE #16934

Merged
merged 15 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# This file contains queries that test expressions in Vitess.
# We've found a number of bugs around precedences that we want to test.
CREATE TABLE t0
(
c1 BIT,
INDEX idx_c1 (c1)
);

INSERT INTO t0(c1)
VALUES ('');


SELECT *
FROM t0;

SELECT ((t0.c1 = 'a'))
FROM t0;

SELECT *
FROM t0
WHERE ((t0.c1 = 'a'));


SELECT (1 LIKE ('a' IS NULL));
SELECT (NOT (1 LIKE ('a' IS NULL)));

SELECT (~ (1 || 0)) IS NULL;

SELECT 1
WHERE (~ (1 || 0)) IS NULL;
2 changes: 1 addition & 1 deletion go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func ASTToStatementType(stmt Statement) StatementType {
// CanNormalize takes Statement and returns if the statement can be normalized.
func CanNormalize(stmt Statement) bool {
switch stmt.(type) {
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream: // TODO: we could merge this logic into ASTrewriter
case *Select, *Union, *Insert, *Update, *Delete, *Set, *CallProc, *Stream, *VExplainStmt: // TODO: we could merge this logic into ASTrewriter
return true
}
return false
Expand Down
4 changes: 0 additions & 4 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1666,10 +1666,6 @@ func (op BinaryExprOperator) ToString() string {
return ShiftLeftStr
case ShiftRightOp:
return ShiftRightStr
case JSONExtractOp:
return JSONExtractOpStr
case JSONUnquoteExtractOp:
return JSONUnquoteExtractOpStr
default:
return "Unknown BinaryExprOperator"
}
Expand Down
26 changes: 11 additions & 15 deletions go/vt/sqlparser/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,19 +169,17 @@ const (
IsNotFalseStr = "is not false"

// BinaryExpr.Operator
BitAndStr = "&"
BitOrStr = "|"
BitXorStr = "^"
PlusStr = "+"
MinusStr = "-"
MultStr = "*"
DivStr = "/"
IntDivStr = "div"
ModStr = "%"
ShiftLeftStr = "<<"
ShiftRightStr = ">>"
JSONExtractOpStr = "->"
JSONUnquoteExtractOpStr = "->>"
BitAndStr = "&"
BitOrStr = "|"
BitXorStr = "^"
PlusStr = "+"
MinusStr = "-"
MultStr = "*"
DivStr = "/"
IntDivStr = "div"
ModStr = "%"
ShiftLeftStr = "<<"
ShiftRightStr = ">>"

// UnaryExpr.Operator
UPlusStr = "+"
Expand Down Expand Up @@ -715,8 +713,6 @@ const (
ModOp
ShiftLeftOp
ShiftRightOp
JSONExtractOp
JSONUnquoteExtractOp
)

// Constant for Enum Type - UnaryExprOperator
Expand Down
8 changes: 5 additions & 3 deletions go/vt/sqlparser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,11 @@ var (
}, {
input: "select /* u~ */ 1 from t where a = ~b",
}, {
input: "select /* -> */ a.b -> 'ab' from t",
input: "select /* -> */ a.b -> 'ab' from t",
output: "select /* -> */ json_extract(a.b, 'ab') from t",
}, {
input: "select /* -> */ a.b ->> 'ab' from t",
input: "select /* -> */ a.b ->> 'ab' from t",
output: "select /* -> */ json_unquote(json_extract(a.b, 'ab')) from t",
}, {
input: "select /* empty function */ 1 from t where a = b()",
}, {
Expand Down Expand Up @@ -5937,7 +5939,7 @@ partition by range (YEAR(purchased)) subpartition by hash (TO_DAYS(purchased))
},
{
input: "create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned ARRAY))))",
output: "create table t (\n\tid int,\n\tinfo JSON,\n\tkey zips ((cast(info -> '$.field' as unsigned array)))\n)",
output: "create table t (\n\tid int,\n\tinfo JSON,\n\tkey zips ((cast(json_extract(info, '$.field') as unsigned array)))\n)",
},
{
input: "create table t (id int, s varchar(255) default 'foo\"bar')",
Expand Down
5 changes: 1 addition & 4 deletions go/vt/sqlparser/precedence.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ func precedenceFor(in Expr) Precendence {
case *BetweenExpr:
return P12
case *ComparisonExpr:
switch node.Operator {
case EqualOp, NotEqualOp, GreaterThanOp, GreaterEqualOp, LessThanOp, LessEqualOp, LikeOp, InOp, RegexpOp, NullSafeEqualOp:
return P11
}
return P11
case *IsExpr:
return P11
case *BinaryExpr:
Expand Down
2 changes: 2 additions & 0 deletions go/vt/sqlparser/precedence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ func TestParens(t *testing.T) {
{in: "10 - (2 - 1)", expected: "10 - (2 - 1)"},
{in: "0 <=> (1 and 0)", expected: "0 <=> (1 and 0)"},
{in: "(~ (1||0)) IS NULL", expected: "~(1 or 0) is null"},
{in: "1 not like ('a' is null)", expected: "1 not like ('a' is null)"},
{in: ":vtg1 not like (:vtg2 is null)", expected: ":vtg1 not like (:vtg2 is null)"},
}

parser := NewTestParser()
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/sqlparser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -5591,11 +5591,11 @@ function_call_keyword
}
| column_name_or_offset JSON_EXTRACT_OP text_literal_or_arg
{
$$ = &BinaryExpr{Left: $1, Operator: JSONExtractOp, Right: $3}
$$ = &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}}
}
| column_name_or_offset JSON_UNQUOTE_EXTRACT_OP text_literal_or_arg
{
$$ = &BinaryExpr{Left: $1, Operator: JSONUnquoteExtractOp, Right: $3}
$$ = &JSONUnquoteExpr{JSONValue: &JSONExtractExpr{JSONDoc: $1, PathList: []Expr{$3}}}
}

column_names_opt_paren:
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/tracked_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestCanonicalOutput(t *testing.T) {
},
{
"create table t (id int, info JSON, INDEX zips((CAST(info->'$.field' AS unsigned array))))",
"CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tKEY `zips` ((CAST(`info` -> '$.field' AS unsigned array)))\n)",
"CREATE TABLE `t` (\n\t`id` int,\n\t`info` JSON,\n\tKEY `zips` ((CAST(JSON_EXTRACT(`info`, '$.field') AS unsigned array)))\n)",
},
{
"select 1 from t1 into outfile 'test/t1.txt'",
Expand Down
105 changes: 61 additions & 44 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -94,7 +96,18 @@ func (s *Tracker) String() string {
return s.buf.String()
}

func TestOneCase(t *testing.T) {
query := ``
if query == "" {
t.Skip("no query to test")
}
venv := vtenv.NewTestEnv()
env := evalengine.EmptyExpressionEnv(venv)
testCompilerCase(t, query, venv, nil, env)
}

func TestCompilerReference(t *testing.T) {
// This test runs a lot of queries and compares the results of the evalengine in eval mode to the results of the compiler.
now := time.Now()
evalengine.SystemTime = func() time.Time { return now }
defer func() { evalengine.SystemTime = time.Now }()
Expand All @@ -108,52 +121,11 @@ func TestCompilerReference(t *testing.T) {

tc.Run(func(query string, row []sqltypes.Value) {
env.Row = row

stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(tc.Schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

expected, evalErr := env.EvaluateAST(converted)
total++

res, vmErr := env.Evaluate(converted)
if vmErr != nil {
switch {
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
default:
supported++
}
return
testCompilerCase(t, query, venv, tc.Schema, env)
if !t.Failed() {
supported++
}

eval := expected.String()
comp := res.String()

if eval != comp {
t.Errorf("bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
return
}

supported++
})

track.Add(tc.Name(), supported, total)
Expand All @@ -163,6 +135,51 @@ func TestCompilerReference(t *testing.T) {
t.Logf("\n%s", track.String())
}

func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schema []*querypb.Field, env *evalengine.ExpressionEnv) {
stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

var expected evalengine.EvalResult
var evalErr error
assert.NotPanics(t, func() {
expected, evalErr = env.EvaluateAST(converted)
})
dbussink marked this conversation as resolved.
Show resolved Hide resolved
var res evalengine.EvalResult
var vmErr error
assert.NotPanics(t, func() {
res, vmErr = env.Evaluate(converted)
})
switch {
case vmErr == nil && evalErr == nil:
eval := expected.String()
comp := res.String()
assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
case vmErr == nil:
t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr)
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
}
}

func TestCompilerSingle(t *testing.T) {
var testCases = []struct {
expression string
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (expr *BinaryExpr) arguments(env *ExpressionEnv) (eval, eval, error) {
}
right, err := expr.Right.eval(env)
if err != nil {
return nil, nil, err
return left, nil, err
}
return left, right, nil
}
34 changes: 16 additions & 18 deletions go/vt/vtgate/evalengine/expr_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,13 +592,18 @@ func (l *LikeExpr) matchWildcard(left, right []byte, coll collations.ID) bool {
}
fullColl := colldata.Lookup(coll)
wc := fullColl.Wildcard(right, 0, 0, 0)
return wc.Match(left)
return wc.Match(left) == !l.Negate
}

func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) {
left, right, err := l.arguments(env)
if left == nil || right == nil || err != nil {
return nil, err
left, err := l.Left.eval(env)
if err != nil || left == nil {
return left, err
}

right, err := l.Right.eval(env)
if err != nil || right == nil {
return right, err
}

var col collations.TypedCollation
Expand All @@ -607,18 +612,9 @@ func (l *LikeExpr) eval(env *ExpressionEnv) (eval, error) {
return nil, err
}

var matched bool
switch {
case typeIsTextual(left.SQLType()) && typeIsTextual(right.SQLType()):
matched = l.matchWildcard(left.(*evalBytes).bytes, right.(*evalBytes).bytes, col.Collation)
case typeIsTextual(right.SQLType()):
matched = l.matchWildcard(left.ToRawBytes(), right.(*evalBytes).bytes, col.Collation)
case typeIsTextual(left.SQLType()):
matched = l.matchWildcard(left.(*evalBytes).bytes, right.ToRawBytes(), col.Collation)
default:
matched = l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), collations.CollationBinaryID)
}
return newEvalBool(matched == !l.Negate), nil
matched := l.matchWildcard(left.ToRawBytes(), right.ToRawBytes(), col.Collation)

return newEvalBool(matched), nil
}

func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
Expand All @@ -627,12 +623,14 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

skip1 := c.compileNullCheck1(lt)

rt, err := expr.Right.compile(c)
if err != nil {
return ctype{}, err
}

skip := c.compileNullCheck2(lt, rt)
skip2 := c.compileNullCheck1(rt)

if !lt.isTextual() {
c.asm.Convert_xc(2, sqltypes.VarChar, c.collation, nil)
Expand Down Expand Up @@ -684,6 +682,6 @@ func (expr *LikeExpr) compile(c *compiler) (ctype, error) {
})
}

c.asm.jumpDestination(skip)
c.asm.jumpDestination(skip1, skip2)
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagIsBoolean | flagNullable}, nil
}
Loading
Loading