Skip to content

Commit

Permalink
Add more MLS methods (#321)
Browse files Browse the repository at this point in the history
* Scaffold MLS server

* Update go.mod

* Fix missing argument

* Add unsaved file

* Lint

* Working end-to-end

* Lint

* Add new push action

* Add a bunch of new endpoints

* Address review comments

* Change method casing

* Change casing of server options

* Change casing of validation options

* Remove unused function

* Remove double pointer

* Make private again

* Fix pointer to key package

* Capitalize more things

* Update server fields

* Add test for sort methods

* Save change to capitalization

* Fix lint warnings

* Fix problem with mocks

* Fix index name

* Move sorting to the store

* Fix ciphertext validation

* Make installation_id bytes

* Add missing credential identity

* Hack sql in query

* Revert "Hack sql in query"

This reverts commit 168b78a.

* Remove custom type

* Update to latest protos

* Add CredentialIdentity
  • Loading branch information
neekolas authored and Steven Normore committed Jan 23, 2024
1 parent 3d22b24 commit 8f6f633
Show file tree
Hide file tree
Showing 16 changed files with 872 additions and 110 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ require (
github.com/uptrace/bun/driver/pgdriver v1.1.16
github.com/waku-org/go-waku v0.8.0
github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3
github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4
github.com/xmtp/proto/v3 v3.32.1-0.20231026053711-5efc208e3135
github.com/yoheimuta/protolint v0.39.0
go.uber.org/zap v1.24.0
golang.org/x/sync v0.3.0
Expand Down
213 changes: 190 additions & 23 deletions pkg/api/message/v3/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ import (
"context"

wakunode "github.com/waku-org/go-waku/waku/v2/node"
wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb"
proto "github.com/xmtp/proto/v3/go/message_api/v3"
"github.com/xmtp/xmtp-node-go/pkg/metrics"
"github.com/xmtp/xmtp-node-go/pkg/mlsstore"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
"github.com/xmtp/xmtp-node-go/pkg/store"
"github.com/xmtp/xmtp-node-go/pkg/topic"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
Expand All @@ -22,9 +25,6 @@ type Service struct {
messageStore *store.Store
mlsStore mlsstore.MlsStore
validationService mlsvalidate.MLSValidationService

ctx context.Context
ctxCancel func()
}

func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store.Store, mlsStore mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService) (s *Service, err error) {
Expand All @@ -36,18 +36,15 @@ func NewService(node *wakunode.WakuNode, logger *zap.Logger, messageStore *store
validationService: validationService,
}

s.ctx, s.ctxCancel = context.WithCancel(context.Background())

s.log.Info("Starting MLS service")
return s, nil
}

func (s *Service) Close() {
if s.ctxCancel != nil {
s.ctxCancel()
func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) {
if err := validateRegisterInstallationRequest(req); err != nil {
return nil, err
}
}

func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterInstallationRequest) (*proto.RegisterInstallationResponse, error) {
results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.LastResortKeyPackage.KeyPackageTlsSerialized})
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err)
Expand All @@ -58,9 +55,9 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI

installationId := results[0].InstallationId
walletAddress := results[0].WalletAddress
credentialIdentity := results[0].CredentialIdentity

err = s.mlsStore.CreateInstallation(ctx, installationId, walletAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized)
if err != nil {
if err = s.mlsStore.CreateInstallation(ctx, installationId, walletAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized, credentialIdentity); err != nil {
return nil, err
}

Expand All @@ -77,13 +74,13 @@ func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyP
}
keyPackageMap := make(map[string]int)
for idx, id := range ids {
keyPackageMap[id] = idx
keyPackageMap[string(id)] = idx
}

resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages))
for _, keyPackage := range keyPackages {

idx, ok := keyPackageMap[keyPackage.InstallationId]
idx, ok := keyPackageMap[string(keyPackage.InstallationId)]
if !ok {
return nil, status.Errorf(codes.Internal, "could not find key package for installation")
}
Expand All @@ -98,15 +95,82 @@ func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyP
}, nil
}

func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupRequest) (res *emptypb.Empty, err error) {
if err = validatePublishToGroupRequest(req); err != nil {
return nil, err
}

messages := make([][]byte, len(req.Messages))
for i, message := range req.Messages {
v1 := message.GetV1()
if v1 == nil {
return nil, status.Errorf(codes.InvalidArgument, "message must be v1")
}
messages[i] = v1.MlsMessageTlsSerialized
}

validationResults, err := s.validationService.ValidateGroupMessages(ctx, messages)
if err != nil {
// TODO: Separate validation errors from internal errors
return nil, status.Errorf(codes.InvalidArgument, "invalid group message: %s", err)
}

for i, result := range validationResults {
message := messages[i]

if err = requireReadyToSend(result.GroupId, message); err != nil {
return nil, err
}

// TODO: Wrap this in a transaction so publishing is all or nothing
if err = s.publishMessage(ctx, topic.BuildGroupTopic(result.GroupId), message); err != nil {
return nil, status.Errorf(codes.Internal, "failed to publish message: %s", err)
}
}

return &emptypb.Empty{}, nil
}

func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcomesRequest) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
func (s *Service) publishMessage(ctx context.Context, contentTopic string, message []byte) error {
log := s.log.Named("publish-mls").With(zap.String("content_topic", contentTopic))
env, err := s.messageStore.InsertMLSMessage(ctx, contentTopic, message)
if err != nil {
return status.Errorf(codes.Internal, "failed to insert message: %s", err)
}

if _, err = s.waku.Relay().Publish(ctx, &wakupb.WakuMessage{
ContentTopic: contentTopic,
Timestamp: int64(env.TimestampNs),
Payload: message,
}); err != nil {
return status.Errorf(codes.Internal, "failed to publish message: %s", err)
}

metrics.EmitPublishedEnvelope(ctx, log, env)

return nil
}

