From 64d8d5277dad710d55681e02c723bed9eb876252 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Thu, 25 Apr 2024 17:13:39 -0400 Subject: [PATCH] Add InsertLog Query (#383) * Add insertlog query * validation service * insert log * revocation for removed members * lint * remove unnecessary log * change test to use query from sqlc * remove comments * fix tests --- pkg/identity/api/v1/identity_service.go | 6 +- pkg/identity/api/v1/identity_service_test.go | 22 ++++- pkg/mls/store/queries.sql | 24 ++++++ pkg/mls/store/queries/queries.sql.go | 89 ++++++++++++++++---- pkg/mls/store/store.go | 45 ++++++++-- pkg/mls/store/store_test.go | 39 ++------- 6 files changed, 168 insertions(+), 57 deletions(-) diff --git a/pkg/identity/api/v1/identity_service.go b/pkg/identity/api/v1/identity_service.go index 7f36b31d..8c1459be 100644 --- a/pkg/identity/api/v1/identity_service.go +++ b/pkg/identity/api/v1/identity_service.go @@ -7,8 +7,6 @@ import ( "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) type Service struct { @@ -78,7 +76,7 @@ Start transaction (SERIALIZABLE isolation level) End transaction */ func (s *Service) PublishIdentityUpdate(ctx context.Context, req *api.PublishIdentityUpdateRequest) (*api.PublishIdentityUpdateResponse, error) { - return s.store.PublishIdentityUpdate(ctx, req) + return s.store.PublishIdentityUpdate(ctx, req, s.validationService) } func (s *Service) GetIdentityUpdates(ctx context.Context, req *api.GetIdentityUpdatesRequest) (*api.GetIdentityUpdatesResponse, error) { @@ -97,5 +95,5 @@ func (s *Service) GetInboxIds(ctx context.Context, req *api.GetInboxIdsRequest) for the address where revocation_sequence_id is lower or NULL 2. Return the value of the 'inbox_id' column */ - return nil, status.Errorf(codes.Unimplemented, "unimplemented") + return s.store.GetInboxIds(ctx, req) } diff --git a/pkg/identity/api/v1/identity_service_test.go b/pkg/identity/api/v1/identity_service_test.go index d0a7d2c4..24d49a4c 100644 --- a/pkg/identity/api/v1/identity_service_test.go +++ b/pkg/identity/api/v1/identity_service_test.go @@ -20,7 +20,27 @@ type mockedMLSValidationService struct { } func (m *mockedMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) { - return nil, nil + + member_map := make([]*associations.MemberMap, 0) + member_map = append(member_map, &associations.MemberMap{ + Key: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "key_address"}}, + Value: &associations.Member{ + Identifier: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "ident"}}, + AddedByEntity: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "added_by_entity"}}, + }, + }) + + new_members := make([]*associations.MemberIdentifier, 0) + + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x01"}}) + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x02"}}) + new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x03"}}) + + out := mlsvalidate.AssociationStateResult{ + AssociationState: &associations.AssociationState{InboxId: "test_inbox", Members: member_map, RecoveryAddress: "recovery", SeenSignatures: [][]byte{[]byte("seen"), []byte("sig")}}, + StateDiff: &associations.AssociationStateDiff{NewMembers: new_members, RemovedMembers: nil}, + } + return &out, nil } func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) { diff --git a/pkg/mls/store/queries.sql b/pkg/mls/store/queries.sql index 8b1e6a4d..a63937ce 100644 --- a/pkg/mls/store/queries.sql +++ b/pkg/mls/store/queries.sql @@ -30,6 +30,16 @@ FROM address_log a ) b ON a.address = b.address AND a.association_sequence_id = b.max_association_sequence_id; +-- name: InsertAddressLog :one +INSERT INTO address_log ( + address, + inbox_id, + association_sequence_id, + revocation_sequence_id + ) +VALUES ($1, $2, $3, $4) +RETURNING *; + -- name: InsertInboxLog :one INSERT INTO inbox_log ( inbox_id, @@ -39,6 +49,20 @@ INSERT INTO inbox_log ( VALUES ($1, $2, $3) RETURNING sequence_id; +-- name: RevokeAddressFromLog :exec +UPDATE address_log +SET revocation_sequence_id = $1 +WHERE (address, inbox_id, association_sequence_id) = ( + SELECT address, + inbox_id, + MAX(association_sequence_id) + FROM address_log AS a + WHERE a.address = $2 + AND a.inbox_id = $3 + GROUP BY address, + inbox_id + ); + -- name: CreateInstallation :exec INSERT INTO installations ( id, diff --git a/pkg/mls/store/queries/queries.sql.go b/pkg/mls/store/queries/queries.sql.go index b2ef7b4d..36d40044 100644 --- a/pkg/mls/store/queries/queries.sql.go +++ b/pkg/mls/store/queries/queries.sql.go @@ -73,25 +73,19 @@ func (q *Queries) FetchKeyPackages(ctx context.Context, installationIds [][]byte } const getAddressLogs = `-- name: GetAddressLogs :many -SELECT - a.address, +SELECT a.address, a.inbox_id, a.association_sequence_id -FROM - address_log a -INNER JOIN ( - SELECT - address, - MAX(association_sequence_id) AS max_association_sequence_id - FROM - address_log - WHERE - address = ANY ($1::text[]) - AND - revocation_sequence_id IS NULL - GROUP BY - address -) b ON a.address = b.address AND a.association_sequence_id = b.max_association_sequence_id +FROM address_log a + INNER JOIN ( + SELECT address, + MAX(association_sequence_id) AS max_association_sequence_id + FROM address_log + WHERE address = ANY ($1::text []) + AND revocation_sequence_id IS NULL + GROUP BY address + ) b ON a.address = b.address + AND a.association_sequence_id = b.max_association_sequence_id ` type GetAddressLogsRow struct { @@ -237,6 +231,41 @@ func (q *Queries) GetInboxLogFiltered(ctx context.Context, filters json.RawMessa return items, nil } +const insertAddressLog = `-- name: InsertAddressLog :one +INSERT INTO address_log ( + address, + inbox_id, + association_sequence_id, + revocation_sequence_id + ) +VALUES ($1, $2, $3, $4) +RETURNING address, inbox_id, association_sequence_id, revocation_sequence_id +` + +type InsertAddressLogParams struct { + Address string + InboxID string + AssociationSequenceID sql.NullInt64 + RevocationSequenceID sql.NullInt64 +} + +func (q *Queries) InsertAddressLog(ctx context.Context, arg InsertAddressLogParams) (AddressLog, error) { + row := q.db.QueryRowContext(ctx, insertAddressLog, + arg.Address, + arg.InboxID, + arg.AssociationSequenceID, + arg.RevocationSequenceID, + ) + var i AddressLog + err := row.Scan( + &i.Address, + &i.InboxID, + &i.AssociationSequenceID, + &i.RevocationSequenceID, + ) + return i, err +} + const insertGroupMessage = `-- name: InsertGroupMessage :one INSERT INTO group_messages (group_id, data, group_id_data_hash) VALUES ($1, $2, $3) @@ -481,6 +510,32 @@ func (q *Queries) QueryGroupMessagesWithCursorDesc(ctx context.Context, arg Quer return items, nil } +const revokeAddressFromLog = `-- name: RevokeAddressFromLog :exec +UPDATE address_log +SET revocation_sequence_id = $1 +WHERE (address, inbox_id, association_sequence_id) = ( + SELECT address, + inbox_id, + MAX(association_sequence_id) + FROM address_log AS a + WHERE a.address = $2 + AND a.inbox_id = $3 + GROUP BY address, + inbox_id + ) +` + +type RevokeAddressFromLogParams struct { + RevocationSequenceID sql.NullInt64 + Address string + InboxID string +} + +func (q *Queries) RevokeAddressFromLog(ctx context.Context, arg RevokeAddressFromLogParams) error { + _, err := q.db.ExecContext(ctx, revokeAddressFromLog, arg.RevocationSequenceID, arg.Address, arg.InboxID) + return err +} + const revokeInstallation = `-- name: RevokeInstallation :exec UPDATE installations SET revoked_at = $1 diff --git a/pkg/mls/store/store.go b/pkg/mls/store/store.go index db2300da..172579fd 100644 --- a/pkg/mls/store/store.go +++ b/pkg/mls/store/store.go @@ -13,6 +13,7 @@ import ( "github.com/uptrace/bun/migrate" migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls" queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" + "github.com/xmtp/xmtp-node-go/pkg/mlsvalidate" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" @@ -30,7 +31,7 @@ type Store struct { } type IdentityStore interface { - PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) + PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) GetInboxLogs(ctx context.Context, req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error) } @@ -99,7 +100,7 @@ func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsReques }, nil } -func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) { +func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) { new_update := req.GetIdentityUpdate() if new_update == nil { return nil, errors.New("IdentityUpdate is required") @@ -125,23 +126,57 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish } _ = append(updates, new_update) - // TODO: Validate the updates, and abort transaction if failed + state, err := validationService.GetAssociationState(ctx, updates, []*associations.IdentityUpdate{new_update}) + if err != nil { + return err + } + s.log.Info("Got association state", zap.Any("state", state)) protoBytes, err := proto.Marshal(new_update) if err != nil { return err } - _, err = txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{ + sequence_id, err := txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{ InboxID: new_update.GetInboxId(), ServerTimestampNs: nowNs(), IdentityUpdateProto: protoBytes, }) + s.log.Info("Inserted inbox log", zap.Any("sequence_id", sequence_id)) + if err != nil { return err } - // TODO: Insert or update the address_log table using sequence_id + + for _, new_member := range state.StateDiff.NewMembers { + s.log.Info("New member", zap.Any("member", new_member)) + if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok { + _, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{ + Address: address.Address, + InboxID: state.AssociationState.InboxId, + AssociationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, + RevocationSequenceID: sql.NullInt64{Valid: false}, + }) + if err != nil { + return err + } + } + } + + for _, removed_member := range state.StateDiff.RemovedMembers { + s.log.Info("New member", zap.Any("member", removed_member)) + if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok { + err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{ + Address: address.Address, + InboxID: state.AssociationState.InboxId, + RevocationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id}, + }) + if err != nil { + return err + } + } + } return nil }); err != nil { diff --git a/pkg/mls/store/store_test.go b/pkg/mls/store/store_test.go index 0f74996e..112776b9 100644 --- a/pkg/mls/store/store_test.go +++ b/pkg/mls/store/store_test.go @@ -2,11 +2,13 @@ package store import ( "context" + "database/sql" "sort" "testing" "time" "github.com/stretchr/testify/require" + queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries" identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1" mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1" test "github.com/xmtp/xmtp-node-go/pkg/testing" @@ -27,38 +29,18 @@ func NewTestStore(t *testing.T) (*Store, func()) { return store, dbCleanup } -func InsertAddressLog(store *Store, address string, inboxId string, associationSequenceId *uint64, revocationSequenceId *uint64) error { - - entry := AddressLogEntry{ - Address: address, - InboxId: inboxId, - AssociationSequenceId: associationSequenceId, - RevocationSequenceId: nil, - } - ctx := context.Background() - - _, err := store.db.NewInsert(). - Model(&entry). - Exec(ctx) - - return err -} - func TestInboxIds(t *testing.T) { store, cleanup := NewTestStore(t) defer cleanup() + ctx := context.Background() - seq, rev := uint64(1), uint64(5) - err := InsertAddressLog(store, "address", "inbox1", &seq, &rev) + _, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - seq, rev = uint64(2), uint64(8) - err = InsertAddressLog(store, "address", "inbox1", &seq, &rev) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - seq, rev = uint64(3), uint64(9) - err = InsertAddressLog(store, "address", "inbox1", &seq, &rev) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 3}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) - seq, rev = uint64(4), uint64(1) - err = InsertAddressLog(store, "address", "correct", &seq, &rev) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 4}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) reqs := make([]*identity.GetInboxIdsRequest_Request, 0) @@ -69,12 +51,10 @@ func TestInboxIds(t *testing.T) { Requests: reqs, } resp, _ := store.GetInboxIds(context.Background(), req) - t.Log(resp) require.Equal(t, "correct", *resp.Responses[0].InboxId) - seq = uint64(5) - err = InsertAddressLog(store, "address", "correct_inbox2", &seq, nil) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct_inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 5}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) resp, _ = store.GetInboxIds(context.Background(), req) require.Equal(t, "correct_inbox2", *resp.Responses[0].InboxId) @@ -83,8 +63,7 @@ func TestInboxIds(t *testing.T) { req = &identity.GetInboxIdsRequest{ Requests: reqs, } - seq, rev = uint64(8), uint64(2) - err = InsertAddressLog(store, "address2", "inbox2", &seq, &rev) + _, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address2", InboxID: "inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 8}, RevocationSequenceID: sql.NullInt64{Valid: false}}) require.NoError(t, err) resp, _ = store.GetInboxIds(context.Background(), req) require.Equal(t, "inbox2", *resp.Responses[1].InboxId)