Skip to content

Commit

Permalink
Merge pull request #99 from github/hm/backport-vtexplain-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hmaurer authored Mar 21, 2024
2 parents ea0e19c + bab4912 commit a8bc7c0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
18 changes: 6 additions & 12 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,6 @@ func (mysqld *Mysqld) executeSchemaCommands(sql string) error {
return mysqld.executeMysqlScript(params, strings.NewReader(sql))
}

func encodeEntityName(name string) string {
var buf strings.Builder
sqltypes.NewVarChar(name).EncodeSQL(&buf)
return buf.String()
}

// tableListSQL returns an IN clause "('t1', 't2'...) for a list of tables."
func tableListSQL(tables []string) (string, error) {
if len(tables) == 0 {
Expand All @@ -74,7 +68,7 @@ func tableListSQL(tables []string) (string, error) {

encodedTables := make([]string, len(tables))
for i, tableName := range tables {
encodedTables[i] = encodeEntityName(tableName)
encodedTables[i] = sqltypes.EncodeStringSQL(tableName)
}

return "(" + strings.Join(encodedTables, ", ") + ")", nil
Expand Down Expand Up @@ -304,9 +298,9 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql
if dbName == "" {
dbName2 = "database()"
} else {
dbName2 = encodeEntityName(dbName)
dbName2 = sqltypes.EncodeStringSQL(dbName)
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName)))
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, sqltypes.EncodeStringSQL(sqlescape.UnescapeID(tableName)))
qr, err := exec(query, -1, true)
if err != nil {
return "", err
Expand Down Expand Up @@ -393,7 +387,7 @@ func (mysqld *Mysqld) getPrimaryKeyColumns(ctx context.Context, dbName string, t
FROM information_schema.STATISTICS
WHERE TABLE_SCHEMA = %s AND TABLE_NAME IN %s AND LOWER(INDEX_NAME) = 'primary'
ORDER BY table_name, SEQ_IN_INDEX`
sql = fmt.Sprintf(sql, encodeEntityName(dbName), tableList)
sql = fmt.Sprintf(sql, sqltypes.EncodeStringSQL(dbName), tableList)
qr, err := conn.ExecuteFetch(sql, len(tables)*100, true)
if err != nil {
return nil, err
Expand Down Expand Up @@ -622,8 +616,8 @@ func (mysqld *Mysqld) GetPrimaryKeyEquivalentColumns(ctx context.Context, dbName
) AS pke ON index_cols.INDEX_NAME = pke.INDEX_NAME
WHERE index_cols.TABLE_SCHEMA = %s AND index_cols.TABLE_NAME = %s AND NON_UNIQUE = 0 AND NULLABLE != 'YES'
ORDER BY SEQ_IN_INDEX ASC`
encodedDbName := encodeEntityName(dbName)
encodedTable := encodeEntityName(table)
encodedDbName := sqltypes.EncodeStringSQL(dbName)
encodedTable := sqltypes.EncodeStringSQL(table)
sql = fmt.Sprintf(sql, encodedDbName, encodedTable, encodedDbName, encodedTable, encodedDbName, encodedTable)
qr, err := conn.ExecuteFetch(sql, 1000, true)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet
}
tEnv.addResult(query, tEnv.getResult(likeQuery))

likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(likeTable))
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table))
likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(likeTable))
query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(table))
if tEnv.getResult(likeQuery) == nil {
return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable)
}
Expand Down Expand Up @@ -477,7 +477,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options) (*tablet
tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{
Fields: rowTypes,
})
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table))
query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqltypes.EncodeStringSQL(table))
tEnv.addResult(query, &sqltypes.Result{
Fields: colTypes,
Rows: colValues,
Expand Down Expand Up @@ -558,7 +558,7 @@ func (t *explainTablet) handleSelect(query string) (*sqltypes.Result, error) {

// Gen4 supports more complex queries so we now need to
// handle multiple FROM clauses
tables := make([]*sqlparser.AliasedTableExpr, len(selStmt.From))
tables := make([]*sqlparser.AliasedTableExpr, 0, len(selStmt.From))
for _, from := range selStmt.From {
tables = append(tables, getTables(from)...)
}
Expand Down

0 comments on commit a8bc7c0

Please sign in to comment.