func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcomesRequest) (res *emptypb.Empty, err error) {
if err = validatePublishWelcomesRequest(req); err != nil {
return nil, err
}

// TODO: Wrap this in a transaction so publishing is all or nothing
for _, welcome := range req.WelcomeMessages {
contentTopic := topic.BuildWelcomeTopic(welcome.InstallationId)
if err = s.publishMessage(ctx, contentTopic, welcome.WelcomeMessage.GetV1().WelcomeMessageTlsSerialized); err != nil {
return nil, status.Errorf(codes.Internal, "failed to publish welcome message: %s", err)
}
}
return &emptypb.Empty{}, nil
}

func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (*emptypb.Empty, error) {
func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (res *emptypb.Empty, err error) {
if err = validateUploadKeyPackagesRequest(req); err != nil {
return nil, err
}
// Extract the key packages from the request
keyPackageBytes := make([][]byte, len(req.KeyPackages))
for i, keyPackage := range req.KeyPackages {
keyPackageBytes[i] = keyPackage.KeyPackageTlsSerialized
Expand All @@ -122,8 +186,8 @@ func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPac
kp := mlsstore.NewKeyPackage(validationResult.InstallationId, keyPackageBytes[i], false)
keyPackageModels[i] = kp
}
err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels)
if err != nil {

if err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels); err != nil {
return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err)
}

Expand All @@ -134,6 +198,109 @@ func (s *Service) RevokeInstallation(ctx context.Context, req *proto.RevokeInsta
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
}

func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (*proto.GetIdentityUpdatesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
func (s *Service) GetIdentityUpdates(ctx context.Context, req *proto.GetIdentityUpdatesRequest) (res *proto.GetIdentityUpdatesResponse, err error) {
if err = validateGetIdentityUpdatesRequest(req); err != nil {
return nil, err
}

walletAddresses := req.WalletAddresses
updates, err := s.mlsStore.GetIdentityUpdates(ctx, req.WalletAddresses, int64(req.StartTimeNs))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity updates: %s", err)
}

resUpdates := make([]*proto.GetIdentityUpdatesResponse_WalletUpdates, len(walletAddresses))
for i, walletAddress := range walletAddresses {
walletUpdates := updates[walletAddress]

resUpdates[i] = &proto.GetIdentityUpdatesResponse_WalletUpdates{
Updates: []*proto.GetIdentityUpdatesResponse_Update{},
}

for _, walletUpdate := range walletUpdates {
resUpdates[i].Updates = append(resUpdates[i].Updates, buildIdentityUpdate(walletUpdate))
}
}

return &proto.GetIdentityUpdatesResponse{
Updates: resUpdates,
}, nil
}

func buildIdentityUpdate(update mlsstore.IdentityUpdate) *proto.GetIdentityUpdatesResponse_Update {
base := proto.GetIdentityUpdatesResponse_Update{
TimestampNs: update.TimestampNs,
}
switch update.Kind {
case mlsstore.Create:
base.Kind = &proto.GetIdentityUpdatesResponse_Update_NewInstallation{
NewInstallation: &proto.GetIdentityUpdatesResponse_NewInstallationUpdate{
InstallationId: update.InstallationId,
CredentialIdentity: update.CredentialIdentity,
},
}
case mlsstore.Revoke:
base.Kind = &proto.GetIdentityUpdatesResponse_Update_RevokedInstallation{
RevokedInstallation: &proto.GetIdentityUpdatesResponse_RevokedInstallationUpdate{
InstallationId: update.InstallationId,
},
}
}

return &base
}

func validatePublishToGroupRequest(req *proto.PublishToGroupRequest) error {
if req == nil || len(req.Messages) == 0 {
return status.Errorf(codes.InvalidArgument, "no messages to publish")
}
return nil
}

func validatePublishWelcomesRequest(req *proto.PublishWelcomesRequest) error {
if req == nil || len(req.WelcomeMessages) == 0 {
return status.Errorf(codes.InvalidArgument, "no welcome messages to publish")
}
for _, welcome := range req.WelcomeMessages {
if welcome == nil || welcome.WelcomeMessage == nil {
return status.Errorf(codes.InvalidArgument, "invalid welcome message")
}

v1 := welcome.WelcomeMessage.GetV1()
if v1 == nil || len(v1.WelcomeMessageTlsSerialized) == 0 {
return status.Errorf(codes.InvalidArgument, "invalid welcome message")
}
}
return nil
}

func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error {
if req == nil || req.LastResortKeyPackage == nil {
return status.Errorf(codes.InvalidArgument, "no last resort key package")
}
return nil
}

func validateUploadKeyPackagesRequest(req *proto.UploadKeyPackagesRequest) error {
if req == nil || len(req.KeyPackages) == 0 {
return status.Errorf(codes.InvalidArgument, "no key packages to upload")
}
return nil
}

func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) error {
if req == nil || len(req.WalletAddresses) == 0 {
return status.Errorf(codes.InvalidArgument, "no wallet addresses to get updates for")
}
return nil
}

func requireReadyToSend(groupId string, message []byte) error {
if groupId == "" {
return status.Errorf(codes.InvalidArgument, "group id is empty")
}
if len(message) == 0 {
return status.Errorf(codes.InvalidArgument, "message is empty")
}
return nil
}
Loading

0 comments on commit 8f6f633

Please sign in to comment.