Skip to content

Commit

Permalink
Add InsertLog Query (#383)
Browse files Browse the repository at this point in the history
* Add insertlog query

* validation service

* insert log

* revocation for removed members

* lint

* remove unnecessary log

* change test to use query from sqlc

* remove comments

* fix tests
  • Loading branch information
insipx authored Apr 25, 2024
1 parent dc19304 commit 065a73c
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 43 deletions.
6 changes: 2 additions & 4 deletions pkg/identity/api/v1/identity_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type Service struct {
Expand Down Expand Up @@ -78,7 +76,7 @@ Start transaction (SERIALIZABLE isolation level)
End transaction
*/
func (s *Service) PublishIdentityUpdate(ctx context.Context, req *api.PublishIdentityUpdateRequest) (*api.PublishIdentityUpdateResponse, error) {
return s.store.PublishIdentityUpdate(ctx, req)
return s.store.PublishIdentityUpdate(ctx, req, s.validationService)
}

func (s *Service) GetIdentityUpdates(ctx context.Context, req *api.GetIdentityUpdatesRequest) (*api.GetIdentityUpdatesResponse, error) {
Expand All @@ -97,5 +95,5 @@ func (s *Service) GetInboxIds(ctx context.Context, req *api.GetInboxIdsRequest)
for the address where revocation_sequence_id is lower or NULL
2. Return the value of the 'inbox_id' column
*/
return nil, status.Errorf(codes.Unimplemented, "unimplemented")
return s.store.GetInboxIds(ctx, req)
}
22 changes: 21 additions & 1 deletion pkg/identity/api/v1/identity_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,27 @@ type mockedMLSValidationService struct {
}

func (m *mockedMLSValidationService) GetAssociationState(ctx context.Context, oldUpdates []*associations.IdentityUpdate, newUpdates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) {
return nil, nil

member_map := make([]*associations.MemberMap, 0)
member_map = append(member_map, &associations.MemberMap{
Key: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "key_address"}},
Value: &associations.Member{
Identifier: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "ident"}},
AddedByEntity: &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "added_by_entity"}},
},
})

new_members := make([]*associations.MemberIdentifier, 0)

new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x01"}})
new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x02"}})
new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x03"}})

out := mlsvalidate.AssociationStateResult{
AssociationState: &associations.AssociationState{InboxId: "test_inbox", Members: member_map, RecoveryAddress: "recovery", SeenSignatures: [][]byte{[]byte("seen"), []byte("sig")}},
StateDiff: &associations.AssociationStateDiff{NewMembers: new_members, RemovedMembers: nil},
}
return &out, nil
}

func (m *mockedMLSValidationService) ValidateKeyPackages(ctx context.Context, keyPackages [][]byte) ([]mlsvalidate.IdentityValidationResult, error) {
Expand Down
15 changes: 15 additions & 0 deletions pkg/mls/store/queries.sql
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ FROM address_log a
) b ON a.address = b.address
AND a.association_sequence_id = b.max_association_sequence_id;

-- name: InsertAddressLog :one
INSERT INTO address_log (address, inbox_id, association_sequence_id, revocation_sequence_id)
VALUES ($1, $2, $3, $4)
RETURNING *;

-- name: InsertInboxLog :one
INSERT INTO inbox_log (
inbox_id,
Expand All @@ -39,6 +44,16 @@ INSERT INTO inbox_log (
VALUES ($1, $2, $3)
RETURNING sequence_id;

-- name: RevokeAddressFromLog :exec
UPDATE address_log
SET revocation_sequence_id = $1
WHERE (address, inbox_id, association_sequence_id) = (
SELECT address, inbox_id, MAX(association_sequence_id)
FROM address_log AS a
WHERE a.address = $2 AND a.inbox_id = $3
GROUP BY address, inbox_id
);

-- name: CreateInstallation :exec
INSERT INTO installations (
id,
Expand Down
2 changes: 1 addition & 1 deletion pkg/mls/store/queries/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mls/store/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 53 additions & 1 deletion pkg/mls/store/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 40 additions & 5 deletions pkg/mls/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/uptrace/bun/migrate"
migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls"
queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
"github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1"
Expand All @@ -30,7 +31,7 @@ type Store struct {
}

type IdentityStore interface {
PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error)
PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error)
GetInboxLogs(ctx context.Context, req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error)
GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error)
}
Expand Down Expand Up @@ -99,7 +100,7 @@ func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsReques
}, nil
}

