Skip to content

Commit

Permalink
feat: mls subscribe via db
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven Normore committed Jan 20, 2024
1 parent b54d575 commit 798df89
Show file tree
Hide file tree
Showing 4 changed files with 806 additions and 225 deletions.
336 changes: 300 additions & 36 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,69 @@ func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcom
}

func (s *Service) QueryGroupMessages(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) {
return s.store.QueryGroupMessagesV1(ctx, req)
if req.PagingInfo == nil {
req.PagingInfo = &mlsv1.PagingInfo{}
}
if req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED {
req.PagingInfo.Direction = mlsv1.SortDirection_SORT_DIRECTION_DESCENDING
}
if req.PagingInfo.Limit == 0 || req.PagingInfo.Limit > mlsstore.MaxQueryPageSize {
req.PagingInfo.Limit = mlsstore.MaxQueryPageSize
}

msgs, err := s.store.QueryGroupMessagesV1(ctx, req)
if err != nil {
return nil, err
}

pbMsgs := make([]*mlsv1.GroupMessage, 0, len(msgs))
for _, msg := range msgs {
pbMsgs = append(pbMsgs, toProtoGroupMessage(msg))
}

pagingInfo := &mlsv1.PagingInfo{Limit: uint32(req.PagingInfo.Limit), IdCursor: 0, Direction: req.PagingInfo.Direction}
if len(pbMsgs) >= int(req.PagingInfo.Limit) {
lastMsg := msgs[len(pbMsgs)-1]
pagingInfo.IdCursor = lastMsg.Id
}

return &mlsv1.QueryGroupMessagesResponse{
Messages: pbMsgs,
PagingInfo: pagingInfo,
}, nil
}

func (s *Service) QueryWelcomeMessages(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) {
return s.store.QueryWelcomeMessagesV1(ctx, req)
if req.PagingInfo == nil {
req.PagingInfo = &mlsv1.PagingInfo{}
}
if req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED {
req.PagingInfo.Direction = mlsv1.SortDirection_SORT_DIRECTION_DESCENDING
}
if req.PagingInfo.Limit == 0 || req.PagingInfo.Limit > mlsstore.MaxQueryPageSize {
req.PagingInfo.Limit = mlsstore.MaxQueryPageSize
}

msgs, err := s.store.QueryWelcomeMessagesV1(ctx, req)
if err != nil {
return nil, err
}

pbMsgs := make([]*mlsv1.WelcomeMessage, 0, len(msgs))
for _, msg := range msgs {
pbMsgs = append(pbMsgs, toProtoWelcomeMessage(msg))
}

pagingInfo := &mlsv1.PagingInfo{Limit: uint32(req.PagingInfo.Limit), IdCursor: 0, Direction: req.PagingInfo.Direction}
if len(pbMsgs) >= int(req.PagingInfo.Limit) {
lastMsg := msgs[len(pbMsgs)-1]
pagingInfo.IdCursor = lastMsg.Id
}

return &mlsv1.QueryWelcomeMessagesResponse{
Messages: pbMsgs,
PagingInfo: pagingInfo,
}, nil
}

func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesRequest, stream mlsv1.MlsApi_SubscribeGroupMessagesServer) error {
Expand All @@ -344,24 +402,22 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques
// See: https://github.com/xmtp/libxmtp/pull/58
_ = stream.SendHeader(metadata.Pairs("subscribed", "true"))

var streamLock sync.Mutex
var hasMessagesLock sync.Mutex
var hasMessages bool
setHasMessages := func() {
hasMessagesLock.Lock()
defer hasMessagesLock.Unlock()
hasMessages = true
}

var retErr error

for _, filter := range req.Filters {
filter := filter

natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId)
sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) {
var msg mlsv1.GroupMessage
err := pb.Unmarshal(natsMsg.Data, &msg)
if err != nil {
log.Error("parsing group message from bytes", zap.Error(err))
return
}
func() {
streamLock.Lock()
defer streamLock.Unlock()
err := stream.Send(&msg)
if err != nil {
log.Error("sending group message to subscribe", zap.Error(err))
}
}()
setHasMessages()
})
if err != nil {
log.Error("error subscribing to group messages", zap.Error(err))
Expand All @@ -370,14 +426,106 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques
defer func() {
_ = sub.Unsubscribe()
}()

go func() {
pagingInfo := &mlsv1.PagingInfo{
Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING,
}
if filter.IdCursor > 0 {
pagingInfo.IdCursor = filter.IdCursor
} else {
latestMsg, err := s.store.GetLatestGroupMessage(stream.Context(), filter.GroupId)
if err != nil && !mlsstore.IsNotFoundError(err) {
log.Error("error getting latest group message", zap.Error(err))
retErr = err
return
}
if latestMsg != nil {
pagingInfo.IdCursor = latestMsg.Id
}
}

activeTicker := time.NewTicker(100 * time.Millisecond)
defer activeTicker.Stop()
passiveTicker := time.NewTicker(5 * time.Second)
defer passiveTicker.Stop()
for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
case <-passiveTicker.C:
setHasMessages()
case <-activeTicker.C:
var skip bool
func() {
hasMessagesLock.Lock()
defer hasMessagesLock.Unlock()
if !hasMessages {
skip = true
}
hasMessages = false
}()
if skip {
continue
}

for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
default:
}

msgs, err := s.store.QueryGroupMessagesV1(stream.Context(), &mlsv1.QueryGroupMessagesRequest{
GroupId: filter.GroupId,
PagingInfo: pagingInfo,
})
if err != nil {
if err == context.Canceled {
return
}
log.Error("error querying for subscription cursor messages", zap.Error(err))
// Break out and try again during the next ticker period.
break
}

for _, msg := range msgs {
pbMsg := toProtoGroupMessage(msg)
err := stream.Send(pbMsg)
if err != nil {
log.Error("error streaming group message", zap.Error(err))
}
}

// We can't just use resp.PagingInfo since we always
// want the cursor from the last message even if it's
// the last page.
if len(msgs) > 0 {
lastMsg := msgs[len(msgs)-1]
pagingInfo.IdCursor = lastMsg.Id
}

if len(msgs) == 0 {
break
}
}
}
}
}()
}

