From fb4f2fafc280fa018c90f4714846bd2dfa95fedb Mon Sep 17 00:00:00 2001 From: Steven Normore Date: Fri, 12 Jan 2024 16:13:16 -0500 Subject: [PATCH] feat: implement mls subscribe group/welcome messages --- go.mod | 2 +- go.sum | 4 +- pkg/api/message/v1/service.go | 67 +++------ pkg/api/server.go | 59 +++++++- pkg/mls/api/v1/service.go | 264 +++++++++++++++++++++++++++------ pkg/mls/api/v1/service_test.go | 39 ++++- pkg/topic/mls.go | 28 ++++ pkg/topic/topic.go | 9 -- 8 files changed, 363 insertions(+), 109 deletions(-) create mode 100644 pkg/topic/mls.go diff --git a/go.mod b/go.mod index 50656702..357928a7 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.37.1-0.20240112031043-fd75b4bf81f8 + github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0 github.com/yoheimuta/protolint v0.39.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 diff --git a/go.sum b/go.sum index b87f10e7..bd5d139b 100644 --- a/go.sum +++ b/go.sum @@ -1146,8 +1146,8 @@ github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0 github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 h1:wzUffJGCTBGXIDyNU+1UBu1fn2Nzo+OQzM1pLrheh58= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3/go.mod h1:bJREWk+NDnZYjgLQdAi8SUWuq/5pkMme4GqiffEhUF4= -github.com/xmtp/proto/v3 v3.37.1-0.20240112031043-fd75b4bf81f8 h1:r7KYIg8OtDLDHGwlEHo5SiOkEM86C33DHtZsQ9B0pKM= -github.com/xmtp/proto/v3 v3.37.1-0.20240112031043-fd75b4bf81f8/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0 h1:eGNiXDTiXcXTf5ne4HACbqbHaQrVlRz2hwcn05E7v8U= +github.com/xmtp/proto/v3 v3.37.1-0.20240112125235-f02fe8d0f1a0/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw= github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8= diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index bb1b256a..984a36a7 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -13,16 +13,13 @@ import ( "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" "github.com/pkg/errors" - wakunode "github.com/waku-org/go-waku/waku/v2/node" wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" - wakurelay "github.com/waku-org/go-waku/waku/v2/protocol/relay" proto "github.com/xmtp/proto/v3/go/message_api/v1" apicontext "github.com/xmtp/xmtp-node-go/pkg/api/message/v1/context" "github.com/xmtp/xmtp-node-go/pkg/logging" "github.com/xmtp/xmtp-node-go/pkg/metrics" "github.com/xmtp/xmtp-node-go/pkg/store" "github.com/xmtp/xmtp-node-go/pkg/topic" - "github.com/xmtp/xmtp-node-go/pkg/tracing" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -45,24 +42,24 @@ type Service struct { // Configured as constructor options. log *zap.Logger - waku *wakunode.WakuNode store *store.Store + publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error + // Configured internally. ctx context.Context ctxCancel func() wg sync.WaitGroup - relaySub *wakurelay.Subscription ns *server.Server nc *nats.Conn } -func NewService(node *wakunode.WakuNode, logger *zap.Logger, store *store.Store) (s *Service, err error) { +func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { s = &Service{ - waku: node, - log: logger.Named("message/v1"), - store: store, + log: log.Named("message/v1"), + store: store, + publishToWakuRelay: publishToWakuRelay, } s.ctx, s.ctxCancel = context.WithCancel(context.Background()) @@ -82,44 +79,11 @@ func NewService(node *wakunode.WakuNode, logger *zap.Logger, store *store.Store) return nil, err } - // Initialize waku relay subscription. - s.relaySub, err = s.waku.Relay().Subscribe(s.ctx) - if err != nil { - return nil, errors.Wrap(err, "subscribing to relay") - } - tracing.GoPanicWrap(s.ctx, &s.wg, "broadcast", func(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case wakuEnv := <-s.relaySub.Ch: - if wakuEnv == nil { - continue - } - env := buildEnvelope(wakuEnv.Message()) - - envB, err := pb.Marshal(env) - if err != nil { - s.log.Error("marshalling envelope", zap.Error(err)) - continue - } - err = s.nc.Publish(buildNatsSubject(env.ContentTopic), envB) - if err != nil { - s.log.Error("publishing envelope to local nats", zap.Error(err)) - continue - } - } - } - }) - return s, nil } func (s *Service) Close() { s.log.Info("closing") - if s.relaySub != nil { - s.relaySub.Unsubscribe() - } if s.ctxCancel != nil { s.ctxCancel() @@ -136,6 +100,22 @@ func (s *Service) Close() { s.log.Info("closed") } +func (s *Service) HandleIncomingWakuRelayMessage(msg *wakupb.WakuMessage) error { + env := buildEnvelope(msg) + + envB, err := pb.Marshal(env) + if err != nil { + return err + } + + err = s.nc.Publish(buildNatsSubject(env.ContentTopic), envB) + if err != nil { + return err + } + + return nil +} + func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*proto.PublishResponse, error) { for _, env := range req.Envelopes { log := s.log.Named("publish").With(zap.String("content_topic", env.ContentTopic)) @@ -156,7 +136,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot } } - _, err := s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ + err := s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ ContentTopic: env.ContentTopic, Timestamp: int64(env.TimestampNs), Payload: env.Message, @@ -164,6 +144,7 @@ func (s *Service) Publish(ctx context.Context, req *proto.PublishRequest) (*prot if err != nil { return nil, status.Errorf(codes.Internal, err.Error()) } + metrics.EmitPublishedEnvelope(ctx, log, env) } return &proto.PublishResponse{}, nil diff --git a/pkg/api/server.go b/pkg/api/server.go index fd63ae3b..0d6fdbc3 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -14,10 +14,13 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" swgui "github.com/swaggest/swgui/v3" + wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" + wakurelay "github.com/waku-org/go-waku/waku/v2/protocol/relay" proto "github.com/xmtp/proto/v3/go/message_api/v1" mlsv1pb "github.com/xmtp/proto/v3/go/mls/api/v1" messagev1openapi "github.com/xmtp/proto/v3/openapi/message_api/v1" "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" + "github.com/xmtp/xmtp-node-go/pkg/topic" "github.com/xmtp/xmtp-node-go/pkg/tracing" "google.golang.org/grpc/health" healthgrpc "google.golang.org/grpc/health/grpc_health_v1" @@ -48,6 +51,8 @@ type Server struct { mlsv1 *mlsv1.Service wg sync.WaitGroup ctx context.Context + ctxCancel func() + wakuRelaySub *wakurelay.Subscription authorizer *WalletAuthorizer } @@ -61,7 +66,7 @@ func New(config *Config) (*Server, error) { Config: config, } - s.ctx = context.Background() + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) // Start gRPC services. err := s.startGRPC() @@ -123,7 +128,12 @@ func (s *Server) startGRPC() error { healthcheck := health.NewServer() healthgrpc.RegisterHealthServer(grpcServer, healthcheck) - s.messagev1, err = messagev1.NewService(s.Waku, s.Log, s.Store) + publishToWakuRelay := func(ctx context.Context, msg *wakupb.WakuMessage) error { + _, err := s.Waku.Relay().Publish(ctx, msg) + return err + } + + s.messagev1, err = messagev1.NewService(s.Log, s.Store, publishToWakuRelay) if err != nil { return errors.Wrap(err, "creating message service") } @@ -131,12 +141,43 @@ func (s *Server) startGRPC() error { // Enable the MLS server if a store is provided if s.Config.MLSStore != nil && s.Config.MLSValidator != nil && s.Config.EnableMls { - s.mlsv1, err = mlsv1.NewService(s.Waku, s.Log, s.Config.MLSStore, s.Config.MLSValidator) + s.mlsv1, err = mlsv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, publishToWakuRelay) if err != nil { return errors.Wrap(err, "creating mls service") } mlsv1pb.RegisterMlsApiServer(grpcServer, s.mlsv1) } + + // Initialize waku relay subscription. + s.wakuRelaySub, err = s.Waku.Relay().Subscribe(s.ctx) + if err != nil { + return errors.Wrap(err, "subscribing to relay") + } + tracing.GoPanicWrap(s.ctx, &s.wg, "broadcast", func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case wakuEnv := <-s.wakuRelaySub.Ch: + if wakuEnv == nil || wakuEnv.Message() == nil { + continue + } + wakuMsg := wakuEnv.Message() + + if topic.IsMLSV1(wakuMsg.ContentTopic) { + if s.mlsv1 != nil { + s.mlsv1.HandleIncomingWakuRelayMessage(wakuEnv.Message()) + } + } else { + if s.messagev1 != nil { + s.messagev1.HandleIncomingWakuRelayMessage(wakuEnv.Message()) + } + } + + } + } + }) + prometheus.Register(grpcServer) tracing.GoPanicWrap(s.ctx, &s.wg, "grpc", func(ctx context.Context) { @@ -215,9 +256,21 @@ func (s *Server) startHTTP() error { func (s *Server) Close() { s.Log.Info("closing") + + if s.ctxCancel != nil { + s.ctxCancel() + } + + if s.wakuRelaySub != nil { + s.wakuRelaySub.Unsubscribe() + } + if s.messagev1 != nil { s.messagev1.Close() } + if s.mlsv1 != nil { + s.mlsv1.Close() + } if s.httpListener != nil { err := s.httpListener.Close() diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 61c229e6..5040d094 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -3,42 +3,125 @@ package api import ( "context" "encoding/hex" - - wakunode "github.com/waku-org/go-waku/waku/v2/node" + "errors" + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" - proto "github.com/xmtp/proto/v3/go/mls/api/v1" + mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" "github.com/xmtp/xmtp-node-go/pkg/mls/store" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/topic" "go.uber.org/zap" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + pb "google.golang.org/protobuf/proto" emptypb "google.golang.org/protobuf/types/known/emptypb" ) type Service struct { - proto.UnimplementedMlsApiServer + mlsv1.UnimplementedMlsApiServer log *zap.Logger - waku *wakunode.WakuNode store mlsstore.MlsStore validationService mlsvalidate.MLSValidationService + + publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error + + ns *server.Server + nc *nats.Conn + + ctx context.Context + ctxCancel func() } -func NewService(node *wakunode.WakuNode, logger *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) { +func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { s = &Service{ - log: logger.Named("mls/v1"), - waku: node, - store: store, - validationService: validationService, + log: log.Named("mls/v1"), + store: store, + validationService: validationService, + publishToWakuRelay: publishToWakuRelay, + } + s.ctx, s.ctxCancel = context.WithCancel(context.Background()) + + // Initialize nats for subscriptions. + s.ns, err = server.NewServer(&server.Options{ + Port: server.RANDOM_PORT, + }) + if err != nil { + return nil, err + } + go s.ns.Start() + if !s.ns.ReadyForConnections(4 * time.Second) { + return nil, errors.New("nats not ready") + } + s.nc, err = nats.Connect(s.ns.ClientURL()) + if err != nil { + return nil, err } s.log.Info("Starting MLS service") return s, nil } -func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) { +func (s *Service) Close() { + s.log.Info("closing") + + if s.ctxCancel != nil { + s.ctxCancel() + } + + if s.nc != nil { + s.nc.Close() + } + if s.ns != nil { + s.ns.Shutdown() + } + + s.log.Info("closed") +} + +func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) error { + if topic.IsMLSV1Group(wakuMsg.ContentTopic) { + var msg mlsv1.GroupMessage + err := pb.Unmarshal(wakuMsg.Payload, &msg) + if err != nil { + return err + } + if msg.GetV1() == nil { + return nil + } + err = s.nc.Publish(buildNatsSubjectForGroupMessages(msg.GetV1().GroupId), wakuMsg.Payload) + if err != nil { + return err + } + } else if topic.IsMLSV1Welcome(wakuMsg.ContentTopic) { + var msg mlsv1.WelcomeMessage + err := pb.Unmarshal(wakuMsg.Payload, &msg) + if err != nil { + return err + } + if msg.GetV1() == nil { + return nil + } + err = s.nc.Publish(buildNatsSubjectForWelcomeMessages(msg.GetV1().InstallationKey), wakuMsg.Payload) + if err != nil { + return err + } + } else { + s.log.Info("received unknown mls message type from waku relay", zap.String("topic", wakuMsg.ContentTopic)) + } + + return errors.New("not implemented") +} + +func (s *Service) RegisterInstallation(ctx context.Context, req *mlsv1.RegisterInstallationRequest) (*mlsv1.RegisterInstallationResponse, error) { if err := validateRegisterInstallationRequest(req); err != nil { return nil, err } @@ -59,12 +142,12 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI return nil, err } - return &proto.RegisterInstallationResponse{ + return &mlsv1.RegisterInstallationResponse{ InstallationKey: installationId, }, nil } -func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPackagesRequest) (*proto.FetchKeyPackagesResponse, error) { +func (s *Service) FetchKeyPackages(ctx context.Context, req *mlsv1.FetchKeyPackagesRequest) (*mlsv1.FetchKeyPackagesResponse, error) { ids := req.InstallationKeys installations, err := s.store.FetchKeyPackages(ctx, ids) if err != nil { @@ -75,7 +158,7 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPacka keyPackageMap[string(id)] = idx } - resPackages := make([]*proto.FetchKeyPackagesResponse_KeyPackage, len(ids)) + resPackages := make([]*mlsv1.FetchKeyPackagesResponse_KeyPackage, len(ids)) for _, installation := range installations { idx, ok := keyPackageMap[string(installation.ID)] @@ -83,17 +166,17 @@ func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPacka return nil, status.Errorf(codes.Internal, "could not find key package for installation") } - resPackages[idx] = &proto.FetchKeyPackagesResponse_KeyPackage{ + resPackages[idx] = &mlsv1.FetchKeyPackagesResponse_KeyPackage{ KeyPackageTlsSerialized: installation.KeyPackage, } } - return &proto.FetchKeyPackagesResponse{ + return &mlsv1.FetchKeyPackagesResponse{ KeyPackages: resPackages, }, nil } -func (s *Service) UploadKeyPackage(ctx context.Context, req *proto.UploadKeyPackageRequest) (res *emptypb.Empty, err error) { +func (s *Service) UploadKeyPackage(ctx context.Context, req *mlsv1.UploadKeyPackageRequest) (res *emptypb.Empty, err error) { if err = validateUploadKeyPackageRequest(req); err != nil { return nil, err } @@ -115,11 +198,11 @@ func (s *Service) UploadKeyPackage(ctx context.Context, req *proto.UploadKeyPack return &emptypb.Empty{}, nil } -func (s *Service) RevokeInstallation(ctx context.Context, req *proto.RevokeInstallationRequest) (*emptypb.Empty, error) { +func (s *Service) RevokeInstallation(ctx context.Context, req *mlsv1.RevokeInstallationRequest) (*emptypb.Empty, error) { return nil, status.Errorf(codes.Unimplemented, "unimplemented") } -func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (res *proto.GetIdentityUpdatesResponse, err error) { +func (s *Service) GetIdentityUpdates(ctx context.Context, req *mlsv1.GetIdentityUpdatesRequest) (res *mlsv1.GetIdentityUpdatesResponse, err error) { if err = validateGetIdentityUpdatesRequest(req); err != nil { return nil, err } @@ -130,12 +213,12 @@ func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentity return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err) } - resUpdates := make([]*proto.GetIdentityUpdatesResponse_WalletUpdates, len(accountAddresses)) + resUpdates := make([]*mlsv1.GetIdentityUpdatesResponse_WalletUpdates, len(accountAddresses)) for i, accountAddress := range accountAddresses { walletUpdates := updates[accountAddress] - resUpdates[i] = &proto.GetIdentityUpdatesResponse_WalletUpdates{ - Updates: []*proto.GetIdentityUpdatesResponse_Update{}, + resUpdates[i] = &mlsv1.GetIdentityUpdatesResponse_WalletUpdates{ + Updates: []*mlsv1.GetIdentityUpdatesResponse_Update{}, } for _, walletUpdate := range walletUpdates { @@ -143,12 +226,12 @@ func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentity } } - return &proto.GetIdentityUpdatesResponse{ + return &mlsv1.GetIdentityUpdatesResponse{ Updates: resUpdates, }, nil } -func (s *Service) SendGroupMessages(ctx context.Context, req *proto.SendGroupMessagesRequest) (res *emptypb.Empty, err error) { +func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMessagesRequest) (res *emptypb.Empty, err error) { if err = validateSendGroupMessagesRequest(req); err != nil { return nil, err } @@ -179,21 +262,34 @@ func (s *Service) SendGroupMessages(ctx context.Context, req *proto.SendGroupMes return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) } - wakuTopic := topic.BuildGroupTopic(decodedGroupId) - _, err = s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ - ContentTopic: wakuTopic, + msgB, err := pb.Marshal(&mlsv1.GroupMessage{ + Version: &mlsv1.GroupMessage_V1_{ + V1: &mlsv1.GroupMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + GroupId: msg.GroupId, + Data: msg.Data, + }, + }, + }) + if err != nil { + return nil, err + } + + err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1GroupTopic(decodedGroupId), Timestamp: msg.CreatedAt.UnixNano(), - Payload: msg.Data, + Payload: msgB, }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to publish message: %s", err) + return nil, err } } return &emptypb.Empty{}, nil } -func (s *Service) SendWelcomeMessages(ctx context.Context, req *proto.SendWelcomeMessagesRequest) (res *emptypb.Empty, err error) { +func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcomeMessagesRequest) (res *emptypb.Empty, err error) { if err = validateSendWelcomeMessagesRequest(req); err != nil { return nil, err } @@ -208,42 +304,114 @@ func (s *Service) SendWelcomeMessages(ctx context.Context, req *proto.SendWelcom return nil, status.Errorf(codes.Internal, "failed to insert message: %s", err) } - wakuTopic := topic.BuildWelcomeTopic(input.GetV1().InstallationKey) - _, err = s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ - ContentTopic: wakuTopic, + msgB, err := pb.Marshal(&mlsv1.WelcomeMessage{ + Version: &mlsv1.WelcomeMessage_V1_{ + V1: &mlsv1.WelcomeMessage_V1{ + Id: msg.Id, + CreatedNs: uint64(msg.CreatedAt.UnixNano()), + InstallationKey: msg.InstallationKey, + Data: msg.Data, + }, + }, + }) + if err != nil { + return nil, err + } + + err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{ + ContentTopic: topic.BuildMLSV1WelcomeTopic(input.GetV1().InstallationKey), Timestamp: msg.CreatedAt.UnixNano(), - Payload: msg.Data, + Payload: msgB, }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to publish message: %s", err) + return nil, err } } return &emptypb.Empty{}, nil } -func (s *Service) QueryGroupMessages(ctx context.Context, req *proto.QueryGroupMessagesRequest) (*proto.QueryGroupMessagesResponse, error) { +func (s *Service) QueryGroupMessages(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { return s.store.QueryGroupMessagesV1(ctx, req) } -func (s *Service) QueryWelcomeMessages(ctx context.Context, req *proto.QueryWelcomeMessagesRequest) (*proto.QueryWelcomeMessagesResponse, error) { +func (s *Service) QueryWelcomeMessages(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { return s.store.QueryWelcomeMessagesV1(ctx, req) } -func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdatesResponse_Update { - base := proto.GetIdentityUpdatesResponse_Update{ +func (s *Service) SubscribeGroupMessages(req *mlsv1.SubscribeGroupMessagesRequest, stream mlsv1.MlsApi_SubscribeGroupMessagesServer) error { + log := s.log.Named("subscribe-group-messages").With(zap.Int("filters", len(req.Filters))) + + // Send a header (any header) to fix an issue with Tonic based GRPC clients. + // See: https://github.com/xmtp/libxmtp/pull/58 + _ = stream.SendHeader(metadata.Pairs("subscribed", "true")) + + var streamLock sync.Mutex + for _, filter := range req.Filters { + natsSubject := buildNatsSubjectForGroupMessages(filter.GroupId) + sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) { + var msg mlsv1.GroupMessage + err := pb.Unmarshal(natsMsg.Data, &msg) + if err != nil { + log.Error("parsing group message from bytes", zap.Error(err)) + return + } + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(&msg) + if err != nil { + log.Error("sending group message to subscribe", zap.Error(err)) + } + }() + }) + if err != nil { + log.Error("error subscribing to group messages", zap.Error(err)) + return err + } + defer func() { + _ = sub.Unsubscribe() + }() + } + + select { + case <-stream.Context().Done(): + return nil + case <-s.ctx.Done(): + return nil + } +} + +func (s *Service) SubscribeWelcomeMessages(req *mlsv1.SubscribeWelcomeMessagesRequest, stream mlsv1.MlsApi_SubscribeWelcomeMessagesServer) error { + return status.Errorf(codes.Unimplemented, "method SubscribeWelcomeMessages not implemented") +} + +func buildNatsSubjectForGroupMessages(groupId []byte) string { + hasher := fnv.New64a() + hasher.Write(groupId) + return fmt.Sprintf("gm-%x", hasher.Sum64()) +} + +func buildNatsSubjectForWelcomeMessages(installationId []byte) string { + hasher := fnv.New64a() + hasher.Write(installationId) + return fmt.Sprintf("wm-%x", hasher.Sum64()) +} + +func buildIdentityUpdate(update mlsstore.IdentityUpdate) *mlsv1.GetIdentityUpdatesResponse_Update { + base := mlsv1.GetIdentityUpdatesResponse_Update{ TimestampNs: update.TimestampNs, } switch update.Kind { case mlsstore.Create: - base.Kind = &proto.GetIdentityUpdatesResponse_Update_NewInstallation{ - NewInstallation: &proto.GetIdentityUpdatesResponse_NewInstallationUpdate{ + base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_NewInstallation{ + NewInstallation: &mlsv1.GetIdentityUpdatesResponse_NewInstallationUpdate{ InstallationKey: update.InstallationKey, CredentialIdentity: update.CredentialIdentity, }, } case mlsstore.Revoke: - base.Kind = &proto.GetIdentityUpdatesResponse_Update_RevokedInstallation{ - RevokedInstallation: &proto.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ + base.Kind = &mlsv1.GetIdentityUpdatesResponse_Update_RevokedInstallation{ + RevokedInstallation: &mlsv1.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ InstallationKey: update.InstallationKey, }, } @@ -252,7 +420,7 @@ func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdat return &base } -func validateSendGroupMessagesRequest(req *proto.SendGroupMessagesRequest) error { +func validateSendGroupMessagesRequest(req *mlsv1.SendGroupMessagesRequest) error { if req == nil || len(req.Messages) == 0 { return status.Errorf(codes.InvalidArgument, "no group messages to send") } @@ -264,7 +432,7 @@ func validateSendGroupMessagesRequest(req *proto.SendGroupMessagesRequest) error return nil } -func validateSendWelcomeMessagesRequest(req *proto.SendWelcomeMessagesRequest) error { +func validateSendWelcomeMessagesRequest(req *mlsv1.SendWelcomeMessagesRequest) error { if req == nil || len(req.Messages) == 0 { return status.Errorf(codes.InvalidArgument, "no welcome messages to send") } @@ -276,21 +444,21 @@ func validateSendWelcomeMessagesRequest(req *proto.SendWelcomeMessagesRequest) e return nil } -func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error { +func validateRegisterInstallationRequest(req *mlsv1.RegisterInstallationRequest) error { if req == nil || req.KeyPackage == nil { return status.Errorf(codes.InvalidArgument, "no key package") } return nil } -func validateUploadKeyPackageRequest(req *proto.UploadKeyPackageRequest) error { +func validateUploadKeyPackageRequest(req *mlsv1.UploadKeyPackageRequest) error { if req == nil || req.KeyPackage == nil { return status.Errorf(codes.InvalidArgument, "no key package") } return nil } -func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) error { +func validateGetIdentityUpdatesRequest(req *mlsv1.GetIdentityUpdatesRequest) error { if req == nil || len(req.AccountAddresses) == 0 { return status.Errorf(codes.InvalidArgument, "no wallet addresses to get updates for") } diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index ae8f85d7..0a8057c0 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/uptrace/bun" + wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" mlsv1 "github.com/xmtp/proto/v3/go/mls/api/v1" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" @@ -67,15 +68,16 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mock DB: db, }) require.NoError(t, err) - node, nodeCleanup := test.NewNode(t) mlsValidationService := newMockedValidationService() - svc, err := NewService(node, log, store, mlsValidationService) + svc, err := NewService(log, store, mlsValidationService, func(ctx context.Context, wm *wakupb.WakuMessage) error { + return nil + }) require.NoError(t, err) return svc, db, mlsValidationService, func() { + svc.Close() mlsDbCleanup() - nodeCleanup() } } @@ -329,3 +331,34 @@ func TestGetIdentityUpdates(t *testing.T) { require.Len(t, identityUpdates.Updates, 1) require.Len(t, identityUpdates.Updates[0].Updates, 2) } + +func TestSubscribeGroupMessages(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + groupId := []byte(test.RandomString(32)) + + mlsValidationService.mockValidateGroupMessages(groupId) + + _, err := svc.SendGroupMessages(ctx, &mlsv1.SendGroupMessagesRequest{ + Messages: []*mlsv1.GroupMessageInput{ + { + Version: &mlsv1.GroupMessageInput_V1_{ + V1: &mlsv1.GroupMessageInput_V1{ + Data: []byte("data"), + }, + }, + }, + }, + }) + require.NoError(t, err) + + resp, err := svc.store.QueryGroupMessagesV1(ctx, &mlsv1.QueryGroupMessagesRequest{ + GroupId: groupId, + }) + require.NoError(t, err) + require.Len(t, resp.Messages, 1) + require.Equal(t, resp.Messages[0].GetV1().Data, []byte("data")) + require.NotEmpty(t, resp.Messages[0].GetV1().CreatedNs) +} diff --git a/pkg/topic/mls.go b/pkg/topic/mls.go new file mode 100644 index 00000000..26169041 --- /dev/null +++ b/pkg/topic/mls.go @@ -0,0 +1,28 @@ +package topic + +import ( + "fmt" + "strings" +) + +const mlsv1Prefix = "/xmtp/mls/1/" + +func IsMLSV1(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix) +} + +func IsMLSV1Group(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix+"g-") +} + +func IsMLSV1Welcome(topic string) bool { + return strings.HasPrefix(topic, mlsv1Prefix+"w-") +} + +func BuildMLSV1GroupTopic(groupId []byte) string { + return fmt.Sprintf("%sg-%s/proto", mlsv1Prefix, groupId) +} + +func BuildMLSV1WelcomeTopic(installationId []byte) string { + return fmt.Sprintf("%sw-%x/proto", mlsv1Prefix, installationId) +} diff --git a/pkg/topic/topic.go b/pkg/topic/topic.go index 554e8c8a..f2460045 100644 --- a/pkg/topic/topic.go +++ b/pkg/topic/topic.go @@ -1,7 +1,6 @@ package topic import ( - "fmt" "strings" ) @@ -38,11 +37,3 @@ func Category(contentTopic string) string { } return "invalid" } - -func BuildGroupTopic(groupId []byte) string { - return fmt.Sprintf("/xmtp/mls/1/g-%s/proto", groupId) -} - -func BuildWelcomeTopic(installationId []byte) string { - return fmt.Sprintf("/xmtp/mls/1/w-%x/proto", installationId) -}