diff --git a/go/cmd/vtctldclient/command/root.go b/go/cmd/vtctldclient/command/root.go index 8a03eadc764..742944315f1 100644 --- a/go/cmd/vtctldclient/command/root.go +++ b/go/cmd/vtctldclient/command/root.go @@ -23,6 +23,7 @@ import ( "io" "strconv" "strings" + "sync" "time" "github.com/spf13/cobra" @@ -70,6 +71,11 @@ var ( // Register functions to be called when the command completes. onTerm = []func(){} + // Register our nil tmclient grpc handler only one time. + // This is primarily for tests where we execute the root + // command multiple times. + once = sync.Once{} + server string actionTimeout time.Duration compactOutput bool @@ -194,9 +200,12 @@ func getClientForCommand(cmd *cobra.Command) (vtctldclient.VtctldClient, error) onTerm = append(onTerm, ts.Close) // Use internal vtcltd server implementation. - // Register a nil grpc handler -- we will not use tmclient at all. - tmclient.RegisterTabletManagerClientFactory("grpc", func() tmclient.TabletManagerClient { - return nil + // Register a nil grpc handler -- we will not use tmclient at all but + // a factory still needs to be registered. + once.Do(func() { + tmclient.RegisterTabletManagerClientFactory("grpc", func() tmclient.TabletManagerClient { + return nil + }) }) vtctld := grpcvtctldserver.NewVtctldServer(ts) localvtctldclient.SetServer(vtctld) diff --git a/go/cmd/vtctldclient/command/root_test.go b/go/cmd/vtctldclient/command/root_test.go index 155fac78705..5efe844e1a1 100644 --- a/go/cmd/vtctldclient/command/root_test.go +++ b/go/cmd/vtctldclient/command/root_test.go @@ -17,13 +17,19 @@ limitations under the License. package command_test import ( + "context" + "fmt" "os" + "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/cmd/vtctldclient/command" + "vitess.io/vitess/go/vt/topo" + "vitess.io/vitess/go/vt/topo/memorytopo" "vitess.io/vitess/go/vt/vtctl/localvtctldclient" vtctlservicepb "vitess.io/vitess/go/vt/proto/vtctlservice" @@ -52,3 +58,64 @@ func TestRoot(t *testing.T) { assert.Contains(t, err.Error(), "unknown command") }) } + +// TestRootWithInternalVtctld tests that the internal VtctldServer +// implementation -- used with --server=internal -- works for +// commands as expected. +func TestRootWithInternalVtctld(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + cell := "zone1" + ts, factory := memorytopo.NewServerAndFactory(ctx, cell) + topo.RegisterFactory("test", factory) + command.VtctldClientProtocol = "local" + baseArgs := []string{"vtctldclient", "--server", "internal", "--topo-implementation", "test"} + + args := append([]string{}, os.Args...) + protocol := command.VtctldClientProtocol + t.Cleanup(func() { + ts.Close() + os.Args = append([]string{}, args...) + command.VtctldClientProtocol = protocol + }) + + testCases := []struct { + command string + args []string + expectErr string + }{ + { + command: "AddCellInfo", + args: []string{"--root", fmt.Sprintf("/vitess/%s", cell), "--server-address", "", cell}, + expectErr: "node already exists", // Cell already exists + }, + { + command: "GetTablets", + }, + { + command: "NoCommandDrJones", + expectErr: "unknown command", // Invalid command + }, + } + + for _, tc := range testCases { + t.Run(tc.command, func(t *testing.T) { + defer func() { + // Reset the OS args. + os.Args = append([]string{}, args...) + }() + + os.Args = append(baseArgs, tc.command) + os.Args = append(os.Args, tc.args...) + + err := command.Root.Execute() + if tc.expectErr != "" { + if !strings.Contains(err.Error(), tc.expectErr) { + t.Errorf(fmt.Sprintf("%s error = %v, expectErr = %v", tc.command, err, tc.expectErr)) + } + } else { + require.NoError(t, err, "unexpected error: %v", err) + } + }) + } +}