diff --git a/pkg/db/queries.sql b/pkg/db/queries.sql index 250e7907..b6431cd6 100644 --- a/pkg/db/queries.sql +++ b/pkg/db/queries.sql @@ -46,13 +46,15 @@ DELETE FROM staged_originator_envelopes WHERE id = @id; -- name: SelectVectorClock :many -SELECT +SELECT DISTINCT ON (originator_node_id) originator_node_id, - max(originator_sequence_id)::BIGINT AS originator_sequence_id + originator_sequence_id, + originator_envelope FROM gateway_envelopes -GROUP BY - originator_node_id; +ORDER BY + originator_node_id, + originator_sequence_id DESC; -- name: GetAddressLogs :many SELECT diff --git a/pkg/db/queries/queries.sql.go b/pkg/db/queries/queries.sql.go index 72515e50..e9c2c443 100644 --- a/pkg/db/queries/queries.sql.go +++ b/pkg/db/queries/queries.sql.go @@ -339,18 +339,21 @@ func (q *Queries) SelectStagedOriginatorEnvelopes(ctx context.Context, arg Selec } const selectVectorClock = `-- name: SelectVectorClock :many -SELECT +SELECT DISTINCT ON (originator_node_id) originator_node_id, - max(originator_sequence_id)::BIGINT AS originator_sequence_id + originator_sequence_id, + originator_envelope FROM gateway_envelopes -GROUP BY - originator_node_id +ORDER BY + originator_node_id, + originator_sequence_id DESC ` type SelectVectorClockRow struct { OriginatorNodeID int32 OriginatorSequenceID int64 + OriginatorEnvelope []byte } func (q *Queries) SelectVectorClock(ctx context.Context) ([]SelectVectorClockRow, error) { @@ -362,7 +365,7 @@ func (q *Queries) SelectVectorClock(ctx context.Context) ([]SelectVectorClockRow var items []SelectVectorClockRow for rows.Next() { var i SelectVectorClockRow - if err := rows.Scan(&i.OriginatorNodeID, &i.OriginatorSequenceID); err != nil { + if err := rows.Scan(&i.OriginatorNodeID, &i.OriginatorSequenceID, &i.OriginatorEnvelope); err != nil { return nil, err } items = append(items, i) diff --git a/pkg/envelopes/originator.go b/pkg/envelopes/originator.go index ad2335c0..b86b9f94 100644 --- a/pkg/envelopes/originator.go +++ b/pkg/envelopes/originator.go @@ -58,6 +58,10 @@ func (o *OriginatorEnvelope) OriginatorSequenceID() uint64 { return o.UnsignedOriginatorEnvelope.OriginatorSequenceID() } +func (o *OriginatorEnvelope) OriginatorNs() int64 { + return o.UnsignedOriginatorEnvelope.OriginatorNs() +} + func (o *OriginatorEnvelope) TargetTopic() topic.Topic { return o.UnsignedOriginatorEnvelope.TargetTopic() } diff --git a/pkg/sync/syncWorker.go b/pkg/sync/syncWorker.go index 65059760..79a3299e 100644 --- a/pkg/sync/syncWorker.go +++ b/pkg/sync/syncWorker.go @@ -10,6 +10,7 @@ import ( "github.com/xmtp/xmtpd/pkg/db" "github.com/xmtp/xmtpd/pkg/db/queries" + envUtils "github.com/xmtp/xmtpd/pkg/envelopes" clientInterceptors "github.com/xmtp/xmtpd/pkg/interceptors/client" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/envelopes" "github.com/xmtp/xmtpd/pkg/proto/xmtpv4/message_api" @@ -18,7 +19,6 @@ import ( "github.com/xmtp/xmtpd/pkg/tracing" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/protobuf/proto" ) type syncWorker struct { @@ -33,6 +33,12 @@ type syncWorker struct { cancel context.CancelFunc } +type originatorStream struct { + nodeID uint32 + lastEnvelope *envUtils.OriginatorEnvelope + stream message_api.ReplicationApi_SubscribeEnvelopesClient +} + type ExitLoopError struct { Message string } @@ -167,7 +173,7 @@ func (s *syncWorker) subscribeToNodeRegistration( } var conn *grpc.ClientConn - var stream message_api.ReplicationApi_SubscribeEnvelopesClient + var stream *originatorStream err = nil // TODO(mkysel) we should eventually implement a better backoff strategy @@ -270,7 +276,7 @@ func (s *syncWorker) setupStream( ctx context.Context, node registry.Node, conn *grpc.ClientConn, -) (message_api.ReplicationApi_SubscribeEnvelopesClient, error) { +) (*originatorStream, error) { result, err := queries.New(s.store).SelectVectorClock(ctx) if err != nil { return nil, err @@ -282,11 +288,12 @@ func (s *syncWorker) setupStream( zap.Any("vc", vc), ) client := message_api.NewReplicationApiClient(conn) + nodeID := node.NodeID stream, err := client.SubscribeEnvelopes( ctx, &message_api.SubscribeEnvelopesRequest{ Query: &message_api.EnvelopesQuery{ - OriginatorNodeIds: []uint32{node.NodeID}, + OriginatorNodeIds: []uint32{nodeID}, LastSeen: &envelopes.VectorClock{ NodeIdToSequenceId: vc, }, @@ -299,15 +306,25 @@ func (s *syncWorker) setupStream( err, ) } - return stream, nil + originatorStream := &originatorStream{nodeID: nodeID, stream: stream} + for _, row := range result { + if uint32(row.OriginatorNodeID) == nodeID { + lastEnvelope, err := envUtils.NewOriginatorEnvelopeFromBytes(row.OriginatorEnvelope) + if err != nil { + return nil, err + } + originatorStream.lastEnvelope = lastEnvelope + } + } + return originatorStream, nil } func (s *syncWorker) listenToStream( - stream message_api.ReplicationApi_SubscribeEnvelopesClient, + originatorStream *originatorStream, ) error { for { // Recv() is a blocking operation that can only be interrupted by cancelling ctx - envs, err := stream.Recv() + envs, err := originatorStream.stream.Recv() if err == io.EOF { return fmt.Errorf("Stream closed with EOF") } @@ -318,55 +335,62 @@ func (s *syncWorker) listenToStream( } s.log.Debug("Received envelopes", zap.Any("numEnvelopes", len(envs.Envelopes))) for _, env := range envs.Envelopes { - s.insertEnvelope(env) + s.validateAndInsertEnvelope(originatorStream, env) } } - } -func (s *syncWorker) insertEnvelope(env *envelopes.OriginatorEnvelope) { - s.log.Debug("Replication server received envelope", zap.Any("envelope", env)) - // TODO(nm) Validation logic - share code with API service and publish worker - originatorBytes, err := proto.Marshal(env) +func (s *syncWorker) validateAndInsertEnvelope( + stream *originatorStream, + envProto *envelopes.OriginatorEnvelope, +) { + env, err := envUtils.NewOriginatorEnvelope(envProto) if err != nil { - s.log.Error("Failed to marshal originator envelope", zap.Error(err)) + s.log.Error("Failed to unmarshal originator envelope", zap.Error(err)) return } - unsignedEnvelope := &envelopes.UnsignedOriginatorEnvelope{} - err = proto.Unmarshal(env.GetUnsignedOriginatorEnvelope(), unsignedEnvelope) - if err != nil { - s.log.Error( - "Failed to unmarshal unsigned originator envelope", - zap.Error(err), - ) + if env.OriginatorNodeID() != stream.nodeID { + s.log.Error("Received envelope from wrong node", zap.Any("nodeID", env.OriginatorNodeID())) return } - clientEnvelope := &envelopes.ClientEnvelope{} - err = proto.Unmarshal( - unsignedEnvelope.GetPayerEnvelope().GetUnsignedClientEnvelope(), - clientEnvelope, - ) + var lastSequenceID uint64 = 0 + var lastNs int64 = 0 + if stream.lastEnvelope != nil { + lastSequenceID = stream.lastEnvelope.OriginatorSequenceID() + lastNs = stream.lastEnvelope.OriginatorNs() + } + if env.OriginatorSequenceID() != lastSequenceID+1 || env.OriginatorNs() < lastNs { + // TODO(rich) Submit misbehavior report and continue + s.log.Error("Received out of order envelope") + } + + if env.OriginatorSequenceID() > lastSequenceID { + stream.lastEnvelope = env + } + + // TODO Validation logic - share code with API service and publish worker + // Signatures, topic type, etc + s.insertEnvelope(env) +} + +func (s *syncWorker) insertEnvelope(env *envUtils.OriginatorEnvelope) { + s.log.Debug("Replication server received envelope", zap.Any("envelope", env)) + originatorBytes, err := env.Bytes() if err != nil { - s.log.Error( - "Failed to unmarshal client envelope", - zap.Error(err), - ) + s.log.Error("Failed to marshal originator envelope", zap.Error(err)) return } q := queries.New(s.store) - inserted, err := q.InsertGatewayEnvelope( s.ctx, queries.InsertGatewayEnvelopeParams{ - OriginatorNodeID: int32(unsignedEnvelope.GetOriginatorNodeId()), - OriginatorSequenceID: int64( - unsignedEnvelope.GetOriginatorSequenceId(), - ), - Topic: clientEnvelope.GetAad().GetTargetTopic(), - OriginatorEnvelope: originatorBytes, + OriginatorNodeID: int32(env.OriginatorNodeID()), + OriginatorSequenceID: int64(env.OriginatorSequenceID()), + Topic: env.TargetTopic().Bytes(), + OriginatorEnvelope: originatorBytes, }, ) if err != nil {