Skip to content

Commit

Permalink
fix: Unescape(Escape(str)) now returns the original string (#15009)
Browse files Browse the repository at this point in the history
Signed-off-by: Manik Rana <[email protected]>
Signed-off-by: Manik Rana <[email protected]>
Co-authored-by: Shlomi Noach <[email protected]>
  • Loading branch information
Maniktherana and shlomi-noach authored Jan 27, 2024
1 parent 44d6a6b commit f751c83
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 50 deletions.
65 changes: 60 additions & 5 deletions go/sqlescape/ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
package sqlescape

import (
"fmt"
"strings"
)

Expand Down Expand Up @@ -52,11 +53,65 @@ func EscapeIDs(identifiers []string) []string {
return result
}

// UnescapeID reverses any backticking in the input string.
func UnescapeID(in string) string {
// UnescapeID reverses any backticking in the input string by EscapeID.
func UnescapeID(in string) (string, error) {
l := len(in)
if l >= 2 && in[0] == '`' && in[l-1] == '`' {
return in[1 : l-1]

if l == 0 || in == "``" {
return "", fmt.Errorf("UnescapeID err: invalid input identifier '%s'", in)

}

if l == 1 {
if in[0] == '`' {
return "", fmt.Errorf("UnescapeID err: invalid input identifier '`'")
}
return in, nil
}

first, last := in[0], in[l-1]

if first == '`' && last != '`' {
return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", 0, in)
}
if first != '`' && last == '`' {
return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", l, in)
}
if first != '`' && last != '`' {
if idx := strings.IndexByte(in, '`'); idx != -1 {
return "", fmt.Errorf("UnescapeID err: no outer backticks found in the identifier '%s'", in)
}
return in, nil
}

in = in[1 : l-1]

if idx := strings.IndexByte(in, '`'); idx == -1 {
return in, nil
}

var buf strings.Builder
buf.Grow(len(in))

for i := 0; i < len(in); i++ {
buf.WriteByte(in[i])

if i < len(in)-1 && in[i] == '`' {
if in[i+1] == '`' {
i++ // halves the number of backticks
} else {
return "", fmt.Errorf("UnescapeID err: unexpected single backtick at position %d in '%s'", i, in)
}
}
}

return buf.String(), nil
}

func EnsureEscaped(in string) (string, error) {
out, err := UnescapeID(in)
if err != nil {
return "", err
}
return in
return EscapeID(out), nil
}
181 changes: 177 additions & 4 deletions go/sqlescape/ids_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ package sqlescape

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestEscapeID(t *testing.T) {
Expand All @@ -26,12 +29,182 @@ func TestEscapeID(t *testing.T) {
}, {
in: "a`a",
out: "`a``a`",
}, {
in: "`fo`o`",
out: "```fo``o```",
}, {
in: "",
out: "``",
}}
for _, tc := range testcases {
out := EscapeID(tc.in)
if out != tc.out {
t.Errorf("EscapeID(%s): %s, want %s", tc.in, out, tc.out)
}
t.Run(tc.in, func(t *testing.T) {
out := EscapeID(tc.in)
assert.Equal(t, out, tc.out)
})
}
}

func TestUnescapeID(t *testing.T) {
testcases := []struct {
in, out string
err bool
}{
{
in: "``",
out: "",
err: true,
},
{
in: "a",
out: "a",
err: false,
},
{
in: "`aa`",
out: "aa",
err: false,
},
{
in: "`a``a`",
out: "a`a",
err: false,
},
{
in: "`foo",
out: "",
err: true,
},
{
in: "foo`",
out: "",
err: true,
},
{
in: "`fo`o",
out: "",
err: true,
},
{
in: "`fo`o`",
out: "",
err: true,
},
{
in: "``fo``o``",
out: "",
err: true,
},
{
in: "```fo``o```",
out: "`fo`o`",
err: false,
},
{
in: "```fo`o```",
out: "",
err: true,
},
{
in: "foo",
out: "foo",
err: false,
},
{
in: "f`oo",
out: "",
err: true,
},
{
in: "",
out: "",
err: true,
},
{
in: "`",
out: "",
err: true,
},
}
for _, tc := range testcases {
t.Run(tc.in, func(t *testing.T) {
out, err := UnescapeID(tc.in)
if tc.err {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.out, out, "output mismatch")
}
})
}
}

func TestEnsureEscaped(t *testing.T) {
tt := []struct {
in string
out string
err bool
}{
{
in: "",
out: "",
err: true,
},
{
in: "foo",
out: "`foo`",
err: false,
},
{
in: "`foo`",
out: "`foo`",
err: false,
},
{
in: "```fo``o```",
out: "```fo``o```",
err: false,
},
{
in: "`fo``o`",
out: "`fo``o`",
err: false,
},
{
in: "f`oo",
out: "",
err: true,
},
{
in: "`fo`o",
out: "",
err: true,
},
{
in: "`foo",
out: "",
err: true,
},
{
in: "foo`",
out: "",
err: true,
},
{
in: "`fo`o`",
out: "",
err: true,
},
}
for _, tc := range tt {
t.Run(tc.in, func(t *testing.T) {
out, err := EnsureEscaped(tc.in)
if tc.err {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.out, out, "output mismatch")
}
})
}
}

