Skip to content

Commit

Permalink
Properly unescape keyspace name in FindAllShardsInKeyspace
Browse files Browse the repository at this point in the history
This is needed because the input to this function comes from
e.g. the vschema, where a keyspace name that contains special
characters such as a dash needs to be escaped with backticks.

Signed-off-by: Matt Lord <[email protected]>
  • Loading branch information
mattlord committed Apr 21, 2024
1 parent 14b36d0 commit 407c841
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
9 changes: 9 additions & 0 deletions go/vt/topo/keyspace.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"golang.org/x/sync/errgroup"

"vitess.io/vitess/go/constants/sidecar"
"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -213,6 +214,14 @@ func (ts *Server) FindAllShardsInKeyspace(ctx context.Context, keyspace string,
opt.Concurrency = DefaultConcurrency
}

// Unescape the keyspace name as this can e.g. come from the VSchema where
// a keyspace/database name will need to be SQL escaped if it has special
// characters such as a dash.
keyspace, err := sqlescape.UnescapeID(keyspace)
if err != nil {
return nil, vterrors.Wrapf(err, "FindAllShardsInKeyspace(%s) invalid keyspace name: %v", keyspace, err)
}

// First try to get all shards using List if we can.
buildResultFromList := func(kvpairs []KVInfo) (map[string]*ShardInfo, error) {
result := make(map[string]*ShardInfo, len(kvpairs))
Expand Down
28 changes: 23 additions & 5 deletions go/vt/topo/keyspace_external_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqlescape"
"vitess.io/vitess/go/vt/key"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/memorytopo"
Expand All @@ -32,10 +33,12 @@ import (
)

func TestServerFindAllShardsInKeyspace(t *testing.T) {
const defaultKeyspace = "keyspace"
tests := []struct {
name string
shards int
opt *topo.FindAllShardsInKeyspaceOptions
name string
shards int
keyspace string // If you want to override the default
opt *topo.FindAllShardsInKeyspaceOptions
}{
{
name: "negative concurrency",
Expand All @@ -54,9 +57,25 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {
shards: 32,
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
{
name: "backtick'd keyspace",
shards: 32,
keyspace: "`my-keyspace`",
opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8},
},
}

for _, tt := range tests {
keyspace := defaultKeyspace
if tt.keyspace != "" {
// Most calls such as CreateKeyspace will not accept invalid characters
// in the value so we'll only use the original test case value in
// FindAllShardsInKeyspace. This allows us to test and confirm that
// FindAllShardsInKeyspace can handle SQL escaped or backtick'd names.
keyspace, _ = sqlescape.UnescapeID(tt.keyspace)
} else {
tt.keyspace = defaultKeyspace
}
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -66,7 +85,6 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Create an ephemeral keyspace and generate shard records within
// the keyspace to fetch later.
const keyspace = "keyspace"
require.NoError(t, ts.CreateKeyspace(ctx, keyspace, &topodatapb.Keyspace{}))

shards, err := key.GenerateShardRanges(tt.shards)
Expand All @@ -78,7 +96,7 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) {

// Verify that we return a complete list of shards and that each
// key range is present in the output.
out, err := ts.FindAllShardsInKeyspace(ctx, keyspace, tt.opt)
out, err := ts.FindAllShardsInKeyspace(ctx, tt.keyspace, tt.opt)
require.NoError(t, err)
require.Len(t, out, tt.shards)

Expand Down

0 comments on commit 407c841

Please sign in to comment.