select {
case <-stream.Context().Done():
return nil
break
case <-s.ctx.Done():
return nil
break
}

return retErr
}

func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRequest, stream mlsv1.MlsApi_SubscribeWelcomeMessagesServer) error {
Expand All @@ -387,24 +535,22 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe
// See: https://github.com/xmtp/libxmtp/pull/58
_ = stream.SendHeader(metadata.Pairs("subscribed", "true"))

var streamLock sync.Mutex
var hasMessagesLock sync.Mutex
var hasMessages bool
setHasMessages := func() {
hasMessagesLock.Lock()
defer hasMessagesLock.Unlock()
hasMessages = true
}

var retErr error

for _, filter := range req.Filters {
filter := filter

natsSubject := buildNatsSubjectForWelcomeMessages(filter.InstallationKey)
sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) {
var msg mlsv1.WelcomeMessage
err := pb.Unmarshal(natsMsg.Data, &msg)
if err != nil {
log.Error("parsing welcome message from bytes", zap.Error(err))
return
}
func() {
streamLock.Lock()
defer streamLock.Unlock()
err := stream.Send(&msg)
if err != nil {
log.Error("sending welcome message to subscribe", zap.Error(err))
}
}()
setHasMessages()
})
if err != nil {
log.Error("error subscribing to welcome messages", zap.Error(err))
Expand All @@ -413,14 +559,106 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe
defer func() {
_ = sub.Unsubscribe()
}()

go func() {
pagingInfo := &mlsv1.PagingInfo{
Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING,
}
if filter.IdCursor > 0 {
pagingInfo.IdCursor = filter.IdCursor
} else {
latestMsg, err := s.store.GetLatestWelcomeMessage(stream.Context(), filter.InstallationKey)
if err != nil && !mlsstore.IsNotFoundError(err) {
log.Error("error getting latest welcome message", zap.Error(err))
retErr = err
return
}
if latestMsg != nil {
pagingInfo.IdCursor = latestMsg.Id
}
}

activeTicker := time.NewTicker(200 * time.Millisecond)
defer activeTicker.Stop()
passiveTicker := time.NewTicker(5 * time.Second)
defer passiveTicker.Stop()
for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
case <-passiveTicker.C:
setHasMessages()
case <-activeTicker.C:
var skip bool
func() {
hasMessagesLock.Lock()
defer hasMessagesLock.Unlock()
if !hasMessages {
skip = true
}
hasMessages = false
}()
if skip {
continue
}

for {
select {
case <-stream.Context().Done():
return
case <-s.ctx.Done():
return
default:
}

msgs, err := s.store.QueryWelcomeMessagesV1(stream.Context(), &mlsv1.QueryWelcomeMessagesRequest{
InstallationKey: filter.InstallationKey,
PagingInfo: pagingInfo,
})
if err != nil {
if err == context.Canceled {
return
}
log.Error("error querying for subscription cursor messages", zap.Error(err))
// Break out and try again during the next ticker period.
break
}

for _, msg := range msgs {
pbMsg := toProtoWelcomeMessage(msg)
err := stream.Send(pbMsg)
if err != nil {
log.Error("error streaming welcome message", zap.Error(err))
}
}

// We can't just use resp.PagingInfo since we always
// want the cursor from the last message even if it's
// the last page.
if len(msgs) > 0 {
lastMsg := msgs[len(msgs)-1]
pagingInfo.IdCursor = lastMsg.Id
}

if len(msgs) == 0 {
break
}
}
}
}
}()
}

select {
case <-stream.Context().Done():
return nil
break
case <-s.ctx.Done():
return nil
break
}

return retErr
}

func buildNatsSubjectForGroupMessages(groupId []byte) string {
Expand Down Expand Up @@ -512,3 +750,29 @@ func requireReadyToSend(groupId string, message []byte) error {
}
return nil
}

func toProtoGroupMessage(msg *mlsstore.GroupMessage) *mlsv1.GroupMessage {
return &mlsv1.GroupMessage{
Version: &mlsv1.GroupMessage_V1_{
V1: &mlsv1.GroupMessage_V1{
Id: msg.Id,
GroupId: msg.GroupId,
CreatedNs: uint64(msg.CreatedAt.UnixNano()),
Data: msg.Data,
},
},
}
}

func toProtoWelcomeMessage(msg *mlsstore.WelcomeMessage) *mlsv1.WelcomeMessage {
return &mlsv1.WelcomeMessage{
Version: &mlsv1.WelcomeMessage_V1_{
V1: &mlsv1.WelcomeMessage_V1{
Id: msg.Id,
InstallationKey: msg.InstallationKey,
CreatedNs: uint64(msg.CreatedAt.UnixNano()),
Data: msg.Data,
},
},
}
}
Loading

0 comments on commit 798df89

Please sign in to comment.