Skip to content

Commit

Permalink
Fix transaction isolation issues (#393)
Browse files Browse the repository at this point in the history
* Deal with transaction serialization issues

* Revert schema change

* Use pg_advisory_xact_lock

* Check return value of lock

* Add random sleep

* Remove unused line

* Shorten array
  • Loading branch information
neekolas authored May 29, 2024
1 parent e8c5c1b commit b0064d6
Show file tree
Hide file tree
Showing 13 changed files with 618 additions and 130 deletions.
1 change: 1 addition & 0 deletions dev/generate
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ set -e
go generate ./...

mockgen -package api github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1 MlsApi_SubscribeGroupMessagesServer,MlsApi_SubscribeWelcomeMessagesServer > pkg/mls/api/v1/mock.gen.go
mockgen -package mocks -source ./pkg/mlsvalidate/service.go MLSValidationService > pkg/mlsvalidate/mocks/mock.gen.go
rm -rf pkg/proto/**/*.pb.go pkg/proto/**/*.pb.gw.go pkg/proto/**/*.swagger.json
if ! buf generate https://github.com/xmtp/proto.git#branch=main,subdir=proto; then
echo "Failed to generate protobuf definitions"
Expand Down
4 changes: 2 additions & 2 deletions pkg/authn/authn.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions pkg/mls/store/queries.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
-- name: LockInboxLog :exec
SELECT
pg_advisory_xact_lock(hashtext(@inbox_id));

-- name: GetAllInboxLogs :many
SELECT
*
Expand All @@ -6,8 +10,7 @@ FROM
WHERE
inbox_id = $1
ORDER BY
sequence_id ASC
FOR UPDATE;
sequence_id ASC;

-- name: GetInboxLogFiltered :many
SELECT
Expand Down
2 changes: 1 addition & 1 deletion pkg/mls/store/queries/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/mls/store/queries/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions pkg/mls/store/queries/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 23 additions & 15 deletions pkg/mls/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
"github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1"
"github.com/xmtp/xmtp-node-go/pkg/utils"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
)
Expand Down Expand Up @@ -102,13 +103,21 @@ func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsReques
}

