diff --git a/pkg/api/message/v3/service.go b/pkg/api/message/v3/service.go index bb3945fb..8e2e7ffc 100644 --- a/pkg/api/message/v3/service.go +++ b/pkg/api/message/v3/service.go @@ -67,20 +67,20 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI } func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) { - ids := mlsstore.InstallationIdArray(req.InstallationIds) + ids := req.InstallationIds keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids) if err != nil { return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err) } keyPackageMap := make(map[string]int) for idx, id := range ids { - keyPackageMap[id.String()] = idx + keyPackageMap[string(id)] = idx } resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages)) for _, keyPackage := range keyPackages { - idx, ok := keyPackageMap[keyPackage.InstallationId.String()] + idx, ok := keyPackageMap[string(keyPackage.InstallationId)] if !ok { return nil, status.Errorf(codes.Internal, "could not find key package for installation") } diff --git a/pkg/api/message/v3/service_test.go b/pkg/api/message/v3/service_test.go index c353f617..cc8bf02e 100644 --- a/pkg/api/message/v3/service_test.go +++ b/pkg/api/message/v3/service_test.go @@ -43,7 +43,7 @@ func newMockedValidationService() *mockedMLSValidationService { return new(mockedMLSValidationService) } -func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId mlsstore.InstallationId, walletAddress string) *mock.Call { +func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []byte, walletAddress string) *mock.Call { return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{ { InstallationId: installationId, diff --git a/pkg/mlsstore/models.go b/pkg/mlsstore/models.go index 216d0e8f..226db843 100644 --- a/pkg/mlsstore/models.go +++ b/pkg/mlsstore/models.go @@ -2,38 +2,24 @@ package mlsstore import "github.com/uptrace/bun" -type InstallationId []byte - -func (id InstallationId) String() string { - return string(id) -} - type Installation struct { bun.BaseModel `bun:"table:installations"` - ID InstallationId `bun:",pk,type:bytea"` - WalletAddress string `bun:"wallet_address,notnull"` - CreatedAt int64 `bun:"created_at,notnull"` - RevokedAt *int64 `bun:"revoked_at"` - CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` + ID []byte `bun:",pk,type:bytea"` + WalletAddress string `bun:"wallet_address,notnull"` + CreatedAt int64 `bun:"created_at,notnull"` + RevokedAt *int64 `bun:"revoked_at"` + CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` } type KeyPackage struct { bun.BaseModel `bun:"table:key_packages"` - ID string `bun:",pk"` // ID is the hash of the data field - InstallationId InstallationId `bun:"installation_id,notnull,type:bytea"` - CreatedAt int64 `bun:"created_at,notnull"` - ConsumedAt *int64 `bun:"consumed_at"` - NotConsumed bool `bun:"not_consumed,default:true"` - IsLastResort bool `bun:"is_last_resort,notnull"` - Data []byte `bun:"data,notnull,type:bytea"` -} - -func InstallationIdArray(data [][]byte) []InstallationId { - result := make([]InstallationId, len(data)) - for i, d := range data { - result[i] = InstallationId(d) - } - return result + ID string `bun:",pk"` // ID is the hash of the data field + InstallationId []byte `bun:"installation_id,notnull,type:bytea"` + CreatedAt int64 `bun:"created_at,notnull"` + ConsumedAt *int64 `bun:"consumed_at"` + NotConsumed bool `bun:"not_consumed,default:true"` + IsLastResort bool `bun:"is_last_resort,notnull"` + Data []byte `bun:"data,notnull,type:bytea"` } diff --git a/pkg/mlsstore/store.go b/pkg/mlsstore/store.go index 420bc4b3..b0f83523 100644 --- a/pkg/mlsstore/store.go +++ b/pkg/mlsstore/store.go @@ -22,9 +22,9 @@ type Store struct { } type MlsStore interface { - CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error + CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error - ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) + ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) } @@ -43,7 +43,7 @@ func New(ctx context.Context, config Config) (*Store, error) { } // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error { +func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error { createdAt := nowNs() installation := Installation{ @@ -84,7 +84,7 @@ func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage return err } -func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) { +func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) { keyPackages := make([]*KeyPackage, 0) err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { err := tx.NewRaw(` @@ -164,7 +164,7 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string return out, nil } -func (s *Store) RevokeInstallation(ctx context.Context, installationId InstallationId) error { +func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) error { _, err := s.db.NewUpdate(). Model(&Installation{}). Set("revoked_at = ?", nowNs()). @@ -175,7 +175,7 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId Installat return err } -func NewKeyPackage(installationId InstallationId, data []byte, isLastResort bool) *KeyPackage { +func NewKeyPackage(installationId []byte, data []byte, isLastResort bool) *KeyPackage { return &KeyPackage{ ID: buildKeyPackageId(data), InstallationId: installationId, diff --git a/pkg/mlsstore/store_test.go b/pkg/mlsstore/store_test.go index c610659f..50673cec 100644 --- a/pkg/mlsstore/store_test.go +++ b/pkg/mlsstore/store_test.go @@ -118,7 +118,7 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) require.NoError(t, err) - consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage, consumeResult[0].Data) @@ -146,13 +146,13 @@ func TestConsumeMultipleKeyPackages(t *testing.T) { Data: keyPackage2, }})) - consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) + consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) require.Equal(t, keyPackage2, consumeResult[0].Data) require.Equal(t, installationId, consumeResult[0].InstallationId) - consumeResult, err = store.ConsumeKeyPackages(ctx, []InstallationId{installationId}) + consumeResult, err = store.ConsumeKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) require.Len(t, consumeResult, 1) // Now we are out of regular key packages. Expect to consume the last resort