diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index 0f92ae3f..0f18504f 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -5,7 +5,6 @@ import ( "fmt" "hash/fnv" "io" - "strings" "sync" "time" @@ -35,6 +34,20 @@ const ( // 1048576 - 300 - 62 = 1048214 MaxMessageSize = pubsub.DefaultMaxMessageSize - MaxContentTopicNameSize - 62 + + // maxQueriesPerBatch defines the maximum number of queries we can support per batch. + maxQueriesPerBatch = 50 + + // maxTopicsPerQueryRequest defines the maximum number of topics that can be queried in a single request. + // the number is likely to be more than we want it to be, but would be a safe place to put it - + // per Test_LargeQueryTesting, the request decoding already failing before it reaches th handler. + maxTopicsPerQueryRequest = 157733 + + // maxTopicsPerBatchQueryRequest defines the maximum number of topics that can be queried in a batch query. This + // limit is imposed in additional to the per-query limit maxTopicsPerRequest. + // as a starting value, we've using the same value as above, since the entire request would be tossed + // away before this is reached. + maxTopicsPerBatchQueryRequest = maxTopicsPerQueryRequest ) type Service struct { @@ -53,6 +66,8 @@ type Service struct { ns *server.Server nc *nats.Conn + + subDispatcher *subscriptionDispatcher } func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { @@ -74,11 +89,16 @@ func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(con if !s.ns.ReadyForConnections(4 * time.Second) { return nil, errors.New("nats not ready") } + s.nc, err = nats.Connect(s.ns.ClientURL()) if err != nil { return nil, err } + s.subDispatcher, err = newSubscriptionDispatcher(s.nc, s.log) + if err != nil { + return nil, err + } return s, nil } @@ -88,6 +108,7 @@ func (s *Service) Close() { if s.ctxCancel != nil { s.ctxCancel() } + s.subDispatcher.Shutdown() if s.nc != nil { s.nc.Close() @@ -161,50 +182,45 @@ func (s *Service) Subscribe(req *proto.SubscribeRequest, stream proto.MessageApi metrics.EmitSubscribeTopics(stream.Context(), log, len(req.ContentTopics)) - var streamLock sync.Mutex + // create a topics map. + topics := make(map[string]bool, len(req.ContentTopics)) for _, topic := range req.ContentTopics { - subject := topic - if subject != natsWildcardTopic { - subject = buildNatsSubject(topic) + topics[topic] = true + } + sub := s.subDispatcher.Subscribe(topics) + defer func() { + if sub != nil { + sub.Unsubscribe() } - sub, err := s.nc.Subscribe(subject, func(msg *nats.Msg) { - var env proto.Envelope - err := pb.Unmarshal(msg.Data, &env) - if err != nil { - log.Error("parsing envelope from bytes", zap.Error(err)) - return - } - if topic == natsWildcardTopic && !isValidSubscribeAllTopic(env.ContentTopic) { - return + metrics.EmitUnsubscribeTopics(stream.Context(), log, len(req.ContentTopics)) + }() + + var streamLock sync.Mutex + for exit := false; !exit; { + select { + case msg, open := <-sub.messagesCh: + if open { + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(msg) + if err != nil { + log.Error("sending envelope to subscribe", zap.Error(err)) + } + }() + } else { + // channel got closed; likely due to backpressure of the sending channel. + log.Debug("stream closed due to backpressure") + exit = true } - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(&env) - if err != nil { - log.Error("sending envelope to subscribe", zap.Error(err)) - } - }() - }) - if err != nil { - log.Error("error subscribing", zap.Error(err), zap.Int("topics", len(req.ContentTopics))) - return err + case <-stream.Context().Done(): + log.Debug("stream closed") + exit = true + case <-s.ctx.Done(): + log.Info("service closed") + exit = true } - defer func() { - _ = sub.Unsubscribe() - metrics.EmitUnsubscribeTopics(stream.Context(), log, 1) - }() } - - select { - case <-stream.Context().Done(): - log.Debug("stream closed") - break - case <-s.ctx.Done(): - log.Info("service closed") - break - } - return nil } @@ -243,14 +259,16 @@ func (s *Service) Subscribe2(stream proto.MessageApi_Subscribe2Server) error { } }() - subs := map[string]*nats.Subscription{} + var streamLock sync.Mutex + subscribedTopicCount := 0 + var currentSubscription *subscription defer func() { - for _, sub := range subs { - _ = sub.Unsubscribe() + if currentSubscription != nil { + currentSubscription.Unsubscribe() + metrics.EmitUnsubscribeTopics(stream.Context(), log, subscribedTopicCount) } - metrics.EmitUnsubscribeTopics(stream.Context(), log, len(subs)) }() - var streamLock sync.Mutex + subscriptionChannel := make(chan *proto.Envelope, 1) for { select { case <-stream.Context().Done(): @@ -263,52 +281,45 @@ func (s *Service) Subscribe2(stream proto.MessageApi_Subscribe2Server) error { if req == nil { continue } + + // unsubscribe first. + if currentSubscription != nil { + currentSubscription.Unsubscribe() + currentSubscription = nil + } log.Info("updating subscription", zap.Int("num_content_topics", len(req.ContentTopics))) topics := map[string]bool{} - numSubscribes := 0 for _, topic := range req.ContentTopics { topics[topic] = true + } - // If topic not in existing subscriptions, then subscribe. - if _, ok := subs[topic]; !ok { - sub, err := s.nc.Subscribe(buildNatsSubject(topic), func(msg *nats.Msg) { - var env proto.Envelope - err := pb.Unmarshal(msg.Data, &env) - if err != nil { - log.Info("unmarshaling envelope", zap.Error(err)) - return - } - func() { - streamLock.Lock() - defer streamLock.Unlock() - - err = stream.Send(&env) - if err != nil { - log.Error("sending envelope to subscriber", zap.Error(err)) - } - }() - }) + nextSubscription := s.subDispatcher.Subscribe(topics) + if currentSubscription == nil { + // on the first time, emit subscription + metrics.EmitSubscribeTopics(stream.Context(), log, len(topics)) + } else { + // otherwise, emit the change. + metrics.EmitSubscriptionChange(stream.Context(), log, len(topics)-subscribedTopicCount) + } + subscribedTopicCount = len(topics) + subscriptionChannel = nextSubscription.messagesCh + currentSubscription = nextSubscription + case msg, open := <-subscriptionChannel: + if open { + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(msg) if err != nil { - log.Error("error subscribing", zap.Error(err), zap.Int("topics", len(req.ContentTopics))) - return err + log.Error("sending envelope to subscribe", zap.Error(err)) } - subs[topic] = sub - numSubscribes++ - } - } - - // If subscription not in topic, then unsubscribe. - var numUnsubscribes int - for topic, sub := range subs { - if topics[topic] { - continue - } - _ = sub.Unsubscribe() - delete(subs, topic) - numUnsubscribes++ + }() + } else { + // channel got closed; likely due to backpressure of the sending channel. + log.Debug("stream closed due to backpressure") + return nil } - metrics.EmitSubscriptionChange(stream.Context(), log, numSubscribes-numUnsubscribes) } } } @@ -334,6 +345,9 @@ func (s *Service) Query(ctx context.Context, req *proto.QueryRequest) (*proto.Qu } if len(req.ContentTopics) > 1 { + if len(req.ContentTopics) > maxTopicsPerQueryRequest { + return nil, status.Errorf(codes.InvalidArgument, "the number of content topics(%d) exceed the maximum topics per query request (%d)", len(req.ContentTopics), maxTopicsPerQueryRequest) + } ri := apicontext.NewRequesterInfo(ctx) log.Info("query with multiple topics", ri.ZapFields()...) } else { @@ -366,13 +380,33 @@ func (s *Service) BatchQuery(ctx context.Context, req *proto.BatchQueryRequest) logFunc = log.Info } logFunc("large batch query", zap.Int("num_queries", len(req.Requests))) - // NOTE: in our implementation, we implicitly limit batch size to 50 requests - if len(req.Requests) > 50 { - return nil, status.Errorf(codes.InvalidArgument, "cannot exceed 50 requests in single batch") + + // NOTE: in our implementation, we implicitly limit batch size to 50 requests (maxQueriesPerBatch = 50) + if len(req.Requests) > maxQueriesPerBatch { + return nil, status.Errorf(codes.InvalidArgument, "cannot exceed %d requests in single batch", maxQueriesPerBatch) + } + + // calculate the total number of topics being requested in this batch request. + totalRequestedTopicsCount := 0 + for _, query := range req.Requests { + totalRequestedTopicsCount += len(query.ContentTopics) + } + + if totalRequestedTopicsCount == 0 { + return nil, status.Errorf(codes.InvalidArgument, "content topics required") + } + + // are we still within limits ? + if totalRequestedTopicsCount > maxTopicsPerBatchQueryRequest { + return nil, status.Errorf(codes.InvalidArgument, "the total number of content topics(%d) exceed the maximum topics per batch query request(%d)", totalRequestedTopicsCount, maxTopicsPerBatchQueryRequest) } + // Naive implementation, perform all sub query requests sequentially responses := make([]*proto.QueryResponse, 0) for _, query := range req.Requests { + if len(query.ContentTopics) > maxTopicsPerQueryRequest { + return nil, status.Errorf(codes.InvalidArgument, "the number of content topics(%d) exceed the maximum topics per query request (%d)", len(query.ContentTopics), maxTopicsPerQueryRequest) + } // We execute the query using the existing Query API resp, err := s.Query(ctx, query) if err != nil { @@ -394,10 +428,6 @@ func buildEnvelope(msg *wakupb.WakuMessage) *proto.Envelope { } } -func isValidSubscribeAllTopic(contentTopic string) bool { - return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || topic.IsMLSV1(contentTopic) -} - func fromWakuTimestamp(ts int64) uint64 { if ts < 0 { return 0 diff --git a/pkg/api/message/v1/subscription.go b/pkg/api/message/v1/subscription.go new file mode 100644 index 00000000..7fc276f1 --- /dev/null +++ b/pkg/api/message/v1/subscription.go @@ -0,0 +1,155 @@ +package api + +import ( + "strings" + "sync" + + "github.com/nats-io/nats.go" // NATS messaging system + proto "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1" // Custom XMTP Protocol Buffers definition + "go.uber.org/zap" // Logging library + pb "google.golang.org/protobuf/proto" // Protocol Buffers for serialization +) + +const ( + // allTopicsBacklogLength defines the buffer size for subscriptions that listen to all topics. + allTopicsBacklogLength = 1024 + + // minBacklogBufferLength defines the minimal length used for backlog buffer. + minBacklogBufferLength +) + +// subscriptionDispatcher manages subscriptions and message dispatching. +type subscriptionDispatcher struct { + natsConn *nats.Conn // Connection to NATS server + natsSub *nats.Subscription // Subscription to NATS topics + log *zap.Logger // Logger instance + subscriptions map[*subscription]interface{} // Active subscriptions + mu sync.Mutex // Mutex for concurrency control +} + +// newSubscriptionDispatcher creates a new dispatcher for managing subscriptions. +func newSubscriptionDispatcher(conn *nats.Conn, log *zap.Logger) (*subscriptionDispatcher, error) { + dispatcher := &subscriptionDispatcher{ + natsConn: conn, + log: log, + subscriptions: make(map[*subscription]interface{}), + } + + // Subscribe to NATS wildcard topic and assign message handler + var err error + dispatcher.natsSub, err = conn.Subscribe(natsWildcardTopic, dispatcher.messageHandler) + if err != nil { + return nil, err + } + return dispatcher, nil +} + +// Shutdown gracefully shuts down the dispatcher, unsubscribing from all topics. +func (d *subscriptionDispatcher) Shutdown() { + _ = d.natsSub.Unsubscribe() + // the lock/unlock ensures that there is no in-process dispatching. + d.mu.Lock() + defer d.mu.Unlock() + d.natsSub = nil + d.natsConn = nil + d.subscriptions = nil + +} + +// messageHandler processes incoming messages, dispatching them to the correct subscription. +func (d *subscriptionDispatcher) messageHandler(msg *nats.Msg) { + var env proto.Envelope + err := pb.Unmarshal(msg.Data, &env) + if err != nil { + d.log.Info("unmarshaling envelope", zap.Error(err)) + return + } + + xmtpTopic := isValidSubscribeAllTopic(env.ContentTopic) + + d.mu.Lock() + defer d.mu.Unlock() + for subscription := range d.subscriptions { + if subscription.all && !xmtpTopic { + continue + } + if subscription.all || subscription.topics[env.ContentTopic] { + select { + case subscription.messagesCh <- &env: + default: + // we got here since the message channel was full. This happens when the client cannot + // consume the data fast enough. In that case, we don't want to block further since it migth + // slow down other users. Instead, we're going to close the channel and let the + // consumer re-establish the connection if needed. + close(subscription.messagesCh) + delete(d.subscriptions, subscription) + } + } + } +} + +// subscription represents a single subscription, including its message channel and topics. +type subscription struct { + messagesCh chan *proto.Envelope // Channel for receiving messages + topics map[string]bool // Map of topics to subscribe to + all bool // Flag indicating subscription to all topics + dispatcher *subscriptionDispatcher // Parent dispatcher +} + +// log2 calculates the base-2 logarithm of an integer using bitwise operations. +// It returns the floor of the actual base-2 logarithm. +func log2(n uint) (log2 uint) { + if n == 0 { + return 0 + } + + // Keep shifting n right until it becomes 0. + // The number of shifts needed is the floor of log2(n). + for n > 1 { + n >>= 1 + log2++ + } + return log2 +} + +// Subscribe creates a new subscription for the given topics. +func (d *subscriptionDispatcher) Subscribe(topics map[string]bool) *subscription { + sub := &subscription{ + dispatcher: d, + } + + // Determine if subscribing to all topics or specific ones + for topic := range topics { + if natsWildcardTopic == topic { + sub.all = true + break + } + } + if !sub.all { + sub.topics = topics + // use a log2(length) as a backbuffer + backlogBufferSize := log2(uint(len(topics))) + 1 + if backlogBufferSize < minBacklogBufferLength { + backlogBufferSize = minBacklogBufferLength + } + sub.messagesCh = make(chan *proto.Envelope, backlogBufferSize) + } else { + sub.messagesCh = make(chan *proto.Envelope, allTopicsBacklogLength) + } + + d.mu.Lock() + defer d.mu.Unlock() + d.subscriptions[sub] = true + return sub +} + +// Unsubscribe removes the subscription from its dispatcher. +func (sub *subscription) Unsubscribe() { + sub.dispatcher.mu.Lock() + defer sub.dispatcher.mu.Unlock() + delete(sub.dispatcher.subscriptions, sub) +} + +func isValidSubscribeAllTopic(topic string) bool { + return strings.HasPrefix(topic, validXMTPTopicPrefix) +} diff --git a/pkg/api/server.go b/pkg/api/server.go index 267804a2..96e12636 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -283,6 +283,7 @@ func (s *Server) Close() { if err != nil { s.Log.Error("closing http listener", zap.Error(err)) } + s.httpListener = nil } if s.grpcListener != nil { @@ -290,6 +291,7 @@ func (s *Server) Close() { if err != nil { s.Log.Error("closing grpc listener", zap.Error(err)) } + s.grpcListener = nil } s.wg.Wait() diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index 5273ce1e..abbd8ee4 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -5,8 +5,10 @@ import ( "encoding/json" "fmt" "io" + "math/rand" "net/http" "strings" + "sync" "testing" "time" @@ -16,6 +18,7 @@ import ( messageV1 "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1" "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" test "github.com/xmtp/xmtp-node-go/pkg/testing" + "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -60,6 +63,24 @@ func Test_HTTPRootPath(t *testing.T) { require.NotEmpty(t, body) } +type deferedPublishResponseFunc func(tb testing.TB) + +func publishTestEnvelopes(ctx context.Context, client messageclient.Client, msgs *messageV1.PublishRequest) deferedPublishResponseFunc { + var waitGroup sync.WaitGroup + waitGroup.Add(1) + var publishErr error + var publishRes *messageV1.PublishResponse + go func() { + defer waitGroup.Done() + publishRes, publishErr = client.Publish(ctx, msgs) + }() + return func(tb testing.TB) { + waitGroup.Wait() + require.NoError(tb, publishErr) + require.NotNil(tb, publishRes) + } +} + func Test_SubscribePublishQuery(t *testing.T) { ctx := withAuth(t, context.Background()) testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { @@ -73,9 +94,8 @@ func Test_SubscribePublishQuery(t *testing.T) { // publish 10 messages envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) // read subscription subscribeExpect(t, stream, envs) @@ -247,9 +267,8 @@ func Test_GRPCMaxMessageSize(t *testing.T) { TimestampNs: 3, }, } - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) subscribeExpect(t, stream, envs) requireEventuallyStored(t, ctx, client, envs) @@ -324,9 +343,8 @@ func Test_SubscribeClientClose(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) + defer deferedPublishResult(t) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -334,9 +352,9 @@ func Test_SubscribeClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) + defer deferedPublishResult(t) + time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -359,9 +377,8 @@ func Test_Subscribe2ClientClose(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) + defer deferedPublishResult(t) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -369,9 +386,8 @@ func Test_Subscribe2ClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) + defer deferedPublishResult(t) time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -394,9 +410,9 @@ func Test_Subscribe2UpdateTopics(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) + defer deferedPublishResult(t) + // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -443,9 +459,8 @@ func Test_SubscribeAllClientClose(t *testing.T) { for i, env := range envs { envs[i].ContentTopic = "/xmtp/0/" + env.ContentTopic } - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) + defer deferedPublishResult(t) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -453,9 +468,8 @@ func Test_SubscribeAllClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) + defer deferedPublishResult(t) time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -478,12 +492,11 @@ func Test_SubscribeServerClose(t *testing.T) { // Publish 5 messages. envs := makeEnvelopes(5) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) // Receive 5 - subscribeExpect(t, stream, envs[:5]) + subscribeExpect(t, stream, envs) // stop Server server.Close() @@ -509,9 +522,8 @@ func Test_SubscribeAllServerClose(t *testing.T) { for i, env := range envs { envs[i].ContentTopic = "/xmtp/0/" + env.ContentTopic } - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) // Receive 5 subscribeExpect(t, stream, envs[:5]) @@ -581,9 +593,8 @@ func Test_MultipleSubscriptions(t *testing.T) { // publish 5 envelopes envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) + defer deferedPublishResult(t) // receive 5 envelopes on both streams subscribeExpect(t, stream1, envs[:5]) @@ -600,9 +611,8 @@ func Test_MultipleSubscriptions(t *testing.T) { time.Sleep(50 * time.Millisecond) // publish another 5 envelopes - publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) + defer deferedPublishResult(t) // receive 5 on stream 2 and 3 subscribeExpect(t, stream2, envs[5:]) @@ -615,9 +625,8 @@ func Test_QueryPaging(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) time.Sleep(50 * time.Millisecond) requireEventuallyStored(t, ctx, client, envs) @@ -664,9 +673,8 @@ func Test_BatchQuery(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) requireEventuallyStored(t, ctx, client, envs) batchSize := 50 @@ -714,9 +722,8 @@ func Test_BatchQueryOverLimitError(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(t, err) - require.NotNil(t, publishRes) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) requireEventuallyStored(t, ctx, client, envs) // Limit is 50 queries implicitly so 100 should result in an error @@ -731,7 +738,7 @@ func Test_BatchQueryOverLimitError(t *testing.T) { } repeatedQueries = append(repeatedQueries, query) } - _, err = client.BatchQuery(ctx, &messageV1.BatchQueryRequest{ + _, err := client.BatchQuery(ctx, &messageV1.BatchQueryRequest{ Requests: repeatedQueries, }) grpcErr, ok := status.FromError(err) @@ -841,3 +848,109 @@ func requireErrorEqual(t *testing.T, err error, code codes.Code, msg string, det require.ElementsMatch(t, details, httpErr["details"]) } } + +func Benchmark_SubscribePublishQuery(b *testing.B) { + server, cleanup := newTestServerWithLog(b, zap.NewNop()) + defer cleanup() + + ctx := withAuth(b, context.Background()) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + client, err := messageclient.NewGRPCClient(ctx, server.dialGRPC) + require.NoError(b, err) + + // create topics large topics for 10 streams. Topic should be interleaved. + const chunkSize = 1000 + const streamsCount = 10 + topics := [streamsCount][]string{} + + maxTopic := (len(topics)-1)*chunkSize*3/4 + chunkSize + // create a random order of topics. + topicsOrder := rand.Perm(maxTopic) + envs := make([]*messageV1.Envelope, len(topicsOrder)) + for i, topicID := range topicsOrder { + envs[i] = &messageV1.Envelope{ + ContentTopic: fmt.Sprintf("/xmtp/0/topic/%d", topicID), + Message: []byte{1, 2, 3}, + TimestampNs: uint64(time.Second), + } + } + + for j := range topics { + topics[j] = make([]string, chunkSize) + for k := range topics[j] { + topics[j][k] = fmt.Sprintf("/xmtp/0/topic/%d", (j*chunkSize*3/4 + k)) + } + } + + streams := [10]messageclient.Stream{} + b.ResetTimer() + for i := range streams { + // start subscribe streams + var err error + streams[i], err = client.Subscribe(ctx, &messageV1.SubscribeRequest{ + ContentTopics: topics[i], + }) + require.NoError(b, err) + defer streams[i].Close() + } + + for n := 0; n < b.N; n++ { + // publish messages + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(b, err) + require.NotNil(b, publishRes) + + readCtx, readCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer readCancel() + // read subscription + for _, stream := range streams { + for k := 0; k < chunkSize; k++ { + _, err := stream.Next(readCtx) + require.NoError(b, err) + } + } + } +} + +func Test_LargeQueryTesting(t *testing.T) { + ctx := withAuth(t, context.Background()) + testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { + // Store 10 envelopes with increasing SenderTimestamp + envs := makeEnvelopes(10) + deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) + defer deferedPublishResult(t) + time.Sleep(50 * time.Millisecond) + requireEventuallyStored(t, ctx, client, envs) + + // create a large set of query topics. + topics := make([]string, 512*1024) + for i := range topics { + topics[i] = fmt.Sprintf("topic/%d", i) + } + size := 16 + step := 16 + prevSize := 16 + for { + query := &messageV1.QueryRequest{ + ContentTopics: topics[:size], + } + _, err := client.Query(ctx, query) + if err != nil { + // go back, and cut the step by half. + size = prevSize + step /= 2 + size += step + if step == 0 { + break + } + continue + } + prevSize = size + step *= 2 + size += step + } + t.Logf("max number of topics without any error was %d", size) + }) +} diff --git a/pkg/api/setup_test.go b/pkg/api/setup_test.go index 806f9c1d..3a234202 100644 --- a/pkg/api/setup_test.go +++ b/pkg/api/setup_test.go @@ -19,9 +19,8 @@ const ( testMaxMsgSize = 2 * 1024 * 1024 ) -func newTestServer(t *testing.T) (*Server, func()) { - log := test.NewLog(t) - waku, wakuCleanup := test.NewNode(t) +func newTestServerWithLog(t testing.TB, log *zap.Logger) (*Server, func()) { + waku, wakuCleanup := test.NewNode(t, log) store, storeCleanup := newTestStore(t, log) authzDB, _, authzDBCleanup := test.NewAuthzDB(t) allowLister := authz.NewDatabaseWalletAllowLister(authzDB, log) @@ -38,7 +37,7 @@ func newTestServer(t *testing.T) (*Server, func()) { MaxMsgSize: testMaxMsgSize, }, Waku: waku, - Log: test.NewLog(t), + Log: log, Store: store, AllowLister: allowLister, }) @@ -51,7 +50,12 @@ func newTestServer(t *testing.T) (*Server, func()) { } } -func newTestStore(t *testing.T, log *zap.Logger) (*store.Store, func()) { +func newTestServer(t testing.TB) (*Server, func()) { + log := test.NewLog(t) + return newTestServerWithLog(t, log) +} + +func newTestStore(t testing.TB, log *zap.Logger) (*store.Store, func()) { db, _, dbCleanup := test.NewDB(t) store, err := store.New(&store.Config{ Log: log, @@ -109,7 +113,7 @@ func testGRPC(t *testing.T, ctx context.Context, f func(*testing.T, messageclien f(t, c, server) } -func withAuth(t *testing.T, ctx context.Context) context.Context { +func withAuth(t testing.TB, ctx context.Context) context.Context { ctx, _ = withAuthWithDetails(t, ctx, time.Now()) return ctx } @@ -139,7 +143,7 @@ func withMissingIdentityKey(t *testing.T, ctx context.Context) context.Context { return metadata.AppendToOutgoingContext(ctx, authorizationMetadataKey, "Bearer "+et) } -func withAuthWithDetails(t *testing.T, ctx context.Context, when time.Time) (context.Context, *v1.AuthData) { +func withAuthWithDetails(t testing.TB, ctx context.Context, when time.Time) (context.Context, *v1.AuthData) { token, data, err := generateV2AuthToken(when) require.NoError(t, err) et, err := EncodeAuthToken(token) diff --git a/pkg/api/utils_test.go b/pkg/api/utils_test.go index 3cc3f9ad..b8108fbf 100644 --- a/pkg/api/utils_test.go +++ b/pkg/api/utils_test.go @@ -24,7 +24,7 @@ func makeEnvelopes(count int) (envs []*messageV1.Envelope) { return envs } -func subscribeExpect(t *testing.T, stream messageclient.Stream, expected []*messageV1.Envelope) { +func subscribeExpect(t testing.TB, stream messageclient.Stream, expected []*messageV1.Envelope) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() received := []*messageV1.Envelope{} @@ -38,7 +38,7 @@ func subscribeExpect(t *testing.T, stream messageclient.Stream, expected []*mess requireEnvelopesEqual(t, expected, received) } -func requireEventuallyStored(t *testing.T, ctx context.Context, client messageclient.Client, expected []*messageV1.Envelope) { +func requireEventuallyStored(t testing.TB, ctx context.Context, client messageclient.Client, expected []*messageV1.Envelope) { var queryRes *messageV1.QueryResponse require.Eventually(t, func() bool { var err error @@ -54,14 +54,14 @@ func requireEventuallyStored(t *testing.T, ctx context.Context, client messagecl requireEnvelopesEqual(t, expected, queryRes.Envelopes) } -func requireEnvelopesEqual(t *testing.T, expected, received []*messageV1.Envelope) { +func requireEnvelopesEqual(t testing.TB, expected, received []*messageV1.Envelope) { require.Equal(t, len(expected), len(received), "length mismatch") for i, env := range received { requireEnvelopeEqual(t, expected[i], env, "mismatched message[%d]", i) } } -func requireEnvelopeEqual(t *testing.T, expected, actual *messageV1.Envelope, msgAndArgs ...interface{}) { +func requireEnvelopeEqual(t testing.TB, expected, actual *messageV1.Envelope, msgAndArgs ...interface{}) { require.Equal(t, expected.ContentTopic, actual.ContentTopic, msgAndArgs...) require.Equal(t, expected.Message, actual.Message, msgAndArgs...) if expected.TimestampNs != 0 { diff --git a/pkg/e2e/test_messagev1.go b/pkg/e2e/test_messagev1.go index 3fb71cc0..5a69f879 100644 --- a/pkg/e2e/test_messagev1.go +++ b/pkg/e2e/test_messagev1.go @@ -114,6 +114,32 @@ syncLoop: } } + // start listeners + streamsCh := make([]chan *messagev1.Envelope, len(clients)) + for i := 0; i < clientCount; i++ { + envC := make(chan *messagev1.Envelope, 100) + go func(stream messageclient.Stream, envC chan *messagev1.Envelope) { + for { + env, err := stream.Next(ctx) + if err != nil { + if isErrClosedConnection(err) || err.Error() == "context canceled" { + break + } + s.log.Error("getting next", zap.Error(err)) + break + } + if env == nil { + continue + } + envC <- env + } + }(streams[i], envC) + streamsCh[i] = envC + } + + // wait until all the listeners are up and ready. + time.Sleep(100 * time.Millisecond) + // Publish messages. envs := []*messagev1.Envelope{} for i, client := range clients { @@ -126,6 +152,7 @@ syncLoop: } } envs = append(envs, clientEnvs...) + _, err = client.Publish(ctx, &messagev1.PublishRequest{ Envelopes: clientEnvs, }) @@ -136,25 +163,7 @@ syncLoop: // Expect them to be relayed to each subscription. for i := 0; i < clientCount; i++ { - stream := streams[i] - envC := make(chan *messagev1.Envelope, 100) - go func() { - for { - env, err := stream.Next(ctx) - if err != nil { - if isErrClosedConnection(err) || err.Error() == "context canceled" { - break - } - s.log.Error("getting next", zap.Error(err)) - break - } - if env == nil { - continue - } - envC <- env - } - }() - err = subscribeExpect(envC, envs) + err = subscribeExpect(streamsCh[i], envs) if err != nil { return err } diff --git a/pkg/server/node_test.go b/pkg/server/node_test.go index 223918c0..14259b2f 100644 --- a/pkg/server/node_test.go +++ b/pkg/server/node_test.go @@ -91,9 +91,9 @@ func TestNodes_Deployment(t *testing.T) { n2PrivKey := test.NewPrivateKey(t) // Spin up initial instances of the nodes. - n1, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n1PrivKey)) + n1, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n1PrivKey)) defer cleanup() - n2, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n2PrivKey)) + n2, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n2PrivKey)) defer cleanup() // Connect the nodes. @@ -101,9 +101,9 @@ func TestNodes_Deployment(t *testing.T) { test.Connect(t, n2, n1) // Spin up new instances of the nodes. - newN1, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n1PrivKey)) + newN1, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n1PrivKey)) defer cleanup() - newN2, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n2PrivKey)) + newN2, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n2PrivKey)) defer cleanup() // Expect matching peer IDs for new and old instances. diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 63a1d137..aed1368f 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -21,11 +21,11 @@ func TestServer_NewShutdown(t *testing.T) { func TestServer_StaticNodesReconnect(t *testing.T) { t.Parallel() - n1, cleanup := test.NewNode(t) + n1, cleanup := test.NewNode(t, test.NewLog(t)) defer cleanup() n1ID := n1.Host().ID() - n2, cleanup := test.NewNode(t) + n2, cleanup := test.NewNode(t, test.NewLog(t)) defer cleanup() n2ID := n2.Host().ID() diff --git a/pkg/testing/log.go b/pkg/testing/log.go index ec885634..cbfd3fd0 100644 --- a/pkg/testing/log.go +++ b/pkg/testing/log.go @@ -14,7 +14,7 @@ func init() { flag.BoolVar(&debug, "debug", false, "debug level logging in tests") } -func NewLog(t *testing.T) *zap.Logger { +func NewLog(t testing.TB) *zap.Logger { cfg := zap.NewDevelopmentConfig() if !debug { cfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) diff --git a/pkg/testing/node.go b/pkg/testing/node.go index 68f425a0..acdce878 100644 --- a/pkg/testing/node.go +++ b/pkg/testing/node.go @@ -18,6 +18,7 @@ import ( "github.com/waku-org/go-waku/tests" wakunode "github.com/waku-org/go-waku/waku/v2/node" "github.com/waku-org/go-waku/waku/v2/peerstore" + "go.uber.org/zap" ) func Connect(t *testing.T, n1 *wakunode.WakuNode, n2 *wakunode.WakuNode, protocols ...protocol.ID) { @@ -65,11 +66,10 @@ func Disconnect(t *testing.T, n1 *wakunode.WakuNode, n2 *wakunode.WakuNode) { }, 3*time.Second, 50*time.Millisecond) } -func NewNode(t *testing.T, opts ...wakunode.WakuNodeOption) (*wakunode.WakuNode, func()) { +func NewNode(t testing.TB, log *zap.Logger, opts ...wakunode.WakuNodeOption) (*wakunode.WakuNode, func()) { hostAddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") prvKey := NewPrivateKey(t) ctx := context.Background() - log := NewLog(t) opts = append([]wakunode.WakuNodeOption{ wakunode.WithLogger(log), wakunode.WithPrivateKey(prvKey), @@ -94,7 +94,7 @@ func NewPeer(t *testing.T) host.Host { return host } -func NewPrivateKey(t *testing.T) *ecdsa.PrivateKey { +func NewPrivateKey(t testing.TB) *ecdsa.PrivateKey { key, err := tests.RandomHex(32) require.NoError(t, err) prvKey, err := crypto.HexToECDSA(key) diff --git a/pkg/testing/store.go b/pkg/testing/store.go index 3727fe27..bb4c4aff 100644 --- a/pkg/testing/store.go +++ b/pkg/testing/store.go @@ -19,7 +19,7 @@ const ( localTestDBDSNSuffix = "?sslmode=disable" ) -func NewDB(t *testing.T) (*sql.DB, string, func()) { +func NewDB(t testing.TB) (*sql.DB, string, func()) { dsn := localTestDBDSNPrefix + localTestDBDSNSuffix ctlDB := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) dbName := "test_" + RandomStringLower(12) @@ -36,7 +36,7 @@ func NewDB(t *testing.T) (*sql.DB, string, func()) { } } -func NewAuthzDB(t *testing.T) (*bun.DB, string, func()) { +func NewAuthzDB(t testing.TB) (*bun.DB, string, func()) { db, dsn, cleanup := NewDB(t) bunDB := bun.NewDB(db, pgdialect.New())