Skip to content

Commit

Permalink
feat: add universal 60 sec event counter to livestream (#24380)
Browse files Browse the repository at this point in the history
  • Loading branch information
fuziontech authored Aug 14, 2024
1 parent 826f5a8 commit 8cecfbf
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 47 deletions.
27 changes: 22 additions & 5 deletions livestream/live_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,37 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable"
)

type TeamStats struct {
Store map[string]*expirable.LRU[string, string]
const (
COUNTER_TTL = 60
)

type Stats struct {
Store map[string]*expirable.LRU[string, string]
GlobalStore *expirable.LRU[string, string]
Counter *SlidingWindowCounter
}

func newStatsKeeper() *Stats {
return &Stats{
Store: make(map[string]*expirable.LRU[string, string]),
GlobalStore: expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL),
Counter: NewSlidingWindowCounter(COUNTER_TTL),
}
}

func (ts *TeamStats) keepStats(statsChan chan PostHogEvent) {
func (ts *Stats) keepStats(statsChan chan PostHogEvent) {
log.Println("starting stats keeper...")

for { // ignore the range warning here - it's wrong
select {
case event := <-statsChan:
ts.Counter.Increment()
token := event.Token
if _, ok := ts.Store[token]; !ok {
ts.Store[token] = expirable.NewLRU[string, string](1000000, nil, time.Second*30)
ts.Store[token] = expirable.NewLRU[string, string](0, nil, time.Second*COUNTER_TTL)
}
ts.Store[token].Add(event.DistinctId, "much wow")
ts.Store[token].Add(event.DistinctId, "1")
ts.GlobalStore.Add(event.DistinctId, "1")
}
}
}
42 changes: 5 additions & 37 deletions livestream/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/getsentry/sentry-go"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/spf13/viper"
Expand Down Expand Up @@ -63,16 +62,14 @@ func main() {
log.Fatalf("Failed to open MMDB: %v", err)
}

teamStats := &TeamStats{
Store: make(map[string]*expirable.LRU[string, string]),
}
stats := newStatsKeeper()

phEventChan := make(chan PostHogEvent)
statsChan := make(chan PostHogEvent)
subChan := make(chan Subscription)
unSubChan := make(chan Subscription)

go teamStats.keepStats(statsChan)
go stats.keepStats(statsChan)

kafkaSecurityProtocol := "SSL"
if !isProd {
Expand Down Expand Up @@ -109,43 +106,14 @@ func main() {
// Routes
e.GET("/", index)

e.GET("/stats", func(c echo.Context) error {

type stats struct {
UsersOnProduct int `json:"users_on_product,omitempty"`
Error string `json:"error,omitempty"`
}

authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return errors.New("authorization header is required")
}

claims, err := decodeAuthToken(authHeader)
if err != nil {
return err
}
token := fmt.Sprint(claims["api_token"])
e.GET("/served", servedHandler(stats))

var hash *expirable.LRU[string, string]
var ok bool
if hash, ok = teamStats.Store[token]; !ok {
resp := stats{
Error: "no stats",
}
return c.JSON(http.StatusOK, resp)
}

siteStats := stats{
UsersOnProduct: hash.Len(),
}
return c.JSON(http.StatusOK, siteStats)
})
e.GET("/stats", statsHandler(stats))

e.GET("/events", func(c echo.Context) error {
e.Logger.Printf("SSE client connected, ip: %v", c.RealIP())

teamId := c.QueryParam("teamId")
var teamId string
eventType := c.QueryParam("eventType")
distinctId := c.QueryParam("distinctId")
geo := c.QueryParam("geo")
Expand Down
10 changes: 5 additions & 5 deletions livestream/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ func TestStatsHandler(t *testing.T) {
req.Header.Set("Authorization", "Bearer mock_token")

// Create a mock TeamStats
teamStats := &TeamStats{
stats := &Stats{
Store: make(map[string]*expirable.LRU[string, string]),
}
teamStats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute)
teamStats.Store["mock_token"].Add("user1", "data1")
stats.Store["mock_token"] = expirable.NewLRU[string, string](100, nil, time.Minute)
stats.Store["mock_token"].Add("user1", "data1")

// Add the teamStats to the context
c.Set("teamStats", teamStats)
c.Set("teamStats", stats)

handler := func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]interface{}{
"users_on_product": teamStats.Store["mock_token"].Len(),
"users_on_product": stats.Store["mock_token"].Len(),
})
}

Expand Down
62 changes: 62 additions & 0 deletions livestream/served.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package main

import (
"errors"
"fmt"
"net/http"

"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/labstack/echo/v4"
)

type Counter struct {
EventCount uint32
UserCount uint32
}

func servedHandler(stats *Stats) func(c echo.Context) error {
return func(c echo.Context) error {
userCount := stats.GlobalStore.Len()
count := stats.Counter.Count()
resp := Counter{
EventCount: uint32(count),
UserCount: uint32(userCount),
}
return c.JSON(http.StatusOK, resp)
}
}

func statsHandler(stats *Stats) func(c echo.Context) error {
return func(c echo.Context) error {

type resp struct {
UsersOnProduct int `json:"users_on_product,omitempty"`
Error string `json:"error,omitempty"`
}

authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return errors.New("authorization header is required")
}

claims, err := decodeAuthToken(authHeader)
if err != nil {
return err
}
token := fmt.Sprint(claims["api_token"])

var hash *expirable.LRU[string, string]
var ok bool
if hash, ok = stats.Store[token]; !ok {
resp := resp{
Error: "no stats",
}
return c.JSON(http.StatusNotFound, resp)
}

siteStats := resp{
UsersOnProduct: hash.Len(),
}
return c.JSON(http.StatusOK, siteStats)
}
}
48 changes: 48 additions & 0 deletions livestream/ttl_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package main

import (
"sync"
"time"
)

type SlidingWindowCounter struct {
mu sync.Mutex
events []time.Time
windowSize time.Duration
}

func NewSlidingWindowCounter(windowSize time.Duration) *SlidingWindowCounter {
return &SlidingWindowCounter{
events: make([]time.Time, 0),
windowSize: windowSize,
}
}

func (swc *SlidingWindowCounter) Increment() {
swc.mu.Lock()
defer swc.mu.Unlock()

now := time.Now()
swc.events = append(swc.events, now)
swc.removeOldEvents(now)
}

func (swc *SlidingWindowCounter) Count() int {
swc.mu.Lock()
defer swc.mu.Unlock()

now := time.Now()
swc.removeOldEvents(now)
return len(swc.events)
}

func (swc *SlidingWindowCounter) removeOldEvents(now time.Time) {
cutoff := now.Add(-swc.windowSize)
i := 0
for ; i < len(swc.events); i++ {
if swc.events[i].After(cutoff) {
break
}
}
swc.events = swc.events[i:]
}
78 changes: 78 additions & 0 deletions livestream/ttl_counter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package main

import (
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewSlidingWindowCounter(t *testing.T) {
windowSize := time.Minute
swc := NewSlidingWindowCounter(windowSize)

assert.Equal(t, windowSize, swc.windowSize, "Window size should match")
assert.Empty(t, swc.events, "Events slice should be empty")
}

func TestIncrement(t *testing.T) {
swc := NewSlidingWindowCounter(time.Minute)

swc.Increment()
assert.Equal(t, 1, swc.Count(), "Count should be 1 after first increment")

swc.Increment()
assert.Equal(t, 2, swc.Count(), "Count should be 2 after second increment")
}

func TestCount(t *testing.T) {
swc := NewSlidingWindowCounter(time.Second)

swc.Increment()
time.Sleep(500 * time.Millisecond)
swc.Increment()

assert.Equal(t, 2, swc.Count(), "Count should be 2 within the time window")

time.Sleep(600 * time.Millisecond)

assert.Equal(t, 1, swc.Count(), "Count should be 1 after oldest event expires")
}

func TestRemoveOldEvents(t *testing.T) {
swc := NewSlidingWindowCounter(time.Second)

now := time.Now()
swc.events = []time.Time{
now.Add(-2 * time.Second),
now.Add(-1500 * time.Millisecond),
now.Add(-500 * time.Millisecond),
now,
}

swc.removeOldEvents(now)

require.Len(t, swc.events, 2, "Should have 2 events after removal")
assert.Equal(t, now.Add(-500*time.Millisecond), swc.events[0], "First event should be 500ms ago")
assert.Equal(t, now, swc.events[1], "Second event should be now")
}

func TestConcurrency(t *testing.T) {
swc := NewSlidingWindowCounter(time.Minute)
iterations := 1000
var wg sync.WaitGroup

wg.Add(iterations)
for i := 0; i < iterations; i++ {
go func() {
defer wg.Done()
swc.Increment()
}()
}

wg.Wait()

assert.Equal(t, iterations, swc.Count(), "Count should match the number of increments")
}

0 comments on commit 8cecfbf

Please sign in to comment.