From 4367d6449490a30118ea4cc030b7a0b218169f57 Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Thu, 26 Oct 2023 18:44:31 -0700 Subject: [PATCH] Add more MLS methods (#321) * Scaffold MLS server * Update go.mod * Fix missing argument * Add unsaved file * Lint * Working end-to-end * Lint * Add new push action * Add a bunch of new endpoints * Address review comments * Change method casing * Change casing of server options * Change casing of validation options * Remove unused function * Remove double pointer * Make private again * Fix pointer to key package * Capitalize more things * Update server fields * Add test for sort methods * Save change to capitalization * Fix lint warnings * Fix problem with mocks * Fix index name * Move sorting to the store * Fix ciphertext validation * Make installation_id bytes * Add missing credential identity * Hack sql in query * Revert "Hack sql in query" This reverts commit 168b78a1111d135b175453b2e7a27cd25c8beef7. * Remove custom type * Update to latest protos * Add CredentialIdentity --- go.mod | 2 +- pkg/api/message/v3/service.go | 213 ++++++++++-- pkg/api/message/v3/service_test.go | 312 ++++++++++++++++++ .../mls/20231023050806_init-schema.up.sql | 7 +- pkg/mlsstore/models.go | 11 +- pkg/mlsstore/store.go | 108 +++++- pkg/mlsstore/store_test.go | 155 ++++++--- pkg/mlsvalidate/service.go | 10 +- pkg/mlsvalidate/service_test.go | 12 +- pkg/server/options.go | 2 +- pkg/server/server.go | 25 +- pkg/store/query_test.go | 31 ++ pkg/store/store.go | 72 ++++ pkg/testing/random.go | 7 + pkg/testing/store.go | 2 +- pkg/topic/topic.go | 13 +- 16 files changed, 872 insertions(+), 110 deletions(-) create mode 100644 pkg/api/message/v3/service_test.go diff --git a/go.mod b/go.mod index e86b30f8..e40e63bf 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 + github.com/xmtp/proto/v3 v3.32.1-0.20231026053711-5efc208e3135 github.com/yoheimuta/protolint v0.39.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 diff --git a/pkg/api/message/v3/service.go b/pkg/api/message/v3/service.go index 9ea61de4..e3d4140d 100644 --- a/pkg/api/message/v3/service.go +++ b/pkg/api/message/v3/service.go @@ -4,10 +4,13 @@ import ( "context" wakunode "github.com/waku-org/go-waku/waku/v2/node" + wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" proto "github.com/xmtp/proto/v3/go/message_api/v3" + "github.com/xmtp/xmtp-node-go/pkg/metrics" "github.com/xmtp/xmtp-node-go/pkg/mlsstore" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/store" + "github.com/xmtp/xmtp-node-go/pkg/topic" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -22,9 +25,6 @@ type Service struct { messageStore *store.Store mlsStore mlsstore.MlsStore validationService mlsvalidate.MLSValidationService - - ctx context.Context - ctxCancel func() } func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store.Store, mlsStore mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) { @@ -36,18 +36,15 @@ func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store validationService: validationService, } - s.ctx, s.ctxCancel = context.WithCancel(context.Background()) - + s.log.Info("Starting MLS service") return s, nil } -func (s *Service) Close() { - if s.ctxCancel != nil { - s.ctxCancel() +func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) { + if err := validateRegisterInstallationRequest(req); err != nil { + return nil, err } -} -func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) { results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.LastResortKeyPackage.KeyPackageTlsSerialized}) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) @@ -58,9 +55,9 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI installationId := results[0].InstallationId walletAddress := results[0].WalletAddress + credentialIdentity := results[0].CredentialIdentity - err = s.mlsStore.CreateInstallation(ctx, installationId, walletAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized) - if err != nil { + if err = s.mlsStore.CreateInstallation(ctx, installationId, walletAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized, credentialIdentity); err != nil { return nil, err } @@ -77,13 +74,13 @@ func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyP } keyPackageMap := make(map[string]int) for idx, id := range ids { - keyPackageMap[id] = idx + keyPackageMap[string(id)] = idx } resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages)) for _, keyPackage := range keyPackages { - idx, ok := keyPackageMap[keyPackage.InstallationId] + idx, ok := keyPackageMap[string(keyPackage.InstallationId)] if !ok { return nil, status.Errorf(codes.Internal, "could not find key package for installation") } @@ -98,15 +95,82 @@ func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyP }, nil } -func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (*emptypb.Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") +func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (res *emptypb.Empty, err error) { + if err = validatePublishToGroupRequest(req); err != nil { + return nil, err + } + + messages := make([][]byte, len(req.Messages)) + for i, message := range req.Messages { + v1 := message.GetV1() + if v1 == nil { + return nil, status.Errorf(codes.InvalidArgument, "message must be v1") + } + messages[i] = v1.MlsMessageTlsSerialized + } + + validationResults, err := s.validationService.ValidateGroupMessages(ctx, messages) + if err != nil { + // TODO: Separate validation errors from internal errors + return nil, status.Errorf(codes.InvalidArgument, "invalid group message: %s", err) + } + + for i, result := range validationResults { + message := messages[i] + + if err = requireReadyToSend(result.GroupId, message); err != nil { + return nil, err + } + + // TODO: Wrap this in a transaction so publishing is all or nothing + if err = s.publishMessage(ctx, topic.BuildGroupTopic(result.GroupId), message); err != nil { + return nil, status.Errorf(codes.Internal, "failed to publish message: %s", err) + } + } + + return &emptypb.Empty{}, nil } -func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcomesRequest) (*emptypb.Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") +func (s *Service) publishMessage(ctx context.Context, contentTopic string, message []byte) error { + log := s.log.Named("publish-mls").With(zap.String("content_topic", contentTopic)) + env, err := s.messageStore.InsertMLSMessage(ctx, contentTopic, message) + if err != nil { + return status.Errorf(codes.Internal, "failed to insert message: %s", err) + } + + if _, err = s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{ + ContentTopic: contentTopic, + Timestamp: int64(env.TimestampNs), + Payload: message, + }); err != nil { + return status.Errorf(codes.Internal, "failed to publish message: %s", err) + } + + metrics.EmitPublishedEnvelope(ctx, log, env) + + return nil +} + +func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcomesRequest) (res *emptypb.Empty, err error) { + if err = validatePublishWelcomesRequest(req); err != nil { + return nil, err + } + + // TODO: Wrap this in a transaction so publishing is all or nothing + for _, welcome := range req.WelcomeMessages { + contentTopic := topic.BuildWelcomeTopic(welcome.InstallationId) + if err = s.publishMessage(ctx, contentTopic, welcome.WelcomeMessage.GetV1().WelcomeMessageTlsSerialized); err != nil { + return nil, status.Errorf(codes.Internal, "failed to publish welcome message: %s", err) + } + } + return &emptypb.Empty{}, nil } -func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (*emptypb.Empty, error) { +func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (res *emptypb.Empty, err error) { + if err = validateUploadKeyPackagesRequest(req); err != nil { + return nil, err + } + // Extract the key packages from the request keyPackageBytes := make([][]byte, len(req.KeyPackages)) for i, keyPackage := range req.KeyPackages { keyPackageBytes[i] = keyPackage.KeyPackageTlsSerialized @@ -122,8 +186,8 @@ func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPac kp := mlsstore.NewKeyPackage(validationResult.InstallationId, keyPackageBytes[i], false) keyPackageModels[i] = kp } - err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels) - if err != nil { + + if err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels); err != nil { return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) } @@ -134,6 +198,109 @@ func (s *Service) RevokeInstallation(ctx context.Context, req *proto.RevokeInsta return nil, status.Errorf(codes.Unimplemented, "unimplemented") } -func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (*proto.GetIdentityUpdatesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "unimplemented") +func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (res *proto.GetIdentityUpdatesResponse, err error) { + if err = validateGetIdentityUpdatesRequest(req); err != nil { + return nil, err + } + + walletAddresses := req.WalletAddresses + updates, err := s.mlsStore.GetIdentityUpdates(ctx, req.WalletAddresses, int64(req.StartTimeNs)) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err) + } + + resUpdates := make([]*proto.GetIdentityUpdatesResponse_WalletUpdates, len(walletAddresses)) + for i, walletAddress := range walletAddresses { + walletUpdates := updates[walletAddress] + + resUpdates[i] = &proto.GetIdentityUpdatesResponse_WalletUpdates{ + Updates: []*proto.GetIdentityUpdatesResponse_Update{}, + } + + for _, walletUpdate := range walletUpdates { + resUpdates[i].Updates = append(resUpdates[i].Updates, buildIdentityUpdate(walletUpdate)) + } + } + + return &proto.GetIdentityUpdatesResponse{ + Updates: resUpdates, + }, nil +} + +func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdatesResponse_Update { + base := proto.GetIdentityUpdatesResponse_Update{ + TimestampNs: update.TimestampNs, + } + switch update.Kind { + case mlsstore.Create: + base.Kind = &proto.GetIdentityUpdatesResponse_Update_NewInstallation{ + NewInstallation: &proto.GetIdentityUpdatesResponse_NewInstallationUpdate{ + InstallationId: update.InstallationId, + CredentialIdentity: update.CredentialIdentity, + }, + } + case mlsstore.Revoke: + base.Kind = &proto.GetIdentityUpdatesResponse_Update_RevokedInstallation{ + RevokedInstallation: &proto.GetIdentityUpdatesResponse_RevokedInstallationUpdate{ + InstallationId: update.InstallationId, + }, + } + } + + return &base +} + +func validatePublishToGroupRequest(req *proto.PublishToGroupRequest) error { + if req == nil || len(req.Messages) == 0 { + return status.Errorf(codes.InvalidArgument, "no messages to publish") + } + return nil +} + +func validatePublishWelcomesRequest(req *proto.PublishWelcomesRequest) error { + if req == nil || len(req.WelcomeMessages) == 0 { + return status.Errorf(codes.InvalidArgument, "no welcome messages to publish") + } + for _, welcome := range req.WelcomeMessages { + if welcome == nil || welcome.WelcomeMessage == nil { + return status.Errorf(codes.InvalidArgument, "invalid welcome message") + } + + v1 := welcome.WelcomeMessage.GetV1() + if v1 == nil || len(v1.WelcomeMessageTlsSerialized) == 0 { + return status.Errorf(codes.InvalidArgument, "invalid welcome message") + } + } + return nil +} + +func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error { + if req == nil || req.LastResortKeyPackage == nil { + return status.Errorf(codes.InvalidArgument, "no last resort key package") + } + return nil +} + +func validateUploadKeyPackagesRequest(req *proto.UploadKeyPackagesRequest) error { + if req == nil || len(req.KeyPackages) == 0 { + return status.Errorf(codes.InvalidArgument, "no key packages to upload") + } + return nil +} + +func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) error { + if req == nil || len(req.WalletAddresses) == 0 { + return status.Errorf(codes.InvalidArgument, "no wallet addresses to get updates for") + } + return nil +} + +func requireReadyToSend(groupId string, message []byte) error { + if groupId == "" { + return status.Errorf(codes.InvalidArgument, "group id is empty") + } + if len(message) == 0 { + return status.Errorf(codes.InvalidArgument, "message is empty") + } + return nil } diff --git a/pkg/api/message/v3/service_test.go b/pkg/api/message/v3/service_test.go new file mode 100644 index 00000000..a2d61679 --- /dev/null +++ b/pkg/api/message/v3/service_test.go @@ -0,0 +1,312 @@ +package api + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/uptrace/bun" + v1 "github.com/xmtp/proto/v3/go/message_api/v1" + proto "github.com/xmtp/proto/v3/go/message_api/v3" + messageContents "github.com/xmtp/proto/v3/go/mls/message_contents" + "github.com/xmtp/xmtp-node-go/pkg/mlsstore" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" + "github.com/xmtp/xmtp-node-go/pkg/store" + test "github.com/xmtp/xmtp-node-go/pkg/testing" +) + +type mockedMLSValidationService struct { + mock.Mock +} + +func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { + args := m.Called(ctx, keyPackages) + + response := args.Get(0) + if response == nil { + return nil, args.Error(1) + } + + return response.([]mlsvalidate.IdentityValidationResult), args.Error(1) +} + +func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) { + args := m.Called(ctx, groupMessages) + + return args.Get(0).([]mlsvalidate.GroupMessageValidationResult), args.Error(1) +} + +func newMockedValidationService() *mockedMLSValidationService { + return new(mockedMLSValidationService) +} + +func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []byte, walletAddress string) *mock.Call { + return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{ + { + InstallationId: installationId, + WalletAddress: walletAddress, + CredentialIdentity: []byte("test"), + }, + }, nil) +} + +func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId string) *mock.Call { + return m.On("ValidateGroupMessages", mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{ + { + GroupId: groupId, + }, + }, nil) +} + +func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMLSValidationService, func()) { + log := test.NewLog(t) + mlsDb, _, mlsDbCleanup := test.NewMLSDB(t) + mlsStore, err := mlsstore.New(ctx, mlsstore.Config{ + Log: log, + DB: mlsDb, + }) + require.NoError(t, err) + messageDb, _, messageDbCleanup := test.NewDB(t) + messageStore, err := store.New(&store.Config{ + Log: log, + DB: messageDb, + ReaderDB: messageDb, + CleanerDB: messageDb, + }) + require.NoError(t, err) + node, nodeCleanup := test.NewNode(t) + mlsValidationService := newMockedValidationService() + + svc, err := NewService(node, log, messageStore, mlsStore, mlsValidationService) + require.NoError(t, err) + + return svc, mlsDb, mlsValidationService, func() { + messageStore.Close() + mlsDbCleanup() + messageDbCleanup() + nodeCleanup() + } +} + +func TestRegisterInstallation(t *testing.T) { + ctx := context.Background() + svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + + mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) + + res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + + require.NoError(t, err) + require.Equal(t, installationId, res.InstallationId) + + installations := []mlsstore.Installation{} + err = mlsDb.NewSelect().Model(&installations).Where("id = ?", installationId).Scan(ctx) + require.NoError(t, err) + + require.Len(t, installations, 1) + require.Equal(t, walletAddress, installations[0].WalletAddress) +} + +func TestRegisterInstallationError(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + mlsValidationService.On("ValidateKeyPackages", ctx, mock.Anything).Return(nil, errors.New("error validating")) + + res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + require.Error(t, err) + require.Nil(t, res) +} + +func TestUploadKeyPackages(t *testing.T) { + ctx := context.Background() + svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + + mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) + + res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + require.NoError(t, err) + require.NotNil(t, res) + + uploadRes, err := svc.UploadKeyPackages(ctx, &proto.UploadKeyPackagesRequest{ + KeyPackages: []*proto.KeyPackageUpload{ + {KeyPackageTlsSerialized: []byte("test2")}, + }, + }) + require.NoError(t, err) + require.NotNil(t, uploadRes) + + keyPackages := []mlsstore.KeyPackage{} + err = mlsDb.NewSelect().Model(&keyPackages).Where("installation_id = ?", installationId).Scan(ctx) + require.NoError(t, err) + require.Len(t, keyPackages, 2) +} + +func TestConsumeKeyPackages(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + installationId1 := test.RandomBytes(32) + walletAddress1 := test.RandomString(32) + + mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, walletAddress1) + + res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + require.NoError(t, err) + require.NotNil(t, res) + + // Add a second key package + installationId2 := test.RandomBytes(32) + walletAddress2 := test.RandomString(32) + // Unset the original mock so we can set a new one + mockCall.Unset() + mlsValidationService.mockValidateKeyPackages(installationId2, walletAddress2) + + res, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test2"), + }, + }) + require.NoError(t, err) + require.NotNil(t, res) + + consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + InstallationIds: [][]byte{installationId1, installationId2}, + }) + require.NoError(t, err) + require.NotNil(t, consumeRes) + require.Len(t, consumeRes.KeyPackages, 2) + require.Equal(t, []byte("test"), consumeRes.KeyPackages[0].KeyPackageTlsSerialized) + require.Equal(t, []byte("test2"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized) + + // Now do it with the installationIds reversed + consumeRes, err = svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + InstallationIds: [][]byte{installationId2, installationId1}, + }) + + require.NoError(t, err) + require.NotNil(t, consumeRes) + require.Len(t, consumeRes.KeyPackages, 2) + require.Equal(t, []byte("test2"), consumeRes.KeyPackages[0].KeyPackageTlsSerialized) + require.Equal(t, []byte("test"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized) +} + +// Trying to consume key packages that don't exist should fail +func TestConsumeKeyPackagesFail(t *testing.T) { + ctx := context.Background() + svc, _, _, cleanup := newTestService(t, ctx) + defer cleanup() + + consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + InstallationIds: [][]byte{test.RandomBytes(32)}, + }) + require.Error(t, err) + require.Nil(t, consumeRes) +} + +func TestPublishToGroup(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + groupId := test.RandomString(32) + + mlsValidationService.mockValidateGroupMessages(groupId) + + _, err := svc.PublishToGroup(ctx, &proto.PublishToGroupRequest{ + Messages: []*messageContents.GroupMessage{{ + Version: &messageContents.GroupMessage_V1_{ + V1: &messageContents.GroupMessage_V1{ + MlsMessageTlsSerialized: []byte("test"), + }, + }, + }}, + }) + require.NoError(t, err) + + results, err := svc.messageStore.Query(&v1.QueryRequest{ + ContentTopics: []string{fmt.Sprintf("/xmtp/3/g-%s/proto", groupId)}, + }) + require.NoError(t, err) + require.Len(t, results.Envelopes, 1) + require.Equal(t, results.Envelopes[0].Message, []byte("test")) + require.NotNil(t, results.Envelopes[0].TimestampNs) +} + +func TestGetIdentityUpdates(t *testing.T) { + ctx := context.Background() + svc, _, mlsValidationService, cleanup := newTestService(t, ctx) + defer cleanup() + + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + + mockCall := mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) + + _, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + require.NoError(t, err) + + identityUpdates, err := svc.GetIdentityUpdates(ctx, &proto.GetIdentityUpdatesRequest{ + WalletAddresses: []string{walletAddress}, + }) + require.NoError(t, err) + require.NotNil(t, identityUpdates) + require.Len(t, identityUpdates.Updates, 1) + require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().InstallationId, installationId) + require.Equal(t, identityUpdates.Updates[0].Updates[0].GetNewInstallation().CredentialIdentity, []byte("test")) + + for _, walletUpdate := range identityUpdates.Updates { + for _, update := range walletUpdate.Updates { + require.Equal(t, installationId, update.GetNewInstallation().InstallationId) + } + } + + mockCall.Unset() + mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), walletAddress) + _, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ + LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test"), + }, + }) + require.NoError(t, err) + + identityUpdates, err = svc.GetIdentityUpdates(ctx, &proto.GetIdentityUpdatesRequest{ + WalletAddresses: []string{walletAddress}, + }) + require.NoError(t, err) + require.Len(t, identityUpdates.Updates, 1) + require.Len(t, identityUpdates.Updates[0].Updates, 2) +} diff --git a/pkg/migrations/mls/20231023050806_init-schema.up.sql b/pkg/migrations/mls/20231023050806_init-schema.up.sql index b6aa90d5..e77152e9 100644 --- a/pkg/migrations/mls/20231023050806_init-schema.up.sql +++ b/pkg/migrations/mls/20231023050806_init-schema.up.sql @@ -3,16 +3,17 @@ SET --bun:split CREATE TABLE installations ( - id TEXT PRIMARY KEY, + id BYTEA PRIMARY KEY, wallet_address TEXT NOT NULL, created_at BIGINT NOT NULL, + credential_identity BYTEA NOT NULL, revoked_at BIGINT ); --bun:split CREATE TABLE key_packages ( id TEXT PRIMARY KEY, - installation_id TEXT NOT NULL, + installation_id BYTEA NOT NULL, created_at BIGINT NOT NULL, consumed_at BIGINT, not_consumed BOOLEAN DEFAULT TRUE NOT NULL, @@ -33,7 +34,7 @@ CREATE INDEX idx_installations_revoked_at ON installations(revoked_at); --bun:split -- Adding indexes for the key_packages table -CREATE INDEX idx_key_packages_installation_id_not_is_last_resort_created_at ON key_packages( +CREATE INDEX idx_key_packages_installation_id_not_consumed_is_last_resort_created_at ON key_packages( installation_id, not_consumed, is_last_resort, diff --git a/pkg/mlsstore/models.go b/pkg/mlsstore/models.go index 533c595d..226db843 100644 --- a/pkg/mlsstore/models.go +++ b/pkg/mlsstore/models.go @@ -5,17 +5,18 @@ import "github.com/uptrace/bun" type Installation struct { bun.BaseModel `bun:"table:installations"` - ID string `bun:",pk"` - WalletAddress string `bun:"wallet_address,notnull"` - CreatedAt int64 `bun:"created_at,notnull"` - RevokedAt *int64 `bun:"revoked_at"` + ID []byte `bun:",pk,type:bytea"` + WalletAddress string `bun:"wallet_address,notnull"` + CreatedAt int64 `bun:"created_at,notnull"` + RevokedAt *int64 `bun:"revoked_at"` + CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` } type KeyPackage struct { bun.BaseModel `bun:"table:key_packages"` ID string `bun:",pk"` // ID is the hash of the data field - InstallationId string `bun:"installation_id,notnull"` + InstallationId []byte `bun:"installation_id,notnull,type:bytea"` CreatedAt int64 `bun:"created_at,notnull"` ConsumedAt *int64 `bun:"consumed_at"` NotConsumed bool `bun:"not_consumed,default:true"` diff --git a/pkg/mlsstore/store.go b/pkg/mlsstore/store.go index f528b78d..df2f392a 100644 --- a/pkg/mlsstore/store.go +++ b/pkg/mlsstore/store.go @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/hex" "errors" + "sort" "time" "github.com/uptrace/bun" @@ -21,9 +22,10 @@ type Store struct { } type MlsStore interface { - CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error + CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error - ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) + ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) + GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) } func New(ctx context.Context, config Config) (*Store, error) { @@ -40,20 +42,15 @@ func New(ctx context.Context, config Config) (*Store, error) { return s, nil } -func (s *Store) Close() { - if s.db != nil { - s.db.Close() - } -} - // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error { +func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error { createdAt := nowNs() installation := Installation{ - ID: installationId, - WalletAddress: walletAddress, - CreatedAt: createdAt, + ID: installationId, + WalletAddress: walletAddress, + CreatedAt: createdAt, + CredentialIdentity: credentialIdentity, } keyPackage := NewKeyPackage(installationId, lastResortKeyPackage, true) @@ -87,7 +84,7 @@ func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage return err } -func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) { +func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) { keyPackages := make([]*KeyPackage, 0) err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { err := tx.NewRaw(` @@ -125,7 +122,61 @@ func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []string return keyPackages, nil } -func NewKeyPackage(installationId string, data []byte, isLastResort bool) *KeyPackage { +func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) { + updated := make([]*Installation, 0) + // Find all installations that were changed since the startTimeNs + err := s.db.NewSelect(). + Model(&updated). + Where("wallet_address IN (?)", bun.In(walletAddresses)). + WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { + return q.Where("created_at > ?", startTimeNs).WhereOr("revoked_at > ?", startTimeNs) + }). + Order("created_at ASC"). + Scan(ctx) + + if err != nil { + return nil, err + } + + // The returned list is only partially sorted + out := make(map[string]IdentityUpdateList) + for _, installation := range updated { + if installation.CreatedAt > startTimeNs { + out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ + Kind: Create, + InstallationId: installation.ID, + CredentialIdentity: installation.CredentialIdentity, + TimestampNs: uint64(installation.CreatedAt), + }) + } + if installation.RevokedAt != nil && *installation.RevokedAt > startTimeNs { + out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ + Kind: Revoke, + InstallationId: installation.ID, + TimestampNs: uint64(*installation.RevokedAt), + }) + } + } + // Sort the updates by timestamp now that the full list is assembled + for _, updates := range out { + sort.Sort(updates) + } + + return out, nil +} + +func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) error { + _, err := s.db.NewUpdate(). + Model(&Installation{}). + Set("revoked_at = ?", nowNs()). + Where("id = ?", installationId). + Where("revoked_at IS NULL"). + Exec(ctx) + + return err +} + +func NewKeyPackage(installationId []byte, data []byte, isLastResort bool) *KeyPackage { return &KeyPackage{ ID: buildKeyPackageId(data), InstallationId: installationId, @@ -171,3 +222,32 @@ func buildKeyPackageId(keyPackageData []byte) string { digest := sha256.Sum256(keyPackageData) return hex.EncodeToString(digest[:]) } + +type IdentityUpdateKind int + +const ( + Create IdentityUpdateKind = iota + Revoke +) + +type IdentityUpdate struct { + Kind IdentityUpdateKind + InstallationId []byte + CredentialIdentity []byte + TimestampNs uint64 +} + +// Add the required methods to make a valid sort.Sort interface +type IdentityUpdateList []IdentityUpdate + +func (a IdentityUpdateList) Len() int { + return len(a) +} + +func (a IdentityUpdateList) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func (a IdentityUpdateList) Less(i, j int) bool { + return a[i].TimestampNs < a[j].TimestampNs +} diff --git a/pkg/mlsstore/store_test.go b/pkg/mlsstore/store_test.go index 026b3da9..50673cec 100644 --- a/pkg/mlsstore/store_test.go +++ b/pkg/mlsstore/store_test.go @@ -2,8 +2,7 @@ package mlsstore import ( "context" - "crypto/rand" - "fmt" + "sort" "testing" "github.com/stretchr/testify/require" @@ -12,7 +11,7 @@ import ( func NewTestStore(t *testing.T) (*Store, func()) { log := test.NewLog(t) - db, _, dbCleanup := test.NewMlsDB(t) + db, _, dbCleanup := test.NewMLSDB(t) ctx := context.Background() c := Config{ Log: log, @@ -25,25 +24,15 @@ func NewTestStore(t *testing.T) (*Store, func()) { return store, dbCleanup } -func randomBytes(n int) []byte { - b := make([]byte, n) - _, _ = rand.Reader.Read(b) - return b -} - -func randomString(n int) string { - return fmt.Sprintf("%x", randomBytes(n)) -} - func TestCreateInstallation(t *testing.T) { store, cleanup := NewTestStore(t) defer cleanup() ctx := context.Background() - installationId := randomString(32) - walletAddress := randomString(32) + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, randomBytes(32)) + err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) require.NoError(t, err) installationFromDb := &Installation{} @@ -60,13 +49,13 @@ func TestCreateInstallationIdempotent(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := randomString(32) - walletAddress := randomString(32) - keyPackage := randomBytes(32) + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - err = store.CreateInstallation(ctx, installationId, walletAddress, randomBytes(32)) + err = store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) require.NoError(t, err) keyPackageFromDb := &KeyPackage{} @@ -79,14 +68,14 @@ func TestInsertKeyPackages(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := randomString(32) - walletAddress := randomString(32) - keyPackage := randomBytes(32) + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - keyPackage2 := randomBytes(32) + keyPackage2 := test.RandomBytes(32) err = store.InsertKeyPackages(ctx, []*KeyPackage{{ ID: buildKeyPackageId(keyPackage2), InstallationId: installationId, @@ -122,14 +111,14 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := randomString(32) - walletAddress := randomString(32) - keyPackage := randomBytes(32) + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage, consumeResult[0].Data) @@ -141,14 +130,14 @@ func TestConsumeMultipleKeyPackages(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := randomString(32) - walletAddress := randomString(32) - keyPackage := randomBytes(32) + installationId := test.RandomBytes(32) + walletAddress := test.RandomString(32) + keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - keyPackage2 := randomBytes(32) + keyPackage2 := test.RandomBytes(32) require.NoError(t, store.InsertKeyPackages(ctx, []*KeyPackage{{ ID: buildKeyPackageId(keyPackage2), InstallationId: installationId, @@ -157,15 +146,105 @@ func TestConsumeMultipleKeyPackages(t *testing.T) { Data: keyPackage2, }})) - consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage2, consumeResult[0].Data) require.Equal(t, installationId, consumeResult[0].InstallationId) - consumeResult, err = store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err = store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) // Now we are out of regular key packages. Expect to consume the last resort require.Equal(t, keyPackage, consumeResult[0].Data) } + +func TestGetIdentityUpdates(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + walletAddress := test.RandomString(32) + + installationId1 := test.RandomBytes(32) + keyPackage1 := test.RandomBytes(32) + + err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1, keyPackage1) + require.NoError(t, err) + + installationId2 := test.RandomBytes(32) + keyPackage2 := test.RandomBytes(32) + + err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2, keyPackage2) + require.NoError(t, err) + + identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) + require.NoError(t, err) + require.Len(t, identityUpdates[walletAddress], 2) + require.Equal(t, identityUpdates[walletAddress][0].InstallationId, installationId1) + require.Equal(t, identityUpdates[walletAddress][0].Kind, Create) + require.Equal(t, identityUpdates[walletAddress][1].InstallationId, installationId2) + + // Make sure that date filtering works + identityUpdates, err = store.GetIdentityUpdates(ctx, []string{walletAddress}, nowNs()+1000000) + require.NoError(t, err) + require.Len(t, identityUpdates[walletAddress], 0) +} + +func TestGetIdentityUpdatesMultipleWallets(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + walletAddress1 := test.RandomString(32) + installationId1 := test.RandomBytes(32) + keyPackage1 := test.RandomBytes(32) + + err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1, keyPackage1) + require.NoError(t, err) + + walletAddress2 := test.RandomString(32) + installationId2 := test.RandomBytes(32) + keyPackage2 := test.RandomBytes(32) + + err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2, keyPackage2) + require.NoError(t, err) + + identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress1, walletAddress2}, 0) + require.NoError(t, err) + require.Len(t, identityUpdates[walletAddress1], 1) + require.Len(t, identityUpdates[walletAddress2], 1) +} + +func TestGetIdentityUpdatesNoResult(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + ctx := context.Background() + walletAddress := test.RandomString(32) + + identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) + require.NoError(t, err) + require.Len(t, identityUpdates[walletAddress], 0) +} + +func TestIdentityUpdateSort(t *testing.T) { + updates := IdentityUpdateList([]IdentityUpdate{ + { + Kind: Create, + TimestampNs: 2, + }, + { + Kind: Create, + TimestampNs: 3, + }, + { + Kind: Create, + TimestampNs: 1, + }, + }) + sort.Sort(updates) + require.Equal(t, updates[0].TimestampNs, uint64(1)) + require.Equal(t, updates[1].TimestampNs, uint64(2)) + require.Equal(t, updates[2].TimestampNs, uint64(3)) +} diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index 36756ee2..686e533e 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -10,8 +10,9 @@ import ( ) type IdentityValidationResult struct { - WalletAddress string - InstallationId string + WalletAddress string + InstallationId []byte + CredentialIdentity []byte } type GroupMessageValidationResult struct { @@ -54,8 +55,9 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages(ctx context.Context, keyP return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) } out[i] = IdentityValidationResult{ - WalletAddress: response.WalletAddress, - InstallationId: response.InstallationId, + WalletAddress: response.WalletAddress, + InstallationId: response.InstallationId, + CredentialIdentity: response.CredentialIdentityBytes, } } return out, nil diff --git a/pkg/mlsvalidate/service_test.go b/pkg/mlsvalidate/service_test.go index 2fb2cca3..9613eb98 100644 --- a/pkg/mlsvalidate/service_test.go +++ b/pkg/mlsvalidate/service_test.go @@ -41,10 +41,11 @@ func TestValidateKeyPackages(t *testing.T) { ctx := context.Background() firstResponse := svc.ValidateKeyPackagesResponse_ValidationResponse{ - IsOk: true, - WalletAddress: "0x123", - InstallationId: "123", - ErrorMessage: "", + IsOk: true, + WalletAddress: "0x123", + InstallationId: []byte("123"), + CredentialIdentityBytes: []byte("456"), + ErrorMessage: "", } mockGrpc.On("ValidateKeyPackages", ctx, mock.Anything).Return(&svc.ValidateKeyPackagesResponse{ @@ -55,7 +56,8 @@ func TestValidateKeyPackages(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, "0x123", res[0].WalletAddress) - assert.Equal(t, "123", res[0].InstallationId) + assert.Equal(t, []byte("123"), res[0].InstallationId) + assert.Equal(t, []byte("456"), res[0].CredentialIdentity) } func TestValidateKeyPackagesError(t *testing.T) { diff --git a/pkg/server/options.go b/pkg/server/options.go index e21e3efb..477c99e9 100644 --- a/pkg/server/options.go +++ b/pkg/server/options.go @@ -77,5 +77,5 @@ type Options struct { Tracing TracingOptions `group:"DD APM Tracing Options"` Profiling ProfilingOptions `group:"DD APM Profiling Options" namespace:"profiling"` MLSStore mlsstore.StoreOptions `group:"MLS Options" namespace:"mls-store"` - MlsValidation mlsvalidate.MLSValidationOptions `group:"MLS Validation Options" namespace:"mls-validation"` + MLSValidation mlsvalidate.MLSValidationOptions `group:"MLS Validation Options" namespace:"mls-validation"` } diff --git a/pkg/server/server.go b/pkg/server/server.go index 8887fb82..90f96e12 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -61,7 +61,7 @@ type Server struct { allowLister authz.WalletAllowLister authenticator *authn.XmtpAuthentication grpc *api.Server - MLSStore *mlsstore.Store + mlsDB *bun.DB } // Create a new Server @@ -229,27 +229,24 @@ func New(ctx context.Context, log *zap.Logger, options Options) (*Server, error) } s.log.With(logging.MultiAddrs("listen", maddrs...)).Info("got server") - var MLSStore mlsstore.MlsStore - + var MLSStore *mlsstore.Store if options.MLSStore.DbConnectionString != "" { - mlsDb, err := createBunDB(options.MLSStore.DbConnectionString, options.WaitForDB, options.MLSStore.ReadTimeout, options.MLSStore.WriteTimeout, options.MLSStore.MaxOpenConns) - if err != nil { + if s.mlsDB, err = createBunDB(options.MLSStore.DbConnectionString, options.WaitForDB, options.MLSStore.ReadTimeout, options.MLSStore.WriteTimeout, options.MLSStore.MaxOpenConns); err != nil { return nil, errors.Wrap(err, "creating mls db") } - s.MLSStore, err = mlsstore.New(s.ctx, mlsstore.Config{ + s.log.Info("creating mls store") + if MLSStore, err = mlsstore.New(s.ctx, mlsstore.Config{ Log: s.log, - DB: mlsDb, - }) - - if err != nil { + DB: s.mlsDB, + }); err != nil { return nil, errors.Wrap(err, "creating mls store") } } var MLSValidator mlsvalidate.MLSValidationService - if options.MlsValidation.GRPCAddress != "" { - MLSValidator, err = mlsvalidate.NewMlsValidationService(ctx, options.MlsValidation) + if options.MLSValidation.GRPCAddress != "" { + MLSValidator, err = mlsvalidate.NewMlsValidationService(ctx, options.MLSValidation) if err != nil { return nil, errors.Wrap(err, "creating mls validation service") } @@ -306,8 +303,8 @@ func (s *Server) Shutdown() { if s.store != nil { s.store.Close() } - if s.MLSStore != nil { - s.MLSStore.Close() + if s.mlsDB != nil { + s.mlsDB.Close() } // Close metrics server. diff --git a/pkg/store/query_test.go b/pkg/store/query_test.go index 5d084e4a..d3c2c028 100644 --- a/pkg/store/query_test.go +++ b/pkg/store/query_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "testing" "time" @@ -284,3 +285,33 @@ func TestPageSizeOne(t *testing.T) { loops++ } } + +func TestMlsMessagePublish(t *testing.T) { + store, cleanup, _ := createAndFillDb(t) + defer cleanup() + + message := []byte{1, 2, 3} + contentTopic := "foo" + ctx := context.Background() + + env, err := store.InsertMLSMessage(ctx, contentTopic, message) + require.NoError(t, err) + + require.Equal(t, env.ContentTopic, contentTopic) + require.Equal(t, env.Message, message) + + response, err := store.Query(&messagev1.QueryRequest{ + ContentTopics: []string{contentTopic}, + }) + require.NoError(t, err) + require.Len(t, response.Envelopes, 1) + require.Equal(t, response.Envelopes[0].Message, message) + require.Equal(t, response.Envelopes[0].ContentTopic, contentTopic) + require.NotNil(t, response.Envelopes[0].TimestampNs) + + parsedTime := time.Unix(0, int64(response.Envelopes[0].TimestampNs)) + // Sanity check to ensure that the timestamps are reasonable + require.True(t, time.Since(parsedTime) < 10*time.Second || time.Since(parsedTime) > -10*time.Second) + + require.Equal(t, env.TimestampNs, response.Envelopes[0].TimestampNs) +} diff --git a/pkg/store/store.go b/pkg/store/store.go index 71e17d9a..4a309044 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -2,6 +2,7 @@ package store import ( "context" + "fmt" "strings" "sync" "time" @@ -20,6 +21,22 @@ import ( const maxPageSize = 100 +const timestampGeneratorSql = `( + ( + EXTRACT( + EPOCH + FROM + clock_timestamp() + ) :: bigint * 1000000000 + ) + ( + EXTRACT( + MICROSECONDS + FROM + clock_timestamp() + ) :: bigint * 1000 + ) +)` + type Store struct { config *Config ctx context.Context @@ -127,6 +144,61 @@ func (s *Store) InsertMessage(env *messagev1.Envelope) (bool, error) { return stored, err } +func (s *Store) InsertMLSMessage(ctx context.Context, contentTopic string, data []byte) (*messagev1.Envelope, error) { + tmpEnvelope := &messagev1.Envelope{ + ContentTopic: contentTopic, + Message: data, + } + digest := computeDigest(tmpEnvelope) + var envelope messagev1.Envelope + + err := tracing.Wrap(s.ctx, s.log, "storing mls message", func(ctx context.Context, log *zap.Logger, span tracing.Span) error { + tracing.SpanResource(span, "store") + tracing.SpanType(span, "db") + + stmnt := fmt.Sprintf(`INSERT INTO + message ( + id, + receiverTimestamp, + senderTimestamp, + contentTopic, + pubsubTopic, + payload, + version, + should_expire + ) + VALUES + ( + $1, + %s, + %s, + $2, + $3, + $4, + $5, + $6 + ) RETURNING senderTimestamp`, timestampGeneratorSql, timestampGeneratorSql) + + var senderTimestamp uint64 + err := s.config.DB.QueryRowContext(ctx, stmnt, digest, contentTopic, "", data, 0, false).Scan(&senderTimestamp) + if err != nil { + return err + } + envelope = messagev1.Envelope{ + ContentTopic: contentTopic, + TimestampNs: senderTimestamp, + Message: data, + } + return err + }) + + if err != nil { + return nil, err + } + + return &envelope, nil +} + func (s *Store) insertMessage(env *messagev1.Envelope, receiverTimestamp int64) error { digest := computeDigest(env) shouldExpire := !isXMTP(env.ContentTopic) diff --git a/pkg/testing/random.go b/pkg/testing/random.go index e96427c4..508331db 100644 --- a/pkg/testing/random.go +++ b/pkg/testing/random.go @@ -1,6 +1,7 @@ package testing import ( + cryptoRand "crypto/rand" "math/rand" "strings" ) @@ -18,3 +19,9 @@ func RandomString(n int) string { func RandomStringLower(n int) string { return strings.ToLower(RandomString(n)) } + +func RandomBytes(n int) []byte { + b := make([]byte, n) + _, _ = cryptoRand.Read(b) + return b +} diff --git a/pkg/testing/store.go b/pkg/testing/store.go index f9a2266c..3727fe27 100644 --- a/pkg/testing/store.go +++ b/pkg/testing/store.go @@ -50,7 +50,7 @@ func NewAuthzDB(t *testing.T) (*bun.DB, string, func()) { return bunDB, dsn, cleanup } -func NewMlsDB(t *testing.T) (*bun.DB, string, func()) { +func NewMLSDB(t *testing.T) (*bun.DB, string, func()) { db, dsn, cleanup := NewDB(t) bunDB := bun.NewDB(db, pgdialect.New()) diff --git a/pkg/topic/topic.go b/pkg/topic/topic.go index 0bc16df0..f0192bbf 100644 --- a/pkg/topic/topic.go +++ b/pkg/topic/topic.go @@ -1,6 +1,9 @@ package topic -import "strings" +import ( + "fmt" + "strings" +) var topicCategoryByPrefix = map[string]string{ "test": "test", @@ -35,3 +38,11 @@ func Category(contentTopic string) string { } return "invalid" } + +func BuildGroupTopic(groupId string) string { + return fmt.Sprintf("/xmtp/3/g-%s/proto", groupId) +} + +func BuildWelcomeTopic(installationId []byte) string { + return fmt.Sprintf("/xmtp/3/w-%x/proto", installationId) +}