Skip to content

Commit

Permalink
Make installation_id bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Oct 25, 2023
1 parent cd38308 commit cfffd9b
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 86 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ require (
github.com/uptrace/bun/driver/pgdriver v1.1.16
github.com/waku-org/go-waku v0.8.0
github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3
github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4
github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab
github.com/yoheimuta/protolint v0.39.0
go.uber.org/zap v1.24.0
golang.org/x/sync v0.3.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,10 @@ github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34 h1:rR10cJ5RTlw7OW
github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 h1:Qc2ed8NrlosJnPMNxVriugcFB21d4V90HKZdO83yV2M=
github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xmtp/proto/v3 v3.29.1-0.20231025193535-9760f07c3401 h1:ctStMkU5570kEgpVnsklo7PUOuTridm8bkXcZ1i6sXI=
github.com/xmtp/proto/v3 v3.29.1-0.20231025193535-9760f07c3401/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab h1:hWBftgxB7QWXDOOv1Wah6VZ6mwSFZX8e8rEGJQHm8zA=
github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw=
github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8=
Expand Down
12 changes: 6 additions & 6 deletions pkg/api/message/v3/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI
}

func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) {
ids := req.InstallationIds
ids := mlsstore.InstallationIdArray(req.InstallationIds)
keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err)
}
keyPackageMap := make(map[string]int)
for idx, id := range ids {
keyPackageMap[id] = idx
keyPackageMap[id.String()] = idx
}

resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages))
for _, keyPackage := range keyPackages {

idx, ok := keyPackageMap[keyPackage.InstallationId]
idx, ok := keyPackageMap[keyPackage.InstallationId.String()]
if !ok {
return nil, status.Errorf(codes.Internal, "could not find key package for installation")
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupR
for i, result := range validationResults {
message := messages[i]

if err = isReadyToSend(result.GroupId, message); err != nil {
if err = requireReadyToSend(result.GroupId, message); err != nil {
return nil, err
}

Expand All @@ -132,7 +132,7 @@ func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupR

func (s *Service) publishMessage(ctx context.Context, contentTopic string, message []byte) error {
log := s.log.Named("publish-mls").With(zap.String("content_topic", contentTopic))
env, err := s.messageStore.InsertMlsMessage(ctx, contentTopic, message)
env, err := s.messageStore.InsertMLSMessage(ctx, contentTopic, message)
if err != nil {
return status.Errorf(codes.Internal, "failed to insert message: %s", err)
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) err
return nil
}

func isReadyToSend(groupId string, message []byte) error {
func requireReadyToSend(groupId string, message []byte) error {
if groupId == "" {
return status.Errorf(codes.InvalidArgument, "group id is empty")
}
Expand Down
39 changes: 20 additions & 19 deletions pkg/api/message/v3/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ import (
test "github.com/xmtp/xmtp-node-go/pkg/testing"
)

type mockedMlsValidationService struct {
type mockedMLSValidationService struct {
mock.Mock
}

func (m *mockedMlsValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) {
func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) {
args := m.Called(ctx, keyPackages)

response := args.Get(0)
Expand All @@ -33,34 +33,35 @@ func (m *mockedMlsValidationService) ValidateKeyPackages(ctx context.Context, ke
return response.([]mlsvalidate.IdentityValidationResult), args.Error(1)
}

func (m *mockedMlsValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) {
func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) {
args := m.Called(ctx, groupMessages)

return args.Get(0).([]mlsvalidate.GroupMessageValidationResult), args.Error(1)
}

func newMockedValidationService() *mockedMlsValidationService {
return new(mockedMlsValidationService)
func newMockedValidationService() *mockedMLSValidationService {
return new(mockedMLSValidationService)
}

func (m *mockedMlsValidationService) mockValidateKeyPackages(installationId, walletAddress string) *mock.Call {
func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId mlsstore.InstallationId, walletAddress string) *mock.Call {
return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{
{
InstallationId: installationId,
WalletAddress: walletAddress,
InstallationId: installationId,
WalletAddress: walletAddress,
CredentialIdentity: []byte("test"),
},
}, nil)
}

func (m *mockedMlsValidationService) mockValidateGroupMessages(groupId string) *mock.Call {
func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId string) *mock.Call {
return m.On("ValidateGroupMessages", mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{
{
GroupId: groupId,
},
}, nil)
}

func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMlsValidationService, func()) {
func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMLSValidationService, func()) {
log := test.NewLog(t)
mlsDb, _, mlsDbCleanup := test.NewMLSDB(t)
mlsStore, err := mlsstore.New(ctx, mlsstore.Config{
Expand Down Expand Up @@ -95,7 +96,7 @@ func TestRegisterInstallation(t *testing.T) {
svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

installationId := test.RandomString(32)
installationId := test.RandomBytes(32)
walletAddress := test.RandomString(32)

mlsValidationService.mockValidateKeyPackages(installationId, walletAddress)
Expand Down Expand Up @@ -138,7 +139,7 @@ func TestUploadKeyPackages(t *testing.T) {
svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

installationId := test.RandomString(32)
installationId := test.RandomBytes(32)
walletAddress := test.RandomString(32)

mlsValidationService.mockValidateKeyPackages(installationId, walletAddress)
Expand Down Expand Up @@ -170,7 +171,7 @@ func TestConsumeKeyPackages(t *testing.T) {
svc, _, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

installationId1 := test.RandomString(32)
installationId1 := test.RandomBytes(32)
walletAddress1 := test.RandomString(32)

mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, walletAddress1)
Expand All @@ -184,7 +185,7 @@ func TestConsumeKeyPackages(t *testing.T) {
require.NotNil(t, res)

// Add a second key package
installationId2 := test.RandomString(32)
installationId2 := test.RandomBytes(32)
walletAddress2 := test.RandomString(32)
// Unset the original mock so we can set a new one
mockCall.Unset()
Expand All @@ -199,7 +200,7 @@ func TestConsumeKeyPackages(t *testing.T) {
require.NotNil(t, res)

consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
InstallationIds: []string{installationId1, installationId2},
InstallationIds: [][]byte{installationId1, installationId2},
})
require.NoError(t, err)
require.NotNil(t, consumeRes)
Expand All @@ -209,7 +210,7 @@ func TestConsumeKeyPackages(t *testing.T) {

// Now do it with the installationIds reversed
consumeRes, err = svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
InstallationIds: []string{installationId2, installationId1},
InstallationIds: [][]byte{installationId2, installationId1},
})

require.NoError(t, err)
Expand All @@ -226,7 +227,7 @@ func TestConsumeKeyPackagesFail(t *testing.T) {
defer cleanup()

consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{
InstallationIds: []string{test.RandomString(32)},
InstallationIds: [][]byte{test.RandomBytes(32)},
})
require.Error(t, err)
require.Nil(t, consumeRes)
Expand Down Expand Up @@ -266,7 +267,7 @@ func TestGetIdentityUpdates(t *testing.T) {
svc, _, mlsValidationService, cleanup := newTestService(t, ctx)
defer cleanup()

installationId := test.RandomString(32)
installationId := test.RandomBytes(32)
walletAddress := test.RandomString(32)

mockCall := mlsValidationService.mockValidateKeyPackages(installationId, walletAddress)
Expand All @@ -292,7 +293,7 @@ func TestGetIdentityUpdates(t *testing.T) {
}

mockCall.Unset()
mlsValidationService.mockValidateKeyPackages(test.RandomString(32), walletAddress)
mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), walletAddress)
_, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{
LastResortKeyPackage: &proto.KeyPackageUpload{
KeyPackageTlsSerialized: []byte("test"),
Expand Down
5 changes: 3 additions & 2 deletions pkg/migrations/mls/20231023050806_init-schema.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ SET

--bun:split
CREATE TABLE installations (
id TEXT PRIMARY KEY,
id BYTEA PRIMARY KEY,
wallet_address TEXT NOT NULL,
created_at BIGINT NOT NULL,
credential_identity BYTEA NOT NULL,
revoked_at BIGINT
);

--bun:split
CREATE TABLE key_packages (
id TEXT PRIMARY KEY,
installation_id TEXT NOT NULL,
installation_id BYTEA NOT NULL,
created_at BIGINT NOT NULL,
consumed_at BIGINT,
not_consumed BOOLEAN DEFAULT TRUE NOT NULL,
Expand Down
37 changes: 26 additions & 11 deletions pkg/mlsstore/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,38 @@ package mlsstore

import "github.com/uptrace/bun"

type InstallationId []byte

func (id InstallationId) String() string {
return string(id)
}

type Installation struct {
bun.BaseModel `bun:"table:installations"`

ID string `bun:",pk"`
WalletAddress string `bun:"wallet_address,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
RevokedAt *int64 `bun:"revoked_at"`
ID InstallationId `bun:",pk,type:bytea"`
WalletAddress string `bun:"wallet_address,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
RevokedAt *int64 `bun:"revoked_at"`
CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"`
}

type KeyPackage struct {
bun.BaseModel `bun:"table:key_packages"`

ID string `bun:",pk"` // ID is the hash of the data field
InstallationId string `bun:"installation_id,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
ConsumedAt *int64 `bun:"consumed_at"`
NotConsumed bool `bun:"not_consumed,default:true"`
IsLastResort bool `bun:"is_last_resort,notnull"`
Data []byte `bun:"data,notnull,type:bytea"`
ID string `bun:",pk"` // ID is the hash of the data field
InstallationId InstallationId `bun:"installation_id,notnull,type:bytea"`
CreatedAt int64 `bun:"created_at,notnull"`
ConsumedAt *int64 `bun:"consumed_at"`
NotConsumed bool `bun:"not_consumed,default:true"`
IsLastResort bool `bun:"is_last_resort,notnull"`
Data []byte `bun:"data,notnull,type:bytea"`
}

func InstallationIdArray(data [][]byte) []InstallationId {
result := make([]InstallationId, len(data))
for i, d := range data {
result[i] = InstallationId(d)
}
return result
}
26 changes: 14 additions & 12 deletions pkg/mlsstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ type Store struct {
}

type MlsStore interface {
CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error
CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error
InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error
ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error)
ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error)
GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error)
}

Expand All @@ -43,13 +43,14 @@ func New(ctx context.Context, config Config) (*Store, error) {
}

// Creates the installation and last resort key package
func (s *Store) CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error {
func (s *Store) CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error {
createdAt := nowNs()

installation := Installation{
ID: installationId,
WalletAddress: walletAddress,
CreatedAt: createdAt,
ID: installationId,
WalletAddress: walletAddress,
CreatedAt: createdAt,
CredentialIdentity: credentialIdentity,
}

keyPackage := NewKeyPackage(installationId, lastResortKeyPackage, true)
Expand Down Expand Up @@ -83,7 +84,7 @@ func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage
return err
}

func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) {
func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) {
keyPackages := make([]*KeyPackage, 0)
err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
err := tx.NewRaw(`
Expand Down Expand Up @@ -163,7 +164,7 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string
return out, nil
}

func (s *Store) RevokeInstallation(ctx context.Context, installationId string) error {
func (s *Store) RevokeInstallation(ctx context.Context, installationId InstallationId) error {
_, err := s.db.NewUpdate().
Model(&Installation{}).
Set("revoked_at = ?", nowNs()).
Expand All @@ -174,7 +175,7 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId string) e
return err
}

func NewKeyPackage(installationId string, data []byte, isLastResort bool) *KeyPackage {
func NewKeyPackage(installationId InstallationId, data []byte, isLastResort bool) *KeyPackage {
return &KeyPackage{
ID: buildKeyPackageId(data),
InstallationId: installationId,
Expand Down Expand Up @@ -229,9 +230,10 @@ const (
)

type IdentityUpdate struct {
Kind IdentityUpdateKind
InstallationId string
TimestampNs uint64
Kind IdentityUpdateKind
InstallationId []byte
CredentialIdentity []byte
TimestampNs uint64
}

// Add the required methods to make a valid sort.Sort interface
Expand Down
Loading

0 comments on commit cfffd9b

Please sign in to comment.