diff --git a/go/vt/vtorc/logic/tablet_discovery.go b/go/vt/vtorc/logic/tablet_discovery.go index 7066229ab06..eb10bb2a667 100644 --- a/go/vt/vtorc/logic/tablet_discovery.go +++ b/go/vt/vtorc/logic/tablet_discovery.go @@ -27,19 +27,19 @@ import ( "time" "github.com/spf13/pflag" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" "vitess.io/vitess/go/vt/external/golib/sqlutils" "vitess.io/vitess/go/vt/log" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/topo/topoproto" "vitess.io/vitess/go/vt/vtorc/config" "vitess.io/vitess/go/vt/vtorc/db" "vitess.io/vitess/go/vt/vtorc/inst" "vitess.io/vitess/go/vt/vttablet/tmclient" - - topodatapb "vitess.io/vitess/go/vt/proto/topodata" ) var ( @@ -48,6 +48,9 @@ var ( clustersToWatch []string shutdownWaitTime = 30 * time.Second shardsLockCounter int32 + shardsToWatch map[string][]string + shardsToWatchMu sync.Mutex + // ErrNoPrimaryTablet is a fixed error message. ErrNoPrimaryTablet = errors.New("no primary tablet found") ) @@ -58,6 +61,52 @@ func RegisterFlags(fs *pflag.FlagSet) { fs.DurationVar(&shutdownWaitTime, "shutdown_wait_time", shutdownWaitTime, "Maximum time to wait for VTOrc to release all the locks that it is holding before shutting down on SIGTERM") } +// updateShardsToWatch parses the --clusters_to_watch flag-value +// into a map of keyspace/shards. +func updateShardsToWatch() { + if len(clustersToWatch) == 0 { + return + } + + newShardsToWatch := make(map[string][]string, 0) + for _, ks := range clustersToWatch { + if strings.Contains(ks, "/") && !strings.HasSuffix(ks, "/") { + // Validate keyspace/shard parses. + k, s, err := topoproto.ParseKeyspaceShard(ks) + if err != nil { + log.Errorf("Could not parse keyspace/shard %q: %+v", ks, err) + continue + } + newShardsToWatch[k] = append(newShardsToWatch[k], s) + } else { + ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + defer cancel() + // Assume this is a keyspace and find all shards in keyspace. + // Remove trailing slash if exists. + ks = strings.TrimSuffix(ks, "/") + shards, err := ts.GetShardNames(ctx, ks) + if err != nil { + // Log the err and continue. + log.Errorf("Error fetching shards for keyspace: %v", ks) + continue + } + if len(shards) == 0 { + log.Errorf("Topo has no shards for ks: %v", ks) + continue + } + newShardsToWatch[ks] = shards + } + } + if len(newShardsToWatch) == 0 { + log.Error("No keyspace/shards to watch") + return + } + + shardsToWatchMu.Lock() + defer shardsToWatchMu.Unlock() + shardsToWatch = newShardsToWatch +} + // OpenTabletDiscovery opens the vitess topo if enables and returns a ticker // channel for polling. func OpenTabletDiscovery() <-chan time.Time { @@ -67,6 +116,8 @@ func OpenTabletDiscovery() <-chan time.Time { if _, err := db.ExecVTOrc("DELETE FROM vitess_tablet"); err != nil { log.Error(err) } + // Parse --clusters_to_watch into a filter. + updateShardsToWatch() // We refresh all information from the topo once before we start the ticks to do // it on a timer. ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) @@ -77,6 +128,28 @@ func OpenTabletDiscovery() <-chan time.Time { return time.Tick(config.GetTopoInformationRefreshDuration()) //nolint SA1015: using time.Tick leaks the underlying ticker } +// getAllTablets gets all tablets from all cells using a goroutine per cell. +func getAllTablets(ctx context.Context, cells []string) []*topo.TabletInfo { + var tabletsMu sync.Mutex + tablets := make([]*topo.TabletInfo, 0) + eg, ctx := errgroup.WithContext(ctx) + for _, cell := range cells { + eg.Go(func() error { + t, err := ts.GetTabletsByCell(ctx, cell, nil) + if err != nil { + log.Errorf("Failed to load tablets from cell %s: %+v", cell, err) + return nil + } + tabletsMu.Lock() + defer tabletsMu.Unlock() + tablets = append(tablets, t...) + return nil + }) + } + _ = eg.Wait() // always nil + return tablets +} + // refreshAllTablets reloads the tablets from topo and discovers the ones which haven't been refreshed in a while func refreshAllTablets(ctx context.Context) error { return refreshTabletsUsing(ctx, func(tabletAlias string) { @@ -84,81 +157,45 @@ func refreshAllTablets(ctx context.Context) error { }, false /* forceRefresh */) } +// refreshTabletsUsing refreshes tablets using a provided loader. func refreshTabletsUsing(ctx context.Context, loader func(tabletAlias string), forceRefresh bool) error { - if len(clustersToWatch) == 0 { // all known clusters - ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) - defer cancel() - cells, err := ts.GetKnownCells(ctx) - if err != nil { - return err - } + // Get all cells. + ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) + defer cancel() + cells, err := ts.GetKnownCells(ctx) + if err != nil { + return err + } - refreshCtx, refreshCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) - defer refreshCancel() - var wg sync.WaitGroup - for _, cell := range cells { - wg.Add(1) - go func(cell string) { - defer wg.Done() - refreshTabletsInCell(refreshCtx, cell, loader, forceRefresh) - }(cell) - } - wg.Wait() - } else { - // Parse input and build list of keyspaces / shards - var keyspaceShards []*topo.KeyspaceShard - for _, ks := range clustersToWatch { - if strings.Contains(ks, "/") { - // This is a keyspace/shard specification - input := strings.Split(ks, "/") - keyspaceShards = append(keyspaceShards, &topo.KeyspaceShard{Keyspace: input[0], Shard: input[1]}) - } else { - // Assume this is a keyspace and find all shards in keyspace - ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) - defer cancel() - shards, err := ts.GetShardNames(ctx, ks) - if err != nil { - // Log the errr and continue - log.Errorf("Error fetching shards for keyspace: %v", ks) - continue - } - if len(shards) == 0 { - log.Errorf("Topo has no shards for ks: %v", ks) - continue - } - for _, s := range shards { - keyspaceShards = append(keyspaceShards, &topo.KeyspaceShard{Keyspace: ks, Shard: s}) + // Get all tablets from all cells. + getTabletsCtx, getTabletsCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) + defer getTabletsCancel() + tablets := getAllTablets(getTabletsCtx, cells) + if len(tablets) == 0 { + log.Error("Found no tablets") + return nil + } + + // Filter tablets that should not be watched using shardsToWatch map. + matchedTablets := make([]*topo.TabletInfo, 0, len(tablets)) + func() { + shardsToWatchMu.Lock() + defer shardsToWatchMu.Unlock() + for _, t := range tablets { + if len(shardsToWatch) > 0 { + _, ok := shardsToWatch[t.Tablet.Keyspace] + if !ok || !slices.Contains(shardsToWatch[t.Tablet.Keyspace], t.Tablet.Shard) { + continue // filter } } + matchedTablets = append(matchedTablets, t) } - if len(keyspaceShards) == 0 { - log.Errorf("Found no keyspaceShards for input: %+v", clustersToWatch) - return nil - } - refreshCtx, refreshCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) - defer refreshCancel() - var wg sync.WaitGroup - for _, ks := range keyspaceShards { - wg.Add(1) - go func(ks *topo.KeyspaceShard) { - defer wg.Done() - refreshTabletsInKeyspaceShard(refreshCtx, ks.Keyspace, ks.Shard, loader, forceRefresh, nil) - }(ks) - } - wg.Wait() - } - return nil -} + }() -func refreshTabletsInCell(ctx context.Context, cell string, loader func(tabletAlias string), forceRefresh bool) { - tablets, err := ts.GetTabletsByCell(ctx, cell, nil) - if err != nil { - log.Errorf("Error fetching topo info for cell %v: %v", cell, err) - return - } - query := "select alias from vitess_tablet where cell = ?" - args := sqlutils.Args(cell) - refreshTablets(tablets, query, args, loader, forceRefresh, nil) + // Refresh the filtered tablets. + query := "select alias from vitess_tablet" + refreshTablets(matchedTablets, query, nil, loader, forceRefresh, nil) + return nil } // forceRefreshAllTabletsInShard is used to refresh all the tablet's information (both MySQL information and topo records) diff --git a/go/vt/vtorc/logic/tablet_discovery_test.go b/go/vt/vtorc/logic/tablet_discovery_test.go index f6a7af38382..54284e8a017 100644 --- a/go/vt/vtorc/logic/tablet_discovery_test.go +++ b/go/vt/vtorc/logic/tablet_discovery_test.go @@ -19,6 +19,7 @@ package logic import ( "context" "fmt" + "strings" "sync/atomic" "testing" "time" @@ -101,6 +102,76 @@ var ( } ) +func TestUpdateShardsToWatch(t *testing.T) { + oldClustersToWatch := clustersToWatch + oldTs := ts + defer func() { + clustersToWatch = oldClustersToWatch + shardsToWatch = nil + ts = oldTs + }() + + // Create a memory topo-server and create the keyspace and shard records + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts = memorytopo.NewServer(ctx, cell1) + _, err := ts.GetOrCreateShard(context.Background(), keyspace, shard) + require.NoError(t, err) + + testCases := []struct { + in []string + expected map[string][]string + }{ + { + in: []string{}, + expected: nil, + }, + { + in: []string{""}, + expected: map[string][]string{}, + }, + { + in: []string{"test/-"}, + expected: map[string][]string{ + "test": {"-"}, + }, + }, + { + in: []string{"test/-", "test2/-80", "test2/80-"}, + expected: map[string][]string{ + "test": {"-"}, + "test2": {"-80", "80-"}, + }, + }, + { + // confirm shards fetch from topo + in: []string{keyspace}, + expected: map[string][]string{ + keyspace: {shard}, + }, + }, + { + // confirm shards fetch from topo when keyspace has trailing-slash + in: []string{keyspace + "/"}, + expected: map[string][]string{ + keyspace: {shard}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(strings.Join(testCase.in, ","), func(t *testing.T) { + defer func() { + shardsToWatch = make(map[string][]string, 0) + }() + clustersToWatch = testCase.in + updateShardsToWatch() + require.Equal(t, testCase.expected, shardsToWatch) + }) + } +} + func TestRefreshTabletsInKeyspaceShard(t *testing.T) { // Store the old flags and restore on test completion oldTs := ts diff --git a/go/vt/vtorc/logic/vtorc.go b/go/vt/vtorc/logic/vtorc.go index 39326525ce2..c6a4a4ee46d 100644 --- a/go/vt/vtorc/logic/vtorc.go +++ b/go/vt/vtorc/logic/vtorc.go @@ -326,6 +326,12 @@ func refreshAllInformation(ctx context.Context) error { return RefreshAllKeyspacesAndShards(ctx) }) + // Refresh shards to watch. + eg.Go(func() error { + updateShardsToWatch() + return nil + }) + // Refresh all tablets. eg.Go(func() error { return refreshAllTablets(ctx)