diff --git a/go.mod b/go.mod index e86b30f8..1481fb76 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 + github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab github.com/yoheimuta/protolint v0.39.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 diff --git a/go.sum b/go.sum index 506a1663..9e9fac51 100644 --- a/go.sum +++ b/go.sum @@ -1158,6 +1158,10 @@ github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34 h1:rR10cJ5RTlw7OW github.com/xmtp/proto/v3 v3.29.1-0.20231019225839-328520e94f34/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4 h1:Qc2ed8NrlosJnPMNxVriugcFB21d4V90HKZdO83yV2M= github.com/xmtp/proto/v3 v3.29.1-0.20231023182354-832c8d572ed4/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.29.1-0.20231025193535-9760f07c3401 h1:ctStMkU5570kEgpVnsklo7PUOuTridm8bkXcZ1i6sXI= +github.com/xmtp/proto/v3 v3.29.1-0.20231025193535-9760f07c3401/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab h1:hWBftgxB7QWXDOOv1Wah6VZ6mwSFZX8e8rEGJQHm8zA= +github.com/xmtp/proto/v3 v3.29.1-0.20231025220423-87413e63f3ab/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw= github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8= diff --git a/pkg/api/message/v3/service.go b/pkg/api/message/v3/service.go index cf388f63..3f79bdec 100644 --- a/pkg/api/message/v3/service.go +++ b/pkg/api/message/v3/service.go @@ -66,20 +66,20 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI } func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) { - ids := req.InstallationIds + ids := mlsstore.InstallationIdArray(req.InstallationIds) keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids) if err != nil { return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err) } keyPackageMap := make(map[string]int) for idx, id := range ids { - keyPackageMap[id] = idx + keyPackageMap[id.String()] = idx } resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages)) for _, keyPackage := range keyPackages { - idx, ok := keyPackageMap[keyPackage.InstallationId] + idx, ok := keyPackageMap[keyPackage.InstallationId.String()] if !ok { return nil, status.Errorf(codes.Internal, "could not find key package for installation") } @@ -117,7 +117,7 @@ func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupR for i, result := range validationResults { message := messages[i] - if err = isReadyToSend(result.GroupId, message); err != nil { + if err = requireReadyToSend(result.GroupId, message); err != nil { return nil, err } @@ -132,7 +132,7 @@ func (s *Service) PublishToGroup(ctx context.Context, req *proto.PublishToGroupR func (s *Service) publishMessage(ctx context.Context, contentTopic string, message []byte) error { log := s.log.Named("publish-mls").With(zap.String("content_topic", contentTopic)) - env, err := s.messageStore.InsertMlsMessage(ctx, contentTopic, message) + env, err := s.messageStore.InsertMLSMessage(ctx, contentTopic, message) if err != nil { return status.Errorf(codes.Internal, "failed to insert message: %s", err) } @@ -293,7 +293,7 @@ func validateGetIdentityUpdatesRequest(req *proto.GetIdentityUpdatesRequest) err return nil } -func isReadyToSend(groupId string, message []byte) error { +func requireReadyToSend(groupId string, message []byte) error { if groupId == "" { return status.Errorf(codes.InvalidArgument, "group id is empty") } diff --git a/pkg/api/message/v3/service_test.go b/pkg/api/message/v3/service_test.go index 9bc9bef0..c353f617 100644 --- a/pkg/api/message/v3/service_test.go +++ b/pkg/api/message/v3/service_test.go @@ -18,11 +18,11 @@ import ( test "github.com/xmtp/xmtp-node-go/pkg/testing" ) -type mockedMlsValidationService struct { +type mockedMLSValidationService struct { mock.Mock } -func (m *mockedMlsValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { +func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { args := m.Called(ctx, keyPackages) response := args.Get(0) @@ -33,26 +33,27 @@ func (m *mockedMlsValidationService) ValidateKeyPackages(ctx context.Context, ke return response.([]mlsvalidate.IdentityValidationResult), args.Error(1) } -func (m *mockedMlsValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) { +func (m *mockedMLSValidationService) ValidateGroupMessages(ctx context.Context, groupMessages [][]byte) ([]mlsvalidate.GroupMessageValidationResult, error) { args := m.Called(ctx, groupMessages) return args.Get(0).([]mlsvalidate.GroupMessageValidationResult), args.Error(1) } -func newMockedValidationService() *mockedMlsValidationService { - return new(mockedMlsValidationService) +func newMockedValidationService() *mockedMLSValidationService { + return new(mockedMLSValidationService) } -func (m *mockedMlsValidationService) mockValidateKeyPackages(installationId, walletAddress string) *mock.Call { +func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId mlsstore.InstallationId, walletAddress string) *mock.Call { return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{ { - InstallationId: installationId, - WalletAddress: walletAddress, + InstallationId: installationId, + WalletAddress: walletAddress, + CredentialIdentity: []byte("test"), }, }, nil) } -func (m *mockedMlsValidationService) mockValidateGroupMessages(groupId string) *mock.Call { +func (m *mockedMLSValidationService) mockValidateGroupMessages(groupId string) *mock.Call { return m.On("ValidateGroupMessages", mock.Anything, mock.Anything).Return([]mlsvalidate.GroupMessageValidationResult{ { GroupId: groupId, @@ -60,7 +61,7 @@ func (m *mockedMlsValidationService) mockValidateGroupMessages(groupId string) * }, nil) } -func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMlsValidationService, func()) { +func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, *mockedMLSValidationService, func()) { log := test.NewLog(t) mlsDb, _, mlsDbCleanup := test.NewMLSDB(t) mlsStore, err := mlsstore.New(ctx, mlsstore.Config{ @@ -95,7 +96,7 @@ func TestRegisterInstallation(t *testing.T) { svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) @@ -138,7 +139,7 @@ func TestUploadKeyPackages(t *testing.T) { svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) @@ -170,7 +171,7 @@ func TestConsumeKeyPackages(t *testing.T) { svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - installationId1 := test.RandomString(32) + installationId1 := test.RandomBytes(32) walletAddress1 := test.RandomString(32) mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, walletAddress1) @@ -184,7 +185,7 @@ func TestConsumeKeyPackages(t *testing.T) { require.NotNil(t, res) // Add a second key package - installationId2 := test.RandomString(32) + installationId2 := test.RandomBytes(32) walletAddress2 := test.RandomString(32) // Unset the original mock so we can set a new one mockCall.Unset() @@ -199,7 +200,7 @@ func TestConsumeKeyPackages(t *testing.T) { require.NotNil(t, res) consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ - InstallationIds: []string{installationId1, installationId2}, + InstallationIds: [][]byte{installationId1, installationId2}, }) require.NoError(t, err) require.NotNil(t, consumeRes) @@ -209,7 +210,7 @@ func TestConsumeKeyPackages(t *testing.T) { // Now do it with the installationIds reversed consumeRes, err = svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ - InstallationIds: []string{installationId2, installationId1}, + InstallationIds: [][]byte{installationId2, installationId1}, }) require.NoError(t, err) @@ -226,7 +227,7 @@ func TestConsumeKeyPackagesFail(t *testing.T) { defer cleanup() consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ - InstallationIds: []string{test.RandomString(32)}, + InstallationIds: [][]byte{test.RandomBytes(32)}, }) require.Error(t, err) require.Nil(t, consumeRes) @@ -266,7 +267,7 @@ func TestGetIdentityUpdates(t *testing.T) { svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) mockCall := mlsValidationService.mockValidateKeyPackages(installationId, walletAddress) @@ -292,7 +293,7 @@ func TestGetIdentityUpdates(t *testing.T) { } mockCall.Unset() - mlsValidationService.mockValidateKeyPackages(test.RandomString(32), walletAddress) + mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), walletAddress) _, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ LastResortKeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), diff --git a/pkg/migrations/mls/20231023050806_init-schema.up.sql b/pkg/migrations/mls/20231023050806_init-schema.up.sql index 7b2c54ca..e77152e9 100644 --- a/pkg/migrations/mls/20231023050806_init-schema.up.sql +++ b/pkg/migrations/mls/20231023050806_init-schema.up.sql @@ -3,16 +3,17 @@ SET --bun:split CREATE TABLE installations ( - id TEXT PRIMARY KEY, + id BYTEA PRIMARY KEY, wallet_address TEXT NOT NULL, created_at BIGINT NOT NULL, + credential_identity BYTEA NOT NULL, revoked_at BIGINT ); --bun:split CREATE TABLE key_packages ( id TEXT PRIMARY KEY, - installation_id TEXT NOT NULL, + installation_id BYTEA NOT NULL, created_at BIGINT NOT NULL, consumed_at BIGINT, not_consumed BOOLEAN DEFAULT TRUE NOT NULL, diff --git a/pkg/mlsstore/models.go b/pkg/mlsstore/models.go index 533c595d..216d0e8f 100644 --- a/pkg/mlsstore/models.go +++ b/pkg/mlsstore/models.go @@ -2,23 +2,38 @@ package mlsstore import "github.com/uptrace/bun" +type InstallationId []byte + +func (id InstallationId) String() string { + return string(id) +} + type Installation struct { bun.BaseModel `bun:"table:installations"` - ID string `bun:",pk"` - WalletAddress string `bun:"wallet_address,notnull"` - CreatedAt int64 `bun:"created_at,notnull"` - RevokedAt *int64 `bun:"revoked_at"` + ID InstallationId `bun:",pk,type:bytea"` + WalletAddress string `bun:"wallet_address,notnull"` + CreatedAt int64 `bun:"created_at,notnull"` + RevokedAt *int64 `bun:"revoked_at"` + CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` } type KeyPackage struct { bun.BaseModel `bun:"table:key_packages"` - ID string `bun:",pk"` // ID is the hash of the data field - InstallationId string `bun:"installation_id,notnull"` - CreatedAt int64 `bun:"created_at,notnull"` - ConsumedAt *int64 `bun:"consumed_at"` - NotConsumed bool `bun:"not_consumed,default:true"` - IsLastResort bool `bun:"is_last_resort,notnull"` - Data []byte `bun:"data,notnull,type:bytea"` + ID string `bun:",pk"` // ID is the hash of the data field + InstallationId InstallationId `bun:"installation_id,notnull,type:bytea"` + CreatedAt int64 `bun:"created_at,notnull"` + ConsumedAt *int64 `bun:"consumed_at"` + NotConsumed bool `bun:"not_consumed,default:true"` + IsLastResort bool `bun:"is_last_resort,notnull"` + Data []byte `bun:"data,notnull,type:bytea"` +} + +func InstallationIdArray(data [][]byte) []InstallationId { + result := make([]InstallationId, len(data)) + for i, d := range data { + result[i] = InstallationId(d) + } + return result } diff --git a/pkg/mlsstore/store.go b/pkg/mlsstore/store.go index 7f05c5e5..420bc4b3 100644 --- a/pkg/mlsstore/store.go +++ b/pkg/mlsstore/store.go @@ -22,9 +22,9 @@ type Store struct { } type MlsStore interface { - CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error + CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error - ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) + ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) } @@ -43,13 +43,14 @@ func New(ctx context.Context, config Config) (*Store, error) { } // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId string, walletAddress string, lastResortKeyPackage []byte) error { +func (s *Store) CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error { createdAt := nowNs() installation := Installation{ - ID: installationId, - WalletAddress: walletAddress, - CreatedAt: createdAt, + ID: installationId, + WalletAddress: walletAddress, + CreatedAt: createdAt, + CredentialIdentity: credentialIdentity, } keyPackage := NewKeyPackage(installationId, lastResortKeyPackage, true) @@ -83,7 +84,7 @@ func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage return err } -func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []string) ([]*KeyPackage, error) { +func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) { keyPackages := make([]*KeyPackage, 0) err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { err := tx.NewRaw(` @@ -163,7 +164,7 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string return out, nil } -func (s *Store) RevokeInstallation(ctx context.Context, installationId string) error { +func (s *Store) RevokeInstallation(ctx context.Context, installationId InstallationId) error { _, err := s.db.NewUpdate(). Model(&Installation{}). Set("revoked_at = ?", nowNs()). @@ -174,7 +175,7 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId string) e return err } -func NewKeyPackage(installationId string, data []byte, isLastResort bool) *KeyPackage { +func NewKeyPackage(installationId InstallationId, data []byte, isLastResort bool) *KeyPackage { return &KeyPackage{ ID: buildKeyPackageId(data), InstallationId: installationId, @@ -229,9 +230,10 @@ const ( ) type IdentityUpdate struct { - Kind IdentityUpdateKind - InstallationId string - TimestampNs uint64 + Kind IdentityUpdateKind + InstallationId []byte + CredentialIdentity []byte + TimestampNs uint64 } // Add the required methods to make a valid sort.Sort interface diff --git a/pkg/mlsstore/store_test.go b/pkg/mlsstore/store_test.go index a1fa3eec..c610659f 100644 --- a/pkg/mlsstore/store_test.go +++ b/pkg/mlsstore/store_test.go @@ -29,10 +29,10 @@ func TestCreateInstallation(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32)) + err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) require.NoError(t, err) installationFromDb := &Installation{} @@ -49,13 +49,13 @@ func TestCreateInstallationIdempotent(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - err = store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32)) + err = store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) require.NoError(t, err) keyPackageFromDb := &KeyPackage{} @@ -68,11 +68,11 @@ func TestInsertKeyPackages(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) keyPackage2 := test.RandomBytes(32) @@ -111,14 +111,14 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage, consumeResult[0].Data) @@ -130,11 +130,11 @@ func TestConsumeMultipleKeyPackages(t *testing.T) { defer cleanup() ctx := context.Background() - installationId := test.RandomString(32) + installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) keyPackage2 := test.RandomBytes(32) @@ -146,13 +146,13 @@ func TestConsumeMultipleKeyPackages(t *testing.T) { Data: keyPackage2, }})) - consumeResult, err := store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage2, consumeResult[0].Data) require.Equal(t, installationId, consumeResult[0].InstallationId) - consumeResult, err = store.ConsumeKeyPackages(ctx, []string{installationId}) + consumeResult, err = store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) // Now we are out of regular key packages. Expect to consume the last resort @@ -166,16 +166,16 @@ func TestGetIdentityUpdates(t *testing.T) { ctx := context.Background() walletAddress := test.RandomString(32) - installationId1 := test.RandomString(32) + installationId1 := test.RandomBytes(32) keyPackage1 := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1) + err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1, keyPackage1) require.NoError(t, err) - installationId2 := test.RandomString(32) + installationId2 := test.RandomBytes(32) keyPackage2 := test.RandomBytes(32) - err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2) + err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2, keyPackage2) require.NoError(t, err) identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) @@ -197,17 +197,17 @@ func TestGetIdentityUpdatesMultipleWallets(t *testing.T) { ctx := context.Background() walletAddress1 := test.RandomString(32) - installationId1 := test.RandomString(32) + installationId1 := test.RandomBytes(32) keyPackage1 := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1) + err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1, keyPackage1) require.NoError(t, err) walletAddress2 := test.RandomString(32) - installationId2 := test.RandomString(32) + installationId2 := test.RandomBytes(32) keyPackage2 := test.RandomBytes(32) - err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2) + err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2, keyPackage2) require.NoError(t, err) identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress1, walletAddress2}, 0) diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index 36756ee2..686e533e 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -10,8 +10,9 @@ import ( ) type IdentityValidationResult struct { - WalletAddress string - InstallationId string + WalletAddress string + InstallationId []byte + CredentialIdentity []byte } type GroupMessageValidationResult struct { @@ -54,8 +55,9 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages(ctx context.Context, keyP return nil, fmt.Errorf("validation failed with error %s", response.ErrorMessage) } out[i] = IdentityValidationResult{ - WalletAddress: response.WalletAddress, - InstallationId: response.InstallationId, + WalletAddress: response.WalletAddress, + InstallationId: response.InstallationId, + CredentialIdentity: response.CredentialIdentityBytes, } } return out, nil diff --git a/pkg/mlsvalidate/service_test.go b/pkg/mlsvalidate/service_test.go index 2fb2cca3..9613eb98 100644 --- a/pkg/mlsvalidate/service_test.go +++ b/pkg/mlsvalidate/service_test.go @@ -41,10 +41,11 @@ func TestValidateKeyPackages(t *testing.T) { ctx := context.Background() firstResponse := svc.ValidateKeyPackagesResponse_ValidationResponse{ - IsOk: true, - WalletAddress: "0x123", - InstallationId: "123", - ErrorMessage: "", + IsOk: true, + WalletAddress: "0x123", + InstallationId: []byte("123"), + CredentialIdentityBytes: []byte("456"), + ErrorMessage: "", } mockGrpc.On("ValidateKeyPackages", ctx, mock.Anything).Return(&svc.ValidateKeyPackagesResponse{ @@ -55,7 +56,8 @@ func TestValidateKeyPackages(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(res)) assert.Equal(t, "0x123", res[0].WalletAddress) - assert.Equal(t, "123", res[0].InstallationId) + assert.Equal(t, []byte("123"), res[0].InstallationId) + assert.Equal(t, []byte("456"), res[0].CredentialIdentity) } func TestValidateKeyPackagesError(t *testing.T) { diff --git a/pkg/store/query_test.go b/pkg/store/query_test.go index 0a57e96e..d3c2c028 100644 --- a/pkg/store/query_test.go +++ b/pkg/store/query_test.go @@ -294,7 +294,7 @@ func TestMlsMessagePublish(t *testing.T) { contentTopic := "foo" ctx := context.Background() - env, err := store.InsertMlsMessage(ctx, contentTopic, message) + env, err := store.InsertMLSMessage(ctx, contentTopic, message) require.NoError(t, err) require.Equal(t, env.ContentTopic, contentTopic) diff --git a/pkg/store/store.go b/pkg/store/store.go index b697b7d1..4a309044 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -144,7 +144,7 @@ func (s *Store) InsertMessage(env *messagev1.Envelope) (bool, error) { return stored, err } -func (s *Store) InsertMlsMessage(ctx context.Context, contentTopic string, data []byte) (*messagev1.Envelope, error) { +func (s *Store) InsertMLSMessage(ctx context.Context, contentTopic string, data []byte) (*messagev1.Envelope, error) { tmpEnvelope := &messagev1.Envelope{ ContentTopic: contentTopic, Message: data, diff --git a/pkg/topic/topic.go b/pkg/topic/topic.go index 473f141b..f0192bbf 100644 --- a/pkg/topic/topic.go +++ b/pkg/topic/topic.go @@ -43,6 +43,6 @@ func BuildGroupTopic(groupId string) string { return fmt.Sprintf("/xmtp/3/g-%s/proto", groupId) } -func BuildWelcomeTopic(installationId string) string { - return fmt.Sprintf("/xmtp/3/w-%s/proto", installationId) +func BuildWelcomeTopic(installationId []byte) string { + return fmt.Sprintf("/xmtp/3/w-%x/proto", installationId) }