Skip to content

Commit

Permalink
Merge pull request #3362 from telepresenceio/josecv/fix-container-ctx
Browse files Browse the repository at this point in the history
Context spaghetti fix.
  • Loading branch information
josecv authored Sep 29, 2023
2 parents a96b0ef + 9d69961 commit 8286379
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cmd/traffic/cmd/manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ func MainWithEnv(ctx context.Context) error {
}
ctx = k8sapi.WithK8sInterface(ctx, ki)

ctx, imgRetErr := WithAgentImageRetrieverFunc(ctx, mutator.RegenerateAgentMaps)
mgr, ctx, err := NewServiceFunc(ctx)
if err != nil {
return fmt.Errorf("unable to initialize traffic manager: %w", err)
}
ctx, imgRetErr := WithAgentImageRetrieverFunc(ctx, mutator.RegenerateAgentMaps)

g := dgroup.NewGroup(ctx, dgroup.GroupConfig{
EnableSignalHandling: true,
Expand Down
22 changes: 12 additions & 10 deletions cmd/traffic/cmd/manager/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ type interceptFinalizerCall struct {
state *interceptState
info *rpc.InterceptInfo
errCh chan error
ctx context.Context
}

// state is the total state of the Traffic Manager. A zero state is invalid; you must call
Expand Down Expand Up @@ -231,13 +232,13 @@ func (s *state) GetSession(sessionID string) SessionState {
func (s *state) RemoveSession(ctx context.Context, sessionID string) error {
s.mu.Lock()
dlog.Debugf(ctx, "Session %s removed. Explicit removal", sessionID)
wait := s.unlockedRemoveSession(sessionID)
wait := s.unlockedRemoveSession(ctx, sessionID)
s.mu.Unlock()

return wait()
}

func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter {
func (s *state) gcSessionIntercepts(ctx context.Context, sessionID string) cleanupWaiter {
agent, isAgent := s.agents.Load(sessionID)

wait := func() error { return nil }
Expand All @@ -250,7 +251,7 @@ func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter {
if intercept.ClientSession.SessionId == sessionID {
// Client went away:
// Delete it.
_, iceptWait := s.unlockedRemoveIntercept(interceptID)
_, iceptWait := s.unlockedRemoveIntercept(ctx, interceptID)
newWait := func() error {
err := wait()
err = multierror.Append(err, iceptWait())
Expand All @@ -274,13 +275,13 @@ func (s *state) gcSessionIntercepts(sessionID string) cleanupWaiter {
return wait
}

func (s *state) unlockedRemoveSession(sessionID string) cleanupWaiter {
func (s *state) unlockedRemoveSession(ctx context.Context, sessionID string) cleanupWaiter {
wait := func() error { return nil }
if sess, ok := s.sessions[sessionID]; ok {
// kill the session
defer sess.Cancel()

wait = s.gcSessionIntercepts(sessionID)
wait = s.gcSessionIntercepts(ctx, sessionID)

agent, isAgent := s.agents.Load(sessionID)
if isAgent {
Expand Down Expand Up @@ -317,13 +318,13 @@ func (s *state) ExpireSessions(ctx context.Context, clientMoment, agentMoment ti
if _, ok := sess.(*clientSessionState); ok {
if sess.LastMarked().Before(clientMoment) {
dlog.Debugf(ctx, "Client Session %s removed. It has expired", id)
wait := s.unlockedRemoveSession(id)
wait := s.unlockedRemoveSession(ctx, id)
go reportErr(id, wait)
}
} else {
if sess.LastMarked().Before(agentMoment) {
dlog.Debugf(ctx, "Agent Session %s removed. It has expired", id)
wait := s.unlockedRemoveSession(id)
wait := s.unlockedRemoveSession(ctx, id)
go reportErr(id, wait)
}
}
Expand Down Expand Up @@ -609,12 +610,12 @@ func (s *state) UpdateClient(sessionID string, apply func(*rpc.ClientInfo)) *rpc

func (s *state) RemoveIntercept(ctx context.Context, interceptID string) (bool, error) {
s.mu.Lock()
removed, wait := s.unlockedRemoveIntercept(interceptID)
removed, wait := s.unlockedRemoveIntercept(ctx, interceptID)
s.mu.Unlock()
return removed, wait()
}

func (s *state) unlockedRemoveIntercept(interceptID string) (bool, cleanupWaiter) {
func (s *state) unlockedRemoveIntercept(ctx context.Context, interceptID string) (bool, cleanupWaiter) {
intercept, didDelete := s.intercepts.LoadAndDelete(interceptID)
wait := func() error { return nil }
if state, ok := s.interceptStates[interceptID]; ok && didDelete {
Expand All @@ -623,6 +624,7 @@ func (s *state) unlockedRemoveIntercept(interceptID string) (bool, cleanupWaiter
state: state,
info: intercept,
errCh: make(chan error),
ctx: ctx,
}
s.interceptFinalizerCh <- call
wait = func() error {
Expand All @@ -637,7 +639,7 @@ func (s *state) runInterceptFinalizerQueue() {
for {
select {
case call := <-s.interceptFinalizerCh:
call.errCh <- call.state.terminate(s.backgroundCtx, call.info)
call.errCh <- call.state.terminate(call.ctx, call.info)
case <-s.backgroundCtx.Done():
return
}
Expand Down

0 comments on commit 8286379

Please sign in to comment.