diff --git a/cmd/traffic/cmd/manager/manager.go b/cmd/traffic/cmd/manager/manager.go index f7fc7c6a1c..f6e279fc11 100644 --- a/cmd/traffic/cmd/manager/manager.go +++ b/cmd/traffic/cmd/manager/manager.go @@ -79,11 +79,11 @@ func MainWithEnv(ctx context.Context) error { } ctx = k8sapi.WithK8sInterface(ctx, ki) - ctx, imgRetErr := WithAgentImageRetrieverFunc(ctx, mutator.RegenerateAgentMaps) mgr, ctx, err := NewServiceFunc(ctx) if err != nil { return fmt.Errorf("unable to initialize traffic manager: %w", err) } + ctx, imgRetErr := WithAgentImageRetrieverFunc(ctx, mutator.RegenerateAgentMaps) g := dgroup.NewGroup(ctx, dgroup.GroupConfig{ EnableSignalHandling: true, diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index 9cf0adffd7..44736ee7ab 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -75,6 +75,7 @@ type interceptFinalizerCall struct { state *interceptState info *rpc.InterceptInfo errCh chan error + ctx context.Context } // state is the total state of the Traffic Manager. A zero state is invalid; you must call @@ -231,13 +232,13 @@ func (s *state) GetSession(sessionID string) SessionState { func (s *state) RemoveSession(ctx context.Context, sessionID string) error { s.mu.Lock() dlog.Debugf(ctx, "Session %s removed. Explicit removal", sessionID) - wait := s.unlockedRemoveSession(sessionID) + wait := s.unlockedRemoveSession(ctx, sessionID) s.mu.Unlock() return wait() } -func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter { +func (s *state) gcSessionIntercepts(ctx context.Context, sessionID string) cleanupWaiter { agent, isAgent := s.agents.Load(sessionID) wait := func() error { return nil } @@ -250,7 +251,7 @@ func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter { if intercept.ClientSession.SessionId == sessionID { // Client went away: // Delete it. - _, iceptWait := s.unlockedRemoveIntercept(interceptID) + _, iceptWait := s.unlockedRemoveIntercept(ctx, interceptID) newWait := func() error { err := wait() err = multierror.Append(err, iceptWait()) @@ -274,13 +275,13 @@ func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter { return wait } -func (s *state) unlockedRemoveSession(sessionID string) cleanupWaiter { +func (s *state) unlockedRemoveSession(ctx context.Context, sessionID string) cleanupWaiter { wait := func() error { return nil } if sess, ok := s.sessions[sessionID]; ok { // kill the session defer sess.Cancel() - wait = s.gcSessionIntercepts(sessionID) + wait = s.gcSessionIntercepts(ctx, sessionID) agent, isAgent := s.agents.Load(sessionID) if isAgent { @@ -317,13 +318,13 @@ func (s *state) ExpireSessions(ctx context.Context, clientMoment, agentMoment ti if _, ok := sess.(*clientSessionState); ok { if sess.LastMarked().Before(clientMoment) { dlog.Debugf(ctx, "Client Session %s removed. It has expired", id) - wait := s.unlockedRemoveSession(id) + wait := s.unlockedRemoveSession(ctx, id) go reportErr(id, wait) } } else { if sess.LastMarked().Before(agentMoment) { dlog.Debugf(ctx, "Agent Session %s removed. It has expired", id) - wait := s.unlockedRemoveSession(id) + wait := s.unlockedRemoveSession(ctx, id) go reportErr(id, wait) } } @@ -609,12 +610,12 @@ func (s *state) UpdateClient(sessionID string, apply func(*rpc.ClientInfo)) *rpc func (s *state) RemoveIntercept(ctx context.Context, interceptID string) (bool, error) { s.mu.Lock() - removed, wait := s.unlockedRemoveIntercept(interceptID) + removed, wait := s.unlockedRemoveIntercept(ctx, interceptID) s.mu.Unlock() return removed, wait() } -func (s *state) unlockedRemoveIntercept(interceptID string) (bool, cleanupWaiter) { +func (s *state) unlockedRemoveIntercept(ctx context.Context, interceptID string) (bool, cleanupWaiter) { intercept, didDelete := s.intercepts.LoadAndDelete(interceptID) wait := func() error { return nil } if state, ok := s.interceptStates[interceptID]; ok && didDelete { @@ -623,6 +624,7 @@ func (s *state) unlockedRemoveIntercept(interceptID string) (bool, cleanupWaiter state: state, info: intercept, errCh: make(chan error), + ctx: ctx, } s.interceptFinalizerCh <- call wait = func() error { @@ -637,7 +639,7 @@ func (s *state) runInterceptFinalizerQueue() { for { select { case call := <-s.interceptFinalizerCh: - call.errCh <- call.state.terminate(s.backgroundCtx, call.info) + call.errCh <- call.state.terminate(call.ctx, call.info) case <-s.backgroundCtx.Done(): return }