Skip to content

Commit

Permalink
feat: fix locate null check escaping
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 committed Oct 14, 2024
1 parent 9efcf48 commit bd67c87
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 57 deletions.
101 changes: 57 additions & 44 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ func (s *Tracker) String() string {
return s.buf.String()
}

func TestOneCase(t *testing.T) {
query := ``
if query == "" {
t.Skip("no query to test")
}
venv := vtenv.NewTestEnv()
env := evalengine.EmptyExpressionEnv(venv)
testCompilerCase(t, query, venv, nil, env)
}

func TestCompilerReference(t *testing.T) {
// This test runs a lot of queries and compares the results of the evalengine in eval mode to the results of the compiler.
now := time.Now()
Expand All @@ -111,52 +121,10 @@ func TestCompilerReference(t *testing.T) {

tc.Run(func(query string, row []sqltypes.Value) {
env.Row = row

stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(tc.Schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

var expected evalengine.EvalResult
var evalErr error
assert.NotPanics(t, func() {
expected, evalErr = env.EvaluateAST(converted)
})
total++

var res evalengine.EvalResult
var vmErr error
assert.NotPanics(t, func() {
res, vmErr = env.Evaluate(converted)
})

switch {
case vmErr == nil && evalErr == nil:
eval := expected.String()
comp := res.String()
assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
testCompilerCase(t, query, venv, tc.Schema, env)
if !t.Failed() {
supported++
case vmErr == nil:
t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr)
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
}
})

Expand All @@ -167,6 +135,51 @@ func TestCompilerReference(t *testing.T) {
t.Logf("\n%s", track.String())
}

func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schema []*querypb.Field, env *evalengine.ExpressionEnv) {
stmt, err := venv.Parser().ParseExpr(query)
if err != nil {
// no need to test un-parseable queries
return
}

fields := evalengine.FieldResolver(schema)
cfg := &evalengine.Config{
ResolveColumn: fields.Column,
ResolveType: fields.Type,
Collation: collations.CollationUtf8mb4ID,
Environment: venv,
NoConstantFolding: true,
}

converted, err := evalengine.Translate(stmt, cfg)
if err != nil {
return
}

var expected evalengine.EvalResult
var evalErr error
assert.NotPanics(t, func() {
expected, evalErr = env.EvaluateAST(converted)
})
var res evalengine.EvalResult
var vmErr error
assert.NotPanics(t, func() {
res, vmErr = env.Evaluate(converted)
})
switch {
case vmErr == nil && evalErr == nil:
eval := expected.String()
comp := res.String()
assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp)
case vmErr == nil:
t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr)
case evalErr == nil:
t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr)
case evalErr.Error() != vmErr.Error():
t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr)
}
}

func TestCompilerSingle(t *testing.T) {
var testCases = []struct {
expression string
Expand Down
28 changes: 15 additions & 13 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -1685,24 +1685,17 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

skip1 := c.compileNullCheck1(substr)
str, err := call.Arguments[1].compile(c)
if err != nil {
return ctype{}, err
}

skip1 := c.compileNullCheck2(substr, str)
var skip2 *jump
if len(call.Arguments) > 2 {
l, err := call.Arguments[2].compile(c)
if err != nil {
return ctype{}, err
}
skip2 = c.compileNullCheck2(str, l)
_ = c.compileToInt64(l, 1)
}
skip2 := c.compileNullCheck1(str)
var skip3 *jump

if !str.isTextual() {
c.asm.Convert_xce(len(call.Arguments)-1, sqltypes.VarChar, c.collation)
c.asm.Convert_xce(1, sqltypes.VarChar, c.collation)
str.Col = collations.TypedCollation{
Collation: c.collation,
Coercibility: collations.CoerceCoercible,
Expand All @@ -1713,14 +1706,23 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) {
fromCharset := colldata.Lookup(substr.Col.Collation).Charset()
toCharset := colldata.Lookup(str.Col.Collation).Charset()
if !substr.isTextual() || (fromCharset != toCharset && !toCharset.IsSuperset(fromCharset)) {
c.asm.Convert_xce(len(call.Arguments), sqltypes.VarChar, str.Col.Collation)
c.asm.Convert_xce(2, sqltypes.VarChar, str.Col.Collation)
substr.Col = collations.TypedCollation{
Collation: str.Col.Collation,
Coercibility: collations.CoerceCoercible,
Repertoire: collations.RepertoireASCII,
}
}

if len(call.Arguments) > 2 {
l, err := call.Arguments[2].compile(c)
if err != nil {
return ctype{}, err
}
skip3 = c.compileNullCheck1(l)
_ = c.compileToInt64(l, 1)
}

var coll colldata.Collation
if typeIsTextual(substr.Type) && typeIsTextual(str.Type) {
coll = colldata.Lookup(str.Col.Collation)
Expand All @@ -1734,7 +1736,7 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) {
c.asm.Locate2(coll)
}

c.asm.jumpDestination(skip1, skip2)
c.asm.jumpDestination(skip1, skip2, skip3)
return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil
}

Expand Down

0 comments on commit bd67c87

Please sign in to comment.