From 650bf2463d38f682a815aab12c54511443f59b92 Mon Sep 17 00:00:00 2001 From: Richard Hua Date: Tue, 27 Aug 2024 10:52:37 -0700 Subject: [PATCH] Add publish worker (#115) - Adds a publish worker that performs in-order insertion into the `gateway_envelopes` table - Adds basic validation of the client envelope on publish - Store the topic on the `staged_originated_envelopes` table - we extract this during the API call, so that the publish worker doesn't need to do any additional unmarshaling or validation. Would particularly love feedback on the error handling in the worker, and if there's any test cases I should add to `service_test.go`. --- pkg/api/publishWorker.go | 147 ++++++++++++++++++++++++ pkg/api/service.go | 84 +++++++++++--- pkg/api/service_test.go | 102 ++++++++++++++-- pkg/db/queries.sql | 2 +- pkg/db/queries/models.go | 1 + pkg/db/queries/queries.sql.go | 29 +++-- pkg/migrations/00001_init-schema.up.sql | 18 +-- pkg/registrant/registrant.go | 4 + 8 files changed, 346 insertions(+), 41 deletions(-) create mode 100644 pkg/api/publishWorker.go diff --git a/pkg/api/publishWorker.go b/pkg/api/publishWorker.go new file mode 100644 index 00000000..78adda84 --- /dev/null +++ b/pkg/api/publishWorker.go @@ -0,0 +1,147 @@ +package api + +import ( + "context" + "database/sql" + "time" + + "github.com/xmtp/xmtpd/pkg/db" + "github.com/xmtp/xmtpd/pkg/db/queries" + "github.com/xmtp/xmtpd/pkg/registrant" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" +) + +type PublishWorker struct { + ctx context.Context + log *zap.Logger + listener <-chan []queries.StagedOriginatorEnvelope + notifier chan<- bool + registrant *registrant.Registrant + store *sql.DB + subscription db.DBSubscription[queries.StagedOriginatorEnvelope] +} + +func StartPublishWorker( + ctx context.Context, + log *zap.Logger, + reg *registrant.Registrant, + store *sql.DB, +) (*PublishWorker, error) { + q := queries.New(store) + query := func(ctx context.Context, lastSeenID int64, numRows int32) ([]queries.StagedOriginatorEnvelope, int64, error) { + results, err := q.SelectStagedOriginatorEnvelopes( + ctx, + queries.SelectStagedOriginatorEnvelopesParams{ + LastSeenID: lastSeenID, + NumRows: numRows, + }, + ) + if err != nil { + return nil, 0, err + } + if len(results) > 0 { + lastSeenID = results[len(results)-1].ID + } + return results, lastSeenID, nil + } + notifier := make(chan bool, 1) + subscription := db.NewDBSubscription( + ctx, + log, + query, + 0, // lastSeenID + db.PollingOptions{Interval: time.Second, Notifier: notifier, NumRows: 100}, + ) + listener, err := subscription.Start() + if err != nil { + return nil, err + } + + worker := &PublishWorker{ + ctx: ctx, + log: log, + notifier: notifier, + subscription: *subscription, + listener: listener, + registrant: reg, + store: store, + } + go worker.start() + + return worker, nil +} + +func (p *PublishWorker) NotifyStagedPublish() { + select { + case p.notifier <- true: + default: + } +} + +func (p *PublishWorker) start() { + for { + select { + case <-p.ctx.Done(): + return + case new_batch := <-p.listener: + for _, stagedEnv := range new_batch { + for !p.publishStagedEnvelope(stagedEnv) { + // Infinite retry on failure to publish; we cannot + // continue to the next envelope until this one is processed + time.Sleep(time.Second) + } + } + } + } +} + +func (p *PublishWorker) publishStagedEnvelope(stagedEnv queries.StagedOriginatorEnvelope) bool { + logger := p.log.With(zap.Int64("sequenceID", stagedEnv.ID)) + originatorEnv, err := p.registrant.SignStagedEnvelope(stagedEnv) + if err != nil { + logger.Error( + "Failed to sign staged envelope", + zap.Error(err), + ) + return false + } + originatorBytes, err := proto.Marshal(originatorEnv) + if err != nil { + logger.Error("Failed to marshal originator envelope", zap.Error(err)) + return false + } + + q := queries.New(p.store) + + // On unique constraint conflicts, no error is thrown, but numRows is 0 + inserted, err := q.InsertGatewayEnvelope( + p.ctx, + queries.InsertGatewayEnvelopeParams{ + OriginatorID: int32(p.registrant.NodeID()), + OriginatorSequenceID: stagedEnv.ID, + Topic: stagedEnv.Topic, + OriginatorEnvelope: originatorBytes, + }, + ) + if err != nil { + logger.Error("Failed to insert gateway envelope", zap.Error(err)) + return false + } else if inserted == 0 { + // Envelope was already inserted by another worker + logger.Debug("Envelope already inserted") + } + + // Try to delete the row regardless of if the gateway envelope was inserted elsewhere + deleted, err := q.DeleteStagedOriginatorEnvelope(context.Background(), stagedEnv.ID) + if err != nil { + logger.Error("Failed to delete staged envelope", zap.Error(err)) + // Envelope is already inserted, so it is safe to continue + return true + } else if deleted == 0 { + // Envelope was already deleted by another worker + logger.Debug("Envelope already deleted") + } + + return true +} diff --git a/pkg/api/service.go b/pkg/api/service.go index 32367558..7cf02ef6 100644 --- a/pkg/api/service.go +++ b/pkg/api/service.go @@ -20,16 +20,27 @@ type Service struct { ctx context.Context log *zap.Logger registrant *registrant.Registrant - queries *queries.Queries + store *sql.DB + worker *PublishWorker } func NewReplicationApiService( ctx context.Context, log *zap.Logger, registrant *registrant.Registrant, - writerDB *sql.DB, + store *sql.DB, ) (*Service, error) { - return &Service{ctx: ctx, log: log, registrant: registrant, queries: queries.New(writerDB)}, nil + worker, err := StartPublishWorker(ctx, log, registrant, store) + if err != nil { + return nil, err + } + return &Service{ + ctx: ctx, + log: log, + registrant: registrant, + store: store, + worker: worker, + }, nil } func (s *Service) Close() { @@ -54,27 +65,32 @@ func (s *Service) PublishEnvelope( ctx context.Context, req *message_api.PublishEnvelopeRequest, ) (*message_api.PublishEnvelopeResponse, error) { - payerEnv := req.GetPayerEnvelope() - clientBytes := payerEnv.GetUnsignedClientEnvelope() - sig := payerEnv.GetPayerSignature() - if (clientBytes == nil) || (sig == nil) { - return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature") + clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope()) + if err != nil { + return nil, err } - // TODO(rich): Verify payer signature - // TODO(rich): Verify all originators have synced past `last_originator_sids` - // TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group - // TODO(rich): Perform any payload-specific validation (e.g. identity updates) + + topic, err := s.validateClientInfo(clientEnv) + if err != nil { + return nil, err + } + // TODO(rich): If it is a commit, publish it to blockchain instead - payerBytes, err := proto.Marshal(payerEnv) + payerBytes, err := proto.Marshal(req.GetPayerEnvelope()) if err != nil { return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err) } - stagedEnv, err := s.queries.InsertStagedOriginatorEnvelope(ctx, payerBytes) + stagedEnv, err := queries.New(s.store). + InsertStagedOriginatorEnvelope(ctx, queries.InsertStagedOriginatorEnvelopeParams{ + Topic: topic, + PayerEnvelope: payerBytes, + }) if err != nil { return nil, status.Errorf(codes.Internal, "could not insert staged envelope: %v", err) } + s.worker.NotifyStagedPublish() originatorEnv, err := s.registrant.SignStagedEnvelope(stagedEnv) if err != nil { @@ -83,3 +99,43 @@ func (s *Service) PublishEnvelope( return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil } + +func (s *Service) validatePayerInfo( + payerEnv *message_api.PayerEnvelope, +) (*message_api.ClientEnvelope, error) { + clientBytes := payerEnv.GetUnsignedClientEnvelope() + sig := payerEnv.GetPayerSignature() + if (clientBytes == nil) || (sig == nil) { + return nil, status.Errorf(codes.InvalidArgument, "missing envelope or signature") + } + // TODO(rich): Verify payer signature + + clientEnv := &message_api.ClientEnvelope{} + err := proto.Unmarshal(clientBytes, clientEnv) + if err != nil { + return nil, status.Errorf( + codes.InvalidArgument, + "could not unmarshal client envelope: %v", + err, + ) + } + + return clientEnv, nil +} + +func (s *Service) validateClientInfo(clientEnv *message_api.ClientEnvelope) ([]byte, error) { + if clientEnv.GetAad().GetTargetOriginator() != uint32(s.registrant.NodeID()) { + return nil, status.Errorf(codes.InvalidArgument, "invalid target originator") + } + + topic := clientEnv.GetAad().GetTargetTopic() + if len(topic) == 0 { + return nil, status.Errorf(codes.InvalidArgument, "missing target topic") + } + + // TODO(rich): Verify all originators have synced past `last_originator_sids` + // TODO(rich): Check that the blockchain sequence ID is equal to the latest on the group + // TODO(rich): Perform any payload-specific validation (e.g. identity updates) + + return topic, nil +} diff --git a/pkg/api/service_test.go b/pkg/api/service_test.go index e33fa4b8..c91ba4cc 100644 --- a/pkg/api/service_test.go +++ b/pkg/api/service_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "testing" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/require" @@ -40,17 +41,41 @@ func newTestService(t *testing.T) (*Service, *sql.DB, func()) { } } +func createClientEnvelope() *message_api.ClientEnvelope { + return &message_api.ClientEnvelope{ + Payload: nil, + Aad: &message_api.AuthenticatedData{ + TargetOriginator: 1, + TargetTopic: []byte{0x5}, + LastOriginatorSids: []uint64{}, + }, + } +} + +func createPayerEnvelope( + t *testing.T, + clientEnv ...*message_api.ClientEnvelope, +) *message_api.PayerEnvelope { + if len(clientEnv) == 0 { + clientEnv = append(clientEnv, createClientEnvelope()) + } + clientEnvBytes, err := proto.Marshal(clientEnv[0]) + require.NoError(t, err) + + return &message_api.PayerEnvelope{ + UnsignedClientEnvelope: clientEnvBytes, + PayerSignature: &associations.RecoverableEcdsaSignature{}, + } +} + func TestSimplePublish(t *testing.T) { - svc, _, cleanup := newTestService(t) + svc, db, cleanup := newTestService(t) defer cleanup() resp, err := svc.PublishEnvelope( context.Background(), &message_api.PublishEnvelopeRequest{ - PayerEnvelope: &message_api.PayerEnvelope{ - UnsignedClientEnvelope: []byte{0x5}, - PayerSignature: &associations.RecoverableEcdsaSignature{}, - }, + PayerEnvelope: createPayerEnvelope(t), }, ) require.NoError(t, err) @@ -61,7 +86,70 @@ func TestSimplePublish(t *testing.T) { t, proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv), ) - require.Equal(t, uint8(0x5), unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope()[0]) + clientEnv := &message_api.ClientEnvelope{} + require.NoError( + t, + proto.Unmarshal(unsignedEnv.GetPayerEnvelope().GetUnsignedClientEnvelope(), clientEnv), + ) + require.Equal(t, uint8(0x5), clientEnv.Aad.GetTargetTopic()[0]) - // TODO(rich) Test that the published envelope is retrievable via the query API + // Check that the envelope was published to the database after a delay + require.Eventually(t, func() bool { + envs, err := queries.New(db). + SelectGatewayEnvelopes(context.Background(), queries.SelectGatewayEnvelopesParams{}) + require.NoError(t, err) + + if len(envs) != 1 { + return false + } + + originatorEnv := &message_api.OriginatorEnvelope{} + require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv)) + return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope()) + }, 500*time.Millisecond, 50*time.Millisecond) +} + +func TestUnmarshalError(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + envelope := createPayerEnvelope(t) + envelope.UnsignedClientEnvelope = []byte("invalidbytes") + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: envelope, + }, + ) + require.ErrorContains(t, err, "unmarshal") +} + +func TestMismatchingOriginator(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + clientEnv := createClientEnvelope() + clientEnv.Aad.TargetOriginator = 2 + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: createPayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "originator") +} + +func TestMissingTopic(t *testing.T) { + svc, _, cleanup := newTestService(t) + defer cleanup() + + clientEnv := createClientEnvelope() + clientEnv.Aad.TargetTopic = nil + _, err := svc.PublishEnvelope( + context.Background(), + &message_api.PublishEnvelopeRequest{ + PayerEnvelope: createPayerEnvelope(t, clientEnv), + }, + ) + require.ErrorContains(t, err, "topic") } diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index 45f18a05..cc380f7f 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -35,7 +35,7 @@ LIMIT sqlc.narg('row_limit')::INT; SELECT * FROM - insert_staged_originator_envelope(@payer_envelope); + insert_staged_originator_envelope(@topic, @payer_envelope); -- name: SelectStagedOriginatorEnvelopes :many SELECT diff --git a/pkg/db/queries/models.go b/pkg/db/queries/models.go index c6c3d4d6..36178292 100644 --- a/pkg/db/queries/models.go +++ b/pkg/db/queries/models.go @@ -33,5 +33,6 @@ type NodeInfo struct { type StagedOriginatorEnvelope struct { ID int64 OriginatorTime time.Time + Topic []byte PayerEnvelope []byte } diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 7244bb48..f68153fd 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -70,15 +70,25 @@ func (q *Queries) InsertNodeInfo(ctx context.Context, arg InsertNodeInfoParams) const insertStagedOriginatorEnvelope = `-- name: InsertStagedOriginatorEnvelope :one SELECT - id, originator_time, payer_envelope + id, originator_time, topic, payer_envelope FROM - insert_staged_originator_envelope($1) + insert_staged_originator_envelope($1, $2) ` -func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, payerEnvelope []byte) (StagedOriginatorEnvelope, error) { - row := q.db.QueryRowContext(ctx, insertStagedOriginatorEnvelope, payerEnvelope) +type InsertStagedOriginatorEnvelopeParams struct { + Topic []byte + PayerEnvelope []byte +} + +func (q *Queries) InsertStagedOriginatorEnvelope(ctx context.Context, arg InsertStagedOriginatorEnvelopeParams) (StagedOriginatorEnvelope, error) { + row := q.db.QueryRowContext(ctx, insertStagedOriginatorEnvelope, arg.Topic, arg.PayerEnvelope) var i StagedOriginatorEnvelope - err := row.Scan(&i.ID, &i.OriginatorTime, &i.PayerEnvelope) + err := row.Scan( + &i.ID, + &i.OriginatorTime, + &i.Topic, + &i.PayerEnvelope, + ) return i, err } @@ -159,7 +169,7 @@ func (q *Queries) SelectNodeInfo(ctx context.Context) (NodeInfo, error) { const selectStagedOriginatorEnvelopes = `-- name: SelectStagedOriginatorEnvelopes :many SELECT - id, originator_time, payer_envelope + id, originator_time, topic, payer_envelope FROM staged_originator_envelopes WHERE @@ -183,7 +193,12 @@ func (q *Queries) SelectStagedOriginatorEnvelopes(ctx context.Context, arg Selec var items []StagedOriginatorEnvelope for rows.Next() { var i StagedOriginatorEnvelope - if err := rows.Scan(&i.ID, &i.OriginatorTime, &i.PayerEnvelope); err != nil { + if err := rows.Scan( + &i.ID, + &i.OriginatorTime, + &i.Topic, + &i.PayerEnvelope, + ); err != nil { return nil, err } items = append(items, i) diff --git a/pkg/migrations/00001_init-schema.up.sql b/pkg/migrations/00001_init-schema.up.sql index b08e8d45..21cf35d0 100644 --- a/pkg/migrations/00001_init-schema.up.sql +++ b/pkg/migrations/00001_init-schema.up.sql @@ -41,30 +41,24 @@ END; $$ LANGUAGE plpgsql; --- Process for originating envelopes: --- 1. Perform any necessary validation --- 2. Insert into originated_envelopes --- 3. Singleton background task will continuously query (or subscribe to) --- staged_originated_envelopes, and for each envelope in order of ID: --- 2.1. Construct and sign OriginatorEnvelope proto --- 2.2. Atomically insert into all_envelopes and delete from originated_envelopes, --- ignoring unique index violations on originator_sid --- This preserves total ordering, while avoiding gaps in sequence ID's. +-- Newly published envelopes will be queued here first (and assigned an originator +-- sequence ID), before being inserted in-order into the gateway_envelopes table. CREATE TABLE staged_originator_envelopes( -- used to construct originator_sid id BIGSERIAL PRIMARY KEY, originator_time TIMESTAMP NOT NULL DEFAULT now(), + topic BYTEA NOT NULL, payer_envelope BYTEA NOT NULL ); -CREATE FUNCTION insert_staged_originator_envelope(payer_envelope BYTEA) +CREATE FUNCTION insert_staged_originator_envelope(topic BYTEA, payer_envelope BYTEA) RETURNS SETOF staged_originator_envelopes AS $$ BEGIN PERFORM pg_advisory_xact_lock(hashtext('staged_originator_envelopes_sequence')); - RETURN QUERY INSERT INTO staged_originator_envelopes(payer_envelope) - VALUES(payer_envelope) + RETURN QUERY INSERT INTO staged_originator_envelopes(topic, payer_envelope) + VALUES(topic, payer_envelope) ON CONFLICT DO NOTHING RETURNING diff --git a/pkg/registrant/registrant.go b/pkg/registrant/registrant.go index e642a7e1..38fdb458 100644 --- a/pkg/registrant/registrant.go +++ b/pkg/registrant/registrant.go @@ -60,6 +60,10 @@ func (r *Registrant) signKeccak256(data []byte) ([]byte, error) { return crypto.Sign(hash, r.privateKey) } +func (r *Registrant) NodeID() uint16 { + return r.record.NodeID +} + func (r *Registrant) SignStagedEnvelope( stagedEnv queries.StagedOriginatorEnvelope, ) (*message_api.OriginatorEnvelope, error) {