From 1de8b0887d49e7265c28d23dc469b21cedd2a6b8 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 11:49:39 +0200 Subject: [PATCH] fixed remaining rewriter bugs Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 70 +++++ go/vt/sqlparser/predicate_rewriting.go | 268 +++++++++--------- .../planbuilder/predicate_rewrite_test.go | 33 ++- 3 files changed, 229 insertions(+), 142 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index c94ccde2894..bce28502631 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "io" + "slices" "strconv" "strings" @@ -2864,3 +2865,72 @@ func ExtractAllTables(stmt Statement) []string { }, stmt) return tables } + +// ExtractINFromOR rewrites the OR expression into an IN clause. +// Each side of each ORs has to be an equality comparison expression and the column names have to +// match for all sides of each comparison. +// This rewriter takes a query that looks like this WHERE a = 1 and b = 11 or a = 2 and b = 12 or a = 3 and b = 13 +// And rewrite that to WHERE (a, b) IN ((1,11), (2,12), (3,13)) +func ExtractINFromOR(expr *OrExpr) []Expr { + var varNames []*ColName + var values []Exprs + orSlice := orToSlice(expr) + for _, expr := range orSlice { + andSlice := andToSlice(expr) + if len(andSlice) == 0 { + return nil + } + + var currentVarNames []*ColName + var currentValues []Expr + for _, comparisonExpr := range andSlice { + if comparisonExpr.Operator != EqualOp { + return nil + } + + var colName *ColName + if left, ok := comparisonExpr.Left.(*ColName); ok { + colName = left + currentValues = append(currentValues, comparisonExpr.Right) + } + + if right, ok := comparisonExpr.Right.(*ColName); ok { + if colName != nil { + return nil + } + colName = right + currentValues = append(currentValues, comparisonExpr.Left) + } + + if colName == nil { + return nil + } + + currentVarNames = append(currentVarNames, colName) + } + + if len(varNames) == 0 { + varNames = currentVarNames + } else if !slices.EqualFunc(varNames, currentVarNames, func(col1, col2 *ColName) bool { return col1.Equal(col2) }) { + return nil + } + + values = append(values, currentValues) + } + + var nameTuple ValTuple + for _, name := range varNames { + nameTuple = append(nameTuple, name) + } + + var valueTuple ValTuple + for _, value := range values { + valueTuple = append(valueTuple, ValTuple(value)) + } + + return []Expr{&ComparisonExpr{ + Operator: InOp, + Left: nameTuple, + Right: valueTuple, + }} +} diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index f29f640c5e6..7635553ec6d 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -17,13 +17,19 @@ limitations under the License. package sqlparser import ( - "slices" + "fmt" ) +var DebugRewrite = false + // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN func RewritePredicate(ast SQLNode) SQLNode { - original := CloneSQLNode(ast) + return RewritePredicateInternal(ast, nil) +} + +func RewritePredicateInternal(ast SQLNode, view func(SQLNode)) SQLNode { + original := Clone(ast) // Beware: converting to CNF in this loop might cause exponential formula growth. // We bail out early to prevent going overboard. @@ -32,6 +38,10 @@ func RewritePredicate(ast SQLNode) SQLNode { stopOnChange := func(SQLNode, SQLNode) bool { return !exprChanged } + if DebugRewrite { + fmt.Println(String(ast)) + } + ast = SafeRewrite(ast, stopOnChange, func(cursor *Cursor) bool { e, isExpr := cursor.node.(Expr) if !isExpr { @@ -46,6 +56,10 @@ func RewritePredicate(ast SQLNode) SQLNode { return !exprChanged }) + if view != nil { + view(Clone(ast)) + } + if !exprChanged { return ast } @@ -74,6 +88,9 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { return child.Expr, true case *OrExpr: // not(or(a,b)) => and(not(a),not(b)) + if DebugRewrite { + fmt.Println(" >> not (a or b) => not a and not b") + } return AndExpressions(&NotExpr{Expr: child.Left}, &NotExpr{Expr: child.Right}), true case *AndExpr: // not(and(a,b)) => or(not(a), not(b)) @@ -85,6 +102,9 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { curr = &OrExpr{Left: curr, Right: &NotExpr{Expr: p}} } } + if DebugRewrite { + fmt.Println(" >> not (a and b) => not a or not b") + } return curr, true } return expr, false @@ -97,81 +117,68 @@ func createOrs(exprs ...Expr) Expr { return &OrExpr{Left: exprs[0], Right: createOrs(exprs[1:]...)} } -func simplifyOr(or *OrExpr) (Expr, bool) { - res, rewritten := distinctOr(or) - if rewritten { - return res, true - } - - land, lok := or.Left.(*AndExpr) - rand, rok := or.Right.(*AndExpr) - - if lok && rok { - // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) - var commonPredicates []Expr - var leftRemainder, rightRemainder []Expr - - // Find all matching predicates and separate the remainder - rightRemainder = rand.Predicates - for _, lp := range land.Predicates { - rhs := rightRemainder - rightRemainder = nil - isCommon := false - for _, rp := range rhs { - if Equals.Expr(lp, rp) { - commonPredicates = append(commonPredicates, lp) - isCommon = true - } else { - rightRemainder = append(rightRemainder, rp) - } - } - if !isCommon { - leftRemainder = append(leftRemainder, lp) +func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { + // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) + var commonPredicates []Expr + var leftRemainder, rightRemainder []Expr + + // Find all matching predicates and separate the remainder + rightRemainder = rhs.Predicates + for _, lp := range lhs.Predicates { + rhs := rightRemainder + rightRemainder = nil + isCommon := false + for _, rp := range rhs { + if Equals.Expr(lp, rp) { + commonPredicates = append(commonPredicates, lp) + isCommon = true + } else { + rightRemainder = append(rightRemainder, rp) } } - - if len(commonPredicates) > 0 { - // Build the final AndExpr with common predicates and the OrExpr of remainders - nonCommonPredicates := append(leftRemainder, rightRemainder...) - commonPred := AndExpressions(commonPredicates...) - if len(nonCommonPredicates) == 0 { - return commonPred, true - } - return AndExpressions(commonPred, createOrs(nonCommonPredicates...)), true + if !isCommon { + leftRemainder = append(leftRemainder, lp) } - return or, false } - if !lok && !rok { - lftCmp, lok := or.Left.(*ComparisonExpr) - rgtCmp, rok := or.Right.(*ComparisonExpr) - if lok && rok { - newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) - if rewritten { - // or(a=x,a=y) => in(a,[x,y]) - return newExpr, true - } - } + if len(commonPredicates) == 0 { return or, false } - // if we get here, one side is an AND - var and *AndExpr - var other Expr - if lok { - and = land - other = or.Right - } else { - and = rand - other = or.Left + // Build the final AndExpr with common predicates and the OrExpr of remainders + commonPred := AndExpressions(commonPredicates...) + if len(leftRemainder) == 0 && len(rightRemainder) == 0 { + if DebugRewrite { + fmt.Println(" >> remove duplicate predicates across ANDs") + } + return commonPred, true + } + + switch { + case len(rightRemainder) == 0 || len(leftRemainder) == 0: + if DebugRewrite { + fmt.Println(" >> (A and B and A and D) or (A and D) => A and D") + } + return commonPred, true + default: + if DebugRewrite { + fmt.Println(" >> (A and B and D) or (A and D and C) => (A and D) and (B or C)") + } + + return AndExpressions(commonPred, createOrs(AndExpressions(rightRemainder...), AndExpressions(leftRemainder...))), true } +} +func simplifyOrWithAnAND(and *AndExpr, other Expr, left bool) (Expr, bool) { for _, lp := range and.Predicates { if Equals.Expr(other, lp) { // if we have the same predicate on both sides of the OR, we can simplify // (A AND B) OR A => A // because if A is true, the OR is true, not matter what B is, // and if A is false, the AND is false, and again we don't care about B + if DebugRewrite { + fmt.Println(" >> (A AND B) OR A => A") + } return other, true } } @@ -180,18 +187,73 @@ func simplifyOr(or *OrExpr) (Expr, bool) { var distributedPredicates []Expr for _, lp := range and.Predicates { var or *OrExpr - if lok { + if left { or = &OrExpr{Left: lp, Right: other} } else { or = &OrExpr{Left: other, Right: lp} } distributedPredicates = append(distributedPredicates, or) } + if DebugRewrite { + fmt.Println(" >> (A and B) or C => (A or C) and (B or C)") + } return AndExpressions(distributedPredicates...), true } +func simplifyOr(or *OrExpr) (Expr, bool) { + res, rewritten := distinctOr(or) + if rewritten { + if DebugRewrite { + fmt.Println(" >> distinct or elements") + } + + return res, true + } + + land, lok := or.Left.(*AndExpr) + rand, rok := or.Right.(*AndExpr) + + switch { + case lok && rok: + return simplifyOredAnds(or, land, rand) + case !lok && !rok: + return simplifyOrToIn(or) + default: + // if we get here, one side is an AND + var and *AndExpr + var other Expr + if lok { + and = land + other = or.Right + } else { + and = rand + other = or.Left + } + + return simplifyOrWithAnAND(and, other, lok) + } +} + +func simplifyOrToIn(or *OrExpr) (Expr, bool) { + lftCmp, lok := or.Left.(*ComparisonExpr) + rgtCmp, rok := or.Right.(*ComparisonExpr) + if lok && rok { + newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) + if rewritten { + if DebugRewrite { + fmt.Println(" >> turning OR into IN") + } + return newExpr, true + } + } + + return or, false +} + func simplifyXor(xor *XorExpr) (Expr, bool) { - // xor(a,b) => and(or(a,b), not(and(a,b)) + if DebugRewrite { + fmt.Println(" >> a xor b => (a or b) and not(a and b)") + } return AndExpressions( &OrExpr{Left: xor.Left, Right: xor.Right}, &NotExpr{Expr: AndExpressions(xor.Left, xor.Right)}, @@ -200,11 +262,18 @@ func simplifyXor(xor *XorExpr) (Expr, bool) { func simplifyAnd(expr *AndExpr) (Expr, bool) { if len(expr.Predicates) == 1 { + if DebugRewrite { + fmt.Println(" >> single predicate in AND") + } return expr.Predicates[0], true } res, rewritten := distinctAnd(expr) if rewritten { + if DebugRewrite { + fmt.Println(" >> distinct and elements") + } + return res, true } @@ -236,6 +305,10 @@ outer: } if simplified { + if DebugRewrite { + fmt.Println(" >> (a or b) and a => a") + } + // Return a new AndExpr with the simplified predicates return AndExpressions(simplifiedPredicates...), true } @@ -243,75 +316,6 @@ outer: return expr, false } -// ExtractINFromOR rewrites the OR expression into an IN clause. -// Each side of each ORs has to be an equality comparison expression and the column names have to -// match for all sides of each comparison. -// This rewriter takes a query that looks like this WHERE a = 1 and b = 11 or a = 2 and b = 12 or a = 3 and b = 13 -// And rewrite that to WHERE (a, b) IN ((1,11), (2,12), (3,13)) -func ExtractINFromOR(expr *OrExpr) []Expr { - var varNames []*ColName - var values []Exprs - orSlice := orToSlice(expr) - for _, expr := range orSlice { - andSlice := andToSlice(expr) - if len(andSlice) == 0 { - return nil - } - - var currentVarNames []*ColName - var currentValues []Expr - for _, comparisonExpr := range andSlice { - if comparisonExpr.Operator != EqualOp { - return nil - } - - var colName *ColName - if left, ok := comparisonExpr.Left.(*ColName); ok { - colName = left - currentValues = append(currentValues, comparisonExpr.Right) - } - - if right, ok := comparisonExpr.Right.(*ColName); ok { - if colName != nil { - return nil - } - colName = right - currentValues = append(currentValues, comparisonExpr.Left) - } - - if colName == nil { - return nil - } - - currentVarNames = append(currentVarNames, colName) - } - - if len(varNames) == 0 { - varNames = currentVarNames - } else if !slices.EqualFunc(varNames, currentVarNames, func(col1, col2 *ColName) bool { return col1.Equal(col2) }) { - return nil - } - - values = append(values, currentValues) - } - - var nameTuple ValTuple - for _, name := range varNames { - nameTuple = append(nameTuple, name) - } - - var valueTuple ValTuple - for _, value := range values { - valueTuple = append(valueTuple, ValTuple(value)) - } - - return []Expr{&ComparisonExpr{ - Operator: InOp, - Left: nameTuple, - Right: valueTuple, - }} -} - func orToSlice(expr *OrExpr) []Expr { var exprs []Expr diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index 9c1e25624ce..ce56b5e85bc 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -86,15 +86,21 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { func TestOneRewriting(t *testing.T) { venv := vtenv.NewTestEnv() + sqlparser.DebugRewrite = true // Modify these const numberOfColumns = 2 - const expr = "n1 and n0 or n1 xor n1" + const expr = "not (n1 xor n0 or n1)" predicate, err := sqlparser.NewTestParser().ParseExpr(expr) require.NoError(t, err) - simplified := sqlparser.RewritePredicate(predicate) + var steps []sqlparser.SQLNode + + simplified := sqlparser.RewritePredicateInternal(predicate, func(n sqlparser.SQLNode) { + steps = append(steps, n) + }) + fmt.Println(sqlparser.String(simplified)) cfg := &evalengine.Config{ Environment: venv, @@ -103,16 +109,23 @@ func TestOneRewriting(t *testing.T) { } original, err := evalengine.Translate(predicate, cfg) require.NoError(t, err) - simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg) - require.NoError(t, err) - env := evalengine.EmptyExpressionEnv(venv) - env.Row = make([]sqltypes.Value, numberOfColumns) - for i := range env.Row { - env.Row[i] = sqltypes.NULL + for _, step := range steps { + name := sqlparser.String(step) + t.Run(name, func(t *testing.T) { + simpler, err := evalengine.Translate(step.(sqlparser.Expr), cfg) + require.NoError(t, err) + + env := evalengine.EmptyExpressionEnv(venv) + env.Row = make([]sqltypes.Value, numberOfColumns) + for i := range env.Row { + env.Row[i] = sqltypes.NULL + } + + testValues(t, env, 0, original, simpler) + }) } - testValues(t, env, 0, original, simpler) } func TestFuzzRewriting(t *testing.T) { @@ -125,7 +138,7 @@ func TestFuzzRewriting(t *testing.T) { start := time.Now() for time.Since(start) < 1*time.Second { tc := testCase{ - nodes: 2, + nodes: rand.IntN(4) + 1, depth: rand.IntN(4) + 1, }