diff --git a/go.mod b/go.mod index a19af8f3..e8439dcf 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( github.com/uptrace/bun/driver/pgdriver v1.1.16 github.com/waku-org/go-waku v0.8.0 github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 - github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090 + github.com/xmtp/proto/v3 v3.36.1 github.com/yoheimuta/protolint v0.39.0 go.uber.org/zap v1.24.0 golang.org/x/sync v0.3.0 diff --git a/go.sum b/go.sum index f71cdb51..4000f59d 100644 --- a/go.sum +++ b/go.sum @@ -1146,8 +1146,8 @@ github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0 github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3 h1:wzUffJGCTBGXIDyNU+1UBu1fn2Nzo+OQzM1pLrheh58= github.com/xmtp/go-msgio v0.2.1-0.20220510223757-25a701b79cd3/go.mod h1:bJREWk+NDnZYjgLQdAi8SUWuq/5pkMme4GqiffEhUF4= -github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090 h1:+0KTgQiUfu5UxgLjP18VL4BtG6hJMJYL0n1mVXtf3Ss= -github.com/xmtp/proto/v3 v3.36.1-0.20231219054634-2ff03b7d5090/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= +github.com/xmtp/proto/v3 v3.36.1 h1:eBUsWlA/jbfhxfbDtbAofBpi8Q+TIXPMz84e64o1XTE= +github.com/xmtp/proto/v3 v3.36.1/go.mod h1:NF2zAjtNpVIhS4tFG19g4L1tJcPZHm81oeDFXltmOiY= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/yoheimuta/go-protoparser/v4 v4.6.0 h1:uvz1e9/5Ihsm4Ku8AJeDImTpirKmIxubZdSn0QJNdnw= github.com/yoheimuta/go-protoparser/v4 v4.6.0/go.mod h1:AHNNnSWnb0UoL4QgHPiOAg2BniQceFscPI5X/BZNHl8= diff --git a/pkg/api/message/v3/service.go b/pkg/api/message/v3/service.go index 3c1a6154..dbdb33c7 100644 --- a/pkg/api/message/v3/service.go +++ b/pkg/api/message/v3/service.go @@ -45,7 +45,7 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI return nil, err } - results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.LastResortKeyPackage.KeyPackageTlsSerialized}) + results, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{req.KeyPackage.KeyPackageTlsSerialized}) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) } @@ -57,7 +57,7 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI accountAddress := results[0].AccountAddress credentialIdentity := results[0].CredentialIdentity - if err = s.mlsStore.CreateInstallation(ctx, installationId, accountAddress, req.LastResortKeyPackage.KeyPackageTlsSerialized, credentialIdentity); err != nil { + if err = s.mlsStore.CreateInstallation(ctx, installationId, accountAddress, credentialIdentity, req.KeyPackage.KeyPackageTlsSerialized, results[0].Expiration); err != nil { return nil, err } @@ -66,31 +66,31 @@ func (s *Service) RegisterInstallation(ctx context.Context, req *proto.RegisterI }, nil } -func (s *Service) ConsumeKeyPackages(ctx context.Context, req *proto.ConsumeKeyPackagesRequest) (*proto.ConsumeKeyPackagesResponse, error) { +func (s *Service) FetchKeyPackages(ctx context.Context, req *proto.FetchKeyPackagesRequest) (*proto.FetchKeyPackagesResponse, error) { ids := req.InstallationIds - keyPackages, err := s.mlsStore.ConsumeKeyPackages(ctx, ids) + installations, err := s.mlsStore.FetchKeyPackages(ctx, ids) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to consume key packages: %s", err) + return nil, status.Errorf(codes.Internal, "failed to fetch key packages: %s", err) } keyPackageMap := make(map[string]int) for idx, id := range ids { keyPackageMap[string(id)] = idx } - resPackages := make([]*proto.ConsumeKeyPackagesResponse_KeyPackage, len(keyPackages)) - for _, keyPackage := range keyPackages { + resPackages := make([]*proto.FetchKeyPackagesResponse_KeyPackage, len(ids)) + for _, installation := range installations { - idx, ok := keyPackageMap[string(keyPackage.InstallationId)] + idx, ok := keyPackageMap[string(installation.ID)] if !ok { return nil, status.Errorf(codes.Internal, "could not find key package for installation") } - resPackages[idx] = &proto.ConsumeKeyPackagesResponse_KeyPackage{ - KeyPackageTlsSerialized: keyPackage.Data, + resPackages[idx] = &proto.FetchKeyPackagesResponse_KeyPackage{ + KeyPackageTlsSerialized: installation.KeyPackage, } } - return &proto.ConsumeKeyPackagesResponse{ + return &proto.FetchKeyPackagesResponse{ KeyPackages: resPackages, }, nil } @@ -166,28 +166,22 @@ func (s *Service) PublishWelcomes(ctx context.Context, req *proto.PublishWelcome return &emptypb.Empty{}, nil } -func (s *Service) UploadKeyPackages(ctx context.Context, req *proto.UploadKeyPackagesRequest) (res *emptypb.Empty, err error) { - if err = validateUploadKeyPackagesRequest(req); err != nil { +func (s *Service) UploadKeyPackage(ctx context.Context, req *proto.UploadKeyPackageRequest) (res *emptypb.Empty, err error) { + if err = validateUploadKeyPackageRequest(req); err != nil { return nil, err } // Extract the key packages from the request - keyPackageBytes := make([][]byte, len(req.KeyPackages)) - for i, keyPackage := range req.KeyPackages { - keyPackageBytes[i] = keyPackage.KeyPackageTlsSerialized - } - validationResults, err := s.validationService.ValidateKeyPackages(ctx, keyPackageBytes) + keyPackageBytes := req.KeyPackage.KeyPackageTlsSerialized + + validationResults, err := s.validationService.ValidateKeyPackages(ctx, [][]byte{keyPackageBytes}) if err != nil { // TODO: Differentiate between validation errors and internal errors return nil, status.Errorf(codes.InvalidArgument, "invalid identity: %s", err) } + installationId := validationResults[0].InstallationId + expiration := validationResults[0].Expiration - keyPackageModels := make([]*mlsstore.KeyPackage, len(validationResults)) - for i, validationResult := range validationResults { - kp := mlsstore.NewKeyPackage(validationResult.InstallationId, keyPackageBytes[i], false) - keyPackageModels[i] = kp - } - - if err = s.mlsStore.InsertKeyPackages(ctx, keyPackageModels); err != nil { + if err = s.mlsStore.UpdateKeyPackage(ctx, installationId, keyPackageBytes, expiration); err != nil { return nil, status.Errorf(codes.Internal, "failed to insert key packages: %s", err) } @@ -275,15 +269,15 @@ func validatePublishWelcomesRequest(req *proto.PublishWelcomesRequest) error { } func validateRegisterInstallationRequest(req *proto.RegisterInstallationRequest) error { - if req == nil || req.LastResortKeyPackage == nil { - return status.Errorf(codes.InvalidArgument, "no last resort key package") + if req == nil || req.KeyPackage == nil { + return status.Errorf(codes.InvalidArgument, "no key package") } return nil } -func validateUploadKeyPackagesRequest(req *proto.UploadKeyPackagesRequest) error { - if req == nil || len(req.KeyPackages) == 0 { - return status.Errorf(codes.InvalidArgument, "no key packages to upload") +func validateUploadKeyPackageRequest(req *proto.UploadKeyPackageRequest) error { + if req == nil || req.KeyPackage == nil { + return status.Errorf(codes.InvalidArgument, "no key package") } return nil } diff --git a/pkg/api/message/v3/service_test.go b/pkg/api/message/v3/service_test.go index ec068399..1134af5d 100644 --- a/pkg/api/message/v3/service_test.go +++ b/pkg/api/message/v3/service_test.go @@ -49,6 +49,7 @@ func (m *mockedMLSValidationService) mockValidateKeyPackages(installationId []by InstallationId: installationId, AccountAddress: accountAddress, CredentialIdentity: []byte("test"), + Expiration: 0, }, }, nil) } @@ -102,7 +103,7 @@ func TestRegisterInstallation(t *testing.T) { mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -126,7 +127,7 @@ func TestRegisterInstallationError(t *testing.T) { mlsValidationService.On("ValidateKeyPackages", ctx, mock.Anything).Return(nil, errors.New("error validating")) res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -134,7 +135,7 @@ func TestRegisterInstallationError(t *testing.T) { require.Nil(t, res) } -func TestUploadKeyPackages(t *testing.T) { +func TestUploadKeyPackage(t *testing.T) { ctx := context.Background() svc, mlsDb, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() @@ -145,28 +146,27 @@ func TestUploadKeyPackages(t *testing.T) { mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) require.NoError(t, err) require.NotNil(t, res) - uploadRes, err := svc.UploadKeyPackages(ctx, &proto.UploadKeyPackagesRequest{ - KeyPackages: []*proto.KeyPackageUpload{ - {KeyPackageTlsSerialized: []byte("test2")}, + uploadRes, err := svc.UploadKeyPackage(ctx, &proto.UploadKeyPackageRequest{ + KeyPackage: &proto.KeyPackageUpload{ + KeyPackageTlsSerialized: []byte("test2"), }, }) require.NoError(t, err) require.NotNil(t, uploadRes) - keyPackages := []mlsstore.KeyPackage{} - err = mlsDb.NewSelect().Model(&keyPackages).Where("installation_id = ?", installationId).Scan(ctx) + installation := &mlsstore.Installation{} + err = mlsDb.NewSelect().Model(installation).Where("id = ?", installationId).Scan(ctx) require.NoError(t, err) - require.Len(t, keyPackages, 2) } -func TestConsumeKeyPackages(t *testing.T) { +func TestFetchKeyPackages(t *testing.T) { ctx := context.Background() svc, _, mlsValidationService, cleanup := newTestService(t, ctx) defer cleanup() @@ -177,7 +177,7 @@ func TestConsumeKeyPackages(t *testing.T) { mockCall := mlsValidationService.mockValidateKeyPackages(installationId1, accountAddress1) res, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -192,14 +192,14 @@ func TestConsumeKeyPackages(t *testing.T) { mlsValidationService.mockValidateKeyPackages(installationId2, accountAddress2) res, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test2"), }, }) require.NoError(t, err) require.NotNil(t, res) - consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ InstallationIds: [][]byte{installationId1, installationId2}, }) require.NoError(t, err) @@ -209,7 +209,7 @@ func TestConsumeKeyPackages(t *testing.T) { require.Equal(t, []byte("test2"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized) // Now do it with the installationIds reversed - consumeRes, err = svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + consumeRes, err = svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ InstallationIds: [][]byte{installationId2, installationId1}, }) @@ -220,17 +220,17 @@ func TestConsumeKeyPackages(t *testing.T) { require.Equal(t, []byte("test"), consumeRes.KeyPackages[1].KeyPackageTlsSerialized) } -// Trying to consume key packages that don't exist should fail -func TestConsumeKeyPackagesFail(t *testing.T) { +// Trying to fetch key packages that don't exist should return nil +func TestFetchKeyPackagesFail(t *testing.T) { ctx := context.Background() svc, _, _, cleanup := newTestService(t, ctx) defer cleanup() - consumeRes, err := svc.ConsumeKeyPackages(ctx, &proto.ConsumeKeyPackagesRequest{ + consumeRes, err := svc.FetchKeyPackages(ctx, &proto.FetchKeyPackagesRequest{ InstallationIds: [][]byte{test.RandomBytes(32)}, }) - require.Error(t, err) - require.Nil(t, consumeRes) + require.Nil(t, err) + require.Equal(t, []*proto.FetchKeyPackagesResponse_KeyPackage{nil}, consumeRes.KeyPackages) } func TestPublishToGroup(t *testing.T) { @@ -273,7 +273,7 @@ func TestGetIdentityUpdates(t *testing.T) { mockCall := mlsValidationService.mockValidateKeyPackages(installationId, accountAddress) _, err := svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) @@ -297,7 +297,7 @@ func TestGetIdentityUpdates(t *testing.T) { mockCall.Unset() mlsValidationService.mockValidateKeyPackages(test.RandomBytes(32), accountAddress) _, err = svc.RegisterInstallation(ctx, &proto.RegisterInstallationRequest{ - LastResortKeyPackage: &proto.KeyPackageUpload{ + KeyPackage: &proto.KeyPackageUpload{ KeyPackageTlsSerialized: []byte("test"), }, }) diff --git a/pkg/authn/authn.pb.go b/pkg/authn/authn.pb.go index 5a3d5733..e7e88e3f 100644 --- a/pkg/authn/authn.pb.go +++ b/pkg/authn/authn.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.1 -// protoc v4.24.3 +// protoc v3.21.8 // source: authn.proto package authn @@ -28,6 +28,7 @@ type Signature struct { unknownFields protoimpl.UnknownFields // Types that are assignable to Union: + // // *Signature_EcdsaCompact Union isSignature_Union `protobuf_oneof:"union"` } @@ -146,6 +147,7 @@ type PublicKey struct { Timestamp uint64 `protobuf:"varint,1,opt,name=timestamp,proto3" json:"timestamp,omitempty"` Signature *Signature `protobuf:"bytes,2,opt,name=signature,proto3,oneof" json:"signature,omitempty"` // Types that are assignable to Union: + // // *PublicKey_Secp256K1Uncompressed Union isPublicKey_Union `protobuf_oneof:"union"` } @@ -297,6 +299,7 @@ type ClientAuthRequest struct { unknownFields protoimpl.UnknownFields // Types that are assignable to Version: + // // *ClientAuthRequest_V1 Version isClientAuthRequest_Version `protobuf_oneof:"version"` } @@ -418,6 +421,7 @@ type ClientAuthResponse struct { unknownFields protoimpl.UnknownFields // Types that are assignable to Version: + // // *ClientAuthResponse_V1 Version isClientAuthResponse_Version `protobuf_oneof:"version"` } diff --git a/pkg/migrations/mls/20231023050806_init-schema.down.sql b/pkg/migrations/mls/20231023050806_init-schema.down.sql index b1077270..bcbc86f0 100644 --- a/pkg/migrations/mls/20231023050806_init-schema.down.sql +++ b/pkg/migrations/mls/20231023050806_init-schema.down.sql @@ -3,6 +3,3 @@ SET --bun:split DROP TABLE IF EXISTS installations; - ---bun:split -DROP TABLE IF EXISTS key_packages; \ No newline at end of file diff --git a/pkg/migrations/mls/20231023050806_init-schema.up.sql b/pkg/migrations/mls/20231023050806_init-schema.up.sql index e77152e9..a0ebe97c 100644 --- a/pkg/migrations/mls/20231023050806_init-schema.up.sql +++ b/pkg/migrations/mls/20231023050806_init-schema.up.sql @@ -6,21 +6,12 @@ CREATE TABLE installations ( id BYTEA PRIMARY KEY, wallet_address TEXT NOT NULL, created_at BIGINT NOT NULL, + updated_at BIGINT NOT NULL, credential_identity BYTEA NOT NULL, - revoked_at BIGINT -); + revoked_at BIGINT, ---bun:split -CREATE TABLE key_packages ( - id TEXT PRIMARY KEY, - installation_id BYTEA NOT NULL, - created_at BIGINT NOT NULL, - consumed_at BIGINT, - not_consumed BOOLEAN DEFAULT TRUE NOT NULL, - is_last_resort BOOLEAN NOT NULL, - data BYTEA NOT NULL, - -- Add a foreign key constraint to ensure key packages cannot be added for unregistered installations - CONSTRAINT fk_installation_id FOREIGN KEY (installation_id) REFERENCES installations (id) + key_package BYTEA NOT NULL, + expiration BIGINT NOT NULL ); --bun:split @@ -31,15 +22,3 @@ CREATE INDEX idx_installations_created_at ON installations(created_at); --bun:split CREATE INDEX idx_installations_revoked_at ON installations(revoked_at); - ---bun:split --- Adding indexes for the key_packages table -CREATE INDEX idx_key_packages_installation_id_not_consumed_is_last_resort_created_at ON key_packages( - installation_id, - not_consumed, - is_last_resort, - created_at -); - ---bun:split -CREATE INDEX idx_key_packages_is_last_resort_id ON key_packages(is_last_resort, id); \ No newline at end of file diff --git a/pkg/mlsstore/models.go b/pkg/mlsstore/models.go index 226db843..cdfd0da0 100644 --- a/pkg/mlsstore/models.go +++ b/pkg/mlsstore/models.go @@ -8,18 +8,10 @@ type Installation struct { ID []byte `bun:",pk,type:bytea"` WalletAddress string `bun:"wallet_address,notnull"` CreatedAt int64 `bun:"created_at,notnull"` + UpdatedAt int64 `bun:"updated_at,notnull"` RevokedAt *int64 `bun:"revoked_at"` CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` -} - -type KeyPackage struct { - bun.BaseModel `bun:"table:key_packages"` - ID string `bun:",pk"` // ID is the hash of the data field - InstallationId []byte `bun:"installation_id,notnull,type:bytea"` - CreatedAt int64 `bun:"created_at,notnull"` - ConsumedAt *int64 `bun:"consumed_at"` - NotConsumed bool `bun:"not_consumed,default:true"` - IsLastResort bool `bun:"is_last_resort,notnull"` - Data []byte `bun:"data,notnull,type:bytea"` + KeyPackage []byte `bun:"key_package,notnull,type:bytea"` + Expiration uint64 `bun:"expiration,notnull"` } diff --git a/pkg/mlsstore/store.go b/pkg/mlsstore/store.go index 41e74b41..17f17bf2 100644 --- a/pkg/mlsstore/store.go +++ b/pkg/mlsstore/store.go @@ -2,9 +2,6 @@ package mlsstore import ( "context" - "crypto/sha256" - "database/sql" - "encoding/hex" "errors" "sort" "time" @@ -22,9 +19,9 @@ type Store struct { } type MlsStore interface { - CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error - InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error - ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) + CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error + UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error + FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]*Installation, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) } @@ -43,83 +40,67 @@ func New(ctx context.Context, config Config) (*Store, error) { } // Creates the installation and last resort key package -func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, lastResortKeyPackage []byte, credentialIdentity []byte) error { +func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error { createdAt := nowNs() installation := Installation{ ID: installationId, WalletAddress: walletAddress, CreatedAt: createdAt, + UpdatedAt: createdAt, CredentialIdentity: credentialIdentity, - } - - keyPackage := NewKeyPackage(installationId, lastResortKeyPackage, true) - - return s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - _, err := tx.NewInsert(). - Model(&installation). - Ignore(). - Exec(ctx) - - if err != nil { - return err - } - _, err = tx.NewInsert(). - Model(keyPackage). - Ignore(). - Exec(ctx) - - if err != nil { - return err - } - - return nil - }) -} + KeyPackage: keyPackage, + Expiration: expiration, + } -// Insert a batch of key packages, ignoring any that may already exist -func (s *Store) InsertKeyPackages(ctx context.Context, keyPackages []*KeyPackage) error { - _, err := s.db.NewInsert().Model(&keyPackages).Ignore().Exec(ctx) + _, err := s.db.NewInsert(). + Model(&installation). + Ignore(). + Exec(ctx) return err } -func (s *Store) ConsumeKeyPackages(ctx context.Context, installationIds [][]byte) ([]*KeyPackage, error) { - keyPackages := make([]*KeyPackage, 0) - err := s.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { - err := tx.NewRaw(` - SELECT DISTINCT ON(installation_id) * FROM key_packages - WHERE "installation_id" IN (?) - AND not_consumed = TRUE - ORDER BY installation_id ASC, is_last_resort ASC, created_at ASC - `, - bun.In(installationIds)). - Scan(ctx, &keyPackages) - - if err != nil { - return err - } - - if len(keyPackages) < len(installationIds) { - return errors.New("key packages not found") - } +// Insert a new key package, ignoring any that may already exist +func (s *Store) UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error { + installation := Installation{ + ID: installationId, + UpdatedAt: nowNs(), - _, err = tx.NewUpdate(). - Table("key_packages"). - Set("consumed_at = ?", nowNs()). - Set("not_consumed = FALSE"). - Where("is_last_resort = FALSE"). - Where("id IN (?)", bun.In(extractIds(keyPackages))). - Exec(ctx) + KeyPackage: keyPackage, + Expiration: expiration, + } + res, err := s.db.NewUpdate(). + Model(&installation). + OmitZero(). + WherePK(). + Exec(ctx) + if err != nil { return err - }) + } + rows, err := res.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return errors.New("installation id unknown") + } + return nil +} +func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]*Installation, error) { + installations := make([]*Installation, 0) + + err := s.db.NewSelect(). + Model(&installations). + Where("ID IN (?)", bun.In(installationIds)). + Scan(ctx, &installations) if err != nil { return nil, err } - return keyPackages, nil + return installations, nil } func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) { @@ -176,25 +157,6 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) e return err } -func NewKeyPackage(installationId []byte, data []byte, isLastResort bool) *KeyPackage { - return &KeyPackage{ - ID: buildKeyPackageId(data), - InstallationId: installationId, - CreatedAt: nowNs(), - IsLastResort: isLastResort, - NotConsumed: true, - Data: data, - } -} - -func extractIds(keyPackages []*KeyPackage) []string { - out := make([]string, len(keyPackages)) - for i, keyPackage := range keyPackages { - out[i] = keyPackage.ID - } - return out -} - func (s *Store) migrate(ctx context.Context) error { migrator := migrate.NewMigrator(s.db, mlsMigrations.Migrations) err := migrator.Init(ctx) @@ -218,11 +180,6 @@ func nowNs() int64 { return time.Now().UTC().UnixNano() } -func buildKeyPackageId(keyPackageData []byte) string { - digest := sha256.Sum256(keyPackageData) - return hex.EncodeToString(digest[:]) -} - type IdentityUpdateKind int const ( diff --git a/pkg/mlsstore/store_test.go b/pkg/mlsstore/store_test.go index 50673cec..95c21fb9 100644 --- a/pkg/mlsstore/store_test.go +++ b/pkg/mlsstore/store_test.go @@ -32,19 +32,15 @@ func TestCreateInstallation(t *testing.T) { installationId := test.RandomBytes(32) walletAddress := test.RandomString(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) + err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32), 0) require.NoError(t, err) installationFromDb := &Installation{} require.NoError(t, store.db.NewSelect().Model(installationFromDb).Where("id = ?", installationId).Scan(ctx)) require.Equal(t, walletAddress, installationFromDb.WalletAddress) - - keyPackageFromDB := &KeyPackage{} - require.NoError(t, store.db.NewSelect().Model(keyPackageFromDB).Where("installation_id = ?", installationId).Scan(ctx)) - require.Equal(t, installationId, keyPackageFromDB.InstallationId) } -func TestCreateInstallationIdempotent(t *testing.T) { +func TestUpdateKeyPackage(t *testing.T) { store, cleanup := NewTestStore(t) defer cleanup() @@ -53,57 +49,18 @@ func TestCreateInstallationIdempotent(t *testing.T) { walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) - require.NoError(t, err) - err = store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32)) - require.NoError(t, err) - - keyPackageFromDb := &KeyPackage{} - require.NoError(t, store.db.NewSelect().Model(keyPackageFromDb).Where("installation_id = ?", installationId).Scan(ctx)) - require.Equal(t, keyPackage, keyPackageFromDb.Data) -} - -func TestInsertKeyPackages(t *testing.T) { - store, cleanup := NewTestStore(t) - defer cleanup() - - ctx := context.Background() - installationId := test.RandomBytes(32) - walletAddress := test.RandomString(32) - keyPackage := test.RandomBytes(32) - - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage, 0) require.NoError(t, err) keyPackage2 := test.RandomBytes(32) - err = store.InsertKeyPackages(ctx, []*KeyPackage{{ - ID: buildKeyPackageId(keyPackage2), - InstallationId: installationId, - CreatedAt: nowNs(), - IsLastResort: false, - Data: keyPackage2, - }}) + err = store.UpdateKeyPackage(ctx, installationId, keyPackage2, 1) require.NoError(t, err) - keyPackagesFromDb := []*KeyPackage{} - require.NoError(t, store.db.NewSelect().Model(&keyPackagesFromDb).Where("installation_id = ?", installationId).Scan(ctx)) - require.Len(t, keyPackagesFromDb, 2) - - hasLastResort := false - hasRegular := false - for _, keyPackageFromDb := range keyPackagesFromDb { - require.Equal(t, installationId, keyPackageFromDb.InstallationId) - if keyPackageFromDb.IsLastResort { - hasLastResort = true - } - if !keyPackageFromDb.IsLastResort { - hasRegular = true - require.Equal(t, keyPackage2, keyPackageFromDb.Data) - } - } + installationFromDb := &Installation{} + require.NoError(t, store.db.NewSelect().Model(installationFromDb).Where("id = ?", installationId).Scan(ctx)) - require.True(t, hasLastResort) - require.True(t, hasRegular) + require.Equal(t, keyPackage2, installationFromDb.KeyPackage) + require.Equal(t, uint64(1), installationFromDb.Expiration) } func TestConsumeLastResortKeyPackage(t *testing.T) { @@ -115,48 +72,14 @@ func TestConsumeLastResortKeyPackage(t *testing.T) { walletAddress := test.RandomString(32) keyPackage := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) - require.NoError(t, err) - - consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) - require.NoError(t, err) - require.Len(t, consumeResult, 1) - require.Equal(t, keyPackage, consumeResult[0].Data) - require.Equal(t, installationId, consumeResult[0].InstallationId) -} - -func TestConsumeMultipleKeyPackages(t *testing.T) { - store, cleanup := NewTestStore(t) - defer cleanup() - - ctx := context.Background() - installationId := test.RandomBytes(32) - walletAddress := test.RandomString(32) - keyPackage := test.RandomBytes(32) - - err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage) - require.NoError(t, err) - - keyPackage2 := test.RandomBytes(32) - require.NoError(t, store.InsertKeyPackages(ctx, []*KeyPackage{{ - ID: buildKeyPackageId(keyPackage2), - InstallationId: installationId, - CreatedAt: nowNs(), - IsLastResort: false, - Data: keyPackage2, - }})) - - consumeResult, err := store.ConsumeKeyPackages(ctx, [][]byte{installationId}) + err := store.CreateInstallation(ctx, installationId, walletAddress, keyPackage, keyPackage, 0) require.NoError(t, err) - require.Len(t, consumeResult, 1) - require.Equal(t, keyPackage2, consumeResult[0].Data) - require.Equal(t, installationId, consumeResult[0].InstallationId) - consumeResult, err = store.ConsumeKeyPackages(ctx, [][]byte{installationId}) + fetchResult, err := store.FetchKeyPackages(ctx, [][]byte{installationId}) require.NoError(t, err) - require.Len(t, consumeResult, 1) - // Now we are out of regular key packages. Expect to consume the last resort - require.Equal(t, keyPackage, consumeResult[0].Data) + require.Len(t, fetchResult, 1) + require.Equal(t, keyPackage, fetchResult[0].KeyPackage) + require.Equal(t, installationId, fetchResult[0].ID) } func TestGetIdentityUpdates(t *testing.T) { @@ -169,13 +92,13 @@ func TestGetIdentityUpdates(t *testing.T) { installationId1 := test.RandomBytes(32) keyPackage1 := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1, keyPackage1) + err := store.CreateInstallation(ctx, installationId1, walletAddress, keyPackage1, keyPackage1, 0) require.NoError(t, err) installationId2 := test.RandomBytes(32) keyPackage2 := test.RandomBytes(32) - err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2, keyPackage2) + err = store.CreateInstallation(ctx, installationId2, walletAddress, keyPackage2, keyPackage2, 0) require.NoError(t, err) identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress}, 0) @@ -200,14 +123,14 @@ func TestGetIdentityUpdatesMultipleWallets(t *testing.T) { installationId1 := test.RandomBytes(32) keyPackage1 := test.RandomBytes(32) - err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1, keyPackage1) + err := store.CreateInstallation(ctx, installationId1, walletAddress1, keyPackage1, keyPackage1, 0) require.NoError(t, err) walletAddress2 := test.RandomString(32) installationId2 := test.RandomBytes(32) keyPackage2 := test.RandomBytes(32) - err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2, keyPackage2) + err = store.CreateInstallation(ctx, installationId2, walletAddress2, keyPackage2, keyPackage2, 0) require.NoError(t, err) identityUpdates, err := store.GetIdentityUpdates(ctx, []string{walletAddress1, walletAddress2}, 0) diff --git a/pkg/mlsvalidate/service.go b/pkg/mlsvalidate/service.go index bdf0679b..0391abb1 100644 --- a/pkg/mlsvalidate/service.go +++ b/pkg/mlsvalidate/service.go @@ -13,6 +13,7 @@ type IdentityValidationResult struct { AccountAddress string InstallationId []byte CredentialIdentity []byte + Expiration uint64 } type GroupMessageValidationResult struct { @@ -58,6 +59,7 @@ func (s *MLSValidationServiceImpl) ValidateKeyPackages(ctx context.Context, keyP AccountAddress: response.AccountAddress, InstallationId: response.InstallationId, CredentialIdentity: response.CredentialIdentityBytes, + Expiration: response.Expiration, } } return out, nil