diff --git a/pkg/identity/api/v1/identity_service.go b/pkg/identity/api/v1/identity_service.go index 7f36b31d..8c1459be 100644 --- a/pkg/identity/api/v1/identity_service.go +++ b/pkg/identity/api/v1/identity_service.go @@ -7,8 +7,6 @@ import ( "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) type Service struct { @@ -78,7 +76,7 @@ Start transaction (SERIALIZABLE isolation level) End transaction */ func (s *Service) PublishIdentityUpdate(ctx context.Context, req *api.PublishIdentityUpdateRequest) (*api.PublishIdentityUpdateResponse, error) { - return s.store.PublishIdentityUpdate(ctx, req) + return s.store.PublishIdentityUpdate(ctx, req, s.validationService) } func (s *Service) GetIdentityUpdates(ctx context.Context, req *api.GetIdentityUpdatesRequest) (*api.GetIdentityUpdatesResponse, error) { @@ -97,5 +95,5 @@ func (s *Service) GetInboxIds(ctx context.Context, req *api.GetInboxIdsRequest) for the address where revocation_sequence_id is lower or NULL 2. Return the value of the 'inbox_id' column */ - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + return s.store.GetInboxIds(ctx, req) } diff --git a/pkg/identity/api/v1/identity_service_test.go b/pkg/identity/api/v1/identity_service_test.go index d0a7d2c4..24d49a4c 100644 --- a/pkg/identity/api/v1/identity_service_test.go +++ b/pkg/identity/api/v1/identity_service_test.go @@ -20,7 +20,27 @@ type mockedMLSValidationService struct { } func (m *mockedMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { - return nil, nil + + member_map := make([]*associations.MemberMap, 0) + member_map = append(member_map, &associations.MemberMap{ + Key: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "key_address"}}, + Value: &associations.Member{ + Identifier: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "ident"}}, + AddedByEntity: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "added_by_entity"}}, + }, + }) + + new_members := make([]*associations.MemberIdentifier, 0) + + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x01"}}) + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x02"}}) + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x03"}}) + + out := mlsvalidate.AssociationStateResult{ + AssociationState: &associations.AssociationState{InboxId: "test_inbox", Members: member_map, RecoveryAddress: "recovery", SeenSignatures: [][]byte{[]byte("seen"), []byte("sig")}}, + StateDiff: &associations.AssociationStateDiff{NewMembers: new_members, RemovedMembers: nil}, + } + return &out, nil } func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { diff --git a/pkg/migrations/mls/20240411200242_init-identity.up.sql b/pkg/migrations/mls/20240411200242_init-identity.up.sql index cc56762e..d3399c0b 100644 --- a/pkg/migrations/mls/20240411200242_init-identity.up.sql +++ b/pkg/migrations/mls/20240411200242_init-identity.up.sql @@ -1,7 +1,6 @@ SET statement_timeout = 0; --bun:split - CREATE TABLE inbox_log ( sequence_id BIGSERIAL PRIMARY KEY, inbox_id TEXT NOT NULL, @@ -10,11 +9,9 @@ CREATE TABLE inbox_log ( ); --bun:split - CREATE INDEX idx_inbox_log_inbox_id ON inbox_log(inbox_id); --bun:split - CREATE TABLE address_log ( address TEXT NOT NULL, inbox_id TEXT NOT NULL, @@ -23,5 +20,4 @@ CREATE TABLE address_log ( ); --bun:split - CREATE INDEX idx_address_log_address_inbox_id ON address_log(address, inbox_id); \ No newline at end of file diff --git a/pkg/mls/store/models.go b/pkg/mls/store/models.go index 4ac6342f..47d7145b 100644 --- a/pkg/mls/store/models.go +++ b/pkg/mls/store/models.go @@ -6,6 +6,15 @@ import ( "github.com/uptrace/bun" ) +type AddressLogEntry struct { + bun.BaseModel `bun:"table:address_log"` + + Address string `bun:",notnull"` + InboxId string `bun:",notnull"` + AssociationSequenceId *uint64 `bun:","` + RevocationSequenceId *uint64 `bun:","` +} + type InboxLogEntry struct { bun.BaseModel `bun:"table:inbox_log"` diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 7bd43e35..a63937ce 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -15,6 +15,31 @@ FROM inbox_log AS a AND a.sequence_id > b.sequence_id ORDER BY a.sequence_id ASC; +-- name: GetAddressLogs :many +SELECT a.address, + a.inbox_id, + a.association_sequence_id +FROM address_log a + INNER JOIN ( + SELECT address, + MAX(association_sequence_id) AS max_association_sequence_id + FROM address_log + WHERE address = ANY (@addresses::text []) + AND revocation_sequence_id IS NULL + GROUP BY address + ) b ON a.address = b.address + AND a.association_sequence_id = b.max_association_sequence_id; + +-- name: InsertAddressLog :one +INSERT INTO address_log ( + address, + inbox_id, + association_sequence_id, + revocation_sequence_id + ) +VALUES ($1, $2, $3, $4) +RETURNING *; + -- name: InsertInboxLog :one INSERT INTO inbox_log ( inbox_id, @@ -24,6 +49,20 @@ INSERT INTO inbox_log ( VALUES ($1, $2, $3) RETURNING sequence_id; +-- name: RevokeAddressFromLog :exec +UPDATE address_log +SET revocation_sequence_id = $1 +WHERE (address, inbox_id, association_sequence_id) = ( + SELECT address, + inbox_id, + MAX(association_sequence_id) + FROM address_log AS a + WHERE a.address = $2 + AND a.inbox_id = $3 + GROUP BY address, + inbox_id + ); + -- name: CreateInstallation :exec INSERT INTO installations ( id, diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index 8a2d795d..db14c010 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -82,6 +82,51 @@ func (q *Queries) FetchKeyPackages(ctx context.Context, installationIds [][]byte return items, nil } +const getAddressLogs = `-- name: GetAddressLogs :many +SELECT a.address, + a.inbox_id, + a.association_sequence_id +FROM address_log a + INNER JOIN ( + SELECT address, + MAX(association_sequence_id) AS max_association_sequence_id + FROM address_log + WHERE address = ANY ($1::text []) + AND revocation_sequence_id IS NULL + GROUP BY address + ) b ON a.address = b.address + AND a.association_sequence_id = b.max_association_sequence_id +` + +type GetAddressLogsRow struct { + Address string + InboxID string + AssociationSequenceID sql.NullInt64 +} + +func (q *Queries) GetAddressLogs(ctx context.Context, addresses []string) ([]GetAddressLogsRow, error) { + rows, err := q.db.QueryContext(ctx, getAddressLogs, pq.Array(addresses)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetAddressLogsRow + for rows.Next() { + var i GetAddressLogsRow + if err := rows.Scan(&i.Address, &i.InboxID, &i.AssociationSequenceID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getAllInboxLogs = `-- name: GetAllInboxLogs :many SELECT sequence_id, inbox_id, server_timestamp_ns, identity_update_proto FROM inbox_log @@ -205,6 +250,41 @@ func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessa return items, nil } +const insertAddressLog = `-- name: InsertAddressLog :one +INSERT INTO address_log ( + address, + inbox_id, + association_sequence_id, + revocation_sequence_id + ) +VALUES ($1, $2, $3, $4) +RETURNING address, inbox_id, association_sequence_id, revocation_sequence_id +` + +type InsertAddressLogParams struct { + Address string + InboxID string + AssociationSequenceID sql.NullInt64 + RevocationSequenceID sql.NullInt64 +} + +func (q *Queries) InsertAddressLog(ctx context.Context, arg InsertAddressLogParams) (AddressLog, error) { + row := q.db.QueryRowContext(ctx, insertAddressLog, + arg.Address, + arg.InboxID, + arg.AssociationSequenceID, + arg.RevocationSequenceID, + ) + var i AddressLog + err := row.Scan( + &i.Address, + &i.InboxID, + &i.AssociationSequenceID, + &i.RevocationSequenceID, + ) + return i, err +} + const insertGroupMessage = `-- name: InsertGroupMessage :one INSERT INTO group_messages (group_id, data, group_id_data_hash) VALUES ($1, $2, $3) @@ -462,6 +542,32 @@ func (q *Queries) QueryGroupMessagesWithCursorDesc(ctx context.Context, arg Quer return items, nil } +const revokeAddressFromLog = `-- name: RevokeAddressFromLog :exec +UPDATE address_log +SET revocation_sequence_id = $1 +WHERE (address, inbox_id, association_sequence_id) = ( + SELECT address, + inbox_id, + MAX(association_sequence_id) + FROM address_log AS a + WHERE a.address = $2 + AND a.inbox_id = $3 + GROUP BY address, + inbox_id + ) +` + +type RevokeAddressFromLogParams struct { + RevocationSequenceID sql.NullInt64 + Address string + InboxID string +} + +func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFromLogParams) error { + _, err := q.db.ExecContext(ctx, revokeAddressFromLog, arg.RevocationSequenceID, arg.Address, arg.InboxID) + return err +} + const revokeInstallation = `-- name: RevokeInstallation :exec UPDATE installations SET revoked_at = $1 diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index 9e5f6da5..172579fd 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -13,6 +13,7 @@ import ( "github.com/uptrace/bun/migrate" migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" @@ -30,8 +31,9 @@ type Store struct { } type IdentityStore interface { - PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) + PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) GetInboxLogs(ctx context.Context, req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error) + GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error) } type MlsStore interface { @@ -67,7 +69,38 @@ func New(ctx context.Context, config Config) (*Store, error) { return s, nil } -func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) { +func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error) { + + addresses := []string{} + for _, request := range req.Requests { + addresses = append(addresses, request.GetAddress()) + } + + addressLogEntries, err := s.queries.GetAddressLogs(ctx, addresses) + if err != nil { + return nil, err + } + + out := make([]*identity.GetInboxIdsResponse_Response, len(addresses)) + + for index, address := range addresses { + resp := identity.GetInboxIdsResponse_Response{} + resp.Address = address + + for _, log_entry := range addressLogEntries { + if log_entry.Address == address { + resp.InboxId = &log_entry.InboxID + } + } + out[index] = &resp + } + + return &identity.GetInboxIdsResponse{ + Responses: out, + }, nil +} + +func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) { new_update := req.GetIdentityUpdate() if new_update == nil { return nil, errors.New("IdentityUpdate is required") @@ -93,23 +126,57 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish } _ = append(updates, new_update) - // TODO: Validate the updates, and abort transaction if failed + state, err := validationService.GetAssociationState(ctx, updates, []*associations.IdentityUpdate{new_update}) + if err != nil { + return err + } + s.log.Info("Got association state", zap.Any("state", state)) protoBytes, err := proto.Marshal(new_update) if err != nil { return err } - _, err = txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{ + sequence_id, err := txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{ InboxID: new_update.GetInboxId(), ServerTimestampNs: nowNs(), IdentityUpdateProto: protoBytes, }) + s.log.Info("Inserted inbox log", zap.Any("sequence_id", sequence_id)) + if err != nil { return err } - // TODO: Insert or update the address_log table using sequence_id + + for _, new_member := range state.StateDiff.NewMembers { + s.log.Info("New member", zap.Any("member", new_member)) + if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok { + _, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{ + Address: address.Address, + InboxID: state.AssociationState.InboxId, + AssociationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, + RevocationSequenceID: sql.NullInt64{Valid: false}, + }) + if err != nil { + return err + } + } + } + + for _, removed_member := range state.StateDiff.RemovedMembers { + s.log.Info("New member", zap.Any("member", removed_member)) + if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok { + err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{ + Address: address.Address, + InboxID: state.AssociationState.InboxId, + RevocationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, + }) + if err != nil { + return err + } + } + } return nil }); err != nil { diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 379584f0..112776b9 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -2,11 +2,14 @@ package store import ( "context" + "database/sql" "sort" "testing" "time" "github.com/stretchr/testify/require" + queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" + identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" ) @@ -26,6 +29,46 @@ func NewTestStore(t *testing.T) (*Store, func()) { return store, dbCleanup } +func TestInboxIds(t *testing.T) { + store, cleanup := NewTestStore(t) + defer cleanup() + ctx := context.Background() + + _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 3}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 4}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + + reqs := make([]*identity.GetInboxIdsRequest_Request, 0) + reqs = append(reqs, &identity.GetInboxIdsRequest_Request{ + Address: "address", + }) + req := &identity.GetInboxIdsRequest{ + Requests: reqs, + } + resp, _ := store.GetInboxIds(context.Background(), req) + + require.Equal(t, "correct", *resp.Responses[0].InboxId) + + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct_inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 5}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + resp, _ = store.GetInboxIds(context.Background(), req) + require.Equal(t, "correct_inbox2", *resp.Responses[0].InboxId) + + reqs = append(reqs, &identity.GetInboxIdsRequest_Request{Address: "address2"}) + req = &identity.GetInboxIdsRequest{ + Requests: reqs, + } + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address2", InboxID: "inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 8}, RevocationSequenceID: sql.NullInt64{Valid: false}}) + require.NoError(t, err) + resp, _ = store.GetInboxIds(context.Background(), req) + require.Equal(t, "inbox2", *resp.Responses[1].InboxId) +} + func TestCreateInstallation(t *testing.T) { store, cleanup := NewTestStore(t) defer cleanup()