From db7c40fddd40fd89d2209d33c4e261730ca384c2 Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Wed, 21 Feb 2024 12:24:50 -0500 Subject: [PATCH] MLS subscribe from cursor (#351) * feat: mls subscribe from cursor * Fix lint errors :cop: --- pkg/mls/api/v1/service.go | 144 ++++++++++++++++--- pkg/mls/api/v1/service_test.go | 252 ++++++++++++++++++++++++++++++++- pkg/mls/store/store.go | 9 +- 3 files changed, 381 insertions(+), 24 deletions(-) diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index c6779139..dba96243 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -367,8 +367,32 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques // See: https://github.com/xmtp/libxmtp/pull/58 _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) - var streamLock sync.Mutex + streamed := map[string]*mlsv1.GroupMessage{} + var streamingLock sync.Mutex + streamMessages := func(msgs []*mlsv1.GroupMessage) { + streamingLock.Lock() + defer streamingLock.Unlock() + + for _, msg := range msgs { + if msg.GetV1() == nil { + continue + } + encodedId := fmt.Sprintf("%x", msg.GetV1().Id) + if _, ok := streamed[encodedId]; ok { + log.Debug("skipping already streamed message", zap.String("id", encodedId)) + continue + } + err := stream.Send(msg) + if err != nil { + log.Error("error streaming group message", zap.Error(err)) + } + streamed[encodedId] = msg + } + } + for _, filter := range req.Filters { + filter := filter + natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId) sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { var msg mlsv1.GroupMessage @@ -377,14 +401,7 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques log.Error("parsing group message from bytes", zap.Error(err)) return } - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(&msg) - if err != nil { - log.Error("sending group message to subscribe", zap.Error(err)) - } - }() + streamMessages([]*mlsv1.GroupMessage{&msg}) }) if err != nil { log.Error("error subscribing to group messages", zap.Error(err)) @@ -393,6 +410,43 @@ func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesReques defer func() { _ = sub.Unsubscribe() }() + + if filter.IdCursor > 0 { + go func() { + pagingInfo := &mlsv1.PagingInfo{ + IdCursor: filter.IdCursor, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + } + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + default: + } + + resp, err := s.store.QueryGroupMessagesV1(stream.Context(), &mlsv1.QueryGroupMessagesRequest{ + GroupId: filter.GroupId, + PagingInfo: pagingInfo, + }) + if err != nil { + if err == context.Canceled { + return + } + log.Error("error querying for subscription cursor messages", zap.Error(err)) + return + } + + streamMessages(resp.Messages) + + if len(resp.Messages) == 0 || resp.PagingInfo == nil || resp.PagingInfo.IdCursor == 0 { + break + } + pagingInfo = resp.PagingInfo + } + }() + } } select { @@ -411,8 +465,32 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe // See: https://github.com/xmtp/libxmtp/pull/58 _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) - var streamLock sync.Mutex + streamed := map[string]*mlsv1.WelcomeMessage{} + var streamingLock sync.Mutex + streamMessages := func(msgs []*mlsv1.WelcomeMessage) { + streamingLock.Lock() + defer streamingLock.Unlock() + + for _, msg := range msgs { + if msg.GetV1() == nil { + continue + } + encodedId := fmt.Sprintf("%x", msg.GetV1().Id) + if _, ok := streamed[encodedId]; ok { + log.Debug("skipping already streamed message", zap.String("id", encodedId)) + continue + } + err := stream.Send(msg) + if err != nil { + log.Error("error streaming welcome message", zap.Error(err)) + } + streamed[encodedId] = msg + } + } + for _, filter := range req.Filters { + filter := filter + natsSubject := buildNatsSubjectForWelcomeMessages(filter.InstallationKey) sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { var msg mlsv1.WelcomeMessage @@ -421,14 +499,7 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe log.Error("parsing welcome message from bytes", zap.Error(err)) return } - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(&msg) - if err != nil { - log.Error("sending welcome message to subscribe", zap.Error(err)) - } - }() + streamMessages([]*mlsv1.WelcomeMessage{&msg}) }) if err != nil { log.Error("error subscribing to welcome messages", zap.Error(err)) @@ -437,6 +508,43 @@ func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRe defer func() { _ = sub.Unsubscribe() }() + + if filter.IdCursor > 0 { + go func() { + pagingInfo := &mlsv1.PagingInfo{ + IdCursor: filter.IdCursor, + Direction: mlsv1.SortDirection_SORT_DIRECTION_ASCENDING, + } + for { + select { + case <-stream.Context().Done(): + return + case <-s.ctx.Done(): + return + default: + } + + resp, err := s.store.QueryWelcomeMessagesV1(stream.Context(), &mlsv1.QueryWelcomeMessagesRequest{ + InstallationKey: filter.InstallationKey, + PagingInfo: pagingInfo, + }) + if err != nil { + if err == context.Canceled { + return + } + log.Error("error querying for subscription cursor messages", zap.Error(err)) + return + } + + streamMessages(resp.Messages) + + if len(resp.Messages) == 0 || resp.PagingInfo == nil || resp.PagingInfo.IdCursor == 0 { + break + } + pagingInfo = resp.PagingInfo + } + }() + } } select { diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index 4d4565ad..d098af6c 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "context" "errors" "fmt" @@ -337,7 +338,7 @@ func TestGetIdentityUpdates(t *testing.T) { require.Len(t, identityUpdates.Updates[0].Updates, 2) } -func TestSubscribeGroupMessages(t *testing.T) { +func TestSubscribeGroupMessages_WithoutCursor(t *testing.T) { ctx := context.Background() svc, _, _, cleanup := newTestService(t, ctx) defer cleanup() @@ -393,7 +394,108 @@ func TestSubscribeGroupMessages(t *testing.T) { require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) } -func TestSubscribeWelcomeMessages(t *testing.T) { +func TestSubscribeGroupMessages_WithCursor(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + groupId := []byte(test.RandomString(32)) + + // Initial message before stream starts. + mlsValidationService.mockValidateGroupMessages(groupId) + initialMsgs := []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data1"), + }, + }, + }, + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data2"), + }, + }, + }, + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data3"), + }, + }, + }, + } + for _, msg := range initialMsgs { + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{msg}, + }) + require.NoError(t, err) + } + + // Set of 10 messages that are included in the stream. + msgs := make([]*mlsv1.GroupMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: uint64(i + 4), + CreatedNs: uint64(i + 4), + GroupId: groupId, + Data: []byte(fmt.Sprintf("data%d", i+4)), + }, + }, + } + } + + // Set up expectations of streaming the 11 messages from cursor. + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeGroupMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream.EXPECT().Send(newGroupMessageIdAndDataEqualsMatcher(&mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: 3, + Data: []byte("data3"), + }, + }, + })).Return(nil).Times(1) + for _, msg := range msgs { + stream.EXPECT().Send(newGroupMessageEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx).AnyTimes() + + go func() { + err := svc.SubscribeGroupMessages(&mlsv1.SubscribeGroupMessagesRequest{ + Filters: []*mlsv1.SubscribeGroupMessagesRequest_Filter{ + { + GroupId: groupId, + IdCursor: 2, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + // Send the 10 real-time messages. + for _, msg := range msgs { + msgB, err := proto.Marshal(msg) + require.NoError(t, err) + + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1GroupTopic(msg.GetV1().GroupId), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + } + + // Expectations should eventually be satisfied. + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + +func TestSubscribeWelcomeMessages_WithoutCursor(t *testing.T) { ctx := context.Background() svc, _, _, cleanup := newTestService(t, ctx) defer cleanup() @@ -450,6 +552,116 @@ func TestSubscribeWelcomeMessages(t *testing.T) { require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) } +func TestSubscribeWelcomeMessages_WithCursor(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + installationKey := []byte(test.RandomString(32)) + hpkePublicKey := []byte(test.RandomString(32)) + + // Initial message before stream starts. + initialMsgs := []*mlsv1.WelcomeMessageInput{ + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + HpkePublicKey: hpkePublicKey, + Data: []byte("data1"), + }, + }, + }, + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + HpkePublicKey: hpkePublicKey, + Data: []byte("data2"), + }, + }, + }, + { + Version: &mlsv1.WelcomeMessageInput_V1_{ + V1: &mlsv1.WelcomeMessageInput_V1{ + InstallationKey: installationKey, + HpkePublicKey: hpkePublicKey, + Data: []byte("data3"), + }, + }, + }, + } + for _, msg := range initialMsgs { + _, err := svc.SendWelcomeMessages(ctx, &mlsv1.SendWelcomeMessagesRequest{ + Messages: []*mlsv1.WelcomeMessageInput{msg}, + }) + require.NoError(t, err) + } + + // Set of 10 messages that are included in the stream. + msgs := make([]*mlsv1.WelcomeMessage, 10) + for i := 0; i < 10; i++ { + msgs[i] = &mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: uint64(i + 4), + CreatedNs: uint64(i + 4), + InstallationKey: installationKey, + HpkePublicKey: hpkePublicKey, + Data: []byte(fmt.Sprintf("data%d", i+4)), + }, + }, + } + } + + // Set up expectations of streaming the 11 messages from cursor. + ctrl := gomock.NewController(t) + stream := NewMockMlsApi_SubscribeWelcomeMessagesServer(ctrl) + stream.EXPECT().SendHeader(map[string][]string{"subscribed": {"true"}}) + stream.EXPECT().Send(newWelcomeMessageEqualsMatcherWithoutTimestamp(&mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: 3, + InstallationKey: installationKey, + HpkePublicKey: hpkePublicKey, + Data: []byte("data3"), + }, + }, + })).Return(nil).Times(1) + for _, msg := range msgs { + stream.EXPECT().Send(newWelcomeMessageEqualsMatcher(msg)).Return(nil).Times(1) + } + stream.EXPECT().Context().Return(ctx).AnyTimes() + + go func() { + err := svc.SubscribeWelcomeMessages(&mlsv1.SubscribeWelcomeMessagesRequest{ + Filters: []*mlsv1.SubscribeWelcomeMessagesRequest_Filter{ + { + InstallationKey: installationKey, + IdCursor: 2, + }, + }, + }, stream) + require.NoError(t, err) + }() + time.Sleep(50 * time.Millisecond) + + // Send the 10 real-time messages. + for _, msg := range msgs { + msgB, err := proto.Marshal(msg) + require.NoError(t, err) + + err = svc.HandleIncomingWakuRelayMessage(&wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1WelcomeTopic(msg.GetV1().InstallationKey), + Timestamp: int64(msg.GetV1().CreatedNs), + Payload: msgB, + }) + require.NoError(t, err) + } + + // Expectations should eventually be satisfied. + require.Eventually(t, ctrl.Satisfied, 5*time.Second, 100*time.Millisecond) +} + type groupMessageEqualsMatcher struct { obj *mlsv1.GroupMessage } @@ -466,6 +678,23 @@ func (m *groupMessageEqualsMatcher) String() string { return m.obj.String() } +type groupMessageIdAndDataEqualsMatcher struct { + obj *mlsv1.GroupMessage +} + +func newGroupMessageIdAndDataEqualsMatcher(obj *mlsv1.GroupMessage) *groupMessageIdAndDataEqualsMatcher { + return &groupMessageIdAndDataEqualsMatcher{obj} +} + +func (m *groupMessageIdAndDataEqualsMatcher) Matches(obj interface{}) bool { + return m.obj.GetV1().Id == obj.(*mlsv1.GroupMessage).GetV1().Id && + bytes.Equal(m.obj.GetV1().Data, obj.(*mlsv1.GroupMessage).GetV1().Data) +} + +func (m *groupMessageIdAndDataEqualsMatcher) String() string { + return m.obj.String() +} + type welcomeMessageEqualsMatcher struct { obj *mlsv1.WelcomeMessage } @@ -481,3 +710,22 @@ func (m *welcomeMessageEqualsMatcher) Matches(obj interface{}) bool { func (m *welcomeMessageEqualsMatcher) String() string { return m.obj.String() } + +type welcomeMessageEqualsMatcherWithoutTimestamp struct { + obj *mlsv1.WelcomeMessage +} + +func newWelcomeMessageEqualsMatcherWithoutTimestamp(obj *mlsv1.WelcomeMessage) *welcomeMessageEqualsMatcherWithoutTimestamp { + return &welcomeMessageEqualsMatcherWithoutTimestamp{obj} +} + +func (m *welcomeMessageEqualsMatcherWithoutTimestamp) Matches(obj interface{}) bool { + return m.obj.GetV1().Id == obj.(*mlsv1.WelcomeMessage).GetV1().Id && + bytes.Equal(m.obj.GetV1().InstallationKey, obj.(*mlsv1.WelcomeMessage).GetV1().InstallationKey) && + bytes.Equal(m.obj.GetV1().HpkePublicKey, obj.(*mlsv1.WelcomeMessage).GetV1().HpkePublicKey) && + bytes.Equal(m.obj.GetV1().Data, obj.(*mlsv1.WelcomeMessage).GetV1().Data) +} + +func (m *welcomeMessageEqualsMatcherWithoutTimestamp) String() string { + return m.obj.String() +} diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 8810f703..64c94b32 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -328,10 +328,11 @@ func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelc messages = append(messages, &mlsv1.WelcomeMessage{ Version: &mlsv1.WelcomeMessage_V1_{ V1: &mlsv1.WelcomeMessage_V1{ - Id: msg.Id, - CreatedNs: uint64(msg.CreatedAt.UnixNano()), - Data: msg.Data, - HpkePublicKey: msg.HpkePublicKey, + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + Data: msg.Data, + InstallationKey: msg.InstallationKey, + HpkePublicKey: msg.HpkePublicKey, }, }, })