diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 7ddb559ebf9..9d01e87f3c1 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -720,6 +720,11 @@ type ( // IndexType is the type of index in a DDL statement IndexType int8 + + WhereAble interface { + AddWhere(e Expr) + GetWherePredicate() Expr + } ) var _ OrderAndLimit = (*Select)(nil) diff --git a/go/vt/vtctl/workflow/materializer.go b/go/vt/vtctl/workflow/materializer.go index f0171f31cab..7719bc81318 100644 --- a/go/vt/vtctl/workflow/materializer.go +++ b/go/vt/vtctl/workflow/materializer.go @@ -234,10 +234,10 @@ func (mz *materializer) generateBinlogSources(ctx context.Context, targetShard * Name: sqlparser.NewIdentifierCI("in_keyrange"), Exprs: subExprs, } - addFilter(sel, inKeyRange) + sel.AddWhere(inKeyRange) } if tenantClause != nil { - addFilter(sel, *tenantClause) + sel.AddWhere(*tenantClause) } rule.Filter = sqlparser.String(sel) bls.Filter.Rules = append(bls.Filter.Rules, rule) diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go index 34e1e4e4329..b10f7a9645e 100644 --- a/go/vt/vtctl/workflow/traffic_switcher.go +++ b/go/vt/vtctl/workflow/traffic_switcher.go @@ -968,7 +968,7 @@ func (ts *trafficSwitcher) addTenantFilter(ctx context.Context, filter string) ( if !ok { return "", fmt.Errorf("unrecognized statement: %s", filter) } - addFilter(sel, *tenantClause) + sel.AddWhere(*tenantClause) filter = sqlparser.String(sel) return filter, nil } diff --git a/go/vt/vtctl/workflow/utils.go b/go/vt/vtctl/workflow/utils.go index 374d96396f2..5abe1689084 100644 --- a/go/vt/vtctl/workflow/utils.go +++ b/go/vt/vtctl/workflow/utils.go @@ -793,23 +793,6 @@ func LegacyBuildTargets(ctx context.Context, ts *topo.Server, tmc tmclient.Table }, nil } -func addFilter(sel *sqlparser.Select, filter sqlparser.Expr) { - if sel.Where != nil { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: &sqlparser.AndExpr{ - Left: filter, - Right: sel.Where.Expr, - }, - } - } else { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: filter, - } - } -} - func getTenantClause(vrOptions *vtctldatapb.WorkflowOptions, targetVSchema *vindexes.KeyspaceSchema, parser *sqlparser.Parser) (*sqlparser.Expr, error) { if vrOptions.TenantId == "" { diff --git a/go/vt/vtctl/workflow/vexec/query_planner.go b/go/vt/vtctl/workflow/vexec/query_planner.go index de052efce8c..9d16dc72f55 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner.go +++ b/go/vt/vtctl/workflow/vexec/query_planner.go @@ -354,10 +354,7 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q case nil: targetWhere.Expr = expr default: - targetWhere.Expr = &sqlparser.AndExpr{ - Left: expr, - Right: where.Expr, - } + targetWhere.Expr = sqlparser.CreateAndExpr(expr, where.Expr) } sel.Where = targetWhere @@ -408,10 +405,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Expr: expr, } default: - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) } } @@ -424,10 +418,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Right: sqlparser.NewStrLiteral(params.Workflow), } - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) } return newWhere diff --git a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go index 9bbc98ca2bd..55988640fc3 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go +++ b/go/vt/vttablet/tabletserver/vstreamer/planbuilder.go @@ -518,7 +518,7 @@ func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) er if where == nil { return nil } - exprs := splitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, where.Expr) for _, expr := range exprs { switch expr := expr.(type) { case *sqlparser.ComparisonExpr: @@ -595,21 +595,6 @@ func (plan *Plan) analyzeWhere(vschema *localVSchema, where *sqlparser.Where) er return nil } -// splitAndExpression breaks up the Expr into AND-separated conditions -// and appends them to filters, which can be shuffled and recombined -// as needed. -func splitAndExpression(filters []sqlparser.Expr, node sqlparser.Expr) []sqlparser.Expr { - if node == nil { - return filters - } - switch node := node.(type) { - case *sqlparser.AndExpr: - filters = splitAndExpression(filters, node.Left) - return splitAndExpression(filters, node.Right) - } - return append(filters, node) -} - func (plan *Plan) analyzeExprs(vschema *localVSchema, selExprs sqlparser.SelectExprs) error { if _, ok := selExprs[0].(*sqlparser.StarExpr); !ok { for _, expr := range selExprs { diff --git a/go/vt/wrangler/materializer.go b/go/vt/wrangler/materializer.go index cc7ba3f1603..1974b35e863 100644 --- a/go/vt/wrangler/materializer.go +++ b/go/vt/wrangler/materializer.go @@ -1405,21 +1405,7 @@ func (mz *materializer) generateInserts(ctx context.Context, sourceShards []*top Name: sqlparser.NewIdentifierCI("in_keyrange"), Exprs: subExprs, } - if sel.Where != nil { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: &sqlparser.AndExpr{ - Left: inKeyRange, - Right: sel.Where.Expr, - }, - } - } else { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: inKeyRange, - } - } - + sel.AddWhere(inKeyRange) filter = sqlparser.String(sel) } diff --git a/go/vt/wrangler/vdiff.go b/go/vt/wrangler/vdiff.go index 4caad42ce1f..ca83e8c57fb 100644 --- a/go/vt/wrangler/vdiff.go +++ b/go/vt/wrangler/vdiff.go @@ -1439,16 +1439,13 @@ func removeKeyrange(where *sqlparser.Where) *sqlparser.Where { func removeExprKeyrange(node sqlparser.Expr) sqlparser.Expr { switch node := node.(type) { case *sqlparser.AndExpr: - if isFuncKeyrange(node.Left) { - return removeExprKeyrange(node.Right) - } - if isFuncKeyrange(node.Right) { - return removeExprKeyrange(node.Left) - } - return &sqlparser.AndExpr{ - Left: removeExprKeyrange(node.Left), - Right: removeExprKeyrange(node.Right), + var keep sqlparser.Exprs + for _, p := range node.Predicates { + if !isFuncKeyrange(p) { + keep = append(keep, removeExprKeyrange(p)) + } } + return sqlparser.CreateAndExpr(keep...) } return node } diff --git a/go/vt/wrangler/vexec_plan.go b/go/vt/wrangler/vexec_plan.go index 76b2d0fe732..1165add5d4b 100644 --- a/go/vt/wrangler/vexec_plan.go +++ b/go/vt/wrangler/vexec_plan.go @@ -158,7 +158,7 @@ func (vx *vexec) buildPlan(ctx context.Context) (plan *vexecPlan, err error) { case *sqlparser.Insert: plan, err = vx.buildInsertPlan(ctx, vx.planner, stmt) case *sqlparser.Select: - plan, err = vx.buildSelectPlan(ctx, vx.planner, stmt) + plan, err = vx.buildSelectPlan(vx.planner, stmt) default: return nil, fmt.Errorf("query not supported by vexec: %s", sqlparser.String(stmt)) } @@ -166,12 +166,12 @@ func (vx *vexec) buildPlan(ctx context.Context) (plan *vexecPlan, err error) { } // analyzeWhereEqualsColumns identifies column names in a WHERE clause that have a comparison expression -func (vx *vexec) analyzeWhereEqualsColumns(where *sqlparser.Where) []string { +func (vx *vexec) analyzeWhereEqualsColumns(expr sqlparser.Expr) []string { var cols []string - if where == nil { + if expr == nil { return cols } - exprs := sqlparser.SplitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, expr) for _, expr := range exprs { switch expr := expr.(type) { case *sqlparser.ComparisonExpr: @@ -185,8 +185,8 @@ func (vx *vexec) analyzeWhereEqualsColumns(where *sqlparser.Where) []string { } // addDefaultWheres modifies the query to add, if appropriate, the workflow and DB-name column modifiers -func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) *sqlparser.Where { - cols := vx.analyzeWhereEqualsColumns(where) +func (vx *vexec) addDefaultWheres(planner vexecPlanner, stmt sqlparser.WhereAble) { + cols := vx.analyzeWhereEqualsColumns(stmt.GetWherePredicate()) var hasDBName, hasWorkflow bool plannerParams := planner.params() for _, col := range cols { @@ -196,24 +196,13 @@ func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) hasWorkflow = true } } - newWhere := where if !hasDBName { expr := &sqlparser.ComparisonExpr{ Left: &sqlparser.ColName{Name: sqlparser.NewIdentifierCI(plannerParams.dbNameColumn)}, Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(vx.primaries[0].DbName()), } - if newWhere == nil { - newWhere = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: expr, - } - } else { - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } - } + stmt.AddWhere(expr) } if !hasWorkflow && vx.workflow != "" { expr := &sqlparser.ComparisonExpr{ @@ -221,12 +210,8 @@ func (vx *vexec) addDefaultWheres(planner vexecPlanner, where *sqlparser.Where) Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(vx.workflow), } - newWhere.Expr = &sqlparser.AndExpr{ - Left: newWhere.Expr, - Right: expr, - } + stmt.AddWhere(expr) } - return newWhere } // buildUpdatePlan builds a plan for an UPDATE query @@ -262,7 +247,7 @@ func (vx *vexec) buildUpdatePlan(ctx context.Context, planner vexecPlanner, upd return nil, fmt.Errorf("Query must match one of these templates: %s", strings.Join(templates, "; ")) } } - upd.Where = vx.addDefaultWheres(planner, upd.Where) + vx.addDefaultWheres(planner, upd) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", upd) @@ -285,7 +270,7 @@ func (vx *vexec) buildDeletePlan(ctx context.Context, planner vexecPlanner, del return nil, fmt.Errorf("unsupported construct: %v", sqlparser.String(del)) } - del.Where = vx.addDefaultWheres(planner, del.Where) + vx.addDefaultWheres(planner, del) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", del) @@ -325,8 +310,8 @@ func (vx *vexec) buildInsertPlan(ctx context.Context, planner vexecPlanner, ins } // buildSelectPlan builds a plan for a SELECT query -func (vx *vexec) buildSelectPlan(ctx context.Context, planner vexecPlanner, sel *sqlparser.Select) (*vexecPlan, error) { - sel.Where = vx.addDefaultWheres(planner, sel.Where) +func (vx *vexec) buildSelectPlan(planner vexecPlanner, sel *sqlparser.Select) (*vexecPlan, error) { + vx.addDefaultWheres(planner, sel) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel)