diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index b0401f56324..21a21202811 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -327,13 +327,12 @@ func (cluster *LocalProcessCluster) startKeyspace(keyspace Keyspace, shardNames if !cluster.ReusingVTDATAROOT { _ = cluster.VtctlProcess.CreateKeyspace(keyspace.Name) } - var mysqlctlProcessList []*exec.Cmd for _, shardName := range shardNames { shard := &Shard{ Name: shardName, } log.Infof("Starting shard: %v", shardName) - mysqlctlProcessList = []*exec.Cmd{} + var mysqlctlProcessList []*exec.Cmd for i := 0; i < totalTabletsRequired; i++ { // instantiate vttablet object with reserved ports tabletUID := cluster.GetAndReserveTabletUID() @@ -1191,8 +1190,16 @@ func (cluster *LocalProcessCluster) VtprocessInstanceFromVttablet(tablet *Vttabl } // StartVttablet starts a new tablet -func (cluster *LocalProcessCluster) StartVttablet(tablet *Vttablet, servingStatus string, - supportBackup bool, cell string, keyspaceName string, hostname string, shardName string) error { +func (cluster *LocalProcessCluster) StartVttablet( + tablet *Vttablet, + explicitServingStatus bool, + servingStatus string, + supportBackup bool, + cell string, + keyspaceName string, + hostname string, + shardName string, +) error { tablet.VttabletProcess = VttabletProcessInstance( tablet.HTTPPort, tablet.GrpcPort, @@ -1210,6 +1217,7 @@ func (cluster *LocalProcessCluster) StartVttablet(tablet *Vttablet, servingStatu tablet.VttabletProcess.SupportsBackup = supportBackup tablet.VttabletProcess.ServingStatus = servingStatus + tablet.VttabletProcess.ExplicitServingStatus = explicitServingStatus return tablet.VttabletProcess.Setup() } diff --git a/go/test/endtoend/cluster/vttablet_process.go b/go/test/endtoend/cluster/vttablet_process.go index 68dfddb831e..747bdc8fb39 100644 --- a/go/test/endtoend/cluster/vttablet_process.go +++ b/go/test/endtoend/cluster/vttablet_process.go @@ -40,6 +40,8 @@ import ( "vitess.io/vitess/go/vt/log" ) +const vttabletStateTimeout = 30 * time.Second + // VttabletProcess is a generic handle for a running vttablet . // It can be spawned manually type VttabletProcess struct { @@ -67,6 +69,7 @@ type VttabletProcess struct { QueryzURL string StatusDetailsURL string SupportsBackup bool + ExplicitServingStatus bool ServingStatus string DbPassword string DbPort int @@ -75,7 +78,7 @@ type VttabletProcess struct { Charset string ConsolidationsURL string - //Extra Args to be set before starting the vttablet process + // Extra Args to be set before starting the vttablet process ExtraArgs []string proc *exec.Cmd @@ -146,7 +149,15 @@ func (vttablet *VttabletProcess) Setup() (err error) { }() if vttablet.ServingStatus != "" { - if err = vttablet.WaitForTabletStatus(vttablet.ServingStatus); err != nil { + // If the tablet has an explicit serving status we use the serving status + // otherwise we wait for any serving status to show up in the healthcheck. + var servingStatus []string + if vttablet.ExplicitServingStatus { + servingStatus = append(servingStatus, vttablet.ServingStatus) + } else { + servingStatus = append(servingStatus, "SERVING", "NOT_SERVING") + } + if err = vttablet.WaitForTabletStatuses(servingStatus); err != nil { errFileContent, _ := os.ReadFile(fname) if errFileContent != nil { log.Infof("vttablet error:\n%s\n", string(errFileContent)) diff --git a/go/test/endtoend/tabletmanager/custom_rule_topo_test.go b/go/test/endtoend/tabletmanager/custom_rule_topo_test.go index 0d68c7a2521..537633e175a 100644 --- a/go/test/endtoend/tabletmanager/custom_rule_topo_test.go +++ b/go/test/endtoend/tabletmanager/custom_rule_topo_test.go @@ -71,7 +71,7 @@ func TestTopoCustomRule(t *testing.T) { require.Nil(t, err, "error should be Nil") // Start Vttablet - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.Nil(t, err, "error should be Nil") err = clusterInstance.VtctlclientProcess.ExecuteCommand("Validate") diff --git a/go/test/endtoend/tabletmanager/primary/tablet_test.go b/go/test/endtoend/tabletmanager/primary/tablet_test.go index 3db692694b5..28b238883ab 100644 --- a/go/test/endtoend/tabletmanager/primary/tablet_test.go +++ b/go/test/endtoend/tabletmanager/primary/tablet_test.go @@ -189,7 +189,7 @@ func TestPrimaryRestartSetsTERTimestamp(t *testing.T) { require.NoError(t, err) // Start Vttablet - err = clusterInstance.StartVttablet(&replicaTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(&replicaTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // Make sure that the TER did not change diff --git a/go/test/endtoend/tabletmanager/tablet_health_test.go b/go/test/endtoend/tabletmanager/tablet_health_test.go index 19359406607..5e5eaebca84 100644 --- a/go/test/endtoend/tabletmanager/tablet_health_test.go +++ b/go/test/endtoend/tabletmanager/tablet_health_test.go @@ -73,7 +73,7 @@ func TestTabletReshuffle(t *testing.T) { // SupportsBackup=False prevents vttablet from trying to restore // Start vttablet process - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) sql := "select value from t1" @@ -107,7 +107,7 @@ func TestHealthCheck(t *testing.T) { utils.Exec(t, replicaConn, fmt.Sprintf("create database vt_%s", keyspaceName)) // start vttablet process, should be in SERVING state as we already have a primary - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, true, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) conn, err := mysql.Connect(ctx, &primaryTabletParams) @@ -248,7 +248,7 @@ func TestHealthCheckDrainedStateDoesNotShutdownQueryService(t *testing.T) { // - the second tablet will be set to 'drained' and we expect that // - the query service won't be shutdown - //Wait if tablet is not in service state + // Wait if tablet is not in service state defer cluster.PanicHandler(t) err := rdonlyTablet.VttabletProcess.WaitForTabletStatus("SERVING") require.NoError(t, err) diff --git a/go/test/endtoend/tabletmanager/tablet_security_policy_test.go b/go/test/endtoend/tabletmanager/tablet_security_policy_test.go index a2e1e8bd987..e7aae5ef08c 100644 --- a/go/test/endtoend/tabletmanager/tablet_security_policy_test.go +++ b/go/test/endtoend/tabletmanager/tablet_security_policy_test.go @@ -39,7 +39,7 @@ func TestFallbackSecurityPolicy(t *testing.T) { // Requesting an unregistered security_policy should fallback to deny-all. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "bogus"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. @@ -94,7 +94,7 @@ func TestDenyAllSecurityPolicy(t *testing.T) { // Requesting a deny-all security_policy. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "deny-all"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. @@ -126,7 +126,7 @@ func TestReadOnlySecurityPolicy(t *testing.T) { // Requesting a read-only security_policy. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "read-only"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. diff --git a/go/test/endtoend/tabletmanager/tablet_test.go b/go/test/endtoend/tabletmanager/tablet_test.go index 3c597e97981..750379c5468 100644 --- a/go/test/endtoend/tabletmanager/tablet_test.go +++ b/go/test/endtoend/tabletmanager/tablet_test.go @@ -40,7 +40,7 @@ func TestEnsureDB(t *testing.T) { log.Info(fmt.Sprintf("Started vttablet %v", tablet)) // Start vttablet process as replica. It won't be able to serve because there's no db. - err = clusterInstance.StartVttablet(tablet, "NOT_SERVING", false, cell, "dbtest", hostname, "0") + err = clusterInstance.StartVttablet(tablet, false, "NOT_SERVING", false, cell, "dbtest", hostname, "0") require.NoError(t, err) // Make it the primary. @@ -73,7 +73,7 @@ func TestResetReplicationParameters(t *testing.T) { log.Info(fmt.Sprintf("Started vttablet %v", tablet)) // Start vttablet process as replica. It won't be able to serve because there's no db. - err = clusterInstance.StartVttablet(tablet, "NOT_SERVING", false, cell, "dbtest", hostname, "0") + err = clusterInstance.StartVttablet(tablet, false, "NOT_SERVING", false, cell, "dbtest", hostname, "0") require.NoError(t, err) // Set a replication source on the tablet and start replication diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 094ccc040bf..cdd9f97d671 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -185,3 +185,11 @@ func TestBuggyOuterJoin(t *testing.T) { mcmp.Exec("insert into t1(id1, id2) values (1,2), (42,5), (5, 42)") mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2") } + +func TestLeftJoinUsingUnsharded(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + utils.Exec(t, mcmp.VtConn, "insert /*vt+ QUERY_TIMEOUT_MS=1000 */ into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + utils.Exec(t, mcmp.VtConn, "select /*vt+ QUERY_TIMEOUT_MS=1000 */ * from uks.unsharded as A left join uks.unsharded as B using(id1)") +} diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 50238c38eee..831c45d1f95 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -68,6 +68,7 @@ var ( VT09009 = errorWithoutState("VT09009", vtrpcpb.Code_FAILED_PRECONDITION, "stream is supported only for primary tablet type, current type: %v", "Stream is only supported for primary tablets, please use a stream on those tablets.") VT09010 = errorWithoutState("VT09010", vtrpcpb.Code_FAILED_PRECONDITION, "SHOW VITESS_THROTTLER STATUS works only on primary tablet", "SHOW VITESS_THROTTLER STATUS works only on primary tablet.") VT09013 = errorWithoutState("VT09013", vtrpcpb.Code_FAILED_PRECONDITION, "semi-sync plugins are not loaded", "Durability policy wants Vitess to use semi-sync, but the MySQL instances don't have the semi-sync plugin loaded.") + VT09015 = errorWithoutState("VT09015", vtrpcpb.Code_FAILED_PRECONDITION, "schema tracking required", "This query cannot be planned without more information on the SQL schema. Please turn on schema tracking or add authoritative columns information to your VSchema.") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") @@ -125,6 +126,7 @@ var ( VT09009, VT09010, VT09013, + VT09015, VT10001, VT12001, VT13001, diff --git a/go/vt/vterrors/errors_test.go b/go/vt/vterrors/errors_test.go index c115fb41686..5ecd97fbd00 100644 --- a/go/vt/vterrors/errors_test.go +++ b/go/vt/vterrors/errors_test.go @@ -269,7 +269,7 @@ func TestStackFormat(t *testing.T) { // but the change in errors#27 made them incomparable. Assert that // various kinds of errors have a functional equality operator, even // if the result of that equality is always false. -func TestErrorEquality(t *testing.T) { +func TestErrorEquality(_ *testing.T) { vals := []error{ nil, io.EOF, diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index cfc673ad3af..14b435a1a46 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -6500,6 +6500,28 @@ ] } }, + { + "comment": "left join with using has to be transformed into inner join with on condition", + "query": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1", + "Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1", + "Table": "unsharded_authoritative" + }, + "TablesUsed": [ + "main.unsharded_authoritative" + ] + } + }, { "comment": "join query using table with muticolumn vindex", "query": "select 1 from multicol_tbl m1 join multicol_tbl m2 on m1.cola = m2.cola", diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 5a2c92451d4..d2ff7278e71 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -45,13 +45,13 @@ "comment": "join with USING construct", "query": "select * from user join user_extra using(id)", "v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries", - "gen4-plan": "can't handle JOIN USING without authoritative tables" + "gen4-plan": "VT09015: schema tracking required" }, { "comment": "join with USING construct with 3 tables", "query": "select user.id from user join user_extra using(id) join music using(id2)", "v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries", - "gen4-plan": "can't handle JOIN USING without authoritative tables" + "gen4-plan": "VT09015: schema tracking required" }, { "comment": "natural left join", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 33edf05b68f..8d80e2a6ed2 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -172,6 +172,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } + if err := a.rewriter.up(cursor); err != nil { + a.setError(err) + return true + } + a.leaveProjection(cursor) return a.shouldContinue() } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 2a1b8c2c0ac..e60a4aab11a 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -81,13 +81,6 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } currScope.joinUsing[ident.Lowered()] = deps.direct } - if len(node.Using) > 0 { - err := rewriteJoinUsing(currScope, node.Using, b.org) - if err != nil { - return err - } - node.Using = nil - } case *sqlparser.ColName: currentScope := b.scoper.currentScope() deps, err := b.resolveColumn(node, currentScope, false) diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 4071cfe3837..0a524807ff0 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -17,8 +17,8 @@ limitations under the License. package semantics import ( + "fmt" "strconv" - "strings" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -108,6 +108,33 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { return nil } +func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { + // this rewriting is done in the `up` phase, because we need the scope to have been + // filled in with the available tables + node, ok := cursor.Node().(*sqlparser.JoinTableExpr) + if !ok || len(node.Condition.Using) == 0 { + return nil + } + + err := rewriteJoinUsing(r.binder, node) + if err != nil { + return err + } + + // since the binder has already been over the join, we need to invoke it again so it + // can bind columns to the right tables + sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool { + innerErr := r.binder.up(cursor) + if innerErr == nil { + return true + } + + err = innerErr + return false + }) + return err +} + func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.SelectExprs) error { currentScope := r.scoper.currentScope() var selExprs sqlparser.SelectExprs @@ -279,67 +306,144 @@ func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr { return nil } -func rewriteJoinUsing( - current *scope, - using sqlparser.Columns, - org originable, -) error { - joinUsing := current.prepareUsingMap() - predicates := make([]sqlparser.Expr, 0, len(using)) - for _, column := range using { - var foundTables []sqlparser.TableName - for _, tbl := range current.tables { - if !tbl.authoritative() { - return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables") - } +// rewriteJoinUsing rewrites SQL JOINs that use the USING clause to their equivalent +// JOINs with the ON condition. This function finds all the tables that have the +// specified columns in the USING clause, constructs an equality predicate for +// each pair of tables, and adds the resulting predicates to the WHERE clause +// of the outermost SELECT statement. +// +// For example, given the query: +// +// SELECT * FROM t1 JOIN t2 USING (col1, col2) +// +// The rewriteJoinUsing function will rewrite the query to: +// +// SELECT * FROM t1 JOIN t2 ON (t1.col1 = t2.col1 AND t1.col2 = t2.col2) +// +// This function returns an error if it encounters a non-authoritative table or +// if it cannot find a SELECT statement to add the WHERE predicate to. +func rewriteJoinUsing(b *binder, join *sqlparser.JoinTableExpr) error { + predicates, err := buildJoinPredicates(b, join) + if err != nil { + return err + } + if len(predicates) > 0 { + join.Condition.On = sqlparser.AndExpressions(predicates...) + join.Condition.Using = nil + } + return nil +} - currTable := tbl.getTableSet(org) - usingCols := joinUsing[currTable] - if usingCols == nil { - usingCols = map[string]TableSet{} - } - for _, col := range tbl.getColumns() { - _, found := usingCols[strings.ToLower(col.Name)] - if found { - tblName, err := tbl.Name() - if err != nil { - return err - } +// buildJoinPredicates constructs the join predicates for a given set of USING columns. +// It returns a slice of sqlparser.Expr, each representing a join predicate for the given columns. +func buildJoinPredicates(b *binder, join *sqlparser.JoinTableExpr) ([]sqlparser.Expr, error) { + var predicates []sqlparser.Expr - foundTables = append(foundTables, tblName) - break // no need to look at other columns in this table - } + for _, column := range join.Condition.Using { + foundTables, err := findTablesWithColumn(b, join, column) + if err != nil { + return nil, err + } + + predicates = append(predicates, createComparisonPredicates(column, foundTables)...) + } + + return predicates, nil +} + +func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, column sqlparser.IdentifierCI) ([]TableInfo, error) { + switch tbl := tbl.(type) { + case *sqlparser.AliasedTableExpr: + ts := b.tc.tableSetFor(tbl) + tblInfo := b.tc.Tables[ts.TableOffset()] + for _, info := range tblInfo.getColumns() { + if column.EqualString(info.Name) { + return []TableInfo{tblInfo}, nil } } - for i, lft := range foundTables { - for j := i + 1; j < len(foundTables); j++ { - rgt := foundTables[j] - predicates = append(predicates, &sqlparser.ComparisonExpr{ - Operator: sqlparser.EqualOp, - Left: sqlparser.NewColNameWithQualifier(column.String(), lft), - Right: sqlparser.NewColNameWithQualifier(column.String(), rgt), - }) + return nil, nil + case *sqlparser.JoinTableExpr: + tblInfoR, err := findOnlyOneTableInfoThatHasColumn(b, tbl.RightExpr, column) + if err != nil { + return nil, err + } + tblInfoL, err := findOnlyOneTableInfoThatHasColumn(b, tbl.LeftExpr, column) + if err != nil { + return nil, err + } + + return append(tblInfoL, tblInfoR...), nil + case *sqlparser.ParenTableExpr: + var tblInfo []TableInfo + for _, parenTable := range tbl.Exprs { + newTblInfo, err := findOnlyOneTableInfoThatHasColumn(b, parenTable, column) + if err != nil { + return nil, err + } + if tblInfo != nil && newTblInfo != nil { + return nil, vterrors.VT03021(column.String()) + } + if newTblInfo != nil { + tblInfo = newTblInfo } } + return tblInfo, nil + default: + panic(fmt.Sprintf("unsupported TableExpr type in JOIN: %T", tbl)) } +} - // now, we go up the scope until we find a SELECT with a where clause we can add this predicate to - for current != nil { - sel, found := current.stmt.(*sqlparser.Select) - if found { - if sel.Where == nil { - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: sqlparser.AndExpressions(predicates...), - } - } else { - sel.Where.Expr = sqlparser.AndExpressions(append(predicates, sel.Where.Expr)...) - } - return nil +// findTablesWithColumn finds the tables with the specified column in the current scope. +func findTablesWithColumn(b *binder, join *sqlparser.JoinTableExpr, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) { + leftTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.LeftExpr, column) + if err != nil { + return nil, err + } + + rightTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.RightExpr, column) + if err != nil { + return nil, err + } + + if leftTableInfo == nil || rightTableInfo == nil { + return nil, ShardedError{Inner: vterrors.VT09015()} + } + var tableNames []sqlparser.TableName + for _, info := range leftTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err + } + tableNames = append(tableNames, nm) + } + for _, info := range rightTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err + } + tableNames = append(tableNames, nm) + } + return tableNames, nil +} + +// createComparisonPredicates creates a list of comparison predicates between the given column and foundTables. +func createComparisonPredicates(column sqlparser.IdentifierCI, foundTables []sqlparser.TableName) []sqlparser.Expr { + var predicates []sqlparser.Expr + for i, lft := range foundTables { + for j := i + 1; j < len(foundTables); j++ { + rgt := foundTables[j] + predicates = append(predicates, createComparisonBetween(column, lft, rgt)) } - current = current.parent } - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause") + return predicates +} + +func createComparisonBetween(column sqlparser.IdentifierCI, lft, rgt sqlparser.TableName) *sqlparser.ComparisonExpr { + return &sqlparser.ComparisonExpr{ + Operator: sqlparser.EqualOp, + Left: sqlparser.NewColNameWithQualifier(column.String(), lft), + Right: sqlparser.NewColNameWithQualifier(column.String(), rgt), + } } func (r *earlyRewriter) expandTableColumns( diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 324172ab72e..3819d766100 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -143,24 +143,30 @@ func TestExpandStar(t *testing.T) { }, { sql: "select * from t1 join t2 on t1.a = t2.c1", expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 left join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 left join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 right join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 right join t2 on t1.a = t2.c1", }, { sql: "select * from t2 join t4 using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 on t2.c1 = t4.c1", }, { sql: "select * from t2 join t4 using (c1) join t2 as X using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 on t2.c1 = t4.c1 join t2 as X on t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", }, { sql: "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4 on t2.c1 = t4.c1, t2 as t2b join t4 as t4b on t2b.c1 = t4b.c1", }, { sql: "select * from t1 join t5 using (b)", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b", }, { sql: "select * from t1 join t5 using (b) having b = 12", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having b = 12", }, { sql: "select 1 from t1 join t5 using (b) having b = 12", - expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12", + expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { sql: "select * from (select 12) as t", expSQL: "select t.`12` from (select 12 from dual) as t", @@ -269,13 +275,16 @@ func TestRewriteJoinUsingColumns(t *testing.T) { expErr string }{{ sql: "select 1 from t1 join t2 using (a) where a = 42", - expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42", + expSQL: "select 1 from t1 join t2 on t1.a = t2.a where t1.a = 42", }, { sql: "select 1 from t1 join t2 using (a), t3 where a = 42", expErr: "Column 'a' in field list is ambiguous", }, { sql: "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42", expErr: "Column 'a' in field list is ambiguous", + }, { + sql: "select 1 from t1 left join t2 using (a) where a = 42", + expSQL: "select 1 from t1 left join t2 on t1.a = t2.a where t1.a = 42", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) {