From f751c8323ff52c90f288481b0bd92192f1734973 Mon Sep 17 00:00:00 2001 From: Manik Rana Date: Sat, 27 Jan 2024 19:07:50 +0530 Subject: [PATCH] fix: `Unescape(Escape(str))` now returns the original string (#15009) Signed-off-by: Manik Rana Signed-off-by: Manik Rana Co-authored-by: Shlomi Noach <2607934+shlomi-noach@users.noreply.github.com> --- go/sqlescape/ids.go | 65 ++++++- go/sqlescape/ids_test.go | 181 +++++++++++++++++- go/vt/mysqlctl/schema.go | 17 +- go/vt/vtctl/workflow/traffic_switcher.go | 38 ++-- go/vt/vtexplain/vtexplain_vttablet.go | 18 +- go/vt/vtgate/vindexes/vschema.go | 14 +- .../tabletserver/vstreamer/rowstreamer.go | 7 +- go/vt/wrangler/traffic_switcher.go | 39 ++-- 8 files changed, 329 insertions(+), 50 deletions(-) diff --git a/go/sqlescape/ids.go b/go/sqlescape/ids.go index 3983db13362..a70d48c1cd2 100644 --- a/go/sqlescape/ids.go +++ b/go/sqlescape/ids.go @@ -14,6 +14,7 @@ limitations under the License. package sqlescape import ( + "fmt" "strings" ) @@ -52,11 +53,65 @@ func EscapeIDs(identifiers []string) []string { return result } -// UnescapeID reverses any backticking in the input string. -func UnescapeID(in string) string { +// UnescapeID reverses any backticking in the input string by EscapeID. +func UnescapeID(in string) (string, error) { l := len(in) - if l >= 2 && in[0] == '`' && in[l-1] == '`' { - return in[1 : l-1] + + if l == 0 || in == "``" { + return "", fmt.Errorf("UnescapeID err: invalid input identifier '%s'", in) + + } + + if l == 1 { + if in[0] == '`' { + return "", fmt.Errorf("UnescapeID err: invalid input identifier '`'") + } + return in, nil + } + + first, last := in[0], in[l-1] + + if first == '`' && last != '`' { + return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", 0, in) + } + if first != '`' && last == '`' { + return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", l, in) + } + if first != '`' && last != '`' { + if idx := strings.IndexByte(in, '`'); idx != -1 { + return "", fmt.Errorf("UnescapeID err: no outer backticks found in the identifier '%s'", in) + } + return in, nil + } + + in = in[1 : l-1] + + if idx := strings.IndexByte(in, '`'); idx == -1 { + return in, nil + } + + var buf strings.Builder + buf.Grow(len(in)) + + for i := 0; i < len(in); i++ { + buf.WriteByte(in[i]) + + if i < len(in)-1 && in[i] == '`' { + if in[i+1] == '`' { + i++ // halves the number of backticks + } else { + return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", i, in) + } + } + } + + return buf.String(), nil +} + +func EnsureEscaped(in string) (string, error) { + out, err := UnescapeID(in) + if err != nil { + return "", err } - return in + return EscapeID(out), nil } diff --git a/go/sqlescape/ids_test.go b/go/sqlescape/ids_test.go index a2d2e69be6f..37b14206416 100644 --- a/go/sqlescape/ids_test.go +++ b/go/sqlescape/ids_test.go @@ -15,6 +15,9 @@ package sqlescape import ( "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEscapeID(t *testing.T) { @@ -26,12 +29,182 @@ func TestEscapeID(t *testing.T) { }, { in: "a`a", out: "`a``a`", + }, { + in: "`fo`o`", + out: "```fo``o```", + }, { + in: "", + out: "``", }} for _, tc := range testcases { - out := EscapeID(tc.in) - if out != tc.out { - t.Errorf("EscapeID(%s): %s, want %s", tc.in, out, tc.out) - } + t.Run(tc.in, func(t *testing.T) { + out := EscapeID(tc.in) + assert.Equal(t, out, tc.out) + }) + } +} + +func TestUnescapeID(t *testing.T) { + testcases := []struct { + in, out string + err bool + }{ + { + in: "``", + out: "", + err: true, + }, + { + in: "a", + out: "a", + err: false, + }, + { + in: "`aa`", + out: "aa", + err: false, + }, + { + in: "`a``a`", + out: "a`a", + err: false, + }, + { + in: "`foo", + out: "", + err: true, + }, + { + in: "foo`", + out: "", + err: true, + }, + { + in: "`fo`o", + out: "", + err: true, + }, + { + in: "`fo`o`", + out: "", + err: true, + }, + { + in: "``fo``o``", + out: "", + err: true, + }, + { + in: "```fo``o```", + out: "`fo`o`", + err: false, + }, + { + in: "```fo`o```", + out: "", + err: true, + }, + { + in: "foo", + out: "foo", + err: false, + }, + { + in: "f`oo", + out: "", + err: true, + }, + { + in: "", + out: "", + err: true, + }, + { + in: "`", + out: "", + err: true, + }, + } + for _, tc := range testcases { + t.Run(tc.in, func(t *testing.T) { + out, err := UnescapeID(tc.in) + if tc.err { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.out, out, "output mismatch") + } + }) + } +} + +func TestEnsureEscaped(t *testing.T) { + tt := []struct { + in string + out string + err bool + }{ + { + in: "", + out: "", + err: true, + }, + { + in: "foo", + out: "`foo`", + err: false, + }, + { + in: "`foo`", + out: "`foo`", + err: false, + }, + { + in: "```fo``o```", + out: "```fo``o```", + err: false, + }, + { + in: "`fo``o`", + out: "`fo``o`", + err: false, + }, + { + in: "f`oo", + out: "", + err: true, + }, + { + in: "`fo`o", + out: "", + err: true, + }, + { + in: "`foo", + out: "", + err: true, + }, + { + in: "foo`", + out: "", + err: true, + }, + { + in: "`fo`o`", + out: "", + err: true, + }, + } + for _, tc := range tt { + t.Run(tc.in, func(t *testing.T) { + out, err := EnsureEscaped(tc.in) + if tc.err { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.out, out, "output mismatch") + } + }) } } diff --git a/go/vt/mysqlctl/schema.go b/go/vt/mysqlctl/schema.go index f3325827ab9..c7ca98c4917 100644 --- a/go/vt/mysqlctl/schema.go +++ b/go/vt/mysqlctl/schema.go @@ -309,7 +309,11 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql } else { dbName2 = encodeEntityName(dbName) } - query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName))) + sanitizedTableName, err := sqlescape.UnescapeID(tableName) + if err != nil { + return "", err + } + query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sanitizedTableName)) qr, err := exec(query, -1, true) if err != nil { return "", err @@ -342,9 +346,16 @@ func GetColumns(dbName, table string, exec func(string, int, bool) (*sqltypes.Re if selectColumns == "" { selectColumns = "*" } - tableSpec := sqlescape.EscapeID(sqlescape.UnescapeID(table)) + tableSpec, err := sqlescape.EnsureEscaped(table) + if err != nil { + return nil, nil, err + } if dbName != "" { - tableSpec = fmt.Sprintf("%s.%s", sqlescape.EscapeID(sqlescape.UnescapeID(dbName)), tableSpec) + dbName, err := sqlescape.EnsureEscaped(dbName) + if err != nil { + return nil, nil, err + } + tableSpec = fmt.Sprintf("%s.%s", dbName, tableSpec) } query := fmt.Sprintf(GetFieldsQuery, selectColumns, tableSpec) qr, err := exec(query, 0, true) diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go index 47a29100cd5..890ffd098a0 100644 --- a/go/vt/vtctl/workflow/traffic_switcher.go +++ b/go/vt/vtctl/workflow/traffic_switcher.go @@ -487,23 +487,29 @@ func (ts *trafficSwitcher) dropParticipatingTablesFromKeyspace(ctx context.Conte func (ts *trafficSwitcher) removeSourceTables(ctx context.Context, removalType TableRemovalType) error { err := ts.ForAllSources(func(source *MigrationSource) error { for _, tableName := range ts.Tables() { - query := fmt.Sprintf("drop table %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName))) + primaryDbName, err := sqlescape.EnsureEscaped(source.GetPrimary().DbName()) + if err != nil { + return err + } + tableNameEscaped, err := sqlescape.EnsureEscaped(tableName) + if err != nil { + return err + } + + query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableNameEscaped) if removalType == DropTable { ts.Logger().Infof("%s: Dropping table %s.%s\n", source.GetPrimary().String(), source.GetPrimary().DbName(), tableName) } else { - renameName := getRenameFileName(tableName) + renameName, err := sqlescape.EnsureEscaped(getRenameFileName(tableName)) + if err != nil { + return err + } ts.Logger().Infof("%s: Renaming table %s.%s to %s.%s\n", source.GetPrimary().String(), source.GetPrimary().DbName(), tableName, source.GetPrimary().DbName(), renameName) - query = fmt.Sprintf("rename table %s.%s TO %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName)), - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(renameName))) + query = fmt.Sprintf("rename table %s.%s TO %s.%s", primaryDbName, tableNameEscaped, primaryDbName, renameName) } - _, err := ts.ws.tmc.ExecuteFetchAsDba(ctx, source.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{ + _, err = ts.ws.tmc.ExecuteFetchAsDba(ctx, source.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{ Query: []byte(query), MaxRows: 1, ReloadSchema: true, @@ -1065,9 +1071,15 @@ func (ts *trafficSwitcher) removeTargetTables(ctx context.Context) error { err := ts.ForAllTargets(func(target *MigrationTarget) error { log.Infof("ForAllTargets: %+v", target) for _, tableName := range ts.Tables() { - query := fmt.Sprintf("drop table %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(target.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName))) + primaryDbName, err := sqlescape.EnsureEscaped(target.GetPrimary().DbName()) + if err != nil { + return err + } + tableName, err := sqlescape.EnsureEscaped(tableName) + if err != nil { + return err + } + query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableName) ts.Logger().Infof("%s: Dropping table %s.%s\n", target.GetPrimary().String(), target.GetPrimary().DbName(), tableName) res, err := ts.ws.tmc.ExecuteFetchAsDba(ctx, target.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{ diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index a227488925d..b04365a3d0a 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -454,10 +454,18 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio indexRows := make([][]sqltypes.Value, 0, 4) for _, ddl := range ddls { table := sqlparser.String(ddl.GetTable().Name) - backtickedTable := sqlescape.EscapeID(sqlescape.UnescapeID(table)) + sanitizedTable, err := sqlescape.UnescapeID(table) + if err != nil { + return nil, err + } + backtickedTable := sqlescape.EscapeID(sanitizedTable) if ddl.GetOptLike() != nil { likeTable := ddl.GetOptLike().LikeTable.Name.String() - backtickedLikeTable := sqlescape.EscapeID(sqlescape.UnescapeID(likeTable)) + sanitizedLikeTable, err := sqlescape.UnescapeID(likeTable) + if err != nil { + return nil, err + } + backtickedLikeTable := sqlescape.EscapeID(sanitizedLikeTable) likeQuery := "SELECT * FROM " + backtickedLikeTable + " WHERE 1 != 1" query := "SELECT * FROM " + backtickedTable + " WHERE 1 != 1" @@ -466,8 +474,8 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio } tEnv.addResult(query, tEnv.getResult(likeQuery)) - likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(likeTable)) - query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table)) + likeQuery = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedLikeTable) + query = fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable) if tEnv.getResult(likeQuery) == nil { return nil, fmt.Errorf("check your schema, table[%s] doesn't exist", likeTable) } @@ -508,7 +516,7 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio tEnv.addResult("SELECT * FROM "+backtickedTable+" WHERE 1 != 1", &sqltypes.Result{ Fields: rowTypes, }) - query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sqlescape.UnescapeID(table)) + query := fmt.Sprintf(mysqlctl.GetColumnNamesQuery, "database()", sanitizedTable) tEnv.addResult(query, &sqltypes.Result{ Fields: colTypes, Rows: colValues, diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index c20f5561566..e1044d0136d 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -854,10 +854,16 @@ func escapeQualifiedTable(qualifiedTableName string) (string, error) { if err != nil { return "", err } - return fmt.Sprintf("%s.%s", - // unescape() first in case an already escaped string was passed - sqlescape.EscapeID(sqlescape.UnescapeID(keyspace)), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName))), nil + // unescape() first in case an already escaped string was passed + keyspace, err = sqlescape.EnsureEscaped(keyspace) + if err != nil { + return "", err + } + tableName, err = sqlescape.EnsureEscaped(tableName) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", keyspace, tableName), nil } func extractTableParts(tableName string, allowUnqualified bool) (string, string, error) { diff --git a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go index 712248d7470..c1685c61d13 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go @@ -273,8 +273,11 @@ func (rs *rowStreamer) buildSelect(st *binlogdatapb.MinimalTable) (string, error // of the PK columns which are used in the ORDER BY clause below. var indexHint string if st.PKIndexName != "" { - indexHint = fmt.Sprintf(" force index (%s)", - sqlescape.EscapeID(sqlescape.UnescapeID(st.PKIndexName))) + escapedPKIndexName, err := sqlescape.EnsureEscaped(st.PKIndexName) + if err != nil { + return "", err + } + indexHint = fmt.Sprintf(" force index (%s)", escapedPKIndexName) } buf.Myprintf(" from %v%s", sqlparser.NewIdentifierCS(rs.plan.Table.Name), indexHint) if len(rs.lastpk) != 0 { diff --git a/go/vt/wrangler/traffic_switcher.go b/go/vt/wrangler/traffic_switcher.go index e25b94f99e2..d2204c09cab 100644 --- a/go/vt/wrangler/traffic_switcher.go +++ b/go/vt/wrangler/traffic_switcher.go @@ -1769,23 +1769,28 @@ func getRenameFileName(tableName string) string { func (ts *trafficSwitcher) removeSourceTables(ctx context.Context, removalType workflow.TableRemovalType) error { err := ts.ForAllSources(func(source *workflow.MigrationSource) error { for _, tableName := range ts.Tables() { - query := fmt.Sprintf("drop table %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName))) + primaryDbName, err := sqlescape.EnsureEscaped(source.GetPrimary().DbName()) + if err != nil { + return err + } + tableNameEscaped, err := sqlescape.EnsureEscaped(tableName) + if err != nil { + return err + } + query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableNameEscaped) if removalType == workflow.DropTable { ts.Logger().Infof("%s: Dropping table %s.%s\n", source.GetPrimary().String(), source.GetPrimary().DbName(), tableName) } else { - renameName := getRenameFileName(tableName) + renameName, err := sqlescape.EnsureEscaped(getRenameFileName(tableName)) + if err != nil { + return err + } ts.Logger().Infof("%s: Renaming table %s.%s to %s.%s\n", source.GetPrimary().String(), source.GetPrimary().DbName(), tableName, source.GetPrimary().DbName(), renameName) - query = fmt.Sprintf("rename table %s.%s TO %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName)), - sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(renameName))) + query = fmt.Sprintf("rename table %s.%s TO %s.%s", primaryDbName, tableNameEscaped, primaryDbName, renameName) } - _, err := ts.wr.ExecuteFetchAsDba(ctx, source.GetPrimary().Alias, query, 1, false, true) + _, err = ts.wr.ExecuteFetchAsDba(ctx, source.GetPrimary().Alias, query, 1, false, true) if err != nil { ts.Logger().Errorf("%s: Error removing table %s: %v", source.GetPrimary().String(), tableName, err) return err @@ -1880,12 +1885,18 @@ func (ts *trafficSwitcher) removeTargetTables(ctx context.Context) error { log.Infof("removeTargetTables") err := ts.ForAllTargets(func(target *workflow.MigrationTarget) error { for _, tableName := range ts.Tables() { - query := fmt.Sprintf("drop table %s.%s", - sqlescape.EscapeID(sqlescape.UnescapeID(target.GetPrimary().DbName())), - sqlescape.EscapeID(sqlescape.UnescapeID(tableName))) + primaryDbName, err := sqlescape.EnsureEscaped(target.GetPrimary().DbName()) + if err != nil { + return err + } + tableName, err := sqlescape.EnsureEscaped(tableName) + if err != nil { + return err + } + query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableName) ts.Logger().Infof("%s: Dropping table %s.%s\n", target.GetPrimary().String(), target.GetPrimary().DbName(), tableName) - _, err := ts.wr.ExecuteFetchAsDba(ctx, target.GetPrimary().Alias, query, 1, false, true) + _, err = ts.wr.ExecuteFetchAsDba(ctx, target.GetPrimary().Alias, query, 1, false, true) if err != nil { ts.Logger().Errorf("%s: Error removing table %s: %v", target.GetPrimary().String(), tableName, err)