diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 124833bc91d..7ddb559ebf9 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2263,7 +2263,7 @@ type ( } // AndExpr represents an AND expression. - AndExpr struct{ Predicates []Expr } + AndExpr struct{ Predicates Exprs } // OrExpr represents an OR expression. OrExpr struct { diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 00330ff1972..99122454bba 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -744,7 +744,7 @@ func CloneRefOfAndExpr(n *AndExpr) *AndExpr { return nil } out := *n - out.Predicates = CloneSliceOfExpr(n.Predicates) + out.Predicates = CloneExprs(n.Predicates) return &out } @@ -4368,18 +4368,6 @@ func CloneSliceOfIdentifierCI(n []IdentifierCI) []IdentifierCI { return res } -// CloneSliceOfExpr creates a deep clone of the input. -func CloneSliceOfExpr(n []Expr) []Expr { - if n == nil { - return nil - } - res := make([]Expr, len(n)) - for i, x := range n { - res[i] = CloneExpr(x) - } - return res -} - // CloneSliceOfTxAccessMode creates a deep clone of the input. func CloneSliceOfTxAccessMode(n []TxAccessMode) []TxAccessMode { if n == nil { @@ -4469,6 +4457,18 @@ func CloneSliceOfRefOfVariable(n []*Variable) []*Variable { return res } +// CloneSliceOfExpr creates a deep clone of the input. +func CloneSliceOfExpr(n []Expr) []Expr { + if n == nil { + return nil + } + res := make([]Expr, len(n)) + for i, x := range n { + res[i] = CloneExpr(x) + } + return res +} + // CloneRefOfIdentifierCI creates a deep clone of the input. func CloneRefOfIdentifierCI(n *IdentifierCI) *IdentifierCI { if n == nil { diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 2caea4bd978..594cc21df8c 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -953,18 +953,10 @@ func (c *cow) copyOnRewriteRefOfAndExpr(n *AndExpr, parent SQLNode) (out SQLNode } out = n if c.pre == nil || c.pre(n, parent) { - var changedPredicates bool - _Predicates := make([]Expr, len(n.Predicates)) - for x, el := range n.Predicates { - this, changed := c.copyOnRewriteExpr(el, n) - _Predicates[x] = this.(Expr) - if changed { - changedPredicates = true - } - } + _Predicates, changedPredicates := c.copyOnRewriteExprs(n.Predicates, n) if changedPredicates { res := *n - res.Predicates = _Predicates + res.Predicates, _ = _Predicates.(Exprs) out = &res if c.cloned != nil { c.cloned(n, out) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 66ec6190159..4d5335da125 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1863,7 +1863,7 @@ func (cmp *Comparator) RefOfAndExpr(a, b *AndExpr) bool { if a == nil || b == nil { return false } - return cmp.SliceOfExpr(a.Predicates, b.Predicates) + return cmp.Exprs(a.Predicates, b.Predicates) } // RefOfAnyValue does deep equals between the two objects. @@ -7245,19 +7245,6 @@ func (cmp *Comparator) SliceOfIdentifierCI(a, b []IdentifierCI) bool { return true } -// SliceOfExpr does deep equals between the two objects. -func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { - if len(a) != len(b) { - return false - } - for i := 0; i < len(a); i++ { - if !cmp.Expr(a[i], b[i]) { - return false - } - } - return true -} - // SliceOfTxAccessMode does deep equals between the two objects. func (cmp *Comparator) SliceOfTxAccessMode(a, b []TxAccessMode) bool { if len(a) != len(b) { @@ -7366,6 +7353,19 @@ func (cmp *Comparator) SliceOfRefOfVariable(a, b []*Variable) bool { return true } +// SliceOfExpr does deep equals between the two objects. +func (cmp *Comparator) SliceOfExpr(a, b []Expr) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !cmp.Expr(a[i], b[i]) { + return false + } + } + return true +} + // RefOfIdentifierCI does deep equals between the two objects. func (cmp *Comparator) RefOfIdentifierCI(a, b *IdentifierCI) bool { if a == b { diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index db6965de1f3..73247d9a322 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2360,6 +2360,17 @@ func SplitAndExpression(filters []Expr, node Expr) []Expr { return append(filters, node) } +func CreateAndExpr(exprs ...Expr) Expr { + switch len(exprs) { + case 0: + return nil + case 1: + return exprs[0] + default: + return &AndExpr{Predicates: exprs} + } +} + // AndExpressions ands together two or more expressions, minimising the expr when possible func AndExpressions(exprs ...Expr) Expr { switch len(exprs) { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index df144e4fad8..02e4d9411fd 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1088,14 +1088,10 @@ func (a *application) rewriteRefOfAndExpr(parent SQLNode, node *AndExpr, replace return true } } - for x, el := range node.Predicates { - if !a.rewriteExpr(node, el, func(idx int) replacerFunc { - return func(newNode, parent SQLNode) { - parent.(*AndExpr).Predicates[idx] = newNode.(Expr) - } - }(x)) { - return false - } + if !a.rewriteExprs(node, node.Predicates, func(newNode, parent SQLNode) { + parent.(*AndExpr).Predicates = newNode.(Exprs) + }) { + return false } if a.post != nil { a.cur.replacer = replacer diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index 04608d84ec2..f299a131a9f 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -811,10 +811,8 @@ func VisitRefOfAndExpr(in *AndExpr, f Visit) error { if cont, err := f(in); err != nil || !cont { return err } - for _, el := range in.Predicates { - if err := VisitExpr(el, f); err != nil { - return err - } + if err := VisitExprs(in.Predicates, f); err != nil { + return err } return nil } diff --git a/go/vt/sqlparser/predicate_rewriting.go b/go/vt/sqlparser/predicate_rewriting.go index e4e859b2ba2..a6da91c8c7f 100644 --- a/go/vt/sqlparser/predicate_rewriting.go +++ b/go/vt/sqlparser/predicate_rewriting.go @@ -181,10 +181,13 @@ func simplifyOr(or *OrExpr) (Expr, bool) { // Distribution Law var distributedPredicates []Expr for _, lp := range and.Predicates { - distributedPredicates = append(distributedPredicates, &OrExpr{ - Left: lp, - Right: other, - }) + var or *OrExpr + if lok { + or = &OrExpr{Left: lp, Right: other} + } else { + or = &OrExpr{Left: other, Right: lp} + } + distributedPredicates = append(distributedPredicates, or) } return AndExpressions(distributedPredicates...), true } diff --git a/go/vt/vtgate/evalengine/translate.go b/go/vt/vtgate/evalengine/translate.go index 99e1508cc04..a887f187e16 100644 --- a/go/vt/vtgate/evalengine/translate.go +++ b/go/vt/vtgate/evalengine/translate.go @@ -99,15 +99,35 @@ func (ast *astCompiler) translateLogicalNot(node *sqlparser.NotExpr) (IR, error) return &NotExpr{UnaryExpr{inner}}, nil } +func (ast *astCompiler) translateLogicalAnd(node *sqlparser.AndExpr) (IR, error) { + var acc IR + for i, pred := range node.Predicates { + ir, err := ast.translateExpr(pred) + if err != nil { + return nil, err + } + if i == 0 { + acc = ir + continue + } + + acc = &LogicalExpr{ + BinaryExpr: BinaryExpr{ + Left: acc, + Right: ir, + }, + op: opLogicalAnd{}, + } + } + + return acc, nil +} + func (ast *astCompiler) translateLogicalExpr(node sqlparser.Expr) (IR, error) { var left, right sqlparser.Expr var logic opLogical switch n := node.(type) { - case *sqlparser.AndExpr: - left = n.Left - right = n.Right - logic = opLogicalAnd{} case *sqlparser.OrExpr: left = n.Left right = n.Right @@ -521,7 +541,7 @@ func (ast *astCompiler) translateExpr(e sqlparser.Expr) (IR, error) { case *sqlparser.Literal: return translateLiteral(node, ast.cfg.Collation) case *sqlparser.AndExpr: - return ast.translateLogicalExpr(node) + return ast.translateLogicalAnd(node) case *sqlparser.OrExpr: return ast.translateLogicalExpr(node) case *sqlparser.XorExpr: diff --git a/go/vt/vtgate/planbuilder/operators/querygraph.go b/go/vt/vtgate/planbuilder/operators/querygraph.go index 8e8572f7dfa..e0231e26b3b 100644 --- a/go/vt/vtgate/planbuilder/operators/querygraph.go +++ b/go/vt/vtgate/planbuilder/operators/querygraph.go @@ -141,10 +141,7 @@ func (qg *QueryGraph) addNoDepsPredicate(predicate sqlparser.Expr) { if qg.NoDeps == nil { qg.NoDeps = predicate } else { - qg.NoDeps = &sqlparser.AndExpr{ - Left: qg.NoDeps, - Right: predicate, - } + qg.NoDeps = sqlparser.AndExpressions(qg.NoDeps, predicate) } } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index b919bbfaed9..7b5e7b0515d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -298,7 +298,7 @@ func (sq *SubQuery) settleFilter(ctx *plancontext.PlanningContext, outer Operato // lead to better routing. This however might not always be true for example we can have the rhsPred to be something like // `user.id = 2 OR (:__sq_has_values AND user.id IN ::sql1)` if andExpr, isAndExpr := rhsPred.(*sqlparser.AndExpr); isAndExpr { - predicates = append(predicates, andExpr.Left, andExpr.Right) + predicates = append(predicates, andExpr.Predicates...) } else { predicates = append(predicates, rhsPred) } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index b4f0a37914e..56d516332a0 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -743,10 +743,7 @@ func buildChildUpdOpForSetNull( updatedTable, updateExprs, fk, updatedTable.GetTableName(), nonLiteralUpdateInfo, false /* appendQualifier */) if compExpr != nil { - childWhereExpr = &sqlparser.AndExpr{ - Left: childWhereExpr, - Right: compExpr, - } + childWhereExpr = sqlparser.AndExpressions(childWhereExpr, compExpr) } parsedComments := getParsedCommentsForFkChecks(ctx) childUpdStmt := &sqlparser.Update{ @@ -847,13 +844,12 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, upda var predicate sqlparser.Expr = parentIsNullExpr var joinExpr sqlparser.Expr if matchedExpr == nil { - predicate = &sqlparser.AndExpr{ - Left: predicate, - Right: &sqlparser.IsExpr{ + predicate = sqlparser.AndExpressions( + predicate, + &sqlparser.IsExpr{ Left: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), Right: sqlparser.IsNotNullOp, - }, - } + }) joinExpr = &sqlparser.ComparisonExpr{ Operator: sqlparser.EqualOp, Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), @@ -868,35 +864,33 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, upda Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), Right: prefixedMatchExpr, } - predicate = &sqlparser.AndExpr{ - Left: predicate, - Right: &sqlparser.IsExpr{ + predicate = sqlparser.AndExpressions( + predicate, + &sqlparser.IsExpr{ Left: prefixedMatchExpr, Right: sqlparser.IsNotNullOp, - }, - } + }) } if idx == 0 { joinCond, whereCond = joinExpr, predicate continue } - joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} - whereCond = &sqlparser.AndExpr{Left: whereCond, Right: predicate} + joinCond = sqlparser.AndExpressions(joinCond, joinExpr) + whereCond = sqlparser.AndExpressions(whereCond, predicate) } - whereCond = &sqlparser.AndExpr{ - Left: whereCond, - Right: &sqlparser.NotExpr{ + whereCond = sqlparser.AndExpressions( + whereCond, + &sqlparser.NotExpr{ Expr: &sqlparser.ComparisonExpr{ Operator: sqlparser.NullSafeEqualOp, Left: notEqualColNames, Right: notEqualExprs, }, - }, - } + }) // add existing where condition on the update statement if updStmt.Where != nil { - whereCond = &sqlparser.AndExpr{Left: whereCond, Right: prefixColNames(ctx, childTbl, updStmt.Where.Expr)} + whereCond = sqlparser.AndExpressions(whereCond, prefixColNames(ctx, childTbl, updStmt.Where.Expr)) } return createSelectionOp(ctx, sqlparser.SelectExprs{sqlparser.NewAliasedExpr(sqlparser.NewIntLiteral("1"), "")}, @@ -959,7 +953,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updat joinCond = joinExpr continue } - joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} + joinCond = sqlparser.AndExpressions(joinCond, joinExpr) } var whereCond sqlparser.Expr diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index 4945c2bb7ff..2f262f75a7f 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -64,10 +64,10 @@ func (tc testCase) createPredicate(lvl int) sqlparser.Expr { Expr: tc.createPredicate(lvl + 1), } case AND: - return &sqlparser.AndExpr{ - Left: tc.createPredicate(lvl + 1), - Right: tc.createPredicate(lvl + 1), - } + return sqlparser.AndExpressions( + tc.createPredicate(lvl+1), + tc.createPredicate(lvl+1), + ) case OR: return &sqlparser.OrExpr{ Left: tc.createPredicate(lvl + 1), diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ee12765e984..f38259735c5 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -864,13 +864,13 @@ func rewriteOrExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlpa // rewriteAndExpr rewrites AND expressions when either side is TRUE. func rewriteAndExpr(env *vtenv.Environment, cursor *sqlparser.Cursor, node *sqlparser.AndExpr) { - newNode := rewriteAndTrue(env, *node) + newNode := rewriteAndTrue(env, node) if newNode != nil { cursor.ReplaceAndRevisit(newNode) } } -func rewriteAndTrue(env *vtenv.Environment, andExpr sqlparser.AndExpr) sqlparser.Expr { +func rewriteAndTrue(env *vtenv.Environment, andExpr *sqlparser.AndExpr) sqlparser.Expr { // we are looking for the pattern `WHERE c = 1 AND 1 = 1` isTrue := func(subExpr sqlparser.Expr) bool { coll := env.CollationEnv().DefaultConnectionCharset() @@ -896,13 +896,18 @@ func rewriteAndTrue(env *vtenv.Environment, andExpr sqlparser.AndExpr) sqlparser return boolValue } - if isTrue(andExpr.Left) { - return andExpr.Right - } else if isTrue(andExpr.Right) { - return andExpr.Left + var remaining sqlparser.Exprs + for _, p := range andExpr.Predicates { + if !isTrue(p) { + remaining = append(remaining, p) + } } - return nil + if len(remaining) == len(andExpr.Predicates) { + return nil + } + + return sqlparser.AndExpressions(remaining...) } // handleComparisonExpr processes Comparison expressions, specifically for tuples with equal length and EqualOp operator. diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 9d596d9ecd1..0c0dfe1fa34 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -309,10 +309,7 @@ func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) erro // at this stage, we don't store the actual dependencies, we only store the expressions. // only later will we walk the expression tree and figure out the deps. so, we need to create a // composite expression that contains all the expressions in the SELECTs that this UNION consists of - tableInfo.cols[i] = &sqlparser.AndExpr{ - Left: col, - Right: thisTableInfo.cols[i], - } + tableInfo.cols[i] = sqlparser.CreateAndExpr(col, thisTableInfo.cols[i]) } } diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index 6738546fe37..81667b2d898 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -911,7 +911,7 @@ func (st *SemTable) AndExpressions(exprs ...sqlparser.Expr) sqlparser.Expr { continue outer } } - result = &sqlparser.AndExpr{Left: result, Right: expr} + result = sqlparser.AndExpressions(result, expr) } return result } diff --git a/go/vt/vtgate/simplifier/expression_simplifier.go b/go/vt/vtgate/simplifier/expression_simplifier.go index b64402cfaac..2a3312533bc 100644 --- a/go/vt/vtgate/simplifier/expression_simplifier.go +++ b/go/vt/vtgate/simplifier/expression_simplifier.go @@ -20,6 +20,8 @@ import ( "fmt" "strconv" + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" ) @@ -97,7 +99,10 @@ func (s *shrinker) fillQueue() bool { before := len(s.queue) switch e := s.orig.(type) { case *sqlparser.AndExpr: - s.queue = append(s.queue, e.Left, e.Right) + addThese := slice.Map(e.Predicates, func(e sqlparser.Expr) sqlparser.SQLNode { + return e + }) + s.queue = append(s.queue, addThese...) case *sqlparser.OrExpr: s.queue = append(s.queue, e.Left, e.Right) case *sqlparser.ComparisonExpr: