From de1262e1903a47b5b3f4f8fb96a3e5ece3a09731 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 27 Aug 2024 12:59:34 +0200 Subject: [PATCH 01/14] wip - change AndExpr to contain arbitrary many predicates Signed-off-by: Andres Taylor --- go/vt/sqlparser/analyzer_test.go | 8 +- go/vt/sqlparser/ast.go | 4 +- go/vt/sqlparser/ast_clone.go | 27 ++- go/vt/sqlparser/ast_copy_on_rewrite.go | 16 +- go/vt/sqlparser/ast_equals.go | 29 ++- go/vt/sqlparser/ast_format.go | 8 +- go/vt/sqlparser/ast_format_fast.go | 11 +- go/vt/sqlparser/ast_funcs.go | 62 +++-- go/vt/sqlparser/ast_rewrite.go | 17 +- go/vt/sqlparser/ast_visit.go | 9 +- go/vt/sqlparser/cached_size.go | 16 +- go/vt/sqlparser/precedence_test.go | 12 +- go/vt/sqlparser/predicate_rewriting.go | 243 +++++++++++--------- go/vt/sqlparser/predicate_rewriting_test.go | 23 +- go/vt/sqlparser/random_expr.go | 17 +- go/vt/sqlparser/sql.go | 2 +- go/vt/sqlparser/sql.y | 2 +- go/vt/sqlparser/testdata/select_cases.txt | 10 +- go/vt/sqlparser/utils.go | 5 +- 19 files changed, 292 insertions(+), 229 deletions(-) diff --git a/go/vt/sqlparser/analyzer_test.go b/go/vt/sqlparser/analyzer_test.go index 0a2de52ef19..87a5cb81dea 100644 --- a/go/vt/sqlparser/analyzer_test.go +++ b/go/vt/sqlparser/analyzer_test.go @@ -209,10 +209,10 @@ func TestAndExpressions(t *testing.T) { equalExpr, equalExpr, }, - expectedOutput: &AndExpr{ - Left: greaterThanExpr, - Right: equalExpr, - }, + expectedOutput: &AndExpr{Predicates: Exprs{ + greaterThanExpr, + equalExpr, + }}, }, { name: "two equal inputs", diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 938b9063011..124833bc91d 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2263,9 +2263,7 @@ type ( } // AndExpr represents an AND expression. - AndExpr struct { - Left, Right Expr - } + AndExpr struct{ Predicates []Expr } // OrExpr represents an OR expression. OrExpr struct { diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index f22a1790232..00330ff1972 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -744,8 +744,7 @@ func CloneRefOfAndExpr(n *AndExpr) *AndExpr { return nil } out := *n - out.Left = CloneExpr(n.Left) - out.Right = CloneExpr(n.Right) + out.Predicates = CloneSliceOfExpr(n.Predicates) return &out } @@ -4369,6 +4368,18 @@ func CloneSliceOfIdentifierCI(n []IdentifierCI) []IdentifierCI { return res } +// CloneSliceOfExpr creates a deep clone of the input. +func CloneSliceOfExpr(n []Expr) []Expr { + if n == nil { + return nil + } + res := make([]Expr, len(n)) + for i, x := range n { + res[i] = CloneExpr(x) + } + return res +} + // CloneSliceOfTxAccessMode creates a deep clone of the input. func CloneSliceOfTxAccessMode(n []TxAccessMode) []TxAccessMode { if n == nil { @@ -4458,18 +4469,6 @@ func CloneSliceOfRefOfVariable(n []*Variable) []*Variable { return res } -// CloneSliceOfExpr creates a deep clone of the input. -func CloneSliceOfExpr(n []Expr) []Expr { - if n == nil { - return nil - } - res := make([]Expr, len(n)) - for i, x := range n { - res[i] = CloneExpr(x) - } - return res -} - // CloneRefOfIdentifierCI creates a deep clone of the input. func CloneRefOfIdentifierCI(n *IdentifierCI) *IdentifierCI { if n == nil { diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 0e329e24f31..2caea4bd978 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -953,12 +953,18 @@ func (c *cow) copyOnRewriteRefOfAndExpr(n *AndExpr, parent SQLNode) (out SQLNode } out = n if c.pre == nil || c.pre(n, parent) { - _Left, changedLeft := c.copyOnRewriteExpr(n.Left, n) - _Right, changedRight := c.copyOnRewriteExpr(n.Right, n) - if changedLeft || changedRight { + var changedPredicates bool + _Predicates := make([]Expr, len(n.Predicates)) + for x, el := range n.Predicates { + this, changed := c.copyOnRewriteExpr(el, n) + _Predicates[x] = this.(Expr) + if changed { + changedPredicates = true + } + } + if changedPredicates { res := *n - res.Left, _ = _Left.(Expr) - res.Right, _ = _Right.(Expr) + res.Predicates = _Predicates out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index cf076d706e7..66ec6190159 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1863,8 +1863,7 @@ func (cmp *Comparator) RefOfAndExpr(a, b *AndExpr) bool { if a == nil || b == nil { return false } - return cmp.Expr(a.Left, b.Left) && - cmp.Expr(a.Right, b.Right) + return cmp.SliceOfExpr(a.Predicates, b.Predicates) } // RefOfAnyValue does deep equals between the two objects. @@ -7246,6 +7245,19 @@ func (cmp *Comparator) SliceOfIdentifierCI(a, b []IdentifierCI) bool { return true } +// SliceOfExpr does deep equals between the two objects. +func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !cmp.Expr(a[i], b[i]) { + return false + } + } + return true +} + // SliceOfTxAccessMode does deep equals between the two objects. func (cmp *Comparator) SliceOfTxAccessMode(a, b []TxAccessMode) bool { if len(a) != len(b) { @@ -7354,19 +7366,6 @@ func (cmp *Comparator) SliceOfRefOfVariable(a, b []*Variable) bool { return true } -// SliceOfExpr does deep equals between the two objects. -func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !cmp.Expr(a[i], b[i]) { - return false - } - } - return true -} - // RefOfIdentifierCI does deep equals between the two objects. func (cmp *Comparator) RefOfIdentifierCI(a, b *IdentifierCI) bool { if a == b { diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 587b32d4afe..f8fea84ab4f 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1297,7 +1297,13 @@ func (node Exprs) Format(buf *TrackedBuffer) { // Format formats the node. func (node *AndExpr) Format(buf *TrackedBuffer) { - buf.astPrintf(node, "%l and %r", node.Left, node.Right) + for idx, expr := range node.Predicates { + if idx == len(node.Predicates)-1 { + buf.astPrintf(node, "%r", expr) + continue + } + buf.astPrintf(node, "%l and ", expr) + } } // Format formats the node. diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index c2b02711398..66cedf45f6d 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1684,9 +1684,14 @@ func (node Exprs) FormatFast(buf *TrackedBuffer) { // FormatFast formats the node. func (node *AndExpr) FormatFast(buf *TrackedBuffer) { - buf.printExpr(node, node.Left, true) - buf.WriteString(" and ") - buf.printExpr(node, node.Right, false) + for idx, expr := range node.Predicates { + if idx == len(node.Predicates)-1 { + buf.printExpr(node, expr, false) + continue + } + buf.printExpr(node, expr, true) + buf.WriteString(" and ") + } } // FormatFast formats the node. diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index ea7a6e93e0e..db6965de1f3 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -66,6 +66,27 @@ func Append(buf *strings.Builder, node SQLNode) { node.FormatFast(tbuf) } +func createAndExpr(exprL, exprR Expr) *AndExpr { + leftAnd, isLeftAnd := exprL.(*AndExpr) + rightAnd, isRightAnd := exprR.(*AndExpr) + if isLeftAnd && isRightAnd { + return &AndExpr{ + Predicates: append(leftAnd.Predicates, rightAnd.Predicates...), + } + } + if isLeftAnd { + leftAnd.Predicates = append(leftAnd.Predicates, exprR) + return leftAnd + } + if isRightAnd { + rightAnd.Predicates = append([]Expr{exprL}, rightAnd.Predicates...) + return rightAnd + } + return &AndExpr{ + Predicates: Exprs{exprL, exprR}, + } +} + // IndexColumn describes a column or expression in an index definition with optional length (for column) type IndexColumn struct { // Only one of Column or Expression can be specified @@ -1239,10 +1260,8 @@ func addPredicate(where *Where, pred Expr) *Where { Expr: pred, } } - where.Expr = &AndExpr{ - Left: where.Expr, - Right: pred, - } + + where.Expr = createAndExpr(where.Expr, pred) return where } @@ -2336,8 +2355,7 @@ func SplitAndExpression(filters []Expr, node Expr) []Expr { } switch node := node.(type) { case *AndExpr: - filters = SplitAndExpression(filters, node.Left) - return SplitAndExpression(filters, node.Right) + return append(filters, node.Predicates...) } return append(filters, node) } @@ -2350,26 +2368,30 @@ func AndExpressions(exprs ...Expr) Expr { case 1: return exprs[0] default: - result := (Expr)(nil) - outer: + var unique Exprs // we'll loop and remove any duplicates - for i, expr := range exprs { - if expr == nil { - continue - } - if result == nil { - result = expr - continue outer + uniqueAdd := func(e Expr) { + for _, existing := range unique { + if Equals.Expr(e, existing) { + return + } } + unique = append(unique, e) + } - for j := 0; j < i; j++ { - if Equals.Expr(expr, exprs[j]) { - continue outer + for _, expr := range exprs { + switch expr := expr.(type) { + case *AndExpr: + for _, p := range expr.Predicates { + uniqueAdd(p) } + case nil: + continue + default: + uniqueAdd(expr) } - result = &AndExpr{Left: result, Right: expr} } - return result + return &AndExpr{Predicates: unique} } } diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 015c27a2cbd..df144e4fad8 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1088,15 +1088,14 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace return true } } - if !a.rewriteExpr(node, node.Left, func(newNode, parent SQLNode) { - parent.(*AndExpr).Left = newNode.(Expr) - }) { - return false - } - if !a.rewriteExpr(node, node.Right, func(newNode, parent SQLNode) { - parent.(*AndExpr).Right = newNode.(Expr) - }) { - return false + for x, el := range node.Predicates { + if !a.rewriteExpr(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*AndExpr).Predicates[idx] = newNode.(Expr) + } + }(x)) { + return false + } } if a.post != nil { a.cur.replacer = replacer diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index d33c2d1e055..04608d84ec2 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -811,11 +811,10 @@ func VisitRefOfAndExpr(in *AndExpr, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - if err := VisitExpr(in.Left, f); err != nil { - return err - } - if err := VisitExpr(in.Right, f); err != nil { - return err + for _, el := range in.Predicates { + if err := VisitExpr(el, f); err != nil { + return err + } } return nil } diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 391e9a84ad3..aa731d4799d 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -319,15 +319,17 @@ func (cached *AndExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(32) - } - // field Left vitess.io/vitess/go/vt/sqlparser.Expr - if cc, ok := cached.Left.(cachedObject); ok { - size += cc.CachedSize(true) + size += int64(64) } // field Right vitess.io/vitess/go/vt/sqlparser.Expr - if cc, ok := cached.Right.(cachedObject); ok { - size += cc.CachedSize(true) + // field Predicates vitess.io/vitess/go/vt/sqlparser.Exprs + { + size += hack.RuntimeAllocSize(int64(cap(cached.Predicates)) * int64(16)) + for _, elem := range cached.Predicates { + if cc, ok := elem.(cachedObject); ok { + size += cc.CachedSize(true) + } + } } return size } diff --git a/go/vt/sqlparser/precedence_test.go b/go/vt/sqlparser/precedence_test.go index 0a14df5a2c1..f19a6c7696f 100644 --- a/go/vt/sqlparser/precedence_test.go +++ b/go/vt/sqlparser/precedence_test.go @@ -18,10 +18,13 @@ package sqlparser import ( "fmt" + "strings" "testing" "time" "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/slice" ) func readable(node Expr) string { @@ -29,7 +32,10 @@ func readable(node Expr) string { case *OrExpr: return fmt.Sprintf("(%s or %s)", readable(node.Left), readable(node.Right)) case *AndExpr: - return fmt.Sprintf("(%s and %s)", readable(node.Left), readable(node.Right)) + predicates := slice.Map(node.Predicates, func(from Expr) string { + return readable(from) + }) + return fmt.Sprintf("(%s)", strings.Join(predicates, " and ")) case *XorExpr: return fmt.Sprintf("(%s xor %s)", readable(node.Left), readable(node.Right)) case *BinaryExpr: @@ -153,7 +159,7 @@ func TestParens(t *testing.T) { {in: "((((((1000))))))", expected: "1000"}, {in: "100 - (50 + 10)", expected: "100 - (50 + 10)"}, {in: "100 - 50 + 10", expected: "100 - 50 + 10"}, - {in: "true and (true and true)", expected: "true and (true and true)"}, + {in: "true and (true and true)", expected: "true and true and true"}, {in: "10 - 2 - 1", expected: "10 - 2 - 1"}, {in: "(10 - 2) - 1", expected: "10 - 2 - 1"}, {in: "10 - (2 - 1)", expected: "10 - (2 - 1)"}, @@ -193,6 +199,6 @@ func TestRandom(t *testing.T) { // Then the unparsing should be the same as the input query outputOfParseResult := String(parsedInput) - require.Equal(t, outputOfParseResult, inputQ) + require.Equal(t, inputQ, outputOfParseResult) } } diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 234a2f4acd5..e4e859b2ba2 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -16,7 +16,9 @@ limitations under the License. package sqlparser -import "slices" +import ( + "slices" +) // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN @@ -72,109 +74,127 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { return child.Expr, true case *OrExpr: // not(or(a,b)) => and(not(a),not(b)) - return &AndExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true + return AndExpressions(&NotExpr{Expr: child.Left}, &NotExpr{Expr: child.Right}), true case *AndExpr: // not(and(a,b)) => or(not(a), not(b)) - return &OrExpr{Right: &NotExpr{Expr: child.Right}, Left: &NotExpr{Expr: child.Left}}, true + var curr Expr + for i, p := range child.Predicates { + if i == 0 { + curr = &NotExpr{Expr: p} + } else { + curr = &OrExpr{Left: curr, Right: &NotExpr{Expr: p}} + } + } + return curr, true } return expr, false } -func simplifyOr(expr *OrExpr) (Expr, bool) { - res, rewritten := distinctOr(expr) +func simplifyOr(or *OrExpr) (Expr, bool) { + res, rewritten := distinctOr(or) if rewritten { return res, true } - or := expr - - // first we search for ANDs and see how they can be simplified land, lok := or.Left.(*AndExpr) rand, rok := or.Right.(*AndExpr) if lok && rok { - // (<> AND <>) OR (<> AND <>) - // or(and(T1,T2), and(T2, T3)) => and(T1, or(T2, T2)) - var a, b, c Expr - switch { - case Equals.Expr(land.Left, rand.Left): - a, b, c = land.Left, land.Right, rand.Right - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true - case Equals.Expr(land.Left, rand.Right): - a, b, c = land.Left, land.Right, rand.Left - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true - case Equals.Expr(land.Right, rand.Left): - a, b, c = land.Right, land.Left, rand.Right - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true - case Equals.Expr(land.Right, rand.Right): - a, b, c = land.Right, land.Left, rand.Left - return &AndExpr{Left: a, Right: &OrExpr{Left: b, Right: c}}, true + // (A AND B) OR (A AND C) => A OR (B AND C) + var commonPredicates []Expr + var leftRemainder, rightRemainder []Expr + + // Find all matching predicates and separate the remainder + rightRemainder = rand.Predicates + for _, lp := range land.Predicates { + rhs := rightRemainder + rightRemainder = nil + isCommon := false + for _, rp := range rhs { + if Equals.Expr(lp, rp) { + commonPredicates = append(commonPredicates, lp) + isCommon = true + } else { + rightRemainder = append(rightRemainder, rp) + } + } + if !isCommon { + leftRemainder = append(leftRemainder, lp) + } } - } - // (<> AND <>) OR <> - if lok { - // Simplification - if Equals.Expr(or.Right, land.Left) || Equals.Expr(or.Right, land.Right) { - // or(and(a,b), c) => c where c=a or c=b - return or.Right, true + if len(commonPredicates) > 0 { + // Build the final AndExpr with common predicates and the OrExpr of remainders + var notCommon Expr + switch { + case len(leftRemainder) == 0 && len(rightRemainder) == 0: + // all expressions were common + return AndExpressions(commonPredicates...), true + case len(leftRemainder) == 0: + notCommon = AndExpressions(rightRemainder...) + case len(rightRemainder) == 0: + notCommon = AndExpressions(leftRemainder...) + default: + notCommon = &OrExpr{ + Left: AndExpressions(leftRemainder...), + Right: AndExpressions(rightRemainder...), + } + } + return AndExpressions(append(commonPredicates, notCommon)...), true } - - // Distribution Law - // or(c, and(a,b)) => and(or(c,a), or(c,b)) - return &AndExpr{ - Left: &OrExpr{ - Left: land.Left, - Right: or.Right, - }, - Right: &OrExpr{ - Left: land.Right, - Right: or.Right, - }, - }, true } - - // <> OR (<> AND <>) - if rok { - // Simplification - if Equals.Expr(or.Left, rand.Left) || Equals.Expr(or.Left, rand.Right) { - // or(a,and(b,c)) => a - return or.Left, true + if !lok && !rok { + lftCmp, lok := or.Left.(*ComparisonExpr) + rgtCmp, rok := or.Right.(*ComparisonExpr) + if lok && rok { + newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) + if rewritten { + // or(a=x,a=y) => in(a,[x,y]) + return newExpr, true + } } - // Distribution Law - // or(and(a,b), c) => and(or(c,a), or(c,b)) - return &AndExpr{ - Left: &OrExpr{Left: or.Left, Right: rand.Left}, - Right: &OrExpr{Left: or.Left, Right: rand.Right}, - }, true + return or, false } - // next, we want to try to turn multiple ORs into an IN when possible - lftCmp, lok := or.Left.(*ComparisonExpr) - rgtCmp, rok := or.Right.(*ComparisonExpr) - if lok && rok { - newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) - if rewritten { - // or(a=x,a=y) => in(a,[x,y]) - return newExpr, true + // if we get here, one side is an AND + var and *AndExpr + var other Expr + if lok { + and = land + other = or.Right + } else { + and = rand + other = or.Left + } + + for _, lp := range and.Predicates { + if Equals.Expr(other, lp) { + // if we have the same predicate on both sides of the OR, we can simplify + // (A AND B) OR A => A + // because if A is true, the OR is true, not matter what B is, + // and if A is false, the AND is false, and again we don't care about B + return other, true } } - // Try to make distinct - result, changed := distinctOr(expr) - if changed { - return result, true + // Distribution Law + var distributedPredicates []Expr + for _, lp := range and.Predicates { + distributedPredicates = append(distributedPredicates, &OrExpr{ + Left: lp, + Right: other, + }) } - return result, false + return AndExpressions(distributedPredicates...), true } -func simplifyXor(expr *XorExpr) (Expr, bool) { +func simplifyXor(xor *XorExpr) (Expr, bool) { // xor(a,b) => and(or(a,b), not(and(a,b)) - return &AndExpr{ - Left: &OrExpr{Left: expr.Left, Right: expr.Right}, - Right: &NotExpr{Expr: &AndExpr{Left: expr.Left, Right: expr.Right}}, - }, true + return AndExpressions( + &OrExpr{Left: xor.Left, Right: xor.Right}, + &NotExpr{Expr: AndExpressions(xor.Left, xor.Right)}, + ), true } func simplifyAnd(expr *AndExpr) (Expr, bool) { @@ -182,25 +202,36 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) { if rewritten { return res, true } - and := expr - if or, ok := and.Left.(*OrExpr); ok { - // Simplification - // and(or(a,b),c) => c when c=a or c=b - if Equals.Expr(or.Left, and.Right) { - return and.Right, true - } - if Equals.Expr(or.Right, and.Right) { - return and.Right, true + + var simplifiedPredicates []Expr + simplified := false + + // Loop over all predicates in the AndExpr + for i, andPred := range expr.Predicates { + if or, ok := andPred.(*OrExpr); ok { + // Check if we can simplify by matching with another predicate in the AndExpr + for j, otherPred := range expr.Predicates { + if i == j { + continue // Skip the same predicate + } + + // Simplification: and(or(a,b), a) => a + if Equals.Expr(or.Left, otherPred) || Equals.Expr(or.Right, otherPred) { + // Found a match, keep the simpler expression (otherPred) + simplifiedPredicates = append(simplifiedPredicates, otherPred) + simplified = true + break + } + } + } else { + // No simplification possible, keep the original predicate + simplifiedPredicates = append(simplifiedPredicates, andPred) } } - if or, ok := and.Right.(*OrExpr); ok { - // Simplification - if Equals.Expr(or.Left, and.Left) { - return and.Left, true - } - if Equals.Expr(or.Right, and.Left) { - return and.Left, true - } + + if simplified { + // Return a new AndExpr with the simplified predicates + return AndExpressions(simplifiedPredicates...), true } return expr, false @@ -292,6 +323,7 @@ func orToSlice(expr *OrExpr) []Expr { return exprs } +// andToSlice will return a slice of comparisons, containing all the comparison expressions in the AND expression func andToSlice(expr Expr) []*ComparisonExpr { var andExpr *AndExpr switch expr := expr.(type) { @@ -320,11 +352,10 @@ func andToSlice(expr Expr) []*ComparisonExpr { return true } - if !handleAndSide(andExpr.Left) { - return nil - } - if !handleAndSide(andExpr.Right) { - return nil + for _, p := range andExpr.Predicates { + if !handleAndSide(p) { + return nil + } } return exprs @@ -438,16 +469,17 @@ func distinctAnd(in *AndExpr) (result Expr, changed bool) { for len(todo) > 0 { curr := todo[0] todo = todo[1:] - addExpr := func(in Expr) { - if and, ok := in.(*AndExpr); ok { + for _, p := range curr.Predicates { + if and, ok := p.(*AndExpr); ok { + // we will flatten the ANDs + changed = true todo = append(todo, and) } else { - leaves = append(leaves, in) + leaves = append(leaves, p) } } - addExpr(curr.Left) - addExpr(curr.Right) } + var predicates []Expr outer1: @@ -465,12 +497,5 @@ outer1: return in, false } - for i, curr := range predicates { - if i == 0 { - result = curr - continue - } - result = &AndExpr{Left: result, Right: curr} - } return AndExpressions(leaves...), true } diff --git a/go/vt/sqlparser/predicate_rewriting_test.go b/go/vt/sqlparser/predicate_rewriting_test.go index ceb4b276017..ef7799611bc 100644 --- a/go/vt/sqlparser/predicate_rewriting_test.go +++ b/go/vt/sqlparser/predicate_rewriting_test.go @@ -44,7 +44,7 @@ func TestSimplifyExpression(in *testing.T) { expected: "(A or C) and (B or C)", }, { in: "C or (A and B)", - expected: "(C or A) and (C or B)", + expected: "(A or C) and (B or C)", }, { in: "A and A", expected: "A", @@ -111,10 +111,13 @@ func TestRewritePredicate(in *testing.T) { expected: "A and B", }, { in: "((A and B) OR (A and C) OR (A and D)) and E and F", - expected: "A and (B or C or D) and E and F", + expected: "E and F and A and (B or C or D)", }, { - in: "(A and B) OR (A and C)", - expected: "A and (B or C)", + in: "(A and B) OR (A and C) OR (A and D)", + expected: "A and (B or C or D)", + }, { + in: "((A and B) OR (A and C)) and E", + expected: "E and A and (B or C)", }, { in: "(A and B) OR (C and A)", expected: "A and (B or C)", @@ -133,26 +136,26 @@ func TestRewritePredicate(in *testing.T) { }, { in: "(a = 1 and b = 41) or (a = 2 and b = 42)", // this might look weird, but it allows the planner to either a or b in a vindex operation - expected: "a in (1, 2) and (a = 1 or b = 42) and ((b = 41 or a = 2) and b in (41, 42))", + expected: "a in (2, 1) and (b = 42 or a = 1) and (a = 2 or b = 41) and b in (42, 41)", }, { in: "(a = 1 and b = 41) or (a = 2 and b = 42) or (a = 3 and b = 43)", - expected: "a in (1, 2, 3) and (a in (1, 2) or b = 43) and ((a = 1 or b = 42 or a = 3) and (a = 1 or b = 42 or b = 43)) and ((b = 41 or a = 2 or a = 3) and (b = 41 or a = 2 or b = 43) and ((b in (41, 42) or a = 3) and b in (41, 42, 43)))", + expected: "a in (3, 2, 1) and (b = 43 or a in (2, 1)) and (a = 3 or (b = 42 or a = 1)) and (b = 43 or (b = 42 or a = 1)) and (a = 3 or (a = 2 or b = 41)) and (b = 43 or (a = 2 or b = 41)) and (a = 3 or b in (42, 41)) and b in (43, 42, 41)", }, { // the following two tests show some pathological cases that would grow too much, and so we abort the rewriting in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", expected: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", }, { in: "a = 5 and B or a = 6 and C", - expected: "a in (5, 6) and (a = 5 or C) and ((B or a = 6) and (B or C))", + expected: "a in (6, 5) and (C or a = 5) and (a = 6 or B) and (C or B)", }, { in: "(a = 5 and b = 1 or b = 2 and a = 6)", - expected: "(a = 5 or b = 2) and a in (5, 6) and (b in (1, 2) and (b = 1 or a = 6))", + expected: "(b = 2 or a = 5) and a in (6, 5) and b in (2, 1) and (a = 6 or b = 1)", }, { in: "(a in (1,5) and B or C and a = 6)", - expected: "(a in (1, 5) or C) and a in (1, 5, 6) and ((B or C) and (B or a = 6))", + expected: "(C or a in (1, 5)) and a in (6, 1, 5) and (C or B) and (a = 6 or B)", }, { in: "(a in (1, 5) and B or C and a in (5, 7))", - expected: "(a in (1, 5) or C) and a in (1, 5, 7) and ((B or C) and (B or a in (5, 7)))", + expected: "(C or a in (1, 5)) and a in (5, 7, 1) and (C or B) and (a in (5, 7) or B)", }, { in: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", expected: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", diff --git a/go/vt/sqlparser/random_expr.go b/go/vt/sqlparser/random_expr.go index f5b394b36fe..a6b7757c045 100644 --- a/go/vt/sqlparser/random_expr.go +++ b/go/vt/sqlparser/random_expr.go @@ -231,10 +231,7 @@ func (g *Generator) makeAggregateIfNecessary(genConfig ExprGeneratorConfig, expr // if the generated expression must be an aggregate, and it is not, // tack on an extra "and count(*)" to make it aggregate if genConfig.AggrRule == IsAggregate && !g.isAggregate && g.depth == 0 { - expr = &AndExpr{ - Left: expr, - Right: &CountStar{}, - } + expr = createAndExpr(expr, &CountStar{}) g.isAggregate = true } @@ -269,7 +266,7 @@ func (g *Generator) booleanExpr(genConfig ExprGeneratorConfig) Expr { func() Expr { return g.orExpr(genConfig) }, func() Expr { return g.comparison(genConfig.intTypeConfig()) }, func() Expr { return g.comparison(genConfig.stringTypeConfig()) }, - //func() Expr { return g.comparison(genConfig) }, // this is not accepted by the parser + // func() Expr { return g.comparison(genConfig) }, // this is not accepted by the parser func() Expr { return g.inExpr(genConfig) }, func() Expr { return g.existsExpr(genConfig) }, func() Expr { return g.between(genConfig.intTypeConfig()) }, @@ -374,7 +371,7 @@ func (g *Generator) randomBool(prob float32) bool { } func (g *Generator) intLiteral() Expr { - t := fmt.Sprintf("%d", rand.IntN(100)-rand.IntN(100)) //nolint SA4000 + t := fmt.Sprintf("%d", rand.IntN(100)-rand.IntN(100)) // nolint SA4000 return NewIntLiteral(t) } @@ -477,10 +474,10 @@ func (g *Generator) randomOfS(options []string) string { func (g *Generator) andExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() - return &AndExpr{ - Left: g.Expression(genConfig), - Right: g.Expression(genConfig), - } + return createAndExpr( + g.Expression(genConfig), + g.Expression(genConfig), + ) } func (g *Generator) orExpr(genConfig ExprGeneratorConfig) Expr { diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 9912b19f323..09068dd02f9 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -17811,7 +17811,7 @@ yydefault: var yyLOCAL Expr //line sql.y:5308 { - yyLOCAL = &AndExpr{Left: yyDollar[1].exprUnion(), Right: yyDollar[3].exprUnion()} + yyLOCAL = createAndExpr(yyDollar[1].exprUnion(), yyDollar[3].exprUnion()) } yyVAL.union = yyLOCAL case 1013: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 64ce957d2dd..880f00939cd 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -5306,7 +5306,7 @@ expression: } | expression AND expression %prec AND { - $$ = &AndExpr{Left: $1, Right: $3} + $$ = createAndExpr($1,$3) } | NOT expression %prec NOT { diff --git a/go/vt/sqlparser/testdata/select_cases.txt b/go/vt/sqlparser/testdata/select_cases.txt index 157b2ebfe99..804868796d4 100644 --- a/go/vt/sqlparser/testdata/select_cases.txt +++ b/go/vt/sqlparser/testdata/select_cases.txt @@ -3872,7 +3872,7 @@ INPUT select a1,a2,b,min(c) from t1 where ((a1 > 'a') or (a1 < '9')) and ((a2 >= 'b') and (a2 < 'z')) and (b = 'a') and ((c < 'h112') or (c = 'j121') or (c > 'k121' and c < 'm122') or (c > 'o122')) group by a1,a2,b; END OUTPUT -select a1, a2, b, min(c) from t1 where (a1 > 'a' or a1 < '9') and (a2 >= 'b' and a2 < 'z') and b = 'a' and (c < 'h112' or c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122') group by a1, a2, b +select a1, a2, b, min(c) from t1 where (a1 > 'a' or a1 < '9') and a2 >= 'b' and a2 < 'z' and b = 'a' and (c < 'h112' or c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122') group by a1, a2, b END INPUT select a, quote(a), isnull(quote(a)), quote(a) is null, ifnull(quote(a), 'n') from t1; @@ -6434,7 +6434,7 @@ INPUT select distinct t1.project_id as project_id, t1.project_name as project_name, t1.client_ptr as client_ptr, t1.comments as comments, sum( t3.amount_received ) + sum( t3.adjustment ) as total_budget from t2 as client_period , t2 as project_period, t3 left join t1 on (t3.project_ptr = t1.project_id and t3.date_received <= '2001-03-22 14:15:09') left join t4 on t4.client_id = t1.client_ptr where 1 and ( client_period.period_type = 'client_table' and client_period.period_key = t4.client_id and ( client_period.start_date <= '2001-03-22 14:15:09' or isnull( client_period.start_date )) and ( client_period.end_date > '2001-03-21 14:15:09' or isnull( client_period.end_date )) ) and ( project_period.period_type = 'project_table' and project_period.period_key = t1.project_id and ( project_period.start_date <= '2001-03-22 14:15:09' or isnull( project_period.start_date )) and ( project_period.end_date > '2001-03-21 14:15:09' or isnull( project_period.end_date )) ) group by client_id, project_id , client_period.period_id , project_period.period_id order by client_name asc, project_name asc; END OUTPUT -select distinct t1.project_id as project_id, t1.project_name as project_name, t1.client_ptr as client_ptr, t1.comments as comments, sum(t3.amount_received) + sum(t3.adjustment) as total_budget from t2 as client_period, t2 as project_period, t3 left join t1 on t3.project_ptr = t1.project_id and t3.date_received <= '2001-03-22 14:15:09' left join t4 on t4.client_id = t1.client_ptr where 1 and (client_period.period_type = 'client_table' and client_period.period_key = t4.client_id and (client_period.start_date <= '2001-03-22 14:15:09' or isnull(client_period.start_date)) and (client_period.end_date > '2001-03-21 14:15:09' or isnull(client_period.end_date))) and (project_period.period_type = 'project_table' and project_period.period_key = t1.project_id and (project_period.start_date <= '2001-03-22 14:15:09' or isnull(project_period.start_date)) and (project_period.end_date > '2001-03-21 14:15:09' or isnull(project_period.end_date))) group by client_id, project_id, client_period.period_id, project_period.period_id order by client_name asc, project_name asc +select distinct t1.project_id as project_id, t1.project_name as project_name, t1.client_ptr as client_ptr, t1.comments as comments, sum(t3.amount_received) + sum(t3.adjustment) as total_budget from t2 as client_period, t2 as project_period, t3 left join t1 on t3.project_ptr = t1.project_id and t3.date_received <= '2001-03-22 14:15:09' left join t4 on t4.client_id = t1.client_ptr where 1 and client_period.period_type = 'client_table' and client_period.period_key = t4.client_id and (client_period.start_date <= '2001-03-22 14:15:09' or isnull(client_period.start_date)) and (client_period.end_date > '2001-03-21 14:15:09' or isnull(client_period.end_date)) and project_period.period_type = 'project_table' and project_period.period_key = t1.project_id and (project_period.start_date <= '2001-03-22 14:15:09' or isnull(project_period.start_date)) and (project_period.end_date > '2001-03-21 14:15:09' or isnull(project_period.end_date)) group by client_id, project_id, client_period.period_id, project_period.period_id order by client_name asc, project_name asc END INPUT select a1,a2,b,min(c) from t2 where b is NULL group by a1,a2; @@ -9368,7 +9368,7 @@ INPUT select a1,a2,b,min(c) from t1 where ((a1 > 'a') or (a1 < '9')) and ((a2 >= 'b') and (a2 < 'z')) and (b = 'a') and ((c = 'j121') or (c > 'k121' and c < 'm122') or (c > 'o122') or (c < 'h112') or (c = 'c111')) group by a1,a2,b; END OUTPUT -select a1, a2, b, min(c) from t1 where (a1 > 'a' or a1 < '9') and (a2 >= 'b' and a2 < 'z') and b = 'a' and (c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122' or c < 'h112' or c = 'c111') group by a1, a2, b +select a1, a2, b, min(c) from t1 where (a1 > 'a' or a1 < '9') and a2 >= 'b' and a2 < 'z' and b = 'a' and (c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122' or c < 'h112' or c = 'c111') group by a1, a2, b END INPUT select count(distinct n1) from t1; @@ -10886,7 +10886,7 @@ INPUT select a1,a2,b,min(c) from t2 where ((a1 > 'a') or (a1 < '9')) and ((a2 >= 'b') and (a2 < 'z')) and (b = 'a') and ((c = 'j121') or (c > 'k121' and c < 'm122') or (c > 'o122') or (c < 'h112') or (c = 'c111')) group by a1,a2,b; END OUTPUT -select a1, a2, b, min(c) from t2 where (a1 > 'a' or a1 < '9') and (a2 >= 'b' and a2 < 'z') and b = 'a' and (c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122' or c < 'h112' or c = 'c111') group by a1, a2, b +select a1, a2, b, min(c) from t2 where (a1 > 'a' or a1 < '9') and a2 >= 'b' and a2 < 'z' and b = 'a' and (c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122' or c < 'h112' or c = 'c111') group by a1, a2, b END INPUT select * from t1 where a > 5 xor a < 10; @@ -20882,7 +20882,7 @@ INPUT select a1,a2,b,min(c) from t2 where ((a1 > 'a') or (a1 < '9')) and ((a2 >= 'b') and (a2 < 'z')) and (b = 'a') and ((c < 'h112') or (c = 'j121') or (c > 'k121' and c < 'm122') or (c > 'o122')) group by a1,a2,b; END OUTPUT -select a1, a2, b, min(c) from t2 where (a1 > 'a' or a1 < '9') and (a2 >= 'b' and a2 < 'z') and b = 'a' and (c < 'h112' or c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122') group by a1, a2, b +select a1, a2, b, min(c) from t2 where (a1 > 'a' or a1 < '9') and a2 >= 'b' and a2 < 'z' and b = 'a' and (c < 'h112' or c = 'j121' or c > 'k121' and c < 'm122' or c > 'o122') group by a1, a2, b END INPUT select group_concat(c1 order by c1) from t1 group by c1 collate utf8_latvian_ci; diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index b785128917f..62484027d00 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -97,10 +97,7 @@ func (p *Parser) NormalizeAlphabetically(query string) (normalized string, err e Expr: expr, } } else { - newWhere.Expr = &AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + newWhere.Expr = AndExpressions(newWhere.Expr, expr) } } switch stmt := stmt.(type) { From cf7430e398e988067b19d06f3223f4d4e5efc750 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Tue, 27 Aug 2024 17:03:23 +0200 Subject: [PATCH 02/14] change semantics and planner to handle the new AndExpr Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast.go | 2 +- go/vt/sqlparser/ast_clone.go | 26 ++++++------ go/vt/sqlparser/ast_copy_on_rewrite.go | 12 +----- go/vt/sqlparser/ast_equals.go | 28 ++++++------- go/vt/sqlparser/ast_funcs.go | 11 +++++ go/vt/sqlparser/ast_rewrite.go | 12 ++---- go/vt/sqlparser/ast_visit.go | 6 +-- go/vt/sqlparser/predicate_rewriting.go | 11 +++-- go/vt/vtgate/evalengine/translate.go | 30 +++++++++++--- .../planbuilder/operators/querygraph.go | 5 +-- .../vtgate/planbuilder/operators/subquery.go | 2 +- go/vt/vtgate/planbuilder/operators/update.go | 40 ++++++++----------- .../planbuilder/predicate_rewrite_test.go | 8 ++-- go/vt/vtgate/semantics/early_rewriter.go | 19 +++++---- go/vt/vtgate/semantics/scoper.go | 5 +-- go/vt/vtgate/semantics/semantic_table.go | 2 +- .../simplifier/expression_simplifier.go | 7 +++- 17 files changed, 122 insertions(+), 104 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 124833bc91d..7ddb559ebf9 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2263,7 +2263,7 @@ type ( } // AndExpr represents an AND expression. - AndExpr struct{ Predicates []Expr } + AndExpr struct{ Predicates Exprs } // OrExpr represents an OR expression. OrExpr struct { diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 00330ff1972..99122454bba 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -744,7 +744,7 @@ func CloneRefOfAndExpr(n *AndExpr) *AndExpr { return nil } out := *n - out.Predicates = CloneSliceOfExpr(n.Predicates) + out.Predicates = CloneExprs(n.Predicates) return &out } @@ -4368,18 +4368,6 @@ func CloneSliceOfIdentifierCI(n []IdentifierCI) []IdentifierCI { return res } -// CloneSliceOfExpr creates a deep clone of the input. -func CloneSliceOfExpr(n []Expr) []Expr { - if n == nil { - return nil - } - res := make([]Expr, len(n)) - for i, x := range n { - res[i] = CloneExpr(x) - } - return res -} - // CloneSliceOfTxAccessMode creates a deep clone of the input. func CloneSliceOfTxAccessMode(n []TxAccessMode) []TxAccessMode { if n == nil { @@ -4469,6 +4457,18 @@ func CloneSliceOfRefOfVariable(n []*Variable) []*Variable { return res } +// CloneSliceOfExpr creates a deep clone of the input. +func CloneSliceOfExpr(n []Expr) []Expr { + if n == nil { + return nil + } + res := make([]Expr, len(n)) + for i, x := range n { + res[i] = CloneExpr(x) + } + return res +} + // CloneRefOfIdentifierCI creates a deep clone of the input. func CloneRefOfIdentifierCI(n *IdentifierCI) *IdentifierCI { if n == nil { diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 2caea4bd978..594cc21df8c 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -953,18 +953,10 @@ func (c *cow) copyOnRewriteRefOfAndExpr(n *AndExpr, parent SQLNode) (out SQLNode } out = n if c.pre == nil || c.pre(n, parent) { - var changedPredicates bool - _Predicates := make([]Expr, len(n.Predicates)) - for x, el := range n.Predicates { - this, changed := c.copyOnRewriteExpr(el, n) - _Predicates[x] = this.(Expr) - if changed { - changedPredicates = true - } - } + _Predicates, changedPredicates := c.copyOnRewriteExprs(n.Predicates, n) if changedPredicates { res := *n - res.Predicates = _Predicates + res.Predicates, _ = _Predicates.(Exprs) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 66ec6190159..4d5335da125 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1863,7 +1863,7 @@ func (cmp *Comparator) RefOfAndExpr(a, b *AndExpr) bool { if a == nil || b == nil { return false } - return cmp.SliceOfExpr(a.Predicates, b.Predicates) + return cmp.Exprs(a.Predicates, b.Predicates) } // RefOfAnyValue does deep equals between the two objects. @@ -7245,19 +7245,6 @@ func (cmp *Comparator) SliceOfIdentifierCI(a, b []IdentifierCI) bool { return true } -// SliceOfExpr does deep equals between the two objects. -func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !cmp.Expr(a[i], b[i]) { - return false - } - } - return true -} - // SliceOfTxAccessMode does deep equals between the two objects. func (cmp *Comparator) SliceOfTxAccessMode(a, b []TxAccessMode) bool { if len(a) != len(b) { @@ -7366,6 +7353,19 @@ func (cmp *Comparator) SliceOfRefOfVariable(a, b []*Variable) bool { return true } +// SliceOfExpr does deep equals between the two objects. +func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !cmp.Expr(a[i], b[i]) { + return false + } + } + return true +} + // RefOfIdentifierCI does deep equals between the two objects. func (cmp *Comparator) RefOfIdentifierCI(a, b *IdentifierCI) bool { if a == b { diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index db6965de1f3..73247d9a322 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2360,6 +2360,17 @@ func SplitAndExpression(filters []Expr, node Expr) []Expr { return append(filters, node) } +func CreateAndExpr(exprs ...Expr) Expr { + switch len(exprs) { + case 0: + return nil + case 1: + return exprs[0] + default: + return &AndExpr{Predicates: exprs} + } +} + // AndExpressions ands together two or more expressions, minimising the expr when possible func AndExpressions(exprs ...Expr) Expr { switch len(exprs) { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index df144e4fad8..02e4d9411fd 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1088,14 +1088,10 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace return true } } - for x, el := range node.Predicates { - if !a.rewriteExpr(node, el, func(idx int) replacerFunc { - return func(newNode, parent SQLNode) { - parent.(*AndExpr).Predicates[idx] = newNode.(Expr) - } - }(x)) { - return false - } + if !a.rewriteExprs(node, node.Predicates, func(newNode, parent SQLNode) { + parent.(*AndExpr).Predicates = newNode.(Exprs) + }) { + return false } if a.post != nil { a.cur.replacer = replacer diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index 04608d84ec2..f299a131a9f 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -811,10 +811,8 @@ func VisitRefOfAndExpr(in *AndExpr, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in.Predicates { - if err := VisitExpr(el, f); err != nil { - return err - } + if err := VisitExprs(in.Predicates, f); err != nil { + return err } return nil } diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index e4e859b2ba2..a6da91c8c7f 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -181,10 +181,13 @@ func simplifyOr(or *OrExpr) (Expr, bool) { // Distribution Law var distributedPredicates []Expr for _, lp := range and.Predicates { - distributedPredicates = append(distributedPredicates, &OrExpr{ - Left: lp, - Right: other, - }) + var or *OrExpr + if lok { + or = &OrExpr{Left: lp, Right: other} + } else { + or = &OrExpr{Left: other, Right: lp} + } + distributedPredicates = append(distributedPredicates, or) } return AndExpressions(distributedPredicates...), true } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 99e1508cc04..a887f187e16 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -99,15 +99,35 @@ func (ast *astCompiler) translateLogicalNot(node *sqlparser.NotExpr) (IR, error) return &NotExpr{UnaryExpr{inner}}, nil } +func (ast *astCompiler) translateLogicalAnd(node *sqlparser.AndExpr) (IR, error) { + var acc IR + for i, pred := range node.Predicates { + ir, err := ast.translateExpr(pred) + if err != nil { + return nil, err + } + if i == 0 { + acc = ir + continue + } + + acc = &LogicalExpr{ + BinaryExpr: BinaryExpr{ + Left: acc, + Right: ir, + }, + op: opLogicalAnd{}, + } + } + + return acc, nil +} + func (ast *astCompiler) translateLogicalExpr(node sqlparser.Expr) (IR, error) { var left, right sqlparser.Expr var logic opLogical switch n := node.(type) { - case *sqlparser.AndExpr: - left = n.Left - right = n.Right - logic = opLogicalAnd{} case *sqlparser.OrExpr: left = n.Left right = n.Right @@ -521,7 +541,7 @@ func (ast *astCompiler) translateExpr(e sqlparser.Expr) (IR, error) { case *sqlparser.Literal: return translateLiteral(node, ast.cfg.Collation) case *sqlparser.AndExpr: - return ast.translateLogicalExpr(node) + return ast.translateLogicalAnd(node) case *sqlparser.OrExpr: return ast.translateLogicalExpr(node) case *sqlparser.XorExpr: diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index 8e8572f7dfa..e0231e26b3b 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -141,10 +141,7 @@ func (qg *QueryGraph) addNoDepsPredicate(predicate sqlparser.Expr) { if qg.NoDeps == nil { qg.NoDeps = predicate } else { - qg.NoDeps = &sqlparser.AndExpr{ - Left: qg.NoDeps, - Right: predicate, - } + qg.NoDeps = sqlparser.AndExpressions(qg.NoDeps, predicate) } } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index b919bbfaed9..7b5e7b0515d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -298,7 +298,7 @@ func (sq *SubQuery) settleFilter(ctx *plancontext.PlanningContext, outer Operato // lead to better routing. This however might not always be true for example we can have the rhsPred to be something like // `user.id = 2 OR (:__sq_has_values AND user.id IN ::sql1)` if andExpr, isAndExpr := rhsPred.(*sqlparser.AndExpr); isAndExpr { - predicates = append(predicates, andExpr.Left, andExpr.Right) + predicates = append(predicates, andExpr.Predicates...) } else { predicates = append(predicates, rhsPred) } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index b4f0a37914e..56d516332a0 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -743,10 +743,7 @@ func buildChildUpdOpForSetNull( updatedTable, updateExprs, fk, updatedTable.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */) if compExpr != nil { - childWhereExpr = &sqlparser.AndExpr{ - Left: childWhereExpr, - Right: compExpr, - } + childWhereExpr = sqlparser.AndExpressions(childWhereExpr, compExpr) } parsedComments := getParsedCommentsForFkChecks(ctx) childUpdStmt := &sqlparser.Update{ @@ -847,13 +844,12 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, upda var predicate sqlparser.Expr = parentIsNullExpr var joinExpr sqlparser.Expr if matchedExpr == nil { - predicate = &sqlparser.AndExpr{ - Left: predicate, - Right: &sqlparser.IsExpr{ + predicate = sqlparser.AndExpressions( + predicate, + &sqlparser.IsExpr{ Left: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), Right: sqlparser.IsNotNullOp, - }, - } + }) joinExpr = &sqlparser.ComparisonExpr{ Operator: sqlparser.EqualOp, Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), @@ -868,35 +864,33 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, upda Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), Right: prefixedMatchExpr, } - predicate = &sqlparser.AndExpr{ - Left: predicate, - Right: &sqlparser.IsExpr{ + predicate = sqlparser.AndExpressions( + predicate, + &sqlparser.IsExpr{ Left: prefixedMatchExpr, Right: sqlparser.IsNotNullOp, - }, - } + }) } if idx == 0 { joinCond, whereCond = joinExpr, predicate continue } - joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} - whereCond = &sqlparser.AndExpr{Left: whereCond, Right: predicate} + joinCond = sqlparser.AndExpressions(joinCond, joinExpr) + whereCond = sqlparser.AndExpressions(whereCond, predicate) } - whereCond = &sqlparser.AndExpr{ - Left: whereCond, - Right: &sqlparser.NotExpr{ + whereCond = sqlparser.AndExpressions( + whereCond, + &sqlparser.NotExpr{ Expr: &sqlparser.ComparisonExpr{ Operator: sqlparser.NullSafeEqualOp, Left: notEqualColNames, Right: notEqualExprs, }, - }, - } + }) // add existing where condition on the update statement if updStmt.Where != nil { - whereCond = &sqlparser.AndExpr{Left: whereCond, Right: prefixColNames(ctx, childTbl, updStmt.Where.Expr)} + whereCond = sqlparser.AndExpressions(whereCond, prefixColNames(ctx, childTbl, updStmt.Where.Expr)) } return createSelectionOp(ctx, sqlparser.SelectExprs{sqlparser.NewAliasedExpr(sqlparser.NewIntLiteral("1"), "")}, @@ -959,7 +953,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat joinCond = joinExpr continue } - joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} + joinCond = sqlparser.AndExpressions(joinCond, joinExpr) } var whereCond sqlparser.Expr diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index 4945c2bb7ff..2f262f75a7f 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -64,10 +64,10 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { Expr: tc.createPredicate(lvl + 1), } case AND: - return &sqlparser.AndExpr{ - Left: tc.createPredicate(lvl + 1), - Right: tc.createPredicate(lvl + 1), - } + return sqlparser.AndExpressions( + tc.createPredicate(lvl+1), + tc.createPredicate(lvl+1), + ) case OR: return &sqlparser.OrExpr{ Left: tc.createPredicate(lvl + 1), diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ee12765e984..f38259735c5 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -864,13 +864,13 @@ func rewriteOrExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlpa // rewriteAndExpr rewrites AND expressions when either side is TRUE. func rewriteAndExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.AndExpr) { - newNode := rewriteAndTrue(env, *node) + newNode := rewriteAndTrue(env, node) if newNode != nil { cursor.ReplaceAndRevisit(newNode) } } -func rewriteAndTrue(env *vtenv.Environment, andExpr sqlparser.AndExpr) sqlparser.Expr { +func rewriteAndTrue(env *vtenv.Environment, andExpr *sqlparser.AndExpr) sqlparser.Expr { // we are looking for the pattern `WHERE c = 1 AND 1 = 1` isTrue := func(subExpr sqlparser.Expr) bool { coll := env.CollationEnv().DefaultConnectionCharset() @@ -896,13 +896,18 @@ func rewriteAndTrue(env *vtenv.Environment, andExpr sqlparser.AndExpr) sqlparser return boolValue } - if isTrue(andExpr.Left) { - return andExpr.Right - } else if isTrue(andExpr.Right) { - return andExpr.Left + var remaining sqlparser.Exprs + for _, p := range andExpr.Predicates { + if !isTrue(p) { + remaining = append(remaining, p) + } } - return nil + if len(remaining) == len(andExpr.Predicates) { + return nil + } + + return sqlparser.AndExpressions(remaining...) } // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 9d596d9ecd1..0c0dfe1fa34 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -309,10 +309,7 @@ func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) erro // at this stage, we don't store the actual dependencies, we only store the expressions. // only later will we walk the expression tree and figure out the deps. so, we need to create a // composite expression that contains all the expressions in the SELECTs that this UNION consists of - tableInfo.cols[i] = &sqlparser.AndExpr{ - Left: col, - Right: thisTableInfo.cols[i], - } + tableInfo.cols[i] = sqlparser.CreateAndExpr(col, thisTableInfo.cols[i]) } } diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index 6738546fe37..81667b2d898 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -911,7 +911,7 @@ func (st *SemTable) AndExpressions(exprs ...sqlparser.Expr) sqlparser.Expr { continue outer } } - result = &sqlparser.AndExpr{Left: result, Right: expr} + result = sqlparser.AndExpressions(result, expr) } return result } diff --git a/go/vt/vtgate/simplifier/expression_simplifier.go b/go/vt/vtgate/simplifier/expression_simplifier.go index b64402cfaac..2a3312533bc 100644 --- a/go/vt/vtgate/simplifier/expression_simplifier.go +++ b/go/vt/vtgate/simplifier/expression_simplifier.go @@ -20,6 +20,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" ) @@ -97,7 +99,10 @@ func (s *shrinker) fillQueue() bool { before := len(s.queue) switch e := s.orig.(type) { case *sqlparser.AndExpr: - s.queue = append(s.queue, e.Left, e.Right) + addThese := slice.Map(e.Predicates, func(e sqlparser.Expr) sqlparser.SQLNode { + return e + }) + s.queue = append(s.queue, addThese...) case *sqlparser.OrExpr: s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.ComparisonExpr: From b0c519296ee598d07a6a769856a37c8456aa67bb Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 07:47:17 +0200 Subject: [PATCH 03/14] last few uses of Left/Right Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast.go | 5 +++ go/vt/vtctl/workflow/materializer.go | 4 +- go/vt/vtctl/workflow/traffic_switcher.go | 2 +- go/vt/vtctl/workflow/utils.go | 17 -------- go/vt/vtctl/workflow/vexec/query_planner.go | 15 ++----- .../tabletserver/vstreamer/planbuilder.go | 17 +------- go/vt/wrangler/materializer.go | 16 +------- go/vt/wrangler/vdiff.go | 15 +++---- go/vt/wrangler/vexec_plan.go | 39 ++++++------------- 9 files changed, 31 insertions(+), 99 deletions(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 7ddb559ebf9..9d01e87f3c1 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -720,6 +720,11 @@ type ( // IndexType is the type of index in a DDL statement IndexType int8 + + WhereAble interface { + AddWhere(e Expr) + GetWherePredicate() Expr + } ) var _ OrderAndLimit = (*Select)(nil) diff --git a/go/vt/vtctl/workflow/materializer.go b/go/vt/vtctl/workflow/materializer.go index f0171f31cab..7719bc81318 100644 --- a/go/vt/vtctl/workflow/materializer.go +++ b/go/vt/vtctl/workflow/materializer.go @@ -234,10 +234,10 @@ func (mz *materializer) generateBinlogSources(ctx context.Context, targetShard * Name: sqlparser.NewIdentifierCI("in_keyrange"), Exprs: subExprs, } - addFilter(sel, inKeyRange) + sel.AddWhere(inKeyRange) } if tenantClause != nil { - addFilter(sel, *tenantClause) + sel.AddWhere(*tenantClause) } rule.Filter = sqlparser.String(sel) bls.Filter.Rules = append(bls.Filter.Rules, rule) diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go index 34e1e4e4329..b10f7a9645e 100644 --- a/go/vt/vtctl/workflow/traffic_switcher.go +++ b/go/vt/vtctl/workflow/traffic_switcher.go @@ -968,7 +968,7 @@ func (ts *trafficSwitcher) addTenantFilter(ctx context.Context, filter string) ( if !ok { return "", fmt.Errorf("unrecognized statement: %s", filter) } - addFilter(sel, *tenantClause) + sel.AddWhere(*tenantClause) filter = sqlparser.String(sel) return filter, nil } diff --git a/go/vt/vtctl/workflow/utils.go b/go/vt/vtctl/workflow/utils.go index 374d96396f2..5abe1689084 100644 --- a/go/vt/vtctl/workflow/utils.go +++ b/go/vt/vtctl/workflow/utils.go @@ -793,23 +793,6 @@ func LegacyBuildTargets(ctx context.Context, ts *topo.Server, tmc tmclient.Table }, nil } -func addFilter(sel *sqlparser.Select, filter sqlparser.Expr) { - if sel.Where != nil { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: &sqlparser.AndExpr{ - Left: filter, - Right: sel.Where.Expr, - }, - } - } else { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: filter, - } - } -} - func getTenantClause(vrOptions *vtctldatapb.WorkflowOptions, targetVSchema *vindexes.KeyspaceSchema, parser *sqlparser.Parser) (*sqlparser.Expr, error) { if vrOptions.TenantId == "" { diff --git a/go/vt/vtctl/workflow/vexec/query_planner.go b/go/vt/vtctl/workflow/vexec/query_planner.go index de052efce8c..9d16dc72f55 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner.go +++ b/go/vt/vtctl/workflow/vexec/query_planner.go @@ -354,10 +354,7 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q case nil: targetWhere.Expr = expr default: - targetWhere.Expr = &sqlparser.AndExpr{ - Left: expr, - Right: where.Expr, - } + targetWhere.Expr = sqlparser.CreateAndExpr(expr, where.Expr) } sel.Where = targetWhere @@ -408,10 +405,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Expr: expr, } default: - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) } } @@ -424,10 +418,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Right: sqlparser.NewStrLiteral(params.Workflow), } - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) } return newWhere diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index 9bbc98ca2bd..55988640fc3 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -518,7 +518,7 @@ func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) er if where == nil { return nil } - exprs := splitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, where.Expr) for _, expr := range exprs { switch expr := expr.(type) { case *sqlparser.ComparisonExpr: @@ -595,21 +595,6 @@ func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) er return nil } -// splitAndExpression breaks up the Expr into AND-separated conditions -// and appends them to filters, which can be shuffled and recombined -// as needed. -func splitAndExpression(filters []sqlparser.Expr, node sqlparser.Expr) []sqlparser.Expr { - if node == nil { - return filters - } - switch node := node.(type) { - case *sqlparser.AndExpr: - filters = splitAndExpression(filters, node.Left) - return splitAndExpression(filters, node.Right) - } - return append(filters, node) -} - func (plan *Plan) analyzeExprs(vschema *localVSchema, selExprs sqlparser.SelectExprs) error { if _, ok := selExprs[0].(*sqlparser.StarExpr); !ok { for _, expr := range selExprs { diff --git a/go/vt/wrangler/materializer.go b/go/vt/wrangler/materializer.go index cc7ba3f1603..1974b35e863 100644 --- a/go/vt/wrangler/materializer.go +++ b/go/vt/wrangler/materializer.go @@ -1405,21 +1405,7 @@ func (mz *materializer) generateInserts(ctx context.Context, sourceShards []*top Name: sqlparser.NewIdentifierCI("in_keyrange"), Exprs: subExprs, } - if sel.Where != nil { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: &sqlparser.AndExpr{ - Left: inKeyRange, - Right: sel.Where.Expr, - }, - } - } else { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: inKeyRange, - } - } - + sel.AddWhere(inKeyRange) filter = sqlparser.String(sel) } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 4caad42ce1f..ca83e8c57fb 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -1439,16 +1439,13 @@ func removeKeyrange(where *sqlparser.Where) *sqlparser.Where { func removeExprKeyrange(node sqlparser.Expr) sqlparser.Expr { switch node := node.(type) { case *sqlparser.AndExpr: - if isFuncKeyrange(node.Left) { - return removeExprKeyrange(node.Right) - } - if isFuncKeyrange(node.Right) { - return removeExprKeyrange(node.Left) - } - return &sqlparser.AndExpr{ - Left: removeExprKeyrange(node.Left), - Right: removeExprKeyrange(node.Right), + var keep sqlparser.Exprs + for _, p := range node.Predicates { + if !isFuncKeyrange(p) { + keep = append(keep, removeExprKeyrange(p)) + } } + return sqlparser.CreateAndExpr(keep...) } return node } diff --git a/go/vt/wrangler/vexec_plan.go b/go/vt/wrangler/vexec_plan.go index 76b2d0fe732..1165add5d4b 100644 --- a/go/vt/wrangler/vexec_plan.go +++ b/go/vt/wrangler/vexec_plan.go @@ -158,7 +158,7 @@ func (vx *vexec) buildPlan(ctx context.Context) (plan *vexecPlan, err error) { case *sqlparser.Insert: plan, err = vx.buildInsertPlan(ctx, vx.planner, stmt) case *sqlparser.Select: - plan, err = vx.buildSelectPlan(ctx, vx.planner, stmt) + plan, err = vx.buildSelectPlan(vx.planner, stmt) default: return nil, fmt.Errorf("query not supported by vexec: %s", sqlparser.String(stmt)) } @@ -166,12 +166,12 @@ func (vx *vexec) buildPlan(ctx context.Context) (plan *vexecPlan, err error) { } // analyzeWhereEqualsColumns identifies column names in a WHERE clause that have a comparison expression -func (vx *vexec) analyzeWhereEqualsColumns(where *sqlparser.Where) []string { +func (vx *vexec) analyzeWhereEqualsColumns(expr sqlparser.Expr) []string { var cols []string - if where == nil { + if expr == nil { return cols } - exprs := sqlparser.SplitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, expr) for _, expr := range exprs { switch expr := expr.(type) { case *sqlparser.ComparisonExpr: @@ -185,8 +185,8 @@ func (vx *vexec) analyzeWhereEqualsColumns(where *sqlparser.Where) []string { } // addDefaultWheres modifies the query to add, if appropriate, the workflow and DB-name column modifiers -func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) *sqlparser.Where { - cols := vx.analyzeWhereEqualsColumns(where) +func (vx *vexec) addDefaultWheres(planner vexecPlanner, stmt sqlparser.WhereAble) { + cols := vx.analyzeWhereEqualsColumns(stmt.GetWherePredicate()) var hasDBName, hasWorkflow bool plannerParams := planner.params() for _, col := range cols { @@ -196,24 +196,13 @@ func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) hasWorkflow = true } } - newWhere := where if !hasDBName { expr := &sqlparser.ComparisonExpr{ Left: &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(plannerParams.dbNameColumn)}, Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(vx.primaries[0].DbName()), } - if newWhere == nil { - newWhere = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: expr, - } - } else { - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } - } + stmt.AddWhere(expr) } if !hasWorkflow && vx.workflow != "" { expr := &sqlparser.ComparisonExpr{ @@ -221,12 +210,8 @@ func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(vx.workflow), } - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + stmt.AddWhere(expr) } - return newWhere } // buildUpdatePlan builds a plan for an UPDATE query @@ -262,7 +247,7 @@ func (vx *vexec) buildUpdatePlan(ctx context.Context, planner vexecPlanner, upd return nil, fmt.Errorf("Query must match one of these templates: %s", strings.Join(templates, "; ")) } } - upd.Where = vx.addDefaultWheres(planner, upd.Where) + vx.addDefaultWheres(planner, upd) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", upd) @@ -285,7 +270,7 @@ func (vx *vexec) buildDeletePlan(ctx context.Context, planner vexecPlanner, del return nil, fmt.Errorf("unsupported construct: %v", sqlparser.String(del)) } - del.Where = vx.addDefaultWheres(planner, del.Where) + vx.addDefaultWheres(planner, del) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", del) @@ -325,8 +310,8 @@ func (vx *vexec) buildInsertPlan(ctx context.Context, planner vexecPlanner, ins } // buildSelectPlan builds a plan for a SELECT query -func (vx *vexec) buildSelectPlan(ctx context.Context, planner vexecPlanner, sel *sqlparser.Select) (*vexecPlan, error) { - sel.Where = vx.addDefaultWheres(planner, sel.Where) +func (vx *vexec) buildSelectPlan(planner vexecPlanner, sel *sqlparser.Select) (*vexecPlan, error) { + vx.addDefaultWheres(planner, sel) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel) From aa788bdda14345f67c54410512b47d83913dff05 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 07:58:15 +0200 Subject: [PATCH 04/14] refactoring Signed-off-by: Andres Taylor --- go/vt/vtctl/workflow/stream_migrator.go | 11 ++--- go/vt/vtctl/workflow/vexec/query_planner.go | 48 +++++-------------- .../vtgate/planbuilder/operators/ast_to_op.go | 4 +- go/vt/vtgate/planbuilder/operators/phases.go | 2 +- 4 files changed, 17 insertions(+), 48 deletions(-) diff --git a/go/vt/vtctl/workflow/stream_migrator.go b/go/vt/vtctl/workflow/stream_migrator.go index b294ba1fcd0..724413e6030 100644 --- a/go/vt/vtctl/workflow/stream_migrator.go +++ b/go/vt/vtctl/workflow/stream_migrator.go @@ -1130,7 +1130,7 @@ func (sm *StreamMigrator) templatizeRule(ctx context.Context, rule *binlogdatapb case rule.Filter == vreplication.ExcludeStr: return StreamTypeUnknown, fmt.Errorf("unexpected rule in vreplication: %v", rule) default: - if err := sm.templatizeKeyRange(ctx, rule); err != nil { + if err := sm.templatizeKeyRange(rule); err != nil { return StreamTypeUnknown, err } @@ -1138,7 +1138,7 @@ func (sm *StreamMigrator) templatizeRule(ctx context.Context, rule *binlogdatapb } } -func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogdatapb.Rule) error { +func (sm *StreamMigrator) templatizeKeyRange(rule *binlogdatapb.Rule) error { statement, err := sm.parser.Parse(rule.Filter) if err != nil { return err @@ -1149,12 +1149,7 @@ func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogda return fmt.Errorf("unexpected query: %v", rule.Filter) } - var expr sqlparser.Expr - if sel.Where != nil { - expr = sel.Where.Expr - } - - exprs := sqlparser.SplitAndExpression(nil, expr) + exprs := sqlparser.SplitAndExpression(nil, sel.GetWherePredicate()) for _, subexpr := range exprs { funcExpr, ok := subexpr.(*sqlparser.FuncExpr) if !ok || !funcExpr.Name.EqualString("in_keyrange") { diff --git a/go/vt/vtctl/workflow/vexec/query_planner.go b/go/vt/vtctl/workflow/vexec/query_planner.go index 9d16dc72f55..3d3541fafce 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner.go +++ b/go/vt/vtctl/workflow/vexec/query_planner.go @@ -181,7 +181,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix ) } - del.Where = addDefaultWheres(planner, del.Where) + addDefaultWheres(planner, del) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", del) @@ -194,7 +194,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix } func (planner *VReplicationQueryPlanner) planSelect(sel *sqlparser.Select) (*FixedQueryPlan, error) { - sel.Where = addDefaultWheres(planner, sel.Where) + addDefaultWheres(planner, sel) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel) @@ -230,7 +230,7 @@ func (planner *VReplicationQueryPlanner) planUpdate(upd *sqlparser.Update) (*Fix } } - upd.Where = addDefaultWheres(planner, upd.Where) + addDefaultWheres(planner, upd) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", upd) @@ -289,8 +289,7 @@ func (planner *VReplicationLogQueryPlanner) QueryParams() QueryParams { } func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (QueryPlan, error) { - where := sel.Where - cols := extractWhereComparisonColumns(where) + cols := extractWhereComparisonColumns(sel.GetWherePredicate()) hasVReplIDCol := false for _, col := range cols { @@ -313,10 +312,6 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q // streamIDs. queriesByTarget := make(map[string]*sqlparser.ParsedQuery, len(planner.tabletStreamIDs)) for target, streamIDs := range planner.tabletStreamIDs { - targetWhere := &sqlparser.Where{ - Type: sqlparser.WhereClause, - } - var expr sqlparser.Expr switch len(streamIDs) { case 0: // WHERE vreplication_log.vrepl_id IN () => WHERE 1 != 1 @@ -349,15 +344,7 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q Right: tuple, } } - - switch where { - case nil: - targetWhere.Expr = expr - default: - targetWhere.Expr = sqlparser.CreateAndExpr(expr, where.Expr) - } - - sel.Where = targetWhere + sel.AddWhere(expr) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel) @@ -371,8 +358,8 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q }, nil } -func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.Where { - cols := extractWhereComparisonColumns(where) +func addDefaultWheres(planner QueryPlanner, stmt sqlparser.WhereAble) { + cols := extractWhereComparisonColumns(stmt.GetWherePredicate()) params := planner.QueryParams() hasDBNameCol := false @@ -387,8 +374,6 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W } } - newWhere := where - if !hasDBNameCol { expr := &sqlparser.ComparisonExpr{ Left: &sqlparser.ColName{ @@ -398,15 +383,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Right: sqlparser.NewStrLiteral(params.DBName), } - switch newWhere { - case nil: - newWhere = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: expr, - } - default: - newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) - } + stmt.AddWhere(expr) } if !hasWorkflowCol && params.Workflow != "" { @@ -417,23 +394,20 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(params.Workflow), } - - newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) + stmt.AddWhere(expr) } - - return newWhere } // extractWhereComparisonColumns extracts the column names used in AND-ed // comparison expressions in a where clause, given the following assumptions: // - (1) The column name is always the left-hand side of the comparison. // - (2) There are no compound expressions within the where clause involving OR. -func extractWhereComparisonColumns(where *sqlparser.Where) []string { +func extractWhereComparisonColumns(where sqlparser.Expr) []string { if where == nil { return nil } - exprs := sqlparser.SplitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, where) cols := make([]string, 0, len(exprs)) for _, expr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 4c075f480d3..a9903edcc79 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -50,8 +50,8 @@ func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Stat func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Operator { op := crossJoin(ctx, sel.From) - if sel.Where != nil { - op = addWherePredicates(ctx, sel.Where.Expr, op) + if expr := sel.GetWherePredicate(); expr != nil { + op = addWherePredicates(ctx, expr, op) } if sel.Comments != nil || sel.Lock != sqlparser.NoLock { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index d5354e9548f..cf126236a74 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -193,7 +193,7 @@ func createDMLWithInput(ctx *plancontext.PlanningContext, op, src Operator, in * if in.OwnedVindexQuery != nil { in.OwnedVindexQuery.From = sqlparser.TableExprs{targetQT.Alias} - in.OwnedVindexQuery.Where = sqlparser.NewWhere(sqlparser.WhereClause, compExpr) + in.OwnedVindexQuery.AddWhere(compExpr) in.OwnedVindexQuery.OrderBy = nil in.OwnedVindexQuery.Limit = nil } From f1fe39aa5b89db1a7bde7ae57874beee72c023cf Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 08:06:06 +0200 Subject: [PATCH 05/14] codegen Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 15 ++------------- go/vt/sqlparser/cached_size.go | 3 +-- go/vt/sqlparser/random_expr.go | 4 ++-- go/vt/sqlparser/sql.go | 2 +- go/vt/sqlparser/sql.y | 2 +- go/vt/wrangler/vdiff.go | 2 +- 6 files changed, 8 insertions(+), 20 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 73247d9a322..5135ce9014f 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -66,7 +66,7 @@ func Append(buf *strings.Builder, node SQLNode) { node.FormatFast(tbuf) } -func createAndExpr(exprL, exprR Expr) *AndExpr { +func CreateAndExpr(exprL, exprR Expr) *AndExpr { leftAnd, isLeftAnd := exprL.(*AndExpr) rightAnd, isRightAnd := exprR.(*AndExpr) if isLeftAnd && isRightAnd { @@ -1261,7 +1261,7 @@ func addPredicate(where *Where, pred Expr) *Where { } } - where.Expr = createAndExpr(where.Expr, pred) + where.Expr = CreateAndExpr(where.Expr, pred) return where } @@ -2360,17 +2360,6 @@ func SplitAndExpression(filters []Expr, node Expr) []Expr { return append(filters, node) } -func CreateAndExpr(exprs ...Expr) Expr { - switch len(exprs) { - case 0: - return nil - case 1: - return exprs[0] - default: - return &AndExpr{Predicates: exprs} - } -} - // AndExpressions ands together two or more expressions, minimising the expr when possible func AndExpressions(exprs ...Expr) Expr { switch len(exprs) { diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index aa731d4799d..6f907cb738b 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -319,9 +319,8 @@ func (cached *AndExpr) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(24) } - // field Right vitess.io/vitess/go/vt/sqlparser.Expr // field Predicates vitess.io/vitess/go/vt/sqlparser.Exprs { size += hack.RuntimeAllocSize(int64(cap(cached.Predicates)) * int64(16)) diff --git a/go/vt/sqlparser/random_expr.go b/go/vt/sqlparser/random_expr.go index a6b7757c045..ff431687ebf 100644 --- a/go/vt/sqlparser/random_expr.go +++ b/go/vt/sqlparser/random_expr.go @@ -231,7 +231,7 @@ func (g *Generator) makeAggregateIfNecessary(genConfig ExprGeneratorConfig, expr // if the generated expression must be an aggregate, and it is not, // tack on an extra "and count(*)" to make it aggregate if genConfig.AggrRule == IsAggregate && !g.isAggregate && g.depth == 0 { - expr = createAndExpr(expr, &CountStar{}) + expr = CreateAndExpr(expr, &CountStar{}) g.isAggregate = true } @@ -474,7 +474,7 @@ func (g *Generator) randomOfS(options []string) string { func (g *Generator) andExpr(genConfig ExprGeneratorConfig) Expr { g.enter() defer g.exit() - return createAndExpr( + return CreateAndExpr( g.Expression(genConfig), g.Expression(genConfig), ) diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index 09068dd02f9..f74dbcc4167 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -17811,7 +17811,7 @@ yydefault: var yyLOCAL Expr //line sql.y:5308 { - yyLOCAL = createAndExpr(yyDollar[1].exprUnion(), yyDollar[3].exprUnion()) + yyLOCAL = CreateAndExpr(yyDollar[1].exprUnion(), yyDollar[3].exprUnion()) } yyVAL.union = yyLOCAL case 1013: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 880f00939cd..c1c1cf7500e 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -5306,7 +5306,7 @@ expression: } | expression AND expression %prec AND { - $$ = createAndExpr($1,$3) + $$ = CreateAndExpr($1, $3) } | NOT expression %prec NOT { diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index ca83e8c57fb..47ade05700c 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -1445,7 +1445,7 @@ func removeExprKeyrange(node sqlparser.Expr) sqlparser.Expr { keep = append(keep, removeExprKeyrange(p)) } } - return sqlparser.CreateAndExpr(keep...) + return sqlparser.AndExpressions(keep...) } return node } From bceda6fcb7b7c1112db15a5e63fee1583819d572 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 08:49:02 +0200 Subject: [PATCH 06/14] move predicate simplification to later planner stages Signed-off-by: Andres Taylor --- .../vtgate/planbuilder/operators/ast_to_op.go | 47 +++++++++++++++-- .../planbuilder/testdata/filter_cases.json | 6 +-- .../testdata/foreignkey_cases.json | 8 +-- .../testdata/foreignkey_checks_on_cases.json | 4 +- .../planbuilder/testdata/select_cases.json | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 50 ------------------- 6 files changed, 52 insertions(+), 65 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index a9903edcc79..5f83132ba48 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -19,6 +19,8 @@ package operators import ( "fmt" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -50,9 +52,7 @@ func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Stat func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Operator { op := crossJoin(ctx, sel.From) - if expr := sel.GetWherePredicate(); expr != nil { - op = addWherePredicates(ctx, expr, op) - } + op = addWherePredicates(ctx, sel.GetWherePredicate(), op) if sel.Comments != nil || sel.Lock != sqlparser.NoLock { op = &LockAndComment{ @@ -75,19 +75,56 @@ func addWherePredicates(ctx *plancontext.PlanningContext, expr sqlparser.Expr, o func addWherePredsToSubQueryBuilder(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operator, sqc *SubQueryBuilder) Operator { outerID := TableID(op) - exprs := sqlparser.SplitAndExpression(nil, expr) - for _, expr := range exprs { + for _, expr := range sqlparser.SplitAndExpression(nil, expr) { sqlparser.RemoveKeyspaceInCol(expr) subq := sqc.handleSubquery(ctx, expr, outerID) if subq != nil { continue } + b := constantPredicate(ctx, expr) + if b != nil { + if *b { + // If the predicate is true, we can ignore it. + continue + } + + // If the predicate is false, we push down a false predicate to influence routing + expr = sqlparser.NewIntLiteral("0") + } + op = op.AddPredicate(ctx, expr) addColumnEquality(ctx, expr) } return op } +// constantPredicate evaluates the given expression and returns the result if it is a constant, +// in other words - can it be evaluated without any table data. +func constantPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) *bool { + env := ctx.VSchema.Environment() + coll := env.CollationEnv().DefaultConnectionCharset() + evalEnginePred, err := evalengine.Translate(expr, &evalengine.Config{ + Environment: env, + Collation: coll, + }) + if err != nil { + return nil + } + + evalEnv := evalengine.EmptyExpressionEnv(env) + res, err := evalEnv.Evaluate(evalEnginePred) + if err != nil { + return nil + } + + boolValue, err := res.Value(coll).ToBool() + if err != nil { + return nil + } + + return &boolValue +} + // cloneASTAndSemState clones the AST and the semantic state of the input node. func cloneASTAndSemState[T sqlparser.SQLNode](ctx *plancontext.PlanningContext, original T) T { return sqlparser.CopyOnRewrite(original, nil, func(cursor *sqlparser.CopyOnWriteCursor) { diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index b60e8812dda..3dc379b9aae 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -789,7 +789,7 @@ "Sharded": true }, "FieldQuery": "select Id from `user` where 1 != 1", - "Query": "select Id from `user` where 1 in ('aa', 'bb')", + "Query": "select Id from `user` where 0", "Table": "`user`" }, "TablesUsed": [ @@ -1251,7 +1251,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col from `user` where 1 != 1", - "Query": "select `user`.col from `user` where 1 = 1", + "Query": "select `user`.col from `user`", "Table": "`user`" }, { @@ -1262,7 +1262,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.id from user_extra where 1 != 1", - "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */ and 1 = 1", + "Query": "select user_extra.id from user_extra where user_extra.col = :user_col /* INT16 */", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index 47f10cd273b..e8560cb04bf 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -1518,7 +1518,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1594,7 +1594,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -2532,7 +2532,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where 1 != 1", - "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl1.cola is null and 2 is not null and u_multicol_tbl1.colb is null and u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 limit 1 for share", + "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl1.cola is null and u_multicol_tbl1.colb is null and u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 limit 1 for share", "Table": "u_multicol_tbl1, u_multicol_tbl2" }, { @@ -4110,7 +4110,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json index 7b525b2dcc9..6f2e3d26f66 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json @@ -1595,7 +1595,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1671,7 +1671,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index f06a6a50d45..856e56265ca 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1874,7 +1874,7 @@ "Sharded": false }, "FieldQuery": "select 42 from dual where 1 != 1", - "Query": "select 42 from dual where false", + "Query": "select 42 from dual where 0", "Table": "dual" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index f38259735c5..f58d1cbd32c 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -50,8 +50,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { return r.handleSelectExprs(cursor, node) case *sqlparser.OrExpr: rewriteOrExpr(r.env, cursor, node) - case *sqlparser.AndExpr: - rewriteAndExpr(r.env, cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) case *sqlparser.ComparisonExpr: @@ -862,54 +860,6 @@ func rewriteOrExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlpa } } -// rewriteAndExpr rewrites AND expressions when either side is TRUE. -func rewriteAndExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.AndExpr) { - newNode := rewriteAndTrue(env, node) - if newNode != nil { - cursor.ReplaceAndRevisit(newNode) - } -} - -func rewriteAndTrue(env *vtenv.Environment, andExpr *sqlparser.AndExpr) sqlparser.Expr { - // we are looking for the pattern `WHERE c = 1 AND 1 = 1` - isTrue := func(subExpr sqlparser.Expr) bool { - coll := env.CollationEnv().DefaultConnectionCharset() - evalEnginePred, err := evalengine.Translate(subExpr, &evalengine.Config{ - Environment: env, - Collation: coll, - }) - if err != nil { - return false - } - - env := evalengine.EmptyExpressionEnv(env) - res, err := env.Evaluate(evalEnginePred) - if err != nil { - return false - } - - boolValue, err := res.Value(coll).ToBool() - if err != nil { - return false - } - - return boolValue - } - - var remaining sqlparser.Exprs - for _, p := range andExpr.Predicates { - if !isTrue(p) { - remaining = append(remaining, p) - } - } - - if len(remaining) == len(andExpr.Predicates) { - return nil - } - - return sqlparser.AndExpressions(remaining...) -} - // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. func handleComparisonExpr(cursor *sqlparser.Cursor, node *sqlparser.ComparisonExpr) error { lft, lftOK := node.Left.(sqlparser.ValTuple) From 5d1776c8311a9d120eef0950cb06421ef031e1e1 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 08:53:50 +0200 Subject: [PATCH 07/14] test: update expectations Signed-off-by: Andres Taylor --- go/vt/sqlparser/predicate_rewriting_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/go/vt/sqlparser/predicate_rewriting_test.go b/go/vt/sqlparser/predicate_rewriting_test.go index ef7799611bc..96cf0e54095 100644 --- a/go/vt/sqlparser/predicate_rewriting_test.go +++ b/go/vt/sqlparser/predicate_rewriting_test.go @@ -44,7 +44,7 @@ func TestSimplifyExpression(in *testing.T) { expected: "(A or C) and (B or C)", }, { in: "C or (A and B)", - expected: "(A or C) and (B or C)", + expected: "(C or A) and (C or B)", }, { in: "A and A", expected: "A", @@ -136,26 +136,26 @@ func TestRewritePredicate(in *testing.T) { }, { in: "(a = 1 and b = 41) or (a = 2 and b = 42)", // this might look weird, but it allows the planner to either a or b in a vindex operation - expected: "a in (2, 1) and (b = 42 or a = 1) and (a = 2 or b = 41) and b in (42, 41)", + expected: "a in (1, 2) and (a = 1 or b = 42) and (b = 41 or a = 2) and b in (41, 42)", }, { in: "(a = 1 and b = 41) or (a = 2 and b = 42) or (a = 3 and b = 43)", - expected: "a in (3, 2, 1) and (b = 43 or a in (2, 1)) and (a = 3 or (b = 42 or a = 1)) and (b = 43 or (b = 42 or a = 1)) and (a = 3 or (a = 2 or b = 41)) and (b = 43 or (a = 2 or b = 41)) and (a = 3 or b in (42, 41)) and b in (43, 42, 41)", + expected: "a in (1, 2, 3) and (a in (1, 2) or b = 43) and (a = 1 or b = 42 or a = 3) and (a = 1 or b = 42 or b = 43) and (b = 41 or a = 2 or a = 3) and (b = 41 or a = 2 or b = 43) and (b in (41, 42) or a = 3) and b in (41, 42, 43)", }, { // the following two tests show some pathological cases that would grow too much, and so we abort the rewriting in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", expected: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", }, { in: "a = 5 and B or a = 6 and C", - expected: "a in (6, 5) and (C or a = 5) and (a = 6 or B) and (C or B)", + expected: "a in (5, 6) and (a = 5 or C) and (B or a = 6) and (B or C)", }, { in: "(a = 5 and b = 1 or b = 2 and a = 6)", - expected: "(b = 2 or a = 5) and a in (6, 5) and b in (2, 1) and (a = 6 or b = 1)", + expected: "(a = 5 or b = 2) and a in (5, 6) and b in (1, 2) and (b = 1 or a = 6)", }, { in: "(a in (1,5) and B or C and a = 6)", - expected: "(C or a in (1, 5)) and a in (6, 1, 5) and (C or B) and (a = 6 or B)", + expected: "(a in (1, 5) or C) and a in (1, 5, 6) and (B or C) and (B or a = 6)", }, { in: "(a in (1, 5) and B or C and a in (5, 7))", - expected: "(C or a in (1, 5)) and a in (5, 7, 1) and (C or B) and (a in (5, 7) or B)", + expected: "(a in (1, 5) or C) and a in (1, 5, 7) and (B or C) and (B or a in (5, 7))", }, { in: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", expected: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", From 07cbbe5226f1b8336c13f28915205756d0504277 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 09:22:29 +0200 Subject: [PATCH 08/14] refactor: move OR simplifiying to later planner stage Signed-off-by: Andres Taylor --- .../vtgate/planbuilder/operators/ast_to_op.go | 45 ++++++++++++++---- .../testdata/foreignkey_cases.json | 2 +- .../testdata/foreignkey_checks_on_cases.json | 2 +- go/vt/vtgate/semantics/early_rewriter.go | 46 ------------------- 4 files changed, 37 insertions(+), 58 deletions(-) diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 5f83132ba48..e7b180d0dd2 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -73,34 +73,59 @@ func addWherePredicates(ctx *plancontext.PlanningContext, expr sqlparser.Expr, o return sqc.getRootOperator(op, nil) } -func addWherePredsToSubQueryBuilder(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operator, sqc *SubQueryBuilder) Operator { +func addWherePredsToSubQueryBuilder(ctx *plancontext.PlanningContext, in sqlparser.Expr, op Operator, sqc *SubQueryBuilder) Operator { outerID := TableID(op) - for _, expr := range sqlparser.SplitAndExpression(nil, expr) { + for _, expr := range sqlparser.SplitAndExpression(nil, in) { sqlparser.RemoveKeyspaceInCol(expr) subq := sqc.handleSubquery(ctx, expr, outerID) if subq != nil { continue } - b := constantPredicate(ctx, expr) - if b != nil { - if *b { + constant, simplified := simplifyPredicate(ctx, expr) + if constant != nil { + if *constant { // If the predicate is true, we can ignore it. continue } // If the predicate is false, we push down a false predicate to influence routing - expr = sqlparser.NewIntLiteral("0") + simplified = sqlparser.NewIntLiteral("0") } - op = op.AddPredicate(ctx, expr) - addColumnEquality(ctx, expr) + op = op.AddPredicate(ctx, simplified) + addColumnEquality(ctx, simplified) } return op } -// constantPredicate evaluates the given expression and returns the result if it is a constant, +func simplifyPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) (*bool, sqlparser.Expr) { + switch expr := expr.(type) { + case *sqlparser.OrExpr: + if lhs := isConstantValue(ctx, expr.Left); lhs != nil { + if *lhs { + // if the LHS of an OR is true, we can ignore the RHS + return lhs, expr.Left + } + // if the LHS of an OR is false, we can simplify the OR to just the RHS + return simplifyPredicate(ctx, expr.Right) + } + if rhs := isConstantValue(ctx, expr.Right); rhs != nil { + if *rhs { + // if the LHS of an OR is true, we can ignore the LHS + return rhs, expr.Right + } + // if the LHS of an OR is false, we can simplify the OR to just the RHS + return simplifyPredicate(ctx, expr.Left) + } + return nil, expr + default: + return isConstantValue(ctx, expr), expr + } +} + +// isConstantValue evaluates the given expression and returns the result if it is a constant, // in other words - can it be evaluated without any table data. -func constantPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) *bool { +func isConstantValue(ctx *plancontext.PlanningContext, expr sqlparser.Expr) *bool { env := ctx.VSchema.Environment() coll := env.CollationEnv().DefaultConnectionCharset() evalEnginePred, err := evalengine.Translate(expr, &evalengine.Config{ diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index e8560cb04bf..799c9bd4420 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -1606,7 +1606,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json index 6f2e3d26f66..5464ccbd619 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json @@ -1683,7 +1683,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index f58d1cbd32c..3e53ed0816a 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -24,7 +24,6 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/evalengine" ) type earlyRewriter struct { @@ -48,8 +47,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) - case *sqlparser.OrExpr: - rewriteOrExpr(r.env, cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) case *sqlparser.ComparisonExpr: @@ -852,14 +849,6 @@ func (r *earlyRewriter) rewriteGroupByExpr(node *sqlparser.Literal) (sqlparser.E return realCloneOfColNames(aliasedExpr.Expr, false), nil } -// rewriteOrExpr rewrites OR expressions when the right side is FALSE. -func rewriteOrExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.OrExpr) { - newNode := rewriteOrFalse(env, *node) - if newNode != nil { - cursor.ReplaceAndRevisit(newNode) - } -} - // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. func handleComparisonExpr(cursor *sqlparser.Cursor, node *sqlparser.ComparisonExpr) error { lft, lftOK := node.Left.(sqlparser.ValTuple) @@ -925,41 +914,6 @@ func realCloneOfColNames(expr sqlparser.Expr, union bool) sqlparser.Expr { }, nil).(sqlparser.Expr) } -func rewriteOrFalse(env *vtenv.Environment, orExpr sqlparser.OrExpr) sqlparser.Expr { - // we are looking for the pattern `WHERE c = 1 OR 1 = 0` - isFalse := func(subExpr sqlparser.Expr) bool { - coll := env.CollationEnv().DefaultConnectionCharset() - evalEnginePred, err := evalengine.Translate(subExpr, &evalengine.Config{ - Environment: env, - Collation: coll, - }) - if err != nil { - return false - } - - env := evalengine.EmptyExpressionEnv(env) - res, err := env.Evaluate(evalEnginePred) - if err != nil { - return false - } - - boolValue, err := res.Value(coll).ToBool() - if err != nil { - return false - } - - return !boolValue - } - - if isFalse(orExpr.Left) { - return orExpr.Right - } else if isFalse(orExpr.Right) { - return orExpr.Left - } - - return nil -} - // rewriteJoinUsing rewrites SQL JOINs that use the USING clause to their equivalent // JOINs with the ON condition. This function finds all the tables that have the // specified columns in the USING clause, constructs an equality predicate for From fc653a1591ebe8ab81405d30c5ed78e43483d5c5 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 11:33:51 +0200 Subject: [PATCH 09/14] refactor: minor code tweaks Signed-off-by: Andres Taylor --- go/vt/vtctl/workflow/vexec/query_plan.go | 12 +- go/vt/vtctl/workflow/vexec/query_planner.go | 5 +- .../workflow/vexec/query_planner_test.go | 207 +++++++++--------- 3 files changed, 106 insertions(+), 118 deletions(-) diff --git a/go/vt/vtctl/workflow/vexec/query_plan.go b/go/vt/vtctl/workflow/vexec/query_plan.go index 52e7ee00b61..bbed2277af9 100644 --- a/go/vt/vtctl/workflow/vexec/query_plan.go +++ b/go/vt/vtctl/workflow/vexec/query_plan.go @@ -52,22 +52,16 @@ type FixedQueryPlan struct { } // Execute is part of the QueryPlan interface. -func (qp *FixedQueryPlan) Execute(ctx context.Context, target *topo.TabletInfo) (qr *querypb.QueryResult, err error) { +func (qp *FixedQueryPlan) Execute(ctx context.Context, target *topo.TabletInfo) (*querypb.QueryResult, error) { if qp.ParsedQuery == nil { return nil, fmt.Errorf("%w: call PlanQuery on a query planner first", ErrUnpreparedQuery) } targetAliasStr := target.AliasString() - defer func() { - if err != nil { - log.Warningf("Result on %v: %v", targetAliasStr, err) - return - } - }() - - qr, err = qp.tmc.VReplicationExec(ctx, target.Tablet, qp.ParsedQuery.Query) + qr, err := qp.tmc.VReplicationExec(ctx, target.Tablet, qp.ParsedQuery.Query) if err != nil { + log.Warningf("Result on %v: %v", targetAliasStr, err) return nil, err } return qr, nil diff --git a/go/vt/vtctl/workflow/vexec/query_planner.go b/go/vt/vtctl/workflow/vexec/query_planner.go index 3d3541fafce..b53407bf223 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner.go +++ b/go/vt/vtctl/workflow/vexec/query_planner.go @@ -330,12 +330,11 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q Right: sqlparser.NewIntLiteral(fmt.Sprintf("%d", streamIDs[0])), } default: // WHERE vreplication_log.vrepl_id IN (?) - vals := []sqlparser.Expr{} + var tuple sqlparser.ValTuple for _, streamID := range streamIDs { - vals = append(vals, sqlparser.NewIntLiteral(fmt.Sprintf("%d", streamID))) + tuple = append(tuple, sqlparser.NewIntLiteral(fmt.Sprintf("%d", streamID))) } - var tuple sqlparser.ValTuple = vals expr = &sqlparser.ComparisonExpr{ Operator: sqlparser.InOp, Left: &sqlparser.ColName{ diff --git a/go/vt/vtctl/workflow/vexec/query_planner_test.go b/go/vt/vtctl/workflow/vexec/query_planner_test.go index ec162ebc4c7..cc55a41688e 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner_test.go +++ b/go/vt/vtctl/workflow/vexec/query_planner_test.go @@ -246,119 +246,114 @@ func TestVReplicationQueryPlanner_planDelete(t *testing.T) { func TestVReplicationLogQueryPlanner(t *testing.T) { t.Parallel() - t.Run("planSelect", func(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - targetStreamIDs map[string][]int64 - query string - assertion func(t *testing.T, plan QueryPlan) - shouldErr bool - }{ - { - targetStreamIDs: map[string][]int64{ - "a": {1, 2}, - }, - query: "select * from _vt.vreplication_log", - assertion: func(t *testing.T, plan QueryPlan) { - t.Helper() - qp, ok := plan.(*PerTargetQueryPlan) - if !ok { - require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) - } - - expected := map[string]string{ - "a": "select * from _vt.vreplication_log where vrepl_id in (1, 2)", - } - assertQueryMapsMatch(t, expected, qp.ParsedQueries) - }, + tests := []struct { + targetStreamIDs map[string][]int64 + query string + assertion func(t *testing.T, plan QueryPlan) + shouldErr bool + }{ + { + targetStreamIDs: map[string][]int64{ + "a": {1, 2}, + }, + query: "select * from _vt.vreplication_log", + assertion: func(t *testing.T, plan QueryPlan) { + t.Helper() + qp, ok := plan.(*PerTargetQueryPlan) + if !ok { + require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) + } + + expected := map[string]string{ + "a": "select * from _vt.vreplication_log where vrepl_id in (1, 2)", + } + assertQueryMapsMatch(t, expected, qp.ParsedQueries) }, - { - targetStreamIDs: map[string][]int64{ - "a": nil, - }, - query: "select * from _vt.vreplication_log", - assertion: func(t *testing.T, plan QueryPlan) { - t.Helper() - qp, ok := plan.(*PerTargetQueryPlan) - if !ok { - require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) - } - - expected := map[string]string{ - "a": "select * from _vt.vreplication_log where 1 != 1", - } - assertQueryMapsMatch(t, expected, qp.ParsedQueries) - }, + }, + { + targetStreamIDs: map[string][]int64{ + "a": nil, }, - { - targetStreamIDs: map[string][]int64{ - "a": {1}, - }, - query: "select * from _vt.vreplication_log", - assertion: func(t *testing.T, plan QueryPlan) { - t.Helper() - qp, ok := plan.(*PerTargetQueryPlan) - if !ok { - require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) - } - - expected := map[string]string{ - "a": "select * from _vt.vreplication_log where vrepl_id = 1", - } - assertQueryMapsMatch(t, expected, qp.ParsedQueries) - }, + query: "select * from _vt.vreplication_log", + assertion: func(t *testing.T, plan QueryPlan) { + t.Helper() + qp, ok := plan.(*PerTargetQueryPlan) + if !ok { + require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) + } + + expected := map[string]string{ + "a": "select * from _vt.vreplication_log where 1 != 1", + } + assertQueryMapsMatch(t, expected, qp.ParsedQueries) + }, + }, + { + targetStreamIDs: map[string][]int64{ + "a": {1}, + }, + query: "select * from _vt.vreplication_log", + assertion: func(t *testing.T, plan QueryPlan) { + t.Helper() + qp, ok := plan.(*PerTargetQueryPlan) + if !ok { + require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) + } + + expected := map[string]string{ + "a": "select * from _vt.vreplication_log where vrepl_id = 1", + } + assertQueryMapsMatch(t, expected, qp.ParsedQueries) }, - { - query: "select * from _vt.vreplication_log where vrepl_id = 1", - assertion: func(t *testing.T, plan QueryPlan) { - t.Helper() - qp, ok := plan.(*FixedQueryPlan) - if !ok { - require.FailNow(t, "failed type check", "expected plan to be FixedQueryPlan, got %T: %v", plan, plan) - } - - assert.Equal(t, "select * from _vt.vreplication_log where vrepl_id = 1", qp.ParsedQuery.Query) - }, + }, + { + query: "select * from _vt.vreplication_log where vrepl_id = 1", + assertion: func(t *testing.T, plan QueryPlan) { + t.Helper() + qp, ok := plan.(*FixedQueryPlan) + if !ok { + require.FailNow(t, "failed type check", "expected plan to be FixedQueryPlan, got %T: %v", plan, plan) + } + + assert.Equal(t, "select * from _vt.vreplication_log where vrepl_id = 1", qp.ParsedQuery.Query) }, - { - targetStreamIDs: map[string][]int64{ - "a": {1, 2}, - }, - query: "select * from _vt.vreplication_log where foo = 'bar'", - assertion: func(t *testing.T, plan QueryPlan) { - t.Helper() - qp, ok := plan.(*PerTargetQueryPlan) - if !ok { - require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) - } - - expected := map[string]string{ - "a": "select * from _vt.vreplication_log where vrepl_id in (1, 2) and foo = 'bar'", - } - assertQueryMapsMatch(t, expected, qp.ParsedQueries) - }, + }, + { + targetStreamIDs: map[string][]int64{ + "a": {1, 2}, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - planner := NewVReplicationLogQueryPlanner(nil, tt.targetStreamIDs) - stmt, err := sqlparser.NewTestParser().Parse(tt.query) - require.NoError(t, err, "could not parse query %q", tt.query) - qp, err := planner.planSelect(stmt.(*sqlparser.Select)) - if tt.shouldErr { - assert.Error(t, err) - return + query: "select * from _vt.vreplication_log where foo = 'bar'", + assertion: func(t *testing.T, plan QueryPlan) { + t.Helper() + qp, ok := plan.(*PerTargetQueryPlan) + if !ok { + require.FailNow(t, "failed type check", "expected plan to be PerTargetQueryPlan, got %T: %v", plan, plan) } - tt.assertion(t, qp) - }) - } - }) + expected := map[string]string{ + "a": "select * from _vt.vreplication_log where foo = 'bar' and vrepl_id in (1, 2)", + } + assertQueryMapsMatch(t, expected, qp.ParsedQueries) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.query, func(t *testing.T) { + t.Parallel() + + planner := NewVReplicationLogQueryPlanner(nil, tt.targetStreamIDs) + stmt, err := sqlparser.NewTestParser().Parse(tt.query) + require.NoError(t, err, "could not parse query %q", tt.query) + qp, err := planner.planSelect(stmt.(*sqlparser.Select)) + if tt.shouldErr { + assert.Error(t, err) + return + } + + tt.assertion(t, qp) + }) + } } func assertQueryMapsMatch(t *testing.T, expected map[string]string, actual map[string]*sqlparser.ParsedQuery, msgAndArgs ...any) { From d8c214ba340b283423cdc55f44c3907e2df241bd Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 10:37:53 +0200 Subject: [PATCH 10/14] fix some of the faulty rewrites Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 3 + go/vt/sqlparser/predicate_rewriting.go | 43 +++++------ .../planbuilder/predicate_rewrite_test.go | 71 +++++++++++++++---- 3 files changed, 82 insertions(+), 35 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 5135ce9014f..c94ccde2894 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2391,6 +2391,9 @@ func AndExpressions(exprs ...Expr) Expr { uniqueAdd(expr) } } + if len(unique) == 1 { + return unique[0] + } return &AndExpr{Predicates: unique} } } diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index a6da91c8c7f..f29f640c5e6 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -90,6 +90,13 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { return expr, false } +func createOrs(exprs ...Expr) Expr { + if len(exprs) == 1 { + return exprs[0] + } + return &OrExpr{Left: exprs[0], Right: createOrs(exprs[1:]...)} +} + func simplifyOr(or *OrExpr) (Expr, bool) { res, rewritten := distinctOr(or) if rewritten { @@ -100,7 +107,7 @@ func simplifyOr(or *OrExpr) (Expr, bool) { rand, rok := or.Right.(*AndExpr) if lok && rok { - // (A AND B) OR (A AND C) => A OR (B AND C) + // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) var commonPredicates []Expr var leftRemainder, rightRemainder []Expr @@ -125,23 +132,14 @@ func simplifyOr(or *OrExpr) (Expr, bool) { if len(commonPredicates) > 0 { // Build the final AndExpr with common predicates and the OrExpr of remainders - var notCommon Expr - switch { - case len(leftRemainder) == 0 && len(rightRemainder) == 0: - // all expressions were common - return AndExpressions(commonPredicates...), true - case len(leftRemainder) == 0: - notCommon = AndExpressions(rightRemainder...) - case len(rightRemainder) == 0: - notCommon = AndExpressions(leftRemainder...) - default: - notCommon = &OrExpr{ - Left: AndExpressions(leftRemainder...), - Right: AndExpressions(rightRemainder...), - } + nonCommonPredicates := append(leftRemainder, rightRemainder...) + commonPred := AndExpressions(commonPredicates...) + if len(nonCommonPredicates) == 0 { + return commonPred, true } - return AndExpressions(append(commonPredicates, notCommon)...), true + return AndExpressions(commonPred, createOrs(nonCommonPredicates...)), true } + return or, false } if !lok && !rok { lftCmp, lok := or.Left.(*ComparisonExpr) @@ -201,6 +199,10 @@ func simplifyXor(xor *XorExpr) (Expr, bool) { } func simplifyAnd(expr *AndExpr) (Expr, bool) { + if len(expr.Predicates) == 1 { + return expr.Predicates[0], true + } + res, rewritten := distinctAnd(expr) if rewritten { return res, true @@ -210,6 +212,7 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) { simplified := false // Loop over all predicates in the AndExpr +outer: for i, andPred := range expr.Predicates { if or, ok := andPred.(*OrExpr); ok { // Check if we can simplify by matching with another predicate in the AndExpr @@ -223,13 +226,13 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) { // Found a match, keep the simpler expression (otherPred) simplifiedPredicates = append(simplifiedPredicates, otherPred) simplified = true - break + continue outer } } - } else { - // No simplification possible, keep the original predicate - simplifiedPredicates = append(simplifiedPredicates, andPred) } + + // No simplification possible, keep the original predicate + simplifiedPredicates = append(simplifiedPredicates, andPred) } if simplified { diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index 2f262f75a7f..9c1e25624ce 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -20,10 +20,12 @@ import ( "fmt" "math/rand/v2" "strconv" + "strings" "testing" "time" - "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/slice" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" @@ -82,6 +84,37 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { panic("unexpected nodeType") } +func TestOneRewriting(t *testing.T) { + venv := vtenv.NewTestEnv() + + // Modify these + const numberOfColumns = 2 + const expr = "n1 and n0 or n1 xor n1" + + predicate, err := sqlparser.NewTestParser().ParseExpr(expr) + require.NoError(t, err) + + simplified := sqlparser.RewritePredicate(predicate) + + cfg := &evalengine.Config{ + Environment: venv, + Collation: collations.MySQL8().DefaultConnectionCharset(), + ResolveColumn: resolveForFuzz, + } + original, err := evalengine.Translate(predicate, cfg) + require.NoError(t, err) + simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg) + require.NoError(t, err) + + env := evalengine.EmptyExpressionEnv(venv) + env.Row = make([]sqltypes.Value, numberOfColumns) + for i := range env.Row { + env.Row[i] = sqltypes.NULL + } + + testValues(t, env, 0, original, simpler) +} + func TestFuzzRewriting(t *testing.T) { // This test, that runs for one second only, will produce lots of random boolean expressions, // mixing AND, NOT, OR, XOR and column expressions. @@ -89,31 +122,29 @@ func TestFuzzRewriting(t *testing.T) { // Finally, it runs both the original and simplified predicate with all combinations of column // values - trying TRUE, FALSE and NULL. If the two expressions do not return the same value, // this is considered a test failure. - - venv := vtenv.NewTestEnv() start := time.Now() for time.Since(start) < 1*time.Second { tc := testCase{ - nodes: rand.IntN(4) + 1, + nodes: 2, depth: rand.IntN(4) + 1, } predicate := tc.createPredicate(0) name := sqlparser.String(predicate) t.Run(name, func(t *testing.T) { + venv := vtenv.NewTestEnv() simplified := sqlparser.RewritePredicate(predicate) - original, err := evalengine.Translate(predicate, &evalengine.Config{ - Environment: venv, - Collation: collations.MySQL8().DefaultConnectionCharset(), - ResolveColumn: resolveForFuzz, - }) + cfg := &evalengine.Config{ + Environment: venv, + Collation: collations.MySQL8().DefaultConnectionCharset(), + ResolveColumn: resolveForFuzz, + NoConstantFolding: true, + NoCompilation: true, + } + original, err := evalengine.Translate(predicate, cfg) require.NoError(t, err) - simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), &evalengine.Config{ - Environment: venv, - Collation: collations.MySQL8().DefaultConnectionCharset(), - ResolveColumn: resolveForFuzz, - }) + simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg) require.NoError(t, err) env := evalengine.EmptyExpressionEnv(venv) @@ -142,7 +173,17 @@ func testValues(t *testing.T, env *evalengine.ExpressionEnv, i int, original, si require.NoError(t, err) v2, err := env.Evaluate(simpler) require.NoError(t, err) - assert.Equal(t, v1.Value(collations.MySQL8().DefaultConnectionCharset()), v2.Value(collations.MySQL8().DefaultConnectionCharset())) + v1Value := v1.Value(collations.MySQL8().DefaultConnectionCharset()) + v2Value := v2.Value(collations.MySQL8().DefaultConnectionCharset()) + row := strings.Join(slice.Map(env.Row, func(i sqltypes.Value) string { + return i.String() + }), " | ") + msg := fmt.Sprintf("original: %v (%s)\nsimplified: %v (%s)\nrow: %v", sqlparser.String(original), v1Value.String(), sqlparser.String(simpler), v2Value.String(), row) + require.True( + t, + v1Value.Equal(v2Value), + msg, + ) if len(env.Row) > i+1 { testValues(t, env, i+1, original, simpler) } From 1de8b0887d49e7265c28d23dc469b21cedd2a6b8 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 11:49:39 +0200 Subject: [PATCH 11/14] fixed remaining rewriter bugs Signed-off-by: Andres Taylor --- go/vt/sqlparser/ast_funcs.go | 70 +++++ go/vt/sqlparser/predicate_rewriting.go | 268 +++++++++--------- .../planbuilder/predicate_rewrite_test.go | 33 ++- 3 files changed, 229 insertions(+), 142 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index c94ccde2894..bce28502631 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "io" + "slices" "strconv" "strings" @@ -2864,3 +2865,72 @@ func ExtractAllTables(stmt Statement) []string { }, stmt) return tables } + +// ExtractINFromOR rewrites the OR expression into an IN clause. +// Each side of each ORs has to be an equality comparison expression and the column names have to +// match for all sides of each comparison. +// This rewriter takes a query that looks like this WHERE a = 1 and b = 11 or a = 2 and b = 12 or a = 3 and b = 13 +// And rewrite that to WHERE (a, b) IN ((1,11), (2,12), (3,13)) +func ExtractINFromOR(expr *OrExpr) []Expr { + var varNames []*ColName + var values []Exprs + orSlice := orToSlice(expr) + for _, expr := range orSlice { + andSlice := andToSlice(expr) + if len(andSlice) == 0 { + return nil + } + + var currentVarNames []*ColName + var currentValues []Expr + for _, comparisonExpr := range andSlice { + if comparisonExpr.Operator != EqualOp { + return nil + } + + var colName *ColName + if left, ok := comparisonExpr.Left.(*ColName); ok { + colName = left + currentValues = append(currentValues, comparisonExpr.Right) + } + + if right, ok := comparisonExpr.Right.(*ColName); ok { + if colName != nil { + return nil + } + colName = right + currentValues = append(currentValues, comparisonExpr.Left) + } + + if colName == nil { + return nil + } + + currentVarNames = append(currentVarNames, colName) + } + + if len(varNames) == 0 { + varNames = currentVarNames + } else if !slices.EqualFunc(varNames, currentVarNames, func(col1, col2 *ColName) bool { return col1.Equal(col2) }) { + return nil + } + + values = append(values, currentValues) + } + + var nameTuple ValTuple + for _, name := range varNames { + nameTuple = append(nameTuple, name) + } + + var valueTuple ValTuple + for _, value := range values { + valueTuple = append(valueTuple, ValTuple(value)) + } + + return []Expr{&ComparisonExpr{ + Operator: InOp, + Left: nameTuple, + Right: valueTuple, + }} +} diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index f29f640c5e6..7635553ec6d 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -17,13 +17,19 @@ limitations under the License. package sqlparser import ( - "slices" + "fmt" ) +var DebugRewrite = false + // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN func RewritePredicate(ast SQLNode) SQLNode { - original := CloneSQLNode(ast) + return RewritePredicateInternal(ast, nil) +} + +func RewritePredicateInternal(ast SQLNode, view func(SQLNode)) SQLNode { + original := Clone(ast) // Beware: converting to CNF in this loop might cause exponential formula growth. // We bail out early to prevent going overboard. @@ -32,6 +38,10 @@ func RewritePredicate(ast SQLNode) SQLNode { stopOnChange := func(SQLNode, SQLNode) bool { return !exprChanged } + if DebugRewrite { + fmt.Println(String(ast)) + } + ast = SafeRewrite(ast, stopOnChange, func(cursor *Cursor) bool { e, isExpr := cursor.node.(Expr) if !isExpr { @@ -46,6 +56,10 @@ func RewritePredicate(ast SQLNode) SQLNode { return !exprChanged }) + if view != nil { + view(Clone(ast)) + } + if !exprChanged { return ast } @@ -74,6 +88,9 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { return child.Expr, true case *OrExpr: // not(or(a,b)) => and(not(a),not(b)) + if DebugRewrite { + fmt.Println(" >> not (a or b) => not a and not b") + } return AndExpressions(&NotExpr{Expr: child.Left}, &NotExpr{Expr: child.Right}), true case *AndExpr: // not(and(a,b)) => or(not(a), not(b)) @@ -85,6 +102,9 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { curr = &OrExpr{Left: curr, Right: &NotExpr{Expr: p}} } } + if DebugRewrite { + fmt.Println(" >> not (a and b) => not a or not b") + } return curr, true } return expr, false @@ -97,81 +117,68 @@ func createOrs(exprs ...Expr) Expr { return &OrExpr{Left: exprs[0], Right: createOrs(exprs[1:]...)} } -func simplifyOr(or *OrExpr) (Expr, bool) { - res, rewritten := distinctOr(or) - if rewritten { - return res, true - } - - land, lok := or.Left.(*AndExpr) - rand, rok := or.Right.(*AndExpr) - - if lok && rok { - // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) - var commonPredicates []Expr - var leftRemainder, rightRemainder []Expr - - // Find all matching predicates and separate the remainder - rightRemainder = rand.Predicates - for _, lp := range land.Predicates { - rhs := rightRemainder - rightRemainder = nil - isCommon := false - for _, rp := range rhs { - if Equals.Expr(lp, rp) { - commonPredicates = append(commonPredicates, lp) - isCommon = true - } else { - rightRemainder = append(rightRemainder, rp) - } - } - if !isCommon { - leftRemainder = append(leftRemainder, lp) +func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { + // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) + var commonPredicates []Expr + var leftRemainder, rightRemainder []Expr + + // Find all matching predicates and separate the remainder + rightRemainder = rhs.Predicates + for _, lp := range lhs.Predicates { + rhs := rightRemainder + rightRemainder = nil + isCommon := false + for _, rp := range rhs { + if Equals.Expr(lp, rp) { + commonPredicates = append(commonPredicates, lp) + isCommon = true + } else { + rightRemainder = append(rightRemainder, rp) } } - - if len(commonPredicates) > 0 { - // Build the final AndExpr with common predicates and the OrExpr of remainders - nonCommonPredicates := append(leftRemainder, rightRemainder...) - commonPred := AndExpressions(commonPredicates...) - if len(nonCommonPredicates) == 0 { - return commonPred, true - } - return AndExpressions(commonPred, createOrs(nonCommonPredicates...)), true + if !isCommon { + leftRemainder = append(leftRemainder, lp) } - return or, false } - if !lok && !rok { - lftCmp, lok := or.Left.(*ComparisonExpr) - rgtCmp, rok := or.Right.(*ComparisonExpr) - if lok && rok { - newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) - if rewritten { - // or(a=x,a=y) => in(a,[x,y]) - return newExpr, true - } - } + if len(commonPredicates) == 0 { return or, false } - // if we get here, one side is an AND - var and *AndExpr - var other Expr - if lok { - and = land - other = or.Right - } else { - and = rand - other = or.Left + // Build the final AndExpr with common predicates and the OrExpr of remainders + commonPred := AndExpressions(commonPredicates...) + if len(leftRemainder) == 0 && len(rightRemainder) == 0 { + if DebugRewrite { + fmt.Println(" >> remove duplicate predicates across ANDs") + } + return commonPred, true + } + + switch { + case len(rightRemainder) == 0 || len(leftRemainder) == 0: + if DebugRewrite { + fmt.Println(" >> (A and B and A and D) or (A and D) => A and D") + } + return commonPred, true + default: + if DebugRewrite { + fmt.Println(" >> (A and B and D) or (A and D and C) => (A and D) and (B or C)") + } + + return AndExpressions(commonPred, createOrs(AndExpressions(rightRemainder...), AndExpressions(leftRemainder...))), true } +} +func simplifyOrWithAnAND(and *AndExpr, other Expr, left bool) (Expr, bool) { for _, lp := range and.Predicates { if Equals.Expr(other, lp) { // if we have the same predicate on both sides of the OR, we can simplify // (A AND B) OR A => A // because if A is true, the OR is true, not matter what B is, // and if A is false, the AND is false, and again we don't care about B + if DebugRewrite { + fmt.Println(" >> (A AND B) OR A => A") + } return other, true } } @@ -180,18 +187,73 @@ func simplifyOr(or *OrExpr) (Expr, bool) { var distributedPredicates []Expr for _, lp := range and.Predicates { var or *OrExpr - if lok { + if left { or = &OrExpr{Left: lp, Right: other} } else { or = &OrExpr{Left: other, Right: lp} } distributedPredicates = append(distributedPredicates, or) } + if DebugRewrite { + fmt.Println(" >> (A and B) or C => (A or C) and (B or C)") + } return AndExpressions(distributedPredicates...), true } +func simplifyOr(or *OrExpr) (Expr, bool) { + res, rewritten := distinctOr(or) + if rewritten { + if DebugRewrite { + fmt.Println(" >> distinct or elements") + } + + return res, true + } + + land, lok := or.Left.(*AndExpr) + rand, rok := or.Right.(*AndExpr) + + switch { + case lok && rok: + return simplifyOredAnds(or, land, rand) + case !lok && !rok: + return simplifyOrToIn(or) + default: + // if we get here, one side is an AND + var and *AndExpr + var other Expr + if lok { + and = land + other = or.Right + } else { + and = rand + other = or.Left + } + + return simplifyOrWithAnAND(and, other, lok) + } +} + +func simplifyOrToIn(or *OrExpr) (Expr, bool) { + lftCmp, lok := or.Left.(*ComparisonExpr) + rgtCmp, rok := or.Right.(*ComparisonExpr) + if lok && rok { + newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) + if rewritten { + if DebugRewrite { + fmt.Println(" >> turning OR into IN") + } + return newExpr, true + } + } + + return or, false +} + func simplifyXor(xor *XorExpr) (Expr, bool) { - // xor(a,b) => and(or(a,b), not(and(a,b)) + if DebugRewrite { + fmt.Println(" >> a xor b => (a or b) and not(a and b)") + } return AndExpressions( &OrExpr{Left: xor.Left, Right: xor.Right}, &NotExpr{Expr: AndExpressions(xor.Left, xor.Right)}, @@ -200,11 +262,18 @@ func simplifyXor(xor *XorExpr) (Expr, bool) { func simplifyAnd(expr *AndExpr) (Expr, bool) { if len(expr.Predicates) == 1 { + if DebugRewrite { + fmt.Println(" >> single predicate in AND") + } return expr.Predicates[0], true } res, rewritten := distinctAnd(expr) if rewritten { + if DebugRewrite { + fmt.Println(" >> distinct and elements") + } + return res, true } @@ -236,6 +305,10 @@ outer: } if simplified { + if DebugRewrite { + fmt.Println(" >> (a or b) and a => a") + } + // Return a new AndExpr with the simplified predicates return AndExpressions(simplifiedPredicates...), true } @@ -243,75 +316,6 @@ outer: return expr, false } -// ExtractINFromOR rewrites the OR expression into an IN clause. -// Each side of each ORs has to be an equality comparison expression and the column names have to -// match for all sides of each comparison. -// This rewriter takes a query that looks like this WHERE a = 1 and b = 11 or a = 2 and b = 12 or a = 3 and b = 13 -// And rewrite that to WHERE (a, b) IN ((1,11), (2,12), (3,13)) -func ExtractINFromOR(expr *OrExpr) []Expr { - var varNames []*ColName - var values []Exprs - orSlice := orToSlice(expr) - for _, expr := range orSlice { - andSlice := andToSlice(expr) - if len(andSlice) == 0 { - return nil - } - - var currentVarNames []*ColName - var currentValues []Expr - for _, comparisonExpr := range andSlice { - if comparisonExpr.Operator != EqualOp { - return nil - } - - var colName *ColName - if left, ok := comparisonExpr.Left.(*ColName); ok { - colName = left - currentValues = append(currentValues, comparisonExpr.Right) - } - - if right, ok := comparisonExpr.Right.(*ColName); ok { - if colName != nil { - return nil - } - colName = right - currentValues = append(currentValues, comparisonExpr.Left) - } - - if colName == nil { - return nil - } - - currentVarNames = append(currentVarNames, colName) - } - - if len(varNames) == 0 { - varNames = currentVarNames - } else if !slices.EqualFunc(varNames, currentVarNames, func(col1, col2 *ColName) bool { return col1.Equal(col2) }) { - return nil - } - - values = append(values, currentValues) - } - - var nameTuple ValTuple - for _, name := range varNames { - nameTuple = append(nameTuple, name) - } - - var valueTuple ValTuple - for _, value := range values { - valueTuple = append(valueTuple, ValTuple(value)) - } - - return []Expr{&ComparisonExpr{ - Operator: InOp, - Left: nameTuple, - Right: valueTuple, - }} -} - func orToSlice(expr *OrExpr) []Expr { var exprs []Expr diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index 9c1e25624ce..ce56b5e85bc 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -86,15 +86,21 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { func TestOneRewriting(t *testing.T) { venv := vtenv.NewTestEnv() + sqlparser.DebugRewrite = true // Modify these const numberOfColumns = 2 - const expr = "n1 and n0 or n1 xor n1" + const expr = "not (n1 xor n0 or n1)" predicate, err := sqlparser.NewTestParser().ParseExpr(expr) require.NoError(t, err) - simplified := sqlparser.RewritePredicate(predicate) + var steps []sqlparser.SQLNode + + simplified := sqlparser.RewritePredicateInternal(predicate, func(n sqlparser.SQLNode) { + steps = append(steps, n) + }) + fmt.Println(sqlparser.String(simplified)) cfg := &evalengine.Config{ Environment: venv, @@ -103,16 +109,23 @@ func TestOneRewriting(t *testing.T) { } original, err := evalengine.Translate(predicate, cfg) require.NoError(t, err) - simpler, err := evalengine.Translate(simplified.(sqlparser.Expr), cfg) - require.NoError(t, err) - env := evalengine.EmptyExpressionEnv(venv) - env.Row = make([]sqltypes.Value, numberOfColumns) - for i := range env.Row { - env.Row[i] = sqltypes.NULL + for _, step := range steps { + name := sqlparser.String(step) + t.Run(name, func(t *testing.T) { + simpler, err := evalengine.Translate(step.(sqlparser.Expr), cfg) + require.NoError(t, err) + + env := evalengine.EmptyExpressionEnv(venv) + env.Row = make([]sqltypes.Value, numberOfColumns) + for i := range env.Row { + env.Row[i] = sqltypes.NULL + } + + testValues(t, env, 0, original, simpler) + }) } - testValues(t, env, 0, original, simpler) } func TestFuzzRewriting(t *testing.T) { @@ -125,7 +138,7 @@ func TestFuzzRewriting(t *testing.T) { start := time.Now() for time.Since(start) < 1*time.Second { tc := testCase{ - nodes: 2, + nodes: rand.IntN(4) + 1, depth: rand.IntN(4) + 1, } From 80815e71c952ccb1d6ef885df0654e161011b6d5 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 12:00:45 +0200 Subject: [PATCH 12/14] minor cleanups Signed-off-by: Andres Taylor --- go/vt/sqlparser/predicate_rewriting.go | 30 +++++++++---------- .../planbuilder/predicate_rewrite_test.go | 5 ++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 7635553ec6d..04c4a7a9f32 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -21,6 +21,7 @@ import ( ) var DebugRewrite = false +var prefix = " >> " // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN @@ -87,13 +88,11 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { case *NotExpr: return child.Expr, true case *OrExpr: - // not(or(a,b)) => and(not(a),not(b)) if DebugRewrite { - fmt.Println(" >> not (a or b) => not a and not b") + fmt.Println(prefix, "not (a or b) => not a and not b") } return AndExpressions(&NotExpr{Expr: child.Left}, &NotExpr{Expr: child.Right}), true case *AndExpr: - // not(and(a,b)) => or(not(a), not(b)) var curr Expr for i, p := range child.Predicates { if i == 0 { @@ -103,7 +102,7 @@ func simplifyNot(expr *NotExpr) (Expr, bool) { } } if DebugRewrite { - fmt.Println(" >> not (a and b) => not a or not b") + fmt.Println(prefix, "not (a and b) => not a or not b") } return curr, true } @@ -118,7 +117,6 @@ func createOrs(exprs ...Expr) Expr { } func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { - // (A AND B AND D) OR (A AND C AND D) => (A AND D) AND (B OR C) var commonPredicates []Expr var leftRemainder, rightRemainder []Expr @@ -149,7 +147,7 @@ func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { commonPred := AndExpressions(commonPredicates...) if len(leftRemainder) == 0 && len(rightRemainder) == 0 { if DebugRewrite { - fmt.Println(" >> remove duplicate predicates across ANDs") + fmt.Println(prefix, "remove duplicate predicates across ANDs") } return commonPred, true } @@ -157,12 +155,12 @@ func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { switch { case len(rightRemainder) == 0 || len(leftRemainder) == 0: if DebugRewrite { - fmt.Println(" >> (A and B and A and D) or (A and D) => A and D") + fmt.Println(prefix, "(A and B and A and D) or (A and D) => A and D") } return commonPred, true default: if DebugRewrite { - fmt.Println(" >> (A and B and D) or (A and D and C) => (A and D) and (B or C)") + fmt.Println(prefix, "(A and B and C) or (A and B and D) => (A and B) and (C or D)") } return AndExpressions(commonPred, createOrs(AndExpressions(rightRemainder...), AndExpressions(leftRemainder...))), true @@ -177,7 +175,7 @@ func simplifyOrWithAnAND(and *AndExpr, other Expr, left bool) (Expr, bool) { // because if A is true, the OR is true, not matter what B is, // and if A is false, the AND is false, and again we don't care about B if DebugRewrite { - fmt.Println(" >> (A AND B) OR A => A") + fmt.Println(prefix, "(A AND B) OR A => A") } return other, true } @@ -195,7 +193,7 @@ func simplifyOrWithAnAND(and *AndExpr, other Expr, left bool) (Expr, bool) { distributedPredicates = append(distributedPredicates, or) } if DebugRewrite { - fmt.Println(" >> (A and B) or C => (A or C) and (B or C)") + fmt.Println(prefix, "(A and B) or C => (A or C) and (B or C)") } return AndExpressions(distributedPredicates...), true } @@ -204,7 +202,7 @@ func simplifyOr(or *OrExpr) (Expr, bool) { res, rewritten := distinctOr(or) if rewritten { if DebugRewrite { - fmt.Println(" >> distinct or elements") + fmt.Println(prefix, "distinct or elements") } return res, true @@ -241,7 +239,7 @@ func simplifyOrToIn(or *OrExpr) (Expr, bool) { newExpr, rewritten := tryTurningOrIntoIn(lftCmp, rgtCmp) if rewritten { if DebugRewrite { - fmt.Println(" >> turning OR into IN") + fmt.Println(prefix, "turning OR into IN") } return newExpr, true } @@ -252,7 +250,7 @@ func simplifyOrToIn(or *OrExpr) (Expr, bool) { func simplifyXor(xor *XorExpr) (Expr, bool) { if DebugRewrite { - fmt.Println(" >> a xor b => (a or b) and not(a and b)") + fmt.Println(prefix, "a xor b => (a or b) and not(a and b)") } return AndExpressions( &OrExpr{Left: xor.Left, Right: xor.Right}, @@ -263,7 +261,7 @@ func simplifyXor(xor *XorExpr) (Expr, bool) { func simplifyAnd(expr *AndExpr) (Expr, bool) { if len(expr.Predicates) == 1 { if DebugRewrite { - fmt.Println(" >> single predicate in AND") + fmt.Println(prefix, "single predicate in AND") } return expr.Predicates[0], true } @@ -271,7 +269,7 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) { res, rewritten := distinctAnd(expr) if rewritten { if DebugRewrite { - fmt.Println(" >> distinct and elements") + fmt.Println(prefix, "distinct and elements") } return res, true @@ -306,7 +304,7 @@ outer: if simplified { if DebugRewrite { - fmt.Println(" >> (a or b) and a => a") + fmt.Println(prefix, "(a or b) and a => a") } // Return a new AndExpr with the simplified predicates diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index ce56b5e85bc..1ce07f16b98 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -85,6 +85,11 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { } func TestOneRewriting(t *testing.T) { + // This test is a simple test that takes a single expression and simplifies it. + // While simplifying, it also collects all the steps that were taken to simplify the expression, + // and then runs both the original and simplified expressions with all possible values for the columns. + // If the two expressions do not return the same value, this is considered a test failure. + // This test is useful for debugging and understanding how the simplification works. venv := vtenv.NewTestEnv() sqlparser.DebugRewrite = true From 8d2a401196e2a990c561202d22f1413481679f60 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 13:36:47 +0200 Subject: [PATCH 13/14] dont stop too early when rewriting OR Signed-off-by: Andres Taylor --- go/vt/sqlparser/predicate_rewriting.go | 37 ++++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 04c4a7a9f32..9cafaccb3b4 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -163,7 +163,7 @@ func simplifyOredAnds(or *OrExpr, lhs, rhs *AndExpr) (Expr, bool) { fmt.Println(prefix, "(A and B and C) or (A and B and D) => (A and B) and (C or D)") } - return AndExpressions(commonPred, createOrs(AndExpressions(rightRemainder...), AndExpressions(leftRemainder...))), true + return AndExpressions(commonPred, createOrs(AndExpressions(leftRemainder...), AndExpressions(rightRemainder...))), true } } @@ -211,25 +211,28 @@ func simplifyOr(or *OrExpr) (Expr, bool) { land, lok := or.Left.(*AndExpr) rand, rok := or.Right.(*AndExpr) - switch { - case lok && rok: - return simplifyOredAnds(or, land, rand) - case !lok && !rok: - return simplifyOrToIn(or) - default: - // if we get here, one side is an AND - var and *AndExpr - var other Expr - if lok { - and = land - other = or.Right - } else { - and = rand - other = or.Left + if lok && rok { + res, success := simplifyOredAnds(or, land, rand) + if success { + return res, true } + } - return simplifyOrWithAnAND(and, other, lok) + if !lok && !rok { + return simplifyOrToIn(or) } + + var and *AndExpr + var other Expr + if lok { + and = land + other = or.Right + } else { + and = rand + other = or.Left + } + + return simplifyOrWithAnAND(and, other, lok) } func simplifyOrToIn(or *OrExpr) (Expr, bool) { From 31b8f71b8998548d8209121fb5dcf6d4a7c4e6c1 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 29 Aug 2024 13:49:04 +0200 Subject: [PATCH 14/14] remove test not used any more Signed-off-by: Andres Taylor --- go/vt/vtgate/semantics/early_rewriter_test.go | 47 ------------------- 1 file changed, 47 deletions(-) diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index fab8211f74e..4f550d46392 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -905,53 +905,6 @@ func TestOrderByDerivedTable(t *testing.T) { } } -// TestConstantFolding tests that the rewriter is able to do various constant foldings properly. -func TestConstantFolding(t *testing.T) { - ks := &vindexes.Keyspace{ - Name: "main", - Sharded: true, - } - schemaInfo := &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t1": { - Keyspace: ks, - Name: sqlparser.NewIdentifierCS("t1"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewIdentifierCI("a"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewIdentifierCI("b"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewIdentifierCI("c"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: true, - }, - }, - } - cDB := "db" - tcases := []struct { - sql string - expSQL string - }{{ - sql: "select 1 from t1 where (a, b) in ::fkc_vals and (2 is null or (1 is null or a in (1)))", - expSQL: "select 1 from t1 where (a, b) in ::fkc_vals and a in (1)", - }, { - sql: "select 1 from t1 where (false or (false or a in (1)))", - expSQL: "select 1 from t1 where a in (1)", - }} - for _, tcase := range tcases { - t.Run(tcase.sql, func(t *testing.T) { - ast, err := sqlparser.NewTestParser().Parse(tcase.sql) - require.NoError(t, err) - _, err = Analyze(ast, cDB, schemaInfo) - require.NoError(t, err) - require.Equal(t, tcase.expSQL, sqlparser.String(ast)) - }) - } -} - // TestCTEToDerivedTableRewrite checks that CTEs are correctly rewritten to derived tables func TestCTEToDerivedTableRewrite(t *testing.T) { cDB := "db"