Skip to content

Commit

Permalink
Cherry-pick 9a78e7d with conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
vitess-bot[bot] authored and dbussink committed Feb 19, 2024
1 parent fb28fc0 commit 169a3b9
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
18 changes: 6 additions & 12 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error {
return mysqld.executeMysqlScript(params, strings.NewReader(sql))
}

func encodeEntityName(name string) string {
var buf strings.Builder
sqltypes.NewVarChar(name).EncodeSQL(&buf)
return buf.String()
}

// tableListSQL returns an IN clause "('t1', 't2'...) for a list of tables."
func tableListSQL(tables []string) (string, error) {
if len(tables) == 0 {
Expand All @@ -74,7 +68,7 @@ func tableListSQL(tables []string) (string, error) {

encodedTables := make([]string, len(tables))
for i, tableName := range tables {
encodedTables[i] = encodeEntityName(tableName)
encodedTables[i] = sqltypes.EncodeStringSQL(tableName)
}

return "(" + strings.Join(encodedTables, ", ") + ")", nil
Expand Down Expand Up @@ -302,9 +296,9 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql
if dbName == "" {
dbName2 = "database()"
} else {
dbName2 = encodeEntityName(dbName)
dbName2 = sqltypes.EncodeStringSQL(dbName)
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName)))
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqltypes.EncodeStringSQL(sqlescape.UnescapeID(tableName)))
qr, err := exec(query, -1, true)
if err != nil {
return "", err
Expand Down Expand Up @@ -391,7 +385,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary'
ORDER BY table_name, SEQ_IN_INDEX`
sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList)
sql = fmt.Sprintf(sql, sqltypes.EncodeStringSQL(dbName), tableList)
qr, err := conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
Expand Down Expand Up @@ -620,8 +614,8 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName
) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME
WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES'
ORDER BY SEQ_IN_INDEX ASC`
encodedDbName := encodeEntityName(dbName)
encodedTable := encodeEntityName(table)
encodedDbName := sqltypes.EncodeStringSQL(dbName)
encodedTable := sqltypes.EncodeStringSQL(table)
sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable)
qr, err := conn.ExecuteFetch(sql, 1000, true)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet
}
tEnv.addResult(query, tEnv.getResult(likeQuery))

likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(likeTable))
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table))
likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sqlescape.UnescapeID(likeTable)))
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sqlescape.UnescapeID(table)))
if tEnv.getResult(likeQuery) == nil {
return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable)
}
Expand Down Expand Up @@ -496,7 +496,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet
tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{
Fields: rowTypes,
})
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table))
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(sqlescape.UnescapeID(table)))
tEnv.addResult(query, &sqltypes.Result{
Fields: colTypes,
Rows: colValues,
Expand Down Expand Up @@ -598,7 +598,7 @@ func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) {

// Gen4 supports more complex queries so we now need to
// handle multiple FROM clauses
tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From))
tables := make([]*sqlparser.AliasedTableExpr, 0, len(selStmt.From))
for _, from := range selStmt.From {
tables = append(tables, getTables(from)...)
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtexplain/vtexplain_vttablet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ create table t2 (
require.NoError(t, err)
defer vte.Stop()

// Check if the correct schema query is registered.
_, found := vte.globalTabletEnv.schemaQueries["SELECT COLUMN_NAME as column_name\n\t\tFROM INFORMATION_SCHEMA.COLUMNS\n\t\tWHERE TABLE_SCHEMA = database() AND TABLE_NAME = 't1'\n\t\tORDER BY ORDINAL_POSITION"]
assert.True(t, found)

sql := "SELECT * FROM t1 INNER JOIN t2 ON t1.id = t2.id"

_, err = vte.Run(sql)
Expand Down

0 comments on commit 169a3b9

Please sign in to comment.