diff --git a/dev/e2e/run b/dev/e2e/run index a8356918..d83f8a2c 100755 --- a/dev/e2e/run +++ b/dev/e2e/run @@ -3,7 +3,8 @@ set -eou pipefail GIT_COMMIT="$(git rev-parse HEAD)" +# shellcheck disable=SC2068 go run \ -ldflags="-X 'main.GitCommit=${GIT_COMMIT}'" \ cmd/xmtpd-e2e/main.go \ - "$@" + $@ diff --git a/dev/e2e/run-spray-local b/dev/e2e/run-spray-local index ed00af7f..798479d9 100755 --- a/dev/e2e/run-spray-local +++ b/dev/e2e/run-spray-local @@ -3,10 +3,11 @@ set -eou pipefail . dev/net/k8s-env -nodes="$(kubectl -n xmtp-nodes get pods -l "app.kubernetes.io/part-of=xmtp-nodes" -o=json | jq -r '.items[].metadata.labels["app.kubernetes.io/name"]')" -opts="" +nodes="$(dev/terraform/tf output -json | jq -r '.nodes.value[].name')" +opts=() while read -r node; do - opts="${opts} --api-url=${node}.localhost" + opts+=("--api-url=http://${node}.localhost") done <<< "$(echo -e "$nodes")" +opts+=("$@") -dev/e2e/run "${opts}" "$@" +dev/e2e/run "${opts[*]}" diff --git a/dev/terraform/plans/devnet-local/_variables.tf b/dev/terraform/plans/devnet-local/_variables.tf index 2600c01b..88b22465 100644 --- a/dev/terraform/plans/devnet-local/_variables.tf +++ b/dev/terraform/plans/devnet-local/_variables.tf @@ -14,6 +14,7 @@ variable "node_keys" { sensitive = true } variable "enable_chat_app" { default = true } +variable "enable_e2e" { default = true } variable "enable_monitoring" { default = true } variable "e2e_delay" { default = "" } variable "node_container_cpu_limit" { default = "500m" } diff --git a/dev/terraform/plans/devnet-local/main.tf b/dev/terraform/plans/devnet-local/main.tf index f6d5ed3c..11b94bb1 100644 --- a/dev/terraform/plans/devnet-local/main.tf +++ b/dev/terraform/plans/devnet-local/main.tf @@ -1,5 +1,5 @@ module "cluster" { - source = "git@github.com:xmtp-labs/xmtpd-terraform.git//modules/xmtp-cluster-kind?ref=12d1e46" + source = "git@github.com:xmtp-labs/xmtpd-terraform.git//modules/xmtp-cluster-kind?ref=0d6dce9" # Uncomment this line and comment out the previous source line to use a # local instance of xmtpd-modules living in the parent directory of xmtpd. @@ -12,6 +12,7 @@ module "cluster" { e2e_container_image = var.e2e_container_image e2e_delay = var.e2e_delay enable_chat_app = var.enable_chat_app + enable_e2e = var.enable_e2e enable_monitoring = var.enable_monitoring node_container_cpu_limit = var.node_container_cpu_limit node_container_memory_limit = var.node_container_memory_limit diff --git a/pkg/e2e/e2e.go b/pkg/e2e/e2e.go index 870b560b..d0c4f0ab 100644 --- a/pkg/e2e/e2e.go +++ b/pkg/e2e/e2e.go @@ -27,18 +27,19 @@ type E2E struct { } type Options struct { - APIURLs []string `long:"api-url" env:"XMTP_API_URLS" description:"XMTP node API URLs" default:"http://localhost"` - ClientsPerURL int `long:"clients-per-url" description:"Number of clients for each API URL" default:"1"` - MessagePerClient int `long:"messages-per-client" description:"Number of messages to publish for each client" default:"3"` - Continuous bool `long:"continuous" description:"Run continuously"` - ExitOnError bool `long:"exit-on-error" description:"Exit on error if running continuously"` - RunDelay time.Duration `long:"delay" description:"Delay between runs (in seconds)" default:"5s"` - AdminPort uint `long:"admin-port" description:"Admin HTTP server listen port" default:"7777"` + APIURLs []string `long:"api-url" env:"XMTP_API_URLS" description:"XMTP node API URLs" default:"http://localhost"` + ClientsPerURL int `long:"clients-per-url" description:"Number of clients for each API URL" default:"1"` + MessagePerClient int `long:"messages-per-client" description:"Number of messages to publish for each client" default:"3"` + Continuous bool `long:"continuous" description:"Run continuously"` + ExitOnError bool `long:"exit-on-error" description:"Exit on error if running continuously"` + RunDelay time.Duration `long:"delay" description:"Delay between runs" default:"5s"` + AdminPort uint `long:"admin-port" description:"Admin HTTP server listen port" default:"7777"` + QueryConvergenceDelay time.Duration `long:"query-convergence-delay" description:"Delay between query convergence checks" default:"10ms"` GitCommit string } -type testRunFunc func() error +type testRunFunc func(name string) error type Test struct { Name string @@ -47,7 +48,7 @@ type Test struct { func (e *E2E) Tests() []*Test { return []*Test{ - e.newTest("messagev1 publish subscribe query", e.testMessageV1PublishSubscribeQuery), + e.newTest("convergence", e.testConvergence), } } @@ -59,7 +60,7 @@ func New(ctx context.Context, opts *Options) (*E2E, error) { rand: rand.New(rand.NewSource(time.Now().UTC().UnixNano())), opts: opts, } - e.log.Info("running", zap.String("git-commit", opts.GitCommit)) + e.log.Info("running", zap.String("git-commit", opts.GitCommit), zap.Strings("nodes", opts.APIURLs)) if e.opts.Continuous { go func() { @@ -102,7 +103,7 @@ func (e *E2E) runTest(test *Test) error { started := time.Now().UTC() log := e.log.With(zap.String("test", test.Name)) - err := test.Run() + err := test.Run(test.Name) duration := time.Since(started) log = log.With(zap.Duration("duration", duration)) if err != nil { diff --git a/pkg/e2e/e2e_test.go b/pkg/e2e/e2e_test.go index 40580ee4..fa3c7969 100644 --- a/pkg/e2e/e2e_test.go +++ b/pkg/e2e/e2e_test.go @@ -42,7 +42,7 @@ func TestE2E(t *testing.T) { t.Run(test.Name, func(t *testing.T) { t.Parallel() - err := test.Run() + err := test.Run(test.Name) require.NoError(t, err) }) } diff --git a/pkg/e2e/metrics.go b/pkg/e2e/metrics.go index fc2a4619..c8f70b5c 100644 --- a/pkg/e2e/metrics.go +++ b/pkg/e2e/metrics.go @@ -10,7 +10,11 @@ import ( ) type Metrics struct { - runDuration *prometheus.HistogramVec + runDuration *prometheus.HistogramVec + subscribeDuration *prometheus.HistogramVec + publishDuration *prometheus.HistogramVec + subscribeConvergenceDuration *prometheus.HistogramVec + queryConvergenceDuration *prometheus.HistogramVec } func newMetrics() *Metrics { @@ -25,6 +29,46 @@ func newMetrics() *Metrics { }, []string{"test", "status"}, ), + subscribeDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "xmtpd", + Subsystem: "e2e", + Name: "subscribe_duration_us", + Help: "duration of test case subscribe (microseconds)", + Buckets: prometheus.ExponentialBuckets(10, 10, 10), + }, + []string{"test", "node", "status"}, + ), + publishDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "xmtpd", + Subsystem: "e2e", + Name: "publish_duration_us", + Help: "duration of test case publish (microseconds)", + Buckets: prometheus.ExponentialBuckets(10, 10, 10), + }, + []string{"test", "node", "status"}, + ), + subscribeConvergenceDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "xmtpd", + Subsystem: "e2e", + Name: "subscribe_convergence_duration_us", + Help: "duration of test case subscribe convergence (microseconds)", + Buckets: prometheus.ExponentialBuckets(10, 10, 10), + }, + []string{"test", "node", "status"}, + ), + queryConvergenceDuration: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "xmtpd", + Subsystem: "e2e", + Name: "query_convergence_duration_us", + Help: "duration of test case query convergence (microseconds)", + Buckets: prometheus.ExponentialBuckets(10, 10, 10), + }, + []string{"test", "node", "status"}, + ), } } @@ -44,3 +88,75 @@ func (m *Metrics) recordRun(ctx context.Context, test, status string, duration t } met.Observe(float64(duration.Microseconds())) } + +func (m *Metrics) recordSubscribe(ctx context.Context, test, node, status string, duration time.Duration) { + if m == nil || m.subscribeDuration == nil { + return + } + met, err := m.subscribeDuration.GetMetricWithLabelValues(test, node, status) + if err != nil { + ctx.Logger().Warn("error observing metric", + zap.Error(err), + zap.String("metric", "subscribe_duration_us"), + zap.String("test", test), + zap.String("node", node), + zap.String("status", status), + ) + return + } + met.Observe(float64(duration.Microseconds())) +} + +func (m *Metrics) recordPublish(ctx context.Context, test, node, status string, duration time.Duration) { + if m == nil || m.publishDuration == nil { + return + } + met, err := m.publishDuration.GetMetricWithLabelValues(test, node, status) + if err != nil { + ctx.Logger().Warn("error observing metric", + zap.Error(err), + zap.String("metric", "publish_duration_us"), + zap.String("test", test), + zap.String("node", node), + zap.String("status", status), + ) + return + } + met.Observe(float64(duration.Microseconds())) +} + +func (m *Metrics) recordSubscribeConvergence(ctx context.Context, test, node, status string, duration time.Duration) { + if m == nil || m.subscribeConvergenceDuration == nil { + return + } + met, err := m.subscribeConvergenceDuration.GetMetricWithLabelValues(test, node, status) + if err != nil { + ctx.Logger().Warn("error observing metric", + zap.Error(err), + zap.String("metric", "subscribe_convergence_duration_us"), + zap.String("test", test), + zap.String("node", node), + zap.String("status", status), + ) + return + } + met.Observe(float64(duration.Microseconds())) +} + +func (m *Metrics) recordQueryConvergence(ctx context.Context, test, node, status string, duration time.Duration) { + if m == nil || m.queryConvergenceDuration == nil { + return + } + met, err := m.queryConvergenceDuration.GetMetricWithLabelValues(test, node, status) + if err != nil { + ctx.Logger().Warn("error observing metric", + zap.Error(err), + zap.String("metric", "query_convergence_duration_us"), + zap.String("test", test), + zap.String("node", node), + zap.String("status", status), + ) + return + } + met.Observe(float64(duration.Microseconds())) +} diff --git a/pkg/e2e/test_convergence.go b/pkg/e2e/test_convergence.go new file mode 100644 index 00000000..7d4a3371 --- /dev/null +++ b/pkg/e2e/test_convergence.go @@ -0,0 +1,327 @@ +package e2e + +import ( + "bytes" + "fmt" + "io" + "net/url" + "strings" + "sync" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/pkg/errors" + messagev1 "github.com/xmtp/proto/v3/go/message_api/v1" + apiclient "github.com/xmtp/xmtpd/pkg/api/client" + "github.com/xmtp/xmtpd/pkg/context" + "github.com/xmtp/xmtpd/pkg/zap" + "google.golang.org/protobuf/proto" +) + +func (e *E2E) testConvergence(name string) error { + nodeHosts := make([]string, len(e.opts.APIURLs)) + for i, apiURL := range e.opts.APIURLs { + url, err := url.Parse(apiURL) + if err != nil { + return err + } + nodeHosts[i] = url.Hostname() + } + + // Initialize clients for each node. + clients := make([]apiclient.Client, len(e.opts.APIURLs)) + for i, apiURL := range e.opts.APIURLs { + appVersion := "xmtpd-e2e/" + if len(e.opts.GitCommit) > 0 { + appVersion += e.opts.GitCommit[:7] + } + apiURL, clientOpts, err := parseAPIURL(apiURL) + if err != nil { + return err + } + clients[i] = apiclient.NewHTTPClient(e.log, apiURL, e.opts.GitCommit, appVersion, clientOpts...) + defer clients[i].Close() + } + + topic := "test-" + e.randomStringLower(12) + + ctx := context.WithTimeout(e.ctx, 30*time.Second) + defer ctx.Close() + + failedNodes := make([]bool, len(clients)) + + // Subscribe across nodes. + subscribeStart := time.Now().UTC() + subs := make([]apiclient.Stream, len(clients)) + for nodeIndex, client := range clients { + sub, err := client.Subscribe(ctx, &messagev1.SubscribeRequest{ + ContentTopics: []string{ + topic, + }, + }) + if err != nil { + if err == context.Canceled { + e.log.Debug("context canceled", zap.Error(err)) + return nil + } + duration := time.Since(subscribeStart) + e.log.Error("error subscribing", zap.Error(err)) + e.metrics.recordSubscribe(e.ctx, name, nodeHosts[nodeIndex], "failed", duration) + clients[nodeIndex] = nil + failedNodes[nodeIndex] = true + } else { + subs[nodeIndex] = sub + defer sub.Close() + duration := time.Since(subscribeStart) + e.metrics.recordSubscribe(e.ctx, name, nodeHosts[nodeIndex], "passed", duration) + } + } + + // Publish messages. + publishStart := time.Now().UTC() + envs := []*messagev1.Envelope{} + var envsLock sync.Mutex + var publishGroup sync.WaitGroup + for nodeIndex, client := range clients { + if client == nil { + continue + } + client := client + nodeIndex := nodeIndex + publishGroup.Add(1) + go func() { + defer publishGroup.Done() + + clientEnvs := make([]*messagev1.Envelope, e.opts.MessagePerClient) + for j := 0; j < e.opts.MessagePerClient; j++ { + clientEnvs[j] = &messagev1.Envelope{ + ContentTopic: topic, + TimestampNs: uint64(j + 1), + Message: []byte(fmt.Sprintf("msg%d-%d", nodeIndex+1, j+1)), + } + } + func() { + envsLock.Lock() + defer envsLock.Unlock() + envs = append(envs, clientEnvs...) + }() + _, err := client.Publish(ctx, &messagev1.PublishRequest{ + Envelopes: clientEnvs, + }) + if err != nil { + duration := time.Since(publishStart) + e.log.Error("error publishing", zap.Error(err), zap.Duration("duration", time.Since(publishStart))) + e.metrics.recordPublish(e.ctx, name, nodeHosts[nodeIndex], "failed", duration) + failedNodes[nodeIndex] = true + return + } + + duration := time.Since(publishStart) + e.log.Info("published", zap.Duration("duration", duration), zap.String("node", nodeHosts[nodeIndex])) + e.metrics.recordPublish(e.ctx, name, nodeHosts[nodeIndex], "passed", duration) + }() + } + publishGroup.Wait() + + // Expect them to be relayed to each subscription. + var subscribeGroup sync.WaitGroup + for nodeIndex, sub := range subs { + if sub == nil { + continue + } + sub := sub + nodeIndex := nodeIndex + subscribeGroup.Add(1) + go func() { + defer subscribeGroup.Done() + + envC := make(chan *messagev1.Envelope, 100) + go func(sub apiclient.Stream) { + for { + env, err := sub.Next(ctx) + if err != nil { + if isErrClosedConnection(err) || err == context.Canceled { + break + } + e.log.Error("getting next", zap.Error(err)) + break + } + if env == nil { + continue + } + envC <- env + } + }(sub) + err := subscribeExpect(envC, envs) + if err != nil { + duration := time.Since(publishStart) + e.log.Error("error checking subscription", zap.Error(err), zap.Duration("duration", time.Since(publishStart))) + e.metrics.recordSubscribeConvergence(e.ctx, name, nodeHosts[nodeIndex], "failed", duration) + failedNodes[nodeIndex] = true + return + } + + subscribeConvergenceDuration := time.Since(publishStart) + e.log.Info("subscribe converged", zap.Duration("duration", subscribeConvergenceDuration), zap.String("node", nodeHosts[nodeIndex])) + e.metrics.recordSubscribeConvergence(e.ctx, name, nodeHosts[nodeIndex], "passed", subscribeConvergenceDuration) + }() + } + + // Expect that they're stored. + var queryGroup sync.WaitGroup + for nodeIndex, client := range clients { + if client == nil { + continue + } + client := client + nodeIndex := nodeIndex + queryGroup.Add(1) + go func() { + defer queryGroup.Done() + + err := e.expectQueryMessagesEventually(ctx, client, []string{topic}, envs) + if err != nil { + duration := time.Since(publishStart) + e.log.Error("error querying", zap.Error(err), zap.Duration("duration", duration)) + e.metrics.recordQueryConvergence(e.ctx, name, nodeHosts[nodeIndex], "failed", duration) + failedNodes[nodeIndex] = true + return + } + + duration := time.Since(publishStart) + e.log.Info("query converged", zap.Duration("duration", duration), zap.String("node", nodeHosts[nodeIndex])) + e.metrics.recordQueryConvergence(e.ctx, name, nodeHosts[nodeIndex], "passed", duration) + }() + } + + subscribeGroup.Wait() + queryGroup.Wait() + + var failedCount int + for _, failed := range failedNodes { + if !failed { + continue + } + failedCount++ + } + if failedCount > 0 { + return errors.New("some nodes were unavailable or failed to converge") + } + + return nil +} + +func subscribeExpect(envC chan *messagev1.Envelope, envs []*messagev1.Envelope) error { + receivedEnvs := []*messagev1.Envelope{} + waitC := time.After(5 * time.Second) + var done bool + for !done { + select { + case env := <-envC: + receivedEnvs = append(receivedEnvs, env) + if len(receivedEnvs) == len(envs) { + done = true + } + case <-waitC: + done = true + } + } + err := envsDiff(envs, receivedEnvs) + if err != nil { + return errors.Wrap(err, "expected subscribe envelopes") + } + return nil +} + +func isErrClosedConnection(err error) bool { + return errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed network connection") || strings.Contains(err.Error(), "response body closed") +} + +func (e *E2E) expectQueryMessagesEventually(ctx context.Context, client apiclient.Client, contentTopics []string, expectedEnvs []*messagev1.Envelope) error { + timeout := 10 * time.Second + started := time.Now() + for { + envs, err := query(ctx, client, contentTopics) + if err != nil { + return errors.Wrap(err, "querying") + } + if len(envs) == len(expectedEnvs) { + err := envsDiff(envs, expectedEnvs) + if err != nil { + return errors.Wrap(err, "expected query envelopes") + } + break + } + if time.Since(started) > timeout { + err := envsDiff(envs, expectedEnvs) + if err != nil { + return errors.Wrap(err, "expected query envelopes") + } + return fmt.Errorf("timeout waiting for query expectation with no diff") + } + time.Sleep(e.opts.QueryConvergenceDelay) + } + return nil +} + +func query(ctx context.Context, client apiclient.Client, contentTopics []string) ([]*messagev1.Envelope, error) { + var envs []*messagev1.Envelope + var pagingInfo *messagev1.PagingInfo + for { + res, err := client.Query(ctx, &messagev1.QueryRequest{ + ContentTopics: contentTopics, + PagingInfo: pagingInfo, + }) + if err != nil { + return nil, err + } + envs = append(envs, res.Envelopes...) + if len(res.Envelopes) == 0 || res.PagingInfo == nil || res.PagingInfo.Cursor == nil { + break + } + pagingInfo = res.PagingInfo + } + return envs, nil +} + +func envsDiff(a, b []*messagev1.Envelope) error { + diff := cmp.Diff(a, b, + cmpopts.SortSlices(func(a, b *messagev1.Envelope) bool { + if a.ContentTopic != b.ContentTopic { + return a.ContentTopic < b.ContentTopic + } + if a.TimestampNs != b.TimestampNs { + return a.TimestampNs < b.TimestampNs + } + return bytes.Compare(a.Message, b.Message) < 0 + }), + cmp.Comparer(proto.Equal), + ) + if diff != "" { + return fmt.Errorf("expected equal, diff: %s", diff) + } + return nil +} + +func parseAPIURL(apiURL string) (string, []apiclient.Option, error) { + // If the API URL is a subdomain of localhost, then replace it with + // localhost and include a Host header, since Go doesn't resolve this + // DNS properly. + opts := []apiclient.Option{} + url, err := url.Parse(apiURL) + if err != nil { + return "", nil, err + } + urlParts := strings.Split(url.Hostname(), ".") + if len(urlParts) == 2 && urlParts[1] == "localhost" { + opts = append(opts, apiclient.WithHeader("Host", url.Hostname())) + url.Host = "localhost" + port := url.Port() + if port != "80" { + url.Host += ":" + port + } + apiURL = url.String() + } + return apiURL, opts, nil +} diff --git a/pkg/e2e/test_messagev1.go b/pkg/e2e/test_messagev1.go deleted file mode 100644 index 7501d8fc..00000000 --- a/pkg/e2e/test_messagev1.go +++ /dev/null @@ -1,229 +0,0 @@ -package e2e - -import ( - "bytes" - "context" - "fmt" - "io" - "net/url" - "strings" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "github.com/pkg/errors" - messagev1 "github.com/xmtp/proto/v3/go/message_api/v1" - apiclient "github.com/xmtp/xmtpd/pkg/api/client" - "github.com/xmtp/xmtpd/pkg/zap" - "google.golang.org/protobuf/proto" -) - -func (e *E2E) testMessageV1PublishSubscribeQuery() error { - clients := make([]apiclient.Client, len(e.opts.APIURLs)) - for i, apiURL := range e.opts.APIURLs { - appVersion := "xmtpd-e2e/" - if len(e.opts.GitCommit) > 0 { - appVersion += e.opts.GitCommit[:7] - } - apiURL, clientOpts, err := parseAPIURL(apiURL) - if err != nil { - return err - } - clients[i] = apiclient.NewHTTPClient(e.log, apiURL, e.opts.GitCommit, appVersion, clientOpts...) - defer clients[i].Close() - } - - contentTopic := "test-" + e.randomStringLower(12) - - ctx, cancel := context.WithTimeout(e.ctx, 30*time.Second) - defer cancel() - - // Subscribe across nodes. - streams := make([]apiclient.Stream, len(clients)) - for i, client := range clients { - stream, err := client.Subscribe(ctx, &messagev1.SubscribeRequest{ - ContentTopics: []string{ - contentTopic, - }, - }) - if err != nil { - if err == context.Canceled { - e.log.Debug("context canceled", zap.Error(err)) - return nil - } - return errors.Wrap(err, "subscribing") - } - streams[i] = stream - defer stream.Close() - } - - // Publish messages. - envs := []*messagev1.Envelope{} - for i, client := range clients { - clientEnvs := make([]*messagev1.Envelope, e.opts.MessagePerClient) - for j := 0; j < e.opts.MessagePerClient; j++ { - clientEnvs[j] = &messagev1.Envelope{ - ContentTopic: contentTopic, - TimestampNs: uint64(j + 1), - Message: []byte(fmt.Sprintf("msg%d-%d", i+1, j+1)), - } - } - envs = append(envs, clientEnvs...) - _, err := client.Publish(ctx, &messagev1.PublishRequest{ - Envelopes: clientEnvs, - }) - if err != nil { - return errors.Wrap(err, "publishing") - } - } - - // Expect them to be relayed to each subscription. - for _, stream := range streams { - envC := make(chan *messagev1.Envelope, 100) - go func(stream apiclient.Stream) { - for { - env, err := stream.Next(ctx) - if err != nil { - if isErrClosedConnection(err) || err == context.Canceled { - break - } - e.log.Error("getting next", zap.Error(err)) - break - } - if env == nil { - continue - } - envC <- env - } - }(stream) - err := subscribeExpect(envC, envs) - if err != nil { - return err - } - } - - // Expect that they're stored. - for _, client := range clients { - err := expectQueryMessagesEventually(ctx, client, []string{contentTopic}, envs) - if err != nil { - return err - } - } - - return nil -} - -func subscribeExpect(envC chan *messagev1.Envelope, envs []*messagev1.Envelope) error { - receivedEnvs := []*messagev1.Envelope{} - waitC := time.After(5 * time.Second) - var done bool - for !done { - select { - case env := <-envC: - receivedEnvs = append(receivedEnvs, env) - if len(receivedEnvs) == len(envs) { - done = true - } - case <-waitC: - done = true - } - } - err := envsDiff(envs, receivedEnvs) - if err != nil { - return errors.Wrap(err, "expected subscribe envelopes") - } - return nil -} - -func isErrClosedConnection(err error) bool { - return errors.Is(err, io.EOF) || strings.Contains(err.Error(), "closed network connection") || strings.Contains(err.Error(), "response body closed") -} - -func expectQueryMessagesEventually(ctx context.Context, client apiclient.Client, contentTopics []string, expectedEnvs []*messagev1.Envelope) error { - timeout := 10 * time.Second - delay := 500 * time.Millisecond - started := time.Now() - for { - envs, err := query(ctx, client, contentTopics) - if err != nil { - return errors.Wrap(err, "querying") - } - if len(envs) == len(expectedEnvs) { - err := envsDiff(envs, expectedEnvs) - if err != nil { - return errors.Wrap(err, "expected query envelopes") - } - break - } - if time.Since(started) > timeout { - err := envsDiff(envs, expectedEnvs) - if err != nil { - return errors.Wrap(err, "expected query envelopes") - } - return fmt.Errorf("timeout waiting for query expectation with no diff") - } - time.Sleep(delay) - } - return nil -} - -func query(ctx context.Context, client apiclient.Client, contentTopics []string) ([]*messagev1.Envelope, error) { - var envs []*messagev1.Envelope - var pagingInfo *messagev1.PagingInfo - for { - res, err := client.Query(ctx, &messagev1.QueryRequest{ - ContentTopics: contentTopics, - PagingInfo: pagingInfo, - }) - if err != nil { - return nil, err - } - envs = append(envs, res.Envelopes...) - if len(res.Envelopes) == 0 || res.PagingInfo == nil || res.PagingInfo.Cursor == nil { - break - } - pagingInfo = res.PagingInfo - } - return envs, nil -} - -func envsDiff(a, b []*messagev1.Envelope) error { - diff := cmp.Diff(a, b, - cmpopts.SortSlices(func(a, b *messagev1.Envelope) bool { - if a.ContentTopic != b.ContentTopic { - return a.ContentTopic < b.ContentTopic - } - if a.TimestampNs != b.TimestampNs { - return a.TimestampNs < b.TimestampNs - } - return bytes.Compare(a.Message, b.Message) < 0 - }), - cmp.Comparer(proto.Equal), - ) - if diff != "" { - return fmt.Errorf("expected equal, diff: %s", diff) - } - return nil -} - -func parseAPIURL(apiURL string) (string, []apiclient.Option, error) { - // If the API URL is a subdomain of localhost, then replace it with - // localhost and include a Host header, since Go doesn't resolve this - // DNS properly. - opts := []apiclient.Option{} - url, err := url.Parse(apiURL) - if err != nil { - return "", nil, err - } - urlParts := strings.Split(url.Hostname(), ".") - if len(urlParts) == 2 && urlParts[1] == "localhost" { - opts = append(opts, apiclient.WithHeader("Host", url.Hostname())) - url.Host = "localhost" - port := url.Port() - if port != "80" { - url.Host += ":" + port - } - apiURL = url.String() - } - return apiURL, opts, nil -}