From bd67c876e0391189e5f13ba859066fcd7f756e64 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Mon, 14 Oct 2024 12:07:11 +0530 Subject: [PATCH] feat: fix locate null check escaping Signed-off-by: Manan Gupta --- go/vt/vtgate/evalengine/compiler_test.go | 101 +++++++++++++---------- go/vt/vtgate/evalengine/fn_string.go | 28 ++++--- 2 files changed, 72 insertions(+), 57 deletions(-) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 67c9467de48..cb9b99e7776 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -96,6 +96,16 @@ func (s *Tracker) String() string { return s.buf.String() } +func TestOneCase(t *testing.T) { + query := `` + if query == "" { + t.Skip("no query to test") + } + venv := vtenv.NewTestEnv() + env := evalengine.EmptyExpressionEnv(venv) + testCompilerCase(t, query, venv, nil, env) +} + func TestCompilerReference(t *testing.T) { // This test runs a lot of queries and compares the results of the evalengine in eval mode to the results of the compiler. now := time.Now() @@ -111,52 +121,10 @@ func TestCompilerReference(t *testing.T) { tc.Run(func(query string, row []sqltypes.Value) { env.Row = row - - stmt, err := venv.Parser().ParseExpr(query) - if err != nil { - // no need to test un-parseable queries - return - } - - fields := evalengine.FieldResolver(tc.Schema) - cfg := &evalengine.Config{ - ResolveColumn: fields.Column, - ResolveType: fields.Type, - Collation: collations.CollationUtf8mb4ID, - Environment: venv, - NoConstantFolding: true, - } - - converted, err := evalengine.Translate(stmt, cfg) - if err != nil { - return - } - - var expected evalengine.EvalResult - var evalErr error - assert.NotPanics(t, func() { - expected, evalErr = env.EvaluateAST(converted) - }) total++ - - var res evalengine.EvalResult - var vmErr error - assert.NotPanics(t, func() { - res, vmErr = env.Evaluate(converted) - }) - - switch { - case vmErr == nil && evalErr == nil: - eval := expected.String() - comp := res.String() - assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp) + testCompilerCase(t, query, venv, tc.Schema, env) + if !t.Failed() { supported++ - case vmErr == nil: - t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr) - case evalErr == nil: - t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr) - case evalErr.Error() != vmErr.Error(): - t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr) } }) @@ -167,6 +135,51 @@ func TestCompilerReference(t *testing.T) { t.Logf("\n%s", track.String()) } +func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schema []*querypb.Field, env *evalengine.ExpressionEnv) { + stmt, err := venv.Parser().ParseExpr(query) + if err != nil { + // no need to test un-parseable queries + return + } + + fields := evalengine.FieldResolver(schema) + cfg := &evalengine.Config{ + ResolveColumn: fields.Column, + ResolveType: fields.Type, + Collation: collations.CollationUtf8mb4ID, + Environment: venv, + NoConstantFolding: true, + } + + converted, err := evalengine.Translate(stmt, cfg) + if err != nil { + return + } + + var expected evalengine.EvalResult + var evalErr error + assert.NotPanics(t, func() { + expected, evalErr = env.EvaluateAST(converted) + }) + var res evalengine.EvalResult + var vmErr error + assert.NotPanics(t, func() { + res, vmErr = env.Evaluate(converted) + }) + switch { + case vmErr == nil && evalErr == nil: + eval := expected.String() + comp := res.String() + assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp) + case vmErr == nil: + t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr) + case evalErr == nil: + t.Errorf("failed evaluation from compiler:\nSQL: %s\nError: %s", query, vmErr) + case evalErr.Error() != vmErr.Error(): + t.Errorf("error mismatch:\nSQL: %s\nError eval: %s\nError comp: %s", query, evalErr, vmErr) + } +} + func TestCompilerSingle(t *testing.T) { var testCases = []struct { expression string diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 6d83d36412d..1cca7a94c8d 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -1685,24 +1685,17 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { return ctype{}, err } + skip1 := c.compileNullCheck1(substr) str, err := call.Arguments[1].compile(c) if err != nil { return ctype{}, err } - skip1 := c.compileNullCheck2(substr, str) - var skip2 *jump - if len(call.Arguments) > 2 { - l, err := call.Arguments[2].compile(c) - if err != nil { - return ctype{}, err - } - skip2 = c.compileNullCheck2(str, l) - _ = c.compileToInt64(l, 1) - } + skip2 := c.compileNullCheck1(str) + var skip3 *jump if !str.isTextual() { - c.asm.Convert_xce(len(call.Arguments)-1, sqltypes.VarChar, c.collation) + c.asm.Convert_xce(1, sqltypes.VarChar, c.collation) str.Col = collations.TypedCollation{ Collation: c.collation, Coercibility: collations.CoerceCoercible, @@ -1713,7 +1706,7 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { fromCharset := colldata.Lookup(substr.Col.Collation).Charset() toCharset := colldata.Lookup(str.Col.Collation).Charset() if !substr.isTextual() || (fromCharset != toCharset && !toCharset.IsSuperset(fromCharset)) { - c.asm.Convert_xce(len(call.Arguments), sqltypes.VarChar, str.Col.Collation) + c.asm.Convert_xce(2, sqltypes.VarChar, str.Col.Collation) substr.Col = collations.TypedCollation{ Collation: str.Col.Collation, Coercibility: collations.CoerceCoercible, @@ -1721,6 +1714,15 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { } } + if len(call.Arguments) > 2 { + l, err := call.Arguments[2].compile(c) + if err != nil { + return ctype{}, err + } + skip3 = c.compileNullCheck1(l) + _ = c.compileToInt64(l, 1) + } + var coll colldata.Collation if typeIsTextual(substr.Type) && typeIsTextual(str.Type) { coll = colldata.Lookup(str.Col.Collation) @@ -1734,7 +1736,7 @@ func (call *builtinLocate) compile(c *compiler) (ctype, error) { c.asm.Locate2(coll) } - c.asm.jumpDestination(skip1, skip2) + c.asm.jumpDestination(skip1, skip2, skip3) return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable}, nil }