diff --git a/go/test/endtoend/vtgate/foreignkey/fk_fuzz_test.go b/go/test/endtoend/vtgate/foreignkey/fk_fuzz_test.go index 134b9cfa180..f7d7340d6a4 100644 --- a/go/test/endtoend/vtgate/foreignkey/fk_fuzz_test.go +++ b/go/test/endtoend/vtgate/foreignkey/fk_fuzz_test.go @@ -28,9 +28,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" "vitess.io/vitess/go/vt/log" ) @@ -39,6 +37,7 @@ type QueryFormat string const ( SQLQueries QueryFormat = "SQLQueries" + OlapSQLQueries QueryFormat = "OlapSQLQueries" PreparedStatmentQueries QueryFormat = "PreparedStatmentQueries" PreparedStatementPacket QueryFormat = "PreparedStatementPacket" ) @@ -95,7 +94,7 @@ func (fz *fuzzer) generateQuery() []string { val := rand.Intn(fz.insertShare + fz.updateShare + fz.deleteShare) if val < fz.insertShare { switch fz.queryFormat { - case SQLQueries: + case OlapSQLQueries, SQLQueries: return []string{fz.generateInsertDMLQuery()} case PreparedStatmentQueries: return fz.getPreparedInsertQueries() @@ -105,7 +104,7 @@ func (fz *fuzzer) generateQuery() []string { } if val < fz.insertShare+fz.updateShare { switch fz.queryFormat { - case SQLQueries: + case OlapSQLQueries, SQLQueries: return []string{fz.generateUpdateDMLQuery()} case PreparedStatmentQueries: return fz.getPreparedUpdateQueries() @@ -114,7 +113,7 @@ func (fz *fuzzer) generateQuery() []string { } } switch fz.queryFormat { - case SQLQueries: + case OlapSQLQueries, SQLQueries: return []string{fz.generateDeleteDMLQuery()} case PreparedStatmentQueries: return fz.getPreparedDeleteQueries() @@ -131,41 +130,43 @@ func (fz *fuzzer) generateInsertDMLQuery() string { if tableName == "fk_t20" { colValue := rand.Intn(1 + fz.maxValForCol) col2Value := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, col, col2) values (%v, %v, %v)", tableName, idValue, convertColValueToString(colValue), convertColValueToString(col2Value)) + return fmt.Sprintf("insert into %v (id, col, col2) values (%v, %v, %v)", tableName, idValue, convertIntValueToString(colValue), convertIntValueToString(col2Value)) } else if isMultiColFkTable(tableName) { colaValue := rand.Intn(1 + fz.maxValForCol) colbValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, cola, colb) values (%v, %v, %v)", tableName, idValue, convertColValueToString(colaValue), convertColValueToString(colbValue)) + return fmt.Sprintf("insert into %v (id, cola, colb) values (%v, %v, %v)", tableName, idValue, convertIntValueToString(colaValue), convertIntValueToString(colbValue)) } else { colValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, col) values (%v, %v)", tableName, idValue, convertColValueToString(colValue)) + return fmt.Sprintf("insert into %v (id, col) values (%v, %v)", tableName, idValue, convertIntValueToString(colValue)) } } -// convertColValueToString converts the given value to a string -func convertColValueToString(value int) string { - if value == 0 { - return "NULL" - } - return fmt.Sprintf("%d", value) -} - // generateUpdateDMLQuery generates an UPDATE query from the parameters for the fuzzer. func (fz *fuzzer) generateUpdateDMLQuery() string { tableId := rand.Intn(len(fkTables)) idValue := 1 + rand.Intn(fz.maxValForId) tableName := fkTables[tableId] if tableName == "fk_t20" { - colValue := rand.Intn(1 + fz.maxValForCol) - col2Value := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set col = %v, col2 = %v where id = %v", tableName, convertColValueToString(colValue), convertColValueToString(col2Value), idValue) + colValue := convertIntValueToString(rand.Intn(1 + fz.maxValForCol)) + col2Value := convertIntValueToString(rand.Intn(1 + fz.maxValForCol)) + return fmt.Sprintf("update %v set col = %v, col2 = %v where id = %v", tableName, colValue, col2Value, idValue) } else if isMultiColFkTable(tableName) { - colaValue := rand.Intn(1 + fz.maxValForCol) - colbValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set cola = %v, colb = %v where id = %v", tableName, convertColValueToString(colaValue), convertColValueToString(colbValue), idValue) + if rand.Intn(2) == 0 { + colaValue := convertIntValueToString(rand.Intn(1 + fz.maxValForCol)) + colbValue := convertIntValueToString(rand.Intn(1 + fz.maxValForCol)) + if fz.concurrency > 1 { + colaValue = fz.generateExpression(rand.Intn(4)+1, "cola", "colb", "id") + colbValue = fz.generateExpression(rand.Intn(4)+1, "cola", "colb", "id") + } + return fmt.Sprintf("update %v set cola = %v, colb = %v where id = %v", tableName, colaValue, colbValue, idValue) + } else { + colValue := fz.generateExpression(rand.Intn(4)+1, "cola", "colb", "id") + colToUpdate := []string{"cola", "colb"}[rand.Intn(2)] + return fmt.Sprintf("update %v set %v = %v where id = %v", tableName, colToUpdate, colValue, idValue) + } } else { - colValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set col = %v where id = %v", tableName, convertColValueToString(colValue), idValue) + colValue := fz.generateExpression(rand.Intn(4)+1, "col", "id") + return fmt.Sprintf("update %v set col = %v where id = %v", tableName, colValue, idValue) } } @@ -222,13 +223,16 @@ func (fz *fuzzer) runFuzzerThread(t *testing.T, sharded bool, fuzzerThreadId int _, _ = vitessDb.Exec("use `uks`") } } + if fz.queryFormat == OlapSQLQueries { + _ = utils.Exec(t, mcmp.VtConn, "set workload = olap") + } for { // If fuzzer thread is marked to be stopped, then we should exit this go routine. if fz.shouldStop.Load() == true { return } switch fz.queryFormat { - case SQLQueries, PreparedStatmentQueries: + case OlapSQLQueries, SQLQueries, PreparedStatmentQueries: if fz.generateAndExecuteStatementQuery(t, mcmp) { return } @@ -338,8 +342,8 @@ func (fz *fuzzer) getPreparedInsertQueries() []string { return []string{ "prepare stmt_insert from 'insert into fk_t20 (id, col, col2) values (?, ?, ?)'", fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @col = %v", convertColValueToString(colValue)), - fmt.Sprintf("SET @col2 = %v", convertColValueToString(col2Value)), + fmt.Sprintf("SET @col = %v", convertIntValueToString(colValue)), + fmt.Sprintf("SET @col2 = %v", convertIntValueToString(col2Value)), "execute stmt_insert using @id, @col, @col2", } } else if isMultiColFkTable(tableName) { @@ -348,8 +352,8 @@ func (fz *fuzzer) getPreparedInsertQueries() []string { return []string{ fmt.Sprintf("prepare stmt_insert from 'insert into %v (id, cola, colb) values (?, ?, ?)'", tableName), fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @cola = %v", convertColValueToString(colaValue)), - fmt.Sprintf("SET @colb = %v", convertColValueToString(colbValue)), + fmt.Sprintf("SET @cola = %v", convertIntValueToString(colaValue)), + fmt.Sprintf("SET @colb = %v", convertIntValueToString(colbValue)), "execute stmt_insert using @id, @cola, @colb", } } else { @@ -357,7 +361,7 @@ func (fz *fuzzer) getPreparedInsertQueries() []string { return []string{ fmt.Sprintf("prepare stmt_insert from 'insert into %v (id, col) values (?, ?)'", tableName), fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @col = %v", convertColValueToString(colValue)), + fmt.Sprintf("SET @col = %v", convertIntValueToString(colValue)), "execute stmt_insert using @id, @col", } } @@ -374,8 +378,8 @@ func (fz *fuzzer) getPreparedUpdateQueries() []string { return []string{ "prepare stmt_update from 'update fk_t20 set col = ?, col2 = ? where id = ?'", fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @col = %v", convertColValueToString(colValue)), - fmt.Sprintf("SET @col2 = %v", convertColValueToString(col2Value)), + fmt.Sprintf("SET @col = %v", convertIntValueToString(colValue)), + fmt.Sprintf("SET @col2 = %v", convertIntValueToString(col2Value)), "execute stmt_update using @col, @col2, @id", } } else if isMultiColFkTable(tableName) { @@ -384,8 +388,8 @@ func (fz *fuzzer) getPreparedUpdateQueries() []string { return []string{ fmt.Sprintf("prepare stmt_update from 'update %v set cola = ?, colb = ? where id = ?'", tableName), fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @cola = %v", convertColValueToString(colaValue)), - fmt.Sprintf("SET @colb = %v", convertColValueToString(colbValue)), + fmt.Sprintf("SET @cola = %v", convertIntValueToString(colaValue)), + fmt.Sprintf("SET @colb = %v", convertIntValueToString(colbValue)), "execute stmt_update using @cola, @colb, @id", } } else { @@ -393,7 +397,7 @@ func (fz *fuzzer) getPreparedUpdateQueries() []string { return []string{ fmt.Sprintf("prepare stmt_update from 'update %v set col = ? where id = ?'", tableName), fmt.Sprintf("SET @id = %v", idValue), - fmt.Sprintf("SET @col = %v", convertColValueToString(colValue)), + fmt.Sprintf("SET @col = %v", convertIntValueToString(colValue)), "execute stmt_update using @col, @id", } } @@ -419,14 +423,14 @@ func (fz *fuzzer) generateParameterizedInsertQuery() (query string, params []any if tableName == "fk_t20" { colValue := rand.Intn(1 + fz.maxValForCol) col2Value := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, col, col2) values (?, ?, ?)", tableName), []any{idValue, convertColValueToString(colValue), convertColValueToString(col2Value)} + return fmt.Sprintf("insert into %v (id, col, col2) values (?, ?, ?)", tableName), []any{idValue, convertIntValueToString(colValue), convertIntValueToString(col2Value)} } else if isMultiColFkTable(tableName) { colaValue := rand.Intn(1 + fz.maxValForCol) colbValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, cola, colb) values (?, ?, ?)", tableName), []any{idValue, convertColValueToString(colaValue), convertColValueToString(colbValue)} + return fmt.Sprintf("insert into %v (id, cola, colb) values (?, ?, ?)", tableName), []any{idValue, convertIntValueToString(colaValue), convertIntValueToString(colbValue)} } else { colValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("insert into %v (id, col) values (?, ?)", tableName), []any{idValue, convertColValueToString(colValue)} + return fmt.Sprintf("insert into %v (id, col) values (?, ?)", tableName), []any{idValue, convertIntValueToString(colValue)} } } @@ -438,14 +442,14 @@ func (fz *fuzzer) generateParameterizedUpdateQuery() (query string, params []any if tableName == "fk_t20" { colValue := rand.Intn(1 + fz.maxValForCol) col2Value := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set col = ?, col2 = ? where id = ?", tableName), []any{convertColValueToString(colValue), convertColValueToString(col2Value), idValue} + return fmt.Sprintf("update %v set col = ?, col2 = ? where id = ?", tableName), []any{convertIntValueToString(colValue), convertIntValueToString(col2Value), idValue} } else if isMultiColFkTable(tableName) { colaValue := rand.Intn(1 + fz.maxValForCol) colbValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set cola = ?, colb = ? where id = ?", tableName), []any{convertColValueToString(colaValue), convertColValueToString(colbValue), idValue} + return fmt.Sprintf("update %v set cola = ?, colb = ? where id = ?", tableName), []any{convertIntValueToString(colaValue), convertIntValueToString(colbValue), idValue} } else { colValue := rand.Intn(1 + fz.maxValForCol) - return fmt.Sprintf("update %v set col = ? where id = ?", tableName), []any{convertColValueToString(colValue), idValue} + return fmt.Sprintf("update %v set col = ? where id = ?", tableName), []any{convertIntValueToString(colValue), idValue} } } @@ -606,7 +610,7 @@ func TestFkFuzzTest(t *testing.T) { for _, tt := range testcases { for _, testSharded := range []bool{false, true} { - for _, queryFormat := range []QueryFormat{SQLQueries, PreparedStatmentQueries, PreparedStatementPacket} { + for _, queryFormat := range []QueryFormat{OlapSQLQueries, SQLQueries, PreparedStatmentQueries, PreparedStatementPacket} { t.Run(getTestName(tt.name, testSharded)+fmt.Sprintf(" QueryFormat - %v", queryFormat), func(t *testing.T) { mcmp, closer := start(t) defer closer() @@ -650,85 +654,3 @@ func TestFkFuzzTest(t *testing.T) { } } } - -// ensureDatabaseState ensures that the database is either empty or not. -func ensureDatabaseState(t *testing.T, vtconn *mysql.Conn, empty bool) { - results := collectFkTablesState(vtconn) - isEmpty := true - for _, res := range results { - if len(res.Rows) > 0 { - isEmpty = false - } - } - require.Equal(t, isEmpty, empty) -} - -// verifyDataIsCorrect verifies that the data in MySQL database matches the data in the Vitess database. -func verifyDataIsCorrect(t *testing.T, mcmp utils.MySQLCompare, concurrency int) { - // For single concurrent thread, we run all the queries on both MySQL and Vitess, so we can verify correctness - // by just checking if the data in MySQL and Vitess match. - if concurrency == 1 { - for _, table := range fkTables { - query := fmt.Sprintf("SELECT * FROM %v ORDER BY id", table) - mcmp.Exec(query) - } - } else { - // For higher concurrency, we don't have MySQL data to verify everything is fine, - // so we'll have to do something different. - // We run LEFT JOIN queries on all the parent and child tables linked by foreign keys - // to make sure that nothing is broken in the database. - for _, reference := range fkReferences { - query := fmt.Sprintf("select %v.id from %v left join %v on (%v.col = %v.col) where %v.col is null and %v.col is not null", reference.childTable, reference.childTable, reference.parentTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable) - if isMultiColFkTable(reference.childTable) { - query = fmt.Sprintf("select %v.id from %v left join %v on (%v.cola = %v.cola and %v.colb = %v.colb) where %v.cola is null and %v.cola is not null and %v.colb is not null", reference.childTable, reference.childTable, reference.parentTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable, reference.childTable) - } - res, err := mcmp.VtConn.ExecuteFetch(query, 1000, false) - require.NoError(t, err) - require.Zerof(t, len(res.Rows), "Query %v gave non-empty results", query) - } - } - // We also verify that the results in Primary and Replica table match as is. - for _, keyspace := range clusterInstance.Keyspaces { - for _, shard := range keyspace.Shards { - var primaryTab, replicaTab *cluster.Vttablet - for _, vttablet := range shard.Vttablets { - if vttablet.Type == "primary" { - primaryTab = vttablet - } else { - replicaTab = vttablet - } - } - require.NotNil(t, primaryTab) - require.NotNil(t, replicaTab) - checkReplicationHealthy(t, replicaTab) - cluster.WaitForReplicationPos(t, primaryTab, replicaTab, true, 60.0) - primaryConn, err := utils.GetMySQLConn(primaryTab, fmt.Sprintf("vt_%v", keyspace.Name)) - require.NoError(t, err) - replicaConn, err := utils.GetMySQLConn(replicaTab, fmt.Sprintf("vt_%v", keyspace.Name)) - require.NoError(t, err) - primaryRes := collectFkTablesState(primaryConn) - replicaRes := collectFkTablesState(replicaConn) - verifyDataMatches(t, primaryRes, replicaRes) - } - } -} - -// verifyDataMatches verifies that the two list of results are the same. -func verifyDataMatches(t *testing.T, resOne []*sqltypes.Result, resTwo []*sqltypes.Result) { - require.EqualValues(t, len(resTwo), len(resOne), "Res 1 - %v, Res 2 - %v", resOne, resTwo) - for idx, resultOne := range resOne { - resultTwo := resTwo[idx] - require.True(t, resultOne.Equal(resultTwo), "Data for %v doesn't match\nRows 1\n%v\nRows 2\n%v", fkTables[idx], resultOne.Rows, resultTwo.Rows) - } -} - -// collectFkTablesState collects the data stored in the foreign key tables for the given connection. -func collectFkTablesState(conn *mysql.Conn) []*sqltypes.Result { - var tablesData []*sqltypes.Result - for _, table := range fkTables { - query := fmt.Sprintf("SELECT * FROM %v ORDER BY id", table) - res, _ := conn.ExecuteFetch(query, 10000, true) - tablesData = append(tablesData, res) - } - return tablesData -} diff --git a/go/test/endtoend/vtgate/foreignkey/fk_test.go b/go/test/endtoend/vtgate/foreignkey/fk_test.go index 46bbc2ed433..ff7ddef66ff 100644 --- a/go/test/endtoend/vtgate/foreignkey/fk_test.go +++ b/go/test/endtoend/vtgate/foreignkey/fk_test.go @@ -27,6 +27,7 @@ import ( "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" + "vitess.io/vitess/go/vt/log" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/vtgate/vtgateconn" @@ -775,6 +776,181 @@ func TestFkScenarios(t *testing.T) { } } +// TestFkQueries is for testing a specific set of queries one after the other. +func TestFkQueries(t *testing.T) { + // Wait for schema-tracking to be complete. + waitForSchemaTrackingForFkTables(t) + // Remove all the foreign key constraints for all the replicas. + // We can then verify that the replica, and the primary have the same data, to ensure + // that none of the queries ever lead to cascades/updates on MySQL level. + for _, ks := range []string{shardedKs, unshardedKs} { + replicas := getReplicaTablets(ks) + for _, replica := range replicas { + removeAllForeignKeyConstraints(t, replica, ks) + } + } + + testcases := []struct { + name string + queries []string + }{ + { + name: "Non-literal update", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "update fk_t10 set col = id + 3", + }, + }, { + name: "Non-literal update with order by", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "update fk_t10 set col = id + 3 order by id desc", + }, + }, { + name: "Non-literal update with order by that require parent and child foreign keys verification - success", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5),(6,6),(7,7),(8,8)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t12 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t13 (id, col) values (1,1),(2,2)", + "update fk_t11 set col = id + 3 where id >= 3", + }, + }, { + name: "Non-literal update with order by that require parent and child foreign keys verification - parent fails", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t12 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "update fk_t11 set col = id + 3", + }, + }, { + name: "Non-literal update with order by that require parent and child foreign keys verification - child fails", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5),(6,6),(7,7),(8,8)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t12 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t13 (id, col) values (1,1),(2,2)", + "update fk_t11 set col = id + 3", + }, + }, { + name: "Single column update in a multi-col table - success", + queries: []string{ + "insert into fk_multicol_t1 (id, cola, colb) values (1, 1, 1), (2, 2, 2)", + "insert into fk_multicol_t2 (id, cola, colb) values (1, 1, 1)", + "update fk_multicol_t1 set colb = 4 + (colb) where id = 2", + }, + }, { + name: "Single column update in a multi-col table - restrict failure", + queries: []string{ + "insert into fk_multicol_t1 (id, cola, colb) values (1, 1, 1), (2, 2, 2)", + "insert into fk_multicol_t2 (id, cola, colb) values (1, 1, 1)", + "update fk_multicol_t1 set colb = 4 + (colb) where id = 1", + }, + }, { + name: "Single column update in multi-col table - cascade and set null", + queries: []string{ + "insert into fk_multicol_t15 (id, cola, colb) values (1, 1, 1), (2, 2, 2)", + "insert into fk_multicol_t16 (id, cola, colb) values (1, 1, 1), (2, 2, 2)", + "insert into fk_multicol_t17 (id, cola, colb) values (1, 1, 1), (2, 2, 2)", + "update fk_multicol_t15 set colb = 4 + (colb) where id = 1", + }, + }, { + name: "Non literal update that evaluates to NULL - restricted", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t13 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "update fk_t10 set col = id + null where id = 1", + }, + }, { + name: "Non literal update that evaluates to NULL - success", + queries: []string{ + "insert into fk_t10 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t11 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "insert into fk_t12 (id, col) values (1,1),(2,2),(3,3),(4,4),(5,5)", + "update fk_t10 set col = id + null where id = 1", + }, + }, { + name: "Multi column foreign key update with one literal and one non-literal update", + queries: []string{ + "insert into fk_multicol_t15 (id, cola, colb) values (1,1,1),(2,2,2)", + "insert into fk_multicol_t16 (id, cola, colb) values (1,1,1),(2,2,2)", + "update fk_multicol_t15 set cola = 3, colb = (id * 2) - 2", + }, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.name, func(t *testing.T) { + mcmp, closer := start(t) + defer closer() + _ = utils.Exec(t, mcmp.VtConn, "use `uks`") + + // Ensure that the Vitess database is originally empty + ensureDatabaseState(t, mcmp.VtConn, true) + ensureDatabaseState(t, mcmp.MySQLConn, true) + + for _, query := range testcase.queries { + _, _ = mcmp.ExecAllowAndCompareError(query) + if t.Failed() { + break + } + } + + // ensure Vitess database has some data. This ensures not all the commands failed. + ensureDatabaseState(t, mcmp.VtConn, false) + // Verify the consistency of the data. + verifyDataIsCorrect(t, mcmp, 1) + }) + } +} + +// TestFkOneCase is for testing a specific set of queries. On the CI this test won't run since we'll keep the queries empty. +func TestFkOneCase(t *testing.T) { + queries := []string{} + if len(queries) == 0 { + t.Skip("No queries to test") + } + // Wait for schema-tracking to be complete. + waitForSchemaTrackingForFkTables(t) + // Remove all the foreign key constraints for all the replicas. + // We can then verify that the replica, and the primary have the same data, to ensure + // that none of the queries ever lead to cascades/updates on MySQL level. + for _, ks := range []string{shardedKs, unshardedKs} { + replicas := getReplicaTablets(ks) + for _, replica := range replicas { + removeAllForeignKeyConstraints(t, replica, ks) + } + } + + mcmp, closer := start(t) + defer closer() + _ = utils.Exec(t, mcmp.VtConn, "use `uks`") + + // Ensure that the Vitess database is originally empty + ensureDatabaseState(t, mcmp.VtConn, true) + ensureDatabaseState(t, mcmp.MySQLConn, true) + + for _, query := range queries { + _, _ = mcmp.ExecAllowAndCompareError(query) + if t.Failed() { + log.Errorf("Query failed - %v", query) + break + } + } + vitessData := collectFkTablesState(mcmp.VtConn) + for idx, table := range fkTables { + log.Errorf("Vitess data for %v -\n%v", table, vitessData[idx].Rows) + } + + // ensure Vitess database has some data. This ensures not all the commands failed. + ensureDatabaseState(t, mcmp.VtConn, false) + // Verify the consistency of the data. + verifyDataIsCorrect(t, mcmp, 1) +} + func TestCyclicFks(t *testing.T) { mcmp, closer := start(t) defer closer() diff --git a/go/test/endtoend/vtgate/foreignkey/utils_test.go b/go/test/endtoend/vtgate/foreignkey/utils_test.go index 5e0b4a8a3cc..b97e57d685c 100644 --- a/go/test/endtoend/vtgate/foreignkey/utils_test.go +++ b/go/test/endtoend/vtgate/foreignkey/utils_test.go @@ -19,15 +19,21 @@ package foreignkey import ( "database/sql" "fmt" + "math/rand" "strings" "testing" + "time" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" ) +var supportedOpps = []string{"*", "+", "-"} + // getTestName prepends whether the test is for a sharded keyspace or not to the test name. func getTestName(testName string, testSharded bool) string { if testSharded { @@ -41,6 +47,32 @@ func isMultiColFkTable(tableName string) bool { return strings.Contains(tableName, "multicol") } +func (fz *fuzzer) generateExpression(length int, cols ...string) string { + expr := fz.getColOrInt(cols...) + if length == 1 { + return expr + } + rhsExpr := fz.generateExpression(length-1, cols...) + op := supportedOpps[rand.Intn(len(supportedOpps))] + return fmt.Sprintf("%v %s (%v)", expr, op, rhsExpr) +} + +// getColOrInt gets a column or an integer/NULL literal with equal probability. +func (fz *fuzzer) getColOrInt(cols ...string) string { + if len(cols) == 0 || rand.Intn(2) == 0 { + return convertIntValueToString(rand.Intn(1 + fz.maxValForCol)) + } + return cols[rand.Intn(len(cols))] +} + +// convertIntValueToString converts the given value to a string +func convertIntValueToString(value int) string { + if value == 0 { + return "NULL" + } + return fmt.Sprintf("%d", value) +} + // waitForSchemaTrackingForFkTables waits for schema tracking to have run and seen the tables used // for foreign key tests. func waitForSchemaTrackingForFkTables(t *testing.T) { @@ -142,3 +174,85 @@ func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) { out := fmt.Sprintf("Vitess and MySQL are not erroring the same way.\nVitess error: %v\nMySQL error: %v", vtErr, mysqlErr) t.Error(out) } + +// ensureDatabaseState ensures that the database is either empty or not. +func ensureDatabaseState(t *testing.T, vtconn *mysql.Conn, empty bool) { + results := collectFkTablesState(vtconn) + isEmpty := true + for _, res := range results { + if len(res.Rows) > 0 { + isEmpty = false + } + } + require.Equal(t, isEmpty, empty) +} + +// verifyDataIsCorrect verifies that the data in MySQL database matches the data in the Vitess database. +func verifyDataIsCorrect(t *testing.T, mcmp utils.MySQLCompare, concurrency int) { + // For single concurrent thread, we run all the queries on both MySQL and Vitess, so we can verify correctness + // by just checking if the data in MySQL and Vitess match. + if concurrency == 1 { + for _, table := range fkTables { + query := fmt.Sprintf("SELECT * FROM %v ORDER BY id", table) + mcmp.Exec(query) + } + } else { + // For higher concurrency, we don't have MySQL data to verify everything is fine, + // so we'll have to do something different. + // We run LEFT JOIN queries on all the parent and child tables linked by foreign keys + // to make sure that nothing is broken in the database. + for _, reference := range fkReferences { + query := fmt.Sprintf("select %v.id from %v left join %v on (%v.col = %v.col) where %v.col is null and %v.col is not null", reference.childTable, reference.childTable, reference.parentTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable) + if isMultiColFkTable(reference.childTable) { + query = fmt.Sprintf("select %v.id from %v left join %v on (%v.cola = %v.cola and %v.colb = %v.colb) where %v.cola is null and %v.cola is not null and %v.colb is not null", reference.childTable, reference.childTable, reference.parentTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable, reference.parentTable, reference.childTable, reference.childTable) + } + res, err := mcmp.VtConn.ExecuteFetch(query, 1000, false) + require.NoError(t, err) + require.Zerof(t, len(res.Rows), "Query %v gave non-empty results", query) + } + } + // We also verify that the results in Primary and Replica table match as is. + for _, keyspace := range clusterInstance.Keyspaces { + for _, shard := range keyspace.Shards { + var primaryTab, replicaTab *cluster.Vttablet + for _, vttablet := range shard.Vttablets { + if vttablet.Type == "primary" { + primaryTab = vttablet + } else { + replicaTab = vttablet + } + } + require.NotNil(t, primaryTab) + require.NotNil(t, replicaTab) + checkReplicationHealthy(t, replicaTab) + cluster.WaitForReplicationPos(t, primaryTab, replicaTab, true, 1*time.Minute) + primaryConn, err := utils.GetMySQLConn(primaryTab, fmt.Sprintf("vt_%v", keyspace.Name)) + require.NoError(t, err) + replicaConn, err := utils.GetMySQLConn(replicaTab, fmt.Sprintf("vt_%v", keyspace.Name)) + require.NoError(t, err) + primaryRes := collectFkTablesState(primaryConn) + replicaRes := collectFkTablesState(replicaConn) + verifyDataMatches(t, primaryRes, replicaRes) + } + } +} + +// verifyDataMatches verifies that the two list of results are the same. +func verifyDataMatches(t *testing.T, resOne []*sqltypes.Result, resTwo []*sqltypes.Result) { + require.EqualValues(t, len(resTwo), len(resOne), "Res 1 - %v, Res 2 - %v", resOne, resTwo) + for idx, resultOne := range resOne { + resultTwo := resTwo[idx] + require.True(t, resultOne.Equal(resultTwo), "Data for %v doesn't match\nRows 1\n%v\nRows 2\n%v", fkTables[idx], resultOne.Rows, resultTwo.Rows) + } +} + +// collectFkTablesState collects the data stored in the foreign key tables for the given connection. +func collectFkTablesState(conn *mysql.Conn) []*sqltypes.Result { + var tablesData []*sqltypes.Result + for _, table := range fkTables { + query := fmt.Sprintf("SELECT * FROM %v ORDER BY id", table) + res, _ := conn.ExecuteFetch(query, 10000, true) + tablesData = append(tablesData, res) + } + return tablesData +} diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index b70f83b192d..d878650c63e 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -280,7 +280,7 @@ func (cached *FkChild) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(64) + size += int64(80) } // field BVName string size += hack.RuntimeAllocSize(int64(len(cached.BVName))) @@ -288,6 +288,13 @@ func (cached *FkChild) CachedSize(alloc bool) int64 { { size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(8)) } + // field NonLiteralInfo []vitess.io/vitess/go/vt/vtgate/engine.NonLiteralUpdateInfo + { + size += hack.RuntimeAllocSize(int64(cap(cached.NonLiteralInfo)) * int64(32)) + for _, elem := range cached.NonLiteralInfo { + size += elem.CachedSize(false) + } + } // field Exec vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Exec.(cachedObject); ok { size += cc.CachedSize(true) @@ -612,6 +619,18 @@ func (cached *MergeSort) CachedSize(alloc bool) int64 { } return size } +func (cached *NonLiteralUpdateInfo) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(32) + } + // field UpdateExprBvName string + size += hack.RuntimeAllocSize(int64(len(cached.UpdateExprBvName))) + return size +} func (cached *OnlineDDL) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/fk_cascade.go b/go/vt/vtgate/engine/fk_cascade.go index d0bddbea8f9..691e326fec7 100644 --- a/go/vt/vtgate/engine/fk_cascade.go +++ b/go/vt/vtgate/engine/fk_cascade.go @@ -19,6 +19,7 @@ package engine import ( "context" "fmt" + "maps" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -29,9 +30,24 @@ import ( // FkChild contains the Child Primitive to be executed collecting the values from the Selection Primitive using the column indexes. // BVName is used to pass the value as bind variable to the Child Primitive. type FkChild struct { + // BVName is the bind variable name for the tuple bind variable used in the primitive. BVName string - Cols []int // indexes - Exec Primitive + // Cols are the indexes of the column that need to be selected from the SELECT query to create the tuple bind variable. + Cols []int + // NonLiteralInfo stores the information that is needed to run an update query with non-literal values. + NonLiteralInfo []NonLiteralUpdateInfo + Exec Primitive +} + +// NonLiteralUpdateInfo stores the information required to process non-literal update queries. +// It stores 3 information- +// 1. UpdateExprCol- The index of the updated expression in the select query. +// 2. UpdateExprBvName- The bind variable name to store the updated expression into. +// 3. CompExprCol- The index of the comparison expression in the select query to know if the row value is actually being changed or not. +type NonLiteralUpdateInfo struct { + UpdateExprCol int + UpdateExprBvName string + CompExprCol int } // FkCascade is a primitive that implements foreign key cascading using Selection as values required to execute the FkChild Primitives. @@ -82,81 +98,110 @@ func (fkc *FkCascade) TryExecute(ctx context.Context, vcursor VCursor, bindVars } for _, child := range fkc.Children { - // We create a bindVariable for each Child - // that stores the tuple of columns involved in the fk constraint. - bv := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, + // Having non-empty UpdateExprBvNames is an indication that we have an update query with non-literal expressions in it. + // We need to run this query differently because we need to run an update for each row we get back from the SELECT. + if len(child.NonLiteralInfo) > 0 { + err = fkc.executeNonLiteralExprFkChild(ctx, vcursor, bindVars, wantfields, selectionRes, child) + } else { + err = fkc.executeLiteralExprFkChild(ctx, vcursor, bindVars, wantfields, selectionRes, child, false) } - for _, row := range selectionRes.Rows { - var tupleValues []sqltypes.Value - for _, colIdx := range child.Cols { - tupleValues = append(tupleValues, row[colIdx]) - } - bv.Values = append(bv.Values, sqltypes.TupleToProto(tupleValues)) - } - // Execute the child primitive, and bail out incase of failure. - // Since this Primitive is always executed in a transaction, the changes should - // be rolled back incase of an error. - bindVars[child.BVName] = bv - _, err = vcursor.ExecutePrimitive(ctx, child.Exec, bindVars, wantfields) if err != nil { return nil, err } - delete(bindVars, child.BVName) } // All the children are modified successfully, we can now execute the Parent Primitive. return vcursor.ExecutePrimitive(ctx, fkc.Parent, bindVars, wantfields) } -// TryStreamExecute implements the Primitive interface. -func (fkc *FkCascade) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - // We create a bindVariable for each Child - // that stores the tuple of columns involved in the fk constraint. - var bindVariables []*querypb.BindVariable - for range fkc.Children { - bindVariables = append(bindVariables, &querypb.BindVariable{ - Type: querypb.Type_TUPLE, - }) +func (fkc *FkCascade) executeLiteralExprFkChild(ctx context.Context, vcursor VCursor, in map[string]*querypb.BindVariable, wantfields bool, selectionRes *sqltypes.Result, child *FkChild, isStreaming bool) error { + bindVars := maps.Clone(in) + // We create a bindVariable that stores the tuple of columns involved in the fk constraint. + bv := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, } + for _, row := range selectionRes.Rows { + var tupleValues []sqltypes.Value - // Execute the Selection primitive to find the rows that are going to modified. - // This will be used to find the rows that need modification on the children. - err := vcursor.StreamExecutePrimitive(ctx, fkc.Selection, bindVars, wantfields, func(result *sqltypes.Result) error { - if len(result.Rows) == 0 { - return nil - } - for idx, child := range fkc.Children { - for _, row := range result.Rows { - var tupleValues []sqltypes.Value - for _, colIdx := range child.Cols { - tupleValues = append(tupleValues, row[colIdx]) - } - bindVariables[idx].Values = append(bindVariables[idx].Values, sqltypes.TupleToProto(tupleValues)) - } + for _, colIdx := range child.Cols { + tupleValues = append(tupleValues, row[colIdx]) } - return nil - }) - if err != nil { - return err + bv.Values = append(bv.Values, sqltypes.TupleToProto(tupleValues)) } - // Execute the child primitive, and bail out incase of failure. // Since this Primitive is always executed in a transaction, the changes should // be rolled back incase of an error. - for idx, child := range fkc.Children { - bindVars[child.BVName] = bindVariables[idx] - err = vcursor.StreamExecutePrimitive(ctx, child.Exec, bindVars, wantfields, func(result *sqltypes.Result) error { - return nil - }) + bindVars[child.BVName] = bv + var err error + if isStreaming { + err = vcursor.StreamExecutePrimitive(ctx, child.Exec, bindVars, wantfields, func(result *sqltypes.Result) error { return nil }) + } else { + _, err = vcursor.ExecutePrimitive(ctx, child.Exec, bindVars, wantfields) + } + if err != nil { + return err + } + return nil +} + +func (fkc *FkCascade) executeNonLiteralExprFkChild(ctx context.Context, vcursor VCursor, in map[string]*querypb.BindVariable, wantfields bool, selectionRes *sqltypes.Result, child *FkChild) error { + // For each row in the SELECT we need to run the child primitive. + for _, row := range selectionRes.Rows { + bindVars := maps.Clone(in) + // First we check if any of the columns is being updated at all. + skipRow := true + for _, info := range child.NonLiteralInfo { + // We use a null-safe comparison, so the value is guaranteed to be not null. + // We check if the column has updated or not. + isUnchanged, err := row[info.CompExprCol].ToBool() + if err != nil { + return err + } + if !isUnchanged { + // If any column has changed, then we can't skip this row. + // We need to execute the child primitive. + skipRow = false + break + } + } + // If none of the columns have changed, then there is no update to cascade, we can move on. + if skipRow { + continue + } + // We create a bindVariable that stores the tuple of columns involved in the fk constraint. + bv := &querypb.BindVariable{ + Type: querypb.Type_TUPLE, + } + // Create a tuple from the Row. + var tupleValues []sqltypes.Value + for _, colIdx := range child.Cols { + tupleValues = append(tupleValues, row[colIdx]) + } + bv.Values = append(bv.Values, sqltypes.TupleToProto(tupleValues)) + // Execute the child primitive, and bail out incase of failure. + // Since this Primitive is always executed in a transaction, the changes should + // be rolled back in case of an error. + bindVars[child.BVName] = bv + + // Next, we need to copy the updated expressions value into the bind variables map. + for _, info := range child.NonLiteralInfo { + bindVars[info.UpdateExprBvName] = sqltypes.ValueBindVariable(row[info.UpdateExprCol]) + } + _, err := vcursor.ExecutePrimitive(ctx, child.Exec, bindVars, wantfields) if err != nil { return err } - delete(bindVars, child.BVName) } + return nil +} - // All the children are modified successfully, we can now execute the Parent Primitive. - return vcursor.StreamExecutePrimitive(ctx, fkc.Parent, bindVars, wantfields, callback) +// TryStreamExecute implements the Primitive interface. +func (fkc *FkCascade) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { + res, err := fkc.TryExecute(ctx, vcursor, bindVars, wantfields) + if err != nil { + return err + } + return callback(res) } // Inputs implements the Primitive interface. @@ -168,11 +213,15 @@ func (fkc *FkCascade) Inputs() ([]Primitive, []map[string]any) { inputName: "Selection", }) for idx, child := range fkc.Children { - inputsMap = append(inputsMap, map[string]any{ + childInfoMap := map[string]any{ inputName: fmt.Sprintf("CascadeChild-%d", idx+1), "BvName": child.BVName, "Cols": child.Cols, - }) + } + if len(child.NonLiteralInfo) > 0 { + childInfoMap["NonLiteralUpdateInfo"] = child.NonLiteralInfo + } + inputsMap = append(inputsMap, childInfoMap) inputs = append(inputs, child.Exec) } inputs = append(inputs, fkc.Parent) diff --git a/go/vt/vtgate/engine/fk_cascade_test.go b/go/vt/vtgate/engine/fk_cascade_test.go index ddd381003b1..942fe44a709 100644 --- a/go/vt/vtgate/engine/fk_cascade_test.go +++ b/go/vt/vtgate/engine/fk_cascade_test.go @@ -80,7 +80,7 @@ func TestDeleteCascade(t *testing.T) { require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select cola, colb from parent where foo = 48 ks.0: {} `, + `ExecuteMultiShard ks.0: select cola, colb from parent where foo = 48 {} false false`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: delete from child where (ca, cb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x011\x950\x01a"} values:{type:TUPLE value:"\x89\x02\x012\x950\x01b"}} true true`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, @@ -141,7 +141,7 @@ func TestUpdateCascade(t *testing.T) { require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select cola, colb from parent where foo = 48 ks.0: {} `, + `ExecuteMultiShard ks.0: select cola, colb from parent where foo = 48 {} false false`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: update child set ca = :vtg1 where (ca, cb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x011\x950\x01a"} values:{type:TUPLE value:"\x89\x02\x012\x950\x01b"}} true true`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, @@ -149,6 +149,82 @@ func TestUpdateCascade(t *testing.T) { }) } +// TestNonLiteralUpdateCascade tests that FkCascade executes the child and parent primitives for a non-literal update cascade. +func TestNonLiteralUpdateCascade(t *testing.T) { + fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("cola|cola <=> colb + 2|colb + 2", "int64|int64|int64"), "1|1|3", "2|0|5", "3|0|7") + + inputP := &Route{ + Query: "select cola, cola <=> colb + 2, colb + 2, from parent where foo = 48", + RoutingParameters: &RoutingParameters{ + Opcode: Unsharded, + Keyspace: &vindexes.Keyspace{Name: "ks"}, + }, + } + childP := &Update{ + DML: &DML{ + Query: "update child set ca = :fkc_upd where (ca) in ::__vals", + RoutingParameters: &RoutingParameters{ + Opcode: Unsharded, + Keyspace: &vindexes.Keyspace{Name: "ks"}, + }, + }, + } + parentP := &Update{ + DML: &DML{ + Query: "update parent set cola = colb + 2 where foo = 48", + RoutingParameters: &RoutingParameters{ + Opcode: Unsharded, + Keyspace: &vindexes.Keyspace{Name: "ks"}, + }, + }, + } + fkc := &FkCascade{ + Selection: inputP, + Children: []*FkChild{{ + BVName: "__vals", + Cols: []int{0}, + NonLiteralInfo: []NonLiteralUpdateInfo{ + { + UpdateExprBvName: "fkc_upd", + UpdateExprCol: 2, + CompExprCol: 1, + }, + }, + Exec: childP, + }}, + Parent: parentP, + } + + vc := newDMLTestVCursor("0") + vc.results = []*sqltypes.Result{fakeRes} + _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: select cola, cola <=> colb + 2, colb + 2, from parent where foo = 48 {} false false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update child set ca = :fkc_upd where (ca) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x012"} fkc_upd: type:INT64 value:"5"} true true`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update child set ca = :fkc_upd where (ca) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x013"} fkc_upd: type:INT64 value:"7"} true true`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update parent set cola = colb + 2 where foo = 48 {} true true`, + }) + + vc.Rewind() + err = fkc.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { return nil }) + require.NoError(t, err) + vc.ExpectLog(t, []string{ + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: select cola, cola <=> colb + 2, colb + 2, from parent where foo = 48 {} false false`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update child set ca = :fkc_upd where (ca) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x012"} fkc_upd: type:INT64 value:"5"} true true`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update child set ca = :fkc_upd where (ca) in ::__vals {__vals: type:TUPLE values:{type:TUPLE value:"\x89\x02\x013"} fkc_upd: type:INT64 value:"7"} true true`, + `ResolveDestinations ks [] Destinations:DestinationAllShards()`, + `ExecuteMultiShard ks.0: update parent set cola = colb + 2 where foo = 48 {} true true`, + }) +} + // TestNeedsTransactionInExecPrepared tests that if we have a foreign key cascade inside an ExecStmt plan, then we do mark the plan to require a transaction. func TestNeedsTransactionInExecPrepared(t *testing.T) { // Even if FkCascade is wrapped in ExecStmt, the plan should be marked such that it requires a transaction. diff --git a/go/vt/vtgate/engine/fk_verify.go b/go/vt/vtgate/engine/fk_verify.go index 350aeec59e0..8625fbe2362 100644 --- a/go/vt/vtgate/engine/fk_verify.go +++ b/go/vt/vtgate/engine/fk_verify.go @@ -83,18 +83,11 @@ func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map // TryStreamExecute implements the Primitive interface func (f *FkVerify) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - for _, v := range f.Verify { - err := vcursor.StreamExecutePrimitive(ctx, v.Exec, bindVars, wantfields, func(qr *sqltypes.Result) error { - if len(qr.Rows) > 0 { - return getError(v.Typ) - } - return nil - }) - if err != nil { - return err - } + res, err := f.TryExecute(ctx, vcursor, bindVars, wantfields) + if err != nil { + return err } - return vcursor.StreamExecutePrimitive(ctx, f.Exec, bindVars, wantfields, callback) + return callback(res) } // Inputs implements the Primitive interface diff --git a/go/vt/vtgate/engine/fk_verify_test.go b/go/vt/vtgate/engine/fk_verify_test.go index 5635a32bc2c..5c9ff83c2ec 100644 --- a/go/vt/vtgate/engine/fk_verify_test.go +++ b/go/vt/vtgate/engine/fk_verify_test.go @@ -74,7 +74,7 @@ func TestFKVerifyUpdate(t *testing.T) { require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null ks.0: {} `, + `ExecuteMultiShard ks.0: select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null {} false false`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: update child set cola = 1, colb = 'a' where foo = 48 {} true true`, }) @@ -97,7 +97,7 @@ func TestFKVerifyUpdate(t *testing.T) { require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null ks.0: {} `, + `ExecuteMultiShard ks.0: select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null {} false false`, }) }) @@ -119,7 +119,7 @@ func TestFKVerifyUpdate(t *testing.T) { require.ErrorContains(t, err, "Cannot delete or update a parent row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select 1 from grandchild g join child c on g.cola = c.cola and g.colb = c.colb where c.foo = 48 ks.0: {} `, + `ExecuteMultiShard ks.0: select 1 from grandchild g join child c on g.cola = c.cola and g.colb = c.colb where c.foo = 48 {} false false`, }) }) } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index a04c4b00c2c..59200e92e2a 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -139,9 +139,10 @@ func transformFkCascade(ctx *plancontext.PlanningContext, fkc *operators.FkCasca childEngine := childLP.Primitive() children = append(children, &engine.FkChild{ - BVName: child.BVName, - Cols: child.Cols, - Exec: childEngine, + BVName: child.BVName, + Cols: child.Cols, + NonLiteralInfo: child.NonLiteralInfo, + Exec: childEngine, }) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index bc4aaf8a4e6..64d9826a80e 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -28,6 +28,7 @@ import ( ) const foreignKeyConstraintValues = "fkc_vals" +const foreignKeyUpdateExpr = "fkc_upd" // translateQueryToOp creates an operator tree that represents the input SELECT or UNION query func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Statement) (op ops.Operator, err error) { @@ -213,7 +214,10 @@ func createOpFromStmt(ctx *plancontext.PlanningContext, stmt sqlparser.Statement // we should augment the semantic analysis to also tell us whether the given query has any cross shard parent foreign keys to validate. // If there are, then we have to run the query with FOREIGN_KEY_CHECKS off because we can't be sure if the DML will succeed on MySQL with the checks on. // So, we should set VerifyAllFKs to true. i.e. we should add `|| ctx.SemTable.RequireForeignKeyChecksOff()` to the below condition. - ctx.VerifyAllFKs = verifyAllFKs + if verifyAllFKs { + // If ctx.VerifyAllFKs is already true we don't want to turn it off. + ctx.VerifyAllFKs = verifyAllFKs + } // From all the parent foreign keys involved, we should remove the one that we need to ignore. err = ctx.SemTable.RemoveParentForeignKey(fkToIgnore) @@ -397,6 +401,7 @@ func createSelectionOp( selectExprs []sqlparser.SelectExpr, tableExprs sqlparser.TableExprs, where *sqlparser.Where, + orderBy sqlparser.OrderBy, limit *sqlparser.Limit, lock sqlparser.Lock, ) (ops.Operator, error) { @@ -404,20 +409,10 @@ func createSelectionOp( SelectExprs: selectExprs, From: tableExprs, Where: where, + OrderBy: orderBy, Limit: limit, Lock: lock, } // There are no foreign keys to check for a select query, so we can pass anything for verifyAllFKs and fkToIgnore. return createOpFromStmt(ctx, selectionStmt, false /* verifyAllFKs */, "" /* fkToIgnore */) } - -func selectParentColumns(fk vindexes.ChildFKInfo, lastOffset int) ([]int, []sqlparser.SelectExpr) { - var cols []int - var exprs []sqlparser.SelectExpr - for _, column := range fk.ParentColumns { - cols = append(cols, lastOffset) - exprs = append(exprs, aeWrap(sqlparser.NewColName(column.String()))) - lastOffset++ - } - return cols, exprs -} diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index f02444671c1..dd37ed5db01 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -181,16 +181,16 @@ func createFkCascadeOpForDelete(ctx *plancontext.PlanningContext, parentOp ops.O } // We need to select all the parent columns for the foreign key constraint, to use in the update of the child table. - cols, exprs := selectParentColumns(fk, len(selectExprs)) - selectExprs = append(selectExprs, exprs...) + var offsets []int + offsets, selectExprs = addColumns(ctx, fk.ParentColumns, selectExprs) - fkChild, err := createFkChildForDelete(ctx, fk, cols) + fkChild, err := createFkChildForDelete(ctx, fk, offsets) if err != nil { return nil, err } fkChildren = append(fkChildren, fkChild) } - selectionOp, err := createSelectionOp(ctx, selectExprs, delStmt.TableExprs, delStmt.Where, nil, sqlparser.ForUpdateLock) + selectionOp, err := createSelectionOp(ctx, selectExprs, delStmt.TableExprs, delStmt.Where, nil, nil, sqlparser.ForUpdateLock) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/fk_cascade.go b/go/vt/vtgate/planbuilder/operators/fk_cascade.go index 90c797d55e8..73b902a4980 100644 --- a/go/vt/vtgate/planbuilder/operators/fk_cascade.go +++ b/go/vt/vtgate/planbuilder/operators/fk_cascade.go @@ -19,15 +19,17 @@ package operators import ( "slices" + "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) // FkChild is used to represent a foreign key child table operation type FkChild struct { - BVName string - Cols []int // indexes - Op ops.Operator + BVName string + Cols []int // indexes + NonLiteralInfo []engine.NonLiteralUpdateInfo + Op ops.Operator noColumns noPredicates @@ -88,9 +90,10 @@ func (fkc *FkCascade) Clone(inputs []ops.Operator) ops.Operator { } newFkc.Children = append(newFkc.Children, &FkChild{ - BVName: fkc.Children[idx-2].BVName, - Cols: slices.Clone(fkc.Children[idx-2].Cols), - Op: operator, + BVName: fkc.Children[idx-2].BVName, + Cols: slices.Clone(fkc.Children[idx-2].Cols), + NonLiteralInfo: slices.Clone(fkc.Children[idx-2].NonLiteralInfo), + Op: operator, }) } return newFkc diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 743812f9dd7..b7a3a4da2d6 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -125,6 +125,12 @@ func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlpars return nil, vterrors.VT12001("update with limit with foreign key constraints") } + // Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns. + // This is unsafe, and we currently don't support this in Vitess. + if err = ctx.SemTable.ErrIfFkDependentColumnUpdated(updStmt.Exprs); err != nil { + return nil, err + } + return buildFkOperator(ctx, updOp, updClone, parentFks, childFks, vindexTable) } @@ -203,11 +209,6 @@ func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.U } func buildFkOperator(ctx *plancontext.PlanningContext, updOp ops.Operator, updClone *sqlparser.Update, parentFks []vindexes.ParentFKInfo, childFks []vindexes.ChildFKInfo, updatedTable *vindexes.Table) (ops.Operator, error) { - // We only support simple expressions in update queries for foreign key handling. - if isNonLiteral(updClone.Exprs, parentFks, childFks) { - return nil, vterrors.VT12001("update expression with non-literal values with foreign key constraints") - } - restrictChildFks, cascadeChildFks := splitChildFks(childFks) op, err := createFKCascadeOp(ctx, updOp, updClone, cascadeChildFks, updatedTable) @@ -218,25 +219,6 @@ func buildFkOperator(ctx *plancontext.PlanningContext, updOp ops.Operator, updCl return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks) } -func isNonLiteral(updExprs sqlparser.UpdateExprs, parentFks []vindexes.ParentFKInfo, childFks []vindexes.ChildFKInfo) bool { - for _, updateExpr := range updExprs { - if sqlparser.IsLiteral(updateExpr.Expr) { - continue - } - for _, parentFk := range parentFks { - if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { - return true - } - } - for _, childFk := range childFks { - if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { - return true - } - } - } - return false -} - // splitChildFks splits the child foreign keys into restrict and cascade list as restrict is handled through Verify operator and cascade is handled through Cascade operator. func splitChildFks(fks []vindexes.ChildFKInfo) (restrictChildFks, cascadeChildFks []vindexes.ChildFKInfo) { for _, fk := range fks { @@ -271,17 +253,33 @@ func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator, } // We need to select all the parent columns for the foreign key constraint, to use in the update of the child table. - cols, exprs := selectParentColumns(fk, len(selectExprs)) - selectExprs = append(selectExprs, exprs...) + var selectOffsets []int + selectOffsets, selectExprs = addColumns(ctx, fk.ParentColumns, selectExprs) + + // If we are updating a foreign key column to a non-literal value then, need information about + // 1. whether the new value is different from the old value + // 2. the new value itself. + // 3. the bind variable to assign to this value. + var nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo + ue := ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable)) + // We only need to store these offsets and add these expressions to SELECT when there are non-literal updates present. + if hasNonLiteralUpdate(ue) { + for _, updExpr := range ue { + // We add the expression and a comparison expression to the SELECT exprssion while storing their offsets. + var info engine.NonLiteralUpdateInfo + info, selectExprs = addNonLiteralUpdExprToSelect(ctx, updExpr, selectExprs) + nonLiteralUpdateInfo = append(nonLiteralUpdateInfo, info) + } + } - fkChild, err := createFkChildForUpdate(ctx, fk, updStmt, cols, updatedTable) + fkChild, err := createFkChildForUpdate(ctx, fk, selectOffsets, nonLiteralUpdateInfo, updatedTable) if err != nil { return nil, err } fkChildren = append(fkChildren, fkChild) } - selectionOp, err := createSelectionOp(ctx, selectExprs, updStmt.TableExprs, updStmt.Where, nil, sqlparser.ForUpdateLock) + selectionOp, err := createSelectionOp(ctx, selectExprs, updStmt.TableExprs, updStmt.Where, updStmt.OrderBy, nil, sqlparser.ForUpdateLock) if err != nil { return nil, err } @@ -293,8 +291,73 @@ func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator, }, nil } +// hasNonLiteralUpdate checks if any of the update expressions have a non-literal update. +func hasNonLiteralUpdate(exprs sqlparser.UpdateExprs) bool { + for _, expr := range exprs { + if !sqlparser.IsLiteral(expr.Expr) { + return true + } + } + return false +} + +// addColumns adds the given set of columns to the select expressions provided. It tries to reuse the columns if already present in it. +// It returns the list of offsets for the columns and the updated select expressions. +func addColumns(ctx *plancontext.PlanningContext, columns sqlparser.Columns, exprs []sqlparser.SelectExpr) ([]int, []sqlparser.SelectExpr) { + var offsets []int + selectExprs := exprs + for _, column := range columns { + ae := aeWrap(sqlparser.NewColName(column.String())) + exists := false + for idx, expr := range exprs { + if ctx.SemTable.EqualsExpr(expr.(*sqlparser.AliasedExpr).Expr, ae.Expr) { + offsets = append(offsets, idx) + exists = true + break + } + } + if !exists { + offsets = append(offsets, len(selectExprs)) + selectExprs = append(selectExprs, ae) + + } + } + return offsets, selectExprs +} + +// For an update query having non-literal updates, we add the updated expression and a comparison expression to the select query. +// For example, for a query like `update fk_table set col = id * 100 + 1` +// We would add the expression `id * 100 + 1` and the comparison expression `col <=> id * 100 + 1` to the select query. +func addNonLiteralUpdExprToSelect(ctx *plancontext.PlanningContext, updExpr *sqlparser.UpdateExpr, exprs []sqlparser.SelectExpr) (engine.NonLiteralUpdateInfo, []sqlparser.SelectExpr) { + // Create the comparison expression. + compExpr := sqlparser.NewComparisonExpr(sqlparser.NullSafeEqualOp, updExpr.Name, updExpr.Expr, nil) + info := engine.NonLiteralUpdateInfo{ + CompExprCol: -1, + UpdateExprCol: -1, + } + // Add the expressions to the select expressions. We make sure to reuse the offset if it has already been added once. + for idx, selectExpr := range exprs { + if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, compExpr) { + info.CompExprCol = idx + } + if ctx.SemTable.EqualsExpr(selectExpr.(*sqlparser.AliasedExpr).Expr, updExpr.Expr) { + info.UpdateExprCol = idx + } + } + // If the expression doesn't exist, then we add the expression and store the offset. + if info.CompExprCol == -1 { + info.CompExprCol = len(exprs) + exprs = append(exprs, aeWrap(compExpr)) + } + if info.UpdateExprCol == -1 { + info.UpdateExprCol = len(exprs) + exprs = append(exprs, aeWrap(updExpr.Expr)) + } + return info, exprs +} + // createFkChildForUpdate creates the update query operator for the child table based on the foreign key constraints. -func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, cols []int, updatedTable *vindexes.Table) (*FkChild, error) { +func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, selectOffsets []int, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, updatedTable *vindexes.Table) (*FkChild, error) { // Create a ValTuple of child column names var valTuple sqlparser.ValTuple for _, column := range fk.ChildColumns { @@ -307,13 +370,22 @@ func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildF compExpr := sqlparser.NewComparisonExpr(sqlparser.InOp, valTuple, sqlparser.NewListArg(bvName), nil) var childWhereExpr sqlparser.Expr = compExpr + // In the case of non-literal updates, we need to assign bindvariables for storing the updated value of the columns + // coming from the SELECT query. + if len(nonLiteralUpdateInfo) > 0 { + for idx, info := range nonLiteralUpdateInfo { + info.UpdateExprBvName = ctx.ReservedVars.ReserveVariable(foreignKeyUpdateExpr) + nonLiteralUpdateInfo[idx] = info + } + } + var childOp ops.Operator var err error switch fk.OnUpdate { case sqlparser.Cascade: - childOp, err = buildChildUpdOpForCascade(ctx, fk, updStmt, childWhereExpr, updatedTable) + childOp, err = buildChildUpdOpForCascade(ctx, fk, childWhereExpr, nonLiteralUpdateInfo, updatedTable) case sqlparser.SetNull: - childOp, err = buildChildUpdOpForSetNull(ctx, fk, updStmt, childWhereExpr) + childOp, err = buildChildUpdOpForSetNull(ctx, fk, childWhereExpr, nonLiteralUpdateInfo, updatedTable) case sqlparser.SetDefault: return nil, vterrors.VT09016() } @@ -322,9 +394,10 @@ func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildF } return &FkChild{ - BVName: bvName, - Cols: cols, - Op: childOp, + BVName: bvName, + Cols: selectOffsets, + Op: childOp, + NonLiteralInfo: nonLiteralUpdateInfo, }, nil } @@ -332,11 +405,11 @@ func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildF // The query looks like this - // // `UPDATE SET WHERE IN ()` -func buildChildUpdOpForCascade(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, childWhereExpr sqlparser.Expr, updatedTable *vindexes.Table) (ops.Operator, error) { +func buildChildUpdOpForCascade(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, childWhereExpr sqlparser.Expr, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, updatedTable *vindexes.Table) (ops.Operator, error) { // The update expressions are the same as the update expressions in the parent update query // with the column names replaced with the child column names. var childUpdateExprs sqlparser.UpdateExprs - for _, updateExpr := range updStmt.Exprs { + for idx, updateExpr := range ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable)) { colIdx := fk.ParentColumns.FindColumn(updateExpr.Name.Name) if colIdx == -1 { continue @@ -344,9 +417,13 @@ func buildChildUpdOpForCascade(ctx *plancontext.PlanningContext, fk vindexes.Chi // The where condition is the same as the comparison expression above // with the column names replaced with the child column names. + childUpdateExpr := updateExpr.Expr + if len(nonLiteralUpdateInfo) > 0 && nonLiteralUpdateInfo[idx].UpdateExprBvName != "" { + childUpdateExpr = sqlparser.NewArgument(nonLiteralUpdateInfo[idx].UpdateExprBvName) + } childUpdateExprs = append(childUpdateExprs, &sqlparser.UpdateExpr{ Name: sqlparser.NewColName(fk.ChildColumns[colIdx].String()), - Expr: updateExpr.Expr, + Expr: childUpdateExpr, }) } // Because we could be updating the child to a non-null value, @@ -373,7 +450,13 @@ func buildChildUpdOpForCascade(ctx *plancontext.PlanningContext, fk vindexes.Chi // `UPDATE SET // WHERE IN () // [AND ({ IS NULL OR}... NOT IN ())]` -func buildChildUpdOpForSetNull(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, childWhereExpr sqlparser.Expr) (ops.Operator, error) { +func buildChildUpdOpForSetNull( + ctx *plancontext.PlanningContext, + fk vindexes.ChildFKInfo, + childWhereExpr sqlparser.Expr, + nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo, + updatedTable *vindexes.Table, +) (ops.Operator, error) { // For the SET NULL type constraint, we need to set all the child columns to NULL. var childUpdateExprs sqlparser.UpdateExprs for _, column := range fk.ChildColumns { @@ -392,7 +475,7 @@ func buildChildUpdOpForSetNull(ctx *plancontext.PlanningContext, fk vindexes.Chi // For example, if we are setting `update parent cola = :v1 and colb = :v2`, then on the child, the where condition would look something like this - // `:v1 IS NULL OR :v2 IS NULL OR (child_cola, child_colb) NOT IN ((:v1,:v2))` // So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL). - compExpr := nullSafeNotInComparison(updStmt.Exprs, fk) + compExpr := nullSafeNotInComparison(ctx.SemTable.GetUpdateExpressionsForFk(fk.String(updatedTable)), fk, updatedTable.GetTableName(), nonLiteralUpdateInfo) if compExpr != nil { childWhereExpr = &sqlparser.AndExpr{ Left: childWhereExpr, @@ -408,7 +491,13 @@ func buildChildUpdOpForSetNull(ctx *plancontext.PlanningContext, fk vindexes.Chi } // createFKVerifyOp creates the verify operator for the parent foreign key constraints. -func createFKVerifyOp(ctx *plancontext.PlanningContext, childOp ops.Operator, updStmt *sqlparser.Update, parentFks []vindexes.ParentFKInfo, restrictChildFks []vindexes.ChildFKInfo) (ops.Operator, error) { +func createFKVerifyOp( + ctx *plancontext.PlanningContext, + childOp ops.Operator, + updStmt *sqlparser.Update, + parentFks []vindexes.ParentFKInfo, + restrictChildFks []vindexes.ChildFKInfo, +) (ops.Operator, error) { if len(parentFks) == 0 && len(restrictChildFks) == 0 { return childOp, nil } @@ -445,13 +534,13 @@ func createFKVerifyOp(ctx *plancontext.PlanningContext, childOp ops.Operator, up // Each parent foreign key constraint is verified by an anti join query of the form: // select 1 from child_tbl left join parent_tbl on -// where and and limit 1 +// where and and and limit 1 // E.g: // Child (c1, c2) references Parent (p1, p2) -// update Child set c1 = 1 where id = 1 +// update Child set c1 = c2 + 1 where id = 1 // verify query: -// select 1 from Child left join Parent on Parent.p1 = 1 and Parent.p2 = Child.c2 -// where Parent.p1 is null and Parent.p2 is null and Child.id = 1 +// select 1 from Child left join Parent on Parent.p1 = Child.c2 + 1 and Parent.p2 = Child.c2 +// where Parent.p1 is null and Parent.p2 is null and Child.id = 1 and Child.c2 + 1 is not null // and Child.c2 is not null // limit 1 func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) { @@ -479,7 +568,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS var joinExpr sqlparser.Expr if matchedExpr == nil { predicate = &sqlparser.AndExpr{ - Left: parentIsNullExpr, + Left: predicate, Right: &sqlparser.IsExpr{ Left: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), Right: sqlparser.IsNotNullOp, @@ -491,10 +580,18 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS Right: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), } } else { + prefixedMatchExpr := prefixColNames(childTbl, matchedExpr.Expr) joinExpr = &sqlparser.ComparisonExpr{ Operator: sqlparser.EqualOp, Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), - Right: prefixColNames(childTbl, matchedExpr.Expr), + Right: prefixedMatchExpr, + } + predicate = &sqlparser.AndExpr{ + Left: predicate, + Right: &sqlparser.IsExpr{ + Left: prefixedMatchExpr, + Right: sqlparser.IsNotNullOp, + }, } } @@ -519,6 +616,7 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS sqlparser.NewJoinCondition(joinCond, nil)), }, sqlparser.NewWhere(sqlparser.WhereClause, whereCond), + nil, sqlparser.NewLimitWithoutOffset(1), sqlparser.ShareModeLock) } @@ -527,10 +625,10 @@ func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updS // select 1 from child_tbl join parent_tbl on where [AND ({ IS NULL OR}... NOT IN ())] limit 1 // E.g: // Child (c1, c2) references Parent (p1, p2) -// update Parent set p1 = 1 where id = 1 +// update Parent set p1 = col + 1 where id = 1 // verify query: // select 1 from Child join Parent on Parent.p1 = Child.c1 and Parent.p2 = Child.c2 -// where Parent.id = 1 and (1 IS NULL OR (child.c1) NOT IN ((1))) limit 1 +// where Parent.id = 1 and ((Parent.col + 1) IS NULL OR (child.c1) NOT IN ((Parent.col + 1))) limit 1 func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) { // ON UPDATE RESTRICT foreign keys that require validation, should only be allowed in the case where we // are verifying all the FKs on vtgate level. @@ -573,7 +671,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt // For example, if we are setting `update child cola = :v1 and colb = :v2`, then on the parent, the where condition would look something like this - // `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))` // So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL). - compExpr := nullSafeNotInComparison(updStmt.Exprs, cFk) + compExpr := nullSafeNotInComparison(updStmt.Exprs, cFk, parentTbl, nil) if compExpr != nil { whereCond = sqlparser.AndExpressions(whereCond, compExpr) } @@ -588,6 +686,7 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt sqlparser.NewJoinCondition(joinCond, nil)), }, sqlparser.NewWhere(sqlparser.WhereClause, whereCond), + nil, sqlparser.NewLimitWithoutOffset(1), sqlparser.ShareModeLock) } @@ -597,22 +696,24 @@ func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updSt // `:v1 IS NULL OR :v2 IS NULL OR (cola, colb) NOT IN ((:v1,:v2))` // So, if either of :v1 or :v2 is NULL, then the entire condition is true (which is the same as not having the condition when :v1 or :v2 is NULL) // This expression is used in cascading SET NULLs and in verifying whether an update should be restricted. -func nullSafeNotInComparison(updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo) sqlparser.Expr { +func nullSafeNotInComparison(updateExprs sqlparser.UpdateExprs, cFk vindexes.ChildFKInfo, parentTbl sqlparser.TableName, nonLiteralUpdateInfo []engine.NonLiteralUpdateInfo) sqlparser.Expr { + var valTuple sqlparser.ValTuple var updateValues sqlparser.ValTuple - for _, updateExpr := range updateExprs { + for idx, updateExpr := range updateExprs { colIdx := cFk.ParentColumns.FindColumn(updateExpr.Name.Name) if colIdx >= 0 { if sqlparser.IsNull(updateExpr.Expr) { return nil } - updateValues = append(updateValues, updateExpr.Expr) + childUpdateExpr := prefixColNames(parentTbl, updateExpr.Expr) + if len(nonLiteralUpdateInfo) > 0 && nonLiteralUpdateInfo[idx].UpdateExprBvName != "" { + childUpdateExpr = sqlparser.NewArgument(nonLiteralUpdateInfo[idx].UpdateExprBvName) + } + updateValues = append(updateValues, childUpdateExpr) + valTuple = append(valTuple, sqlparser.NewColNameWithQualifier(cFk.ChildColumns[colIdx].String(), cFk.Table.GetTableName())) } } - // Create a ValTuple of child column names - var valTuple sqlparser.ValTuple - for _, column := range cFk.ChildColumns { - valTuple = append(valTuple, sqlparser.NewColNameWithQualifier(column.String(), cFk.Table.GetTableName())) - } + var finalExpr sqlparser.Expr = sqlparser.NewComparisonExpr(sqlparser.NotInOp, valTuple, sqlparser.ValTuple{updateValues}, nil) for _, value := range updateValues { finalExpr = &sqlparser.OrExpr{ diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index 065691d2356..afe42a45720 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -656,8 +656,8 @@ "Name": "unsharded_fk_allow", "Sharded": false }, - "FieldQuery": "select col1, col1 from u_tbl1 where 1 != 1", - "Query": "select col1, col1 from u_tbl1 for update", + "FieldQuery": "select col1 from u_tbl1 where 1 != 1", + "Query": "select col1 from u_tbl1 for update", "Table": "u_tbl1" }, { @@ -715,7 +715,7 @@ "OperatorType": "FkCascade", "BvName": "fkc_vals2", "Cols": [ - 1 + 0 ], "Inputs": [ { @@ -789,17 +789,248 @@ "plan": "VT12001: unsupported: update with limit with foreign key constraints" }, { - "comment": "update in a table with non-literal value - set null fail due to child update where condition", + "comment": "update in a table with non-literal value - set null", "query": "update u_tbl2 set m = 2, col2 = col1 + 'bar' where id = 1", - "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" + "plan": { + "QueryType": "UPDATE", + "Original": "update u_tbl2 set m = 2, col2 = col1 + 'bar' where id = 1", + "Instructions": { + "OperatorType": "FKVerify", + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = u_tbl2.col1 + 'bar' where 1 != 1", + "Query": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = u_tbl2.col1 + 'bar' where u_tbl2.col1 + 'bar' is not null and u_tbl2.id = 1 and u_tbl1.col1 is null limit 1 lock in share mode", + "Table": "u_tbl1, u_tbl2" + }, + { + "InputName": "PostVerify", + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col2, col2 <=> u_tbl2.col1 + 'bar', u_tbl2.col1 + 'bar' from u_tbl2 where 1 != 1", + "Query": "select col2, col2 <=> u_tbl2.col1 + 'bar', u_tbl2.col1 + 'bar' from u_tbl2 where u_tbl2.id = 1 for update", + "Table": "u_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 2, + "UpdateExprBvName": "fkc_upd", + "CompExprCol": 1 + } + ], + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals and (:fkc_upd is null or (u_tbl3.col3) not in ((:fkc_upd)))", + "Table": "u_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl2 set m = 2, col2 = u_tbl2.col1 + 'bar' where u_tbl2.id = 1", + "Table": "u_tbl2" + } + ] + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl1", + "unsharded_fk_allow.u_tbl2", + "unsharded_fk_allow.u_tbl3" + ] + } }, { - "comment": "update in a table with non-literal value - with cascade fail as the cascade value is not known", + "comment": "update in a table with non-literal value - with cascade", "query": "update u_tbl1 set m = 2, col1 = x + 'bar' where id = 1", - "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" + "plan": { + "QueryType": "UPDATE", + "Original": "update u_tbl1 set m = 2, col1 = x + 'bar' where id = 1", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col1, col1 <=> u_tbl1.x + 'bar', u_tbl1.x + 'bar' from u_tbl1 where 1 != 1", + "Query": "select col1, col1 <=> u_tbl1.x + 'bar', u_tbl1.x + 'bar' from u_tbl1 where id = 1 for update", + "Table": "u_tbl1" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FkCascade", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 2, + "UpdateExprBvName": "fkc_upd", + "CompExprCol": 1 + } + ], + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col2 from u_tbl2 where 1 != 1", + "Query": "select col2 from u_tbl2 where (col2) in ::fkc_vals for update", + "Table": "u_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals1", + "Cols": [ + 0 + ], + "Query": "update u_tbl3 set col3 = null where (col3) in ::fkc_vals1 and (:fkc_upd is null or (u_tbl3.col3) not in ((:fkc_upd)))", + "Table": "u_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl2 set col2 = :fkc_upd where (col2) in ::fkc_vals", + "Table": "u_tbl2" + } + ] + }, + { + "InputName": "CascadeChild-2", + "OperatorType": "FkCascade", + "BvName": "fkc_vals2", + "Cols": [ + 0 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 2, + "UpdateExprBvName": "fkc_upd1", + "CompExprCol": 1 + } + ], + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col9 from u_tbl9 where 1 != 1", + "Query": "select col9 from u_tbl9 where (col9) in ::fkc_vals2 and (:fkc_upd1 is null or (u_tbl9.col9) not in ((:fkc_upd1))) for update", + "Table": "u_tbl9" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals3", + "Cols": [ + 0 + ], + "Query": "update u_tbl8 set col8 = null where (col8) in ::fkc_vals3", + "Table": "u_tbl8" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl9 set col9 = null where (col9) in ::fkc_vals2 and (:fkc_upd1 is null or (u_tbl9.col9) not in ((:fkc_upd1)))", + "Table": "u_tbl9" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl1 set m = 2, col1 = u_tbl1.x + 'bar' where id = 1", + "Table": "u_tbl1" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl1", + "unsharded_fk_allow.u_tbl2", + "unsharded_fk_allow.u_tbl3", + "unsharded_fk_allow.u_tbl8", + "unsharded_fk_allow.u_tbl9" + ] + } }, { - "comment": "update in a table with set null, non-literal value on non-foreign key column - allowed", + "comment": "update in a table with set null, non-literal value on non-foreign key column", "query": "update u_tbl2 set m = col1 + 'bar', col2 = 2 where id = 1", "plan": { "QueryType": "UPDATE", @@ -856,7 +1087,7 @@ } }, { - "comment": "update in a table with cascade, non-literal value on non-foreign key column - allowed", + "comment": "update in a table with cascade, non-literal value on non-foreign key column", "query": "update u_tbl1 set m = x + 'bar', col1 = 2 where id = 1", "plan": { "QueryType": "UPDATE", @@ -872,8 +1103,8 @@ "Name": "unsharded_fk_allow", "Sharded": false }, - "FieldQuery": "select col1, col1 from u_tbl1 where 1 != 1", - "Query": "select col1, col1 from u_tbl1 where id = 1 for update", + "FieldQuery": "select col1 from u_tbl1 where 1 != 1", + "Query": "select col1 from u_tbl1 where id = 1 for update", "Table": "u_tbl1" }, { @@ -931,7 +1162,7 @@ "OperatorType": "FkCascade", "BvName": "fkc_vals2", "Cols": [ - 1 + 0 ], "Inputs": [ { @@ -1010,9 +1241,87 @@ "plan": "VT12001: unsupported: foreign keys management at vitess with limit" }, { - "comment": "update with fk on cross-shard with a where condition on non-literal value - disallowed", + "comment": "update with fk on cross-shard with a update condition on non-literal value", "query": "update tbl3 set coly = colx + 10 where coly = 10", - "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" + "plan": { + "QueryType": "UPDATE", + "Original": "update tbl3 set coly = colx + 10 where coly = 10", + "Instructions": { + "OperatorType": "FKVerify", + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Limit", + "Count": "1", + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "1 as 1" + ], + "Inputs": [ + { + "OperatorType": "Filter", + "Predicate": "tbl1.t1col1 is null", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0,R:0", + "JoinVars": { + "tbl3_colx": 0 + }, + "TableName": "tbl3_tbl1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select tbl3.colx from tbl3 where 1 != 1", + "Query": "select tbl3.colx from tbl3 where tbl3.colx + 10 is not null and tbl3.coly = 10 lock in share mode", + "Table": "tbl3" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select tbl1.t1col1 from tbl1 where 1 != 1", + "Query": "select tbl1.t1col1 from tbl1 where tbl1.t1col1 = :tbl3_colx + 10 lock in share mode", + "Table": "tbl1" + } + ] + } + ] + } + ] + } + ] + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ tbl3 set coly = tbl3.colx + 10 where tbl3.coly = 10", + "Table": "tbl3" + } + ] + }, + "TablesUsed": [ + "sharded_fk_allow.tbl1", + "sharded_fk_allow.tbl3" + ] + } }, { "comment": "update with fk on cross-shard with a where condition", @@ -1297,7 +1606,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1 lock in share mode", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :v1 where (u_tbl4.col4) in ::fkc_vals and :v1 is not null and u_tbl3.col3 is null limit 1 lock in share mode", "Table": "u_tbl3, u_tbl4" }, { @@ -1653,5 +1962,301 @@ "sharded_fk_allow.tbl5" ] } + }, + { + "comment": "foreign key column updated by using a column which is also getting updated", + "query": "update u_tbl1 set foo = 100, col1 = baz + 1 + foo where bar = 42", + "plan": "VT12001: unsupported: foo column referenced in foreign key column col1 is itself updated" + }, + { + "comment": "foreign key column updated by using a column which is also getting updated - self reference column is allowed", + "query": "update u_tbl7 set foo = 100, col7 = baz + 1 + col7 where bar = 42", + "plan": { + "QueryType": "UPDATE", + "Original": "update u_tbl7 set foo = 100, col7 = baz + 1 + col7 where bar = 42", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col7, col7 <=> baz + 1 + col7, baz + 1 + col7 from u_tbl7 where 1 != 1", + "Query": "select col7, col7 <=> baz + 1 + col7, baz + 1 + col7 from u_tbl7 where bar = 42 for update", + "Table": "u_tbl7" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FKVerify", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 2, + "UpdateExprBvName": "fkc_upd", + "CompExprCol": 1 + } + ], + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :fkc_upd where 1 != 1", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = :fkc_upd where (u_tbl4.col4) in ::fkc_vals and :fkc_upd is not null and u_tbl3.col3 is null limit 1 lock in share mode", + "Table": "u_tbl3, u_tbl4" + }, + { + "InputName": "VerifyChild-2", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", + "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (:fkc_upd is null or (u_tbl9.col9) not in ((:fkc_upd))) and u_tbl4.col4 = u_tbl9.col9 limit 1 lock in share mode", + "Table": "u_tbl4, u_tbl9" + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl4 set col4 = :fkc_upd where (u_tbl4.col4) in ::fkc_vals", + "Table": "u_tbl4" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl7 set foo = 100, col7 = baz + 1 + col7 where bar = 42", + "Table": "u_tbl7" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl3", + "unsharded_fk_allow.u_tbl4", + "unsharded_fk_allow.u_tbl7", + "unsharded_fk_allow.u_tbl9" + ] + } + }, + { + "comment": "Single column updated in a multi-col table", + "query": "update u_multicol_tbl1 set cola = cola + 3 where id = 3", + "plan": { + "QueryType": "UPDATE", + "Original": "update u_multicol_tbl1 set cola = cola + 3 where id = 3", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select cola, colb, cola <=> u_multicol_tbl1.cola + 3, u_multicol_tbl1.cola + 3 from u_multicol_tbl1 where 1 != 1", + "Query": "select cola, colb, cola <=> u_multicol_tbl1.cola + 3, u_multicol_tbl1.cola + 3 from u_multicol_tbl1 where id = 3 for update", + "Table": "u_multicol_tbl1" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FkCascade", + "BvName": "fkc_vals", + "Cols": [ + 0, + 1 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 3, + "UpdateExprBvName": "fkc_upd", + "CompExprCol": 2 + } + ], + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select cola, colb from u_multicol_tbl2 where 1 != 1", + "Query": "select cola, colb from u_multicol_tbl2 where (cola, colb) in ::fkc_vals and (:fkc_upd is null or (u_multicol_tbl2.cola) not in ((:fkc_upd))) for update", + "Table": "u_multicol_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals1", + "Cols": [ + 0, + 1 + ], + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_multicol_tbl3 set cola = null, colb = null where (cola, colb) in ::fkc_vals1", + "Table": "u_multicol_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_multicol_tbl2 set cola = null, colb = null where (cola, colb) in ::fkc_vals and (:fkc_upd is null or (u_multicol_tbl2.cola) not in ((:fkc_upd)))", + "Table": "u_multicol_tbl2" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_multicol_tbl1 set cola = u_multicol_tbl1.cola + 3 where id = 3", + "Table": "u_multicol_tbl1" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_multicol_tbl1", + "unsharded_fk_allow.u_multicol_tbl2", + "unsharded_fk_allow.u_multicol_tbl3" + ] + } + }, + { + "comment": "updating multiple columns of a fk constraint such that one uses the other", + "query": "update u_multicol_tbl3 set cola = id, colb = 5 * (cola + (1 - (cola))) where id = 2", + "plan": "VT12001: unsupported: cola column referenced in foreign key column colb is itself updated" + }, + { + "comment": "multicol foreign key updates with one literal and one non-literal update", + "query": "update u_multicol_tbl2 set cola = 2, colb = colc - (2) where id = 7", + "plan": { + "QueryType": "UPDATE", + "Original": "update u_multicol_tbl2 set cola = 2, colb = colc - (2) where id = 7", + "Instructions": { + "OperatorType": "FKVerify", + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where 1 != 1", + "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl2.colc - 2 is not null and u_multicol_tbl2.id = 7 and u_multicol_tbl1.cola is null and u_multicol_tbl1.colb is null limit 1 lock in share mode", + "Table": "u_multicol_tbl1, u_multicol_tbl2" + }, + { + "InputName": "PostVerify", + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select cola, colb, cola <=> 2, 2, colb <=> u_multicol_tbl2.colc - 2, u_multicol_tbl2.colc - 2 from u_multicol_tbl2 where 1 != 1", + "Query": "select cola, colb, cola <=> 2, 2, colb <=> u_multicol_tbl2.colc - 2, u_multicol_tbl2.colc - 2 from u_multicol_tbl2 where u_multicol_tbl2.id = 7 for update", + "Table": "u_multicol_tbl2" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "BvName": "fkc_vals", + "Cols": [ + 0, + 1 + ], + "NonLiteralUpdateInfo": [ + { + "UpdateExprCol": 3, + "UpdateExprBvName": "fkc_upd", + "CompExprCol": 2 + }, + { + "UpdateExprCol": 5, + "UpdateExprBvName": "fkc_upd1", + "CompExprCol": 4 + } + ], + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_multicol_tbl3 set cola = :fkc_upd, colb = :fkc_upd1 where (cola, colb) in ::fkc_vals", + "Table": "u_multicol_tbl3" + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_multicol_tbl2 set cola = 2, colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl2.id = 7", + "Table": "u_multicol_tbl2" + } + ] + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_multicol_tbl1", + "unsharded_fk_allow.u_multicol_tbl2", + "unsharded_fk_allow.u_multicol_tbl3" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index cad8e9be1eb..9373ae35a29 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -88,7 +88,7 @@ "Sharded": true }, "FieldQuery": "select `user`.col1 as a, `user`.col2 from `user` where 1 != 1", - "Query": "select `user`.col1 as a, `user`.col2 from `user` where `user`.col1 = 1 and `user`.col1 = `user`.col2 and 1 = 1", + "Query": "select `user`.col1 as a, `user`.col2 from `user` where `user`.col1 = 1 and `user`.col1 = `user`.col2", "Table": "`user`" }, { @@ -99,7 +99,7 @@ "Sharded": true }, "FieldQuery": "select user_extra.col3 from user_extra where 1 != 1", - "Query": "select user_extra.col3 from user_extra where user_extra.col3 = 1 and 1 = 1", + "Query": "select user_extra.col3 from user_extra where user_extra.col3 = 1", "Table": "user_extra" } ] diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index eced4251ab3..dfb841a4d11 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -46,6 +46,13 @@ func gen4UpdateStmtPlanner( return nil, err } + // If there are non-literal foreign key updates, we have to run the query with foreign key checks off. + if ctx.SemTable.HasNonLiteralForeignKeyUpdate(updStmt.Exprs) { + // Since we are running the query with foreign key checks off, we have to verify all the foreign keys validity on vtgate. + ctx.VerifyAllFKs = true + updStmt.Comments = updStmt.Comments.Prepend("/*+ SET_VAR(foreign_key_checks=OFF) */").Parsed() + } + // Remove all the foreign keys that don't require any handling. err = ctx.SemTable.RemoveNonRequiredForeignKeys(ctx.VerifyAllFKs, vindexes.UpdateAction) if err != nil { diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index e524b1a33cf..96c604d76a1 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -108,7 +108,7 @@ func (a *analyzer) newSemTable(statement sqlparser.Statement, coll collations.ID columns[union] = info.exprs } - childFks, parentFks, err := a.getInvolvedForeignKeys(statement) + childFks, parentFks, childFkToUpdExprs, err := a.getInvolvedForeignKeys(statement) if err != nil { return nil, err } @@ -130,6 +130,7 @@ func (a *analyzer) newSemTable(statement sqlparser.Statement, coll collations.ID QuerySignature: a.sig, childForeignKeysInvolved: childFks, parentForeignKeysInvolved: parentFks, + childFkToUpdExprs: childFkToUpdExprs, }, nil } @@ -317,14 +318,14 @@ func (a *analyzer) noteQuerySignature(node sqlparser.SQLNode) { } // getInvolvedForeignKeys gets the foreign keys that might require taking care off when executing the given statement. -func (a *analyzer) getInvolvedForeignKeys(statement sqlparser.Statement) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, error) { +func (a *analyzer) getInvolvedForeignKeys(statement sqlparser.Statement) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs, error) { // There are only the DML statements that require any foreign keys handling. switch stmt := statement.(type) { case *sqlparser.Delete: // For DELETE statements, none of the parent foreign keys require handling. // So we collect all the child foreign keys. allChildFks, _, err := a.getAllManagedForeignKeys() - return allChildFks, nil, err + return allChildFks, nil, nil, err case *sqlparser.Insert: // For INSERT statements, we have 3 different cases: // 1. REPLACE statement: REPLACE statements are essentially DELETEs and INSERTs rolled into one. @@ -334,35 +335,35 @@ func (a *analyzer) getInvolvedForeignKeys(statement sqlparser.Statement) (map[Ta // 3. INSERT with ON DUPLICATE KEY UPDATE: This might trigger an update on the columns specified in the ON DUPLICATE KEY UPDATE clause. allChildFks, allParentFKs, err := a.getAllManagedForeignKeys() if err != nil { - return nil, nil, err + return nil, nil, nil, err } if stmt.Action == sqlparser.ReplaceAct { - return allChildFks, allParentFKs, nil + return allChildFks, allParentFKs, nil, nil } if len(stmt.OnDup) == 0 { - return nil, allParentFKs, nil + return nil, allParentFKs, nil, nil } // If only a certain set of columns are being updated, then there might be some child foreign keys that don't need any consideration since their columns aren't being updated. // So, we filter these child foreign keys out. We can't filter any parent foreign keys because the statement will INSERT a row too, which requires validating all the parent foreign keys. - updatedChildFks, _ := a.filterForeignKeysUsingUpdateExpressions(allChildFks, nil, sqlparser.UpdateExprs(stmt.OnDup)) - return updatedChildFks, allParentFKs, nil + updatedChildFks, _, childFkToUpdExprs := a.filterForeignKeysUsingUpdateExpressions(allChildFks, nil, sqlparser.UpdateExprs(stmt.OnDup)) + return updatedChildFks, allParentFKs, childFkToUpdExprs, nil case *sqlparser.Update: // For UPDATE queries we get all the parent and child foreign keys, but we can filter some of them out if the columns that they consist off aren't being updated or are set to NULLs. allChildFks, allParentFks, err := a.getAllManagedForeignKeys() if err != nil { - return nil, nil, err + return nil, nil, nil, err } - childFks, parentFks := a.filterForeignKeysUsingUpdateExpressions(allChildFks, allParentFks, stmt.Exprs) - return childFks, parentFks, nil + childFks, parentFks, childFkToUpdExprs := a.filterForeignKeysUsingUpdateExpressions(allChildFks, allParentFks, stmt.Exprs) + return childFks, parentFks, childFkToUpdExprs, nil default: - return nil, nil, nil + return nil, nil, nil, nil } } // filterForeignKeysUsingUpdateExpressions filters the child and parent foreign key constraints that don't require any validations/cascades given the updated expressions. -func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[TableSet][]vindexes.ChildFKInfo, allParentFks map[TableSet][]vindexes.ParentFKInfo, updExprs sqlparser.UpdateExprs) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo) { +func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[TableSet][]vindexes.ChildFKInfo, allParentFks map[TableSet][]vindexes.ParentFKInfo, updExprs sqlparser.UpdateExprs) (map[TableSet][]vindexes.ChildFKInfo, map[TableSet][]vindexes.ParentFKInfo, map[string]sqlparser.UpdateExprs) { if len(allChildFks) == 0 && len(allParentFks) == 0 { - return nil, nil + return nil, nil, nil } pFksRequired := make(map[TableSet][]bool, len(allParentFks)) @@ -377,6 +378,9 @@ func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[Table // updExprToTableSet stores the tables that the updated expressions are from. updExprToTableSet := make(map[*sqlparser.ColName]TableSet) + // childFKToUpdExprs stores child foreign key to update expressions mapping. + childFKToUpdExprs := map[string]sqlparser.UpdateExprs{} + // Go over all the update expressions for _, updateExpr := range updExprs { deps := a.binder.direct.dependencies(updateExpr.Name) @@ -393,6 +397,10 @@ func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[Table for idx, childFk := range childFks { if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { cFksRequired[deps][idx] = true + tbl, _ := a.tables.tableInfoFor(deps) + ue := childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] + ue = append(ue, updateExpr) + childFKToUpdExprs[childFk.String(tbl.GetVindexTable())] = ue } } // If we are setting a column to NULL, then we don't need to verify the existance of an @@ -434,7 +442,6 @@ func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[Table } } pFksNeedsHandling[ts] = pFKNeeded - } for ts, childFks := range allChildFks { var cFKNeeded []vindexes.ChildFKInfo @@ -444,9 +451,8 @@ func (a *analyzer) filterForeignKeysUsingUpdateExpressions(allChildFks map[Table } } cFksNeedsHandling[ts] = cFKNeeded - } - return cFksNeedsHandling, pFksNeedsHandling + return cFksNeedsHandling, pFksNeedsHandling, childFKToUpdExprs } // getAllManagedForeignKeys gets all the foreign keys for the query we are analyzing that Vitess is reposible for managing. diff --git a/go/vt/vtgate/semantics/analyzer_fk_test.go b/go/vt/vtgate/semantics/analyzer_fk_test.go new file mode 100644 index 00000000000..5ba6041ef5c --- /dev/null +++ b/go/vt/vtgate/semantics/analyzer_fk_test.go @@ -0,0 +1,582 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + vschemapb "vitess.io/vitess/go/vt/proto/vschema" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +var parentTbl = &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("parentt"), + Keyspace: &vindexes.Keyspace{ + Name: "ks", + }, +} + +var tbl = map[string]TableInfo{ + "t0": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t0"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(parentTbl, []string{"col"}, []string{"col"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), + }, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + }, + "t1": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: "ks_unmanaged", Sharded: true}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(parentTbl, []string{"cola"}, []string{"cola"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola1", "cola2"}, []string{"ccola1", "ccola2"}, sqlparser.SetNull), + }, + }, + }, + "t2": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t2"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + }, + }, + "t3": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t3"), + Keyspace: &vindexes.Keyspace{Name: "undefined_ks", Sharded: true}, + }, + }, + "t4": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t4"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(parentTbl, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), + ckInfo(parentTbl, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), + ckInfo(parentTbl, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), + }, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(parentTbl, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), + pkInfo(parentTbl, []string{"pcolc"}, []string{"colc"}), + pkInfo(parentTbl, []string{"pcolb", "pcola"}, []string{"colb", "cola"}), + pkInfo(parentTbl, []string{"pcolb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"pcola"}, []string{"cola"}), + pkInfo(parentTbl, []string{"pcolb", "pcolx"}, []string{"colb", "colx"}), + }, + }, + }, + "t5": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t5"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(parentTbl, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), + ckInfo(parentTbl, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), + }, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(parentTbl, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), + pkInfo(parentTbl, []string{"pcola"}, []string{"cola"}), + pkInfo(parentTbl, []string{"pcold", "pcolc"}, []string{"cold", "colc"}), + pkInfo(parentTbl, []string{"pcold"}, []string{"cold"}), + pkInfo(parentTbl, []string{"pcold", "pcolx"}, []string{"cold", "colx"}), + }, + }, + }, + "t6": &RealTable{ + Table: &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t6"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(parentTbl, []string{"col"}, []string{"col"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), + ckInfo(parentTbl, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), + ckInfo(parentTbl, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), + ckInfo(parentTbl, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), + }, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + }, +} + +// TestGetAllManagedForeignKeys tests the functionality of getAllManagedForeignKeys. +func TestGetAllManagedForeignKeys(t *testing.T) { + tests := []struct { + name string + analyzer *analyzer + childFkWanted map[TableSet][]vindexes.ChildFKInfo + parentFkWanted map[TableSet][]vindexes.ParentFKInfo + expectedErr string + }{ + { + name: "Collect all foreign key constraints", + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], + tbl["t1"], + &DerivedTable{}, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + }, + }, + }, + childFkWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"col"}, []string{"col"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), + }, + }, + parentFkWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + }, + { + name: "keyspace not found in schema information", + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t2"], + tbl["t3"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }, + }, + }, + }, + expectedErr: "undefined_ks keyspace not found", + }, + { + name: "Cyclic fk constraints error", + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], tbl["t1"], + &DerivedTable{}, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + KsError: map[string]error{ + "ks": fmt.Errorf("VT09019: ks has cyclic foreign keys"), + }, + }, + }, + }, + expectedErr: "VT09019: ks has cyclic foreign keys", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + childFk, parentFk, err := tt.analyzer.getAllManagedForeignKeys() + if tt.expectedErr != "" { + require.EqualError(t, err, tt.expectedErr) + return + } + require.EqualValues(t, tt.childFkWanted, childFk) + require.EqualValues(t, tt.parentFkWanted, parentFk) + }) + } +} + +// TestFilterForeignKeysUsingUpdateExpressions tests the functionality of filterForeignKeysUsingUpdateExpressions. +func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { + cola := sqlparser.NewColName("cola") + colb := sqlparser.NewColName("colb") + colc := sqlparser.NewColName("colc") + cold := sqlparser.NewColName("cold") + a := &analyzer{ + binder: &binder{ + direct: map[sqlparser.Expr]TableSet{ + cola: SingleTableSet(0), + colb: SingleTableSet(0), + colc: SingleTableSet(1), + cold: SingleTableSet(1), + }, + }, + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t4"], + tbl["t5"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }, + }, + }, + } + updateExprs := sqlparser.UpdateExprs{ + &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}, + &sqlparser.UpdateExpr{Name: colc, Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: cold, Expr: &sqlparser.NullVal{}}, + } + tests := []struct { + name string + analyzer *analyzer + allChildFks map[TableSet][]vindexes.ChildFKInfo + allParentFks map[TableSet][]vindexes.ParentFKInfo + updExprs sqlparser.UpdateExprs + childFksWanted map[TableSet][]vindexes.ChildFKInfo + parentFksWanted map[TableSet][]vindexes.ParentFKInfo + }{ + { + name: "Child Foreign Keys Filtering", + analyzer: a, + allParentFks: nil, + allChildFks: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): tbl["t4"].(*RealTable).Table.ChildForeignKeys, + SingleTableSet(1): tbl["t5"].(*RealTable).Table.ChildForeignKeys, + }, + updExprs: updateExprs, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), + }, + SingleTableSet(1): { + ckInfo(parentTbl, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), + }, + }, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{}, + }, { + name: "Parent Foreign Keys Filtering", + analyzer: a, + allParentFks: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): tbl["t4"].(*RealTable).Table.ParentForeignKeys, + SingleTableSet(1): tbl["t5"].(*RealTable).Table.ParentForeignKeys, + }, + allChildFks: nil, + updExprs: updateExprs, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{}, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), + pkInfo(parentTbl, []string{"pcola"}, []string{"cola"}), + }, + SingleTableSet(1): { + pkInfo(parentTbl, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + childFks, parentFks, _ := tt.analyzer.filterForeignKeysUsingUpdateExpressions(tt.allChildFks, tt.allParentFks, tt.updExprs) + require.EqualValues(t, tt.childFksWanted, childFks) + require.EqualValues(t, tt.parentFksWanted, parentFks) + }) + } +} + +// TestGetInvolvedForeignKeys tests the functionality of getInvolvedForeignKeys. +func TestGetInvolvedForeignKeys(t *testing.T) { + cola := sqlparser.NewColName("cola") + colb := sqlparser.NewColName("colb") + colc := sqlparser.NewColName("colc") + cold := sqlparser.NewColName("cold") + tests := []struct { + name string + stmt sqlparser.Statement + analyzer *analyzer + childFksWanted map[TableSet][]vindexes.ChildFKInfo + parentFksWanted map[TableSet][]vindexes.ParentFKInfo + childFkUpdateExprsWanted map[string]sqlparser.UpdateExprs + expectedErr string + }{ + { + name: "Delete Query", + stmt: &sqlparser.Delete{}, + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], + tbl["t1"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + }, + }, + }, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"col"}, []string{"col"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), + }, + }, + }, + { + name: "Update statement", + stmt: &sqlparser.Update{ + Exprs: sqlparser.UpdateExprs{ + &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}, + &sqlparser.UpdateExpr{Name: colc, Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: cold, Expr: &sqlparser.NullVal{}}, + }, + }, + analyzer: &analyzer{ + binder: &binder{ + direct: map[sqlparser.Expr]TableSet{ + cola: SingleTableSet(0), + colb: SingleTableSet(0), + colc: SingleTableSet(1), + cold: SingleTableSet(1), + }, + }, + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t4"], + tbl["t5"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }, + }, + }, + }, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), + }, + SingleTableSet(1): { + ckInfo(parentTbl, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), + }, + }, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), + pkInfo(parentTbl, []string{"pcola"}, []string{"cola"}), + }, + SingleTableSet(1): { + pkInfo(parentTbl, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), + }, + }, + childFkUpdateExprsWanted: map[string]sqlparser.UpdateExprs{ + "ks.parentt|child_cola|child_colx||ks.t4|cola|colx": {&sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}}, + "ks.parentt|child_colb||ks.t4|colb": {&sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}}, + "ks.parentt|child_colc|child_colx||ks.t5|colc|colx": {&sqlparser.UpdateExpr{Name: colc, Expr: sqlparser.NewIntLiteral("1")}}, + "ks.parentt|child_cold||ks.t5|cold": {&sqlparser.UpdateExpr{Name: cold, Expr: &sqlparser.NullVal{}}}, + }, + }, + { + name: "Replace Query", + stmt: &sqlparser.Insert{ + Action: sqlparser.ReplaceAct, + }, + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], + tbl["t1"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + }, + }, + }, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"col"}, []string{"col"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), + }, + }, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + }, + { + name: "Insert Query", + stmt: &sqlparser.Insert{ + Action: sqlparser.InsertAct, + }, + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t0"], + tbl["t1"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + }, + }, + }, + childFksWanted: nil, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + }, + { + name: "Insert Query with On Duplicate", + stmt: &sqlparser.Insert{ + Action: sqlparser.InsertAct, + OnDup: sqlparser.OnDup{ + &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}, + }, + }, + analyzer: &analyzer{ + binder: &binder{ + direct: map[sqlparser.Expr]TableSet{ + cola: SingleTableSet(0), + colb: SingleTableSet(0), + }, + }, + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t6"], + tbl["t1"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + }, + }, + }, + childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ + SingleTableSet(0): { + ckInfo(parentTbl, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), + ckInfo(parentTbl, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), + }, + }, + parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ + SingleTableSet(0): { + pkInfo(parentTbl, []string{"colb"}, []string{"colb"}), + pkInfo(parentTbl, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), + }, + }, + childFkUpdateExprsWanted: map[string]sqlparser.UpdateExprs{ + "ks.parentt|child_cola|child_colx||ks.t6|cola|colx": {&sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}}, + "ks.parentt|child_colb||ks.t6|colb": {&sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}}, + }, + }, + { + name: "Insert error", + stmt: &sqlparser.Insert{}, + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t2"], + tbl["t3"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }, + }, + }, + }, + expectedErr: "undefined_ks keyspace not found", + }, + { + name: "Update error", + stmt: &sqlparser.Update{}, + analyzer: &analyzer{ + tables: &tableCollector{ + Tables: []TableInfo{ + tbl["t2"], + tbl["t3"], + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + }, + }, + }, + }, + expectedErr: "undefined_ks keyspace not found", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + childFks, parentFks, childFkUpdateExprs, err := tt.analyzer.getInvolvedForeignKeys(tt.stmt) + if tt.expectedErr != "" { + require.EqualError(t, err, tt.expectedErr) + return + } + require.EqualValues(t, tt.childFksWanted, childFks) + require.EqualValues(t, tt.childFkUpdateExprsWanted, childFkUpdateExprs) + require.EqualValues(t, tt.parentFksWanted, parentFks) + }) + } +} + +func ckInfo(cTable *vindexes.Table, pCols []string, cCols []string, refAction sqlparser.ReferenceAction) vindexes.ChildFKInfo { + return vindexes.ChildFKInfo{ + Table: cTable, + ParentColumns: sqlparser.MakeColumns(pCols...), + ChildColumns: sqlparser.MakeColumns(cCols...), + OnDelete: refAction, + } +} + +func pkInfo(parentTable *vindexes.Table, pCols []string, cCols []string) vindexes.ParentFKInfo { + return vindexes.ParentFKInfo{ + Table: parentTable, + ParentColumns: sqlparser.MakeColumns(pCols...), + ChildColumns: sqlparser.MakeColumns(cCols...), + } +} diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index c8251dd36c3..cb19d5deafd 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -17,7 +17,6 @@ limitations under the License. package semantics import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -25,7 +24,6 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - vschemapb "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -1627,555 +1625,3 @@ func fakeSchemaInfo() *FakeSI { } return si } - -var tbl = map[string]TableInfo{ - "t0": &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks"}, - ChildForeignKeys: []vindexes.ChildFKInfo{ - ckInfo(nil, []string{"col"}, []string{"col"}, sqlparser.Restrict), - ckInfo(nil, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), - }, - ParentForeignKeys: []vindexes.ParentFKInfo{ - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - "t1": &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks_unmanaged", Sharded: true}, - ChildForeignKeys: []vindexes.ChildFKInfo{ - ckInfo(nil, []string{"cola"}, []string{"cola"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola1", "cola2"}, []string{"ccola1", "ccola2"}, sqlparser.SetNull), - }, - }, - }, - "t2": &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks"}, - }, - }, - "t3": &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "undefined_ks", Sharded: true}, - }, - }, -} - -// TestGetAllManagedForeignKeys tests the functionality of getAllManagedForeignKeys. -func TestGetAllManagedForeignKeys(t *testing.T) { - tests := []struct { - name string - analyzer *analyzer - childFkWanted map[TableSet][]vindexes.ChildFKInfo - parentFkWanted map[TableSet][]vindexes.ParentFKInfo - expectedErr string - }{ - { - name: "Collect all foreign key constraints", - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{tbl["t0"], tbl["t1"], - &DerivedTable{}, - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - }, - }, - }, - childFkWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"col"}, []string{"col"}, sqlparser.Restrict), - ckInfo(nil, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), - }, - }, - parentFkWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - { - name: "keyspace not found in schema information", - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, - }, - expectedErr: "undefined_ks keyspace not found", - }, - { - name: "Cyclic fk constraints error", - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], tbl["t1"], - &DerivedTable{}, - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - KsError: map[string]error{ - "ks": fmt.Errorf("VT09019: ks has cyclic foreign keys"), - }, - }, - }, - }, - expectedErr: "VT09019: ks has cyclic foreign keys", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - childFk, parentFk, err := tt.analyzer.getAllManagedForeignKeys() - if tt.expectedErr != "" { - require.EqualError(t, err, tt.expectedErr) - return - } - require.EqualValues(t, tt.childFkWanted, childFk) - require.EqualValues(t, tt.parentFkWanted, parentFk) - }) - } -} - -// TestFilterForeignKeysUsingUpdateExpressions tests the functionality of filterForeignKeysUsingUpdateExpressions. -func TestFilterForeignKeysUsingUpdateExpressions(t *testing.T) { - cola := sqlparser.NewColName("cola") - colb := sqlparser.NewColName("colb") - colc := sqlparser.NewColName("colc") - cold := sqlparser.NewColName("cold") - a := &analyzer{ - binder: &binder{ - direct: map[sqlparser.Expr]TableSet{ - cola: SingleTableSet(0), - colb: SingleTableSet(0), - colc: SingleTableSet(1), - cold: SingleTableSet(1), - }, - }, - } - updateExprs := sqlparser.UpdateExprs{ - &sqlparser.UpdateExpr{Name: cola, Expr: sqlparser.NewIntLiteral("1")}, - &sqlparser.UpdateExpr{Name: colb, Expr: &sqlparser.NullVal{}}, - &sqlparser.UpdateExpr{Name: colc, Expr: sqlparser.NewIntLiteral("1")}, - &sqlparser.UpdateExpr{Name: cold, Expr: &sqlparser.NullVal{}}, - } - tests := []struct { - name string - analyzer *analyzer - allChildFks map[TableSet][]vindexes.ChildFKInfo - allParentFks map[TableSet][]vindexes.ParentFKInfo - updExprs sqlparser.UpdateExprs - childFksWanted map[TableSet][]vindexes.ChildFKInfo - parentFksWanted map[TableSet][]vindexes.ParentFKInfo - }{ - { - name: "Child Foreign Keys Filtering", - analyzer: a, - allParentFks: nil, - allChildFks: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - ckInfo(nil, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - }, - SingleTableSet(1): { - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - ckInfo(nil, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), - ckInfo(nil, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), - }, - }, - updExprs: updateExprs, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - }, - SingleTableSet(1): { - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - ckInfo(nil, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), - }, - }, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{}, - }, { - name: "Parent Foreign Keys Filtering", - analyzer: a, - allParentFks: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), - pkInfo(nil, []string{"pcolc"}, []string{"colc"}), - pkInfo(nil, []string{"pcolb", "pcola"}, []string{"colb", "cola"}), - pkInfo(nil, []string{"pcolb"}, []string{"colb"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - pkInfo(nil, []string{"pcolb", "pcolx"}, []string{"colb", "colx"}), - }, - SingleTableSet(1): { - pkInfo(nil, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - pkInfo(nil, []string{"pcold", "pcolc"}, []string{"cold", "colc"}), - pkInfo(nil, []string{"pcold"}, []string{"cold"}), - pkInfo(nil, []string{"pcold", "pcolx"}, []string{"cold", "colx"}), - }, - }, - allChildFks: nil, - updExprs: updateExprs, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{}, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - }, - SingleTableSet(1): { - pkInfo(nil, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - childFks, parentFks := tt.analyzer.filterForeignKeysUsingUpdateExpressions(tt.allChildFks, tt.allParentFks, tt.updExprs) - require.EqualValues(t, tt.childFksWanted, childFks) - require.EqualValues(t, tt.parentFksWanted, parentFks) - }) - } -} - -// TestGetInvolvedForeignKeys tests the functionality of getInvolvedForeignKeys. -func TestGetInvolvedForeignKeys(t *testing.T) { - cola := sqlparser.NewColName("cola") - colb := sqlparser.NewColName("colb") - colc := sqlparser.NewColName("colc") - cold := sqlparser.NewColName("cold") - tests := []struct { - name string - stmt sqlparser.Statement - analyzer *analyzer - childFksWanted map[TableSet][]vindexes.ChildFKInfo - parentFksWanted map[TableSet][]vindexes.ParentFKInfo - expectedErr string - }{ - { - name: "Delete Query", - stmt: &sqlparser.Delete{}, - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - }, - }, - }, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"col"}, []string{"col"}, sqlparser.Restrict), - ckInfo(nil, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), - }, - }, - }, - { - name: "Update statement", - stmt: &sqlparser.Update{ - Exprs: sqlparser.UpdateExprs{ - &sqlparser.UpdateExpr{ - Name: cola, - Expr: sqlparser.NewIntLiteral("1"), - }, - &sqlparser.UpdateExpr{ - Name: colb, - Expr: &sqlparser.NullVal{}, - }, - &sqlparser.UpdateExpr{ - Name: colc, - Expr: sqlparser.NewIntLiteral("1"), - }, - &sqlparser.UpdateExpr{ - Name: cold, - Expr: &sqlparser.NullVal{}, - }, - }, - }, - analyzer: &analyzer{ - binder: &binder{ - direct: map[sqlparser.Expr]TableSet{ - cola: SingleTableSet(0), - colb: SingleTableSet(0), - colc: SingleTableSet(1), - cold: SingleTableSet(1), - }, - }, - tables: &tableCollector{ - Tables: []TableInfo{ - &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks"}, - ChildForeignKeys: []vindexes.ChildFKInfo{ - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - ckInfo(nil, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - }, - ParentForeignKeys: []vindexes.ParentFKInfo{ - pkInfo(nil, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), - pkInfo(nil, []string{"pcolc"}, []string{"colc"}), - pkInfo(nil, []string{"pcolb", "pcola"}, []string{"colb", "cola"}), - pkInfo(nil, []string{"pcolb"}, []string{"colb"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - pkInfo(nil, []string{"pcolb", "pcolx"}, []string{"colb", "colx"}), - }, - }, - }, - &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks"}, - ChildForeignKeys: []vindexes.ChildFKInfo{ - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - ckInfo(nil, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), - ckInfo(nil, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), - }, - ParentForeignKeys: []vindexes.ParentFKInfo{ - pkInfo(nil, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - pkInfo(nil, []string{"pcold", "pcolc"}, []string{"cold", "colc"}), - pkInfo(nil, []string{"pcold"}, []string{"cold"}), - pkInfo(nil, []string{"pcold", "pcolx"}, []string{"cold", "colx"}), - }, - }, - }, - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, - }, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - }, - SingleTableSet(1): { - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - ckInfo(nil, []string{"colc", "colx"}, []string{"child_colc", "child_colx"}, sqlparser.SetNull), - }, - }, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"pcola", "pcolx"}, []string{"cola", "colx"}), - pkInfo(nil, []string{"pcola"}, []string{"cola"}), - }, - SingleTableSet(1): { - pkInfo(nil, []string{"pcolc", "pcolx"}, []string{"colc", "colx"}), - }, - }, - }, - { - name: "Replace Query", - stmt: &sqlparser.Insert{ - Action: sqlparser.ReplaceAct, - }, - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - }, - }, - }, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"col"}, []string{"col"}, sqlparser.Restrict), - ckInfo(nil, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), - }, - }, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - { - name: "Insert Query", - stmt: &sqlparser.Insert{ - Action: sqlparser.InsertAct, - }, - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t0"], - tbl["t1"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - }, - }, - }, - childFksWanted: nil, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - { - name: "Insert Query with On Duplicate", - stmt: &sqlparser.Insert{ - Action: sqlparser.InsertAct, - OnDup: sqlparser.OnDup{ - &sqlparser.UpdateExpr{ - Name: cola, - Expr: sqlparser.NewIntLiteral("1"), - }, - &sqlparser.UpdateExpr{ - Name: colb, - Expr: &sqlparser.NullVal{}, - }, - }, - }, - analyzer: &analyzer{ - binder: &binder{ - direct: map[sqlparser.Expr]TableSet{ - cola: SingleTableSet(0), - colb: SingleTableSet(0), - }, - }, - tables: &tableCollector{ - Tables: []TableInfo{ - &RealTable{ - Table: &vindexes.Table{ - Keyspace: &vindexes.Keyspace{Name: "ks"}, - ChildForeignKeys: []vindexes.ChildFKInfo{ - ckInfo(nil, []string{"col"}, []string{"col"}, sqlparser.Restrict), - ckInfo(nil, []string{"col1", "col2"}, []string{"ccol1", "ccol2"}, sqlparser.SetNull), - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - ckInfo(nil, []string{"colx", "coly"}, []string{"child_colx", "child_coly"}, sqlparser.Cascade), - ckInfo(nil, []string{"cold"}, []string{"child_cold"}, sqlparser.Restrict), - }, - ParentForeignKeys: []vindexes.ParentFKInfo{ - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - tbl["t1"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - }, - }, - }, - childFksWanted: map[TableSet][]vindexes.ChildFKInfo{ - SingleTableSet(0): { - ckInfo(nil, []string{"colb"}, []string{"child_colb"}, sqlparser.Restrict), - ckInfo(nil, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}, sqlparser.SetNull), - }, - }, - parentFksWanted: map[TableSet][]vindexes.ParentFKInfo{ - SingleTableSet(0): { - pkInfo(nil, []string{"colb"}, []string{"colb"}), - pkInfo(nil, []string{"colb1", "colb2"}, []string{"ccolb1", "ccolb2"}), - }, - }, - }, - { - name: "Insert error", - stmt: &sqlparser.Insert{}, - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, - }, - expectedErr: "undefined_ks keyspace not found", - }, - { - name: "Update error", - stmt: &sqlparser.Update{}, - analyzer: &analyzer{ - tables: &tableCollector{ - Tables: []TableInfo{ - tbl["t2"], - tbl["t3"], - }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, - }, - }, - }, - expectedErr: "undefined_ks keyspace not found", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - childFks, parentFks, err := tt.analyzer.getInvolvedForeignKeys(tt.stmt) - if tt.expectedErr != "" { - require.EqualError(t, err, tt.expectedErr) - return - } - require.EqualValues(t, tt.childFksWanted, childFks) - require.EqualValues(t, tt.parentFksWanted, parentFks) - }) - } -} - -func ckInfo(cTable *vindexes.Table, pCols []string, cCols []string, refAction sqlparser.ReferenceAction) vindexes.ChildFKInfo { - return vindexes.ChildFKInfo{ - Table: cTable, - ParentColumns: sqlparser.MakeColumns(pCols...), - ChildColumns: sqlparser.MakeColumns(cCols...), - OnDelete: refAction, - } -} - -func pkInfo(parentTable *vindexes.Table, pCols []string, cCols []string) vindexes.ParentFKInfo { - return vindexes.ParentFKInfo{ - Table: parentTable, - ParentColumns: sqlparser.MakeColumns(pCols...), - ChildColumns: sqlparser.MakeColumns(cCols...), - } -} diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index d3c5f312426..0349ef3f79d 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -47,6 +47,8 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { handleOrderBy(r, cursor, node) case *sqlparser.OrExpr: rewriteOrExpr(cursor, node) + case *sqlparser.AndExpr: + rewriteAndExpr(cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) case sqlparser.GroupBy: @@ -176,6 +178,45 @@ func rewriteOrExpr(cursor *sqlparser.Cursor, node *sqlparser.OrExpr) { } } +// rewriteAndExpr rewrites AND expressions when either side is TRUE. +func rewriteAndExpr(cursor *sqlparser.Cursor, node *sqlparser.AndExpr) { + newNode := rewriteAndTrue(*node) + if newNode != nil { + cursor.ReplaceAndRevisit(newNode) + } +} + +func rewriteAndTrue(andExpr sqlparser.AndExpr) sqlparser.Expr { + // we are looking for the pattern `WHERE c = 1 AND 1 = 1` + isTrue := func(subExpr sqlparser.Expr) bool { + evalEnginePred, err := evalengine.Translate(subExpr, nil) + if err != nil { + return false + } + + env := evalengine.EmptyExpressionEnv() + res, err := env.Evaluate(evalEnginePred) + if err != nil { + return false + } + + boolValue, err := res.Value(collations.Default()).ToBool() + if err != nil { + return false + } + + return boolValue + } + + if isTrue(andExpr.Left) { + return andExpr.Right + } else if isTrue(andExpr.Right) { + return andExpr.Left + } + + return nil +} + // handleLiteral processes literals within the context of ORDER BY expressions. func handleLiteral(r *earlyRewriter, cursor *sqlparser.Cursor, node *sqlparser.Literal) error { newNode, err := r.rewriteOrderByExpr(node) diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 94b1302b357..48ab4322bc8 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -137,6 +137,7 @@ type ( // The map is keyed by the tableset of the table that each of the foreign key belongs to. childForeignKeysInvolved map[TableSet][]vindexes.ChildFKInfo parentForeignKeysInvolved map[TableSet][]vindexes.ParentFKInfo + childFkToUpdExprs map[string]sqlparser.UpdateExprs } columnName struct { @@ -196,6 +197,11 @@ func (st *SemTable) GetParentForeignKeysList() []vindexes.ParentFKInfo { return parentFkInfos } +// GetUpdateExpressionsForFk gets the update expressions for the given serialized foreign key constraint. +func (st *SemTable) GetUpdateExpressionsForFk(foreignKey string) sqlparser.UpdateExprs { + return st.childFkToUpdExprs[foreignKey] +} + // RemoveParentForeignKey removes the given foreign key from the parent foreign keys that sem table stores. func (st *SemTable) RemoveParentForeignKey(fkToIgnore string) error { for ts, fkInfos := range st.parentForeignKeysInvolved { @@ -279,6 +285,89 @@ func (st *SemTable) RemoveNonRequiredForeignKeys(verifyAllFks bool, getAction fu return nil } +// ErrIfFkDependentColumnUpdated checks if a foreign key column that is being updated is dependent on another column which also being updated. +func (st *SemTable) ErrIfFkDependentColumnUpdated(updateExprs sqlparser.UpdateExprs) error { + // Go over all the update expressions + for _, updateExpr := range updateExprs { + deps := st.RecursiveDeps(updateExpr.Name) + if deps.NumberOfTables() != 1 { + panic("expected to have single table dependency") + } + // Get all the child and parent foreign keys for the given table that the update expression belongs to. + childFks := st.childForeignKeysInvolved[deps] + parentFKs := st.parentForeignKeysInvolved[deps] + + involvedInFk := false + // Check if this updated column is part of any child or parent foreign key. + for _, childFk := range childFks { + if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { + involvedInFk = true + break + } + } + for _, parentFk := range parentFKs { + if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { + involvedInFk = true + break + } + } + + if !involvedInFk { + continue + } + + // We cannot support updating a foreign key column that is using a column which is also being updated for 2 reasons— + // 1. For the child foreign keys, we aren't sure what the final value of the updated foreign key column will be. So we don't know + // what to cascade to the child. The selection that we do isn't enough to know if the updated value, since one of the columns used in the update is also being updated. + // 2. For the parent foreign keys, we don't know if we need to reject this update. Because we don't know the final updated value, the update might need to be failed, + // but we can't say for certain. + var dependencyUpdatedErr error + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + // self reference column dependency is not considered a dependent column being updated. + if st.EqualsExpr(updateExpr.Name, col) { + return true, nil + } + for _, updExpr := range updateExprs { + if st.EqualsExpr(updExpr.Name, col) { + dependencyUpdatedErr = vterrors.VT12001(fmt.Sprintf("%v column referenced in foreign key column %v is itself updated", sqlparser.String(col), sqlparser.String(updateExpr.Name))) + return false, nil + } + } + return false, nil + }, updateExpr.Expr) + if dependencyUpdatedErr != nil { + return dependencyUpdatedErr + } + } + return nil +} + +// HasNonLiteralForeignKeyUpdate checks for non-literal updates in expressions linked to a foreign key. +func (st *SemTable) HasNonLiteralForeignKeyUpdate(updExprs sqlparser.UpdateExprs) bool { + for _, updateExpr := range updExprs { + if sqlparser.IsLiteral(updateExpr.Expr) { + continue + } + parentFks := st.parentForeignKeysInvolved[st.RecursiveDeps(updateExpr.Name)] + for _, parentFk := range parentFks { + if parentFk.ChildColumns.FindColumn(updateExpr.Name.Name) >= 0 { + return true + } + } + childFks := st.childForeignKeysInvolved[st.RecursiveDeps(updateExpr.Name)] + for _, childFk := range childFks { + if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { + return true + } + } + } + return false +} + // isShardScoped checks if the foreign key constraint is shard-scoped or not. It uses the vindex information to make this call. func isShardScoped(pTable *vindexes.Table, cTable *vindexes.Table, pCols sqlparser.Columns, cCols sqlparser.Columns) bool { if !pTable.Keyspace.Sharded { diff --git a/go/vt/vtgate/semantics/semantic_state_test.go b/go/vt/vtgate/semantics/semantic_state_test.go index ab855322d76..b904f3656de 100644 --- a/go/vt/vtgate/semantics/semantic_state_test.go +++ b/go/vt/vtgate/semantics/semantic_state_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" querypb "vitess.io/vitess/go/vt/proto/query" + vschemapb "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -418,7 +419,7 @@ func TestRemoveParentForeignKey(t *testing.T) { }, }, }, - fkToIgnore: "ks.t2child_coldks.t3cold", + fkToIgnore: "ks.t2|child_cold||ks.t3|cold", parentFksWanted: []vindexes.ParentFKInfo{ pkInfo(t3Table, []string{"colb"}, []string{"child_colb"}), pkInfo(t3Table, []string{"cola", "colx"}, []string{"child_cola", "child_colx"}), @@ -748,3 +749,233 @@ func TestRemoveNonRequiredForeignKeys(t *testing.T) { }) } } + +func TestIsFkDependentColumnUpdated(t *testing.T) { + keyspaceName := "ks" + t3Table := &vindexes.Table{ + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + Name: sqlparser.NewIdentifierCS("t3"), + } + tests := []struct { + name string + query string + fakeSi *FakeSI + updatedErr string + }{ + { + name: "updated child foreign key column is dependent on another updated column", + query: "update t1 set col = id + 1, id = 6 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(t3Table, []string{"col"}, []string{"col"}, sqlparser.Cascade), + }, + }, + }, + }, + updatedErr: "VT12001: unsupported: id column referenced in foreign key column col is itself updated", + }, { + name: "updated parent foreign key column is dependent on another updated column", + query: "update t1 set col = id + 1, id = 6 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + updatedErr: "VT12001: unsupported: id column referenced in foreign key column col is itself updated", + }, { + name: "no foreign key column is dependent on a updated value", + query: "update t1 set col = id + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + updatedErr: "", + }, { + name: "self-referenced foreign key", + query: "update t1 set col = col + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + updatedErr: "", + }, { + name: "no foreign keys", + query: "update t1 set col = id + 1, id = 6 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + }, + }, + }, + updatedErr: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := sqlparser.Parse(tt.query) + require.NoError(t, err) + semTable, err := Analyze(stmt, keyspaceName, tt.fakeSi) + require.NoError(t, err) + got := semTable.ErrIfFkDependentColumnUpdated(stmt.(*sqlparser.Update).Exprs) + if tt.updatedErr == "" { + require.NoError(t, got) + } else { + require.EqualError(t, got, tt.updatedErr) + } + }) + } +} + +func TestHasNonLiteralForeignKeyUpdate(t *testing.T) { + keyspaceName := "ks" + t3Table := &vindexes.Table{ + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + Name: sqlparser.NewIdentifierCS("t3"), + } + tests := []struct { + name string + query string + fakeSi *FakeSI + hasNonLiteral bool + }{ + { + name: "non literal child foreign key update", + query: "update t1 set col = id + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ChildForeignKeys: []vindexes.ChildFKInfo{ + ckInfo(t3Table, []string{"col"}, []string{"col"}, sqlparser.Cascade), + }, + }, + }, + }, + hasNonLiteral: true, + }, { + name: "non literal parent foreign key update", + query: "update t1 set col = id + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + hasNonLiteral: true, + }, { + name: "literal updates only", + query: "update t1 set col = 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + hasNonLiteral: false, + }, { + name: "self-referenced foreign key", + query: "update t1 set col = col + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + ParentForeignKeys: []vindexes.ParentFKInfo{ + pkInfo(t3Table, []string{"col"}, []string{"col"}), + }, + }, + }, + }, + hasNonLiteral: true, + }, { + name: "no foreign keys", + query: "update t1 set col = id + 1 where foo = 3", + fakeSi: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + keyspaceName: vschemapb.Keyspace_managed, + }, + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + }, + }, + }, + hasNonLiteral: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stmt, err := sqlparser.Parse(tt.query) + require.NoError(t, err) + semTable, err := Analyze(stmt, keyspaceName, tt.fakeSi) + require.NoError(t, err) + got := semTable.HasNonLiteralForeignKeyUpdate(stmt.(*sqlparser.Update).Exprs) + require.EqualValues(t, tt.hasNonLiteral, got) + }) + } +} diff --git a/go/vt/vtgate/vindexes/foreign_keys.go b/go/vt/vtgate/vindexes/foreign_keys.go index db984462b25..74f9ce74844 100644 --- a/go/vt/vtgate/vindexes/foreign_keys.go +++ b/go/vt/vtgate/vindexes/foreign_keys.go @@ -46,13 +46,13 @@ func (fk *ParentFKInfo) MarshalJSON() ([]byte, error) { func (fk *ParentFKInfo) String(childTable *Table) string { var str strings.Builder - str.WriteString(childTable.String()) + str.WriteString(sqlparser.String(childTable.GetTableName())) for _, column := range fk.ChildColumns { - str.WriteString(column.String()) + str.WriteString("|" + sqlparser.String(column)) } - str.WriteString(fk.Table.String()) + str.WriteString("||" + sqlparser.String(fk.Table.GetTableName())) for _, column := range fk.ParentColumns { - str.WriteString(column.String()) + str.WriteString("|" + sqlparser.String(column)) } return str.String() } @@ -91,13 +91,13 @@ func (fk *ChildFKInfo) MarshalJSON() ([]byte, error) { func (fk *ChildFKInfo) String(parentTable *Table) string { var str strings.Builder - str.WriteString(fk.Table.String()) + str.WriteString(sqlparser.String(fk.Table.GetTableName())) for _, column := range fk.ChildColumns { - str.WriteString(column.String()) + str.WriteString("|" + sqlparser.String(column)) } - str.WriteString(parentTable.String()) + str.WriteString("||" + sqlparser.String(parentTable.GetTableName())) for _, column := range fk.ParentColumns { - str.WriteString(column.String()) + str.WriteString("|" + sqlparser.String(column)) } return str.String() }