From 1163957efbe11b5c2fa76be3d78a9165c6f28ddd Mon Sep 17 00:00:00 2001 From: Nicholas Molnar <65710+neekolas@users.noreply.github.com> Date: Thu, 25 Apr 2024 17:19:09 -0700 Subject: [PATCH] Continue sqlc migration --- pkg/metrics/mls.go | 6 +- pkg/mls/api/v1/service.go | 6 +- pkg/mls/api/v1/service_test.go | 11 +- pkg/mls/store/models.go | 58 ------ pkg/mls/store/queries.sql | 86 ++++++-- pkg/mls/store/queries/queries.sql.go | 285 +++++++++++++++++++++++---- pkg/mls/store/store.go | 187 ++++++++++-------- pkg/mls/store/store_test.go | 42 ++-- pkg/server/pgxdb.go | 3 +- 9 files changed, 454 insertions(+), 230 deletions(-) delete mode 100644 pkg/mls/store/models.go diff --git a/pkg/metrics/mls.go b/pkg/metrics/mls.go index 82acbbf2..fe27690a 100644 --- a/pkg/metrics/mls.go +++ b/pkg/metrics/mls.go @@ -4,7 +4,7 @@ import ( "context" "github.com/prometheus/client_golang/prometheus" - mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" + "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" "go.uber.org/zap" ) @@ -25,7 +25,7 @@ var mlsSentGroupMessageCount = prometheus.NewCounterVec( appClientVersionTagKeys, ) -func EmitMLSSentGroupMessage(ctx context.Context, log *zap.Logger, msg *mlsstore.GroupMessage) { +func EmitMLSSentGroupMessage(ctx context.Context, log *zap.Logger, msg *queries.GroupMessage) { labels := contextLabels(ctx) mlsSentGroupMessageSize.With(labels).Observe(float64(len(msg.Data))) mlsSentGroupMessageCount.With(labels).Inc() @@ -48,7 +48,7 @@ var mlsSentWelcomeMessageCount = prometheus.NewCounterVec( appClientVersionTagKeys, ) -func EmitMLSSentWelcomeMessage(ctx context.Context, log *zap.Logger, msg *mlsstore.WelcomeMessage) { +func EmitMLSSentWelcomeMessage(ctx context.Context, log *zap.Logger, msg *queries.WelcomeMessage) { labels := contextLabels(ctx) mlsSentWelcomeMessageSize.With(labels).Observe(float64(len(msg.Data))) mlsSentWelcomeMessageCount.With(labels).Inc() diff --git a/pkg/mls/api/v1/service.go b/pkg/mls/api/v1/service.go index 3d79c522..44d13d08 100644 --- a/pkg/mls/api/v1/service.go +++ b/pkg/mls/api/v1/service.go @@ -262,9 +262,9 @@ func (s *Service) SendGroupMessages(ctx context.Context, req *mlsv1.SendGroupMes msgB, err := pb.Marshal(&mlsv1.GroupMessage{ Version: &mlsv1.GroupMessage_V1_{ V1: &mlsv1.GroupMessage_V1{ - Id: msg.Id, + Id: uint64(msg.ID), CreatedNs: uint64(msg.CreatedAt.UnixNano()), - GroupId: msg.GroupId, + GroupId: msg.GroupID, Data: msg.Data, }, }, @@ -312,7 +312,7 @@ func (s *Service) SendWelcomeMessages(ctx context.Context, req *mlsv1.SendWelcom msgB, err := pb.Marshal(&mlsv1.WelcomeMessage{ Version: &mlsv1.WelcomeMessage_V1_{ V1: &mlsv1.WelcomeMessage_V1{ - Id: msg.Id, + Id: uint64(msg.ID), CreatedNs: uint64(msg.CreatedAt.UnixNano()), InstallationKey: msg.InstallationKey, Data: msg.Data, diff --git a/pkg/mls/api/v1/service_test.go b/pkg/mls/api/v1/service_test.go index fc002047..bb02ee88 100644 --- a/pkg/mls/api/v1/service_test.go +++ b/pkg/mls/api/v1/service_test.go @@ -14,6 +14,7 @@ import ( "github.com/uptrace/bun" wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb" mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store" + "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" @@ -120,12 +121,10 @@ func TestRegisterInstallation(t *testing.T) { require.NoError(t, err) require.Equal(t, installationId, res.InstallationKey) - installations := []mlsstore.Installation{} - err = mlsDb.NewSelect().Model(&installations).Where("id = ?", installationId).Scan(ctx) + installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId) require.NoError(t, err) - require.Len(t, installations, 1) - require.Equal(t, accountAddress, installations[0].WalletAddress) + require.Equal(t, accountAddress, installation.WalletAddress) } func TestRegisterInstallationError(t *testing.T) { @@ -170,9 +169,9 @@ func TestUploadKeyPackage(t *testing.T) { require.NoError(t, err) require.NotNil(t, uploadRes) - installation := &mlsstore.Installation{} - err = mlsDb.NewSelect().Model(installation).Where("id = ?", installationId).Scan(ctx) + installation, err := queries.New(mlsDb.DB).GetInstallation(ctx, installationId) require.NoError(t, err) + require.Equal(t, accountAddress, installation.WalletAddress) } func TestFetchKeyPackages(t *testing.T) { diff --git a/pkg/mls/store/models.go b/pkg/mls/store/models.go deleted file mode 100644 index 47d7145b..00000000 --- a/pkg/mls/store/models.go +++ /dev/null @@ -1,58 +0,0 @@ -package store - -import ( - "time" - - "github.com/uptrace/bun" -) - -type AddressLogEntry struct { - bun.BaseModel `bun:"table:address_log"` - - Address string `bun:",notnull"` - InboxId string `bun:",notnull"` - AssociationSequenceId *uint64 `bun:","` - RevocationSequenceId *uint64 `bun:","` -} - -type InboxLogEntry struct { - bun.BaseModel `bun:"table:inbox_log"` - - SequenceId uint64 `bun:",autoincrement"` - InboxId string - ServerTimestampNs int64 - IdentityUpdateProto []byte -} - -type Installation struct { - bun.BaseModel `bun:"table:installations"` - - ID []byte `bun:",pk,type:bytea"` - WalletAddress string `bun:"wallet_address,notnull"` - CreatedAt int64 `bun:"created_at,notnull"` - UpdatedAt int64 `bun:"updated_at,notnull"` - RevokedAt *int64 `bun:"revoked_at"` - CredentialIdentity []byte `bun:"credential_identity,notnull,type:bytea"` - - KeyPackage []byte `bun:"key_package,notnull,type:bytea"` - Expiration uint64 `bun:"expiration,notnull"` -} - -type GroupMessage struct { - bun.BaseModel `bun:"table:group_messages"` - - Id uint64 `bun:",pk,notnull"` - CreatedAt time.Time `bun:",notnull"` - GroupId []byte `bun:",notnull,type:bytea"` - Data []byte `bun:",notnull,type:bytea"` -} - -type WelcomeMessage struct { - bun.BaseModel `bun:"table:welcome_messages"` - - Id uint64 `bun:",pk,notnull"` - CreatedAt time.Time `bun:",notnull"` - InstallationKey []byte `bun:",notnull,type:bytea"` - Data []byte `bun:",notnull,type:bytea"` - HpkePublicKey []byte `bun:"hpke_public_key,notnull,type:bytea"` -} diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index d18f4f8e..29ad2e84 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -18,7 +18,7 @@ FROM SELECT * FROM - json_populate_recordset(NULL::inbox_filter, sqlc.arg(filters)) AS b(inbox_id, + json_populate_recordset(NULL::inbox_filter, @filters) AS b(inbox_id, sequence_id)) AS b ON b.inbox_id = a.inbox_id AND a.sequence_id > b.sequence_id ORDER BY @@ -79,6 +79,14 @@ WHERE (address, inbox_id, association_sequence_id) =( INSERT INTO installations(id, wallet_address, created_at, updated_at, credential_identity, key_package, expiration) VALUES ($1, $2, $3, $3, $4, $5, $6); +-- name: GetInstallation :one +SELECT + * +FROM + installations +WHERE + id = $1; + -- name: UpdateKeyPackage :execrows UPDATE installations @@ -96,7 +104,7 @@ SELECT FROM installations WHERE - id = ANY (sqlc.arg(installation_ids)::BYTEA[]); + id = ANY (@installation_ids::BYTEA[]); -- name: GetIdentityUpdates :many SELECT @@ -131,49 +139,99 @@ INSERT INTO welcome_messages(installation_key, data, installation_key_data_hash, RETURNING *; --- name: QueryGroupMessagesAsc :many +-- name: GetAllGroupMessages :many +SELECT + * +FROM + group_messages +ORDER BY + id ASC; + +-- name: QueryGroupMessages :many SELECT * FROM group_messages WHERE group_id = @group_id +ORDER BY + CASE WHEN @sort_desc::BOOL THEN + id + END DESC, + CASE WHEN @sort_desc::BOOL = FALSE THEN + id + END ASC +LIMIT @numrows; + +-- name: QueryGroupMessagesWithCursorAsc :many +SELECT + * +FROM + group_messages +WHERE + group_id = @group_id + AND id > @cursor ORDER BY id ASC LIMIT @numrows; --- name: QueryGroupMessagesDesc :many +-- name: QueryGroupMessagesWithCursorDesc :many SELECT * FROM group_messages WHERE group_id = @group_id + AND id < @cursor ORDER BY id DESC LIMIT @numrows; --- name: QueryGroupMessagesWithCursorAsc :many +-- name: GetAllWelcomeMessages :many SELECT * FROM - group_messages + welcome_messages +ORDER BY + id ASC; + +-- name: QueryWelcomeMessages :many +SELECT + * +FROM + welcome_messages +WHERE + installation_key = @installation_key +ORDER BY + CASE WHEN @sort_desc::BOOL THEN + id + END DESC, + CASE WHEN @sort_desc::BOOL = FALSE THEN + id + END ASC +LIMIT @numrows; + +-- name: QueryWelcomeMessagesWithCursorAsc :many +SELECT + * +FROM + welcome_messages WHERE - group_id = $1 - AND id > $2 + installation_key = @installation_key + AND id > @cursor ORDER BY id ASC -LIMIT $3; +LIMIT @numrows; --- name: QueryGroupMessagesWithCursorDesc :many +-- name: QueryWelcomeMessagesWithCursorDesc :many SELECT * FROM - group_messages + welcome_messages WHERE - group_id = $1 - AND id < $2 + installation_key = @installation_key + AND id < @cursor ORDER BY id DESC -LIMIT $3; +LIMIT @numrows; diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index f7e3311e..99a27987 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -127,6 +127,44 @@ func (q *Queries) GetAddressLogs(ctx context.Context, addresses []string) ([]Get return items, nil } +const getAllGroupMessages = `-- name: GetAllGroupMessages :many +SELECT + id, created_at, group_id, data, group_id_data_hash +FROM + group_messages +ORDER BY + id ASC +` + +func (q *Queries) GetAllGroupMessages(ctx context.Context) ([]GroupMessage, error) { + rows, err := q.db.QueryContext(ctx, getAllGroupMessages) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GroupMessage + for rows.Next() { + var i GroupMessage + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.GroupID, + &i.Data, + &i.GroupIDDataHash, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAllInboxLogs = `-- name: GetAllInboxLogs :many SELECT sequence_id, inbox_id, server_timestamp_ns, identity_update_proto @@ -167,6 +205,45 @@ func (q *Queries) GetAllInboxLogs(ctx context.Context, inboxID string) ([]InboxL return items, nil } +const getAllWelcomeMessages = `-- name: GetAllWelcomeMessages :many +SELECT + id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key +FROM + welcome_messages +ORDER BY + id ASC +` + +func (q *Queries) GetAllWelcomeMessages(ctx context.Context) ([]WelcomeMessage, error) { + rows, err := q.db.QueryContext(ctx, getAllWelcomeMessages) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WelcomeMessage + for rows.Next() { + var i WelcomeMessage + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.InstallationKey, + &i.Data, + &i.InstallationKeyDataHash, + &i.HpkePublicKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getIdentityUpdates = `-- name: GetIdentityUpdates :many SELECT id, wallet_address, created_at, updated_at, credential_identity, revoked_at, key_package, expiration @@ -261,6 +338,31 @@ func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessa return items, nil } +const getInstallation = `-- name: GetInstallation :one +SELECT + id, wallet_address, created_at, updated_at, credential_identity, revoked_at, key_package, expiration +FROM + installations +WHERE + id = $1 +` + +func (q *Queries) GetInstallation(ctx context.Context, id []byte) (Installation, error) { + row := q.db.QueryRowContext(ctx, getInstallation, id) + var i Installation + err := row.Scan( + &i.ID, + &i.WalletAddress, + &i.CreatedAt, + &i.UpdatedAt, + &i.CredentialIdentity, + &i.RevokedAt, + &i.KeyPackage, + &i.Expiration, + ) + return i, err +} + const insertAddressLog = `-- name: InsertAddressLog :one INSERT INTO address_log(address, inbox_id, association_sequence_id, revocation_sequence_id) VALUES ($1, $2, $3, $4) @@ -371,7 +473,7 @@ func (q *Queries) InsertWelcomeMessage(ctx context.Context, arg InsertWelcomeMes return i, err } -const queryGroupMessagesAsc = `-- name: QueryGroupMessagesAsc :many +const queryGroupMessages = `-- name: QueryGroupMessages :many SELECT id, created_at, group_id, data, group_id_data_hash FROM @@ -379,17 +481,23 @@ FROM WHERE group_id = $1 ORDER BY - id ASC -LIMIT $2 + CASE WHEN $2::BOOL THEN + id + END DESC, + CASE WHEN $2::BOOL = FALSE THEN + id + END ASC +LIMIT $3 ` -type QueryGroupMessagesAscParams struct { - GroupID []byte - Numrows int32 +type QueryGroupMessagesParams struct { + GroupID []byte + SortDesc bool + Numrows int32 } -func (q *Queries) QueryGroupMessagesAsc(ctx context.Context, arg QueryGroupMessagesAscParams) ([]GroupMessage, error) { - rows, err := q.db.QueryContext(ctx, queryGroupMessagesAsc, arg.GroupID, arg.Numrows) +func (q *Queries) QueryGroupMessages(ctx context.Context, arg QueryGroupMessagesParams) ([]GroupMessage, error) { + rows, err := q.db.QueryContext(ctx, queryGroupMessages, arg.GroupID, arg.SortDesc, arg.Numrows) if err != nil { return nil, err } @@ -417,25 +525,27 @@ func (q *Queries) QueryGroupMessagesAsc(ctx context.Context, arg QueryGroupMessa return items, nil } -const queryGroupMessagesDesc = `-- name: QueryGroupMessagesDesc :many +const queryGroupMessagesWithCursorAsc = `-- name: QueryGroupMessagesWithCursorAsc :many SELECT id, created_at, group_id, data, group_id_data_hash FROM group_messages WHERE group_id = $1 + AND id > $2 ORDER BY - id DESC -LIMIT $2 + id ASC +LIMIT $3 ` -type QueryGroupMessagesDescParams struct { +type QueryGroupMessagesWithCursorAscParams struct { GroupID []byte + Cursor int64 Numrows int32 } -func (q *Queries) QueryGroupMessagesDesc(ctx context.Context, arg QueryGroupMessagesDescParams) ([]GroupMessage, error) { - rows, err := q.db.QueryContext(ctx, queryGroupMessagesDesc, arg.GroupID, arg.Numrows) +func (q *Queries) QueryGroupMessagesWithCursorAsc(ctx context.Context, arg QueryGroupMessagesWithCursorAscParams) ([]GroupMessage, error) { + rows, err := q.db.QueryContext(ctx, queryGroupMessagesWithCursorAsc, arg.GroupID, arg.Cursor, arg.Numrows) if err != nil { return nil, err } @@ -463,27 +573,27 @@ func (q *Queries) QueryGroupMessagesDesc(ctx context.Context, arg QueryGroupMess return items, nil } -const queryGroupMessagesWithCursorAsc = `-- name: QueryGroupMessagesWithCursorAsc :many +const queryGroupMessagesWithCursorDesc = `-- name: QueryGroupMessagesWithCursorDesc :many SELECT id, created_at, group_id, data, group_id_data_hash FROM group_messages WHERE group_id = $1 - AND id > $2 + AND id < $2 ORDER BY - id ASC + id DESC LIMIT $3 ` -type QueryGroupMessagesWithCursorAscParams struct { +type QueryGroupMessagesWithCursorDescParams struct { GroupID []byte - ID int64 - Limit int32 + Cursor int64 + Numrows int32 } -func (q *Queries) QueryGroupMessagesWithCursorAsc(ctx context.Context, arg QueryGroupMessagesWithCursorAscParams) ([]GroupMessage, error) { - rows, err := q.db.QueryContext(ctx, queryGroupMessagesWithCursorAsc, arg.GroupID, arg.ID, arg.Limit) +func (q *Queries) QueryGroupMessagesWithCursorDesc(ctx context.Context, arg QueryGroupMessagesWithCursorDescParams) ([]GroupMessage, error) { + rows, err := q.db.QueryContext(ctx, queryGroupMessagesWithCursorDesc, arg.GroupID, arg.Cursor, arg.Numrows) if err != nil { return nil, err } @@ -511,40 +621,143 @@ func (q *Queries) QueryGroupMessagesWithCursorAsc(ctx context.Context, arg Query return items, nil } -const queryGroupMessagesWithCursorDesc = `-- name: QueryGroupMessagesWithCursorDesc :many +const queryWelcomeMessages = `-- name: QueryWelcomeMessages :many SELECT - id, created_at, group_id, data, group_id_data_hash + id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key FROM - group_messages + welcome_messages WHERE - group_id = $1 + installation_key = $1 +ORDER BY + CASE WHEN $2::BOOL THEN + id + END DESC, + CASE WHEN $2::BOOL = FALSE THEN + id + END ASC +LIMIT $3 +` + +type QueryWelcomeMessagesParams struct { + InstallationKey []byte + SortDesc bool + Numrows int32 +} + +func (q *Queries) QueryWelcomeMessages(ctx context.Context, arg QueryWelcomeMessagesParams) ([]WelcomeMessage, error) { + rows, err := q.db.QueryContext(ctx, queryWelcomeMessages, arg.InstallationKey, arg.SortDesc, arg.Numrows) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WelcomeMessage + for rows.Next() { + var i WelcomeMessage + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.InstallationKey, + &i.Data, + &i.InstallationKeyDataHash, + &i.HpkePublicKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queryWelcomeMessagesWithCursorAsc = `-- name: QueryWelcomeMessagesWithCursorAsc :many +SELECT + id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key +FROM + welcome_messages +WHERE + installation_key = $1 + AND id > $2 +ORDER BY + id ASC +LIMIT $3 +` + +type QueryWelcomeMessagesWithCursorAscParams struct { + InstallationKey []byte + Cursor int64 + Numrows int32 +} + +func (q *Queries) QueryWelcomeMessagesWithCursorAsc(ctx context.Context, arg QueryWelcomeMessagesWithCursorAscParams) ([]WelcomeMessage, error) { + rows, err := q.db.QueryContext(ctx, queryWelcomeMessagesWithCursorAsc, arg.InstallationKey, arg.Cursor, arg.Numrows) + if err != nil { + return nil, err + } + defer rows.Close() + var items []WelcomeMessage + for rows.Next() { + var i WelcomeMessage + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.InstallationKey, + &i.Data, + &i.InstallationKeyDataHash, + &i.HpkePublicKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const queryWelcomeMessagesWithCursorDesc = `-- name: QueryWelcomeMessagesWithCursorDesc :many +SELECT + id, created_at, installation_key, data, installation_key_data_hash, hpke_public_key +FROM + welcome_messages +WHERE + installation_key = $1 AND id < $2 ORDER BY id DESC LIMIT $3 ` -type QueryGroupMessagesWithCursorDescParams struct { - GroupID []byte - ID int64 - Limit int32 +type QueryWelcomeMessagesWithCursorDescParams struct { + InstallationKey []byte + Cursor int64 + Numrows int32 } -func (q *Queries) QueryGroupMessagesWithCursorDesc(ctx context.Context, arg QueryGroupMessagesWithCursorDescParams) ([]GroupMessage, error) { - rows, err := q.db.QueryContext(ctx, queryGroupMessagesWithCursorDesc, arg.GroupID, arg.ID, arg.Limit) +func (q *Queries) QueryWelcomeMessagesWithCursorDesc(ctx context.Context, arg QueryWelcomeMessagesWithCursorDescParams) ([]WelcomeMessage, error) { + rows, err := q.db.QueryContext(ctx, queryWelcomeMessagesWithCursorDesc, arg.InstallationKey, arg.Cursor, arg.Numrows) if err != nil { return nil, err } defer rows.Close() - var items []GroupMessage + var items []WelcomeMessage for rows.Next() { - var i GroupMessage + var i WelcomeMessage if err := rows.Scan( &i.ID, &i.CreatedAt, - &i.GroupID, + &i.InstallationKey, &i.Data, - &i.GroupIDDataHash, + &i.InstallationKeyDataHash, + &i.HpkePublicKey, ); err != nil { return nil, err } diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 172579fd..bf202994 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -43,8 +43,8 @@ type MlsStore interface { UpdateKeyPackage(ctx context.Context, installationId, keyPackage []byte, expiration uint64) error FetchKeyPackages(ctx context.Context, installationIds [][]byte) ([]queries.FetchKeyPackagesRow, error) GetIdentityUpdates(ctx context.Context, walletAddresses []string, startTimeNs int64) (map[string]IdentityUpdateList, error) - InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) - InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*WelcomeMessage, error) + InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error) + InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*queries.WelcomeMessage, error) QueryGroupMessagesV1(ctx context.Context, query *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) QueryWelcomeMessagesV1(ctx context.Context, query *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) } @@ -319,13 +319,14 @@ func (s *Store) RevokeInstallation(ctx context.Context, installationId []byte) e }) } -func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*GroupMessage, error) { - message := GroupMessage{ - Data: data, - } +func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []byte) (*queries.GroupMessage, error) { + dataHash := sha256.Sum256(append(groupId, data...)) + message, err := s.queries.InsertGroupMessage(ctx, queries.InsertGroupMessageParams{ + GroupID: groupId, + Data: data, + GroupIDDataHash: dataHash[:], + }) - var id uint64 - err := s.db.QueryRow("INSERT INTO group_messages (group_id, data, group_id_data_hash) VALUES (?, ?, ?) RETURNING id", groupId, data, sha256.Sum256(append(groupId, data...))).Scan(&id) if err != nil { if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { return nil, NewAlreadyExistsError(err) @@ -333,21 +334,17 @@ func (s *Store) InsertGroupMessage(ctx context.Context, groupId []byte, data []b return nil, err } - err = s.db.NewSelect().Model(&message).Where("id = ?", id).Scan(ctx) - if err != nil { - return nil, err - } - return &message, nil } -func (s *Store) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*WelcomeMessage, error) { - message := WelcomeMessage{ - Data: data, - } - - var id uint64 - err := s.db.QueryRow("INSERT INTO welcome_messages (installation_key, data, installation_key_data_hash, hpke_public_key) VALUES (?, ?, ?, ?) RETURNING id", installationId, data, sha256.Sum256(append(installationId, data...)), hpkePublicKey).Scan(&id) +func (s *Store) InsertWelcomeMessage(ctx context.Context, installationId []byte, data []byte, hpkePublicKey []byte) (*queries.WelcomeMessage, error) { + dataHash := sha256.Sum256(append(installationId, data...)) + message, err := s.queries.InsertWelcomeMessage(ctx, queries.InsertWelcomeMessageParams{ + InstallationKey: installationId, + Data: data, + InstallationKeyDataHash: dataHash[:], + HpkePublicKey: hpkePublicKey, + }) if err != nil { if strings.Contains(err.Error(), "duplicate key value violates unique constraint") { return nil, NewAlreadyExistsError(err) @@ -355,145 +352,163 @@ func (s *Store) InsertWelcomeMessage(ctx context.Context, installationId []byte, return nil, err } - err = s.db.NewSelect().Model(&message).Where("id = ?", id).Scan(ctx) - if err != nil { - return nil, err - } - return &message, nil } func (s *Store) QueryGroupMessagesV1(ctx context.Context, req *mlsv1.QueryGroupMessagesRequest) (*mlsv1.QueryGroupMessagesResponse, error) { - msgs := make([]*GroupMessage, 0) - if len(req.GroupId) == 0 { return nil, errors.New("group is required") } - q := s.db.NewSelect(). - Model(&msgs). - Where("group_id = ?", req.GroupId) + sortDesc := true + var idCursor int64 + var err error + var messages []queries.GroupMessage + pageSize := int32(maxPageSize) - direction := mlsv1.SortDirection_SORT_DIRECTION_DESCENDING - if req.PagingInfo != nil && req.PagingInfo.Direction != mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { - direction = req.PagingInfo.Direction - } - switch direction { - case mlsv1.SortDirection_SORT_DIRECTION_DESCENDING: - q = q.Order("id DESC") - case mlsv1.SortDirection_SORT_DIRECTION_ASCENDING: - q = q.Order("id ASC") + if req.PagingInfo != nil && req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { + sortDesc = false } - pageSize := maxPageSize if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { - pageSize = int(req.PagingInfo.Limit) + pageSize = int32(req.PagingInfo.Limit) } - q = q.Limit(pageSize) if req.PagingInfo != nil && req.PagingInfo.IdCursor != 0 { - if direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { - q = q.Where("id > ?", req.PagingInfo.IdCursor) + idCursor = int64(req.PagingInfo.IdCursor) + } + + if idCursor > 0 { + if sortDesc { + messages, err = s.queries.QueryGroupMessagesWithCursorDesc(ctx, queries.QueryGroupMessagesWithCursorDescParams{ + GroupID: req.GroupId, + Cursor: idCursor, + Numrows: pageSize, + }) } else { - q = q.Where("id < ?", req.PagingInfo.IdCursor) + messages, err = s.queries.QueryGroupMessagesWithCursorAsc(ctx, queries.QueryGroupMessagesWithCursorAscParams{ + GroupID: req.GroupId, + Cursor: idCursor, + Numrows: pageSize, + }) } + } else { + messages, err = s.queries.QueryGroupMessages(ctx, queries.QueryGroupMessagesParams{ + GroupID: req.GroupId, + Numrows: pageSize, + SortDesc: sortDesc, + }) } - err := q.Scan(ctx) if err != nil { return nil, err } - messages := make([]*mlsv1.GroupMessage, 0, len(msgs)) - for _, msg := range msgs { - messages = append(messages, &mlsv1.GroupMessage{ + out := make([]*mlsv1.GroupMessage, len(messages)) + for idx, msg := range messages { + out[idx] = &mlsv1.GroupMessage{ Version: &mlsv1.GroupMessage_V1_{ V1: &mlsv1.GroupMessage_V1{ - Id: msg.Id, + Id: uint64(msg.ID), CreatedNs: uint64(msg.CreatedAt.UnixNano()), - GroupId: msg.GroupId, + GroupId: msg.GroupID, Data: msg.Data, }, }, - }) + } + } + + direction := mlsv1.SortDirection_SORT_DIRECTION_ASCENDING + if sortDesc { + direction = mlsv1.SortDirection_SORT_DIRECTION_DESCENDING } pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} - if len(messages) >= pageSize { - lastMsg := msgs[len(messages)-1] - pagingInfo.IdCursor = lastMsg.Id + if len(messages) >= int(pageSize) { + lastMsg := messages[len(messages)-1] + pagingInfo.IdCursor = uint64(lastMsg.ID) } return &mlsv1.QueryGroupMessagesResponse{ - Messages: messages, + Messages: out, PagingInfo: pagingInfo, }, nil } func (s *Store) QueryWelcomeMessagesV1(ctx context.Context, req *mlsv1.QueryWelcomeMessagesRequest) (*mlsv1.QueryWelcomeMessagesResponse, error) { - msgs := make([]*WelcomeMessage, 0) - if len(req.InstallationKey) == 0 { return nil, errors.New("installation is required") } - q := s.db.NewSelect(). - Model(&msgs). - Where("installation_key = ?", req.InstallationKey) - + sortDesc := true direction := mlsv1.SortDirection_SORT_DIRECTION_DESCENDING - if req.PagingInfo != nil && req.PagingInfo.Direction != mlsv1.SortDirection_SORT_DIRECTION_UNSPECIFIED { - direction = req.PagingInfo.Direction - } - switch direction { - case mlsv1.SortDirection_SORT_DIRECTION_DESCENDING: - q = q.Order("id DESC") - case mlsv1.SortDirection_SORT_DIRECTION_ASCENDING: - q = q.Order("id ASC") + pageSize := int32(maxPageSize) + var idCursor int64 + var err error + var messages []queries.WelcomeMessage + + if req.PagingInfo != nil && req.PagingInfo.Direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { + sortDesc = false + direction = mlsv1.SortDirection_SORT_DIRECTION_ASCENDING } - pageSize := maxPageSize if req.PagingInfo != nil && req.PagingInfo.Limit > 0 && req.PagingInfo.Limit <= maxPageSize { - pageSize = int(req.PagingInfo.Limit) + pageSize = int32(req.PagingInfo.Limit) } - q = q.Limit(pageSize) if req.PagingInfo != nil && req.PagingInfo.IdCursor != 0 { - if direction == mlsv1.SortDirection_SORT_DIRECTION_ASCENDING { - q = q.Where("id > ?", req.PagingInfo.IdCursor) + idCursor = int64(req.PagingInfo.IdCursor) + } + + if idCursor > 0 { + if sortDesc { + messages, err = s.queries.QueryWelcomeMessagesWithCursorDesc(ctx, queries.QueryWelcomeMessagesWithCursorDescParams{ + InstallationKey: req.InstallationKey, + Cursor: idCursor, + Numrows: pageSize, + }) } else { - q = q.Where("id < ?", req.PagingInfo.IdCursor) + messages, err = s.queries.QueryWelcomeMessagesWithCursorAsc(ctx, queries.QueryWelcomeMessagesWithCursorAscParams{ + InstallationKey: req.InstallationKey, + Cursor: idCursor, + Numrows: pageSize, + }) } + } else { + messages, err = s.queries.QueryWelcomeMessages(ctx, queries.QueryWelcomeMessagesParams{ + InstallationKey: req.InstallationKey, + Numrows: pageSize, + SortDesc: sortDesc, + }) } - err := q.Scan(ctx) if err != nil { return nil, err } - messages := make([]*mlsv1.WelcomeMessage, 0, len(msgs)) - for _, msg := range msgs { - messages = append(messages, &mlsv1.WelcomeMessage{ + out := make([]*mlsv1.WelcomeMessage, len(messages)) + for idx, msg := range messages { + out[idx] = &mlsv1.WelcomeMessage{ Version: &mlsv1.WelcomeMessage_V1_{ V1: &mlsv1.WelcomeMessage_V1{ - Id: msg.Id, + Id: uint64(msg.ID), CreatedNs: uint64(msg.CreatedAt.UnixNano()), Data: msg.Data, InstallationKey: msg.InstallationKey, HpkePublicKey: msg.HpkePublicKey, }, }, - }) + } } pagingInfo := &mlsv1.PagingInfo{Limit: uint32(pageSize), IdCursor: 0, Direction: direction} - if len(messages) >= pageSize { - lastMsg := msgs[len(messages)-1] - pagingInfo.IdCursor = lastMsg.Id + if len(messages) >= int(pageSize) { + lastMsg := messages[len(messages)-1] + pagingInfo.IdCursor = uint64(lastMsg.ID) } return &mlsv1.QueryWelcomeMessagesResponse{ - Messages: messages, + Messages: out, PagingInfo: pagingInfo, }, nil } diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 112776b9..970cecaf 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -80,8 +80,8 @@ func TestCreateInstallation(t *testing.T) { err := store.CreateInstallation(ctx, installationId, walletAddress, test.RandomBytes(32), test.RandomBytes(32), 0) require.NoError(t, err) - installationFromDb := &Installation{} - require.NoError(t, store.db.NewSelect().Model(installationFromDb).Where("id = ?", installationId).Scan(ctx)) + installationFromDb, err := store.queries.GetInstallation(ctx, installationId) + require.NoError(t, err) require.Equal(t, walletAddress, installationFromDb.WalletAddress) } @@ -101,8 +101,8 @@ func TestUpdateKeyPackage(t *testing.T) { err = store.UpdateKeyPackage(ctx, installationId, keyPackage2, 1) require.NoError(t, err) - installationFromDb := &Installation{} - require.NoError(t, store.db.NewSelect().Model(installationFromDb).Where("id = ?", installationId).Scan(ctx)) + installationFromDb, err := store.queries.GetInstallation(ctx, installationId) + require.NoError(t, err) require.Equal(t, keyPackage2, installationFromDb.KeyPackage) require.Equal(t, uint64(1), installationFromDb.Expiration) @@ -226,16 +226,15 @@ func TestInsertGroupMessage_Single(t *testing.T) { msg, err := store.InsertGroupMessage(ctx, []byte("group"), []byte("data")) require.NoError(t, err) require.NotNil(t, msg) - require.Equal(t, uint64(1), msg.Id) + require.Equal(t, int64(1), msg.ID) require.True(t, msg.CreatedAt.Before(time.Now().UTC()) && msg.CreatedAt.After(started)) - require.Equal(t, []byte("group"), msg.GroupId) + require.Equal(t, []byte("group"), msg.GroupID) require.Equal(t, []byte("data"), msg.Data) - msgs := make([]*GroupMessage, 0) - err = store.db.NewSelect().Model(&msgs).Scan(ctx) + msgs, err := store.queries.GetAllGroupMessages(ctx) require.NoError(t, err) require.Len(t, msgs, 1) - require.Equal(t, msg, msgs[0]) + require.Equal(t, *msg, msgs[0]) } func TestInsertGroupMessage_Duplicate(t *testing.T) { @@ -265,13 +264,12 @@ func TestInsertGroupMessage_ManyOrderedByTime(t *testing.T) { _, err = store.InsertGroupMessage(ctx, []byte("group"), []byte("data3")) require.NoError(t, err) - msgs := make([]*GroupMessage, 0) - err = store.db.NewSelect().Model(&msgs).Order("created_at DESC").Scan(ctx) + msgs, err := store.queries.GetAllGroupMessages(ctx) require.NoError(t, err) require.Len(t, msgs, 3) - require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data1"), msgs[0].Data) require.Equal(t, []byte("data2"), msgs[1].Data) - require.Equal(t, []byte("data1"), msgs[2].Data) + require.Equal(t, []byte("data3"), msgs[2].Data) } func TestInsertWelcomeMessage_Single(t *testing.T) { @@ -283,17 +281,16 @@ func TestInsertWelcomeMessage_Single(t *testing.T) { msg, err := store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data"), []byte("hpke")) require.NoError(t, err) require.NotNil(t, msg) - require.Equal(t, uint64(1), msg.Id) - require.True(t, msg.CreatedAt.Before(time.Now().UTC()) && msg.CreatedAt.After(started)) + require.Equal(t, int64(1), msg.ID) + require.True(t, msg.CreatedAt.Before(time.Now().UTC().Add(1*time.Minute)) && msg.CreatedAt.After(started)) require.Equal(t, []byte("installation"), msg.InstallationKey) require.Equal(t, []byte("data"), msg.Data) require.Equal(t, []byte("hpke"), msg.HpkePublicKey) - msgs := make([]*WelcomeMessage, 0) - err = store.db.NewSelect().Model(&msgs).Scan(ctx) + msgs, err := store.queries.GetAllWelcomeMessages(ctx) require.NoError(t, err) require.Len(t, msgs, 1) - require.Equal(t, msg, msgs[0]) + require.Equal(t, *msg, msgs[0]) } func TestInsertWelcomeMessage_Duplicate(t *testing.T) { @@ -323,13 +320,14 @@ func TestInsertWelcomeMessage_ManyOrderedByTime(t *testing.T) { _, err = store.InsertWelcomeMessage(ctx, []byte("installation"), []byte("data3"), []byte("hpke")) require.NoError(t, err) - msgs := make([]*WelcomeMessage, 0) - err = store.db.NewSelect().Model(&msgs).Order("created_at DESC").Scan(ctx) + msgs, err := store.queries.GetAllWelcomeMessages(ctx) require.NoError(t, err) require.Len(t, msgs, 3) - require.Equal(t, []byte("data3"), msgs[0].Data) + require.Equal(t, []byte("data1"), msgs[0].Data) require.Equal(t, []byte("data2"), msgs[1].Data) - require.Equal(t, []byte("data1"), msgs[2].Data) + require.Equal(t, []byte("data3"), msgs[2].Data) + require.Greater(t, msgs[1].CreatedAt, msgs[0].CreatedAt) + require.Greater(t, msgs[2].CreatedAt, msgs[1].CreatedAt) } func TestQueryGroupMessagesV1_MissingGroup(t *testing.T) { diff --git a/pkg/server/pgxdb.go b/pkg/server/pgxdb.go index 2c33c905..e36c663b 100644 --- a/pkg/server/pgxdb.go +++ b/pkg/server/pgxdb.go @@ -6,7 +6,6 @@ import ( "fmt" "time" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" _ "github.com/jackc/pgx/v5/stdlib" @@ -19,7 +18,7 @@ func newPGXDB(dsn string, waitForDB, statementTimeout time.Duration) (*sql.DB, e if err != nil { return nil, err } - config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol + config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint(statementTimeout.Milliseconds()) dbpool, err := pgxpool.NewWithConfig(context.Background(), config)