diff --git a/pkg/runtime/ast/expression.go b/pkg/runtime/ast/expression.go index f2c245cf..6a40242f 100644 --- a/pkg/runtime/ast/expression.go +++ b/pkg/runtime/ast/expression.go @@ -48,6 +48,7 @@ type ExpressionNode interface { Node Restorer Mode() ExpressionMode + Clone() ExpressionNode } type LogicalExpressionNode struct { @@ -85,6 +86,14 @@ func (l *LogicalExpressionNode) Mode() ExpressionMode { return EmLogical } +func (l *LogicalExpressionNode) Clone() ExpressionNode { + return &LogicalExpressionNode{ + Op: l.Op, + Left: l.Left.Clone(), + Right: l.Right.Clone(), + } +} + type NotExpressionNode struct { E ExpressionNode } @@ -105,6 +114,12 @@ func (n *NotExpressionNode) Mode() ExpressionMode { return EmNot } +func (n *NotExpressionNode) Clone() ExpressionNode { + return &NotExpressionNode{ + E: n.E.Clone(), + } +} + type PredicateExpressionNode struct { P PredicateNode } @@ -123,3 +138,9 @@ func (a *PredicateExpressionNode) Restore(flag RestoreFlag, sb *strings.Builder, func (a *PredicateExpressionNode) Mode() ExpressionMode { return EmPredicate } + +func (a *PredicateExpressionNode) Clone() ExpressionNode { + return &PredicateExpressionNode{ + P: a.P.Clone(), + } +} diff --git a/pkg/runtime/ast/expression_atom.go b/pkg/runtime/ast/expression_atom.go index a3da7f44..8d14e114 100644 --- a/pkg/runtime/ast/expression_atom.go +++ b/pkg/runtime/ast/expression_atom.go @@ -61,6 +61,7 @@ type ExpressionAtom interface { Node Restorer phantom() expressionAtomPhantom + Clone() ExpressionAtom } type IntervalExpressionAtom struct { @@ -104,6 +105,13 @@ func (ie *IntervalExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (ie *IntervalExpressionAtom) Clone() ExpressionAtom { + return &IntervalExpressionAtom{ + Unit: ie.Unit, + Value: ie.Value.Clone(), + } +} + type SystemVariableExpressionAtom struct { Name string System bool @@ -146,6 +154,14 @@ func (sy *SystemVariableExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (sy *SystemVariableExpressionAtom) Clone() ExpressionAtom { + return &SystemVariableExpressionAtom{ + Name: sy.Name, + System: sy.System, + Global: sy.Global, + } +} + type UnaryExpressionAtom struct { Operator string Inner Node // ExpressionAtom or *BinaryComparisonPredicateNode @@ -185,6 +201,10 @@ func (u *UnaryExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (u *UnaryExpressionAtom) Clone() ExpressionAtom { + panic("implement me") +} + type ConstantExpressionAtom struct { Inner interface{} } @@ -202,6 +222,12 @@ func (c *ConstantExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (c *ConstantExpressionAtom) Clone() ExpressionAtom { + return &ConstantExpressionAtom{ + Inner: c.Inner, + } +} + func constant2string(value interface{}) string { switch v := value.(type) { case proto.Null: @@ -300,6 +326,12 @@ func (c ColumnNameExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (c ColumnNameExpressionAtom) Clone() ExpressionAtom { + res := make(ColumnNameExpressionAtom, len(c)) + copy(res, c) + return res +} + type VariableExpressionAtom int func (v VariableExpressionAtom) Accept(visitor Visitor) (interface{}, error) { @@ -324,6 +356,10 @@ func (v VariableExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (v VariableExpressionAtom) Clone() ExpressionAtom { + return v +} + type MathExpressionAtom struct { Left ExpressionAtom Operator string @@ -358,6 +394,14 @@ func (m *MathExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (m *MathExpressionAtom) Clone() ExpressionAtom { + return &MathExpressionAtom{ + Left: m.Left.Clone(), + Operator: m.Operator, + Right: m.Right.Clone(), + } +} + type NestedExpressionAtom struct { First ExpressionNode } @@ -380,6 +424,12 @@ func (n *NestedExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } +func (n *NestedExpressionAtom) Clone() ExpressionAtom { + return &NestedExpressionAtom{ + First: n.First.Clone(), + } +} + type FunctionCallExpressionAtom struct { F Node // *Function OR *AggrFunction OR *CaseWhenElseFunction OR *CastFunction } @@ -413,3 +463,7 @@ func (f *FunctionCallExpressionAtom) Restore(flag RestoreFlag, sb *strings.Build func (f *FunctionCallExpressionAtom) phantom() expressionAtomPhantom { return expressionAtomPhantom{} } + +func (f *FunctionCallExpressionAtom) Clone() ExpressionAtom { + panic("implement me") +} diff --git a/pkg/runtime/ast/predicate.go b/pkg/runtime/ast/predicate.go index a79e4109..265cc3f1 100644 --- a/pkg/runtime/ast/predicate.go +++ b/pkg/runtime/ast/predicate.go @@ -44,6 +44,7 @@ type PredicateNode interface { Node Restorer phantom() predicateNodePhantom + Clone() PredicateNode } type LikePredicateNode struct { @@ -81,6 +82,14 @@ func (l *LikePredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (l *LikePredicateNode) Clone() PredicateNode { + return &LikePredicateNode{ + Not: l.Not, + Left: l.Left.Clone(), + Right: l.Right.Clone(), + } +} + type RegexpPredicationNode struct { Left PredicateNode Right PredicateNode @@ -111,6 +120,14 @@ func (rp *RegexpPredicationNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (rp *RegexpPredicationNode) Clone() PredicateNode { + return &RegexpPredicationNode{ + Left: rp.Left.Clone(), + Right: rp.Right.Clone(), + Not: rp.Not, + } +} + type BinaryComparisonPredicateNode struct { Left PredicateNode Right PredicateNode @@ -158,6 +175,14 @@ func (b *BinaryComparisonPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (b *BinaryComparisonPredicateNode) Clone() PredicateNode { + return &BinaryComparisonPredicateNode{ + Left: b.Left.Clone(), + Right: b.Right.Clone(), + Op: b.Op, + } +} + type AtomPredicateNode struct { A ExpressionAtom } @@ -185,6 +210,12 @@ func (a *AtomPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (a *AtomPredicateNode) Clone() PredicateNode { + return &AtomPredicateNode{ + A: a.A.Clone(), + } +} + type BetweenPredicateNode struct { Not bool Key PredicateNode @@ -222,6 +253,15 @@ func (b *BetweenPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } +func (b *BetweenPredicateNode) Clone() PredicateNode { + return &BetweenPredicateNode{ + Not: b.Not, + Key: b.Key.Clone(), + Left: b.Left.Clone(), + Right: b.Right.Clone(), + } +} + type InPredicateNode struct { Not bool P PredicateNode @@ -264,3 +304,16 @@ func (ip *InPredicateNode) Restore(flag RestoreFlag, sb *strings.Builder, args * func (ip *InPredicateNode) phantom() predicateNodePhantom { return predicateNodePhantom{} } + +func (ip *InPredicateNode) Clone() PredicateNode { + e := make([]ExpressionNode, 0, len(ip.E)) + for _, node := range ip.E { + e = append(e, node.Clone()) + } + + return &InPredicateNode{ + Not: ip.Not, + P: ip.P.Clone(), + E: e, + } +} diff --git a/pkg/runtime/ast/select_element.go b/pkg/runtime/ast/select_element.go index 97ba0814..de396341 100644 --- a/pkg/runtime/ast/select_element.go +++ b/pkg/runtime/ast/select_element.go @@ -267,6 +267,13 @@ func (s *SelectElementColumn) Suffix() string { return s.Name[len(s.Name)-1] } +func (s *SelectElementColumn) Prefix() string { + if len(s.Name) < 2 { + return "" + } + return s.Name[len(s.Name)-2] +} + func (s *SelectElementColumn) Restore(flag RestoreFlag, sb *strings.Builder, args *[]int) error { if err := ColumnNameExpressionAtom(s.Name).Restore(flag, sb, args); err != nil { return errors.WithStack(err) diff --git a/pkg/runtime/optimize/dml/select.go b/pkg/runtime/optimize/dml/select.go index 3608dafe..54bbbd01 100644 --- a/pkg/runtime/optimize/dml/select.go +++ b/pkg/runtime/optimize/dml/select.go @@ -35,6 +35,7 @@ import ( "github.com/arana-db/arana/pkg/proto/hint" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/cmp" rcontext "github.com/arana-db/arana/pkg/runtime/context" "github.com/arana-db/arana/pkg/runtime/misc/extvalue" "github.com/arana-db/arana/pkg/runtime/optimize" @@ -110,13 +111,14 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err } } + if stmt.HasJoin() { + return optimizeJoin(ctx, o, stmt) + } // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be // `select * from student offset 0 limit 100+5` originOffset, newLimit := overwriteLimit(stmt, &o.Args) - if stmt.HasJoin() { - return optimizeJoin(ctx, o, stmt) - } + flag := getSelectFlag(o.Rule, stmt) if flag&_supported == 0 { return nil, errors.Errorf("unsupported sql: %s", rcontext.SQL(ctx)) @@ -124,7 +126,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err if flag&_bypass != 0 { if len(stmt.From) > 0 { - err := rewriteSelectStatement(ctx, stmt, stmt.From[0].Source.(ast.TableName).Suffix()) + err := rewriteSelectStatement(ctx, stmt, o) if err != nil { return nil, err } @@ -173,8 +175,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err } toSingle := func(db, tbl string) (proto.Plan, error) { - _, tb0, _ := vt.Topology().Smallest() - if err := rewriteSelectStatement(ctx, stmt, tb0); err != nil { + if err := rewriteSelectStatement(ctx, stmt, o); err != nil { return nil, err } ret := &dml.SimpleQueryPlan{ @@ -220,8 +221,7 @@ func optimizeSelect(ctx context.Context, o *optimize.Optimizer) (proto.Plan, err return toSingle(db, tbl) } - _, tb, _ := vt.Topology().Smallest() - if err = rewriteSelectStatement(ctx, stmt, tb); err != nil { + if err = rewriteSelectStatement(ctx, stmt, o); err != nil { return nil, errors.WithStack(err) } @@ -417,8 +417,8 @@ func handleGroupBy(parentPlan proto.Plan, stmt *ast.SelectStatement) (proto.Plan // optimizeJoin ony support a join b in one db. // DEPRECATED: reimplement in the future func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectStatement) (proto.Plan, error) { - compute := func(tableSource *ast.TableSourceItem) (database, alias string, shardList []string, err error) { - table := tableSource.Source.(ast.TableName) + compute := func(tableSource *ast.TableSourceItem) (database, alias string, table ast.TableName, shards rule.DatabaseTables, err error) { + table = tableSource.Source.(ast.TableName) if table == nil { err = errors.New("must table, not statement or join node") return @@ -426,63 +426,256 @@ func optimizeJoin(ctx context.Context, o *optimize.Optimizer, stmt *ast.SelectSt alias = tableSource.Alias database = table.Prefix() - shards, err := o.ComputeShards(ctx, table, nil, o.Args) + if alias == "" { + alias = table.Suffix() + } + + shards, err = o.ComputeShards(ctx, table, nil, o.Args) if err != nil { return } - // table no shard - if shards == nil { - shardList = append(shardList, table.Suffix()) - return + return + } + + from := stmt.From[0] + dbLeft, aliasLeft, tableLeft, shardsLeft, err := compute(&from.TableSourceItem) + if err != nil { + return nil, err + } + + join := from.Joins[0] + dbRight, aliasRight, tableRight, shardsRight, err := compute(join.Target) + if err != nil { + return nil, err + } + + // one db + if dbLeft == dbRight && shardsLeft == nil && shardsRight == nil { + joinPan := &dml.SimpleJoinPlan{ + Left: &dml.JoinTable{ + Tables: tableLeft, + Alias: aliasLeft, + }, + Join: from.Joins[0], + Right: &dml.JoinTable{ + Tables: tableRight, + Alias: aliasRight, + }, + Stmt: o.Stmt.(*ast.SelectStatement), } - // table shard more than one db - if len(shards) > 1 { - err = errors.New("not support more than one db") - return + joinPan.BindArgs(o.Args) + return joinPan, nil + } + + //multiple shards & do hash join + hashJoinPlan := &dml.HashJoinPlan{ + Stmt: stmt, + } + + onExpression, ok := from.Joins[0].On.(*ast.PredicateExpressionNode).P.(*ast.BinaryComparisonPredicateNode) + // todo support more 'ON' condition ast.LogicalExpressionNode + if !ok { + return nil, errors.New("not support more than one 'ON' condition") + } + + onLeft := onExpression.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + onRight := onExpression.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + + leftKey := "" + if onLeft.Prefix() == aliasLeft { + leftKey = onLeft.Suffix() + } + + rightKey := "" + if onRight.Prefix() == aliasRight { + rightKey = onRight.Suffix() + } + + if len(leftKey) == 0 || len(rightKey) == 0 { + return nil, errors.Errorf("not found buildKey or probeKey") + } + + rewriteToSingle := func(tableSource ast.TableSourceItem, shards map[string][]string, onKey string) (proto.Plan, error) { + selectStmt := &ast.SelectStatement{ + Select: stmt.Select, + From: ast.FromNode{ + &ast.TableSourceNode{ + TableSourceItem: tableSource, + }, + }, } + table := tableSource.Source.(ast.TableName) + actualTb := table.Suffix() + aliasTb := tableSource.Alias - for k, v := range shards { - database = k - shardList = v + tb0 := actualTb + if shards != nil { + vt := o.Rule.MustVTable(tb0) + _, tb0, _ = vt.Topology().Smallest() } + if _, ok = stmt.Select[0].(*ast.SelectElementAll); !ok && len(stmt.Select) > 1 { + metadata, err := loadMetadataByTable(ctx, tb0) + if err != nil { + return nil, err + } - if alias == "" { - alias = table.Suffix() + selectColumn := selectStmt.Select + var selectElements []ast.SelectElement + for _, element := range selectColumn { + e, ok := element.(*ast.SelectElementColumn) + if ok { + columnsMap := metadata.Columns + ColumnMeta, exist := columnsMap[e.Suffix()] + if (aliasTb == e.Prefix() || actualTb == e.Prefix()) && exist { + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{ColumnMeta.Name}, "")) + } + } + } + selectElements = append(selectElements, ast.NewSelectElementColumn([]string{onKey}, "")) + selectStmt.Select = selectElements } - return - } + if stmt.Where != nil { + selectStmt.Where = stmt.Where.Clone() + err := filterWhereByTable(ctx, selectStmt.Where, tb0, aliasTb) + if err != nil { + return nil, err + } + } - from := stmt.From[0] + optimizer := &optimize.Optimizer{ + Rule: o.Rule, + Stmt: selectStmt, + } + if _, ok = selectStmt.Select[0].(*ast.SelectElementAll); ok && len(selectStmt.Select) == 1 { + if err = rewriteSelectStatement(ctx, selectStmt, optimizer); err != nil { + return nil, err + } + + selectStmt.Select = append(selectStmt.Select, ast.NewSelectElementColumn([]string{onKey}, "")) + } + + plan, err := optimizeSelect(ctx, optimizer) + if err != nil { + return nil, err + } + return plan, nil + } - dbLeft, aliasLeft, shardLeft, err := compute(&from.TableSourceItem) + leftPlan, err := rewriteToSingle(from.TableSourceItem, shardsLeft, leftKey) if err != nil { return nil, err } - dbRight, aliasRight, shardRight, err := compute(from.Joins[0].Target) + + rightPlan, err := rewriteToSingle(*from.Joins[0].Target, shardsRight, rightKey) if err != nil { return nil, err } - if dbLeft != "" && dbRight != "" && dbLeft != dbRight { - return nil, errors.New("not support more than one db") + setPlan := func(plan *dml.HashJoinPlan, buildPlan, probePlan proto.Plan, buildKey, probeKey string) { + plan.BuildKey = buildKey + plan.ProbeKey = probeKey + plan.BuildPlan = buildPlan + plan.ProbePlan = probePlan } - joinPan := &dml.SimpleJoinPlan{ - Left: &dml.JoinTable{ - Tables: shardLeft, - Alias: aliasLeft, - }, - Join: from.Joins[0], - Right: &dml.JoinTable{ - Tables: shardRight, - Alias: aliasRight, - }, - Stmt: o.Stmt.(*ast.SelectStatement), + if join.Typ == ast.InnerJoin { + setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) + hashJoinPlan.IsFilterProbeRow = true + } else { + hashJoinPlan.IsFilterProbeRow = false + if join.Typ == ast.LeftJoin { + hashJoinPlan.IsReversedColumn = true + setPlan(hashJoinPlan, rightPlan, leftPlan, rightKey, leftKey) + } else if join.Typ == ast.RightJoin { + setPlan(hashJoinPlan, leftPlan, rightPlan, leftKey, rightKey) + } else { + return nil, errors.New("not support Join Type") + } + } + + var tmpPlan proto.Plan + tmpPlan = hashJoinPlan + + var ( + analysis selectResult + scanner = newSelectScanner(stmt, o.Args) + ) + + if err = rewriteSelectStatement(ctx, stmt, o); err != nil { + return nil, errors.WithStack(err) + } + + if err = scanner.scan(&analysis); err != nil { + return nil, errors.WithStack(err) } - joinPan.BindArgs(o.Args) - return joinPan, nil + // check if order-by exists + if len(analysis.orders) > 0 { + var ( + sb strings.Builder + orderByItems = make([]dataset.OrderByItem, 0, len(analysis.orders)) + ) + + for _, it := range analysis.orders { + var next dataset.OrderByItem + next.Desc = it.Desc + if alias := it.Alias(); len(alias) > 0 { + next.Column = alias + } else { + switch prev := it.Prev().(type) { + case *ast.SelectElementColumn: + next.Column = prev.Suffix() + default: + if err = it.Restore(ast.RestoreWithoutAlias, &sb, nil); err != nil { + return nil, errors.WithStack(err) + } + next.Column = sb.String() + sb.Reset() + } + } + orderByItems = append(orderByItems, next) + } + tmpPlan = &dml.OrderPlan{ + ParentPlan: tmpPlan, + OrderByItems: orderByItems, + } + } + + if stmt.GroupBy != nil { + if tmpPlan, err = handleGroupBy(tmpPlan, stmt); err != nil { + return nil, errors.WithStack(err) + } + } else if analysis.hasAggregate { + tmpPlan = &dml.AggregatePlan{ + Plan: tmpPlan, + Fields: stmt.Select, + } + } + + if stmt.Limit != nil { + // overwrite stmt limit x offset y. eg `select * from student offset 100 limit 5` will be + // `select * from student offset 0 limit 100+5` + originOffset, newLimit := overwriteLimit(stmt, &o.Args) + tmpPlan = &dml.LimitPlan{ + ParentPlan: tmpPlan, + OriginOffset: originOffset, + OverwriteLimit: newLimit, + } + } + + if analysis.hasMapping { + tmpPlan = &dml.MappingPlan{ + Plan: tmpPlan, + Fields: stmt.Select, + } + } + + tmpPlan = &dml.RenamePlan{ + Plan: tmpPlan, + RenameList: analysis.normalizedFields, + } + return tmpPlan, nil } func getSelectFlag(ru *rule.Rule, stmt *ast.SelectStatement) (flag uint32) { @@ -571,7 +764,7 @@ func overwriteLimit(stmt *ast.SelectStatement, args *[]proto.Value) (originOffse return } -func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb string) error { +func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, o *optimize.Optimizer) error { // todo db 计算逻辑&tb shard 的计算逻辑 starExpand := false if len(stmt.Select) == 1 { @@ -584,9 +777,39 @@ func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb s return nil } - if len(tb) < 1 { - tb = stmt.From[0].Source.(ast.TableName).Suffix() + tbs := []ast.TableName{stmt.From[0].Source.(ast.TableName)} + for _, join := range stmt.From[0].Joins { + joinTable := join.Target.Source.(ast.TableName) + tbs = append(tbs, joinTable) } + + selectExpandElements := make([]ast.SelectElement, 0) + for _, t := range tbs { + shards, err := o.ComputeShards(ctx, t, nil, o.Args) + if err != nil { + return errors.WithStack(err) + } + + tb0 := t.Suffix() + if shards != nil { + vt := o.Rule.MustVTable(tb0) + _, tb0, _ = vt.Topology().Smallest() + } + + metadata, err := loadMetadataByTable(ctx, tb0) + if err != nil { + return errors.WithStack(err) + } + + for _, column := range metadata.ColumnNames { + selectExpandElements = append(selectExpandElements, ast.NewSelectElementColumn([]string{column}, "")) + } + } + stmt.Select = selectExpandElements + return nil +} + +func loadMetadataByTable(ctx context.Context, tb string) (*proto.TableMetadata, error) { metadatas, err := proto.LoadSchemaLoader().Load(ctx, rcontext.Schema(ctx), []string{tb}) if err != nil { if strings.Contains(err.Error(), "Table doesn't exist") { @@ -594,16 +817,156 @@ func rewriteSelectStatement(ctx context.Context, stmt *ast.SelectStatement, tb s } return errors.WithStack(err) } + metadata := metadatas[tb] if metadata == nil || len(metadata.ColumnNames) == 0 { - return errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) + return nil, errors.Errorf("optimize: cannot get metadata of `%s`.`%s`", rcontext.Schema(ctx), tb) + } + return metadata, nil +} + +func filterWhereByTable(ctx context.Context, where ast.ExpressionNode, table string, alis string) error { + metadata, err := loadMetadataByTable(ctx, table) + if err != nil { + return errors.WithStack(err) + } + + if err = filterNodeByTable(where, metadata, alis); err != nil { + return errors.WithStack(err) + } + + return nil +} + +var replaceNode = &ast.BinaryComparisonPredicateNode{ + Left: &ast.AtomPredicateNode{A: &ast.ConstantExpressionAtom{Inner: 1}}, + Right: &ast.AtomPredicateNode{A: &ast.ConstantExpressionAtom{Inner: 1}}, + Op: cmp.Ceq, +} + +func filterNodeByTable(expNode ast.ExpressionNode, metadata *proto.TableMetadata, alis string) error { + predicateNode, ok := expNode.(*ast.PredicateExpressionNode) + if ok { + bcpn, bcOk := predicateNode.P.(*ast.BinaryComparisonPredicateNode) + if bcOk { + columnNode, ok := bcpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + rightColumn, ok := bcpn.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if ok { + if rightColumn.Prefix() != "" { + if rightColumn.Prefix() != metadata.Name && rightColumn.Prefix() != alis { + return errors.New("not support node") + } + } else { + _, ok := metadata.Columns[rightColumn.Suffix()] + if !ok { + return errors.New("not support node") + } + } + } + return nil + } + + lpn, likeOk := predicateNode.P.(*ast.LikePredicateNode) + if likeOk { + columnNode := lpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + ipn, inOk := predicateNode.P.(*ast.InPredicateNode) + if inOk { + columnNode, ok := ipn.P.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + bpn, betweenOk := predicateNode.P.(*ast.BetweenPredicateNode) + if betweenOk { + columnNode, ok := bpn.Key.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + + //columnNode := bpn.Right.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + return nil + } + + rpn, regexpOk := predicateNode.P.(*ast.RegexpPredicationNode) + if regexpOk { + columnNode, ok := rpn.Left.(*ast.AtomPredicateNode).A.(ast.ColumnNameExpressionAtom) + if !ok { + return errors.New("invalid node") + } + if columnNode.Prefix() != "" { + if columnNode.Prefix() != metadata.Name && columnNode.Prefix() != alis { + predicateNode.P = replaceNode + } + } else { + _, ok := metadata.Columns[columnNode.Suffix()] + if !ok { + predicateNode.P = replaceNode + } + } + return nil + } + + return errors.New("invalid node") } - selectElements := make([]ast.SelectElement, len(metadata.Columns)) - for i, column := range metadata.ColumnNames { - selectElements[i] = ast.NewSelectElementColumn([]string{column}, "") + node, ok := expNode.(*ast.LogicalExpressionNode) + if !ok { + return errors.New("invalid node") } - stmt.Select = selectElements + if err := filterNodeByTable(node.Left, metadata, alis); err != nil { + return err + } + if err := filterNodeByTable(node.Right, metadata, alis); err != nil { + return err + } return nil } diff --git a/pkg/runtime/optimize/optimizer_test.go b/pkg/runtime/optimize/optimizer_test.go index df506c96..87b93b5e 100644 --- a/pkg/runtime/optimize/optimizer_test.go +++ b/pkg/runtime/optimize/optimizer_test.go @@ -33,6 +33,10 @@ import ( ) import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" "github.com/arana-db/arana/pkg/proto" "github.com/arana-db/arana/pkg/proto/rule" "github.com/arana-db/arana/pkg/resultx" @@ -64,7 +68,7 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { var ( sql = "select id, uid from student where uid in (?,?,?)" ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) - ru = makeFakeRule(ctrl, 8) + ru = makeFakeRule(ctrl, "student", 8, nil) ) p := parser.New() @@ -81,6 +85,96 @@ func TestOptimizer_OptimizeSelect(t *testing.T) { _, _ = plan.ExecIn(ctx, conn) } +func TestOptimizer_OptimizeHashJoin(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + studentFields := []proto.Field{ + mysql.NewField("uid", consts.FieldTypeLongLong), + } + + salariesFields := []proto.Field{ + mysql.NewField("uid", consts.FieldTypeLongLong), + } + + conn := testdata.NewMockVConn(ctrl) + buildPlan := true + conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args) + + result := testdata.NewMockResult(ctrl) + fakeData := &dataset.VirtualDataset{} + if buildPlan { + fakeData.Columns = append(studentFields, mysql.NewField("uid", consts.FieldTypeLongLong)) + for i := int64(0); i < 8; i++ { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + buildPlan = false + } else { + fakeData.Columns = append(salariesFields, mysql.NewField("uid", consts.FieldTypeLongLong)) + for i := int64(10); i > 3; i-- { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + } + + return result, nil + }). + AnyTimes() + + fakeData := make(map[string]*proto.TableMetadata) + // fake data + fakeData["student_0000"] = &proto.TableMetadata{ + Name: "student_0000", + Columns: map[string]*proto.ColumnMetadata{"uid": {}}, + ColumnNames: []string{"uid"}, + } + + fakeData["salaries_0000"] = &proto.TableMetadata{ + Name: "salaries_0000", + Columns: map[string]*proto.ColumnMetadata{"uid": {}}, + ColumnNames: []string{"uid"}, + } + loader := testdata.NewMockSchemaLoader(ctrl) + loader.EXPECT().Load(gomock.Any(), gomock.Any(), gomock.Any()).Return(fakeData, nil).AnyTimes() + + oldLoader := proto.LoadSchemaLoader() + proto.RegisterSchemaLoader(loader) + defer proto.RegisterSchemaLoader(oldLoader) + + var ( + sql = "select * from student join salaries on student.uid = salaries.uid" + ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) + ru = makeFakeRule(ctrl, "student", 8, nil) + ) + + ru = makeFakeRule(ctrl, "salaries", 8, ru) + + p := parser.New() + stmt, _ := p.ParseOneStmt(sql, "", "") + opt, err := NewOptimizer(ru, nil, stmt, nil) + assert.NoError(t, err) + + vTable, _ := ru.VTable("student") + vTable.SetAllowFullScan(true) + + vTable2, _ := ru.VTable("salaries") + vTable2.SetAllowFullScan(true) + + plan, err := opt.Optimize(ctx) + assert.NoError(t, err) + + _, _ = plan.ExecIn(ctx, conn) +} + func TestOptimizer_OptimizeInsert(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -143,7 +237,7 @@ func TestOptimizer_OptimizeInsert(t *testing.T) { var ( ctx = context.Background() - ru = makeFakeRule(ctrl, 8) + ru = makeFakeRule(ctrl, "student", 8, nil) ) t.Run("sharding", func(t *testing.T) { diff --git a/pkg/runtime/optimize/shard_visitor_test.go b/pkg/runtime/optimize/shard_visitor_test.go index 04bd3d70..4753a7ea 100644 --- a/pkg/runtime/optimize/shard_visitor_test.go +++ b/pkg/runtime/optimize/shard_visitor_test.go @@ -44,7 +44,7 @@ func TestShardNG(t *testing.T) { defer ctrl.Finish() // test rule: student, uid % 8 - fakeRule := makeFakeRule(ctrl, 8) + fakeRule := makeFakeRule(ctrl, "student", 8, nil) type tt struct { sql string @@ -78,17 +78,20 @@ func TestShardNG(t *testing.T) { } } -func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { +func makeFakeRule(c *gomock.Controller, table string, mod int, ru *rule.Rule) *rule.Rule { var ( - ru rule.Rule tab rule.VTable topo rule.Topology ) + if ru == nil { + ru = &rule.Rule{} + } + topo.SetRender(func(_ int) string { return "fake_db" }, func(i int) string { - return fmt.Sprintf("student_%04d", i) + return fmt.Sprintf("%s_%04d", table, i) }) tables := make([]int, 0, mod) @@ -98,7 +101,7 @@ func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { topo.SetTopology(0, tables...) tab.SetTopology(&topo) - tab.SetName("student") + tab.SetName(table) computer := testdata.NewMockShardComputer(c) @@ -111,13 +114,13 @@ func makeFakeRule(c *gomock.Controller, mod int) *rule.Rule { } return n % mod, nil }). - MinTimes(1) + AnyTimes() var sm rule.ShardMetadata sm.Steps = 8 sm.Computer = computer tab.SetShardMetadata("uid", nil, &sm) - ru.SetVTable("student", &tab) - return &ru + ru.SetVTable(table, &tab) + return ru } diff --git a/pkg/runtime/plan/dml/hash_join.go b/pkg/runtime/plan/dml/hash_join.go new file mode 100644 index 00000000..80deedab --- /dev/null +++ b/pkg/runtime/plan/dml/hash_join.go @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "bytes" + "context" + "io" +) + +import ( + "github.com/cespare/xxhash/v2" + + "github.com/pkg/errors" +) + +import ( + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/resultx" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/pkg/runtime/plan" + "github.com/arana-db/arana/third_party/base58" +) + +type HashJoinPlan struct { + BuildPlan proto.Plan + ProbePlan proto.Plan + + BuildKey string + ProbeKey string + hashArea map[string]proto.Row + IsFilterProbeRow bool + IsReversedColumn bool + + Stmt *ast.SelectStatement +} + +func (h *HashJoinPlan) Type() proto.PlanType { + return proto.PlanTypeQuery +} + +func (h *HashJoinPlan) ExecIn(ctx context.Context, conn proto.VConn) (proto.Result, error) { + ctx, span := plan.Tracer.Start(ctx, "HashJoinPlan.ExecIn") + defer span.End() + + // build stage + buildDs, err := h.build(ctx, conn) + if err != nil { + return nil, errors.WithStack(err) + } + + // probe stage + probeDs, err := h.probe(ctx, conn, buildDs) + if err != nil { + return nil, errors.WithStack(err) + } + + return resultx.New(resultx.WithDataset(probeDs)), nil +} + +func (h *HashJoinPlan) queryAggregate(ctx context.Context, conn proto.VConn, plan proto.Plan) (proto.Result, error) { + result, err := plan.ExecIn(ctx, conn) + if err != nil { + return nil, err + } + return result, nil +} + +func (h *HashJoinPlan) build(ctx context.Context, conn proto.VConn) (proto.Dataset, error) { + res, err := h.queryAggregate(ctx, conn, h.BuildPlan) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + cn := h.BuildKey + xh := xxhash.New() + h.hashArea = make(map[string]proto.Row) + // build map + for { + xh.Reset() + next, err := ds.Next() + if err == io.EOF { + break + } + + keyedRow := next.(proto.KeyedRow) + value, err := keyedRow.Get(cn) + if err != nil { + return nil, errors.WithStack(err) + } + + if value != nil { + _, _ = xh.WriteString(value.String()) + h.hashArea[base58.Encode(xh.Sum(nil))] = next + } + } + + return ds, nil +} + +func (h *HashJoinPlan) probe(ctx context.Context, conn proto.VConn, buildDataset proto.Dataset) (proto.Dataset, error) { + res, err := h.queryAggregate(ctx, conn, h.ProbePlan) + if err != nil { + return nil, errors.WithStack(err) + } + + ds, err := res.Dataset() + if err != nil { + return nil, errors.WithStack(err) + } + + probeMapFunc := func(row proto.Row, columnName string) proto.Row { + keyedRow := row.(proto.KeyedRow) + value, _ := keyedRow.Get(columnName) + if value != nil { + xh := xxhash.New() + _, _ = xh.WriteString(value.String()) + return h.hashArea[base58.Encode(xh.Sum(nil))] + } + return nil + } + + cn := h.ProbeKey + filterFunc := func(row proto.Row) bool { + findRow := probeMapFunc(row, cn) + if !h.IsFilterProbeRow { + return true + } + + return findRow != nil + } + + buildFields, err := buildDataset.Fields() + if err != nil { + return nil, errors.WithStack(err) + } + // aggregate fields + aggregateFieldsFunc := func(fields []proto.Field) []proto.Field { + if h.IsReversedColumn { + return append(fields[:len(fields)-1], buildFields[:len(buildFields)-1]...) + } + + return append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) + } + + // aggregate row + fields, err := ds.Fields() + if err != nil { + return nil, errors.WithStack(err) + } + transformFunc := func(row proto.Row) (proto.Row, error) { + dest := make([]proto.Value, len(fields)) + _ = row.Scan(dest) + + matchRow := probeMapFunc(row, cn) + buildDest := make([]proto.Value, len(buildFields)) + if matchRow != nil { + _ = matchRow.Scan(buildDest) + } else { + // set null row + if row.IsBinary() { + matchRow = rows.NewBinaryVirtualRow(buildFields, buildDest) + } else { + matchRow = rows.NewTextVirtualRow(buildFields, buildDest) + } + } + + var ( + resFields []proto.Field + resDest []proto.Value + ) + + // remove 'ON' column + if h.IsReversedColumn { + resFields = append(fields[:len(fields)-1], buildFields[:len(buildFields)-1]...) + resDest = append(dest[:len(dest)-1], buildDest[:len(buildDest)-1]...) + } else { + resFields = append(buildFields[:len(buildFields)-1], fields[:len(fields)-1]...) + resDest = append(buildDest[:len(buildDest)-1], dest[:len(dest)-1]...) + } + + var b bytes.Buffer + if row.IsBinary() { + newRow := rows.NewBinaryVirtualRow(resFields, resDest) + _, err := newRow.WriteTo(&b) + if err != nil { + return nil, err + } + + br := mysql.NewBinaryRow(resFields, b.Bytes()) + return br, nil + } else { + newRow := rows.NewTextVirtualRow(resFields, resDest) + _, err := newRow.WriteTo(&b) + if err != nil { + return nil, err + } + + return mysql.NewTextRow(resFields, b.Bytes()), nil + } + } + + // filter match row & aggregate fields and row + return dataset.Pipe(ds, dataset.Filter(filterFunc), dataset.Map(aggregateFieldsFunc, transformFunc)), nil +} diff --git a/pkg/runtime/plan/dml/hash_join_test.go b/pkg/runtime/plan/dml/hash_join_test.go new file mode 100644 index 00000000..b0a30147 --- /dev/null +++ b/pkg/runtime/plan/dml/hash_join_test.go @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package dml + +import ( + "context" + "fmt" + "io" + "testing" +) + +import ( + "github.com/golang/mock/gomock" + + "github.com/stretchr/testify/assert" +) + +import ( + consts "github.com/arana-db/arana/pkg/constants/mysql" + "github.com/arana-db/arana/pkg/dataset" + "github.com/arana-db/arana/pkg/mysql" + "github.com/arana-db/arana/pkg/mysql/rows" + "github.com/arana-db/arana/pkg/proto" + "github.com/arana-db/arana/pkg/runtime/ast" + "github.com/arana-db/arana/testdata" +) + +func TestHashJoinPlan(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + studentFields := []proto.Field{ + mysql.NewField("uid", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + salariesFields := []proto.Field{ + mysql.NewField("emp_no", consts.FieldTypeLongLong), + mysql.NewField("name", consts.FieldTypeString), + } + + buildPlan := true + conn := testdata.NewMockVConn(ctrl) + conn.EXPECT().Query(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, db string, sql string, args ...interface{}) (proto.Result, error) { + t.Logf("fake query: db=%s, sql=%s, args=%v\n", db, sql, args) + + result := testdata.NewMockResult(ctrl) + fakeData := &dataset.VirtualDataset{} + if buildPlan { + fakeData.Columns = append(studentFields, mysql.NewField("uid", consts.FieldTypeLongLong)) + for i := int64(0); i < 8; i++ { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-student-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + buildPlan = false + } else { + fakeData.Columns = append(salariesFields, mysql.NewField("emp_no", consts.FieldTypeLongLong)) + for i := int64(10); i > 3; i-- { + fakeData.Rows = append(fakeData.Rows, rows.NewTextVirtualRow(fakeData.Columns, []proto.Value{ + proto.NewValueInt64(i), + proto.NewValueString(fmt.Sprintf("fake-salaries-name-%d", i)), + proto.NewValueInt64(i), + })) + } + result.EXPECT().Dataset().Return(fakeData, nil).AnyTimes() + } + + return result, nil + }). + AnyTimes() + + var ( + sql1 = "SELECT * FROM student" // mock build plan + sql2 = "SELECT * FROM salaries" // mock probe plan + ctx = context.WithValue(context.Background(), proto.ContextKeyEnableLocalComputation{}, true) + ) + + _, stmt1, _ := ast.ParseSelect(sql1) + _, stmt2, _ := ast.ParseSelect(sql2) + + // sql: select * from student join salaries on uid = emp_no; + plan := &HashJoinPlan{ + BuildPlan: CompositePlan{ + []proto.Plan{ + &SimpleQueryPlan{ + Stmt: stmt1, + }, + }, + }, + ProbePlan: CompositePlan{ + []proto.Plan{ + &SimpleQueryPlan{ + Stmt: stmt2, + }, + }, + }, + IsFilterProbeRow: true, + BuildKey: "uid", + ProbeKey: "emp_no", + } + + res, err := plan.ExecIn(ctx, conn) + assert.NoError(t, err) + ds, _ := res.Dataset() + f, _ := ds.Fields() + + // expected field + assert.Equal(t, "uid", f[0].Name()) + assert.Equal(t, "name", f[1].Name()) + assert.Equal(t, "emp_no", f[2].Name()) + assert.Equal(t, "name", f[3].Name()) + for { + next, err := ds.Next() + if err == io.EOF { + break + } + dest := make([]proto.Value, len(f)) + _ = next.Scan(dest) + + // expected value: uid = emp_no + assert.Equal(t, dest[0], dest[2]) + } + +}