diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 6436fd3b..57b04b4b 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -26,11 +26,11 @@ WHERE id = @id; -- name: FetchKeyPackages :many SELECT id, key_package FROM installations -WHERE ID IN (@ids); +WHERE id = ANY (sqlc.arg(installation_ids)::bytea[]); -- name: GetIdentityUpdates :many SELECT * FROM installations -WHERE wallet_address IN (@wallet_addresses) +WHERE wallet_address = ANY (@wallet_addresses::text[]) AND (created_at > @start_time OR revoked_at > @start_time) ORDER BY created_at ASC; diff --git a/pkg/mls/store/queries/filters.go b/pkg/mls/store/queries/filters.go new file mode 100644 index 00000000..38ed8dd6 --- /dev/null +++ b/pkg/mls/store/queries/filters.go @@ -0,0 +1,18 @@ +package queries + +import "encoding/json" + +type InboxLogFilter struct { + InboxId string `json:"inbox_id"` + SequenceId int64 `json:"sequence_id"` +} + +type InboxLogFilterList []InboxLogFilter + +func (f *InboxLogFilterList) ToSql() (json.RawMessage, error) { + jsonBytes, err := json.Marshal(f) + if err != nil { + return nil, err + } + return jsonBytes, nil +} diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index 337eebd8..c9556056 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -9,6 +9,8 @@ import ( "context" "database/sql" "encoding/json" + + "github.com/lib/pq" ) const createInstallation = `-- name: CreateInstallation :exec @@ -39,7 +41,7 @@ func (q *Queries) CreateInstallation(ctx context.Context, arg CreateInstallation const fetchKeyPackages = `-- name: FetchKeyPackages :many SELECT id, key_package FROM installations -WHERE ID IN ($1) +WHERE id = ANY ($1::bytea[]) ` type FetchKeyPackagesRow struct { @@ -47,8 +49,8 @@ type FetchKeyPackagesRow struct { KeyPackage []byte } -func (q *Queries) FetchKeyPackages(ctx context.Context, ids []byte) ([]FetchKeyPackagesRow, error) { - rows, err := q.db.QueryContext(ctx, fetchKeyPackages, ids) +func (q *Queries) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]FetchKeyPackagesRow, error) { + rows, err := q.db.QueryContext(ctx, fetchKeyPackages, pq.Array(installationIds)) if err != nil { return nil, err } @@ -106,18 +108,18 @@ func (q *Queries) GetAllInboxLogs(ctx context.Context, inboxID string) ([]InboxL const getIdentityUpdates = `-- name: GetIdentityUpdates :many SELECT id, wallet_address, created_at, updated_at, credential_identity, revoked_at, key_package, expiration FROM installations -WHERE wallet_address IN ($1) +WHERE wallet_address = ANY ($1::text[]) AND (created_at > $2 OR revoked_at > $2) ORDER BY created_at ASC ` type GetIdentityUpdatesParams struct { - WalletAddresses string + WalletAddresses []string StartTime int64 } func (q *Queries) GetIdentityUpdates(ctx context.Context, arg GetIdentityUpdatesParams) ([]Installation, error) { - rows, err := q.db.QueryContext(ctx, getIdentityUpdates, arg.WalletAddresses, arg.StartTime) + rows, err := q.db.QueryContext(ctx, getIdentityUpdates, pq.Array(arg.WalletAddresses), arg.StartTime) if err != nil { return nil, err } diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index f26a42d3..9e5f6da5 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -12,6 +12,7 @@ import ( "github.com/uptrace/bun" "github.com/uptrace/bun/migrate" migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" + queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" @@ -22,9 +23,10 @@ import ( const maxPageSize = 100 type Store struct { - config Config - log *zap.Logger - db *bun.DB + config Config + log *zap.Logger + db *bun.DB + queries *queries.Queries } type IdentityStore interface { @@ -37,7 +39,7 @@ type MlsStore interface { CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error - FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]*Installation, error) + FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*WelcomeMessage, error) @@ -52,9 +54,10 @@ func New(ctx context.Context, config Config) (*Store, error) { } } s := &Store{ - log: config.Log.Named("mlsstore"), - db: config.DB, - config: config, + log: config.Log.Named("mlsstore"), + db: config.DB, + config: config, + queries: queries.New(config.DB.DB), } if err := s.migrate(ctx); err != nil { @@ -70,54 +73,42 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish return nil, errors.New("IdentityUpdate is required") } - // TODO: Implement serializable isolation level once supported - if err := s.db.RunInTx(ctx, &sql.TxOptions{ /*Isolation: sql.LevelSerializable*/ }, func(ctx context.Context, tx bun.Tx) error { - inbox_log_entries := make([]*InboxLogEntry, 0) - - if err := s.db.NewSelect(). - Model(&inbox_log_entries). - Where("inbox_id = ?", new_update.GetInboxId()). - Order("sequence_id ASC"). - For("UPDATE"). - Scan(ctx); err != nil { + if err := s.RunInTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}, func(ctx context.Context, txQueries *queries.Queries) error { + inboxLogEntries, err := txQueries.GetAllInboxLogs(ctx, new_update.GetInboxId()) + if err != nil { return err } - if len(inbox_log_entries) >= 256 { + if len(inboxLogEntries) >= 256 { return errors.New("inbox log is full") } - updates := make([]*associations.IdentityUpdate, 0, len(inbox_log_entries)+1) - for _, log := range inbox_log_entries { - identity_update := &associations.IdentityUpdate{} - if err := proto.Unmarshal(log.IdentityUpdateProto, identity_update); err != nil { + updates := make([]*associations.IdentityUpdate, 0, len(inboxLogEntries)+1) + for _, log := range inboxLogEntries { + identityUpdate := &associations.IdentityUpdate{} + if err := proto.Unmarshal(log.IdentityUpdateProto, identityUpdate); err != nil { return err } - updates = append(updates, identity_update) + updates = append(updates, identityUpdate) } _ = append(updates, new_update) // TODO: Validate the updates, and abort transaction if failed - proto_bytes, err := proto.Marshal(new_update) + protoBytes, err := proto.Marshal(new_update) if err != nil { return err } - new_entry := InboxLogEntry{ - InboxId: new_update.GetInboxId(), + _, err = txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{ + InboxID: new_update.GetInboxId(), ServerTimestampNs: nowNs(), - IdentityUpdateProto: proto_bytes, - } + IdentityUpdateProto: protoBytes, + }) - _, err = s.db.NewInsert(). - Model(&new_entry). - Returning("sequence_id"). - Exec(ctx) if err != nil { return err } - // TODO: Insert or update the address_log table using sequence_id return nil @@ -130,36 +121,47 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error) { reqs := batched_req.GetRequests() - resps := make([]*identity.GetIdentityUpdatesResponse_Response, len(reqs)) + filters := make(queries.InboxLogFilterList, len(reqs)) for i, req := range reqs { - inbox_log_entries := make([]*InboxLogEntry, 0) - - err := s.db.NewSelect(). - Model(&inbox_log_entries). - Where("sequence_id > ?", req.GetSequenceId()). - Where("inbox_id = ?", req.GetInboxId()). - Order("sequence_id ASC"). - Scan(ctx) - if err != nil { - return nil, err + filters[i] = queries.InboxLogFilter{ + InboxId: req.InboxId, + SequenceId: int64(req.SequenceId), } + } + filterBytes, err := filters.ToSql() + if err != nil { + return nil, err + } - updates := make([]*identity.GetIdentityUpdatesResponse_IdentityUpdateLog, len(inbox_log_entries)) - for j, entry := range inbox_log_entries { + results, err := s.queries.GetInboxLogFiltered(ctx, filterBytes) + if err != nil { + return nil, err + } + + // Organize the results by inbox ID + resultMap := make(map[string][]queries.InboxLog) + for _, result := range results { + resultMap[result.InboxID] = append(resultMap[result.InboxID], result) + } + + resps := make([]*identity.GetIdentityUpdatesResponse_Response, len(reqs)) + for i, req := range reqs { + logEntries := resultMap[req.InboxId] + updates := make([]*identity.GetIdentityUpdatesResponse_IdentityUpdateLog, len(logEntries)) + for j, entry := range logEntries { identity_update := &associations.IdentityUpdate{} if err := proto.Unmarshal(entry.IdentityUpdateProto, identity_update); err != nil { return nil, err } updates[j] = &identity.GetIdentityUpdatesResponse_IdentityUpdateLog{ - SequenceId: entry.SequenceId, + SequenceId: uint64(entry.SequenceID), ServerTimestampNs: uint64(entry.ServerTimestampNs), Update: identity_update, } } - resps[i] = &identity.GetIdentityUpdatesResponse_Response{ - InboxId: req.GetInboxId(), + InboxId: req.InboxId, Updates: updates, } } @@ -173,78 +175,45 @@ func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdent func (s *Store) CreateInstallation(ctx context.Context, installationId []byte, walletAddress string, credentialIdentity, keyPackage []byte, expiration uint64) error { createdAt := nowNs() - installation := Installation{ + return s.queries.CreateInstallation(ctx, queries.CreateInstallationParams{ ID: installationId, WalletAddress: walletAddress, CreatedAt: createdAt, - UpdatedAt: createdAt, CredentialIdentity: credentialIdentity, - - KeyPackage: keyPackage, - Expiration: expiration, - } - - _, err := s.db.NewInsert(). - Model(&installation). - Ignore(). - Exec(ctx) - return err + KeyPackage: keyPackage, + Expiration: int64(expiration), + }) } // Insert a new key package, ignoring any that may already exist func (s *Store) UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error { - installation := Installation{ - ID: installationId, - UpdatedAt: nowNs(), - + rowsUpdated, err := s.queries.UpdateKeyPackage(ctx, queries.UpdateKeyPackageParams{ + ID: installationId, + UpdatedAt: nowNs(), KeyPackage: keyPackage, - Expiration: expiration, - } + Expiration: int64(expiration), + }) - res, err := s.db.NewUpdate(). - Model(&installation). - OmitZero(). - WherePK(). - Exec(ctx) if err != nil { return err } - rows, err := res.RowsAffected() - if err != nil { - return err - } - if rows == 0 { + + if rowsUpdated == 0 { return errors.New("installation id unknown") } + return nil } -func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]*Installation, error) { - installations := make([]*Installation, 0) - - err := s.db.NewSelect(). - Model(&installations). - Where("ID IN (?)", bun.In(installationIds)). - Scan(ctx, &installations) - if err != nil { - return nil, err - } - - return installations, nil +func (s *Store) FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) { + return s.queries.FetchKeyPackages(ctx, installationIds) } func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) { - updated := make([]*Installation, 0) - // Find all installations that were changed since the startTimeNs - err := s.db.NewSelect(). - Model(&updated). - Where("wallet_address IN (?)", bun.In(walletAddresses)). - WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery { - return q.Where("created_at > ?", startTimeNs).WhereOr("revoked_at > ?", startTimeNs) - }). - Order("created_at ASC"). - Scan(ctx) - + updated, err := s.queries.GetIdentityUpdates(ctx, queries.GetIdentityUpdatesParams{ + WalletAddresses: walletAddresses, + StartTime: startTimeNs, + }) if err != nil { return nil, err } @@ -260,11 +229,11 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string TimestampNs: uint64(installation.CreatedAt), }) } - if installation.RevokedAt != nil && *installation.RevokedAt > startTimeNs { + if installation.RevokedAt.Valid && installation.RevokedAt.Int64 > startTimeNs { out[installation.WalletAddress] = append(out[installation.WalletAddress], IdentityUpdate{ Kind: Revoke, InstallationKey: installation.ID, - TimestampNs: uint64(*installation.RevokedAt), + TimestampNs: uint64(installation.RevokedAt.Int64), }) } } @@ -277,14 +246,10 @@ func (s *Store) GetIdentityUpdates(ctx context.Context, walletAddresses []string } func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) error { - _, err := s.db.NewUpdate(). - Model(&Installation{}). - Set("revoked_at = ?", nowNs()). - Where("id = ?", installationId). - Where("revoked_at IS NULL"). - Exec(ctx) - - return err + return s.queries.RevokeInstallation(ctx, queries.RevokeInstallationParams{ + RevokedAt: sql.NullInt64{Valid: true, Int64: nowNs()}, + InstallationID: installationId, + }) } func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) { @@ -534,3 +499,27 @@ func IsAlreadyExistsError(err error) bool { _, ok := err.(*AlreadyExistsError) return ok } + +func (s *Store) RunInTx( + ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, txQueries *queries.Queries) error, +) error { + tx, err := s.db.DB.BeginTx(ctx, opts) + if err != nil { + return err + } + + var done bool + + defer func() { + if !done { + _ = tx.Rollback() + } + }() + + if err := fn(ctx, s.queries.WithTx(tx)); err != nil { + return err + } + + done = true + return tx.Commit() +}