func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) {
new_update := req.GetIdentityUpdate()
if new_update == nil {
newUpdate := req.GetIdentityUpdate()
if newUpdate == nil {
return nil, errors.New("IdentityUpdate is required")
}

if err := s.RunInSerializableTx(ctx, 3, func(ctx context.Context, txQueries *queries.Queries) error {
inboxLogEntries, err := txQueries.GetAllInboxLogs(ctx, new_update.GetInboxId())
if err := s.RunInRepeatableReadTx(ctx, 3, func(ctx context.Context, txQueries *queries.Queries) error {
inboxId := newUpdate.GetInboxId()
// We use a pg_advisory_lock to lock the inbox_id instead of SELECT FOR UPDATE
// This allows the lock to be enforced even when there are no existing `inbox_log`s
if err := txQueries.LockInboxLog(ctx, inboxId); err != nil {
return err
}

log := s.log.With(zap.String("inbox_id", inboxId))
inboxLogEntries, err := txQueries.GetAllInboxLogs(ctx, inboxId)
if err != nil {
return err
}
Expand All @@ -117,41 +126,39 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
return errors.New("inbox log is full")
}

updates := make([]*associations.IdentityUpdate, 0, len(inboxLogEntries)+1)
updates := make([]*associations.IdentityUpdate, 0, len(inboxLogEntries))
for _, log := range inboxLogEntries {
identityUpdate := &associations.IdentityUpdate{}
if err := proto.Unmarshal(log.IdentityUpdateProto, identityUpdate); err != nil {
return err
}
updates = append(updates, identityUpdate)
}
_ = append(updates, new_update)

state, err := validationService.GetAssociationState(ctx, updates, []*associations.IdentityUpdate{new_update})
state, err := validationService.GetAssociationState(ctx, updates, []*associations.IdentityUpdate{newUpdate})
if err != nil {
return err
}

s.log.Info("Got association state", zap.Any("state", state))
protoBytes, err := proto.Marshal(new_update)
protoBytes, err := proto.Marshal(newUpdate)
if err != nil {
return err
}

sequence_id, err := txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{
InboxID: new_update.GetInboxId(),
InboxID: inboxId,
ServerTimestampNs: nowNs(),
IdentityUpdateProto: protoBytes,
})

s.log.Info("Inserted inbox log", zap.Any("sequence_id", sequence_id))
log.Info("Inserted inbox log", zap.Any("sequence_id", sequence_id))

if err != nil {
return err
}

for _, new_member := range state.StateDiff.NewMembers {
s.log.Info("New member", zap.Any("member", new_member))
log.Info("New member", zap.Any("member", new_member))
if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok {
_, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{
Address: address.Address,
Expand All @@ -166,7 +173,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
}

for _, removed_member := range state.StateDiff.RemovedMembers {
s.log.Info("New member", zap.Any("member", removed_member))
log.Info("Removed member", zap.Any("member", removed_member))
if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok {
err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{
Address: address.Address,
Expand Down Expand Up @@ -607,18 +614,19 @@ func (s *Store) RunInTx(
return tx.Commit()
}

func (s *Store) RunInSerializableTx(ctx context.Context, numRetries int, fn func(ctx context.Context, txQueries *queries.Queries) error) error {
func (s *Store) RunInRepeatableReadTx(ctx context.Context, numRetries int, fn func(ctx context.Context, txQueries *queries.Queries) error) error {
var err error
for i := 0; i < numRetries; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
err = s.RunInTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}, fn)
err = s.RunInTx(ctx, &sql.TxOptions{Isolation: sql.LevelRepeatableRead}, fn)
if err == nil {
return nil
}
s.log.Warn("Error in serializable tx", zap.Error(err))
utils.RandomSleep(20)
}
}
return err
Expand Down
61 changes: 61 additions & 0 deletions pkg/mls/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@ package store
import (
"context"
"database/sql"
"errors"
"fmt"
"sort"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate/mocks"
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
"github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1"
test "github.com/xmtp/xmtp-node-go/pkg/testing"
"go.uber.org/mock/gomock"
)

func NewTestStore(t *testing.T) (*Store, func()) {
Expand All @@ -29,6 +36,60 @@ func NewTestStore(t *testing.T) (*Store, func()) {
return store, dbCleanup
}

func TestPublishIdentityUpdateParallel(t *testing.T) {
store, cleanup := NewTestStore(t)
defer cleanup()
ctx := context.Background()

// Create a mapping of inboxes to addresses
inboxes := make(map[string]string)
for i := 0; i < 50; i++ {
inboxes[fmt.Sprintf("inbox_%d", i)] = fmt.Sprintf("address_%d", i)
}

mockController := gomock.NewController(t)
mockMlsValidation := mocks.NewMockMLSValidationService(mockController)

// For each inbox_id in the map, return an AssociationStateDiff that adds the corresponding address
mockMlsValidation.EXPECT().GetAssociationState(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ any, _ any, updates []*associations.IdentityUpdate) (*mlsvalidate.AssociationStateResult, error) {
inboxId := updates[0].InboxId
address, ok := inboxes[inboxId]

if !ok {
return nil, errors.New("inbox id not found")
}

return &mlsvalidate.AssociationStateResult{
AssociationState: &associations.AssociationState{
InboxId: inboxId,
},
StateDiff: &associations.AssociationStateDiff{
NewMembers: []*associations.MemberIdentifier{{
Kind: &associations.MemberIdentifier_Address{
Address: address,
},
}},
},
}, nil
}).AnyTimes()

var wg sync.WaitGroup
for inboxId := range inboxes {
inboxId := inboxId
wg.Add(1)
go func() {
defer wg.Done()
_, err := store.PublishIdentityUpdate(ctx, &identity.PublishIdentityUpdateRequest{
IdentityUpdate: &associations.IdentityUpdate{
InboxId: inboxId,
},
}, mockMlsValidation)
require.NoError(t, err)
}()
}
wg.Wait()
}

func TestInboxIds(t *testing.T) {
store, cleanup := NewTestStore(t)
defer cleanup()
Expand Down
Loading

0 comments on commit b0064d6

Please sign in to comment.