From 858f4f8da4dfe1496e504a085ae2e77215a866c4 Mon Sep 17 00:00:00 2001 From: Florent Poinsard <35779988+frouioui@users.noreply.github.com> Date: Mon, 27 Nov 2023 10:19:07 -0600 Subject: [PATCH] Support unlimited number of ORs in `ExtractINFromOR` (#14566) Signed-off-by: Florent Poinsard --- go/vt/sqlparser/predicate_rewriting.go | 140 ++++++++++++++++---- go/vt/sqlparser/predicate_rewriting_test.go | 35 +++-- 2 files changed, 133 insertions(+), 42 deletions(-) diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index 7bad1b3b82f..234a2f4acd5 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -16,6 +16,8 @@ limitations under the License. package sqlparser +import "slices" + // RewritePredicate walks the input AST and rewrites any boolean logic into a simpler form // This simpler form is CNF plus logic for extracting predicates from OR, plus logic for turning ORs into IN func RewritePredicate(ast SQLNode) SQLNode { @@ -204,36 +206,128 @@ func simplifyAnd(expr *AndExpr) (Expr, bool) { return expr, false } -// ExtractINFromOR will add additional predicated to an OR. -// this rewriter should not be used in a fixed point way, since it returns the original expression with additions, -// and it will therefor OOM before it stops rewriting +// ExtractINFromOR rewrites the OR expression into an IN clause. +// Each side of each ORs has to be an equality comparison expression and the column names have to +// match for all sides of each comparison. +// This rewriter takes a query that looks like this WHERE a = 1 and b = 11 or a = 2 and b = 12 or a = 3 and b = 13 +// And rewrite that to WHERE (a, b) IN ((1,11), (2,12), (3,13)) func ExtractINFromOR(expr *OrExpr) []Expr { - // we check if we have two comparisons on either side of the OR - // that we can add as an ANDed comparison. - // WHERE (a = 5 and B) or (a = 6 AND C) => - // WHERE (a = 5 AND B) OR (a = 6 AND C) AND a IN (5,6) - // This rewrite makes it possible to find a better route than Scatter if the `a` column has a helpful vindex - lftPredicates := SplitAndExpression(nil, expr.Left) - rgtPredicates := SplitAndExpression(nil, expr.Right) - var ins []Expr - for _, lft := range lftPredicates { - l, ok := lft.(*ComparisonExpr) - if !ok { - continue + var varNames []*ColName + var values []Exprs + orSlice := orToSlice(expr) + for _, expr := range orSlice { + andSlice := andToSlice(expr) + if len(andSlice) == 0 { + return nil } - for _, rgt := range rgtPredicates { - r, ok := rgt.(*ComparisonExpr) - if !ok { - continue + + var currentVarNames []*ColName + var currentValues []Expr + for _, comparisonExpr := range andSlice { + if comparisonExpr.Operator != EqualOp { + return nil } - in, changed := tryTurningOrIntoIn(l, r) - if changed { - ins = append(ins, in) + + var colName *ColName + if left, ok := comparisonExpr.Left.(*ColName); ok { + colName = left + currentValues = append(currentValues, comparisonExpr.Right) + } + + if right, ok := comparisonExpr.Right.(*ColName); ok { + if colName != nil { + return nil + } + colName = right + currentValues = append(currentValues, comparisonExpr.Left) + } + + if colName == nil { + return nil } + + currentVarNames = append(currentVarNames, colName) + } + + if len(varNames) == 0 { + varNames = currentVarNames + } else if !slices.EqualFunc(varNames, currentVarNames, func(col1, col2 *ColName) bool { return col1.Equal(col2) }) { + return nil } + + values = append(values, currentValues) + } + + var nameTuple ValTuple + for _, name := range varNames { + nameTuple = append(nameTuple, name) + } + + var valueTuple ValTuple + for _, value := range values { + valueTuple = append(valueTuple, ValTuple(value)) + } + + return []Expr{&ComparisonExpr{ + Operator: InOp, + Left: nameTuple, + Right: valueTuple, + }} +} + +func orToSlice(expr *OrExpr) []Expr { + var exprs []Expr + + handleOrSide := func(e Expr) { + switch e := e.(type) { + case *OrExpr: + exprs = append(exprs, orToSlice(e)...) + default: + exprs = append(exprs, e) + } + } + + handleOrSide(expr.Left) + handleOrSide(expr.Right) + return exprs +} + +func andToSlice(expr Expr) []*ComparisonExpr { + var andExpr *AndExpr + switch expr := expr.(type) { + case *AndExpr: + andExpr = expr + case *ComparisonExpr: + return []*ComparisonExpr{expr} + default: + return nil + } + + var exprs []*ComparisonExpr + handleAndSide := func(e Expr) bool { + switch e := e.(type) { + case *AndExpr: + slice := andToSlice(e) + if slice == nil { + return false + } + exprs = append(exprs, slice...) + case *ComparisonExpr: + exprs = append(exprs, e) + default: + return false + } + return true + } + + if !handleAndSide(andExpr.Left) { + return nil + } + if !handleAndSide(andExpr.Right) { + return nil } - return uniquefy(ins) + return exprs } func tryTurningOrIntoIn(l, r *ComparisonExpr) (Expr, bool) { diff --git a/go/vt/sqlparser/predicate_rewriting_test.go b/go/vt/sqlparser/predicate_rewriting_test.go index e106a56f1aa..a4bbb5f7b5c 100644 --- a/go/vt/sqlparser/predicate_rewriting_test.go +++ b/go/vt/sqlparser/predicate_rewriting_test.go @@ -140,6 +140,18 @@ func TestRewritePredicate(in *testing.T) { // the following two tests show some pathological cases that would grow too much, and so we abort the rewriting in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", expected: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", + }, { + in: "a = 5 and B or a = 6 and C", + expected: "a in (5, 6) and (a = 5 or C) and ((B or a = 6) and (B or C))", + }, { + in: "(a = 5 and b = 1 or b = 2 and a = 6)", + expected: "(a = 5 or b = 2) and a in (5, 6) and (b in (1, 2) and (b = 1 or a = 6))", + }, { + in: "(a in (1,5) and B or C and a = 6)", + expected: "(a in (1, 5) or C) and a in (1, 5, 6) and ((B or C) and (B or a = 6))", + }, { + in: "(a in (1, 5) and B or C and a in (5, 7))", + expected: "(a in (1, 5) or C) and a in (1, 5, 7) and ((B or C) and (B or a in (5, 7)))", }, { in: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", expected: "not n0 xor not (n2 and n3) xor (not n2 and (n1 xor n1) xor (n0 xor n0 xor n2))", @@ -161,26 +173,11 @@ func TestExtractINFromOR(in *testing.T) { in string expected string }{{ - in: "(A and B) or (B and A)", - expected: "", - }, { - in: "(a = 5 and B) or A", - expected: "", - }, { - in: "a = 5 and B or a = 6 and C", - expected: "a in (5, 6)", - }, { - in: "(a = 5 and b = 1 or b = 2 and a = 6)", - expected: "a in (5, 6) and b in (1, 2)", - }, { - in: "(a in (1,5) and B or C and a = 6)", - expected: "a in (1, 5, 6)", - }, { - in: "(a in (1, 5) and B or C and a in (5, 7))", - expected: "a in (1, 5, 7)", + in: "a = 1 and b = 41 or a = 2 and b = 42 or a = 3 and b = 43 or a = 4 and b = 44 or a = 5 and b = 45 or a = 6 and b = 46", + expected: "(a, b) in ((1, 41), (2, 42), (3, 43), (4, 44), (5, 45), (6, 46))", }, { - in: "(a = 5 and b = 1 or b = 2 and a = 6 or b = 3 and a = 4)", - expected: "", + in: "a = 1 or a = 2 or a = 3 or a = 4 or a = 5 or a = 6", + expected: "(a) in ((1), (2), (3), (4), (5), (6))", }} for _, tc := range tests {