Expand Down
17 changes: 14 additions & 3 deletions go/vt/mysqlctl/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,11 @@ func GetColumnsList(dbName, tableName string, exec func(string, int, bool) (*sql
} else {
dbName2 = encodeEntityName(dbName)
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sqlescape.UnescapeID(tableName)))
sanitizedTableName, err := sqlescape.UnescapeID(tableName)
if err != nil {
return "", err
}
query := fmt.Sprintf(GetColumnNamesQuery, dbName2, encodeEntityName(sanitizedTableName))
qr, err := exec(query, -1, true)
if err != nil {
return "", err
Expand Down Expand Up @@ -342,9 +346,16 @@ func GetColumns(dbName, table string, exec func(string, int, bool) (*sqltypes.Re
if selectColumns == "" {
selectColumns = "*"
}
tableSpec := sqlescape.EscapeID(sqlescape.UnescapeID(table))
tableSpec, err := sqlescape.EnsureEscaped(table)
if err != nil {
return nil, nil, err
}
if dbName != "" {
tableSpec = fmt.Sprintf("%s.%s", sqlescape.EscapeID(sqlescape.UnescapeID(dbName)), tableSpec)
dbName, err := sqlescape.EnsureEscaped(dbName)
if err != nil {
return nil, nil, err
}
tableSpec = fmt.Sprintf("%s.%s", dbName, tableSpec)
}
query := fmt.Sprintf(GetFieldsQuery, selectColumns, tableSpec)
qr, err := exec(query, 0, true)
Expand Down
38 changes: 25 additions & 13 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,23 +487,29 @@ func (ts *trafficSwitcher) dropParticipatingTablesFromKeyspace(ctx context.Conte
func (ts *trafficSwitcher) removeSourceTables(ctx context.Context, removalType TableRemovalType) error {
err := ts.ForAllSources(func(source *MigrationSource) error {
for _, tableName := range ts.Tables() {
query := fmt.Sprintf("drop table %s.%s",
sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())),
sqlescape.EscapeID(sqlescape.UnescapeID(tableName)))
primaryDbName, err := sqlescape.EnsureEscaped(source.GetPrimary().DbName())
if err != nil {
return err
}
tableNameEscaped, err := sqlescape.EnsureEscaped(tableName)
if err != nil {
return err
}

query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableNameEscaped)
if removalType == DropTable {
ts.Logger().Infof("%s: Dropping table %s.%s\n",
source.GetPrimary().String(), source.GetPrimary().DbName(), tableName)
} else {
renameName := getRenameFileName(tableName)
renameName, err := sqlescape.EnsureEscaped(getRenameFileName(tableName))
if err != nil {
return err
}
ts.Logger().Infof("%s: Renaming table %s.%s to %s.%s\n",
source.GetPrimary().String(), source.GetPrimary().DbName(), tableName, source.GetPrimary().DbName(), renameName)
query = fmt.Sprintf("rename table %s.%s TO %s.%s",
sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())),
sqlescape.EscapeID(sqlescape.UnescapeID(tableName)),
sqlescape.EscapeID(sqlescape.UnescapeID(source.GetPrimary().DbName())),
sqlescape.EscapeID(sqlescape.UnescapeID(renameName)))
query = fmt.Sprintf("rename table %s.%s TO %s.%s", primaryDbName, tableNameEscaped, primaryDbName, renameName)
}
_, err := ts.ws.tmc.ExecuteFetchAsDba(ctx, source.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{
_, err = ts.ws.tmc.ExecuteFetchAsDba(ctx, source.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{
Query: []byte(query),
MaxRows: 1,
ReloadSchema: true,
Expand Down Expand Up @@ -1065,9 +1071,15 @@ func (ts *trafficSwitcher) removeTargetTables(ctx context.Context) error {
err := ts.ForAllTargets(func(target *MigrationTarget) error {
log.Infof("ForAllTargets: %+v", target)
for _, tableName := range ts.Tables() {
query := fmt.Sprintf("drop table %s.%s",
sqlescape.EscapeID(sqlescape.UnescapeID(target.GetPrimary().DbName())),
sqlescape.EscapeID(sqlescape.UnescapeID(tableName)))
primaryDbName, err := sqlescape.EnsureEscaped(target.GetPrimary().DbName())
if err != nil {
return err
}
tableName, err := sqlescape.EnsureEscaped(tableName)
if err != nil {
return err
}
query := fmt.Sprintf("drop table %s.%s", primaryDbName, tableName)
ts.Logger().Infof("%s: Dropping table %s.%s\n",
target.GetPrimary().String(), target.GetPrimary().DbName(), tableName)
res, err := ts.ws.tmc.ExecuteFetchAsDba(ctx, target.GetPrimary().Tablet, false, &tabletmanagerdatapb.ExecuteFetchAsDbaRequest{
Expand Down
Loading

0 comments on commit f751c83

Please sign in to comment.