From 8294544c3e8408a522dff1d8d04e20d11ca19fab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Thu, 20 Jul 2023 16:05:31 -0400 Subject: [PATCH 01/11] Add consumption methods & attributes to traffic manager state MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/service.go | 5 +- cmd/traffic/cmd/manager/state/consumption.go | 59 ++++++++++++++++++++ cmd/traffic/cmd/manager/state/state.go | 54 ++++++++++++------ 3 files changed, 100 insertions(+), 18 deletions(-) create mode 100644 cmd/traffic/cmd/manager/state/consumption.go diff --git a/cmd/traffic/cmd/manager/service.go b/cmd/traffic/cmd/manager/service.go index c78fc0391e..c24ae1e614 100644 --- a/cmd/traffic/cmd/manager/service.go +++ b/cmd/traffic/cmd/manager/service.go @@ -183,10 +183,13 @@ func (s *service) GetClientConfig(ctx context.Context, _ *empty.Empty) (*rpc.CLI func (s *service) Remain(ctx context.Context, req *rpc.RemainRequest) (*empty.Empty, error) { // ctx = WithSessionInfo(ctx, req.GetSession()) // dlog.Debug(ctx, "Remain called") + sessionID := req.GetSession().GetSessionId() if ok := s.state.MarkSession(req, s.clock.Now()); !ok { - return nil, status.Errorf(codes.NotFound, "Session %q not found", req.GetSession().GetSessionId()) + return nil, status.Errorf(codes.NotFound, "Session %q not found", sessionID) } + s.state.RefreshSessionConsumptionMetrics(sessionID) + return &empty.Empty{}, nil } diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go new file mode 100644 index 0000000000..ac6b43ced9 --- /dev/null +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -0,0 +1,59 @@ +package state + +import ( + "time" +) + +// SessionConsumptionMetricsStaleTTL is the duration after which we consider the metrics to be staled, meaning +// that they should not be updated anymore since the user doesn't really use Telepresence at the moment. +const SessionConsumptionMetricsStaleTTL = 1 * time.Minute // TODO: Increase. + +type SessionConsumptionMetrics struct { + Duration uint32 + LastUpdate time.Time +} + +func (s *state) unlockedAddSessionConsumption(sessionID string) { + s.sessionConsumptionMetrics[sessionID] = &SessionConsumptionMetrics{ + Duration: 0, + LastUpdate: time.Now(), + } +} + +func (s *state) unlockedRemoveSessionConsumption(sessionID string) { + delete(s.sessionConsumptionMetrics, sessionID) +} + +func (s *state) GetSessionConsumptionMetrics(sessionID string) *SessionConsumptionMetrics { + s.mu.RLock() + defer s.mu.RUnlock() + return s.sessionConsumptionMetrics[sessionID] +} + +func (c *state) GetAllSessionConsumptionMetrics() map[string]*SessionConsumptionMetrics { + scmCopy := make(map[string]*SessionConsumptionMetrics) + c.mu.RLock() + defer c.mu.RUnlock() + for sessionID, val := range c.sessionConsumptionMetrics { + valCopy := *val + scmCopy[sessionID] = &valCopy + } + return scmCopy +} + +// RefreshSessionConsumptionMetrics refreshes the metrics associated to a specific session. +func (s *state) RefreshSessionConsumptionMetrics(sessionID string) { + lastMark := s.GetSession(sessionID).LastMarked() + s.mu.Lock() + defer s.mu.Unlock() + consumption := s.sessionConsumptionMetrics[sessionID] + + // if last mark is more than SessionConsumptionMetricsStaleTTL old, it means the duration metric should stop being + // updated since the user machine is maybe in standby. + isStale := time.Now().After(lastMark.Add(SessionConsumptionMetricsStaleTTL)) + if !isStale { + consumption.Duration += uint32(time.Since(consumption.LastUpdate).Seconds()) + } + + consumption.LastUpdate = time.Now() +} diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index b295266039..4c7e01c138 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -42,6 +42,9 @@ type State interface { GetAgent(string) *rpc.AgentInfo GetAllClients() map[string]*rpc.ClientInfo GetClient(string) *rpc.ClientInfo + GetSession(string) SessionState + GetSessionConsumptionMetrics(string) *SessionConsumptionMetrics + GetAllSessionConsumptionMetrics() map[string]*SessionConsumptionMetrics GetIntercept(string) (*rpc.InterceptInfo, bool) MarkSession(*rpc.RemainRequest, time.Time) bool NewInterceptInfo(string, *rpc.SessionInfo, *rpc.CreateInterceptRequest) *rpc.InterceptInfo @@ -54,6 +57,7 @@ type State interface { Tunnel(context.Context, tunnel.Stream) error UpdateIntercept(string, func(*rpc.InterceptInfo)) *rpc.InterceptInfo UpdateClient(sessionID string, apply func(*rpc.ClientInfo)) *rpc.ClientInfo + RefreshSessionConsumptionMetrics(sessionID string) ValidateAgentImage(string, bool) error WaitForTempLogLevel(rpc.Manager_WatchLogLevelServer) error WatchAgents(context.Context, func(sessionID string, agent *rpc.AgentInfo) bool) <-chan watchable.Snapshot[*rpc.AgentInfo] @@ -83,16 +87,17 @@ type state struct { // 7. `cfgMapLocks` access must be concurrency protected // 8. `cachedAgentImage` access must be concurrency protected // 9. `interceptState` must be concurrency protected and updated/deleted in sync with intercepts - intercepts watchable.Map[*rpc.InterceptInfo] // info for intercepts, keyed by intercept id - agents watchable.Map[*rpc.AgentInfo] // info for agent sessions, keyed by session id - clients watchable.Map[*rpc.ClientInfo] // info for client sessions, keyed by session id - sessions map[string]SessionState // info for all sessions, keyed by session id - agentsByName map[string]map[string]*rpc.AgentInfo // indexed copy of `agents` - interceptStates map[string]*interceptState - timedLogLevel log.TimedLevel - llSubs *loglevelSubscribers - cfgMapLocks map[string]*sync.Mutex - tunnelCounter int32 + intercepts watchable.Map[*rpc.InterceptInfo] // info for intercepts, keyed by intercept id + agents watchable.Map[*rpc.AgentInfo] // info for agent sessions, keyed by session id + clients watchable.Map[*rpc.ClientInfo] // info for client sessions, keyed by session id + sessions map[string]SessionState // info for all sessions, keyed by session id + agentsByName map[string]map[string]*rpc.AgentInfo // indexed copy of `agents` + interceptStates map[string]*interceptState + sessionConsumptionMetrics map[string]*SessionConsumptionMetrics // TODO: For stale, use LastMark from session ? + timedLogLevel log.TimedLevel + llSubs *loglevelSubscribers + cfgMapLocks map[string]*sync.Mutex + tunnelCounter int32 // Possibly extended version of the state. Use when calling interface methods. self State @@ -103,13 +108,14 @@ var NewStateFunc = NewState //nolint:gochecknoglobals // extension point func NewState(ctx context.Context) State { loglevel := os.Getenv("LOG_LEVEL") s := &state{ - ctx: ctx, - sessions: make(map[string]SessionState), - agentsByName: make(map[string]map[string]*rpc.AgentInfo), - cfgMapLocks: make(map[string]*sync.Mutex), - interceptStates: make(map[string]*interceptState), - timedLogLevel: log.NewTimedLevel(loglevel, log.SetLevel), - llSubs: newLoglevelSubscribers(), + ctx: ctx, + sessions: make(map[string]SessionState), + sessionConsumptionMetrics: make(map[string]*SessionConsumptionMetrics), + agentsByName: make(map[string]map[string]*rpc.AgentInfo), + cfgMapLocks: make(map[string]*sync.Mutex), + interceptStates: make(map[string]*interceptState), + timedLogLevel: log.NewTimedLevel(loglevel, log.SetLevel), + llSubs: newLoglevelSubscribers(), } s.self = s return s @@ -200,12 +206,19 @@ func (s *state) MarkSession(req *rpc.RemainRequest, now time.Time) (ok bool) { return false } +func (s *state) GetSession(sessionID string) SessionState { + s.mu.RLock() + defer s.mu.RUnlock() + return s.sessions[sessionID] +} + // RemoveSession removes a session from the set of present session IDs. func (s *state) RemoveSession(ctx context.Context, sessionID string) { s.mu.Lock() defer s.mu.Unlock() dlog.Debugf(ctx, "Session %s removed. Explicit removal", sessionID) + s.unlockedRemoveSessionConsumption(sessionID) s.unlockedRemoveSession(sessionID) } @@ -271,6 +284,7 @@ func (s *state) ExpireSessions(ctx context.Context, clientMoment, agentMoment ti if _, ok := sess.(*clientSessionState); ok { if sess.LastMarked().Before(clientMoment) { dlog.Debugf(ctx, "Client Session %s removed. It has expired", id) + s.unlockedRemoveSessionConsumption(id) s.unlockedRemoveSession(id) } } else { @@ -315,6 +329,12 @@ func (s *state) addClient(sessionID string, client *rpc.ClientInfo, now time.Tim panic(fmt.Errorf("duplicate id %q, existing %+v, new %+v", sessionID, oldClient, client)) } s.sessions[sessionID] = newClientSessionState(s.ctx, now) + + // Only store consumption for clientSession states. + if _, ok := s.sessions[sessionID].(*clientSessionState); ok { + s.unlockedAddSessionConsumption(sessionID) + } + return sessionID } From 732c2602d1fb647ad167b709b2d946c309c99579 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Mon, 24 Jul 2023 14:44:44 -0400 Subject: [PATCH 02/11] Fix panic by ensuring only client session consumption is stored MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/state/consumption.go | 10 ++++++++-- cmd/traffic/cmd/manager/state/state.go | 5 ++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index ac6b43ced9..0bb273c569 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -43,14 +43,20 @@ func (c *state) GetAllSessionConsumptionMetrics() map[string]*SessionConsumption // RefreshSessionConsumptionMetrics refreshes the metrics associated to a specific session. func (s *state) RefreshSessionConsumptionMetrics(sessionID string) { - lastMark := s.GetSession(sessionID).LastMarked() s.mu.Lock() defer s.mu.Unlock() + + session := s.sessions[sessionID] + if _, isClientSession := session.(*clientSessionState); !isClientSession { + return + } + + lastMarked := session.LastMarked() consumption := s.sessionConsumptionMetrics[sessionID] // if last mark is more than SessionConsumptionMetricsStaleTTL old, it means the duration metric should stop being // updated since the user machine is maybe in standby. - isStale := time.Now().After(lastMark.Add(SessionConsumptionMetricsStaleTTL)) + isStale := time.Now().After(lastMarked.Add(SessionConsumptionMetricsStaleTTL)) if !isStale { consumption.Duration += uint32(time.Since(consumption.LastUpdate).Seconds()) } diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index 4c7e01c138..7f24f0aacb 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -218,7 +218,10 @@ func (s *state) RemoveSession(ctx context.Context, sessionID string) { defer s.mu.Unlock() dlog.Debugf(ctx, "Session %s removed. Explicit removal", sessionID) - s.unlockedRemoveSessionConsumption(sessionID) + if _, isClientSession := s.sessions[sessionID].(*clientSessionState); isClientSession { + s.unlockedRemoveSessionConsumption(sessionID) + } + s.unlockedRemoveSession(sessionID) } From adfdfa7a3d538973c7622118c5ec873bd85f46b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Tue, 25 Jul 2023 10:12:24 -0400 Subject: [PATCH 03/11] Add unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- .../cmd/manager/state/consumption_test.go | 38 +++++++++ cmd/traffic/cmd/manager/state/state.go | 5 +- cmd/traffic/cmd/manager/state/state_test.go | 79 +++++++++++++++++-- 3 files changed, 113 insertions(+), 9 deletions(-) create mode 100644 cmd/traffic/cmd/manager/state/consumption_test.go diff --git a/cmd/traffic/cmd/manager/state/consumption_test.go b/cmd/traffic/cmd/manager/state/consumption_test.go new file mode 100644 index 0000000000..ac75a13afa --- /dev/null +++ b/cmd/traffic/cmd/manager/state/consumption_test.go @@ -0,0 +1,38 @@ +package state + +import ( + "time" + + "github.com/stretchr/testify/assert" +) + +func (s *suiteState) TestRefreshSessionConsumptionMetrics() { + // given + now := time.Now() + session1 := &clientSessionState{} + session1.SetLastMarked(now) + session3 := &clientSessionState{} + session3.SetLastMarked(now.Add(-24 * time.Hour * 30)) + s.state.sessions["session-1"] = session1 + s.state.sessions["session-2"] = &agentSessionState{} + s.state.sessions["session-3"] = session3 + s.state.sessionConsumptionMetrics["session-1"] = &SessionConsumptionMetrics{ + Duration: 42, + LastUpdate: now.Add(-time.Minute), + } + // staled metric + s.state.sessionConsumptionMetrics["session-3"] = &SessionConsumptionMetrics{ + Duration: 36, + LastUpdate: session3.lastMarked, + } + + // when + s.state.RefreshSessionConsumptionMetrics("session-1") + s.state.RefreshSessionConsumptionMetrics("session-2") // should not fail. + s.state.RefreshSessionConsumptionMetrics("session-3") // should not refresh a stale metric. + + // then + assert.Len(s.T(), s.state.GetAllSessionConsumptionMetrics(), 2) + assert.True(s.T(), (s.state.sessionConsumptionMetrics["session-1"].Duration) > 42) + assert.Equal(s.T(), 36, int(s.state.sessionConsumptionMetrics["session-3"].Duration)) +} diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index 7f24f0aacb..b4c08c52c3 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -333,10 +333,7 @@ func (s *state) addClient(sessionID string, client *rpc.ClientInfo, now time.Tim } s.sessions[sessionID] = newClientSessionState(s.ctx, now) - // Only store consumption for clientSession states. - if _, ok := s.sessions[sessionID].(*clientSessionState); ok { - s.unlockedAddSessionConsumption(sessionID) - } + s.unlockedAddSessionConsumption(sessionID) return sessionID } diff --git a/cmd/traffic/cmd/manager/state/state_test.go b/cmd/traffic/cmd/manager/state/state_test.go index feb0b7cf6d..e0b39d6660 100644 --- a/cmd/traffic/cmd/manager/state/state_test.go +++ b/cmd/traffic/cmd/manager/state/state_test.go @@ -2,13 +2,41 @@ package state import ( "context" + "sync" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + + "github.com/datawire/dlib/dlog" "github.com/telepresenceio/telepresence/rpc/v2/manager" + rpc "github.com/telepresenceio/telepresence/rpc/v2/manager" testdata "github.com/telepresenceio/telepresence/v2/cmd/traffic/cmd/manager/test" + "github.com/telepresenceio/telepresence/v2/pkg/log" ) +type suiteState struct { + suite.Suite + + ctx context.Context + state *state +} + +func (s *suiteState) SetupTest() { + s.ctx = dlog.NewTestContext(s.T(), false) + s.state = &state{ + ctx: s.ctx, + sessions: make(map[string]SessionState), + sessionConsumptionMetrics: make(map[string]*SessionConsumptionMetrics), + agentsByName: make(map[string]map[string]*rpc.AgentInfo), + cfgMapLocks: make(map[string]*sync.Mutex), + interceptStates: make(map[string]*interceptState), + timedLogLevel: log.NewTimedLevel("debug", log.SetLevel), + llSubs: newLoglevelSubscribers(), + } +} + type FakeClock struct { When int } @@ -19,13 +47,13 @@ func (fc *FakeClock) Now() time.Time { return base.Add(offset) } -func TestStateInternal(topT *testing.T) { +func (s *suiteState) TestStateInternal() { ctx := context.Background() - testAgents := testdata.GetTestAgents(topT) - testClients := testdata.GetTestClients(topT) + testAgents := testdata.GetTestAgents(s.T()) + testClients := testdata.GetTestClients(s.T()) - topT.Run("agents", func(t *testing.T) { + s.T().Run("agents", func(t *testing.T) { a := assertNew(t) helloAgent := testAgents["hello"] @@ -73,7 +101,7 @@ func TestStateInternal(topT *testing.T) { a.Len(agents, 0) }) - topT.Run("presence-redundant", func(t *testing.T) { + s.T().Run("presence-redundant", func(t *testing.T) { a := assertNew(t) clock := &FakeClock{} @@ -128,3 +156,44 @@ func TestStateInternal(topT *testing.T) { a.False(s.MarkSession(&manager.RemainRequest{Session: &manager.SessionInfo{SessionId: c3}}, clock.Now())) }) } + +func (s *suiteState) TestAddClient() { + // given + now := time.Now() + + // when + s.state.AddClient(&rpc.ClientInfo{ + Name: "my-client", + InstallId: "1234", + Product: "5668", + Version: "2.14.2", + ApiKey: "xxxx", + }, now) + + // then + assert.Len(s.T(), s.state.sessions, 1) + assert.Len(s.T(), s.state.sessionConsumptionMetrics, 1) +} + +func (s *suiteState) TestRemoveSession() { + // given + now := time.Now() + s.state.sessions["session-1"] = newClientSessionState(s.ctx, now) + s.state.sessions["session-2"] = newAgentSessionState(s.ctx, now) + s.state.sessionConsumptionMetrics["session-1"] = &SessionConsumptionMetrics{ + Duration: 42, + LastUpdate: now.Add(-time.Minute), + } + + // when + s.state.RemoveSession(s.ctx, "session-1") + s.state.RemoveSession(s.ctx, "session-2") // won't fail trying to delete consumption. + + // then + assert.Len(s.T(), s.state.sessions, 0) + assert.Len(s.T(), s.state.sessionConsumptionMetrics, 0) +} + +func TestSuiteState(testing *testing.T) { + suite.Run(testing, new(suiteState)) +} From d11f6662f7754269683c9c4a9a35ef2be8187b93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Thu, 27 Jul 2023 17:00:57 -0400 Subject: [PATCH 04/11] Move consumption to session state & add data usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/state/consumption.go | 75 ++++++++++++++----- .../cmd/manager/state/consumption_test.go | 18 ++--- cmd/traffic/cmd/manager/state/session.go | 45 +++++++++-- cmd/traffic/cmd/manager/state/state.go | 64 +++++++++------- cmd/traffic/cmd/manager/state/state_test.go | 21 ++---- pkg/client/remotefs/bridge.go | 2 +- pkg/client/rootd/stream_creator.go | 2 +- pkg/forwarder/tcp.go | 2 +- pkg/tunnel/bidipipe.go | 28 +++++-- pkg/tunnel/dialer.go | 47 ++++++++---- pkg/tunnel/stream.go | 17 ++++- pkg/tunnel/stream_test.go | 10 +-- pkg/tunnel/udplistener.go | 2 +- pkg/vif/stack.go | 2 +- 14 files changed, 227 insertions(+), 108 deletions(-) diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index 0bb273c569..2b2589b4fd 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -1,6 +1,7 @@ package state import ( + "context" "time" ) @@ -8,37 +9,73 @@ import ( // that they should not be updated anymore since the user doesn't really use Telepresence at the moment. const SessionConsumptionMetricsStaleTTL = 1 * time.Minute // TODO: Increase. +func NewSessionConsumptionMetrics() *SessionConsumptionMetrics { + return &SessionConsumptionMetrics{ + ConnectDuration: 0, + FromClientBytesChan: make(chan uint64), + ToClientBytesChan: make(chan uint64), + + LastUpdate: time.Now(), + } +} + type SessionConsumptionMetrics struct { - Duration uint32 - LastUpdate time.Time + ConnectDuration uint32 + LastUpdate time.Time + + // data from client to the traffic manager. + fromClientBytes uint64 + // data from the traffic manager to the client. + toClientBytes uint64 + + FromClientBytesChan chan uint64 + ToClientBytesChan chan uint64 } -func (s *state) unlockedAddSessionConsumption(sessionID string) { - s.sessionConsumptionMetrics[sessionID] = &SessionConsumptionMetrics{ - Duration: 0, - LastUpdate: time.Now(), +func (sc *SessionConsumptionMetrics) RunCollect(ctx context.Context) { + for { + select { + case <-ctx.Done(): + sc.closeChannels() + return + case b, ok := <-sc.FromClientBytesChan: + if !ok { + return + } + sc.fromClientBytes += b + case b, ok := <-sc.ToClientBytesChan: + if !ok { + return + } + sc.toClientBytes += b + } } } -func (s *state) unlockedRemoveSessionConsumption(sessionID string) { - delete(s.sessionConsumptionMetrics, sessionID) +func (sc *SessionConsumptionMetrics) closeChannels() { + close(sc.FromClientBytesChan) + close(sc.ToClientBytesChan) } func (s *state) GetSessionConsumptionMetrics(sessionID string) *SessionConsumptionMetrics { s.mu.RLock() defer s.mu.RUnlock() - return s.sessionConsumptionMetrics[sessionID] + for i := range s.sessions { + if i == sessionID { + return s.sessions[i].ConsumptionMetrics() + } + } + return nil } -func (c *state) GetAllSessionConsumptionMetrics() map[string]*SessionConsumptionMetrics { - scmCopy := make(map[string]*SessionConsumptionMetrics) - c.mu.RLock() - defer c.mu.RUnlock() - for sessionID, val := range c.sessionConsumptionMetrics { - valCopy := *val - scmCopy[sessionID] = &valCopy +func (s *state) GetAllSessionConsumptionMetrics() map[string]*SessionConsumptionMetrics { + allSCM := make(map[string]*SessionConsumptionMetrics) + s.mu.RLock() + defer s.mu.RUnlock() + for sessionID := range s.sessions { + allSCM[sessionID] = s.sessions[sessionID].ConsumptionMetrics() } - return scmCopy + return allSCM } // RefreshSessionConsumptionMetrics refreshes the metrics associated to a specific session. @@ -52,13 +89,13 @@ func (s *state) RefreshSessionConsumptionMetrics(sessionID string) { } lastMarked := session.LastMarked() - consumption := s.sessionConsumptionMetrics[sessionID] + consumption := s.sessions[sessionID].ConsumptionMetrics() // if last mark is more than SessionConsumptionMetricsStaleTTL old, it means the duration metric should stop being // updated since the user machine is maybe in standby. isStale := time.Now().After(lastMarked.Add(SessionConsumptionMetricsStaleTTL)) if !isStale { - consumption.Duration += uint32(time.Since(consumption.LastUpdate).Seconds()) + consumption.ConnectDuration += uint32(time.Since(consumption.LastUpdate).Seconds()) } consumption.LastUpdate = time.Now() diff --git a/cmd/traffic/cmd/manager/state/consumption_test.go b/cmd/traffic/cmd/manager/state/consumption_test.go index ac75a13afa..72819f3e3e 100644 --- a/cmd/traffic/cmd/manager/state/consumption_test.go +++ b/cmd/traffic/cmd/manager/state/consumption_test.go @@ -16,14 +16,14 @@ func (s *suiteState) TestRefreshSessionConsumptionMetrics() { s.state.sessions["session-1"] = session1 s.state.sessions["session-2"] = &agentSessionState{} s.state.sessions["session-3"] = session3 - s.state.sessionConsumptionMetrics["session-1"] = &SessionConsumptionMetrics{ - Duration: 42, - LastUpdate: now.Add(-time.Minute), + session1.consumptionMetrics = &SessionConsumptionMetrics{ + ConnectDuration: 42, + LastUpdate: now.Add(-time.Minute), } // staled metric - s.state.sessionConsumptionMetrics["session-3"] = &SessionConsumptionMetrics{ - Duration: 36, - LastUpdate: session3.lastMarked, + session3.consumptionMetrics = &SessionConsumptionMetrics{ + ConnectDuration: 36, + LastUpdate: session3.lastMarked, } // when @@ -32,7 +32,7 @@ func (s *suiteState) TestRefreshSessionConsumptionMetrics() { s.state.RefreshSessionConsumptionMetrics("session-3") // should not refresh a stale metric. // then - assert.Len(s.T(), s.state.GetAllSessionConsumptionMetrics(), 2) - assert.True(s.T(), (s.state.sessionConsumptionMetrics["session-1"].Duration) > 42) - assert.Equal(s.T(), 36, int(s.state.sessionConsumptionMetrics["session-3"].Duration)) + assert.Len(s.T(), s.state.GetAllSessionConsumptionMetrics(), 3) + assert.True(s.T(), (s.state.sessions["session-1"].ConsumptionMetrics().ConnectDuration) > 42) + assert.Equal(s.T(), 36, int(s.state.sessions["session-3"].ConsumptionMetrics().ConnectDuration)) } diff --git a/cmd/traffic/cmd/manager/state/session.go b/cmd/traffic/cmd/manager/state/session.go index 9f10daaaf2..f0a95da23b 100644 --- a/cmd/traffic/cmd/manager/state/session.go +++ b/cmd/traffic/cmd/manager/state/session.go @@ -15,14 +15,18 @@ import ( "github.com/telepresenceio/telepresence/v2/pkg/tunnel" ) +const agentSessionIDPrefix = "agent:" + type SessionState interface { Cancel() + AwaitingBidiMapOwnerSessionID(stream tunnel.Stream) string + ConsumptionMetrics() *SessionConsumptionMetrics Done() <-chan struct{} LastMarked() time.Time SetLastMarked(lastMarked time.Time) Dials() <-chan *rpc.DialRequest EstablishBidiPipe(context.Context, tunnel.Stream) (tunnel.Endpoint, error) - OnConnect(context.Context, tunnel.Stream, *int32) (tunnel.Endpoint, error) + OnConnect(context.Context, tunnel.Stream, *int32, *SessionConsumptionMetrics) (tunnel.Endpoint, error) } type awaitingBidiPipe struct { @@ -38,6 +42,11 @@ type sessionState struct { lastMarked time.Time awaitingBidiPipeMap map[tunnel.ConnID]*awaitingBidiPipe dials chan *rpc.DialRequest + consumptionMetrics *SessionConsumptionMetrics +} + +func (ss *sessionState) ConsumptionMetrics() *SessionConsumptionMetrics { + return ss.consumptionMetrics } // EstablishBidiPipe registers the given stream as waiting for a matching stream to arrive in a call @@ -86,12 +95,27 @@ func (ss *sessionState) EstablishBidiPipe(ctx context.Context, stream tunnel.Str } } +func (ss *sessionState) AwaitingBidiMapOwnerSessionID(stream tunnel.Stream) string { + ss.Lock() + defer ss.Unlock() + if abp, ok := ss.awaitingBidiPipeMap[stream.ID()]; ok { + return abp.stream.SessionID() + } + return "" +} + // OnConnect checks if a stream is waiting for the given stream to arrive in order to create a BidiPipe. // If that's the case, the BidiPipe is created, started, and returned by both this method and the EstablishBidiPipe // method that registered the waiting stream. Otherwise, this method returns nil. -func (ss *sessionState) OnConnect(_ context.Context, stream tunnel.Stream, counter *int32) (tunnel.Endpoint, error) { +func (ss *sessionState) OnConnect( + ctx context.Context, + stream tunnel.Stream, + counter *int32, + consumptionMetrics *SessionConsumptionMetrics, +) (tunnel.Endpoint, error) { id := stream.ID() ss.Lock() + // abp is a session corresponding to an end user machine abp, ok := ss.awaitingBidiPipeMap[id] if ok { delete(ss.awaitingBidiPipeMap, id) @@ -102,7 +126,13 @@ func (ss *sessionState) OnConnect(_ context.Context, stream tunnel.Stream, count return nil, nil } name := fmt.Sprintf("%s: session %s -> %s", id, abp.stream.SessionID(), stream.SessionID()) - bidiPipe := tunnel.NewBidiPipe(abp.stream, stream, name, counter) + tunnelProbes := &tunnel.BidiPipeProbes{} + if consumptionMetrics != nil { + tunnelProbes.BytesProbeA = consumptionMetrics.FromClientBytesChan + tunnelProbes.BytesProbeB = consumptionMetrics.ToClientBytesChan + } + + bidiPipe := tunnel.NewBidiPipe(abp.stream, stream, name, counter, tunnelProbes) bidiPipe.Start(abp.ctx) defer close(abp.bidiPipeCh) @@ -138,10 +168,11 @@ func (ss *sessionState) SetLastMarked(lastMarked time.Time) { func newSessionState(ctx context.Context, now time.Time) sessionState { ctx, cancel := context.WithCancel(ctx) return sessionState{ - doneCh: ctx.Done(), - cancel: cancel, - lastMarked: now, - dials: make(chan *rpc.DialRequest), + doneCh: ctx.Done(), + cancel: cancel, + lastMarked: now, + dials: make(chan *rpc.DialRequest), + consumptionMetrics: NewSessionConsumptionMetrics(), } } diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index b4c08c52c3..10c9ab5bd4 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -87,17 +87,16 @@ type state struct { // 7. `cfgMapLocks` access must be concurrency protected // 8. `cachedAgentImage` access must be concurrency protected // 9. `interceptState` must be concurrency protected and updated/deleted in sync with intercepts - intercepts watchable.Map[*rpc.InterceptInfo] // info for intercepts, keyed by intercept id - agents watchable.Map[*rpc.AgentInfo] // info for agent sessions, keyed by session id - clients watchable.Map[*rpc.ClientInfo] // info for client sessions, keyed by session id - sessions map[string]SessionState // info for all sessions, keyed by session id - agentsByName map[string]map[string]*rpc.AgentInfo // indexed copy of `agents` - interceptStates map[string]*interceptState - sessionConsumptionMetrics map[string]*SessionConsumptionMetrics // TODO: For stale, use LastMark from session ? - timedLogLevel log.TimedLevel - llSubs *loglevelSubscribers - cfgMapLocks map[string]*sync.Mutex - tunnelCounter int32 + intercepts watchable.Map[*rpc.InterceptInfo] // info for intercepts, keyed by intercept id + agents watchable.Map[*rpc.AgentInfo] // info for agent sessions, keyed by session id + clients watchable.Map[*rpc.ClientInfo] // info for client sessions, keyed by session id + sessions map[string]SessionState // info for all sessions, keyed by session id + agentsByName map[string]map[string]*rpc.AgentInfo // indexed copy of `agents` + interceptStates map[string]*interceptState + timedLogLevel log.TimedLevel + llSubs *loglevelSubscribers + cfgMapLocks map[string]*sync.Mutex + tunnelCounter int32 // Possibly extended version of the state. Use when calling interface methods. self State @@ -108,14 +107,13 @@ var NewStateFunc = NewState //nolint:gochecknoglobals // extension point func NewState(ctx context.Context) State { loglevel := os.Getenv("LOG_LEVEL") s := &state{ - ctx: ctx, - sessions: make(map[string]SessionState), - sessionConsumptionMetrics: make(map[string]*SessionConsumptionMetrics), - agentsByName: make(map[string]map[string]*rpc.AgentInfo), - cfgMapLocks: make(map[string]*sync.Mutex), - interceptStates: make(map[string]*interceptState), - timedLogLevel: log.NewTimedLevel(loglevel, log.SetLevel), - llSubs: newLoglevelSubscribers(), + ctx: ctx, + sessions: make(map[string]SessionState), + agentsByName: make(map[string]map[string]*rpc.AgentInfo), + cfgMapLocks: make(map[string]*sync.Mutex), + interceptStates: make(map[string]*interceptState), + timedLogLevel: log.NewTimedLevel(loglevel, log.SetLevel), + llSubs: newLoglevelSubscribers(), } s.self = s return s @@ -218,10 +216,6 @@ func (s *state) RemoveSession(ctx context.Context, sessionID string) { defer s.mu.Unlock() dlog.Debugf(ctx, "Session %s removed. Explicit removal", sessionID) - if _, isClientSession := s.sessions[sessionID].(*clientSessionState); isClientSession { - s.unlockedRemoveSessionConsumption(sessionID) - } - s.unlockedRemoveSession(sessionID) } @@ -287,7 +281,6 @@ func (s *state) ExpireSessions(ctx context.Context, clientMoment, agentMoment ti if _, ok := sess.(*clientSessionState); ok { if sess.LastMarked().Before(clientMoment) { dlog.Debugf(ctx, "Client Session %s removed. It has expired", id) - s.unlockedRemoveSessionConsumption(id) s.unlockedRemoveSession(id) } } else { @@ -331,9 +324,10 @@ func (s *state) addClient(sessionID string, client *rpc.ClientInfo, now time.Tim if oldClient, hasConflict := s.clients.LoadOrStore(sessionID, client); hasConflict { panic(fmt.Errorf("duplicate id %q, existing %+v, new %+v", sessionID, oldClient, client)) } + s.sessions[sessionID] = newClientSessionState(s.ctx, now) - s.unlockedAddSessionConsumption(sessionID) + go s.sessions[sessionID].ConsumptionMetrics().RunCollect(s.ctx) return sessionID } @@ -376,7 +370,7 @@ func (s *state) AddAgent(agent *rpc.AgentInfo, now time.Time) string { s.mu.Lock() defer s.mu.Unlock() - sessionID := "agent:" + uuid.New().String() + sessionID := agentSessionIDPrefix + uuid.New().String() if oldAgent, hasConflict := s.agents.LoadOrStore(sessionID, agent); hasConflict { panic(fmt.Errorf("duplicate id %q, existing %+v, new %+v", sessionID, oldAgent, agent)) } @@ -624,7 +618,16 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { return status.Errorf(codes.NotFound, "Session %q not found", sessionID) } - bidiPipe, err := ss.OnConnect(ctx, stream, &s.tunnelCounter) + var scm *SessionConsumptionMetrics + if clientSessionID := ss.AwaitingBidiMapOwnerSessionID(stream); clientSessionID != "" { + s.mu.RLock() + if _, ok := s.sessions[clientSessionID]; ok { + scm = s.sessions[clientSessionID].ConsumptionMetrics() + } + s.mu.RUnlock() + } + + bidiPipe, err := ss.OnConnect(ctx, stream, &s.tunnelCounter, scm) if err != nil { return err } @@ -671,7 +674,12 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { return err } } else { - endPoint = tunnel.NewDialer(stream, func() {}) + s.mu.RLock() + // When no intercept is running, a new dialer is opened to talk with resources from the traffic manager. + scm = s.sessions[sessionID].ConsumptionMetrics() + s.mu.RUnlock() + + endPoint = tunnel.NewDialer(stream, func() {}, scm.FromClientBytesChan, scm.ToClientBytesChan) endPoint.Start(ctx) } <-endPoint.Done() diff --git a/cmd/traffic/cmd/manager/state/state_test.go b/cmd/traffic/cmd/manager/state/state_test.go index e0b39d6660..69f1d2002c 100644 --- a/cmd/traffic/cmd/manager/state/state_test.go +++ b/cmd/traffic/cmd/manager/state/state_test.go @@ -26,14 +26,13 @@ type suiteState struct { func (s *suiteState) SetupTest() { s.ctx = dlog.NewTestContext(s.T(), false) s.state = &state{ - ctx: s.ctx, - sessions: make(map[string]SessionState), - sessionConsumptionMetrics: make(map[string]*SessionConsumptionMetrics), - agentsByName: make(map[string]map[string]*rpc.AgentInfo), - cfgMapLocks: make(map[string]*sync.Mutex), - interceptStates: make(map[string]*interceptState), - timedLogLevel: log.NewTimedLevel("debug", log.SetLevel), - llSubs: newLoglevelSubscribers(), + ctx: s.ctx, + sessions: make(map[string]SessionState), + agentsByName: make(map[string]map[string]*rpc.AgentInfo), + cfgMapLocks: make(map[string]*sync.Mutex), + interceptStates: make(map[string]*interceptState), + timedLogLevel: log.NewTimedLevel("debug", log.SetLevel), + llSubs: newLoglevelSubscribers(), } } @@ -172,7 +171,6 @@ func (s *suiteState) TestAddClient() { // then assert.Len(s.T(), s.state.sessions, 1) - assert.Len(s.T(), s.state.sessionConsumptionMetrics, 1) } func (s *suiteState) TestRemoveSession() { @@ -180,10 +178,6 @@ func (s *suiteState) TestRemoveSession() { now := time.Now() s.state.sessions["session-1"] = newClientSessionState(s.ctx, now) s.state.sessions["session-2"] = newAgentSessionState(s.ctx, now) - s.state.sessionConsumptionMetrics["session-1"] = &SessionConsumptionMetrics{ - Duration: 42, - LastUpdate: now.Add(-time.Minute), - } // when s.state.RemoveSession(s.ctx, "session-1") @@ -191,7 +185,6 @@ func (s *suiteState) TestRemoveSession() { // then assert.Len(s.T(), s.state.sessions, 0) - assert.Len(s.T(), s.state.sessionConsumptionMetrics, 0) } func TestSuiteState(testing *testing.T) { diff --git a/pkg/client/remotefs/bridge.go b/pkg/client/remotefs/bridge.go index 09c914a844..3db4b54195 100644 --- a/pkg/client/remotefs/bridge.go +++ b/pkg/client/remotefs/bridge.go @@ -73,7 +73,7 @@ func (m *bridgeMounter) dispatchToTunnel(ctx context.Context, conn net.Conn, pod cancel() return fmt.Errorf("failed to create stream: %v", err) } - d := tunnel.NewConnEndpoint(s, conn, cancel) + d := tunnel.NewConnEndpoint(s, conn, cancel, nil, nil) d.Start(ctx) <-d.Done() return nil diff --git a/pkg/client/rootd/stream_creator.go b/pkg/client/rootd/stream_creator.go index 86e4668441..e46b8f9a91 100644 --- a/pkg/client/rootd/stream_creator.go +++ b/pkg/client/rootd/stream_creator.go @@ -24,7 +24,7 @@ func (s *Session) streamCreator() tunnel.StreamCreator { pipeId := tunnel.NewConnID(p, id.Source(), s.dnsLocalAddr.IP, id.SourcePort(), uint16(s.dnsLocalAddr.Port)) dlog.Tracef(c, "Intercept DNS %s to %s", id, pipeId.DestinationAddr()) from, to := tunnel.NewPipe(pipeId, s.session.SessionId) - tunnel.NewDialerTTL(to, func() {}, dnsConnTTL).Start(c) + tunnel.NewDialerTTL(to, func() {}, dnsConnTTL, nil, nil).Start(c) return from, nil } dlog.Debugf(c, "Opening tunnel for id %s", id) diff --git a/pkg/forwarder/tcp.go b/pkg/forwarder/tcp.go index 3c6cd49135..cb7833db1e 100644 --- a/pkg/forwarder/tcp.go +++ b/pkg/forwarder/tcp.go @@ -188,7 +188,7 @@ func (f *interceptor) interceptConn(ctx context.Context, conn net.Conn, iCept *m cancel() return fmt.Errorf("unable to send client session id. Id %s: %v", id, err) } - d := tunnel.NewConnEndpoint(s, conn, cancel) + d := tunnel.NewConnEndpoint(s, conn, cancel, nil, nil) d.Start(ctx) <-d.Done() return nil diff --git a/pkg/tunnel/bidipipe.go b/pkg/tunnel/bidipipe.go index 4d86823216..1d746cca1c 100644 --- a/pkg/tunnel/bidipipe.go +++ b/pkg/tunnel/bidipipe.go @@ -14,16 +14,28 @@ type bidiPipe struct { name string counter *int32 done chan struct{} + + probes *BidiPipeProbes +} + +type BidiPipeProbes struct { + BytesProbeA, BytesProbeB chan uint64 } // NewBidiPipe creates a bidirectional pipe between the two given streams. -func NewBidiPipe(a, b Stream, name string, counter *int32) Endpoint { +func NewBidiPipe(a, b Stream, name string, counter *int32, probes *BidiPipeProbes) Endpoint { + if probes == nil { + probes = &BidiPipeProbes{} + } + return &bidiPipe{ a: a, b: b, name: name, counter: counter, done: make(chan struct{}), + + probes: probes, } } @@ -40,8 +52,9 @@ func (p *bidiPipe) Start(ctx context.Context) { wg.Add(2) dlog.Debugf(ctx, " FWD connect %s", p.name) atomic.AddInt32(p.counter, 1) - go doPipe(ctx, p.a, p.b, &wg) - go doPipe(ctx, p.b, p.a, &wg) + // p.pm collects metrics only for one stream (since the same data is going through both streams) + go p.doPipe(ctx, p.a, p.b, &wg, nil, nil) + go p.doPipe(ctx, p.b, p.a, &wg, nil, nil) wg.Wait() }() } @@ -51,13 +64,16 @@ func (p *bidiPipe) Done() <-chan struct{} { } // doPipe reads from a and writes to b. -func doPipe(ctx context.Context, a, b Stream, wg *sync.WaitGroup) { +func (p *bidiPipe) doPipe( + ctx context.Context, a, b Stream, wg *sync.WaitGroup, + readBytesProbe, writeBytesProbe chan uint64, +) { defer wg.Done() wrCh := make(chan Message, 50) defer close(wrCh) wg.Add(1) - WriteLoop(ctx, b, wrCh, wg) - rdCh, errCh := ReadLoop(ctx, a) + WriteLoop(ctx, b, wrCh, wg, writeBytesProbe) + rdCh, errCh := ReadLoop(ctx, a, readBytesProbe) for { select { case <-ctx.Done(): diff --git a/pkg/tunnel/dialer.go b/pkg/tunnel/dialer.go index 60bc7e7738..e75754ddf9 100644 --- a/pkg/tunnel/dialer.go +++ b/pkg/tunnel/dialer.go @@ -51,12 +51,19 @@ type dialer struct { conn net.Conn connected int32 done chan struct{} + + ingressBytesProbe chan uint64 + egressBytesProbe chan uint64 } // NewDialer creates a new handler that dispatches messages in both directions between the given gRPC stream // and the given connection. -func NewDialer(stream Stream, cancel context.CancelFunc) Endpoint { - return NewConnEndpoint(stream, nil, cancel) +func NewDialer( + stream Stream, + cancel context.CancelFunc, + ingressBytesProbe, egressBytesProbe chan uint64, +) Endpoint { + return NewConnEndpoint(stream, nil, cancel, ingressBytesProbe, egressBytesProbe) } // NewDialerTTL creates a new handler that dispatches messages in both directions between the given gRPC stream @@ -64,19 +71,25 @@ func NewDialer(stream Stream, cancel context.CancelFunc) Endpoint { // // The handler remains active until it's been idle for the ttl duration, at which time it will automatically close // and call the release function it got from the tunnel.Pool to ensure that it gets properly released. -func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration) Endpoint { - return NewConnEndpointTTL(stream, nil, cancel, ttl) +func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration, ingressBytesProbe, egressBytesProbe chan uint64) Endpoint { + return NewConnEndpointTTL(stream, nil, cancel, ttl, ingressBytesProbe, egressBytesProbe) } -func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc) Endpoint { +func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc, ingressBytesProbe, egressBytesProbe chan uint64) Endpoint { ttl := tcpConnTTL if stream.ID().Protocol() == ipproto.UDP { ttl = udpConnTTL } - return NewConnEndpointTTL(stream, conn, cancel, ttl) + return NewConnEndpointTTL(stream, conn, cancel, ttl, ingressBytesProbe, egressBytesProbe) } -func NewConnEndpointTTL(stream Stream, conn net.Conn, cancel context.CancelFunc, ttl time.Duration) Endpoint { +func NewConnEndpointTTL( + stream Stream, + conn net.Conn, + cancel context.CancelFunc, + ttl time.Duration, + ingressBytesProbe, egressBytesProbe chan uint64, +) Endpoint { state := notConnected if conn != nil { state = connecting @@ -88,6 +101,9 @@ func NewConnEndpointTTL(stream Stream, conn net.Conn, cancel context.CancelFunc, conn: conn, connected: state, done: make(chan struct{}), + + ingressBytesProbe: ingressBytesProbe, + egressBytesProbe: egressBytesProbe, } } @@ -189,7 +205,7 @@ func (h *dialer) connToStreamLoop(ctx context.Context, wg *sync.WaitGroup) { }() wg.Add(1) - WriteLoop(ctx, h.stream, outgoing, wg) + WriteLoop(ctx, h.stream, outgoing, wg, h.egressBytesProbe) buf := make([]byte, 0x100000) dlog.Tracef(ctx, " CONN %s conn-to-stream loop started", id) @@ -238,7 +254,7 @@ func (h *dialer) streamToConnLoop(ctx context.Context, wg *sync.WaitGroup) { defer func() { wg.Done() }() - readLoop(ctx, h) + readLoop(ctx, h, h.ingressBytesProbe) } func handleControl(ctx context.Context, h streamReader, cm Message) { @@ -260,7 +276,7 @@ func handleControl(ctx context.Context, h streamReader, cm Message) { } } -func readLoop(ctx context.Context, h streamReader) { +func readLoop(ctx context.Context, h streamReader, trafficProbe chan uint64) { var endReason string endLevel := dlog.LogLevelTrace id := h.getStream().ID() @@ -269,7 +285,7 @@ func readLoop(ctx context.Context, h streamReader) { dlog.Logf(ctx, endLevel, " CONN %s stream-to-conn loop ended because %s", id, endReason) }() - incoming, errCh := ReadLoop(ctx, h.getStream()) + incoming, errCh := ReadLoop(ctx, h.getStream(), trafficProbe) dlog.Tracef(ctx, " CONN %s stream-to-conn loop started", id) for { select { @@ -316,7 +332,12 @@ func readLoop(ctx context.Context, h streamReader) { // DialWaitLoop reads from the given dialStream. A new goroutine that creates a Tunnel to the manager and then // attaches a dialer Endpoint to that tunnel is spawned for each request that arrives. The method blocks until // the dialStream is closed. -func DialWaitLoop(ctx context.Context, manager rpc.ManagerClient, dialStream rpc.Manager_WatchDialClient, sessionID string) error { +func DialWaitLoop( + ctx context.Context, + manager rpc.ManagerClient, + dialStream rpc.Manager_WatchDialClient, + sessionID string, +) error { // create ctx to cleanup leftover dialRespond if waitloop dies ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -355,7 +376,7 @@ func dialRespond(ctx context.Context, manager rpc.ManagerClient, dr *rpc.DialReq cancel() return } - d := NewDialer(s, cancel) + d := NewDialer(s, cancel, nil, nil) d.Start(ctx) <-d.Done() } diff --git a/pkg/tunnel/stream.go b/pkg/tunnel/stream.go index 99fb8525cd..4467b44fcf 100644 --- a/pkg/tunnel/stream.go +++ b/pkg/tunnel/stream.go @@ -74,7 +74,7 @@ type StreamCreator func(context.Context, ConnID) (Stream, error) // ReadLoop reads from the Stream and dispatches messages and error to the give channels. There // will be max one error since the error also terminates the loop. -func ReadLoop(ctx context.Context, s Stream) (<-chan Message, <-chan error) { +func ReadLoop(ctx context.Context, s Stream, b chan uint64) (<-chan Message, <-chan error) { msgCh := make(chan Message, 50) errCh := make(chan error, 1) // Max one message will be sent on this channel dlog.Tracef(ctx, " %s %s, ReadLoop starting", s.Tag(), s.ID()) @@ -91,6 +91,10 @@ func ReadLoop(ctx context.Context, s Stream) (<-chan Message, <-chan error) { for { m, err := s.Receive(ctx) + if m != nil && b != nil { + b <- uint64(len(m.Payload())) + } + switch { case err == nil: select { @@ -122,7 +126,12 @@ func ReadLoop(ctx context.Context, s Stream) (<-chan Message, <-chan error) { // WriteLoop reads messages from the channel and writes them to the Stream. It will call CloseSend() on the // stream when the channel is closed. -func WriteLoop(ctx context.Context, s Stream, msgCh <-chan Message, wg *sync.WaitGroup) { +func WriteLoop( + ctx context.Context, + s Stream, msgCh <-chan Message, + wg *sync.WaitGroup, + b chan uint64, +) { dlog.Tracef(ctx, " %s %s, WriteLoop starting", s.Tag(), s.ID()) go func() { ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "WriteLoop") @@ -145,7 +154,11 @@ func WriteLoop(ctx context.Context, s Stream, msgCh <-chan Message, wg *sync.Wai endReason = "input channel is closed" break } + err := s.Send(ctx, m) + if m != nil && b != nil { + b <- uint64(len(m.Payload())) + } switch { case err == nil: continue diff --git a/pkg/tunnel/stream_test.go b/pkg/tunnel/stream_test.go index 69b2094ab1..c0fdb62704 100644 --- a/pkg/tunnel/stream_test.go +++ b/pkg/tunnel/stream_test.go @@ -140,7 +140,7 @@ func produce(ctx context.Context, s Stream, msg Message, errs chan<- error) { wrCh := make(chan Message) wg := sync.WaitGroup{} wg.Add(1) - WriteLoop(ctx, s, wrCh, &wg) + WriteLoop(ctx, s, wrCh, &wg, nil) go func() { for i := 0; i < 100; i++ { wrCh <- msg @@ -149,7 +149,7 @@ func produce(ctx context.Context, s Stream, msg Message, errs chan<- error) { wg.Wait() }() - rdCh, errCh := ReadLoop(ctx, s) + rdCh, errCh := ReadLoop(ctx, s, nil) select { case <-ctx.Done(): errs <- ctx.Err() @@ -169,9 +169,9 @@ func consume(ctx context.Context, s Stream, expectedPayload []byte, errs chan<- wrCh := make(chan Message) wg := sync.WaitGroup{} wg.Add(1) - WriteLoop(ctx, s, wrCh, &wg) + WriteLoop(ctx, s, wrCh, &wg, nil) defer close(wrCh) - rdCh, errCh := ReadLoop(ctx, s) + rdCh, errCh := ReadLoop(ctx, s, nil) for { select { case <-ctx.Done(): @@ -326,7 +326,7 @@ func TestStream_Xfer(t *testing.T) { case b = <-bCh: } } - fwd := NewBidiPipe(a, b, "pipe", &counter) + fwd := NewBidiPipe(a, b, "pipe", &counter, nil) fwd.Start(ctx) select { case <-ctx.Done(): diff --git a/pkg/tunnel/udplistener.go b/pkg/tunnel/udplistener.go index 39af6b072a..42223f781c 100644 --- a/pkg/tunnel/udplistener.go +++ b/pkg/tunnel/udplistener.go @@ -119,7 +119,7 @@ func (p *udpStream) Stop(ctx context.Context) { func (p *udpStream) Start(ctx context.Context) { p.TimedHandler.Start(ctx) - go readLoop(ctx, p) + go readLoop(ctx, p, nil) } type UdpReadResult struct { diff --git a/pkg/vif/stack.go b/pkg/vif/stack.go index 2bcffa82f3..b011347214 100644 --- a/pkg/vif/stack.go +++ b/pkg/vif/stack.go @@ -252,6 +252,6 @@ func dispatchToStream(ctx context.Context, id tunnel.ConnID, conn net.Conn, stre cancel() return } - ep := tunnel.NewConnEndpoint(stream, conn, cancel) + ep := tunnel.NewConnEndpoint(stream, conn, cancel, nil, nil) ep.Start(ctx) } From 5a6c1f076fd3d739a7625a0b2e04d009a73a5c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 08:10:57 -0400 Subject: [PATCH 05/11] Update docstring & cleanup TODO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/state/consumption.go | 6 +++--- cmd/traffic/cmd/manager/state/state.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index 2b2589b4fd..33fd343d34 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -7,7 +7,7 @@ import ( // SessionConsumptionMetricsStaleTTL is the duration after which we consider the metrics to be staled, meaning // that they should not be updated anymore since the user doesn't really use Telepresence at the moment. -const SessionConsumptionMetricsStaleTTL = 1 * time.Minute // TODO: Increase. +const SessionConsumptionMetricsStaleTTL = 15 * time.Minute func NewSessionConsumptionMetrics() *SessionConsumptionMetrics { return &SessionConsumptionMetrics{ @@ -91,8 +91,8 @@ func (s *state) RefreshSessionConsumptionMetrics(sessionID string) { lastMarked := session.LastMarked() consumption := s.sessions[sessionID].ConsumptionMetrics() - // if last mark is more than SessionConsumptionMetricsStaleTTL old, it means the duration metric should stop being - // updated since the user machine is maybe in standby. + // If the last mark is older than the SessionConsumptionMetricsStaleTTL, it indicates that the duration + // metric should no longer be updated, as the user's machine may be in standby. isStale := time.Now().After(lastMarked.Add(SessionConsumptionMetricsStaleTTL)) if !isStale { consumption.ConnectDuration += uint32(time.Since(consumption.LastUpdate).Seconds()) diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index 10c9ab5bd4..4e2b503456 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -675,7 +675,7 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { } } else { s.mu.RLock() - // When no intercept is running, a new dialer is opened to talk with resources from the traffic manager. + // When no intercept is active, a new dialer is opened to communicate with resources from the traffic manager. scm = s.sessions[sessionID].ConsumptionMetrics() s.mu.RUnlock() From d33fe35836ab7c6b3ca6a8448d9e3e711058a909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 08:22:18 -0400 Subject: [PATCH 06/11] Make clientBytes metrics public MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/state/consumption.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index 33fd343d34..3948bfd0ae 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -24,9 +24,9 @@ type SessionConsumptionMetrics struct { LastUpdate time.Time // data from client to the traffic manager. - fromClientBytes uint64 + FromClientBytes uint64 // data from the traffic manager to the client. - toClientBytes uint64 + ToClientBytes uint64 FromClientBytesChan chan uint64 ToClientBytesChan chan uint64 @@ -42,12 +42,12 @@ func (sc *SessionConsumptionMetrics) RunCollect(ctx context.Context) { if !ok { return } - sc.fromClientBytes += b + sc.FromClientBytes += b case b, ok := <-sc.ToClientBytesChan: if !ok { return } - sc.toClientBytes += b + sc.ToClientBytes += b } } } From 3080f7c5055687a9f4836af30ce7069aff76116d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 09:15:26 -0400 Subject: [PATCH 07/11] Update rpc go.mod MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index e7de7082d3..c1cfdaeb5a 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.2 - github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230726203957-2d275fd44a77 + github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230728122223-d33fe35836ab github.com/vishvananda/netlink v1.2.1-beta.2 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.41.1 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.41.1 From 606635d2b07c924d8f88271f441428d629df793d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 10:17:08 -0400 Subject: [PATCH 08/11] Fix dependencies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- pkg/vif/testdata/router/go.mod | 2 +- tools/src/test-report/go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/vif/testdata/router/go.mod b/pkg/vif/testdata/router/go.mod index f3f4f4f91b..ec9b842a74 100644 --- a/pkg/vif/testdata/router/go.mod +++ b/pkg/vif/testdata/router/go.mod @@ -47,7 +47,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230726203957-2d275fd44a77 // indirect + github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230728122223-d33fe35836ab // indirect github.com/vishvananda/netlink v1.2.1-beta.2 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect diff --git a/tools/src/test-report/go.mod b/tools/src/test-report/go.mod index 278037f695..ba679c388a 100644 --- a/tools/src/test-report/go.mod +++ b/tools/src/test-report/go.mod @@ -59,7 +59,7 @@ require ( github.com/sirupsen/logrus v1.9.2 // indirect github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230726203957-2d275fd44a77 // indirect + github.com/telepresenceio/telepresence/rpc/v2 v2.14.3-0.20230728122223-d33fe35836ab // indirect github.com/xlab/treeprint v1.2.0 // indirect go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect From 7986f563c3b88ff1f3f54011934d466d8f7d5d35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 14:17:17 -0400 Subject: [PATCH 09/11] Fix metrics not processed when request comes from cluster + review feedbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- build-aux/main.mk | 3 +- cmd/traffic/cmd/manager/state/consumption.go | 42 ++++------- cmd/traffic/cmd/manager/state/session.go | 6 +- cmd/traffic/cmd/manager/state/state.go | 27 +++++--- pkg/tunnel/bidipipe.go | 10 +-- pkg/tunnel/dialer.go | 14 ++-- pkg/tunnel/probe.go | 73 ++++++++++++++++++++ pkg/tunnel/stream.go | 19 +++-- 8 files changed, 133 insertions(+), 61 deletions(-) create mode 100644 pkg/tunnel/probe.go diff --git a/build-aux/main.mk b/build-aux/main.mk index 3d862ec01c..375db3789a 100644 --- a/build-aux/main.mk +++ b/build-aux/main.mk @@ -178,7 +178,8 @@ release-binary: $(TELEPRESENCE) tel2-image: build-deps mkdir -p $(BUILDDIR) printf $(TELEPRESENCE_VERSION) > $(BUILDDIR)/version.txt ## Pass version in a file instead of a --build-arg to maximize cache usage - docker build --target tel2 --tag tel2 --tag $(TELEPRESENCE_REGISTRY)/tel2:$(patsubst v%,%,$(TELEPRESENCE_VERSION)) -f build-aux/docker/images/Dockerfile.traffic . + $(eval PLATFORM_ARG := $(if $(TELEPRESENCE_TEL2_IMAGE_PLATFORM), --platform=$(TELEPRESENCE_TEL2_IMAGE_PLATFORM),)) + docker build $(PLATFORM_ARG) --target tel2 --tag tel2 --tag $(TELEPRESENCE_REGISTRY)/tel2:$(patsubst v%,%,$(TELEPRESENCE_VERSION)) -f build-aux/docker/images/Dockerfile.traffic . .PHONY: client-image client-image: build-deps diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index 3948bfd0ae..582babfff0 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -3,6 +3,8 @@ package state import ( "context" "time" + + "github.com/telepresenceio/telepresence/v2/pkg/tunnel" ) // SessionConsumptionMetricsStaleTTL is the duration after which we consider the metrics to be staled, meaning @@ -11,9 +13,9 @@ const SessionConsumptionMetricsStaleTTL = 15 * time.Minute func NewSessionConsumptionMetrics() *SessionConsumptionMetrics { return &SessionConsumptionMetrics{ - ConnectDuration: 0, - FromClientBytesChan: make(chan uint64), - ToClientBytesChan: make(chan uint64), + ConnectDuration: 0, + FromClientBytes: tunnel.NewCounterProbe("FromClientBytes"), + ToClientBytes: tunnel.NewCounterProbe("ToClientBytes"), LastUpdate: time.Now(), } @@ -24,37 +26,19 @@ type SessionConsumptionMetrics struct { LastUpdate time.Time // data from client to the traffic manager. - FromClientBytes uint64 + FromClientBytes *tunnel.CounterProbe // data from the traffic manager to the client. - ToClientBytes uint64 - - FromClientBytesChan chan uint64 - ToClientBytesChan chan uint64 + ToClientBytes *tunnel.CounterProbe } -func (sc *SessionConsumptionMetrics) RunCollect(ctx context.Context) { - for { - select { - case <-ctx.Done(): - sc.closeChannels() - return - case b, ok := <-sc.FromClientBytesChan: - if !ok { - return - } - sc.FromClientBytes += b - case b, ok := <-sc.ToClientBytesChan: - if !ok { - return - } - sc.ToClientBytes += b - } - } +func (s *SessionConsumptionMetrics) RunCollect(ctx context.Context) { + go s.FromClientBytes.RunCollect(ctx) + go s.ToClientBytes.RunCollect(ctx) } -func (sc *SessionConsumptionMetrics) closeChannels() { - close(sc.FromClientBytesChan) - close(sc.ToClientBytesChan) +func (s *SessionConsumptionMetrics) Close() { + s.FromClientBytes.Close() + s.ToClientBytes.Close() } func (s *state) GetSessionConsumptionMetrics(sessionID string) *SessionConsumptionMetrics { diff --git a/cmd/traffic/cmd/manager/state/session.go b/cmd/traffic/cmd/manager/state/session.go index f0a95da23b..c81105ef88 100644 --- a/cmd/traffic/cmd/manager/state/session.go +++ b/cmd/traffic/cmd/manager/state/session.go @@ -15,7 +15,7 @@ import ( "github.com/telepresenceio/telepresence/v2/pkg/tunnel" ) -const agentSessionIDPrefix = "agent:" +const AgentSessionIDPrefix = "agent:" type SessionState interface { Cancel() @@ -128,8 +128,8 @@ func (ss *sessionState) OnConnect( name := fmt.Sprintf("%s: session %s -> %s", id, abp.stream.SessionID(), stream.SessionID()) tunnelProbes := &tunnel.BidiPipeProbes{} if consumptionMetrics != nil { - tunnelProbes.BytesProbeA = consumptionMetrics.FromClientBytesChan - tunnelProbes.BytesProbeB = consumptionMetrics.ToClientBytesChan + tunnelProbes.BytesProbeA = consumptionMetrics.FromClientBytes + tunnelProbes.BytesProbeB = consumptionMetrics.ToClientBytes } bidiPipe := tunnel.NewBidiPipe(abp.stream, stream, name, counter, tunnelProbes) diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index 4e2b503456..fa3583fba4 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -267,6 +267,7 @@ func (s *state) unlockedRemoveSession(sessionID string) { s.clients.Delete(sessionID) } + defer sess.ConsumptionMetrics().Close() delete(s.sessions, sessionID) } } @@ -327,7 +328,7 @@ func (s *state) addClient(sessionID string, client *rpc.ClientInfo, now time.Tim s.sessions[sessionID] = newClientSessionState(s.ctx, now) - go s.sessions[sessionID].ConsumptionMetrics().RunCollect(s.ctx) + s.sessions[sessionID].ConsumptionMetrics().RunCollect(s.ctx) return sessionID } @@ -370,7 +371,7 @@ func (s *state) AddAgent(agent *rpc.AgentInfo, now time.Time) string { s.mu.Lock() defer s.mu.Unlock() - sessionID := agentSessionIDPrefix + uuid.New().String() + sessionID := AgentSessionIDPrefix + uuid.New().String() if oldAgent, hasConflict := s.agents.LoadOrStore(sessionID, agent); hasConflict { panic(fmt.Errorf("duplicate id %q, existing %+v, new %+v", sessionID, oldAgent, agent)) } @@ -619,12 +620,20 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { } var scm *SessionConsumptionMetrics - if clientSessionID := ss.AwaitingBidiMapOwnerSessionID(stream); clientSessionID != "" { + + _, isAgent := ss.(*agentSessionState) + clientSessionID := ss.AwaitingBidiMapOwnerSessionID(stream) + // If there is a bidipipe owner (a client) waiting for an agent, use the metrics from the first one. + if isAgent && clientSessionID != "" { s.mu.RLock() - if _, ok := s.sessions[clientSessionID]; ok { - scm = s.sessions[clientSessionID].ConsumptionMetrics() - } + css, ok := s.sessions[clientSessionID] s.mu.RUnlock() + if ok { + scm = css.ConsumptionMetrics() + } + } else { + // otherwise, by default, use the session consumption metrics. + scm = ss.ConsumptionMetrics() } bidiPipe, err := ss.OnConnect(ctx, stream, &s.tunnelCounter, scm) @@ -642,7 +651,7 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { // The session is either the telepresence client or a traffic-agent. // // A client will want to extend the tunnel to a dialer in an intercepted traffic-agent or, if no - // intercept is active, to a dialer here in the traffic-agent. + // intercept is active, to a dialer here in the traffic-manager. // // A traffic-agent must always extend the tunnel to the client that it is currently intercepted // by, and hence, start by sending the sessionID of that client on the tunnel. @@ -675,11 +684,9 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { } } else { s.mu.RLock() - // When no intercept is active, a new dialer is opened to communicate with resources from the traffic manager. scm = s.sessions[sessionID].ConsumptionMetrics() s.mu.RUnlock() - - endPoint = tunnel.NewDialer(stream, func() {}, scm.FromClientBytesChan, scm.ToClientBytesChan) + endPoint = tunnel.NewDialer(stream, func() {}, scm.FromClientBytes, scm.ToClientBytes) endPoint.Start(ctx) } <-endPoint.Done() diff --git a/pkg/tunnel/bidipipe.go b/pkg/tunnel/bidipipe.go index 1d746cca1c..ab7ee07926 100644 --- a/pkg/tunnel/bidipipe.go +++ b/pkg/tunnel/bidipipe.go @@ -19,7 +19,7 @@ type bidiPipe struct { } type BidiPipeProbes struct { - BytesProbeA, BytesProbeB chan uint64 + BytesProbeA, BytesProbeB *CounterProbe } // NewBidiPipe creates a bidirectional pipe between the two given streams. @@ -52,9 +52,9 @@ func (p *bidiPipe) Start(ctx context.Context) { wg.Add(2) dlog.Debugf(ctx, " FWD connect %s", p.name) atomic.AddInt32(p.counter, 1) - // p.pm collects metrics only for one stream (since the same data is going through both streams) - go p.doPipe(ctx, p.a, p.b, &wg, nil, nil) - go p.doPipe(ctx, p.b, p.a, &wg, nil, nil) + // Only one probe per bidipipe since each one represents a direction. + go p.doPipe(ctx, p.a, p.b, &wg, p.probes.BytesProbeA, nil) + go p.doPipe(ctx, p.b, p.a, &wg, p.probes.BytesProbeB, nil) wg.Wait() }() } @@ -66,7 +66,7 @@ func (p *bidiPipe) Done() <-chan struct{} { // doPipe reads from a and writes to b. func (p *bidiPipe) doPipe( ctx context.Context, a, b Stream, wg *sync.WaitGroup, - readBytesProbe, writeBytesProbe chan uint64, + readBytesProbe, writeBytesProbe *CounterProbe, ) { defer wg.Done() wrCh := make(chan Message, 50) diff --git a/pkg/tunnel/dialer.go b/pkg/tunnel/dialer.go index e75754ddf9..34633e4221 100644 --- a/pkg/tunnel/dialer.go +++ b/pkg/tunnel/dialer.go @@ -52,8 +52,8 @@ type dialer struct { connected int32 done chan struct{} - ingressBytesProbe chan uint64 - egressBytesProbe chan uint64 + ingressBytesProbe *CounterProbe + egressBytesProbe *CounterProbe } // NewDialer creates a new handler that dispatches messages in both directions between the given gRPC stream @@ -61,7 +61,7 @@ type dialer struct { func NewDialer( stream Stream, cancel context.CancelFunc, - ingressBytesProbe, egressBytesProbe chan uint64, + ingressBytesProbe, egressBytesProbe *CounterProbe, ) Endpoint { return NewConnEndpoint(stream, nil, cancel, ingressBytesProbe, egressBytesProbe) } @@ -71,11 +71,11 @@ func NewDialer( // // The handler remains active until it's been idle for the ttl duration, at which time it will automatically close // and call the release function it got from the tunnel.Pool to ensure that it gets properly released. -func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration, ingressBytesProbe, egressBytesProbe chan uint64) Endpoint { +func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint { return NewConnEndpointTTL(stream, nil, cancel, ttl, ingressBytesProbe, egressBytesProbe) } -func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc, ingressBytesProbe, egressBytesProbe chan uint64) Endpoint { +func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint { ttl := tcpConnTTL if stream.ID().Protocol() == ipproto.UDP { ttl = udpConnTTL @@ -88,7 +88,7 @@ func NewConnEndpointTTL( conn net.Conn, cancel context.CancelFunc, ttl time.Duration, - ingressBytesProbe, egressBytesProbe chan uint64, + ingressBytesProbe, egressBytesProbe *CounterProbe, ) Endpoint { state := notConnected if conn != nil { @@ -276,7 +276,7 @@ func handleControl(ctx context.Context, h streamReader, cm Message) { } } -func readLoop(ctx context.Context, h streamReader, trafficProbe chan uint64) { +func readLoop(ctx context.Context, h streamReader, trafficProbe *CounterProbe) { var endReason string endLevel := dlog.LogLevelTrace id := h.getStream().ID() diff --git a/pkg/tunnel/probe.go b/pkg/tunnel/probe.go new file mode 100644 index 0000000000..f16a631635 --- /dev/null +++ b/pkg/tunnel/probe.go @@ -0,0 +1,73 @@ +package tunnel + +import ( + "context" + "fmt" + "sync" + "time" +) + +type CounterProbe struct { + lock sync.Mutex + + name string + channel chan uint64 + timeout time.Duration + value uint64 +} + +const ( + probeChannelTimeout = 100 * time.Millisecond + probeChannelBufferSize = 1024 +) + +func NewCounterProbe(name string) *CounterProbe { + return &CounterProbe{ + lock: sync.Mutex{}, + name: name, + channel: make(chan uint64, probeChannelBufferSize), + timeout: probeChannelTimeout, + } +} + +func (p *CounterProbe) Increment(v uint64) error { + select { + case p.channel <- v: + case <-time.After(p.timeout): + return fmt.Errorf("timeout trying to increment probe channel") + } + return nil +} + +func (p *CounterProbe) RunCollect(ctx context.Context) { + defer p.Close() + for { + select { + case <-ctx.Done(): + return + case b, ok := <-p.channel: + if !ok { + p.Close() + return + } + p.value += b + } + } +} + +func (p *CounterProbe) Close() { + p.lock.Lock() + defer p.lock.Unlock() + if p.channel != nil { + close(p.channel) + p.channel = nil + } +} + +func (p *CounterProbe) GetName() string { + return p.name +} + +func (p *CounterProbe) GetValue() uint64 { + return p.value +} diff --git a/pkg/tunnel/stream.go b/pkg/tunnel/stream.go index 4467b44fcf..e461226ea5 100644 --- a/pkg/tunnel/stream.go +++ b/pkg/tunnel/stream.go @@ -74,7 +74,7 @@ type StreamCreator func(context.Context, ConnID) (Stream, error) // ReadLoop reads from the Stream and dispatches messages and error to the give channels. There // will be max one error since the error also terminates the loop. -func ReadLoop(ctx context.Context, s Stream, b chan uint64) (<-chan Message, <-chan error) { +func ReadLoop(ctx context.Context, s Stream, p *CounterProbe) (<-chan Message, <-chan error) { msgCh := make(chan Message, 50) errCh := make(chan error, 1) // Max one message will be sent on this channel dlog.Tracef(ctx, " %s %s, ReadLoop starting", s.Tag(), s.ID()) @@ -91,8 +91,11 @@ func ReadLoop(ctx context.Context, s Stream, b chan uint64) (<-chan Message, <-c for { m, err := s.Receive(ctx) - if m != nil && b != nil { - b <- uint64(len(m.Payload())) + if m != nil && p != nil { + errInc := p.Increment(uint64(len(m.Payload()))) + if errInc != nil { + dlog.Error(ctx, errInc) + } } switch { @@ -130,7 +133,7 @@ func WriteLoop( ctx context.Context, s Stream, msgCh <-chan Message, wg *sync.WaitGroup, - b chan uint64, + p *CounterProbe, ) { dlog.Tracef(ctx, " %s %s, WriteLoop starting", s.Tag(), s.ID()) go func() { @@ -156,9 +159,13 @@ func WriteLoop( } err := s.Send(ctx, m) - if m != nil && b != nil { - b <- uint64(len(m.Payload())) + if m != nil && p != nil { + errInc := p.Increment(uint64(len(m.Payload()))) + if errInc != nil { + dlog.Error(ctx, errInc) + } } + switch { case err == nil: continue From 08f46ebea591e36d9a4096143d76070a5a1ae279 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Fri, 28 Jul 2023 15:53:32 -0400 Subject: [PATCH 10/11] Attach consumption metrics to the client session state only MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- cmd/traffic/cmd/manager/state/consumption.go | 21 ++++++--- .../cmd/manager/state/consumption_test.go | 9 ++-- cmd/traffic/cmd/manager/state/session.go | 23 +++++----- cmd/traffic/cmd/manager/state/state.go | 43 +++++++++++-------- 4 files changed, 57 insertions(+), 39 deletions(-) diff --git a/cmd/traffic/cmd/manager/state/consumption.go b/cmd/traffic/cmd/manager/state/consumption.go index 582babfff0..1718d62f6b 100644 --- a/cmd/traffic/cmd/manager/state/consumption.go +++ b/cmd/traffic/cmd/manager/state/consumption.go @@ -9,7 +9,7 @@ import ( // SessionConsumptionMetricsStaleTTL is the duration after which we consider the metrics to be staled, meaning // that they should not be updated anymore since the user doesn't really use Telepresence at the moment. -const SessionConsumptionMetricsStaleTTL = 15 * time.Minute +const SessionConsumptionMetricsStaleTTL = 60 * time.Minute func NewSessionConsumptionMetrics() *SessionConsumptionMetrics { return &SessionConsumptionMetrics{ @@ -45,8 +45,8 @@ func (s *state) GetSessionConsumptionMetrics(sessionID string) *SessionConsumpti s.mu.RLock() defer s.mu.RUnlock() for i := range s.sessions { - if i == sessionID { - return s.sessions[i].ConsumptionMetrics() + if css, ok := s.sessions[i].(*clientSessionState); i == sessionID && ok { + return css.ConsumptionMetrics() } } return nil @@ -57,7 +57,9 @@ func (s *state) GetAllSessionConsumptionMetrics() map[string]*SessionConsumption s.mu.RLock() defer s.mu.RUnlock() for sessionID := range s.sessions { - allSCM[sessionID] = s.sessions[sessionID].ConsumptionMetrics() + if css, ok := s.sessions[sessionID].(*clientSessionState); ok { + allSCM[sessionID] = css.ConsumptionMetrics() + } } return allSCM } @@ -73,14 +75,19 @@ func (s *state) RefreshSessionConsumptionMetrics(sessionID string) { } lastMarked := session.LastMarked() - consumption := s.sessions[sessionID].ConsumptionMetrics() + var scm *SessionConsumptionMetrics + if css, ok := s.sessions[sessionID].(*clientSessionState); ok { + scm = css.ConsumptionMetrics() + } else { + return + } // If the last mark is older than the SessionConsumptionMetricsStaleTTL, it indicates that the duration // metric should no longer be updated, as the user's machine may be in standby. isStale := time.Now().After(lastMarked.Add(SessionConsumptionMetricsStaleTTL)) if !isStale { - consumption.ConnectDuration += uint32(time.Since(consumption.LastUpdate).Seconds()) + scm.ConnectDuration += uint32(time.Since(scm.LastUpdate).Seconds()) } - consumption.LastUpdate = time.Now() + scm.LastUpdate = time.Now() } diff --git a/cmd/traffic/cmd/manager/state/consumption_test.go b/cmd/traffic/cmd/manager/state/consumption_test.go index 72819f3e3e..56128db48e 100644 --- a/cmd/traffic/cmd/manager/state/consumption_test.go +++ b/cmd/traffic/cmd/manager/state/consumption_test.go @@ -32,7 +32,10 @@ func (s *suiteState) TestRefreshSessionConsumptionMetrics() { s.state.RefreshSessionConsumptionMetrics("session-3") // should not refresh a stale metric. // then - assert.Len(s.T(), s.state.GetAllSessionConsumptionMetrics(), 3) - assert.True(s.T(), (s.state.sessions["session-1"].ConsumptionMetrics().ConnectDuration) > 42) - assert.Equal(s.T(), 36, int(s.state.sessions["session-3"].ConsumptionMetrics().ConnectDuration)) + ccs1 := s.state.sessions["session-1"].(*clientSessionState) + ccs3 := s.state.sessions["session-3"].(*clientSessionState) + + assert.Len(s.T(), s.state.GetAllSessionConsumptionMetrics(), 2) + assert.True(s.T(), (ccs1.ConsumptionMetrics().ConnectDuration) > 42) + assert.Equal(s.T(), 36, int(ccs3.ConsumptionMetrics().ConnectDuration)) } diff --git a/cmd/traffic/cmd/manager/state/session.go b/cmd/traffic/cmd/manager/state/session.go index c81105ef88..d246e4cd5c 100644 --- a/cmd/traffic/cmd/manager/state/session.go +++ b/cmd/traffic/cmd/manager/state/session.go @@ -20,7 +20,6 @@ const AgentSessionIDPrefix = "agent:" type SessionState interface { Cancel() AwaitingBidiMapOwnerSessionID(stream tunnel.Stream) string - ConsumptionMetrics() *SessionConsumptionMetrics Done() <-chan struct{} LastMarked() time.Time SetLastMarked(lastMarked time.Time) @@ -42,11 +41,6 @@ type sessionState struct { lastMarked time.Time awaitingBidiPipeMap map[tunnel.ConnID]*awaitingBidiPipe dials chan *rpc.DialRequest - consumptionMetrics *SessionConsumptionMetrics -} - -func (ss *sessionState) ConsumptionMetrics() *SessionConsumptionMetrics { - return ss.consumptionMetrics } // EstablishBidiPipe registers the given stream as waiting for a matching stream to arrive in a call @@ -168,23 +162,30 @@ func (ss *sessionState) SetLastMarked(lastMarked time.Time) { func newSessionState(ctx context.Context, now time.Time) sessionState { ctx, cancel := context.WithCancel(ctx) return sessionState{ - doneCh: ctx.Done(), - cancel: cancel, - lastMarked: now, - dials: make(chan *rpc.DialRequest), - consumptionMetrics: NewSessionConsumptionMetrics(), + doneCh: ctx.Done(), + cancel: cancel, + lastMarked: now, + dials: make(chan *rpc.DialRequest), } } type clientSessionState struct { sessionState pool *tunnel.Pool + + consumptionMetrics *SessionConsumptionMetrics +} + +func (css *clientSessionState) ConsumptionMetrics() *SessionConsumptionMetrics { + return css.consumptionMetrics } func newClientSessionState(ctx context.Context, ts time.Time) *clientSessionState { return &clientSessionState{ sessionState: newSessionState(ctx, ts), pool: tunnel.NewPool(), + + consumptionMetrics: NewSessionConsumptionMetrics(), } } diff --git a/cmd/traffic/cmd/manager/state/state.go b/cmd/traffic/cmd/manager/state/state.go index fa3583fba4..9ff688364d 100644 --- a/cmd/traffic/cmd/manager/state/state.go +++ b/cmd/traffic/cmd/manager/state/state.go @@ -267,7 +267,10 @@ func (s *state) unlockedRemoveSession(sessionID string) { s.clients.Delete(sessionID) } - defer sess.ConsumptionMetrics().Close() + if css, ok := sess.(*clientSessionState); ok { + defer css.ConsumptionMetrics().Close() + } + delete(s.sessions, sessionID) } } @@ -328,7 +331,9 @@ func (s *state) addClient(sessionID string, client *rpc.ClientInfo, now time.Tim s.sessions[sessionID] = newClientSessionState(s.ctx, now) - s.sessions[sessionID].ConsumptionMetrics().RunCollect(s.ctx) + if css, ok := s.sessions[sessionID].(*clientSessionState); ok { + css.ConsumptionMetrics().RunCollect(s.ctx) + } return sessionID } @@ -620,20 +625,22 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { } var scm *SessionConsumptionMetrics - - _, isAgent := ss.(*agentSessionState) - clientSessionID := ss.AwaitingBidiMapOwnerSessionID(stream) - // If there is a bidipipe owner (a client) waiting for an agent, use the metrics from the first one. - if isAgent && clientSessionID != "" { - s.mu.RLock() - css, ok := s.sessions[clientSessionID] - s.mu.RUnlock() - if ok { - scm = css.ConsumptionMetrics() + switch sst := ss.(type) { + case *agentSessionState: + // If it's an agent, find the associated clientSessionState. + if clientSessionID := sst.AwaitingBidiMapOwnerSessionID(stream); clientSessionID != "" { + s.mu.RLock() + as := s.sessions[clientSessionID] // get awaiting state + s.mu.RUnlock() + if as != nil { // if found + if css, isClient := as.(*clientSessionState); isClient { + scm = css.ConsumptionMetrics() + } + } } - } else { - // otherwise, by default, use the session consumption metrics. - scm = ss.ConsumptionMetrics() + case *clientSessionState: + scm = sst.ConsumptionMetrics() + default: } bidiPipe, err := ss.OnConnect(ctx, stream, &s.tunnelCounter, scm) @@ -683,9 +690,9 @@ func (s *state) Tunnel(ctx context.Context, stream tunnel.Stream) error { return err } } else { - s.mu.RLock() - scm = s.sessions[sessionID].ConsumptionMetrics() - s.mu.RUnlock() + if css, isClient := ss.(*clientSessionState); isClient { + scm = css.ConsumptionMetrics() + } endPoint = tunnel.NewDialer(stream, func() {}, scm.FromClientBytes, scm.ToClientBytes) endPoint.Start(ctx) } From de55ed4d82770cc8daca09d2f8cd323e60a9dfa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?K=C3=A9vin=20Lambert?= Date: Mon, 31 Jul 2023 15:28:22 -0400 Subject: [PATCH 11/11] Make probe value atomic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kévin Lambert --- pkg/tunnel/probe.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/tunnel/probe.go b/pkg/tunnel/probe.go index f16a631635..7ec78d4143 100644 --- a/pkg/tunnel/probe.go +++ b/pkg/tunnel/probe.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" ) @@ -13,7 +14,7 @@ type CounterProbe struct { name string channel chan uint64 timeout time.Duration - value uint64 + value atomic.Uint64 } const ( @@ -50,7 +51,7 @@ func (p *CounterProbe) RunCollect(ctx context.Context) { p.Close() return } - p.value += b + p.value.Add(b) } } } @@ -69,5 +70,5 @@ func (p *CounterProbe) GetName() string { } func (p *CounterProbe) GetValue() uint64 { - return p.value + return p.value.Load() }