Skip to content

Commit

Permalink
Remove custom type
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Oct 26, 2023
1 parent 730065d commit e7a4f18
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 39 deletions.
6 changes: 3 additions & 3 deletions pkg/api/message/v3/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,20 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI
}

func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) {
ids := mlsstore.InstallationIdArray(req.InstallationIds)
ids := req.InstallationIds
keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err)
}
keyPackageMap := make(map[string]int)
for idx, id := range ids {
keyPackageMap[id.String()] = idx
keyPackageMap[string(id)] = idx
}

resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages))
for _, keyPackage := range keyPackages {

idx, ok := keyPackageMap[keyPackage.InstallationId.String()]
idx, ok := keyPackageMap[string(keyPackage.InstallationId)]
if !ok {
return nil, status.Errorf(codes.Internal, "could not find key package for installation")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/message/v3/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func newMockedValidationService() *mockedMLSValidationService {
return new(mockedMLSValidationService)
}

func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId mlsstore.InstallationId, walletAddress string) *mock.Call {
func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []byte, walletAddress string) *mock.Call {
return m.On("ValidateKeyPackages", mock.Anything, mock.Anything).Return([]mlsvalidate.IdentityValidationResult{
{
InstallationId: installationId,
Expand Down
38 changes: 12 additions & 26 deletions pkg/mlsstore/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,24 @@ package mlsstore

import "github.com/uptrace/bun"

type InstallationId []byte

func (id InstallationId) String() string {
return string(id)
}

type Installation struct {
bun.BaseModel `bun:"table:installations"`

ID InstallationId `bun:",pk,type:bytea"`
WalletAddress string `bun:"wallet_address,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
RevokedAt *int64 `bun:"revoked_at"`
CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"`
ID []byte `bun:",pk,type:bytea"`
WalletAddress string `bun:"wallet_address,notnull"`
CreatedAt int64 `bun:"created_at,notnull"`
RevokedAt *int64 `bun:"revoked_at"`
CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"`
}

type KeyPackage struct {
bun.BaseModel `bun:"table:key_packages"`

ID string `bun:",pk"` // ID is the hash of the data field
InstallationId InstallationId `bun:"installation_id,notnull,type:bytea"`
CreatedAt int64 `bun:"created_at,notnull"`
ConsumedAt *int64 `bun:"consumed_at"`
NotConsumed bool `bun:"not_consumed,default:true"`
IsLastResort bool `bun:"is_last_resort,notnull"`
Data []byte `bun:"data,notnull,type:bytea"`
}

func InstallationIdArray(data [][]byte) []InstallationId {
result := make([]InstallationId, len(data))
for i, d := range data {
result[i] = InstallationId(d)
}
return result
ID string `bun:",pk"` // ID is the hash of the data field
InstallationId []byte `bun:"installation_id,notnull,type:bytea"`
CreatedAt int64 `bun:"created_at,notnull"`
ConsumedAt *int64 `bun:"consumed_at"`
NotConsumed bool `bun:"not_consumed,default:true"`
IsLastResort bool `bun:"is_last_resort,notnull"`
Data []byte `bun:"data,notnull,type:bytea"`
}
12 changes: 6 additions & 6 deletions pkg/mlsstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ type Store struct {
}

type MlsStore interface {
CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error
CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error
InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error
ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error)
ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error)
GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error)
}

Expand All @@ -43,7 +43,7 @@ func New(ctx context.Context, config Config) (*Store, error) {
}

// Creates the installation and last resort key package
func (s *Store) CreateInstallation(ctx context.Context, installationId InstallationId, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error {
func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error {
createdAt := nowNs()

installation := Installation{
Expand Down Expand Up @@ -84,7 +84,7 @@ func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage
return err
}

func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds []InstallationId) ([]*KeyPackage, error) {
func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) {
keyPackages := make([]*KeyPackage, 0)
err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
err := tx.NewRaw(`
Expand Down Expand Up @@ -164,7 +164,7 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string
return out, nil
}

func (s *Store) RevokeInstallation(ctx context.Context, installationId InstallationId) error {
func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) error {
_, err := s.db.NewUpdate().
Model(&Installation{}).
Set("revoked_at = ?", nowNs()).
Expand All @@ -175,7 +175,7 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId Installat
return err
}

func NewKeyPackage(installationId InstallationId, data []byte, isLastResort bool) *KeyPackage {
func NewKeyPackage(installationId []byte, data []byte, isLastResort bool) *KeyPackage {
return &KeyPackage{
ID: buildKeyPackageId(data),
InstallationId: installationId,
Expand Down
6 changes: 3 additions & 3 deletions pkg/mlsstore/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func TestConsumeLastResortKeyPackage(t *testing.T) {
err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage)
require.NoError(t, err)

consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId})
consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId})
require.NoError(t, err)
require.Len(t, consumeResult, 1)
require.Equal(t, keyPackage, consumeResult[0].Data)
Expand Down Expand Up @@ -146,13 +146,13 @@ func TestConsumeMultipleKeyPackages(t *testing.T) {
Data: keyPackage2,
}}))

consumeResult, err := store.ConsumeKeyPackages(ctx, []InstallationId{installationId})
consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId})
require.NoError(t, err)
require.Len(t, consumeResult, 1)
require.Equal(t, keyPackage2, consumeResult[0].Data)
require.Equal(t, installationId, consumeResult[0].InstallationId)

consumeResult, err = store.ConsumeKeyPackages(ctx, []InstallationId{installationId})
consumeResult, err = store.ConsumeKeyPackages(ctx, [][]byte{installationId})
require.NoError(t, err)
require.Len(t, consumeResult, 1)
// Now we are out of regular key packages. Expect to consume the last resort
Expand Down

0 comments on commit e7a4f18

Please sign in to comment.