From 3b20a40c6d6b0abd2a356fd3a80bef5778595a9b Mon Sep 17 00:00:00 2001 From: "vitess-bot[bot]" <108069721+vitess-bot[bot]@users.noreply.github.com> Date: Fri, 22 Sep 2023 16:09:54 -0400 Subject: [PATCH] [release-17.0] Rewrite `USING` to `ON` condition for joins (#13931) (#13941) Signed-off-by: Florent Poinsard Co-authored-by: Florent Poinsard <35779988+frouioui@users.noreply.github.com> Co-authored-by: Andres Taylor Co-authored-by: Florent Poinsard --- go/test/endtoend/cluster/cluster_process.go | 16 +- go/test/endtoend/cluster/vttablet_process.go | 15 +- .../tabletmanager/custom_rule_topo_test.go | 2 +- .../tabletmanager/primary/tablet_test.go | 2 +- .../tabletmanager/tablet_health_test.go | 8 +- .../tablet_security_policy_test.go | 6 +- go/test/endtoend/tabletmanager/tablet_test.go | 4 +- .../vtgate/queries/aggregation/fuzz_test.go | 10 +- .../endtoend/vtgate/queries/misc/misc_test.go | 8 + go/vt/schemadiff/schema_test.go | 2 +- go/vt/vterrors/code.go | 2 + .../planbuilder/testdata/from_cases.json | 22 +++ .../testdata/unsupported_cases.json | 4 +- go/vt/vtgate/semantics/analyzer.go | 5 + go/vt/vtgate/semantics/binder.go | 7 - go/vt/vtgate/semantics/early_rewriter.go | 151 ++++++++++++------ go/vt/vtgate/semantics/early_rewriter_test.go | 23 ++- 17 files changed, 194 insertions(+), 93 deletions(-) diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index dfcfcbc4947..075bec7c868 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -334,13 +334,12 @@ func (cluster *LocalProcessCluster) startKeyspace(keyspace Keyspace, shardNames } // Create the keyspace if it doesn't already exist. _ = cluster.VtctlProcess.CreateKeyspace(keyspace.Name, keyspace.SidecarDBName) - var mysqlctlProcessList []*exec.Cmd for _, shardName := range shardNames { shard := &Shard{ Name: shardName, } log.Infof("Starting shard: %v", shardName) - mysqlctlProcessList = []*exec.Cmd{} + var mysqlctlProcessList []*exec.Cmd for i := 0; i < totalTabletsRequired; i++ { // instantiate vttablet object with reserved ports tabletUID := cluster.GetAndReserveTabletUID() @@ -1276,8 +1275,16 @@ func (cluster *LocalProcessCluster) VtprocessInstanceFromVttablet(tablet *Vttabl } // StartVttablet starts a new tablet -func (cluster *LocalProcessCluster) StartVttablet(tablet *Vttablet, servingStatus string, - supportBackup bool, cell string, keyspaceName string, hostname string, shardName string) error { +func (cluster *LocalProcessCluster) StartVttablet( + tablet *Vttablet, + explicitServingStatus bool, + servingStatus string, + supportBackup bool, + cell string, + keyspaceName string, + hostname string, + shardName string, +) error { tablet.VttabletProcess = VttabletProcessInstance( tablet.HTTPPort, tablet.GrpcPort, @@ -1295,6 +1302,7 @@ func (cluster *LocalProcessCluster) StartVttablet(tablet *Vttablet, servingStatu tablet.VttabletProcess.SupportsBackup = supportBackup tablet.VttabletProcess.ServingStatus = servingStatus + tablet.VttabletProcess.ExplicitServingStatus = explicitServingStatus return tablet.VttabletProcess.Setup() } diff --git a/go/test/endtoend/cluster/vttablet_process.go b/go/test/endtoend/cluster/vttablet_process.go index 522c205114a..96d6dd04ef0 100644 --- a/go/test/endtoend/cluster/vttablet_process.go +++ b/go/test/endtoend/cluster/vttablet_process.go @@ -42,7 +42,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" ) -const vttabletStateTimeout = 30 * time.Second +const vttabletStateTimeout = 60 * time.Second // VttabletProcess is a generic handle for a running vttablet . // It can be spawned manually @@ -71,6 +71,7 @@ type VttabletProcess struct { QueryzURL string StatusDetailsURL string SupportsBackup bool + ExplicitServingStatus bool ServingStatus string DbPassword string DbPort int @@ -79,7 +80,7 @@ type VttabletProcess struct { Charset string ConsolidationsURL string - //Extra Args to be set before starting the vttablet process + // Extra Args to be set before starting the vttablet process ExtraArgs []string proc *exec.Cmd @@ -149,7 +150,15 @@ func (vttablet *VttabletProcess) Setup() (err error) { }() if vttablet.ServingStatus != "" { - if err = vttablet.WaitForTabletStatus(vttablet.ServingStatus); err != nil { + // If the tablet has an explicit serving status we use the serving status + // otherwise we wait for any serving status to show up in the healthcheck. + var servingStatus []string + if vttablet.ExplicitServingStatus { + servingStatus = append(servingStatus, vttablet.ServingStatus) + } else { + servingStatus = append(servingStatus, "SERVING", "NOT_SERVING") + } + if err = vttablet.WaitForTabletStatuses(servingStatus); err != nil { errFileContent, _ := os.ReadFile(fname) if errFileContent != nil { log.Infof("vttablet error:\n%s\n", string(errFileContent)) diff --git a/go/test/endtoend/tabletmanager/custom_rule_topo_test.go b/go/test/endtoend/tabletmanager/custom_rule_topo_test.go index fb6a64efef3..aa09a99e0fe 100644 --- a/go/test/endtoend/tabletmanager/custom_rule_topo_test.go +++ b/go/test/endtoend/tabletmanager/custom_rule_topo_test.go @@ -71,7 +71,7 @@ func TestTopoCustomRule(t *testing.T) { require.Nil(t, err, "error should be Nil") // Start Vttablet - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.Nil(t, err, "error should be Nil") err = clusterInstance.VtctlclientProcess.ExecuteCommand("Validate") diff --git a/go/test/endtoend/tabletmanager/primary/tablet_test.go b/go/test/endtoend/tabletmanager/primary/tablet_test.go index 3db692694b5..28b238883ab 100644 --- a/go/test/endtoend/tabletmanager/primary/tablet_test.go +++ b/go/test/endtoend/tabletmanager/primary/tablet_test.go @@ -189,7 +189,7 @@ func TestPrimaryRestartSetsTERTimestamp(t *testing.T) { require.NoError(t, err) // Start Vttablet - err = clusterInstance.StartVttablet(&replicaTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(&replicaTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // Make sure that the TER did not change diff --git a/go/test/endtoend/tabletmanager/tablet_health_test.go b/go/test/endtoend/tabletmanager/tablet_health_test.go index 17017e8b807..8e70af9f566 100644 --- a/go/test/endtoend/tabletmanager/tablet_health_test.go +++ b/go/test/endtoend/tabletmanager/tablet_health_test.go @@ -73,7 +73,7 @@ func TestTabletReshuffle(t *testing.T) { // SupportsBackup=False prevents vttablet from trying to restore // Start vttablet process - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) sql := "select value from t1" @@ -106,7 +106,7 @@ func TestHealthCheck(t *testing.T) { defer replicaConn.Close() // start vttablet process, should be in SERVING state as we already have a primary - err = clusterInstance.StartVttablet(rTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(rTablet, true, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) conn, err := mysql.Connect(ctx, &primaryTabletParams) @@ -227,7 +227,7 @@ func TestHealthCheckSchemaChangeSignal(t *testing.T) { clusterInstance.VtTabletExtraArgs = oldArgs }() // start vttablet process, should be in SERVING state as we already have a primary. - err = clusterInstance.StartVttablet(tempTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(tempTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) defer func() { @@ -381,7 +381,7 @@ func TestHealthCheckDrainedStateDoesNotShutdownQueryService(t *testing.T) { // - the second tablet will be set to 'drained' and we expect that // - the query service won't be shutdown - //Wait if tablet is not in service state + // Wait if tablet is not in service state defer cluster.PanicHandler(t) clusterInstance.DisableVTOrcRecoveries(t) defer clusterInstance.EnableVTOrcRecoveries(t) diff --git a/go/test/endtoend/tabletmanager/tablet_security_policy_test.go b/go/test/endtoend/tabletmanager/tablet_security_policy_test.go index 2ad907ec7b8..b3b11405abb 100644 --- a/go/test/endtoend/tabletmanager/tablet_security_policy_test.go +++ b/go/test/endtoend/tabletmanager/tablet_security_policy_test.go @@ -39,7 +39,7 @@ func TestFallbackSecurityPolicy(t *testing.T) { // Requesting an unregistered security_policy should fallback to deny-all. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "bogus"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. @@ -94,7 +94,7 @@ func TestDenyAllSecurityPolicy(t *testing.T) { // Requesting a deny-all security_policy. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "deny-all"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. @@ -126,7 +126,7 @@ func TestReadOnlySecurityPolicy(t *testing.T) { // Requesting a read-only security_policy. clusterInstance.VtTabletExtraArgs = []string{"--security_policy", "read-only"} - err = clusterInstance.StartVttablet(mTablet, "SERVING", false, cell, keyspaceName, hostname, shardName) + err = clusterInstance.StartVttablet(mTablet, false, "SERVING", false, cell, keyspaceName, hostname, shardName) require.NoError(t, err) // It should deny ADMIN role. diff --git a/go/test/endtoend/tabletmanager/tablet_test.go b/go/test/endtoend/tabletmanager/tablet_test.go index 97715d39a58..4fe5a70d125 100644 --- a/go/test/endtoend/tabletmanager/tablet_test.go +++ b/go/test/endtoend/tabletmanager/tablet_test.go @@ -43,7 +43,7 @@ func TestEnsureDB(t *testing.T) { log.Info(fmt.Sprintf("Started vttablet %v", tablet)) // Start vttablet process as replica. It won't be able to serve because there's no db. - err = clusterInstance.StartVttablet(tablet, "NOT_SERVING", false, cell, "dbtest", hostname, "0") + err = clusterInstance.StartVttablet(tablet, false, "NOT_SERVING", false, cell, "dbtest", hostname, "0") require.NoError(t, err) // Make it the primary. @@ -78,7 +78,7 @@ func TestResetReplicationParameters(t *testing.T) { log.Info(fmt.Sprintf("Started vttablet %v", tablet)) // Start vttablet process as replica. It won't be able to serve because there's no db. - err = clusterInstance.StartVttablet(tablet, "NOT_SERVING", false, cell, "dbtest", hostname, "0") + err = clusterInstance.StartVttablet(tablet, false, "NOT_SERVING", false, cell, "dbtest", hostname, "0") require.NoError(t, err) // Set a replication source on the tablet and start replication diff --git a/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go b/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go index 25bec1a39b4..985116c488b 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/fuzz_test.go @@ -44,7 +44,7 @@ func TestFuzzAggregations(t *testing.T) { mcmp, closer := start(t) defer closer() - noOfRows := rand.Intn(20) + noOfRows := rand.Intn(20) + 1 var values []string for i := 0; i < noOfRows; i++ { values = append(values, fmt.Sprintf("(%d, 'name%d', 'value%d', %d)", i, i, i, i)) @@ -160,10 +160,10 @@ func createAggregations(tables []tableT, maxAggrs int, randomCol func(tblIdx int aggregations := []func(string) string{ func(_ string) string { return "count(*)" }, func(e string) string { return fmt.Sprintf("count(%s)", e) }, - //func(e string) string { return fmt.Sprintf("sum(%s)", e) }, - //func(e string) string { return fmt.Sprintf("avg(%s)", e) }, - //func(e string) string { return fmt.Sprintf("min(%s)", e) }, - //func(e string) string { return fmt.Sprintf("max(%s)", e) }, + // func(e string) string { return fmt.Sprintf("sum(%s)", e) }, + // func(e string) string { return fmt.Sprintf("avg(%s)", e) }, + // func(e string) string { return fmt.Sprintf("min(%s)", e) }, + // func(e string) string { return fmt.Sprintf("max(%s)", e) }, } noOfAggrs := rand.Intn(maxAggrs) + 1 diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index afe30e9024d..df6f69f97be 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -244,3 +244,11 @@ func TestBuggyOuterJoin(t *testing.T) { mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2") } + +func TestLeftJoinUsingUnsharded(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + utils.Exec(t, mcmp.VtConn, "insert /*vt+ QUERY_TIMEOUT_MS=2000 */ into uks.unsharded(id1) values (1),(2),(3),(4),(5)") + utils.Exec(t, mcmp.VtConn, "select /*vt+ QUERY_TIMEOUT_MS=2000 */ * from uks.unsharded as A left join uks.unsharded as B using(id1)") +} diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 1fe6b0f2d86..89543531ec6 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -628,7 +628,7 @@ func TestViewReferences(t *testing.T) { "create table t2(id int primary key, n int, info int)", "create view v1 as select id, c as ch from t1 where id > 0", "create view v2 as select n as num, info from t2", - "create view v3 as select num, v1.id, ch from v1 join v2 using (id) where info > 5", + "create view v3 as select num, v1.id, ch from v1 join v2 on v1.id = v2.num where info > 5", }, }, { diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index 26abd85e49e..2aa9827c084 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -73,6 +73,7 @@ var ( VT09012 = errorWithoutState("VT09012", vtrpcpb.Code_FAILED_PRECONDITION, "%s statement with %s tablet not allowed", "This type of statement is not allowed on the given tablet.") VT09013 = errorWithoutState("VT09013", vtrpcpb.Code_FAILED_PRECONDITION, "semi-sync plugins are not loaded", "Durability policy wants Vitess to use semi-sync, but the MySQL instances don't have the semi-sync plugin loaded.") VT09014 = errorWithoutState("VT09014", vtrpcpb.Code_FAILED_PRECONDITION, "vindex cannot be modified", "The vindex cannot be used as table in DML statement") + VT09015 = errorWithoutState("VT09015", vtrpcpb.Code_FAILED_PRECONDITION, "schema tracking required", "This query cannot be planned without more information on the SQL schema. Please turn on schema tracking or add authoritative columns information to your VSchema.") VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.") @@ -136,6 +137,7 @@ var ( VT09012, VT09013, VT09014, + VT09015, VT10001, VT12001, VT13001, diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index c776443e490..cb233af88b8 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -6533,5 +6533,27 @@ "user.multicol_tbl" ] } + }, + { + "comment": "left join with using has to be transformed into inner join with on condition", + "query": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1", + "Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1", + "Table": "unsharded_authoritative" + }, + "TablesUsed": [ + "main.unsharded_authoritative" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 3676c09ead9..67af9a3f9fa 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -39,13 +39,13 @@ "comment": "join with USING construct", "query": "select * from user join user_extra using(id)", "v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries", - "gen4-plan": "can't handle JOIN USING without authoritative tables" + "gen4-plan": "VT09015: schema tracking required" }, { "comment": "join with USING construct with 3 tables", "query": "select user.id from user join user_extra using(id) join music using(id2)", "v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries", - "gen4-plan": "can't handle JOIN USING without authoritative tables" + "gen4-plan": "VT09015: schema tracking required" }, { "comment": "natural left join", diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index d08212d0ad0..2a6fb2464d4 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -186,6 +186,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } + if err := a.rewriter.up(cursor); err != nil { + a.setError(err) + return true + } + a.leaveProjection(cursor) return a.shouldContinue() } diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index b9239fae69f..c84e432a561 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -84,13 +84,6 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { } currScope.joinUsing[ident.Lowered()] = deps.direct } - if len(node.Using) > 0 { - err := rewriteJoinUsing(currScope, node.Using, b.org) - if err != nil { - return err - } - node.Using = nil - } case *sqlparser.ColName: currentScope := b.scoper.currentScope() deps, err := b.resolveColumn(node, currentScope, false) diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index b3553a2de73..77ec5a775e0 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -17,8 +17,8 @@ limitations under the License. package semantics import ( + "fmt" "strconv" - "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -60,6 +60,33 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { return nil } +func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { + // this rewriting is done in the `up` phase, because we need the scope to have been + // filled in with the available tables + node, ok := cursor.Node().(*sqlparser.JoinTableExpr) + if !ok || len(node.Condition.Using) == 0 { + return nil + } + + err := rewriteJoinUsing(r.binder, node) + if err != nil { + return err + } + + // since the binder has already been over the join, we need to invoke it again so it + // can bind columns to the right tables + sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool { + innerErr := r.binder.up(cursor) + if innerErr == nil { + return true + } + + err = innerErr + return false + }) + return err +} + // handleWhereClause processes WHERE clauses, specifically the HAVING clause. func handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) { if node.Type != sqlparser.HavingClause { @@ -344,44 +371,25 @@ func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr { // // This function returns an error if it encounters a non-authoritative table or // if it cannot find a SELECT statement to add the WHERE predicate to. -func rewriteJoinUsing( - current *scope, - using sqlparser.Columns, - org originable, -) error { - predicates, err := buildJoinPredicates(current, using, org) +func rewriteJoinUsing(b *binder, join *sqlparser.JoinTableExpr) error { + predicates, err := buildJoinPredicates(b, join) if err != nil { return err } - // now, we go up the scope until we find a SELECT - // with a where clause we can add this predicate to - for current != nil { - sel, found := current.stmt.(*sqlparser.Select) - if !found { - current = current.parent - continue - } - if sel.Where != nil { - predicates = append(predicates, sel.Where.Expr) - sel.Where = nil - } - sel.Where = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: sqlparser.AndExpressions(predicates...), - } - return nil + if len(predicates) > 0 { + join.Condition.On = sqlparser.AndExpressions(predicates...) + join.Condition.Using = nil } - return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause") + return nil } // buildJoinPredicates constructs the join predicates for a given set of USING columns. // It returns a slice of sqlparser.Expr, each representing a join predicate for the given columns. -func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable) ([]sqlparser.Expr, error) { - joinUsing := current.prepareUsingMap() +func buildJoinPredicates(b *binder, join *sqlparser.JoinTableExpr) ([]sqlparser.Expr, error) { var predicates []sqlparser.Expr - for _, column := range using { - foundTables, err := findTablesWithColumn(current, joinUsing, org, column) + for _, column := range join.Condition.Using { + foundTables, err := findTablesWithColumn(b, join, column) if err != nil { return nil, err } @@ -392,42 +400,79 @@ func buildJoinPredicates(current *scope, using sqlparser.Columns, org originable return predicates, nil } -// findTablesWithColumn finds the tables with the specified column in the current scope. -func findTablesWithColumn(current *scope, joinUsing map[TableSet]map[string]TableSet, org originable, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) { - var foundTables []sqlparser.TableName - - for _, tbl := range current.tables { - if !tbl.authoritative() { - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables") +func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, column sqlparser.IdentifierCI) ([]TableInfo, error) { + switch tbl := tbl.(type) { + case *sqlparser.AliasedTableExpr: + ts := b.tc.tableSetFor(tbl) + tblInfo := b.tc.Tables[ts.TableOffset()] + for _, info := range tblInfo.getColumns() { + if column.EqualString(info.Name) { + return []TableInfo{tblInfo}, nil + } } - - currTable := tbl.getTableSet(org) - usingCols := joinUsing[currTable] - if usingCols == nil { - usingCols = map[string]TableSet{} + return nil, nil + case *sqlparser.JoinTableExpr: + tblInfoR, err := findOnlyOneTableInfoThatHasColumn(b, tbl.RightExpr, column) + if err != nil { + return nil, err + } + tblInfoL, err := findOnlyOneTableInfoThatHasColumn(b, tbl.LeftExpr, column) + if err != nil { + return nil, err } - if hasColumnInTable(tbl, usingCols) { - tblName, err := tbl.Name() + return append(tblInfoL, tblInfoR...), nil + case *sqlparser.ParenTableExpr: + var tblInfo []TableInfo + for _, parenTable := range tbl.Exprs { + newTblInfo, err := findOnlyOneTableInfoThatHasColumn(b, parenTable, column) if err != nil { return nil, err } - foundTables = append(foundTables, tblName) + if tblInfo != nil && newTblInfo != nil { + return nil, vterrors.VT03021(column.String()) + } + if newTblInfo != nil { + tblInfo = newTblInfo + } } + return tblInfo, nil + default: + panic(fmt.Sprintf("unsupported TableExpr type in JOIN: %T", tbl)) } - - return foundTables, nil } -// hasColumnInTable checks if the specified table has the given column. -func hasColumnInTable(tbl TableInfo, usingCols map[string]TableSet) bool { - for _, col := range tbl.getColumns() { - _, found := usingCols[strings.ToLower(col.Name)] - if found { - return true +// findTablesWithColumn finds the tables with the specified column in the current scope. +func findTablesWithColumn(b *binder, join *sqlparser.JoinTableExpr, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) { + leftTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.LeftExpr, column) + if err != nil { + return nil, err + } + + rightTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.RightExpr, column) + if err != nil { + return nil, err + } + + if leftTableInfo == nil || rightTableInfo == nil { + return nil, ShardedError{Inner: vterrors.VT09015()} + } + var tableNames []sqlparser.TableName + for _, info := range leftTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err + } + tableNames = append(tableNames, nm) + } + for _, info := range rightTableInfo { + nm, err := info.Name() + if err != nil { + return nil, err } + tableNames = append(tableNames, nm) } - return false + return tableNames, nil } // createComparisonPredicates creates a list of comparison predicates between the given column and foundTables. diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index f1b16853cfc..2846bfd9366 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -144,26 +144,32 @@ func TestExpandStar(t *testing.T) { }, { sql: "select * from t1 join t2 on t1.a = t2.c1", expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 left join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 left join t2 on t1.a = t2.c1", + }, { + sql: "select * from t1 right join t2 on t1.a = t2.c1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1 right join t2 on t1.a = t2.c1", }, { sql: "select * from t2 join t4 using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 where t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4 from t2 join t4 on t2.c1 = t4.c1", expanded: "main.t2.c1, main.t2.c2, main.t4.c4", }, { sql: "select * from t2 join t4 using (c1) join t2 as X using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 join t2 as X where t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, X.c2 as c2 from t2 join t4 on t2.c1 = t4.c1 join t2 as X on t2.c1 = t4.c1 and t2.c1 = X.c1 and t4.c1 = X.c1", }, { sql: "select * from t2 join t4 using (c1), t2 as t2b join t4 as t4b using (c1)", - expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4, t2 as t2b join t4 as t4b where t2b.c1 = t4b.c1 and t2.c1 = t4.c1", + expSQL: "select t2.c1 as c1, t2.c2 as c2, t4.c4 as c4, t2b.c1 as c1, t2b.c2 as c2, t4b.c4 as c4 from t2 join t4 on t2.c1 = t4.c1, t2 as t2b join t4 as t4b on t2b.c1 = t4b.c1", }, { sql: "select * from t1 join t5 using (b)", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b", expanded: "main.t1.a, main.t1.b, main.t1.c, main.t5.a", }, { sql: "select * from t1 join t5 using (b) having b = 12", - expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 where t1.b = t5.b having b = 12", + expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having b = 12", }, { sql: "select 1 from t1 join t5 using (b) having b = 12", - expSQL: "select 1 from t1 join t5 where t1.b = t5.b having t1.b = 12", + expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { sql: "select * from (select 12) as t", expSQL: "select t.`12` from (select 12 from dual) as t", @@ -265,13 +271,16 @@ func TestRewriteJoinUsingColumns(t *testing.T) { expErr string }{{ sql: "select 1 from t1 join t2 using (a) where a = 42", - expSQL: "select 1 from t1 join t2 where t1.a = t2.a and t1.a = 42", + expSQL: "select 1 from t1 join t2 on t1.a = t2.a where t1.a = 42", }, { sql: "select 1 from t1 join t2 using (a), t3 where a = 42", expErr: "Column 'a' in field list is ambiguous", }, { sql: "select 1 from t1 join t2 using (a), t1 as b join t3 on (a) where a = 42", expErr: "Column 'a' in field list is ambiguous", + }, { + sql: "select 1 from t1 left join t2 using (a) where a = 42", + expSQL: "select 1 from t1 left join t2 on t1.a = t2.a where t1.a = 42", }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) {