From 909ce3805526bd825bc4f228f2edf7696e4ff2a5 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:42:02 -0300 Subject: [PATCH] Implement subscriptions --- pkg/api/message/v1/subscription.go | 2 +- pkg/api/server.go | 2 +- pkg/identity/api/v1/identity_service.go | 61 ++++++++++++++++++-- pkg/identity/api/v1/identity_service_test.go | 13 ++++- pkg/mls/api/v1/service.go | 29 ++++------ pkg/topic/mls.go | 12 ++++ 6 files changed, 94 insertions(+), 25 deletions(-) diff --git a/pkg/api/message/v1/subscription.go b/pkg/api/message/v1/subscription.go index 74ea42a8..2f8308e5 100644 --- a/pkg/api/message/v1/subscription.go +++ b/pkg/api/message/v1/subscription.go @@ -153,5 +153,5 @@ func (sub *subscription) Unsubscribe() { } func isValidSubscribeAllTopic(contentTopic string) bool { - return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || topic.IsMLSV1(contentTopic) + return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || (topic.IsMLSV1(contentTopic) && !topic.IsAssociationChanged(contentTopic)) } diff --git a/pkg/api/server.go b/pkg/api/server.go index be575332..93f60f30 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -175,7 +175,7 @@ func (s *Server) startGRPC() error { } mlsv1pb.RegisterMlsApiServer(grpcServer, s.mlsv1) - s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator) + s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, s.natsServer) if err != nil { return errors.Wrap(err, "creating identity service") } diff --git a/pkg/identity/api/v1/identity_service.go b/pkg/identity/api/v1/identity_service.go index 98d992cc..2deff489 100644 --- a/pkg/identity/api/v1/identity_service.go +++ b/pkg/identity/api/v1/identity_service.go @@ -3,13 +3,18 @@ package api import ( "context" + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" + "github.com/xmtp/xmtp-node-go/pkg/envelopes" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" + v1proto "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1" + "github.com/xmtp/xmtp-node-go/pkg/topic" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "google.golang.org/grpc/metadata" + pb "google.golang.org/protobuf/proto" ) type Service struct { @@ -20,10 +25,11 @@ type Service struct { validationService mlsvalidate.MLSValidationService ctx context.Context + nc *nats.Conn ctxCancel func() } -func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) { +func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService, natsServer *server.Server) (s *Service, err error) { s = &Service{ log: log.Named("identity"), store: store, @@ -31,6 +37,11 @@ func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsv } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + s.nc, err = nats.Connect(natsServer.ClientURL()) + if err != nil { + return nil, err + } + s.log.Info("Starting identity service") return s, nil } @@ -109,5 +120,47 @@ func (s *Service) SubscribeAssociationChanges(req *identity.SubscribeAssociation log := s.log.Named("subscribe-association-changes") log.Info("subscription started") - return status.Errorf(codes.Unimplemented, "method SubscribeAssociationChanges not implemented") + _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) + + natsSubject := buildNatsSubjectForAssociationChanges() + sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { + msg, err := getAssociationChangedMessageFromNats(natsMsg) + if err != nil { + log.Error("parsing message", zap.Error(err)) + } + if err = stream.Send(msg); err != nil { + log.Warn("sending message to stream", zap.Error(err)) + } + }) + + if err != nil { + log.Error("error subscribing to nats", zap.Error(err)) + return err + } + + defer func() { + _ = sub.Unsubscribe() + }() + + return nil +} + +func buildNatsSubjectForAssociationChanges() string { + return envelopes.BuildNatsSubject(topic.BuildAssociationChangedTopic()) +} + +func getAssociationChangedMessageFromNats(natsMsg *nats.Msg) (*identity.SubscribeAssociationChangesResponse, error) { + var env v1proto.Envelope + err := pb.Unmarshal(natsMsg.Data, &env) + if err != nil { + return nil, err + } + + var msg identity.SubscribeAssociationChangesResponse + err = pb.Unmarshal(env.Message, &msg) + if err != nil { + return nil, err + } + + return &msg, nil } diff --git a/pkg/identity/api/v1/identity_service_test.go b/pkg/identity/api/v1/identity_service_test.go index a3b9797f..bbede814 100644 --- a/pkg/identity/api/v1/identity_service_test.go +++ b/pkg/identity/api/v1/identity_service_test.go @@ -3,7 +3,9 @@ package api import ( "context" "testing" + "time" + "github.com/nats-io/nats-server/v2/server" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/uptrace/bun" @@ -73,9 +75,18 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, func( DB: db, }) require.NoError(t, err) + natsServer, err := server.NewServer(&server.Options{ + Port: server.RANDOM_PORT, + }) + require.NoError(t, err) + go natsServer.Start() + if !natsServer.ReadyForConnections(4 * time.Second) { + t.Fail() + } + require.NoError(t, err) mlsValidationService := newMockedValidationService() - svc, err := NewService(log, store, mlsValidationService) + svc, err := NewService(log, store, mlsValidationService, natsServer) require.NoError(t, err) return svc, db, func() { diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 8dfafb3e..4e1a1015 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -73,26 +73,19 @@ func (s *Service) Close() { } func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) error { - if topic.IsMLSV1Group(wakuMsg.ContentTopic) { - s.log.Info("received group message from waku relay", zap.String("topic", wakuMsg.ContentTopic)) - - // Build the nats subject from the topic - natsSubject := envelopes.BuildNatsSubject(wakuMsg.ContentTopic) - s.log.Info("publishing to nats subject from relay", zap.String("subject", natsSubject)) - env := envelopes.BuildEnvelope(wakuMsg) - envB, err := pb.Marshal(env) - if err != nil { - return err + isMlsGroup := topic.IsMLSV1Group(wakuMsg.ContentTopic) + isMlsWelcome := topic.IsMLSV1Welcome(wakuMsg.ContentTopic) + isAssociationChanged := topic.IsAssociationChanged(wakuMsg.ContentTopic) + if isMlsGroup || isMlsWelcome || isAssociationChanged { + if isMlsGroup { + s.log.Info("received group message from waku relay", zap.String("topic", wakuMsg.ContentTopic)) + } else if isMlsWelcome { + s.log.Info("received welcome message from waku relay", zap.String("topic", wakuMsg.ContentTopic)) + } else if isAssociationChanged { + s.log.Info("received association changed message from waku relay", zap.String("topic", wakuMsg.ContentTopic)) } - err = s.nc.Publish(natsSubject, envB) - if err != nil { - s.log.Error("error publishing to nats", zap.Error(err)) - return err - } - } else if topic.IsMLSV1Welcome(wakuMsg.ContentTopic) { - s.log.Info("received welcome message from waku relay", zap.String("topic", wakuMsg.ContentTopic)) - + // Build the nats subject from the topic natsSubject := envelopes.BuildNatsSubject(wakuMsg.ContentTopic) s.log.Info("publishing to nats subject from relay", zap.String("subject", natsSubject)) env := envelopes.BuildEnvelope(wakuMsg) diff --git a/pkg/topic/mls.go b/pkg/topic/mls.go index 33f3ecbf..eeee6d6b 100644 --- a/pkg/topic/mls.go +++ b/pkg/topic/mls.go @@ -7,6 +7,10 @@ import ( const mlsv1Prefix = "/xmtp/mls/1/" +var ( + AssociationChangedTopic = BuildAssociationChangedTopic() +) + func IsMLSV1(topic string) bool { return strings.HasPrefix(topic, mlsv1Prefix) } @@ -19,6 +23,10 @@ func IsMLSV1Welcome(topic string) bool { return strings.HasPrefix(topic, mlsv1Prefix+"w-") } +func IsAssociationChanged(topic string) bool { + return topic == AssociationChangedTopic +} + func BuildMLSV1GroupTopic(groupId []byte) string { return fmt.Sprintf("%sg-%x/proto", mlsv1Prefix, groupId) } @@ -26,3 +34,7 @@ func BuildMLSV1GroupTopic(groupId []byte) string { func BuildMLSV1WelcomeTopic(installationId []byte) string { return fmt.Sprintf("%sw-%x/proto", mlsv1Prefix, installationId) } + +func BuildAssociationChangedTopic() string { + return fmt.Sprintf("%sassociations/proto", mlsv1Prefix) +}