Skip to content

Commit

Permalink
MLS subscribe from cursor (#351)
Browse files Browse the repository at this point in the history
* feat: mls subscribe from cursor

* Fix lint errors 👮
  • Loading branch information
Steven Normore authored Feb 21, 2024
1 parent dcd9beb commit db7c40f
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 24 deletions.
144 changes: 126 additions & 18 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,32 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques
// See: https://github.com/xmtp/libxmtp/pull/58
_ = stream.SendHeader(metadata.Pairs("subscribed", "true"))

var streamLock sync.Mutex
streamed := map[string]*mlsv1.GroupMessage{}
var streamingLock sync.Mutex
streamMessages := func(msgs []*mlsv1.GroupMessage) {
streamingLock.Lock()
defer streamingLock.Unlock()

for _, msg := range msgs {
if msg.GetV1() == nil {
continue
}
encodedId := fmt.Sprintf("%x", msg.GetV1().Id)
if _, ok := streamed[encodedId]; ok {
log.Debug("skipping already streamed message", zap.String("id", encodedId))
continue
}
err := stream.Send(msg)
if err != nil {
log.Error("error streaming group message", zap.Error(err))
}
streamed[encodedId] = msg
}
}

for _, filter := range req.Filters {
filter := filter

natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId)
sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) {
var msg mlsv1.GroupMessage
Expand All @@ -377,14 +401,7 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques
log.Error("parsing group message from bytes", zap.Error(err))
return
}
func() {
streamLock.Lock()
defer streamLock.Unlock()
err := stream.Send(&msg)
if err != nil {
log.Error("sending group message to subscribe", zap.Error(err))
}
}()
streamMessages([]*mlsv1.GroupMessage{&msg})
})
if err != nil {
log.Error("error subscribing to group messages", zap.Error(err))
Expand All @@ -393,6 +410,43 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques
defer func() {
_ = sub.Unsubscribe()
}()

if filter.IdCursor > 0 {
go func() {
pagingInfo := &mlsv1.PagingInfo{
IdCursor: filter.IdCursor,
Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING,
}
for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
default:
}

resp, err := s.store.QueryGroupMessagesV1(stream.Context(), &mlsv1.QueryGroupMessagesRequest{
GroupId: filter.GroupId,
PagingInfo: pagingInfo,
})
if err != nil {
if err == context.Canceled {
return
}
log.Error("error querying for subscription cursor messages", zap.Error(err))
return
}

streamMessages(resp.Messages)

if len(resp.Messages) == 0 || resp.PagingInfo == nil || resp.PagingInfo.IdCursor == 0 {
break
}
pagingInfo = resp.PagingInfo
}
}()
}
}

select {
Expand All @@ -411,8 +465,32 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe
// See: https://github.com/xmtp/libxmtp/pull/58
_ = stream.SendHeader(metadata.Pairs("subscribed", "true"))

var streamLock sync.Mutex
streamed := map[string]*mlsv1.WelcomeMessage{}
var streamingLock sync.Mutex
streamMessages := func(msgs []*mlsv1.WelcomeMessage) {
streamingLock.Lock()
defer streamingLock.Unlock()

for _, msg := range msgs {
if msg.GetV1() == nil {
continue
}
encodedId := fmt.Sprintf("%x", msg.GetV1().Id)
if _, ok := streamed[encodedId]; ok {
log.Debug("skipping already streamed message", zap.String("id", encodedId))
continue
}
err := stream.Send(msg)
if err != nil {
log.Error("error streaming welcome message", zap.Error(err))
}
streamed[encodedId] = msg
}
}

for _, filter := range req.Filters {
filter := filter

natsSubject := buildNatsSubjectForWelcomeMessages(filter.InstallationKey)
sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) {
var msg mlsv1.WelcomeMessage
Expand All @@ -421,14 +499,7 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe
log.Error("parsing welcome message from bytes", zap.Error(err))
return
}
func() {
streamLock.Lock()
defer streamLock.Unlock()
err := stream.Send(&msg)
if err != nil {
log.Error("sending welcome message to subscribe", zap.Error(err))
}
}()
streamMessages([]*mlsv1.WelcomeMessage{&msg})
})
if err != nil {
log.Error("error subscribing to welcome messages", zap.Error(err))
Expand All @@ -437,6 +508,43 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe
defer func() {
_ = sub.Unsubscribe()
}()

if filter.IdCursor > 0 {
go func() {
pagingInfo := &mlsv1.PagingInfo{
IdCursor: filter.IdCursor,
Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING,
}
for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
default:
}

resp, err := s.store.QueryWelcomeMessagesV1(stream.Context(), &mlsv1.QueryWelcomeMessagesRequest{
InstallationKey: filter.InstallationKey,
PagingInfo: pagingInfo,
})
if err != nil {
if err == context.Canceled {
return
}
log.Error("error querying for subscription cursor messages", zap.Error(err))
return
}

streamMessages(resp.Messages)

if len(resp.Messages) == 0 || resp.PagingInfo == nil || resp.PagingInfo.IdCursor == 0 {
break
}
pagingInfo = resp.PagingInfo
}
}()
}
}

select {
Expand Down
Loading

0 comments on commit db7c40f

Please sign in to comment.