diff --git a/pkg/mls/store/models.go b/pkg/mls/store/models.go index 4ac6342f..47d7145b 100644 --- a/pkg/mls/store/models.go +++ b/pkg/mls/store/models.go @@ -6,6 +6,15 @@ import ( "github.com/uptrace/bun" ) +type AddressLogEntry struct { + bun.BaseModel `bun:"table:address_log"` + + Address string `bun:",notnull"` + InboxId string `bun:",notnull"` + AssociationSequenceId *uint64 `bun:","` + RevocationSequenceId *uint64 `bun:","` +} + type InboxLogEntry struct { bun.BaseModel `bun:"table:inbox_log"` diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 57b04b4b..76bc1375 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -10,6 +10,27 @@ JOIN ( ) as b on b.inbox_id = a.inbox_id AND a.sequence_id > b.sequence_id ORDER BY a.sequence_id ASC; +-- name: GetAddressLogs :many +SELECT + a.address, + a.inbox_id, + a.association_sequence_id +FROM + address_log a +INNER JOIN ( + SELECT + address, + MAX(association_sequence_id) AS max_association_sequence_id + FROM + address_log + WHERE + address = ANY (@addresses::text[]) + AND + revocation_sequence_id IS NULL + GROUP BY + address +) b ON a.address = b.address AND a.association_sequence_id = b.max_association_sequence_id; + -- name: InsertInboxLog :one INSERT INTO inbox_log (inbox_id, server_timestamp_ns, identity_update_proto) VALUES ($1, $2, $3) diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index c9556056..b2ef7b4d 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -72,6 +72,57 @@ func (q *Queries) FetchKeyPackages(ctx context.Context, installationIds [][]byte return items, nil } +const getAddressLogs = `-- name: GetAddressLogs :many +SELECT + a.address, + a.inbox_id, + a.association_sequence_id +FROM + address_log a +INNER JOIN ( + SELECT + address, + MAX(association_sequence_id) AS max_association_sequence_id + FROM + address_log + WHERE + address = ANY ($1::text[]) + AND + revocation_sequence_id IS NULL + GROUP BY + address +) b ON a.address = b.address AND a.association_sequence_id = b.max_association_sequence_id +` + +type GetAddressLogsRow struct { + Address string + InboxID string + AssociationSequenceID sql.NullInt64 +} + +func (q *Queries) GetAddressLogs(ctx context.Context, addresses []string) ([]GetAddressLogsRow, error) { + rows, err := q.db.QueryContext(ctx, getAddressLogs, pq.Array(addresses)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAddressLogsRow + for rows.Next() { + var i GetAddressLogsRow + if err := rows.Scan(&i.Address, &i.InboxID, &i.AssociationSequenceID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAllInboxLogs = `-- name: GetAllInboxLogs :many SELECT sequence_id, inbox_id, server_timestamp_ns, identity_update_proto FROM inbox_log WHERE inbox_id = $1 diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 9e5f6da5..db2300da 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -32,6 +32,7 @@ type Store struct { type IdentityStore interface { PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) GetInboxLogs(ctx context.Context, req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error) + GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error) } type MlsStore interface { @@ -67,6 +68,37 @@ func New(ctx context.Context, config Config) (*Store, error) { return s, nil } +func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error) { + + addresses := []string{} + for _, request := range req.Requests { + addresses = append(addresses, request.GetAddress()) + } + + addressLogEntries, err := s.queries.GetAddressLogs(ctx, addresses) + if err != nil { + return nil, err + } + + out := make([]*identity.GetInboxIdsResponse_Response, len(addresses)) + + for index, address := range addresses { + resp := identity.GetInboxIdsResponse_Response{} + resp.Address = address + + for _, log_entry := range addressLogEntries { + if log_entry.Address == address { + resp.InboxId = &log_entry.InboxID + } + } + out[index] = &resp + } + + return &identity.GetInboxIdsResponse{ + Responses: out, + }, nil +} + func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) { new_update := req.GetIdentityUpdate() if new_update == nil { diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 379584f0..0f74996e 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/require" + identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" ) @@ -26,6 +27,69 @@ func NewTestStore(t *testing.T) (*Store, func()) { return store, dbCleanup } +func InsertAddressLog(store *Store, address string, inboxId string, associationSequenceId *uint64, revocationSequenceId *uint64) error { + + entry := AddressLogEntry{ + Address: address, + InboxId: inboxId, + AssociationSequenceId: associationSequenceId, + RevocationSequenceId: nil, + } + ctx := context.Background() + + _, err := store.db.NewInsert(). + Model(&entry). + Exec(ctx) + + return err +} + +func TestInboxIds(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + + seq, rev := uint64(1), uint64(5) + err := InsertAddressLog(store, "address", "inbox1", &seq, &rev) + require.NoError(t, err) + seq, rev = uint64(2), uint64(8) + err = InsertAddressLog(store, "address", "inbox1", &seq, &rev) + require.NoError(t, err) + seq, rev = uint64(3), uint64(9) + err = InsertAddressLog(store, "address", "inbox1", &seq, &rev) + require.NoError(t, err) + seq, rev = uint64(4), uint64(1) + err = InsertAddressLog(store, "address", "correct", &seq, &rev) + require.NoError(t, err) + + reqs := make([]*identity.GetInboxIdsRequest_Request, 0) + reqs = append(reqs, &identity.GetInboxIdsRequest_Request{ + Address: "address", + }) + req := &identity.GetInboxIdsRequest{ + Requests: reqs, + } + resp, _ := store.GetInboxIds(context.Background(), req) + t.Log(resp) + + require.Equal(t, "correct", *resp.Responses[0].InboxId) + + seq = uint64(5) + err = InsertAddressLog(store, "address", "correct_inbox2", &seq, nil) + require.NoError(t, err) + resp, _ = store.GetInboxIds(context.Background(), req) + require.Equal(t, "correct_inbox2", *resp.Responses[0].InboxId) + + reqs = append(reqs, &identity.GetInboxIdsRequest_Request{Address: "address2"}) + req = &identity.GetInboxIdsRequest{ + Requests: reqs, + } + seq, rev = uint64(8), uint64(2) + err = InsertAddressLog(store, "address2", "inbox2", &seq, &rev) + require.NoError(t, err) + resp, _ = store.GetInboxIds(context.Background(), req) + require.Equal(t, "inbox2", *resp.Responses[1].InboxId) +} + func TestCreateInstallation(t *testing.T) { store, cleanup := NewTestStore(t) defer cleanup()