func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest) (*identity.PublishIdentityUpdateResponse, error) {
func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) {
new_update := req.GetIdentityUpdate()
if new_update == nil {
return nil, errors.New("IdentityUpdate is required")
Expand All @@ -125,23 +126,57 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
}
_ = append(updates, new_update)

// TODO: Validate the updates, and abort transaction if failed
state, err := validationService.GetAssociationState(ctx, updates, []*associations.IdentityUpdate{new_update})
if err != nil {
return err
}

s.log.Info("Got association state", zap.Any("state", state))
protoBytes, err := proto.Marshal(new_update)
if err != nil {
return err
}

_, err = txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{
sequence_id, err := txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{
InboxID: new_update.GetInboxId(),
ServerTimestampNs: nowNs(),
IdentityUpdateProto: protoBytes,
})

s.log.Info("Inserted inbox log", zap.Any("sequence_id", sequence_id))

if err != nil {
return err
}
// TODO: Insert or update the address_log table using sequence_id

for _, new_member := range state.StateDiff.NewMembers {
s.log.Info("New member", zap.Any("member", new_member))
if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok {
_, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{
Address: address.Address,
InboxID: state.AssociationState.InboxId,
AssociationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id},
RevocationSequenceID: sql.NullInt64{Valid: false},
})
if err != nil {
return err
}
}
}

for _, removed_member := range state.StateDiff.RemovedMembers {
s.log.Info("New member", zap.Any("member", removed_member))
if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok {
err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{
Address: address.Address,
InboxID: state.AssociationState.InboxId,
RevocationSequenceID: sql.NullInt64{Valid: true, Int64: sequence_id},
})
if err != nil {
return err
}
}
}

return nil
}); err != nil {
Expand Down
39 changes: 9 additions & 30 deletions pkg/mls/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package store

import (
"context"
"database/sql"
"sort"
"testing"
"time"

"github.com/stretchr/testify/require"
queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries"
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1"
test "github.com/xmtp/xmtp-node-go/pkg/testing"
Expand All @@ -27,38 +29,18 @@ func NewTestStore(t *testing.T) (*Store, func()) {
return store, dbCleanup
}

func InsertAddressLog(store *Store, address string, inboxId string, associationSequenceId *uint64, revocationSequenceId *uint64) error {

entry := AddressLogEntry{
Address: address,
InboxId: inboxId,
AssociationSequenceId: associationSequenceId,
RevocationSequenceId: nil,
}
ctx := context.Background()

_, err := store.db.NewInsert().
Model(&entry).
Exec(ctx)

return err
}

func TestInboxIds(t *testing.T) {
store, cleanup := NewTestStore(t)
defer cleanup()
ctx := context.Background()

seq, rev := uint64(1), uint64(5)
err := InsertAddressLog(store, "address", "inbox1", &seq, &rev)
_, err := store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 1}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)
seq, rev = uint64(2), uint64(8)
err = InsertAddressLog(store, "address", "inbox1", &seq, &rev)
_, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 2}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)
seq, rev = uint64(3), uint64(9)
err = InsertAddressLog(store, "address", "inbox1", &seq, &rev)
_, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "inbox1", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 3}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)
seq, rev = uint64(4), uint64(1)
err = InsertAddressLog(store, "address", "correct", &seq, &rev)
_, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 4}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)

reqs := make([]*identity.GetInboxIdsRequest_Request, 0)
Expand All @@ -69,12 +51,10 @@ func TestInboxIds(t *testing.T) {
Requests: reqs,
}
resp, _ := store.GetInboxIds(context.Background(), req)
t.Log(resp)

require.Equal(t, "correct", *resp.Responses[0].InboxId)

seq = uint64(5)
err = InsertAddressLog(store, "address", "correct_inbox2", &seq, nil)
_, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address", InboxID: "correct_inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 5}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)
resp, _ = store.GetInboxIds(context.Background(), req)
require.Equal(t, "correct_inbox2", *resp.Responses[0].InboxId)
Expand All @@ -83,8 +63,7 @@ func TestInboxIds(t *testing.T) {
req = &identity.GetInboxIdsRequest{
Requests: reqs,
}
seq, rev = uint64(8), uint64(2)
err = InsertAddressLog(store, "address2", "inbox2", &seq, &rev)
_, err = store.queries.InsertAddressLog(ctx, queries.InsertAddressLogParams{Address: "address2", InboxID: "inbox2", AssociationSequenceID: sql.NullInt64{Valid: true, Int64: 8}, RevocationSequenceID: sql.NullInt64{Valid: false}})
require.NoError(t, err)
resp, _ = store.GetInboxIds(context.Background(), req)
require.Equal(t, "inbox2", *resp.Responses[1].InboxId)
Expand Down

0 comments on commit 065a73c

Please sign in to comment.