diff --git a/go/vt/vtorc/logic/keyspace_shard_discovery.go b/go/vt/vtorc/logic/keyspace_shard_discovery.go index b1e93fe2a01..0dd17cb65fd 100644 --- a/go/vt/vtorc/logic/keyspace_shard_discovery.go +++ b/go/vt/vtorc/logic/keyspace_shard_discovery.go @@ -29,17 +29,16 @@ import ( ) // RefreshAllKeyspacesAndShards reloads the keyspace and shard information for the keyspaces that vtorc is concerned with. -func RefreshAllKeyspacesAndShards() { +func RefreshAllKeyspacesAndShards(ctx context.Context) error { var keyspaces []string if len(clustersToWatch) == 0 { // all known keyspaces - ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer cancel() var err error // Get all the keyspaces keyspaces, err = ts.GetKeyspaces(ctx) if err != nil { - log.Error(err) - return + return err } } else { // Parse input and build list of keyspaces @@ -55,14 +54,14 @@ func RefreshAllKeyspacesAndShards() { } if len(keyspaces) == 0 { log.Errorf("Found no keyspaces for input: %+v", clustersToWatch) - return + return nil } } // Sort the list of keyspaces. // The list can have duplicates because the input to clusters to watch may have multiple shards of the same keyspace sort.Strings(keyspaces) - refreshCtx, refreshCancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + refreshCtx, refreshCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer refreshCancel() var wg sync.WaitGroup for idx, keyspace := range keyspaces { @@ -83,6 +82,8 @@ func RefreshAllKeyspacesAndShards() { }(keyspace) } wg.Wait() + + return nil } // RefreshKeyspaceAndShard refreshes the keyspace record and shard record for the given keyspace and shard. diff --git a/go/vt/vtorc/logic/keyspace_shard_discovery_test.go b/go/vt/vtorc/logic/keyspace_shard_discovery_test.go index 2911b3d29c2..ecf59998417 100644 --- a/go/vt/vtorc/logic/keyspace_shard_discovery_test.go +++ b/go/vt/vtorc/logic/keyspace_shard_discovery_test.go @@ -93,7 +93,7 @@ func TestRefreshAllKeyspaces(t *testing.T) { // Set clusters to watch to only watch ks1 and ks3 onlyKs1and3 := []string{"ks1/-80", "ks3/-80", "ks3/80-"} clustersToWatch = onlyKs1and3 - RefreshAllKeyspacesAndShards() + require.NoError(t, RefreshAllKeyspacesAndShards(context.Background())) // Verify that we only have ks1 and ks3 in vtorc's db. verifyKeyspaceInfo(t, "ks1", keyspaceDurabilityNone, "") @@ -108,7 +108,7 @@ func TestRefreshAllKeyspaces(t *testing.T) { clustersToWatch = nil // Change the durability policy of ks1 reparenttestutil.SetKeyspaceDurability(ctx, t, ts, "ks1", "semi_sync") - RefreshAllKeyspacesAndShards() + require.NoError(t, RefreshAllKeyspacesAndShards(context.Background())) // Verify that all the keyspaces are correctly reloaded verifyKeyspaceInfo(t, "ks1", keyspaceDurabilitySemiSync, "") diff --git a/go/vt/vtorc/logic/tablet_discovery.go b/go/vt/vtorc/logic/tablet_discovery.go index 6914ebd546d..3bb56c8cb51 100644 --- a/go/vt/vtorc/logic/tablet_discovery.go +++ b/go/vt/vtorc/logic/tablet_discovery.go @@ -27,7 +27,6 @@ import ( "time" "github.com/spf13/pflag" - "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" @@ -71,30 +70,36 @@ func OpenTabletDiscovery() <-chan time.Time { if _, err := db.ExecVTOrc("delete from vitess_tablet"); err != nil { log.Error(err) } + // We refresh all information from the topo once before we start the ticks to do + // it on a timer. + ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + defer cancel() + if err := refreshAllInformation(ctx); err != nil { + log.Errorf("failed to initialize topo information: %+v", err) + } return time.Tick(time.Second * time.Duration(config.Config.TopoInformationRefreshSeconds)) //nolint SA1015: using time.Tick leaks the underlying ticker } // refreshAllTablets reloads the tablets from topo and discovers the ones which haven't been refreshed in a while -func refreshAllTablets() { - refreshTabletsUsing(func(tabletAlias string) { +func refreshAllTablets(ctx context.Context) error { + return refreshTabletsUsing(ctx, func(tabletAlias string) { DiscoverInstance(tabletAlias, false /* forceDiscovery */) }, false /* forceRefresh */) } -func refreshTabletsUsing(loader func(tabletAlias string), forceRefresh bool) { +func refreshTabletsUsing(ctx context.Context, loader func(tabletAlias string), forceRefresh bool) error { if !IsLeaderOrActive() { - return + return nil } if len(clustersToWatch) == 0 { // all known clusters - ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer cancel() cells, err := ts.GetKnownCells(ctx) if err != nil { - log.Error(err) - return + return err } - refreshCtx, refreshCancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + refreshCtx, refreshCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer refreshCancel() var wg sync.WaitGroup for _, cell := range cells { @@ -115,7 +120,7 @@ func refreshTabletsUsing(loader func(tabletAlias string), forceRefresh bool) { keyspaceShards = append(keyspaceShards, &topo.KeyspaceShard{Keyspace: input[0], Shard: input[1]}) } else { // Assume this is a keyspace and find all shards in keyspace - ctx, cancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + ctx, cancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer cancel() shards, err := ts.GetShardNames(ctx, ks) if err != nil { @@ -134,9 +139,9 @@ func refreshTabletsUsing(loader func(tabletAlias string), forceRefresh bool) { } if len(keyspaceShards) == 0 { log.Errorf("Found no keyspaceShards for input: %+v", clustersToWatch) - return + return nil } - refreshCtx, refreshCancel := context.WithTimeout(context.Background(), topo.RemoteOperationTimeout) + refreshCtx, refreshCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout) defer refreshCancel() var wg sync.WaitGroup for _, ks := range keyspaceShards { @@ -148,6 +153,7 @@ func refreshTabletsUsing(loader func(tabletAlias string), forceRefresh bool) { } wg.Wait() } + return nil } func refreshTabletsInCell(ctx context.Context, cell string, loader func(tabletAlias string), forceRefresh bool) { diff --git a/go/vt/vtorc/logic/vtorc.go b/go/vt/vtorc/logic/vtorc.go index f637956fbfd..4115de3c7b3 100644 --- a/go/vt/vtorc/logic/vtorc.go +++ b/go/vt/vtorc/logic/vtorc.go @@ -17,6 +17,7 @@ package logic import ( + "context" "os" "os/signal" "sync" @@ -27,6 +28,7 @@ import ( "github.com/patrickmn/go-cache" "github.com/rcrowley/go-metrics" "github.com/sjmudd/stopwatch" + "golang.org/x/sync/errgroup" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" @@ -415,27 +417,34 @@ func ContinuousDiscovery() { } }() case <-tabletTopoTick: - // Create a wait group - var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(config.Config.TopoInformationRefreshSeconds)) + if err := refreshAllInformation(ctx); err != nil { + log.Errorf("failed to refresh topo information: %+v", err) + } + cancel() + } + } +} - // Refresh all keyspace information. - wg.Add(1) - go func() { - defer wg.Done() - RefreshAllKeyspacesAndShards() - }() +// refreshAllInformation refreshes both shard and tablet information. This is meant to be run on tablet topo ticks. +func refreshAllInformation(ctx context.Context) error { + // Create an errgroup + eg, ctx := errgroup.WithContext(ctx) - // Refresh all tablets. - wg.Add(1) - go func() { - defer wg.Done() - refreshAllTablets() - }() + // Refresh all keyspace information. + eg.Go(func() error { + return RefreshAllKeyspacesAndShards(ctx) + }) - // Wait for both the refreshes to complete - wg.Wait() - // We have completed one discovery cycle in the entirety of it. We should update the process health. - process.FirstDiscoveryCycleComplete.Store(true) - } + // Refresh all tablets. + eg.Go(func() error { + return refreshAllTablets(ctx) + }) + + // Wait for both the refreshes to complete + err := eg.Wait() + if err == nil { + process.FirstDiscoveryCycleComplete.Store(true) } + return err } diff --git a/go/vt/vtorc/logic/vtorc_test.go b/go/vt/vtorc/logic/vtorc_test.go index c8f2ac3bfdc..7ee2f0e253b 100644 --- a/go/vt/vtorc/logic/vtorc_test.go +++ b/go/vt/vtorc/logic/vtorc_test.go @@ -1,11 +1,17 @@ package logic import ( + "context" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/topo/memorytopo" + "vitess.io/vitess/go/vt/vtorc/db" + "vitess.io/vitess/go/vt/vtorc/process" ) func TestWaitForLocksRelease(t *testing.T) { @@ -54,3 +60,49 @@ func waitForLocksReleaseAndGetTimeWaitedFor() time.Duration { waitForLocksRelease() return time.Since(start) } + +func TestRefreshAllInformation(t *testing.T) { + defer process.ResetLastHealthCheckCache() + + // Store the old flags and restore on test completion + oldTs := ts + defer func() { + ts = oldTs + }() + + // Clear the database after the test. The easiest way to do that is to run all the initialization commands again. + defer func() { + db.ClearVTOrcDatabase() + }() + + // Verify in the beginning, we have the first DiscoveredOnce field false. + _, err := process.HealthTest() + require.NoError(t, err) + + // Create a memory topo-server and create the keyspace and shard records + ts = memorytopo.NewServer(context.Background(), cell1) + _, err = ts.GetOrCreateShard(context.Background(), keyspace, shard) + require.NoError(t, err) + + // Test error + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel context to simulate timeout + require.Error(t, refreshAllInformation(ctx)) + require.False(t, process.FirstDiscoveryCycleComplete.Load()) + health, err := process.HealthTest() + require.NoError(t, err) + require.False(t, health.DiscoveredOnce) + require.False(t, health.Healthy) + process.ResetLastHealthCheckCache() + + // Test success + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.NoError(t, refreshAllInformation(ctx2)) + require.True(t, process.FirstDiscoveryCycleComplete.Load()) + health, err = process.HealthTest() + require.NoError(t, err) + require.True(t, health.DiscoveredOnce) + require.True(t, health.Healthy) + process.ResetLastHealthCheckCache() +} diff --git a/go/vt/vtorc/process/health.go b/go/vt/vtorc/process/health.go index 22db89e1d56..7f8ab83b39b 100644 --- a/go/vt/vtorc/process/health.go +++ b/go/vt/vtorc/process/health.go @@ -36,6 +36,8 @@ var FirstDiscoveryCycleComplete atomic.Bool var lastHealthCheckCache = cache.New(config.HealthPollSeconds*time.Second, time.Second) +func ResetLastHealthCheckCache() { lastHealthCheckCache.Flush() } + type NodeHealth struct { Hostname string Token string @@ -120,8 +122,8 @@ func HealthTest() (health *HealthStatus, err error) { log.Error(err) return health, err } - health.Healthy = healthy health.DiscoveredOnce = FirstDiscoveryCycleComplete.Load() + health.Healthy = healthy && health.DiscoveredOnce if health.ActiveNode, health.IsActiveNode, err = ElectedNode(); err != nil { health.Error = err diff --git a/go/vt/vtorc/process/health_test.go b/go/vt/vtorc/process/health_test.go new file mode 100644 index 00000000000..85317530ac4 --- /dev/null +++ b/go/vt/vtorc/process/health_test.go @@ -0,0 +1,51 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package process + +import ( + "testing" + + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" +) + +func TestHealthTest(t *testing.T) { + defer func() { + FirstDiscoveryCycleComplete.Store(false) + ThisNodeHealth = &NodeHealth{} + ResetLastHealthCheckCache() + }() + + require.Zero(t, ThisNodeHealth.LastReported) + + ThisNodeHealth = &NodeHealth{} + health, err := HealthTest() + require.NoError(t, err) + require.False(t, health.Healthy) + require.False(t, health.DiscoveredOnce) + require.NotZero(t, ThisNodeHealth.LastReported) + ResetLastHealthCheckCache() + + ThisNodeHealth = &NodeHealth{} + FirstDiscoveryCycleComplete.Store(true) + health, err = HealthTest() + require.NoError(t, err) + require.True(t, health.Healthy) + require.True(t, health.DiscoveredOnce) + require.NotZero(t, ThisNodeHealth.LastReported) + ResetLastHealthCheckCache() +}