From 8cecfbfdadb5acea06ab5e2ef275272987f1b9c0 Mon Sep 17 00:00:00 2001 From: James Greenhill Date: Wed, 14 Aug 2024 14:28:07 -0700 Subject: [PATCH] feat: add universal 60 sec event counter to livestream (#24380) --- livestream/live_stats.go | 27 +++++++++--- livestream/main.go | 42 +++--------------- livestream/main_test.go | 10 ++--- livestream/served.go | 62 +++++++++++++++++++++++++++ livestream/ttl_counter.go | 48 +++++++++++++++++++++ livestream/ttl_counter_test.go | 78 ++++++++++++++++++++++++++++++++++ 6 files changed, 220 insertions(+), 47 deletions(-) create mode 100644 livestream/served.go create mode 100644 livestream/ttl_counter.go create mode 100644 livestream/ttl_counter_test.go diff --git a/livestream/live_stats.go b/livestream/live_stats.go index 0c5258220bf79..00d3acfa19dd6 100644 --- a/livestream/live_stats.go +++ b/livestream/live_stats.go @@ -7,20 +7,37 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" ) -type TeamStats struct { - Store map[string]*expirable.LRU[string, string] +const ( + COUNTER_TTL = 60 +) + +type Stats struct { + Store map[string]*expirable.LRU[string, string] + GlobalStore *expirable.LRU[string, string] + Counter *SlidingWindowCounter +} + +func newStatsKeeper() *Stats { + return &Stats{ + Store: make(map[string]*expirable.LRU[string, string]), + GlobalStore: expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL), + Counter: NewSlidingWindowCounter(COUNTER_TTL), + } } -func (ts *TeamStats) keepStats(statsChan chan PostHogEvent) { +func (ts *Stats) keepStats(statsChan chan PostHogEvent) { log.Println("starting stats keeper...") + for { // ignore the range warning here - it's wrong select { case event := <-statsChan: + ts.Counter.Increment() token := event.Token if _, ok := ts.Store[token]; !ok { - ts.Store[token] = expirable.NewLRU[string, string](1000000, nil, time.Second*30) + ts.Store[token] = expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL) } - ts.Store[token].Add(event.DistinctId, "much wow") + ts.Store[token].Add(event.DistinctId, "1") + ts.GlobalStore.Add(event.DistinctId, "1") } } } diff --git a/livestream/main.go b/livestream/main.go index 32c624cdd79e1..6a8e147c53140 100644 --- a/livestream/main.go +++ b/livestream/main.go @@ -12,7 +12,6 @@ import ( "time" "github.com/getsentry/sentry-go" - "github.com/hashicorp/golang-lru/v2/expirable" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/spf13/viper" @@ -63,16 +62,14 @@ func main() { log.Fatalf("Failed to open MMDB: %v", err) } - teamStats := &TeamStats{ - Store: make(map[string]*expirable.LRU[string, string]), - } + stats := newStatsKeeper() phEventChan := make(chan PostHogEvent) statsChan := make(chan PostHogEvent) subChan := make(chan Subscription) unSubChan := make(chan Subscription) - go teamStats.keepStats(statsChan) + go stats.keepStats(statsChan) kafkaSecurityProtocol := "SSL" if !isProd { @@ -109,43 +106,14 @@ func main() { // Routes e.GET("/", index) - e.GET("/stats", func(c echo.Context) error { - - type stats struct { - UsersOnProduct int `json:"users_on_product,omitempty"` - Error string `json:"error,omitempty"` - } - - authHeader := c.Request().Header.Get("Authorization") - if authHeader == "" { - return errors.New("authorization header is required") - } - - claims, err := decodeAuthToken(authHeader) - if err != nil { - return err - } - token := fmt.Sprint(claims["api_token"]) + e.GET("/served", servedHandler(stats)) - var hash *expirable.LRU[string, string] - var ok bool - if hash, ok = teamStats.Store[token]; !ok { - resp := stats{ - Error: "no stats", - } - return c.JSON(http.StatusOK, resp) - } - - siteStats := stats{ - UsersOnProduct: hash.Len(), - } - return c.JSON(http.StatusOK, siteStats) - }) + e.GET("/stats", statsHandler(stats)) e.GET("/events", func(c echo.Context) error { e.Logger.Printf("SSE client connected, ip: %v", c.RealIP()) - teamId := c.QueryParam("teamId") + var teamId string eventType := c.QueryParam("eventType") distinctId := c.QueryParam("distinctId") geo := c.QueryParam("geo") diff --git a/livestream/main_test.go b/livestream/main_test.go index 4bc1296fbacb9..fe6a28118a943 100644 --- a/livestream/main_test.go +++ b/livestream/main_test.go @@ -36,18 +36,18 @@ func TestStatsHandler(t *testing.T) { req.Header.Set("Authorization", "Bearer mock_token") // Create a mock TeamStats - teamStats := &TeamStats{ + stats := &Stats{ Store: make(map[string]*expirable.LRU[string, string]), } - teamStats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute) - teamStats.Store["mock_token"].Add("user1", "data1") + stats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute) + stats.Store["mock_token"].Add("user1", "data1") // Add the teamStats to the context - c.Set("teamStats", teamStats) + c.Set("teamStats", stats) handler := func(c echo.Context) error { return c.JSON(http.StatusOK, map[string]interface{}{ - "users_on_product": teamStats.Store["mock_token"].Len(), + "users_on_product": stats.Store["mock_token"].Len(), }) } diff --git a/livestream/served.go b/livestream/served.go new file mode 100644 index 0000000000000..1c1cc4c03967b --- /dev/null +++ b/livestream/served.go @@ -0,0 +1,62 @@ +package main + +import ( + "errors" + "fmt" + "net/http" + + "github.com/hashicorp/golang-lru/v2/expirable" + "github.com/labstack/echo/v4" +) + +type Counter struct { + EventCount uint32 + UserCount uint32 +} + +func servedHandler(stats *Stats) func(c echo.Context) error { + return func(c echo.Context) error { + userCount := stats.GlobalStore.Len() + count := stats.Counter.Count() + resp := Counter{ + EventCount: uint32(count), + UserCount: uint32(userCount), + } + return c.JSON(http.StatusOK, resp) + } +} + +func statsHandler(stats *Stats) func(c echo.Context) error { + return func(c echo.Context) error { + + type resp struct { + UsersOnProduct int `json:"users_on_product,omitempty"` + Error string `json:"error,omitempty"` + } + + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + return errors.New("authorization header is required") + } + + claims, err := decodeAuthToken(authHeader) + if err != nil { + return err + } + token := fmt.Sprint(claims["api_token"]) + + var hash *expirable.LRU[string, string] + var ok bool + if hash, ok = stats.Store[token]; !ok { + resp := resp{ + Error: "no stats", + } + return c.JSON(http.StatusNotFound, resp) + } + + siteStats := resp{ + UsersOnProduct: hash.Len(), + } + return c.JSON(http.StatusOK, siteStats) + } +} diff --git a/livestream/ttl_counter.go b/livestream/ttl_counter.go new file mode 100644 index 0000000000000..4be95d9b20a4b --- /dev/null +++ b/livestream/ttl_counter.go @@ -0,0 +1,48 @@ +package main + +import ( + "sync" + "time" +) + +type SlidingWindowCounter struct { + mu sync.Mutex + events []time.Time + windowSize time.Duration +} + +func NewSlidingWindowCounter(windowSize time.Duration) *SlidingWindowCounter { + return &SlidingWindowCounter{ + events: make([]time.Time, 0), + windowSize: windowSize, + } +} + +func (swc *SlidingWindowCounter) Increment() { + swc.mu.Lock() + defer swc.mu.Unlock() + + now := time.Now() + swc.events = append(swc.events, now) + swc.removeOldEvents(now) +} + +func (swc *SlidingWindowCounter) Count() int { + swc.mu.Lock() + defer swc.mu.Unlock() + + now := time.Now() + swc.removeOldEvents(now) + return len(swc.events) +} + +func (swc *SlidingWindowCounter) removeOldEvents(now time.Time) { + cutoff := now.Add(-swc.windowSize) + i := 0 + for ; i < len(swc.events); i++ { + if swc.events[i].After(cutoff) { + break + } + } + swc.events = swc.events[i:] +} diff --git a/livestream/ttl_counter_test.go b/livestream/ttl_counter_test.go new file mode 100644 index 0000000000000..5138f6ad0e567 --- /dev/null +++ b/livestream/ttl_counter_test.go @@ -0,0 +1,78 @@ +package main + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSlidingWindowCounter(t *testing.T) { + windowSize := time.Minute + swc := NewSlidingWindowCounter(windowSize) + + assert.Equal(t, windowSize, swc.windowSize, "Window size should match") + assert.Empty(t, swc.events, "Events slice should be empty") +} + +func TestIncrement(t *testing.T) { + swc := NewSlidingWindowCounter(time.Minute) + + swc.Increment() + assert.Equal(t, 1, swc.Count(), "Count should be 1 after first increment") + + swc.Increment() + assert.Equal(t, 2, swc.Count(), "Count should be 2 after second increment") +} + +func TestCount(t *testing.T) { + swc := NewSlidingWindowCounter(time.Second) + + swc.Increment() + time.Sleep(500 * time.Millisecond) + swc.Increment() + + assert.Equal(t, 2, swc.Count(), "Count should be 2 within the time window") + + time.Sleep(600 * time.Millisecond) + + assert.Equal(t, 1, swc.Count(), "Count should be 1 after oldest event expires") +} + +func TestRemoveOldEvents(t *testing.T) { + swc := NewSlidingWindowCounter(time.Second) + + now := time.Now() + swc.events = []time.Time{ + now.Add(-2 * time.Second), + now.Add(-1500 * time.Millisecond), + now.Add(-500 * time.Millisecond), + now, + } + + swc.removeOldEvents(now) + + require.Len(t, swc.events, 2, "Should have 2 events after removal") + assert.Equal(t, now.Add(-500*time.Millisecond), swc.events[0], "First event should be 500ms ago") + assert.Equal(t, now, swc.events[1], "Second event should be now") +} + +func TestConcurrency(t *testing.T) { + swc := NewSlidingWindowCounter(time.Minute) + iterations := 1000 + var wg sync.WaitGroup + + wg.Add(iterations) + for i := 0; i < iterations; i++ { + go func() { + defer wg.Done() + swc.Increment() + }() + } + + wg.Wait() + + assert.Equal(t, iterations, swc.Count(), "Count should match the number of increments") +}