diff --git a/pkg/api/message/v1/service.go b/pkg/api/message/v1/service.go index 0f18504f..0f92ae3f 100644 --- a/pkg/api/message/v1/service.go +++ b/pkg/api/message/v1/service.go @@ -5,6 +5,7 @@ import ( "fmt" "hash/fnv" "io" + "strings" "sync" "time" @@ -34,20 +35,6 @@ const ( // 1048576 - 300 - 62 = 1048214 MaxMessageSize = pubsub.DefaultMaxMessageSize - MaxContentTopicNameSize - 62 - - // maxQueriesPerBatch defines the maximum number of queries we can support per batch. - maxQueriesPerBatch = 50 - - // maxTopicsPerQueryRequest defines the maximum number of topics that can be queried in a single request. - // the number is likely to be more than we want it to be, but would be a safe place to put it - - // per Test_LargeQueryTesting, the request decoding already failing before it reaches th handler. - maxTopicsPerQueryRequest = 157733 - - // maxTopicsPerBatchQueryRequest defines the maximum number of topics that can be queried in a batch query. This - // limit is imposed in additional to the per-query limit maxTopicsPerRequest. - // as a starting value, we've using the same value as above, since the entire request would be tossed - // away before this is reached. - maxTopicsPerBatchQueryRequest = maxTopicsPerQueryRequest ) type Service struct { @@ -66,8 +53,6 @@ type Service struct { ns *server.Server nc *nats.Conn - - subDispatcher *subscriptionDispatcher } func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error) (s *Service, err error) { @@ -89,16 +74,11 @@ func NewService(log *zap.Logger, store *store.Store, publishToWakuRelay func(con if !s.ns.ReadyForConnections(4 * time.Second) { return nil, errors.New("nats not ready") } - s.nc, err = nats.Connect(s.ns.ClientURL()) if err != nil { return nil, err } - s.subDispatcher, err = newSubscriptionDispatcher(s.nc, s.log) - if err != nil { - return nil, err - } return s, nil } @@ -108,7 +88,6 @@ func (s *Service) Close() { if s.ctxCancel != nil { s.ctxCancel() } - s.subDispatcher.Shutdown() if s.nc != nil { s.nc.Close() @@ -182,45 +161,50 @@ func (s *Service) Subscribe(req *proto.SubscribeRequest, stream proto.MessageApi metrics.EmitSubscribeTopics(stream.Context(), log, len(req.ContentTopics)) - // create a topics map. - topics := make(map[string]bool, len(req.ContentTopics)) + var streamLock sync.Mutex for _, topic := range req.ContentTopics { - topics[topic] = true - } - sub := s.subDispatcher.Subscribe(topics) - defer func() { - if sub != nil { - sub.Unsubscribe() + subject := topic + if subject != natsWildcardTopic { + subject = buildNatsSubject(topic) } - metrics.EmitUnsubscribeTopics(stream.Context(), log, len(req.ContentTopics)) - }() - - var streamLock sync.Mutex - for exit := false; !exit; { - select { - case msg, open := <-sub.messagesCh: - if open { - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(msg) - if err != nil { - log.Error("sending envelope to subscribe", zap.Error(err)) - } - }() - } else { - // channel got closed; likely due to backpressure of the sending channel. - log.Debug("stream closed due to backpressure") - exit = true + sub, err := s.nc.Subscribe(subject, func(msg *nats.Msg) { + var env proto.Envelope + err := pb.Unmarshal(msg.Data, &env) + if err != nil { + log.Error("parsing envelope from bytes", zap.Error(err)) + return } - case <-stream.Context().Done(): - log.Debug("stream closed") - exit = true - case <-s.ctx.Done(): - log.Info("service closed") - exit = true + if topic == natsWildcardTopic && !isValidSubscribeAllTopic(env.ContentTopic) { + return + } + func() { + streamLock.Lock() + defer streamLock.Unlock() + err := stream.Send(&env) + if err != nil { + log.Error("sending envelope to subscribe", zap.Error(err)) + } + }() + }) + if err != nil { + log.Error("error subscribing", zap.Error(err), zap.Int("topics", len(req.ContentTopics))) + return err } + defer func() { + _ = sub.Unsubscribe() + metrics.EmitUnsubscribeTopics(stream.Context(), log, 1) + }() } + + select { + case <-stream.Context().Done(): + log.Debug("stream closed") + break + case <-s.ctx.Done(): + log.Info("service closed") + break + } + return nil } @@ -259,16 +243,14 @@ func (s *Service) Subscribe2(stream proto.MessageApi_Subscribe2Server) error { } }() - var streamLock sync.Mutex - subscribedTopicCount := 0 - var currentSubscription *subscription + subs := map[string]*nats.Subscription{} defer func() { - if currentSubscription != nil { - currentSubscription.Unsubscribe() - metrics.EmitUnsubscribeTopics(stream.Context(), log, subscribedTopicCount) + for _, sub := range subs { + _ = sub.Unsubscribe() } + metrics.EmitUnsubscribeTopics(stream.Context(), log, len(subs)) }() - subscriptionChannel := make(chan *proto.Envelope, 1) + var streamLock sync.Mutex for { select { case <-stream.Context().Done(): @@ -281,45 +263,52 @@ func (s *Service) Subscribe2(stream proto.MessageApi_Subscribe2Server) error { if req == nil { continue } - - // unsubscribe first. - if currentSubscription != nil { - currentSubscription.Unsubscribe() - currentSubscription = nil - } log.Info("updating subscription", zap.Int("num_content_topics", len(req.ContentTopics))) topics := map[string]bool{} + numSubscribes := 0 for _, topic := range req.ContentTopics { topics[topic] = true - } - nextSubscription := s.subDispatcher.Subscribe(topics) - if currentSubscription == nil { - // on the first time, emit subscription - metrics.EmitSubscribeTopics(stream.Context(), log, len(topics)) - } else { - // otherwise, emit the change. - metrics.EmitSubscriptionChange(stream.Context(), log, len(topics)-subscribedTopicCount) - } - subscribedTopicCount = len(topics) - subscriptionChannel = nextSubscription.messagesCh - currentSubscription = nextSubscription - case msg, open := <-subscriptionChannel: - if open { - func() { - streamLock.Lock() - defer streamLock.Unlock() - err := stream.Send(msg) + // If topic not in existing subscriptions, then subscribe. + if _, ok := subs[topic]; !ok { + sub, err := s.nc.Subscribe(buildNatsSubject(topic), func(msg *nats.Msg) { + var env proto.Envelope + err := pb.Unmarshal(msg.Data, &env) + if err != nil { + log.Info("unmarshaling envelope", zap.Error(err)) + return + } + func() { + streamLock.Lock() + defer streamLock.Unlock() + + err = stream.Send(&env) + if err != nil { + log.Error("sending envelope to subscriber", zap.Error(err)) + } + }() + }) if err != nil { - log.Error("sending envelope to subscribe", zap.Error(err)) + log.Error("error subscribing", zap.Error(err), zap.Int("topics", len(req.ContentTopics))) + return err } - }() - } else { - // channel got closed; likely due to backpressure of the sending channel. - log.Debug("stream closed due to backpressure") - return nil + subs[topic] = sub + numSubscribes++ + } + } + + // If subscription not in topic, then unsubscribe. + var numUnsubscribes int + for topic, sub := range subs { + if topics[topic] { + continue + } + _ = sub.Unsubscribe() + delete(subs, topic) + numUnsubscribes++ } + metrics.EmitSubscriptionChange(stream.Context(), log, numSubscribes-numUnsubscribes) } } } @@ -345,9 +334,6 @@ func (s *Service) Query(ctx context.Context, req *proto.QueryRequest) (*proto.Qu } if len(req.ContentTopics) > 1 { - if len(req.ContentTopics) > maxTopicsPerQueryRequest { - return nil, status.Errorf(codes.InvalidArgument, "the number of content topics(%d) exceed the maximum topics per query request (%d)", len(req.ContentTopics), maxTopicsPerQueryRequest) - } ri := apicontext.NewRequesterInfo(ctx) log.Info("query with multiple topics", ri.ZapFields()...) } else { @@ -380,33 +366,13 @@ func (s *Service) BatchQuery(ctx context.Context, req *proto.BatchQueryRequest) logFunc = log.Info } logFunc("large batch query", zap.Int("num_queries", len(req.Requests))) - - // NOTE: in our implementation, we implicitly limit batch size to 50 requests (maxQueriesPerBatch = 50) - if len(req.Requests) > maxQueriesPerBatch { - return nil, status.Errorf(codes.InvalidArgument, "cannot exceed %d requests in single batch", maxQueriesPerBatch) - } - - // calculate the total number of topics being requested in this batch request. - totalRequestedTopicsCount := 0 - for _, query := range req.Requests { - totalRequestedTopicsCount += len(query.ContentTopics) - } - - if totalRequestedTopicsCount == 0 { - return nil, status.Errorf(codes.InvalidArgument, "content topics required") - } - - // are we still within limits ? - if totalRequestedTopicsCount > maxTopicsPerBatchQueryRequest { - return nil, status.Errorf(codes.InvalidArgument, "the total number of content topics(%d) exceed the maximum topics per batch query request(%d)", totalRequestedTopicsCount, maxTopicsPerBatchQueryRequest) + // NOTE: in our implementation, we implicitly limit batch size to 50 requests + if len(req.Requests) > 50 { + return nil, status.Errorf(codes.InvalidArgument, "cannot exceed 50 requests in single batch") } - // Naive implementation, perform all sub query requests sequentially responses := make([]*proto.QueryResponse, 0) for _, query := range req.Requests { - if len(query.ContentTopics) > maxTopicsPerQueryRequest { - return nil, status.Errorf(codes.InvalidArgument, "the number of content topics(%d) exceed the maximum topics per query request (%d)", len(query.ContentTopics), maxTopicsPerQueryRequest) - } // We execute the query using the existing Query API resp, err := s.Query(ctx, query) if err != nil { @@ -428,6 +394,10 @@ func buildEnvelope(msg *wakupb.WakuMessage) *proto.Envelope { } } +func isValidSubscribeAllTopic(contentTopic string) bool { + return strings.HasPrefix(contentTopic, validXMTPTopicPrefix) || topic.IsMLSV1(contentTopic) +} + func fromWakuTimestamp(ts int64) uint64 { if ts < 0 { return 0 diff --git a/pkg/api/message/v1/subscription.go b/pkg/api/message/v1/subscription.go deleted file mode 100644 index 7fc276f1..00000000 --- a/pkg/api/message/v1/subscription.go +++ /dev/null @@ -1,155 +0,0 @@ -package api - -import ( - "strings" - "sync" - - "github.com/nats-io/nats.go" // NATS messaging system - proto "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1" // Custom XMTP Protocol Buffers definition - "go.uber.org/zap" // Logging library - pb "google.golang.org/protobuf/proto" // Protocol Buffers for serialization -) - -const ( - // allTopicsBacklogLength defines the buffer size for subscriptions that listen to all topics. - allTopicsBacklogLength = 1024 - - // minBacklogBufferLength defines the minimal length used for backlog buffer. - minBacklogBufferLength -) - -// subscriptionDispatcher manages subscriptions and message dispatching. -type subscriptionDispatcher struct { - natsConn *nats.Conn // Connection to NATS server - natsSub *nats.Subscription // Subscription to NATS topics - log *zap.Logger // Logger instance - subscriptions map[*subscription]interface{} // Active subscriptions - mu sync.Mutex // Mutex for concurrency control -} - -// newSubscriptionDispatcher creates a new dispatcher for managing subscriptions. -func newSubscriptionDispatcher(conn *nats.Conn, log *zap.Logger) (*subscriptionDispatcher, error) { - dispatcher := &subscriptionDispatcher{ - natsConn: conn, - log: log, - subscriptions: make(map[*subscription]interface{}), - } - - // Subscribe to NATS wildcard topic and assign message handler - var err error - dispatcher.natsSub, err = conn.Subscribe(natsWildcardTopic, dispatcher.messageHandler) - if err != nil { - return nil, err - } - return dispatcher, nil -} - -// Shutdown gracefully shuts down the dispatcher, unsubscribing from all topics. -func (d *subscriptionDispatcher) Shutdown() { - _ = d.natsSub.Unsubscribe() - // the lock/unlock ensures that there is no in-process dispatching. - d.mu.Lock() - defer d.mu.Unlock() - d.natsSub = nil - d.natsConn = nil - d.subscriptions = nil - -} - -// messageHandler processes incoming messages, dispatching them to the correct subscription. -func (d *subscriptionDispatcher) messageHandler(msg *nats.Msg) { - var env proto.Envelope - err := pb.Unmarshal(msg.Data, &env) - if err != nil { - d.log.Info("unmarshaling envelope", zap.Error(err)) - return - } - - xmtpTopic := isValidSubscribeAllTopic(env.ContentTopic) - - d.mu.Lock() - defer d.mu.Unlock() - for subscription := range d.subscriptions { - if subscription.all && !xmtpTopic { - continue - } - if subscription.all || subscription.topics[env.ContentTopic] { - select { - case subscription.messagesCh <- &env: - default: - // we got here since the message channel was full. This happens when the client cannot - // consume the data fast enough. In that case, we don't want to block further since it migth - // slow down other users. Instead, we're going to close the channel and let the - // consumer re-establish the connection if needed. - close(subscription.messagesCh) - delete(d.subscriptions, subscription) - } - } - } -} - -// subscription represents a single subscription, including its message channel and topics. -type subscription struct { - messagesCh chan *proto.Envelope // Channel for receiving messages - topics map[string]bool // Map of topics to subscribe to - all bool // Flag indicating subscription to all topics - dispatcher *subscriptionDispatcher // Parent dispatcher -} - -// log2 calculates the base-2 logarithm of an integer using bitwise operations. -// It returns the floor of the actual base-2 logarithm. -func log2(n uint) (log2 uint) { - if n == 0 { - return 0 - } - - // Keep shifting n right until it becomes 0. - // The number of shifts needed is the floor of log2(n). - for n > 1 { - n >>= 1 - log2++ - } - return log2 -} - -// Subscribe creates a new subscription for the given topics. -func (d *subscriptionDispatcher) Subscribe(topics map[string]bool) *subscription { - sub := &subscription{ - dispatcher: d, - } - - // Determine if subscribing to all topics or specific ones - for topic := range topics { - if natsWildcardTopic == topic { - sub.all = true - break - } - } - if !sub.all { - sub.topics = topics - // use a log2(length) as a backbuffer - backlogBufferSize := log2(uint(len(topics))) + 1 - if backlogBufferSize < minBacklogBufferLength { - backlogBufferSize = minBacklogBufferLength - } - sub.messagesCh = make(chan *proto.Envelope, backlogBufferSize) - } else { - sub.messagesCh = make(chan *proto.Envelope, allTopicsBacklogLength) - } - - d.mu.Lock() - defer d.mu.Unlock() - d.subscriptions[sub] = true - return sub -} - -// Unsubscribe removes the subscription from its dispatcher. -func (sub *subscription) Unsubscribe() { - sub.dispatcher.mu.Lock() - defer sub.dispatcher.mu.Unlock() - delete(sub.dispatcher.subscriptions, sub) -} - -func isValidSubscribeAllTopic(topic string) bool { - return strings.HasPrefix(topic, validXMTPTopicPrefix) -} diff --git a/pkg/api/server.go b/pkg/api/server.go index 96e12636..267804a2 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -283,7 +283,6 @@ func (s *Server) Close() { if err != nil { s.Log.Error("closing http listener", zap.Error(err)) } - s.httpListener = nil } if s.grpcListener != nil { @@ -291,7 +290,6 @@ func (s *Server) Close() { if err != nil { s.Log.Error("closing grpc listener", zap.Error(err)) } - s.grpcListener = nil } s.wg.Wait() diff --git a/pkg/api/server_test.go b/pkg/api/server_test.go index abbd8ee4..5273ce1e 100644 --- a/pkg/api/server_test.go +++ b/pkg/api/server_test.go @@ -5,10 +5,8 @@ import ( "encoding/json" "fmt" "io" - "math/rand" "net/http" "strings" - "sync" "testing" "time" @@ -18,7 +16,6 @@ import ( messageV1 "github.com/xmtp/xmtp-node-go/pkg/proto/message_api/v1" "github.com/xmtp/xmtp-node-go/pkg/ratelimiter" test "github.com/xmtp/xmtp-node-go/pkg/testing" - "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -63,24 +60,6 @@ func Test_HTTPRootPath(t *testing.T) { require.NotEmpty(t, body) } -type deferedPublishResponseFunc func(tb testing.TB) - -func publishTestEnvelopes(ctx context.Context, client messageclient.Client, msgs *messageV1.PublishRequest) deferedPublishResponseFunc { - var waitGroup sync.WaitGroup - waitGroup.Add(1) - var publishErr error - var publishRes *messageV1.PublishResponse - go func() { - defer waitGroup.Done() - publishRes, publishErr = client.Publish(ctx, msgs) - }() - return func(tb testing.TB) { - waitGroup.Wait() - require.NoError(tb, publishErr) - require.NotNil(tb, publishRes) - } -} - func Test_SubscribePublishQuery(t *testing.T) { ctx := withAuth(t, context.Background()) testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { @@ -94,8 +73,9 @@ func Test_SubscribePublishQuery(t *testing.T) { // publish 10 messages envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) // read subscription subscribeExpect(t, stream, envs) @@ -267,8 +247,9 @@ func Test_GRPCMaxMessageSize(t *testing.T) { TimestampNs: 3, }, } - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) subscribeExpect(t, stream, envs) requireEventuallyStored(t, ctx, client, envs) @@ -343,8 +324,9 @@ func Test_SubscribeClientClose(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -352,9 +334,9 @@ func Test_SubscribeClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) - defer deferedPublishResult(t) - + publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) + require.NoError(t, err) + require.NotNil(t, publishRes) time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -377,8 +359,9 @@ func Test_Subscribe2ClientClose(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -386,8 +369,9 @@ func Test_Subscribe2ClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) - defer deferedPublishResult(t) + publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) + require.NoError(t, err) + require.NotNil(t, publishRes) time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -410,9 +394,9 @@ func Test_Subscribe2UpdateTopics(t *testing.T) { // publish 5 messages envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) - defer deferedPublishResult(t) - + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -459,8 +443,9 @@ func Test_SubscribeAllClientClose(t *testing.T) { for i, env := range envs { envs[i].ContentTopic = "/xmtp/0/" + env.ContentTopic } - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 and close the stream subscribeExpect(t, stream, envs[:5]) @@ -468,8 +453,9 @@ func Test_SubscribeAllClientClose(t *testing.T) { require.NoError(t, err) // publish another 5 - deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) - defer deferedPublishResult(t) + publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) + require.NoError(t, err) + require.NotNil(t, publishRes) time.Sleep(50 * time.Millisecond) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) @@ -492,11 +478,12 @@ func Test_SubscribeServerClose(t *testing.T) { // Publish 5 messages. envs := makeEnvelopes(5) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) // Receive 5 - subscribeExpect(t, stream, envs) + subscribeExpect(t, stream, envs[:5]) // stop Server server.Close() @@ -522,8 +509,9 @@ func Test_SubscribeAllServerClose(t *testing.T) { for i, env := range envs { envs[i].ContentTopic = "/xmtp/0/" + env.ContentTopic } - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) // Receive 5 subscribeExpect(t, stream, envs[:5]) @@ -593,8 +581,9 @@ func Test_MultipleSubscriptions(t *testing.T) { // publish 5 envelopes envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[:5]}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[:5]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 envelopes on both streams subscribeExpect(t, stream1, envs[:5]) @@ -611,8 +600,9 @@ func Test_MultipleSubscriptions(t *testing.T) { time.Sleep(50 * time.Millisecond) // publish another 5 envelopes - deferedPublishResult = publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs[5:]}) - defer deferedPublishResult(t) + publishRes, err = client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs[5:]}) + require.NoError(t, err) + require.NotNil(t, publishRes) // receive 5 on stream 2 and 3 subscribeExpect(t, stream2, envs[5:]) @@ -625,8 +615,9 @@ func Test_QueryPaging(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) time.Sleep(50 * time.Millisecond) requireEventuallyStored(t, ctx, client, envs) @@ -673,8 +664,9 @@ func Test_BatchQuery(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) requireEventuallyStored(t, ctx, client, envs) batchSize := 50 @@ -722,8 +714,9 @@ func Test_BatchQueryOverLimitError(t *testing.T) { testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { // Store 10 envelopes with increasing SenderTimestamp envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) + publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) + require.NoError(t, err) + require.NotNil(t, publishRes) requireEventuallyStored(t, ctx, client, envs) // Limit is 50 queries implicitly so 100 should result in an error @@ -738,7 +731,7 @@ func Test_BatchQueryOverLimitError(t *testing.T) { } repeatedQueries = append(repeatedQueries, query) } - _, err := client.BatchQuery(ctx, &messageV1.BatchQueryRequest{ + _, err = client.BatchQuery(ctx, &messageV1.BatchQueryRequest{ Requests: repeatedQueries, }) grpcErr, ok := status.FromError(err) @@ -848,109 +841,3 @@ func requireErrorEqual(t *testing.T, err error, code codes.Code, msg string, det require.ElementsMatch(t, details, httpErr["details"]) } } - -func Benchmark_SubscribePublishQuery(b *testing.B) { - server, cleanup := newTestServerWithLog(b, zap.NewNop()) - defer cleanup() - - ctx := withAuth(b, context.Background()) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - client, err := messageclient.NewGRPCClient(ctx, server.dialGRPC) - require.NoError(b, err) - - // create topics large topics for 10 streams. Topic should be interleaved. - const chunkSize = 1000 - const streamsCount = 10 - topics := [streamsCount][]string{} - - maxTopic := (len(topics)-1)*chunkSize*3/4 + chunkSize - // create a random order of topics. - topicsOrder := rand.Perm(maxTopic) - envs := make([]*messageV1.Envelope, len(topicsOrder)) - for i, topicID := range topicsOrder { - envs[i] = &messageV1.Envelope{ - ContentTopic: fmt.Sprintf("/xmtp/0/topic/%d", topicID), - Message: []byte{1, 2, 3}, - TimestampNs: uint64(time.Second), - } - } - - for j := range topics { - topics[j] = make([]string, chunkSize) - for k := range topics[j] { - topics[j][k] = fmt.Sprintf("/xmtp/0/topic/%d", (j*chunkSize*3/4 + k)) - } - } - - streams := [10]messageclient.Stream{} - b.ResetTimer() - for i := range streams { - // start subscribe streams - var err error - streams[i], err = client.Subscribe(ctx, &messageV1.SubscribeRequest{ - ContentTopics: topics[i], - }) - require.NoError(b, err) - defer streams[i].Close() - } - - for n := 0; n < b.N; n++ { - // publish messages - publishRes, err := client.Publish(ctx, &messageV1.PublishRequest{Envelopes: envs}) - require.NoError(b, err) - require.NotNil(b, publishRes) - - readCtx, readCancel := context.WithTimeout(context.Background(), 30*time.Second) - defer readCancel() - // read subscription - for _, stream := range streams { - for k := 0; k < chunkSize; k++ { - _, err := stream.Next(readCtx) - require.NoError(b, err) - } - } - } -} - -func Test_LargeQueryTesting(t *testing.T) { - ctx := withAuth(t, context.Background()) - testGRPCAndHTTP(t, ctx, func(t *testing.T, client messageclient.Client, _ *Server) { - // Store 10 envelopes with increasing SenderTimestamp - envs := makeEnvelopes(10) - deferedPublishResult := publishTestEnvelopes(ctx, client, &messageV1.PublishRequest{Envelopes: envs}) - defer deferedPublishResult(t) - time.Sleep(50 * time.Millisecond) - requireEventuallyStored(t, ctx, client, envs) - - // create a large set of query topics. - topics := make([]string, 512*1024) - for i := range topics { - topics[i] = fmt.Sprintf("topic/%d", i) - } - size := 16 - step := 16 - prevSize := 16 - for { - query := &messageV1.QueryRequest{ - ContentTopics: topics[:size], - } - _, err := client.Query(ctx, query) - if err != nil { - // go back, and cut the step by half. - size = prevSize - step /= 2 - size += step - if step == 0 { - break - } - continue - } - prevSize = size - step *= 2 - size += step - } - t.Logf("max number of topics without any error was %d", size) - }) -} diff --git a/pkg/api/setup_test.go b/pkg/api/setup_test.go index 3a234202..806f9c1d 100644 --- a/pkg/api/setup_test.go +++ b/pkg/api/setup_test.go @@ -19,8 +19,9 @@ const ( testMaxMsgSize = 2 * 1024 * 1024 ) -func newTestServerWithLog(t testing.TB, log *zap.Logger) (*Server, func()) { - waku, wakuCleanup := test.NewNode(t, log) +func newTestServer(t *testing.T) (*Server, func()) { + log := test.NewLog(t) + waku, wakuCleanup := test.NewNode(t) store, storeCleanup := newTestStore(t, log) authzDB, _, authzDBCleanup := test.NewAuthzDB(t) allowLister := authz.NewDatabaseWalletAllowLister(authzDB, log) @@ -37,7 +38,7 @@ func newTestServerWithLog(t testing.TB, log *zap.Logger) (*Server, func()) { MaxMsgSize: testMaxMsgSize, }, Waku: waku, - Log: log, + Log: test.NewLog(t), Store: store, AllowLister: allowLister, }) @@ -50,12 +51,7 @@ func newTestServerWithLog(t testing.TB, log *zap.Logger) (*Server, func()) { } } -func newTestServer(t testing.TB) (*Server, func()) { - log := test.NewLog(t) - return newTestServerWithLog(t, log) -} - -func newTestStore(t testing.TB, log *zap.Logger) (*store.Store, func()) { +func newTestStore(t *testing.T, log *zap.Logger) (*store.Store, func()) { db, _, dbCleanup := test.NewDB(t) store, err := store.New(&store.Config{ Log: log, @@ -113,7 +109,7 @@ func testGRPC(t *testing.T, ctx context.Context, f func(*testing.T, messageclien f(t, c, server) } -func withAuth(t testing.TB, ctx context.Context) context.Context { +func withAuth(t *testing.T, ctx context.Context) context.Context { ctx, _ = withAuthWithDetails(t, ctx, time.Now()) return ctx } @@ -143,7 +139,7 @@ func withMissingIdentityKey(t *testing.T, ctx context.Context) context.Context { return metadata.AppendToOutgoingContext(ctx, authorizationMetadataKey, "Bearer "+et) } -func withAuthWithDetails(t testing.TB, ctx context.Context, when time.Time) (context.Context, *v1.AuthData) { +func withAuthWithDetails(t *testing.T, ctx context.Context, when time.Time) (context.Context, *v1.AuthData) { token, data, err := generateV2AuthToken(when) require.NoError(t, err) et, err := EncodeAuthToken(token) diff --git a/pkg/api/utils_test.go b/pkg/api/utils_test.go index b8108fbf..3cc3f9ad 100644 --- a/pkg/api/utils_test.go +++ b/pkg/api/utils_test.go @@ -24,7 +24,7 @@ func makeEnvelopes(count int) (envs []*messageV1.Envelope) { return envs } -func subscribeExpect(t testing.TB, stream messageclient.Stream, expected []*messageV1.Envelope) { +func subscribeExpect(t *testing.T, stream messageclient.Stream, expected []*messageV1.Envelope) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() received := []*messageV1.Envelope{} @@ -38,7 +38,7 @@ func subscribeExpect(t testing.TB, stream messageclient.Stream, expected []*mess requireEnvelopesEqual(t, expected, received) } -func requireEventuallyStored(t testing.TB, ctx context.Context, client messageclient.Client, expected []*messageV1.Envelope) { +func requireEventuallyStored(t *testing.T, ctx context.Context, client messageclient.Client, expected []*messageV1.Envelope) { var queryRes *messageV1.QueryResponse require.Eventually(t, func() bool { var err error @@ -54,14 +54,14 @@ func requireEventuallyStored(t testing.TB, ctx context.Context, client messagecl requireEnvelopesEqual(t, expected, queryRes.Envelopes) } -func requireEnvelopesEqual(t testing.TB, expected, received []*messageV1.Envelope) { +func requireEnvelopesEqual(t *testing.T, expected, received []*messageV1.Envelope) { require.Equal(t, len(expected), len(received), "length mismatch") for i, env := range received { requireEnvelopeEqual(t, expected[i], env, "mismatched message[%d]", i) } } -func requireEnvelopeEqual(t testing.TB, expected, actual *messageV1.Envelope, msgAndArgs ...interface{}) { +func requireEnvelopeEqual(t *testing.T, expected, actual *messageV1.Envelope, msgAndArgs ...interface{}) { require.Equal(t, expected.ContentTopic, actual.ContentTopic, msgAndArgs...) require.Equal(t, expected.Message, actual.Message, msgAndArgs...) if expected.TimestampNs != 0 { diff --git a/pkg/e2e/test_messagev1.go b/pkg/e2e/test_messagev1.go index 5a69f879..3fb71cc0 100644 --- a/pkg/e2e/test_messagev1.go +++ b/pkg/e2e/test_messagev1.go @@ -114,32 +114,6 @@ syncLoop: } } - // start listeners - streamsCh := make([]chan *messagev1.Envelope, len(clients)) - for i := 0; i < clientCount; i++ { - envC := make(chan *messagev1.Envelope, 100) - go func(stream messageclient.Stream, envC chan *messagev1.Envelope) { - for { - env, err := stream.Next(ctx) - if err != nil { - if isErrClosedConnection(err) || err.Error() == "context canceled" { - break - } - s.log.Error("getting next", zap.Error(err)) - break - } - if env == nil { - continue - } - envC <- env - } - }(streams[i], envC) - streamsCh[i] = envC - } - - // wait until all the listeners are up and ready. - time.Sleep(100 * time.Millisecond) - // Publish messages. envs := []*messagev1.Envelope{} for i, client := range clients { @@ -152,7 +126,6 @@ syncLoop: } } envs = append(envs, clientEnvs...) - _, err = client.Publish(ctx, &messagev1.PublishRequest{ Envelopes: clientEnvs, }) @@ -163,7 +136,25 @@ syncLoop: // Expect them to be relayed to each subscription. for i := 0; i < clientCount; i++ { - err = subscribeExpect(streamsCh[i], envs) + stream := streams[i] + envC := make(chan *messagev1.Envelope, 100) + go func() { + for { + env, err := stream.Next(ctx) + if err != nil { + if isErrClosedConnection(err) || err.Error() == "context canceled" { + break + } + s.log.Error("getting next", zap.Error(err)) + break + } + if env == nil { + continue + } + envC <- env + } + }() + err = subscribeExpect(envC, envs) if err != nil { return err } diff --git a/pkg/server/node_test.go b/pkg/server/node_test.go index 14259b2f..223918c0 100644 --- a/pkg/server/node_test.go +++ b/pkg/server/node_test.go @@ -91,9 +91,9 @@ func TestNodes_Deployment(t *testing.T) { n2PrivKey := test.NewPrivateKey(t) // Spin up initial instances of the nodes. - n1, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n1PrivKey)) + n1, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n1PrivKey)) defer cleanup() - n2, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n2PrivKey)) + n2, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n2PrivKey)) defer cleanup() // Connect the nodes. @@ -101,9 +101,9 @@ func TestNodes_Deployment(t *testing.T) { test.Connect(t, n2, n1) // Spin up new instances of the nodes. - newN1, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n1PrivKey)) + newN1, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n1PrivKey)) defer cleanup() - newN2, cleanup := test.NewNode(t, test.NewLog(t), wakunode.WithPrivateKey(n2PrivKey)) + newN2, cleanup := test.NewNode(t, wakunode.WithPrivateKey(n2PrivKey)) defer cleanup() // Expect matching peer IDs for new and old instances. diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index aed1368f..63a1d137 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -21,11 +21,11 @@ func TestServer_NewShutdown(t *testing.T) { func TestServer_StaticNodesReconnect(t *testing.T) { t.Parallel() - n1, cleanup := test.NewNode(t, test.NewLog(t)) + n1, cleanup := test.NewNode(t) defer cleanup() n1ID := n1.Host().ID() - n2, cleanup := test.NewNode(t, test.NewLog(t)) + n2, cleanup := test.NewNode(t) defer cleanup() n2ID := n2.Host().ID() diff --git a/pkg/testing/log.go b/pkg/testing/log.go index cbfd3fd0..ec885634 100644 --- a/pkg/testing/log.go +++ b/pkg/testing/log.go @@ -14,7 +14,7 @@ func init() { flag.BoolVar(&debug, "debug", false, "debug level logging in tests") } -func NewLog(t testing.TB) *zap.Logger { +func NewLog(t *testing.T) *zap.Logger { cfg := zap.NewDevelopmentConfig() if !debug { cfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) diff --git a/pkg/testing/node.go b/pkg/testing/node.go index acdce878..68f425a0 100644 --- a/pkg/testing/node.go +++ b/pkg/testing/node.go @@ -18,7 +18,6 @@ import ( "github.com/waku-org/go-waku/tests" wakunode "github.com/waku-org/go-waku/waku/v2/node" "github.com/waku-org/go-waku/waku/v2/peerstore" - "go.uber.org/zap" ) func Connect(t *testing.T, n1 *wakunode.WakuNode, n2 *wakunode.WakuNode, protocols ...protocol.ID) { @@ -66,10 +65,11 @@ func Disconnect(t *testing.T, n1 *wakunode.WakuNode, n2 *wakunode.WakuNode) { }, 3*time.Second, 50*time.Millisecond) } -func NewNode(t testing.TB, log *zap.Logger, opts ...wakunode.WakuNodeOption) (*wakunode.WakuNode, func()) { +func NewNode(t *testing.T, opts ...wakunode.WakuNodeOption) (*wakunode.WakuNode, func()) { hostAddr, _ := net.ResolveTCPAddr("tcp", "0.0.0.0:0") prvKey := NewPrivateKey(t) ctx := context.Background() + log := NewLog(t) opts = append([]wakunode.WakuNodeOption{ wakunode.WithLogger(log), wakunode.WithPrivateKey(prvKey), @@ -94,7 +94,7 @@ func NewPeer(t *testing.T) host.Host { return host } -func NewPrivateKey(t testing.TB) *ecdsa.PrivateKey { +func NewPrivateKey(t *testing.T) *ecdsa.PrivateKey { key, err := tests.RandomHex(32) require.NoError(t, err) prvKey, err := crypto.HexToECDSA(key) diff --git a/pkg/testing/store.go b/pkg/testing/store.go index bb4c4aff..3727fe27 100644 --- a/pkg/testing/store.go +++ b/pkg/testing/store.go @@ -19,7 +19,7 @@ const ( localTestDBDSNSuffix = "?sslmode=disable" ) -func NewDB(t testing.TB) (*sql.DB, string, func()) { +func NewDB(t *testing.T) (*sql.DB, string, func()) { dsn := localTestDBDSNPrefix + localTestDBDSNSuffix ctlDB := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn))) dbName := "test_" + RandomStringLower(12) @@ -36,7 +36,7 @@ func NewDB(t testing.TB) (*sql.DB, string, func()) { } } -func NewAuthzDB(t testing.TB) (*bun.DB, string, func()) { +func NewAuthzDB(t *testing.T) (*bun.DB, string, func()) { db, dsn, cleanup := NewDB(t) bunDB := bun.NewDB(db, pgdialect.New())