diff --git a/cmd/profile-async/main.go b/cmd/profile-async/main.go index ddacc0b..e0c911a 100644 --- a/cmd/profile-async/main.go +++ b/cmd/profile-async/main.go @@ -2,11 +2,13 @@ package main import ( "context" + "encoding/json" "fmt" - "io" "log" "net/http" + "time" + "github.com/TheLab-ms/profile/internal/chatbot" "github.com/TheLab-ms/profile/internal/conf" "github.com/TheLab-ms/profile/internal/keycloak" "github.com/TheLab-ms/profile/internal/reporting" @@ -22,8 +24,14 @@ func main() { log.Fatal(err) } + q := NewQueue() kc := keycloak.New(env) + bot, err := chatbot.NewBot(env) + if err != nil { + log.Fatal(err) + } + // Webhook registration if env.KeycloakRegisterWebhook { err = kc.EnsureWebhook(context.TODO(), fmt.Sprintf("%s/webhooks/keycloak", env.SelfURL)) if err != nil { @@ -31,21 +39,80 @@ func main() { } } - svr := &Server{Keycloak: kc} - log.Fatal(http.ListenAndServe(":8080", svr.NewHandler())) -} + // Resync loop + go func() { + ticker := time.NewTicker(time.Hour) + for range ticker.C { + ids, err := kc.ListUserIDs(context.TODO()) + if err != nil { + log.Printf("error while listing members for resync: %s", err) + continue + } + for _, id := range ids { + q.Add(id) + } + } + }() -type Server struct { - Keycloak *keycloak.Keycloak -} + // Message processor loop + go func() { + for { + item := q.Get() + start := time.Now() + log.Printf("syncing user %s", item) + err := syncUser(context.TODO(), kc, bot, item) + if err == nil { + q.Done(item) + log.Printf("sync'd user %s in %s", item, time.Since(start)) + continue + } + log.Printf("error while syncing user %q: %s", item, err) + q.Retry(item) + } + }() -func (s *Server) NewHandler() http.Handler { + // Webhook server mux := http.NewServeMux() - + mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(204) }) mux.HandleFunc("/webhooks/keycloak", func(w http.ResponseWriter, r *http.Request) { - body, _ := io.ReadAll(r.Body) - log.Printf("(TODO don't log entire request body) got keycloak webhook: %s", string(body)) + msg := &webhookMsg{} + err := json.NewDecoder(r.Body).Decode(msg) + if err != nil { + http.Error(w, err.Error(), 400) + return + } + if msg.ResourceType != "USER" { + return + } + q.Add(msg.Details.UserID) + log.Printf("got keycloak webhook for user %s", msg.Details.UserID) }) - return mux + log.Fatal(http.ListenAndServe(":8080", mux)) +} + +type webhookMsg struct { + ResourceType string `json:"resourceType"` // e.g. == "USER" + Details struct { + UserID string `json:"userId"` + } `json:"details"` +} + +func syncUser(ctx context.Context, kc *keycloak.Keycloak, bot *chatbot.Bot, id string) error { + user, err := kc.GetUser(ctx, id) + if err != nil { + return fmt.Errorf("getting user: %w", err) + } + + if user.DiscordUserID > 0 { + err = bot.SyncUser(ctx, &chatbot.UserStatus{ + ID: user.DiscordUserID, + ActiveMember: user.ActiveMember, + }) + if err != nil { + return fmt.Errorf("syncing discord user: %w", err) + } + } + + return nil } diff --git a/cmd/profile-async/workqueue.go b/cmd/profile-async/workqueue.go new file mode 100644 index 0000000..d599f4f --- /dev/null +++ b/cmd/profile-async/workqueue.go @@ -0,0 +1,146 @@ +package main + +import ( + "container/heap" + "context" + "math" + "math/rand" + "sync" + "time" +) + +type QueueItem struct { + key string + attempts int + nextRetry time.Time +} + +type Queue struct { + mu sync.Mutex + cond *sync.Cond + items map[string]*QueueItem + heap *priorityQueue +} + +func NewQueue() *Queue { + q := &Queue{ + items: make(map[string]*QueueItem), + heap: &priorityQueue{}, + } + heap.Init(q.heap) + q.cond = sync.NewCond(&q.mu) + return q +} + +func (q *Queue) Run(ctx context.Context) { + ticker := time.NewTicker(time.Millisecond * 100) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + q.mu.Lock() + if q.heap.Len() == 0 { + q.mu.Unlock() + continue + } + + nextItem := (*q.heap)[0] + delta := nextItem.nextRetry.Sub(time.Now()) + if delta > 0 { + ticker.Reset(delta) + q.mu.Unlock() + continue + } + + q.cond.Signal() + q.mu.Unlock() + } + } +} + +func (q *Queue) Add(key string) { + q.mu.Lock() + defer q.mu.Unlock() + if _, exists := q.items[key]; !exists { + item := &QueueItem{key: key, attempts: 0} + q.items[key] = item + heap.Push(q.heap, item) + q.cond.Signal() + } +} + +func (q *Queue) Done(key string) { + q.mu.Lock() + defer q.mu.Unlock() + if item, exists := q.items[key]; exists { + delete(q.items, key) + q.removeFromHeap(item) + } +} + +func (q *Queue) Get() string { + q.mu.Lock() + defer q.mu.Unlock() + for { + if q.heap.Len() == 0 { + q.cond.Wait() + } else { + item := heap.Pop(q.heap).(*QueueItem) + if item.nextRetry.Before(time.Now()) { + delete(q.items, item.key) + return item.key + } + heap.Push(q.heap, item) + q.cond.Wait() + } + } +} + +func (q *Queue) Retry(key string) { + q.mu.Lock() + defer q.mu.Unlock() + if item, exists := q.items[key]; exists { + item.attempts++ + item.nextRetry = time.Now().Add(q.exponentialBackoff(item.attempts)) + heap.Push(q.heap, item) + q.cond.Signal() + } +} + +func (q *Queue) exponentialBackoff(attempts int) time.Duration { + backoff := float64(time.Second) + jitter := backoff * 0.1 + factor := math.Pow(2, float64(attempts)) + return time.Duration(backoff*factor + jitter*factor*0.5*rand.Float64()) +} + +func (q *Queue) removeFromHeap(item *QueueItem) { + for i, heapItem := range *q.heap { + if heapItem == item { + heap.Remove(q.heap, i) + break + } + } +} + +type priorityQueue []*QueueItem + +func (pq priorityQueue) Len() int { return len(pq) } +func (pq priorityQueue) Less(i, j int) bool { + return pq[i].nextRetry.Before(pq[j].nextRetry) +} +func (pq priorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] +} +func (pq *priorityQueue) Push(x interface{}) { + item := x.(*QueueItem) + *pq = append(*pq, item) +} +func (pq *priorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[0 : n-1] + return item +} diff --git a/cmd/profile-async/workqueue_test.go b/cmd/profile-async/workqueue_test.go new file mode 100644 index 0000000..b05cd74 --- /dev/null +++ b/cmd/profile-async/workqueue_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "context" + "math" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAddingSingleItem(t *testing.T) { + q := NewQueue() + q.Add("item1") + + key := q.Get() + if key != "item1" { + t.Errorf("Expected 'item1', got %s", key) + } +} + +func TestAddMultipleItems(t *testing.T) { + q := NewQueue() + q.Add("item1") + q.Add("item2") + + key1 := q.Get() + key2 := q.Get() + + if key1 == key2 { + t.Errorf("Expected different items, got the same item twice: %s", key1) + } + if (key1 != "item1" && key1 != "item2") || (key2 != "item1" && key2 != "item2") { + t.Errorf("Unexpected items: %s, %s", key1, key2) + } +} + +func TestRetryWithBackoff(t *testing.T) { + q := NewQueue() + go q.Run(context.TODO()) + q.Add("item1") + q.Retry("item1") + + // Get should wait for the retry backoff duration before returning the item. + start := time.Now() + key := q.Get() + elapsed := time.Since(start) + + if key != "item1" { + t.Errorf("Expected 'item1', got %s", key) + } + + expectedBackoff := time.Second * 2 + tolerance := 0.2 + if !approxDuration(elapsed, expectedBackoff, tolerance) { + t.Errorf("Expected retry backoff around %v, got %v", expectedBackoff, elapsed) + } +} + +func TestItemUniqueConstraint(t *testing.T) { + q := NewQueue() + q.Add("item1") + q.Add("item1") // This should be ignored + assert.Len(t, q.items, 1) +} + +func TestConcurrentAddAndRetrieve(t *testing.T) { + q := NewQueue() + var wg sync.WaitGroup + keys := []string{"item1", "item2", "item3"} + + for _, key := range keys { + wg.Add(1) + go func(key string) { + defer wg.Done() + q.Add(key) + }(key) + } + + wg.Wait() + + for i := 0; i < len(keys); i++ { + key := q.Get() + if key != "item1" && key != "item2" && key != "item3" { + t.Errorf("Unexpected key retrieved: %s", key) + } + } +} + +func TestDoneFunctionality(t *testing.T) { + q := NewQueue() + q.Add("item1") + q.Done("item1") + assert.Len(t, q.items, 0) +} + +func TestExponentialBackoffFunction(t *testing.T) { + q := NewQueue() + + backoff := q.exponentialBackoff(1) + if !approxDuration(backoff, time.Second*2, 0.2) { + t.Errorf("Expected backoff around 2s, got %v", backoff) + } + + backoff = q.exponentialBackoff(2) + if !approxDuration(backoff, time.Second*4, 0.2) { + t.Errorf("Expected backoff around 4s, got %v", backoff) + } +} + +func approxDuration(d1, d2 time.Duration, tolerance float64) bool { + return math.Abs(float64(d1-d2)) <= tolerance*float64(d1) +} diff --git a/cmd/profile-server/main.go b/cmd/profile-server/main.go index 2f6bb71..705b441 100644 --- a/cmd/profile-server/main.go +++ b/cmd/profile-server/main.go @@ -50,10 +50,11 @@ func main() { kc := keycloak.New(env) go kc.RunReportingLoop() - err = chatbot.Start(ctx, env) + bot, err := chatbot.NewBot(env) if err != nil { panic(err) } + bot.Start(ctx) // Events cache polls a the Discord scheduled events API to feed the calendar API. eventsCache := events.NewCache(env) diff --git a/internal/chatbot/discord.go b/internal/chatbot/discord.go index dc9ee95..4ac8b34 100644 --- a/internal/chatbot/discord.go +++ b/internal/chatbot/discord.go @@ -7,23 +7,39 @@ import ( "encoding/hex" "fmt" "log" + "strconv" "github.com/TheLab-ms/profile/internal/conf" "github.com/bwmarrin/discordgo" ) -func Start(ctx context.Context, env *conf.Env) error { +type Bot struct { + client *discordgo.Session + env *conf.Env +} + +func NewBot(env *conf.Env) (*Bot, error) { + b := &Bot{env: env} if env.DiscordAppID == "" { - log.Printf("not starting discord bot because it isn't configured") - return nil + return b, nil } s, err := discordgo.New("Bot " + env.DiscordBotToken) if err != nil { - return err + return nil, err + } + b.client = s + + return b, nil +} + +func (b *Bot) Start(ctx context.Context) error { + if b.client == nil { + log.Printf("not starting discord bot because it isn't configured") + return nil } - _, err = s.ApplicationCommandCreate(env.DiscordAppID, env.DiscordGuildID, &discordgo.ApplicationCommand{ + _, err := b.client.ApplicationCommandCreate(b.env.DiscordAppID, b.env.DiscordGuildID, &discordgo.ApplicationCommand{ Name: "link", Description: "Link your membership to Discord", Type: discordgo.ChatApplicationCommand, @@ -32,7 +48,7 @@ func Start(ctx context.Context, env *conf.Env) error { return err } - s.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) { + b.client.AddHandler(func(s *discordgo.Session, i *discordgo.InteractionCreate) { member := i.Member if member == nil || member.User == nil { return @@ -40,22 +56,61 @@ func Start(ctx context.Context, env *conf.Env) error { id := member.User.ID log.Printf("got link request for discord user %q", id) - signature := GenerateHMAC(id, env.DiscordBotToken) + signature := GenerateHMAC(id, b.env.DiscordBotToken) s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, Data: &discordgo.InteractionResponseData{ Flags: discordgo.MessageFlagsEphemeral, - Content: fmt.Sprintf("[Go to the profile app to finish the process!](%s/link-discord?user=%s&sig=%s)", env.SelfURL, id, signature), + Content: fmt.Sprintf("[Go to the profile app to finish the process!](%s/link-discord?user=%s&sig=%s)", b.env.SelfURL, id, signature), }, }) }) go func() { <-ctx.Done() - s.Close() + b.client.Close() }() - return s.Open() + return b.client.Open() +} + +func (b *Bot) SyncUser(ctx context.Context, user *UserStatus) error { + member, err := b.client.GuildMember(b.env.DiscordGuildID, strconv.FormatInt(user.ID, 10), discordgo.WithContext(ctx)) + if err != nil { + return fmt.Errorf("getting guild member: %w", err) + } + + var exists bool + for _, role := range member.Roles { + if role == b.env.DiscordMemberRoleID { + exists = true + break + } + } + if exists == user.ActiveMember { + return nil // already in sync + } + + if user.ActiveMember { + err = b.client.GuildMemberRoleAdd(b.env.DiscordGuildID, strconv.FormatInt(user.ID, 10), b.env.DiscordMemberRoleID, discordgo.WithContext(ctx)) + if err != nil { + return fmt.Errorf("adding role to guild member: %w", err) + } + log.Printf("added member role to discord user %d", user.ID) + return nil + } + + err = b.client.GuildMemberRoleRemove(b.env.DiscordGuildID, strconv.FormatInt(user.ID, 10), b.env.DiscordMemberRoleID, discordgo.WithContext(ctx)) + if err != nil { + return fmt.Errorf("removing role from guild member: %w", err) + } + log.Printf("removed member role from discord user %d", user.ID) + return nil +} + +type UserStatus struct { + ID int64 + ActiveMember bool } func GenerateHMAC(message, key string) string { diff --git a/internal/conf/env.go b/internal/conf/env.go index 0d9c1d1..5cbb59e 100644 --- a/internal/conf/env.go +++ b/internal/conf/env.go @@ -40,6 +40,7 @@ type Env struct { DiscordBotToken string `split_words:"true"` DiscordEventBotToken string `split_words:"true"` DiscordInterval time.Duration `split_words:"true" default:"60s"` + DiscordMemberRoleID string `split_words:"true"` // Age (secrets encrpytion) AgePublicKey string `split_words:"true"` diff --git a/internal/keycloak/keycloak.go b/internal/keycloak/keycloak.go index 9a506f2..fd93d3c 100644 --- a/internal/keycloak/keycloak.go +++ b/internal/keycloak/keycloak.go @@ -398,6 +398,32 @@ func (k *Keycloak) ListUsers(ctx context.Context) ([]*ExtendedUser, error) { } } +func (k *Keycloak) ListUserIDs(ctx context.Context) ([]string, error) { + token, err := k.GetToken(ctx) + if err != nil { + return nil, fmt.Errorf("getting token: %w", err) + } + + var ( + max = 50 + first = 0 + ) + ids := []string{} + for { + users, err := k.Client.GetUsers(ctx, token.AccessToken, k.env.KeycloakRealm, gocloak.GetUsersParams{Max: &max, First: &first, BriefRepresentation: gocloak.BoolP(true)}) + if err != nil { + return nil, fmt.Errorf("getting token: %w", err) + } + if len(users) == 0 { + return ids, nil + } + first += len(users) + for _, kcuser := range users { + ids = append(ids, *kcuser.ID) + } + } +} + // For whatever reason the Keycloak client doesn't support token rotation func (k *Keycloak) GetToken(ctx context.Context) (*gocloak.JWT, error) { k.tokenLock.Lock()