Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly unescape keyspace name in FindAllShardsInKeyspace #15765

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions go/vt/topo/keyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -213,6 +214,14 @@ func (ts *Server) FindAllShardsInKeyspace(ctx context.Context, keyspace string,
opt.Concurrency = DefaultConcurrency
}

// Unescape the keyspace name as this can e.g. come from the VSchema where
// a keyspace/database name will need to be SQL escaped if it has special
// characters such as a dash.
keyspace, err := sqlescape.UnescapeID(keyspace)
if err != nil {
return nil, vterrors.Wrapf(err, "FindAllShardsInKeyspace(%s) invalid keyspace name", keyspace)
}

// First try to get all shards using List if we can.
buildResultFromList := func(kvpairs []KVInfo) (map[string]*ShardInfo, error) {
result := make(map[string]*ShardInfo, len(kvpairs))
Expand Down
28 changes: 23 additions & 5 deletions go/vt/topo/keyspace_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/memorytopo"
Expand All @@ -32,10 +33,12 @@ import (
)

func TestServerFindAllShardsInKeyspace(t *testing.T) {
const defaultKeyspace = "keyspace"
tests := []struct {
name string
shards int
opt *topo.FindAllShardsInKeyspaceOptions
name string
shards int
keyspace string // If you want to override the default
opt *topo.FindAllShardsInKeyspaceOptions
}{
{
name: "negative concurrency",
Expand All @@ -54,9 +57,25 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {
shards: 32,
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
{
name: "SQL escaped keyspace",
shards: 32,
keyspace: "`my-keyspace`",
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
}

for _, tt := range tests {
keyspace := defaultKeyspace
if tt.keyspace != "" {
// Most calls such as CreateKeyspace will not accept invalid characters
// in the value so we'll only use the original test case value in
// FindAllShardsInKeyspace. This allows us to test and confirm that
// FindAllShardsInKeyspace can handle SQL escaped or backtick'd names.
keyspace, _ = sqlescape.UnescapeID(tt.keyspace)
} else {
tt.keyspace = defaultKeyspace
}
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -66,7 +85,6 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Create an ephemeral keyspace and generate shard records within
// the keyspace to fetch later.
const keyspace = "keyspace"
require.NoError(t, ts.CreateKeyspace(ctx, keyspace, &topodatapb.Keyspace{}))

shards, err := key.GenerateShardRanges(tt.shards)
Expand All @@ -78,7 +96,7 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Verify that we return a complete list of shards and that each
// key range is present in the output.
out, err := ts.FindAllShardsInKeyspace(ctx, keyspace, tt.opt)
out, err := ts.FindAllShardsInKeyspace(ctx, tt.keyspace, tt.opt)
require.NoError(t, err)
require.Len(t, out, tt.shards)

Expand Down
7 changes: 4 additions & 3 deletions go/vt/vtctl/workflow/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,10 @@ func TestVDiffCreate(t *testing.T) {
wantErr string
}{
{
name: "no values",
req: &vtctldatapb.VDiffCreateRequest{},
wantErr: "FindAllShardsInKeyspace(): List: node doesn't exist: keyspaces/shards", // We did not provide any keyspace or shard
name: "no values",
req: &vtctldatapb.VDiffCreateRequest{},
// We did not provide any keyspace or shard.
wantErr: "FindAllShardsInKeyspace() invalid keyspace name: UnescapeID err: invalid input identifier ''",
},
}
for _, tt := range tests {
Expand Down
Loading