Skip to content

Commit

Permalink
refactor: refactor to allow hash joins in aggregationg pushing
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <[email protected]>
  • Loading branch information
GuptaManan100 authored and systay committed Dec 7, 2023
1 parent 2dfe734 commit c292fe5
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 60 deletions.
28 changes: 16 additions & 12 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
tableID: TableID(join.RHS),
}

joinColumns, output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, lhs, rhs)
columns := &applyJoinColumns{}
output, err := splitAggrColumnsToLeftAndRight(ctx, rootAggr, join, join.LeftJoin, columns, lhs, rhs)
join.JoinColumns = columns
if err != nil {
// if we get this error, we just abort the splitting and fall back on simpler ways of solving the same query
if errors.Is(err, errAbortAggrPushing) {
Expand All @@ -371,14 +373,15 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
}

groupingJCs := splitGroupingToLeftAndRight(ctx, rootAggr, lhs, rhs)
joinColumns = append(joinColumns, groupingJCs...)
for _, col := range groupingJCs {
join.JoinColumns.add(col)
}

// We need to add any columns coming from the lhs of the join to the group by on that side
// If we don't, the LHS will not be able to return the column, and it can't be used to send down to the RHS
addColumnsFromLHSInJoinPredicates(ctx, rootAggr, join, lhs)

join.LHS, join.RHS = lhs.pushed, rhs.pushed
join.JoinColumns = joinColumns

if !rootAggr.Original {
// we only keep the root aggregation, if this aggregator was created
Expand All @@ -394,7 +397,7 @@ func pushAggregationThroughApplyJoin(ctx *plancontext.PlanningContext, rootAggr
var errAbortAggrPushing = fmt.Errorf("abort aggregation pushing")

func addColumnsFromLHSInJoinPredicates(ctx *plancontext.PlanningContext, rootAggr *Aggregator, join *ApplyJoin, lhs *joinPusher) {
for _, pred := range join.JoinPredicates {
for _, pred := range join.JoinPredicates.columns {
for _, bve := range pred.LHSExprs {
expr := bve.Expr
wexpr, err := rootAggr.QP.GetSimplifiedExpr(ctx, expr)
Expand Down Expand Up @@ -466,17 +469,19 @@ func splitGroupingToLeftAndRight(ctx *plancontext.PlanningContext, rootAggr *Agg
func splitAggrColumnsToLeftAndRight(
ctx *plancontext.PlanningContext,
aggregator *Aggregator,
join *ApplyJoin,
join Operator,
leftJoin bool,
columns joinColumns,
lhs, rhs *joinPusher,
) ([]applyJoinColumn, Operator, error) {
) (Operator, error) {
proj := newAliasedProjection(join)
proj.FromAggr = true
builder := &aggBuilder{
lhs: lhs,
rhs: rhs,
joinColumns: &applyJoinColumns{},
joinColumns: columns,
proj: proj,
outerJoin: join.LeftJoin,
outerJoin: leftJoin,
}

canPushDistinctAggr, distinctExpr := checkIfWeCanPush(ctx, aggregator)
Expand All @@ -485,7 +490,7 @@ func splitAggrColumnsToLeftAndRight(
// We keep node of the distinct aggregation expression to be used later for ordering.
if !canPushDistinctAggr {
aggregator.DistinctExpr = distinctExpr
return nil, nil, errAbortAggrPushing
return nil, errAbortAggrPushing
}

outer:
Expand All @@ -495,16 +500,15 @@ outer:
if aggr.ColOffset == colIdx {
err := builder.handleAggr(ctx, aggr)
if err != nil {
return nil, nil, err
return nil, err
}
continue outer
}
}
builder.proj.addUnexploredExpr(col, col.Expr)
}
columns := builder.joinColumns.(*applyJoinColumns)

return columns.columns, builder.proj, nil
return builder.proj, nil
}

func coalesceFunc(e sqlparser.Expr) sqlparser.Expr {
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ type (
applyJoinColumns struct {
columns []applyJoinColumn
}

hashJoinColumns struct {
columns []hashJoinColumn
}
)

func (jc *applyJoinColumns) addLeft(expr sqlparser.Expr) {
Expand All @@ -70,6 +74,24 @@ func (jc *applyJoinColumns) addRight(expr sqlparser.Expr) {
})
}

func (jc *applyJoinColumns) add(col applyJoinColumn) {
jc.columns = append(jc.columns, col)
}

func (jc *hashJoinColumns) addLeft(expr sqlparser.Expr) {
jc.columns = append(jc.columns, hashJoinColumn{
expr: expr,
side: Left,
})
}

func (jc *hashJoinColumns) addRight(expr sqlparser.Expr) {
jc.columns = append(jc.columns, hashJoinColumn{
expr: expr,
side: Right,
})
}

func (ab *aggBuilder) leftCountStar(ctx *plancontext.PlanningContext) *sqlparser.AliasedExpr {
ae, created := ab.lhs.countStar(ctx)
if created {
Expand Down
66 changes: 34 additions & 32 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ type (
Predicate sqlparser.Expr

// JoinColumns keeps track of what AST expression is represented in the Columns array
JoinColumns []applyJoinColumn
JoinColumns *applyJoinColumns

// JoinPredicates are join predicates that have been broken up into left hand side and right hand side parts.
JoinPredicates []applyJoinColumn
JoinPredicates *applyJoinColumns

// ExtraVars are columns we need to copy from left to right not needed by any predicates or projections,
// these are needed by other operators further down the right hand side of the join
Expand Down Expand Up @@ -87,11 +87,13 @@ type (

func NewApplyJoin(lhs, rhs Operator, predicate sqlparser.Expr, leftOuterJoin bool) *ApplyJoin {
return &ApplyJoin{
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
Predicate: predicate,
LeftJoin: leftOuterJoin,
LHS: lhs,
RHS: rhs,
Vars: map[string]int{},
Predicate: predicate,
LeftJoin: leftOuterJoin,
JoinColumns: &applyJoinColumns{},
JoinPredicates: &applyJoinColumns{},
}
}

Expand All @@ -101,8 +103,8 @@ func (aj *ApplyJoin) Clone(inputs []Operator) Operator {
kopy.LHS = inputs[0]
kopy.RHS = inputs[1]
kopy.Columns = slices.Clone(aj.Columns)
kopy.JoinColumns = slices.Clone(aj.JoinColumns)
kopy.JoinPredicates = slices.Clone(aj.JoinPredicates)
kopy.JoinColumns = &applyJoinColumns{columns: slices.Clone(aj.JoinColumns.columns)}
kopy.JoinPredicates = &applyJoinColumns{columns: slices.Clone(aj.JoinPredicates.columns)}
kopy.Vars = maps.Clone(aj.Vars)
kopy.Predicate = sqlparser.CloneExpr(aj.Predicate)
kopy.ExtraLHSVars = slices.Clone(aj.ExtraLHSVars)
Expand Down Expand Up @@ -151,7 +153,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql
aj.Predicate = ctx.SemTable.AndExpressions(expr, aj.Predicate)

col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, TableID(aj.LHS))
aj.JoinPredicates = append(aj.JoinPredicates, col)
aj.JoinPredicates.add(col)
rhs := aj.RHS.AddPredicate(ctx, col.RHSExpr)
aj.RHS = rhs
}
Expand All @@ -162,7 +164,9 @@ func (aj *ApplyJoin) pushColRight(ctx *plancontext.PlanningContext, e *sqlparser
}

func (aj *ApplyJoin) GetColumns(*plancontext.PlanningContext) []*sqlparser.AliasedExpr {
return slice.Map(aj.JoinColumns, joinColumnToAliasedExpr)
return slice.Map(aj.JoinColumns.columns, func(from applyJoinColumn) *sqlparser.AliasedExpr {
return aeWrap(from.Original)
})
}

func (aj *ApplyJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs {
Expand All @@ -173,10 +177,6 @@ func (aj *ApplyJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
return aj.LHS.GetOrdering(ctx)
}

func joinColumnToAliasedExpr(c applyJoinColumn) *sqlparser.AliasedExpr {
return aeWrap(c.Original)
}

func joinColumnToExpr(column applyJoinColumn) sqlparser.Expr {
return column.Original
}
Expand Down Expand Up @@ -205,12 +205,14 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq
return
}

func (aj *ApplyJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, _ bool) int {
offset, found := canReuseColumn(ctx, aj.JoinColumns, expr, joinColumnToExpr)
if !found {
return -1
func applyJoinCompare(ctx *plancontext.PlanningContext, expr sqlparser.Expr) func(e applyJoinColumn) bool {
return func(e applyJoinColumn) bool {
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr)
}
return offset
}

func (aj *ApplyJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, _ bool) int {
return slices.IndexFunc(aj.JoinColumns.columns, applyJoinCompare(ctx, expr))
}

func (aj *ApplyJoin) AddColumn(
Expand All @@ -226,13 +228,13 @@ func (aj *ApplyJoin) AddColumn(
}
}
col := aj.getJoinColumnFor(ctx, expr, expr.Expr, groupBy)
offset := len(aj.JoinColumns)
aj.JoinColumns = append(aj.JoinColumns, col)
offset := len(aj.JoinColumns.columns)
aj.JoinColumns.add(col)
return offset
}

func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
for _, col := range aj.JoinColumns {
for _, col := range aj.JoinColumns.columns {
// Read the type description for applyJoinColumn to understand the following code
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
Expand All @@ -249,7 +251,7 @@ func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
}
}

for _, col := range aj.JoinPredicates {
for _, col := range aj.JoinPredicates.columns {
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
Expand All @@ -270,7 +272,7 @@ func (aj *ApplyJoin) addOffset(offset int) {

func (aj *ApplyJoin) ShortDescription() string {
pred := sqlparser.String(aj.Predicate)
columns := slice.Map(aj.JoinColumns, func(from applyJoinColumn) string {
columns := slice.Map(aj.JoinColumns.columns, func(from applyJoinColumn) string {
return sqlparser.String(from.Original)
})
firstPart := fmt.Sprintf("on %s columns: %s", pred, strings.Join(columns, ", "))
Expand All @@ -283,14 +285,14 @@ func (aj *ApplyJoin) ShortDescription() string {
}

func (aj *ApplyJoin) isColNameMovedFromL2R(bindVarName string) bool {
for _, jc := range aj.JoinColumns {
for _, jc := range aj.JoinColumns.columns {
for _, bve := range jc.LHSExprs {
if bve.Name == bindVarName {
return true
}
}
}
for _, jp := range aj.JoinPredicates {
for _, jp := range aj.JoinPredicates.columns {
for _, bve := range jp.LHSExprs {
if bve.Name == bindVarName {
return true
Expand All @@ -308,7 +310,7 @@ func (aj *ApplyJoin) isColNameMovedFromL2R(bindVarName string) bool {
// findOrAddColNameBindVarName goes through the JoinColumns and looks for the given colName coming from the LHS of the join
// and returns the argument name if found. if it's not found, a new applyJoinColumn passing this through will be added
func (aj *ApplyJoin) findOrAddColNameBindVarName(ctx *plancontext.PlanningContext, col *sqlparser.ColName) (string, error) {
for i, thisCol := range aj.JoinColumns {
for i, thisCol := range aj.JoinColumns.columns {
idx := slices.IndexFunc(thisCol.LHSExprs, func(e BindVarExpr) bool {
return ctx.SemTable.EqualsExpr(e.Expr, col)
})
Expand All @@ -320,12 +322,12 @@ func (aj *ApplyJoin) findOrAddColNameBindVarName(ctx *plancontext.PlanningContex
expr := thisCol.LHSExprs[idx]
bvname := ctx.GetReservedArgumentFor(expr.Expr)
expr.Name = bvname
aj.JoinColumns[i].LHSExprs[idx] = expr
aj.JoinColumns.columns[i].LHSExprs[idx] = expr
}
return thisCol.LHSExprs[idx].Name, nil
}
}
for _, thisCol := range aj.JoinPredicates {
for _, thisCol := range aj.JoinPredicates.columns {
idx := slices.IndexFunc(thisCol.LHSExprs, func(e BindVarExpr) bool {
return ctx.SemTable.EqualsExpr(e.Expr, col)
})
Expand Down Expand Up @@ -354,10 +356,10 @@ func (a *ApplyJoin) LHSColumnsNeeded(ctx *plancontext.PlanningContext) (needed s
f := func(from BindVarExpr) sqlparser.Expr {
return from.Expr
}
for _, jc := range a.JoinColumns {
for _, jc := range a.JoinColumns.columns {
needed = append(needed, slice.Map(jc.LHSExprs, f)...)
}
for _, jc := range a.JoinPredicates {
for _, jc := range a.JoinPredicates.columns {
needed = append(needed, slice.Map(jc.LHSExprs, f)...)
}
needed = append(needed, slice.Map(a.ExtraLHSVars, f)...)
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type (
}

hashJoinColumn struct {
typ joinSide
side joinSide
expr sqlparser.Expr
}

Expand Down Expand Up @@ -136,7 +136,7 @@ func (hj *HashJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
var column *ProjExpr
var pureOffset bool

switch in.typ {
switch in.side {
case Unknown:
column, pureOffset = hj.addColumn(ctx, in.expr)
case Left:
Expand Down Expand Up @@ -187,7 +187,7 @@ func (hj *HashJoin) ShortDescription() string {

if len(hj.columns) > 0 {
cols := slice.Map(hj.columns, func(from hashJoinColumn) (result string) {
switch from.typ {
switch from.side {
case Unknown:
result = "U"
case Left:
Expand Down
10 changes: 4 additions & 6 deletions go/vt/vtgate/planbuilder/operators/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,27 +473,25 @@ func (p *Projection) compactWithJoin(ctx *plancontext.PlanningContext, join *App
}

var newColumns []int
var newColumnsAST []applyJoinColumn
newColumnsAST := &applyJoinColumns{}
for _, col := range ap {
switch colInfo := col.Info.(type) {
case Offset:
newColumns = append(newColumns, join.Columns[colInfo])
newColumnsAST = append(newColumnsAST, join.JoinColumns[colInfo])
newColumnsAST.add(join.JoinColumns.columns[colInfo])
case nil:
if !ctx.SemTable.EqualsExprWithDeps(col.EvalExpr, col.ColExpr) {
// the inner expression is different from what we are presenting to the outside - this means we need to evaluate
return p, NoRewrite
}
offset := slices.IndexFunc(join.JoinColumns, func(jc applyJoinColumn) bool {
return ctx.SemTable.EqualsExprWithDeps(jc.Original, col.ColExpr)
})
offset := slices.IndexFunc(join.JoinColumns.columns, applyJoinCompare(ctx, col.ColExpr))
if offset < 0 {
return p, NoRewrite
}
if len(join.Columns) > 0 {
newColumns = append(newColumns, join.Columns[offset])
}
newColumnsAST = append(newColumnsAST, join.JoinColumns[offset])
newColumnsAST.add(join.JoinColumns.columns[offset])
default:
return p, NoRewrite
}
Expand Down
Loading

0 comments on commit c292fe5

Please sign in to comment.