diff --git a/go/vt/vttablet/tabletserver/state_manager.go b/go/vt/vttablet/tabletserver/state_manager.go index 8aa7776957f..9c01610f770 100644 --- a/go/vt/vttablet/tabletserver/state_manager.go +++ b/go/vt/vttablet/tabletserver/state_manager.go @@ -99,6 +99,10 @@ type stateManager struct { alsoAllow []topodatapb.TabletType reason string transitionErr error + // requestsWaitCounter is the number of goroutines that are waiting for requests to be empty. + // If this value is greater than zero, then we have to ensure that we don't Add to the requests + // to avoid any panics in the wait. + requestsWaitCounter int requests sync.WaitGroup @@ -354,6 +358,20 @@ func (sm *stateManager) checkMySQL() { }() } +// addRequestsWaitCounter adds to the requestsWaitCounter while being protected by a mutex. +func (sm *stateManager) addRequestsWaitCounter(val int) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.requestsWaitCounter += val +} + +// waitForRequestsToBeEmpty waits for requests to be empty. It also increments and decrements the requestsWaitCounter as required. +func (sm *stateManager) waitForRequestsToBeEmpty() { + sm.addRequestsWaitCounter(1) + sm.requests.Wait() + sm.addRequestsWaitCounter(-1) +} + func (sm *stateManager) setWantState(stateWanted servingState) { sm.mu.Lock() defer sm.mu.Unlock() @@ -392,7 +410,9 @@ func (sm *stateManager) StartRequest(ctx context.Context, target *querypb.Target } shuttingDown := sm.wantState != StateServing - if shuttingDown && !allowOnShutdown { + // If requestsWaitCounter is not zero, then there are go-routines blocked on waiting for requests to be empty. + // We cannot allow adding to the requests to prevent any panics from happening. + if (shuttingDown && !allowOnShutdown) || sm.requestsWaitCounter > 0 { // This specific error string needs to be returned for vtgate buffering to work. return vterrors.New(vtrpcpb.Code_CLUSTER_EVENT, vterrors.ShuttingDown) } @@ -560,7 +580,7 @@ func (sm *stateManager) unserveCommon() { log.Info("Finished Killing all OLAP queries. Started tracker close") sm.tracker.Close() log.Infof("Finished tracker close. Started wait for requests") - sm.requests.Wait() + sm.waitForRequestsToBeEmpty() log.Infof("Finished wait for requests. Finished execution of unserveCommon") } diff --git a/go/vt/vttablet/tabletserver/state_manager_test.go b/go/vt/vttablet/tabletserver/state_manager_test.go index 4b88ce734d7..cd72c7232c8 100644 --- a/go/vt/vttablet/tabletserver/state_manager_test.go +++ b/go/vt/vttablet/tabletserver/state_manager_test.go @@ -701,6 +701,29 @@ func TestRefreshReplHealthLocked(t *testing.T) { assert.False(t, sm.replHealthy) } +// TestPanicInWait tests that we don't panic when we wait for requests if more StartRequest calls come up after we start waiting. +func TestPanicInWait(t *testing.T) { + sm := newTestStateManager(t) + sm.wantState = StateServing + sm.state = StateServing + sm.replHealthy = true + ctx := context.Background() + // Simulate an Execute RPC running + err := sm.StartRequest(ctx, sm.target, false) + require.NoError(t, err) + go func() { + time.Sleep(100 * time.Millisecond) + // Simulate the previous RPC finishing after some delay + sm.EndRequest() + // Simulate a COMMIT call arriving right afterwards + _ = sm.StartRequest(ctx, sm.target, true) + }() + + // Simulate going to a not serving state and calling unserveCommon that waits on requests. + sm.wantState = StateNotServing + sm.waitForRequestsToBeEmpty() +} + func verifySubcomponent(t *testing.T, order int64, component any, state testState) { tos := component.(orderState) assert.Equal(t, order, tos.Order())