diff --git a/go/vt/topo/keyspace.go b/go/vt/topo/keyspace.go index 844b4eb4454..dced769ca78 100755 --- a/go/vt/topo/keyspace.go +++ b/go/vt/topo/keyspace.go @@ -26,6 +26,7 @@ import ( "golang.org/x/sync/errgroup" "vitess.io/vitess/go/constants/sidecar" + "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" @@ -213,6 +214,14 @@ func (ts *Server) FindAllShardsInKeyspace(ctx context.Context, keyspace string, opt.Concurrency = DefaultConcurrency } + // Unescape the keyspace name as this can e.g. come from the VSchema where + // a keyspace/database name will need to be SQL escaped if it has special + // characters such as a dash. + keyspace, err := sqlescape.UnescapeID(keyspace) + if err != nil { + return nil, vterrors.Wrapf(err, "FindAllShardsInKeyspace(%s) invalid keyspace name", keyspace) + } + // First try to get all shards using List if we can. buildResultFromList := func(kvpairs []KVInfo) (map[string]*ShardInfo, error) { result := make(map[string]*ShardInfo, len(kvpairs)) diff --git a/go/vt/topo/keyspace_external_test.go b/go/vt/topo/keyspace_external_test.go index 4edb45a411d..bfcb2f591a9 100644 --- a/go/vt/topo/keyspace_external_test.go +++ b/go/vt/topo/keyspace_external_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/memorytopo" @@ -32,10 +33,12 @@ import ( ) func TestServerFindAllShardsInKeyspace(t *testing.T) { + const defaultKeyspace = "keyspace" tests := []struct { - name string - shards int - opt *topo.FindAllShardsInKeyspaceOptions + name string + shards int + keyspace string // If you want to override the default + opt *topo.FindAllShardsInKeyspaceOptions }{ { name: "negative concurrency", @@ -54,9 +57,25 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) { shards: 32, opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8}, }, + { + name: "SQL escaped keyspace", + shards: 32, + keyspace: "`my-keyspace`", + opt: &topo.FindAllShardsInKeyspaceOptions{Concurrency: 8}, + }, } for _, tt := range tests { + keyspace := defaultKeyspace + if tt.keyspace != "" { + // Most calls such as CreateKeyspace will not accept invalid characters + // in the value so we'll only use the original test case value in + // FindAllShardsInKeyspace. This allows us to test and confirm that + // FindAllShardsInKeyspace can handle SQL escaped or backtick'd names. + keyspace, _ = sqlescape.UnescapeID(tt.keyspace) + } else { + tt.keyspace = defaultKeyspace + } t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -66,7 +85,6 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) { // Create an ephemeral keyspace and generate shard records within // the keyspace to fetch later. - const keyspace = "keyspace" require.NoError(t, ts.CreateKeyspace(ctx, keyspace, &topodatapb.Keyspace{})) shards, err := key.GenerateShardRanges(tt.shards) @@ -78,7 +96,7 @@ func TestServerFindAllShardsInKeyspace(t *testing.T) { // Verify that we return a complete list of shards and that each // key range is present in the output. - out, err := ts.FindAllShardsInKeyspace(ctx, keyspace, tt.opt) + out, err := ts.FindAllShardsInKeyspace(ctx, tt.keyspace, tt.opt) require.NoError(t, err) require.Len(t, out, tt.shards) diff --git a/go/vt/vtctl/workflow/server_test.go b/go/vt/vtctl/workflow/server_test.go index 6bb3993fe1c..174cc2aaf6a 100644 --- a/go/vt/vtctl/workflow/server_test.go +++ b/go/vt/vtctl/workflow/server_test.go @@ -180,9 +180,10 @@ func TestVDiffCreate(t *testing.T) { wantErr string }{ { - name: "no values", - req: &vtctldatapb.VDiffCreateRequest{}, - wantErr: "FindAllShardsInKeyspace(): List: node doesn't exist: keyspaces/shards", // We did not provide any keyspace or shard + name: "no values", + req: &vtctldatapb.VDiffCreateRequest{}, + // We did not provide any keyspace or shard. + wantErr: "FindAllShardsInKeyspace() invalid keyspace name: UnescapeID err: invalid input identifier ''", }, } for _, tt := range tests {