Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement subscriptions #412

Open
wants to merge 1 commit into
base: 12-30-stub_out_new_service
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/api/message/v1/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,5 @@ func (sub *subscription) Unsubscribe() {
}

func isValidSubscribeAllTopic(contentTopic string) bool {
return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || topic.IsMLSV1(contentTopic)
return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || (topic.IsMLSV1(contentTopic) && !topic.IsAssociationChanged(contentTopic))
}
2 changes: 1 addition & 1 deletion pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (s *Server) startGRPC() error {
}
mlsv1pb.RegisterMlsApiServer(grpcServer, s.mlsv1)

s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator)
s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, s.natsServer)
if err != nil {
return errors.Wrap(err, "creating identity service")
}
Expand Down
61 changes: 57 additions & 4 deletions pkg/identity/api/v1/identity_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@ package api
import (
"context"

"github.com/nats-io/nats-server/v2/server"
"github.com/nats-io/nats.go"
"github.com/xmtp/xmtp-node-go/pkg/envelopes"
mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
v1proto "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1"
"github.com/xmtp/xmtp-node-go/pkg/topic"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/metadata"
pb "google.golang.org/protobuf/proto"
)

type Service struct {
Expand All @@ -20,17 +25,23 @@ type Service struct {
validationService mlsvalidate.MLSValidationService

ctx context.Context
nc *nats.Conn
ctxCancel func()
}

func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) {
func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService, natsServer *server.Server) (s *Service, err error) {
s = &Service{
log: log.Named("identity"),
store: store,
validationService: validationService,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())

s.nc, err = nats.Connect(natsServer.ClientURL())
if err != nil {
return nil, err
}

s.log.Info("Starting identity service")
return s, nil
}
Expand Down Expand Up @@ -109,5 +120,47 @@ func (s *Service) SubscribeAssociationChanges(req *identity.SubscribeAssociation
log := s.log.Named("subscribe-association-changes")
log.Info("subscription started")

return status.Errorf(codes.Unimplemented, "method SubscribeAssociationChanges not implemented")
_ = stream.SendHeader(metadata.Pairs("subscribed", "true"))

natsSubject := buildNatsSubjectForAssociationChanges()
sub, err := s.nc.Subscribe(natsSubject, func(natsMsg *nats.Msg) {
msg, err := getAssociationChangedMessageFromNats(natsMsg)
if err != nil {
log.Error("parsing message", zap.Error(err))
}
if err = stream.Send(msg); err != nil {
log.Warn("sending message to stream", zap.Error(err))
}
})

if err != nil {
log.Error("error subscribing to nats", zap.Error(err))
return err
}

defer func() {
_ = sub.Unsubscribe()
}()

return nil
}

func buildNatsSubjectForAssociationChanges() string {
return envelopes.BuildNatsSubject(topic.BuildAssociationChangedTopic())
}

func getAssociationChangedMessageFromNats(natsMsg *nats.Msg) (*identity.SubscribeAssociationChangesResponse, error) {
var env v1proto.Envelope
err := pb.Unmarshal(natsMsg.Data, &env)
if err != nil {
return nil, err
}

var msg identity.SubscribeAssociationChangesResponse
err = pb.Unmarshal(env.Message, &msg)
if err != nil {
return nil, err
}

return &msg, nil
}
13 changes: 12 additions & 1 deletion pkg/identity/api/v1/identity_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package api
import (
"context"
"testing"
"time"

"github.com/nats-io/nats-server/v2/server"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
Expand Down Expand Up @@ -73,9 +75,18 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, func(
DB: db,
})
require.NoError(t, err)
natsServer, err := server.NewServer(&server.Options{
Port: server.RANDOM_PORT,
})
require.NoError(t, err)
go natsServer.Start()
if !natsServer.ReadyForConnections(4 * time.Second) {
t.Fail()
}
require.NoError(t, err)
mlsValidationService := newMockedValidationService()

svc, err := NewService(log, store, mlsValidationService)
svc, err := NewService(log, store, mlsValidationService, natsServer)
require.NoError(t, err)

return svc, db, func() {
Expand Down
29 changes: 11 additions & 18 deletions pkg/mls/api/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,19 @@ func (s *Service) Close() {
}

func (s *Service) HandleIncomingWakuRelayMessage(wakuMsg *wakupb.WakuMessage) error {
if topic.IsMLSV1Group(wakuMsg.ContentTopic) {
s.log.Info("received group message from waku relay", zap.String("topic", wakuMsg.ContentTopic))

// Build the nats subject from the topic
natsSubject := envelopes.BuildNatsSubject(wakuMsg.ContentTopic)
s.log.Info("publishing to nats subject from relay", zap.String("subject", natsSubject))
env := envelopes.BuildEnvelope(wakuMsg)
envB, err := pb.Marshal(env)
if err != nil {
return err
isMlsGroup := topic.IsMLSV1Group(wakuMsg.ContentTopic)
isMlsWelcome := topic.IsMLSV1Welcome(wakuMsg.ContentTopic)
isAssociationChanged := topic.IsAssociationChanged(wakuMsg.ContentTopic)
if isMlsGroup || isMlsWelcome || isAssociationChanged {
if isMlsGroup {
s.log.Info("received group message from waku relay", zap.String("topic", wakuMsg.ContentTopic))
} else if isMlsWelcome {
s.log.Info("received welcome message from waku relay", zap.String("topic", wakuMsg.ContentTopic))
} else if isAssociationChanged {
s.log.Info("received association changed message from waku relay", zap.String("topic", wakuMsg.ContentTopic))
}

err = s.nc.Publish(natsSubject, envB)
if err != nil {
s.log.Error("error publishing to nats", zap.Error(err))
return err
}
} else if topic.IsMLSV1Welcome(wakuMsg.ContentTopic) {
s.log.Info("received welcome message from waku relay", zap.String("topic", wakuMsg.ContentTopic))

// Build the nats subject from the topic
natsSubject := envelopes.BuildNatsSubject(wakuMsg.ContentTopic)
s.log.Info("publishing to nats subject from relay", zap.String("subject", natsSubject))
env := envelopes.BuildEnvelope(wakuMsg)
Expand Down
12 changes: 12 additions & 0 deletions pkg/topic/mls.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ import (

const mlsv1Prefix = "/xmtp/mls/1/"

var (
AssociationChangedTopic = BuildAssociationChangedTopic()
)

func IsMLSV1(topic string) bool {
return strings.HasPrefix(topic, mlsv1Prefix)
}
Expand All @@ -19,10 +23,18 @@ func IsMLSV1Welcome(topic string) bool {
return strings.HasPrefix(topic, mlsv1Prefix+"w-")
}

func IsAssociationChanged(topic string) bool {
return topic == AssociationChangedTopic
}

func BuildMLSV1GroupTopic(groupId []byte) string {
return fmt.Sprintf("%sg-%x/proto", mlsv1Prefix, groupId)
}

func BuildMLSV1WelcomeTopic(installationId []byte) string {
return fmt.Sprintf("%sw-%x/proto", mlsv1Prefix, installationId)
}

func BuildAssociationChangedTopic() string {
return fmt.Sprintf("%sassociations/proto", mlsv1Prefix)
}
Loading