diff --git a/pkg/common/common.go b/pkg/common/common.go index eccda56..9001048 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -36,6 +36,7 @@ type Module interface { type ModuleKey string type ConfigProvider interface { + GetBoxConfig() BoxConfig GetConfig(resourceName string, target any) error } diff --git a/pkg/factory/commonFactory.go b/pkg/factory/commonFactory.go index e5de84e..827bc27 100644 --- a/pkg/factory/commonFactory.go +++ b/pkg/factory/commonFactory.go @@ -75,14 +75,10 @@ func NewFromConfig(config Config) (common.Factory, error) { config.FileExtension, log.ForComponent("file_provider"), ) - var boxConfig common.BoxConfig - if err := provider.GetConfig("box", &boxConfig); err != nil { - log.Global().Warn().Err(err).Msg("cannot read box configuration") - } cf := &commonFactory{ modules: make(map[common.ModuleKey]common.Module), cfgProvider: provider, - boxConfig: boxConfig, + boxConfig: provider.GetBoxConfig(), } err := cf.Register(prometheus.NewModule) if err != nil { diff --git a/pkg/factory/configProvider.go b/pkg/factory/configProvider.go index 867bb8d..5517b41 100644 --- a/pkg/factory/configProvider.go +++ b/pkg/factory/configProvider.go @@ -20,6 +20,7 @@ import ( "errors" "github.com/rs/zerolog" "github.com/th2-net/th2-common-go/pkg/common" + "github.com/th2-net/th2-common-go/pkg/log" "io/fs" "os" ) @@ -37,17 +38,28 @@ func NewFileProvider(configPath string, extension string, logger zerolog.Logger) } func NewFileProviderForFS(fs fs.FS, extension string, logger zerolog.Logger) common.ConfigProvider { - return &fileConfigProvider{ + provider := fileConfigProvider{ configFS: fs, fileExtension: extension, zLogger: &logger, } + boxConfig := common.BoxConfig{} + if err := provider.GetConfig("box", &boxConfig); err != nil { + log.Global().Warn().Err(err).Msg("cannot read box configuration. user default values") + } + provider.boxConfig = boxConfig + return &provider } type fileConfigProvider struct { configFS fs.FS fileExtension string zLogger *zerolog.Logger + boxConfig common.BoxConfig +} + +func (cfd *fileConfigProvider) GetBoxConfig() common.BoxConfig { + return cfd.boxConfig } func (cfd *fileConfigProvider) GetConfig(resourceName string, target any) error { diff --git a/pkg/modules/queue/rabbitmq.go b/pkg/modules/queue/rabbitmq.go index 1d57d9b..f1c561e 100644 --- a/pkg/modules/queue/rabbitmq.go +++ b/pkg/modules/queue/rabbitmq.go @@ -42,19 +42,21 @@ func newRabbitMq( provider common.ConfigProvider, queueConfiguration queue.RouterConfig, ) (Module, error) { + boxConfig := provider.GetBoxConfig() connConfiguration := connection.Config{} configErr := provider.GetConfig(connectionConfigFilename, &connConfiguration) if configErr != nil { return nil, configErr } - return NewRabbitMq(connConfiguration, queueConfiguration) + return NewRabbitMq(boxConfig, connConfiguration, queueConfiguration) } func NewRabbitMq( + boxConfig common.BoxConfig, connConfiguration connection.Config, queueConfiguration queue.RouterConfig, ) (Module, error) { - messageRouter, eventRouter, manager, err := rabbitmq.NewRouters(connConfiguration, &queueConfiguration) + messageRouter, eventRouter, manager, err := rabbitmq.NewRouters(boxConfig, connConfiguration, &queueConfiguration) if err != nil { return nil, err } diff --git a/pkg/queue/rabbitmq/factory.go b/pkg/queue/rabbitmq/factory.go index 5ceddba..f47df44 100644 --- a/pkg/queue/rabbitmq/factory.go +++ b/pkg/queue/rabbitmq/factory.go @@ -17,6 +17,7 @@ package rabbitmq import ( "github.com/rs/zerolog" + "github.com/th2-net/th2-common-go/pkg/common" "github.com/th2-net/th2-common-go/pkg/log" "github.com/th2-net/th2-common-go/pkg/queue" "github.com/th2-net/th2-common-go/pkg/queue/event" @@ -29,13 +30,15 @@ import ( ) func NewRouters( + boxConfig common.BoxConfig, connection connection.Config, config *queue.RouterConfig, ) (messageRouter message.Router, eventRouter event.Router, closer io.Closer, err error) { - manager, err := internal.NewConnectionManager(connection, log.ForComponent("connection_manager")) + manager, err := internal.NewConnectionManager(connection, boxConfig.Name, log.ForComponent("connection_manager")) if err != nil { return } + go manager.ListenForBlockingNotifications() messageRouter = newMessageRouter(&manager, config, log.ForComponent("message_router")) eventRouter = newEventRouter(&manager, config, log.ForComponent("event_router")) closer = &manager diff --git a/pkg/queue/rabbitmq/internal/connection/connection.go b/pkg/queue/rabbitmq/internal/connection/connection.go new file mode 100644 index 0000000..91208ba --- /dev/null +++ b/pkg/queue/rabbitmq/internal/connection/connection.go @@ -0,0 +1,261 @@ +/* + * Copyright 2024 Exactpro (Exactpro Systems Limited) + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package connection + +import ( + "errors" + amqp "github.com/rabbitmq/amqp091-go" + "github.com/rs/zerolog" + "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection" + "sync" + "time" +) + +const ( + defaultMinRecoveryTimeout = 1 * time.Second + defaultMaxRecoveryTimeout = 60 * time.Second + // defaultMaxRecoveryAttempts used in case an error with status NOT_FOUND is returned from channel + defaultMaxRecoveryAttempts = 5 +) + +type connectionHolder struct { + connMutex sync.RWMutex + conn *amqp.Connection + channels map[string]*amqp.Channel + done chan struct{} + reconnectToMq func() (*amqp.Connection, error) + onConnectionRecovered func() + onChannelRecovered func(channelKey string) + logger zerolog.Logger + notifyMutex sync.Mutex + notifyRecovered []chan struct{} + minRecoveryTimeout time.Duration + maxRecoveryTimeout time.Duration +} + +func newConnection(url string, name string, logger zerolog.Logger, + configuration connection.Config, + onConnectionRecovered func(), onChannelRecovered func(channelKey string)) (*connectionHolder, error) { + if configuration.MinConnectionRecoveryTimeout > configuration.MaxConnectionRecoveryTimeout { + return nil, errors.New("min connection recovery timeout is greater than max connection recovery timeout") + } + var minRecoveryTimeout time.Duration + var maxRecoveryTimeout time.Duration + if configuration.MinConnectionRecoveryTimeout > 0 { + minRecoveryTimeout = time.Duration(configuration.MinConnectionRecoveryTimeout) * time.Millisecond + } else { + minRecoveryTimeout = defaultMinRecoveryTimeout + } + if configuration.MaxConnectionRecoveryTimeout > 0 { + maxRecoveryTimeout = time.Duration(configuration.MaxConnectionRecoveryTimeout) * time.Millisecond + } else { + maxRecoveryTimeout = defaultMaxRecoveryTimeout + } + logger.Info(). + Dur("minRecoveryTimeout", minRecoveryTimeout). + Dur("maxRecoveryTimeout", maxRecoveryTimeout). + Msg("recovery timeouts configured") + conn, err := dial(url, name) + if err != nil { + return nil, err + } + return &connectionHolder{ + connMutex: sync.RWMutex{}, + conn: conn, + channels: make(map[string]*amqp.Channel), + done: make(chan struct{}), + reconnectToMq: func() (*amqp.Connection, error) { + return dial(url, name) + }, + onConnectionRecovered: onConnectionRecovered, + onChannelRecovered: onChannelRecovered, + logger: logger, + notifyMutex: sync.Mutex{}, + notifyRecovered: make([]chan struct{}, 0), + minRecoveryTimeout: minRecoveryTimeout, + maxRecoveryTimeout: maxRecoveryTimeout, + }, nil +} + +func (c *connectionHolder) runConnectionRoutine() { + run := true + connectionClosed := true + var connectionErrors chan *amqp.Error + for run { + if connectionClosed { + connectionClosed = false + c.connMutex.RLock() + connectionErrors = c.conn.NotifyClose(make(chan *amqp.Error)) + c.connMutex.RUnlock() + } + select { + case <-c.done: + c.logger.Info(). + Msg("stopping connection routine") + run = false + break + case connErr, ok := <-connectionErrors: + if !ok { + // normal close + run = false + break + } + connectionClosed = true + c.logger.Error(). + Err(connErr). + Msg("received connection error. reconnecting") + c.tryToReconnect() + if c.onConnectionRecovered != nil { + c.onConnectionRecovered() + } + c.notifyMutex.Lock() + for _, ch := range c.notifyRecovered { + close(ch) + } + c.notifyRecovered = c.notifyRecovered[:0] + c.notifyMutex.Unlock() + } + } +} + +func (c *connectionHolder) tryToReconnect() { + var delay = c.minRecoveryTimeout + for { + err := c.reconnect() + if err == nil { + c.logger.Info(). + Msg("connection to rabbitmq restored") + break + } + c.logger.Error(). + Err(err). + Dur("timeout", delay). + Msg("reconnect failed. retrying after timeout") + time.Sleep(delay) + delay *= 2 + if delay > c.maxRecoveryTimeout { + delay = c.maxRecoveryTimeout + } + } +} + +func (c *connectionHolder) reconnect() (err error) { + c.connMutex.Lock() + defer c.connMutex.Unlock() + conn := c.conn + if conn != nil { + _ = conn.Close() + // clear map with channels + c.channels = make(map[string]*amqp.Channel) + } + conn, err = c.reconnectToMq() + if err == nil { + c.conn = conn + } + return +} + +func (c *connectionHolder) registerBlockingListener(blocking chan amqp.Blocking) <-chan amqp.Blocking { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + return c.conn.NotifyBlocked(blocking) +} + +func dial(url string, name string) (*amqp.Connection, error) { + properties := amqp.NewConnectionProperties() + properties.SetClientConnectionName(name) + conn, err := amqp.DialConfig(url, amqp.Config{ + Heartbeat: 30 * time.Second, + Locale: "en_US", + Properties: properties, + }) + return conn, err +} + +func (c *connectionHolder) Close() error { + close(c.done) + c.connMutex.RLock() + defer c.connMutex.RUnlock() + return c.conn.Close() +} + +func (c *connectionHolder) waitRecovered(ch chan struct{}) <-chan struct{} { + c.connMutex.RLock() + if !c.conn.IsClosed() { + close(ch) + c.connMutex.RUnlock() + return ch + } + c.connMutex.RUnlock() + + c.notifyMutex.Lock() + c.notifyRecovered = append(c.notifyRecovered, ch) + c.notifyMutex.Unlock() + return ch +} + +func (c *connectionHolder) getChannel(key string) (*amqp.Channel, error) { + var ch *amqp.Channel + var err error + var exists bool + <-c.waitRecovered(make(chan struct{})) + c.connMutex.RLock() + ch, exists = c.channels[key] + c.connMutex.RUnlock() + if !exists { + ch, err = c.getOrCreateChannel(key) + } + + return ch, err +} + +func (c *connectionHolder) getOrCreateChannel(key string) (*amqp.Channel, error) { + c.connMutex.Lock() + defer c.connMutex.Unlock() + var ch *amqp.Channel + var err error + var exists bool + ch, exists = c.channels[key] + if !exists { + ch, err = c.conn.Channel() + if err != nil { + return nil, err + } + c.channels[key] = ch + go func(ch *amqp.Channel) { + select { + case err, ok := <-ch.NotifyClose(make(chan *amqp.Error)): + if !ok { + break + } + c.connMutex.Lock() + c.logger.Warn(). + Err(err). + Str("channelKey", key). + Msg("removing cached channel") + delete(c.channels, key) + c.connMutex.Unlock() + if c.onChannelRecovered != nil { + c.onChannelRecovered(key) + } + case <-c.done: + // closed. do nothing + } + + }(ch) + } + return ch, nil +} diff --git a/pkg/queue/rabbitmq/internal/connection/connectionManager.go b/pkg/queue/rabbitmq/internal/connection/connectionManager.go index 6cc170b..95ea930 100644 --- a/pkg/queue/rabbitmq/internal/connection/connectionManager.go +++ b/pkg/queue/rabbitmq/internal/connection/connectionManager.go @@ -17,6 +17,7 @@ package connection import ( "fmt" + amqp "github.com/rabbitmq/amqp091-go" "github.com/rs/zerolog" "github.com/th2-net/th2-common-go/pkg/log" "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection" @@ -27,20 +28,21 @@ type Manager struct { Consumer *Consumer Logger zerolog.Logger + closed chan struct{} } -func NewConnectionManager(connConfiguration connection.Config, logger zerolog.Logger) (Manager, error) { +func NewConnectionManager(connConfiguration connection.Config, componentName string, logger zerolog.Logger) (Manager, error) { url := fmt.Sprintf("amqp://%s:%s@%s:%d/%s", connConfiguration.Username, connConfiguration.Password, connConfiguration.Host, connConfiguration.Port, connConfiguration.VHost) - publisher, err := NewPublisher(url, log.ForComponent("publisher")) + publisher, err := NewPublisher(url, connConfiguration, componentName, log.ForComponent("publisher")) if err != nil { return Manager{}, err } - consumer, err := NewConsumer(url, log.ForComponent("consumer")) + consumer, err := NewConsumer(url, connConfiguration, componentName, log.ForComponent("consumer")) if err != nil { if pubErr := publisher.Close(); pubErr != nil { logger.Err(pubErr). @@ -48,14 +50,67 @@ func NewConnectionManager(connConfiguration connection.Config, logger zerolog.Lo } return Manager{}, err } + go publisher.runConnectionRoutine() + go consumer.runConnectionRoutine() return Manager{ Publisher: &publisher, Consumer: &consumer, Logger: logger, + // capacity is one to avoid blocking close call + closed: make(chan struct{}), }, nil } +func (manager *Manager) ListenForBlockingNotifications() { + var run = true + var consumerClosed = true + var publisherClosed = true + + var publisherNotifications <-chan amqp.Blocking + var consumerNotifications <-chan amqp.Blocking + for run { + // We try to reinitialize the listeners on each iteration + // because in case of connection problems + // the connections for publisher and consumer will be recreated. + // Old channels will be closed and never receive a new value + if publisherClosed { + publisherNotifications = manager.Publisher.registerBlockingListener(make(chan amqp.Blocking, 1)) + publisherClosed = false + } + if consumerClosed { + consumerNotifications = manager.Consumer.registerBlockingListener(make(chan amqp.Blocking, 1)) + consumerClosed = false + } + select { + case <-manager.closed: + manager.Logger.Info().Msg("stop listening for blocking notifications") + run = false + break + case consumerBlocked, ok := <-consumerNotifications: + if !ok { + consumerClosed = true + break + } + manager.Logger.Warn(). + Str("reason", consumerBlocked.Reason). + Bool("active", consumerBlocked.Active). + Msg("received blocked notification for consumer") + case publisherBlocked, ok := <-publisherNotifications: + if !ok { + publisherClosed = true + break + } + manager.Logger.Warn(). + Str("reason", publisherBlocked.Reason). + Bool("active", publisherBlocked.Active). + Msg("received blocked notification for publisher") + } + } +} + func (manager *Manager) Close() error { + close(manager.closed) + err := manager.Publisher.Close() if err != nil { manager.Logger.Error().Err(err).Msg("cannot close publisher") diff --git a/pkg/queue/rabbitmq/internal/connection/consumer.go b/pkg/queue/rabbitmq/internal/connection/consumer.go index 5064bb2..85ffc67 100644 --- a/pkg/queue/rabbitmq/internal/connection/consumer.go +++ b/pkg/queue/rabbitmq/internal/connection/consumer.go @@ -17,11 +17,14 @@ package connection import ( "errors" + "fmt" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" amqp "github.com/rabbitmq/amqp091-go" "github.com/rs/zerolog" "github.com/th2-net/th2-common-go/pkg/metrics" + "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection" + "time" ) var th2RabbitmqMessageSizeSubscribeBytes = promauto.NewCounterVec( @@ -42,76 +45,76 @@ var th2RabbitmqMessageProcessDurationSeconds = promauto.NewHistogramVec( ) type Consumer struct { - url string - conn *amqp.Connection - channels map[string]*amqp.Channel - Logger zerolog.Logger + *connectionHolder + Logger zerolog.Logger + maxMissingQueueRecoveryAttempts int } -func NewConsumer(url string, logger zerolog.Logger) (Consumer, error) { +func NewConsumer(url string, configuration connection.Config, componentName string, logger zerolog.Logger) (Consumer, error) { if url == "" { return Consumer{}, errors.New("url is empty") } - conn, err := amqp.Dial(url) + maxMissingQueueRecoveryAttempts := defaultMaxRecoveryAttempts + if configuration.MaxRecoveryAttempts > 0 { + maxMissingQueueRecoveryAttempts = configuration.MaxRecoveryAttempts + } + consumer := Consumer{ + Logger: logger, + maxMissingQueueRecoveryAttempts: maxMissingQueueRecoveryAttempts, + } + conn, err := newConnection(url, fmt.Sprintf("%s_consumer", componentName), + logger, configuration, nil, nil) if err != nil { - return Consumer{}, err + return consumer, err } + consumer.connectionHolder = conn logger.Debug().Msg("Consumer connected") - return Consumer{ - url: url, - conn: conn, - channels: make(map[string]*amqp.Channel), - Logger: logger, - }, nil + return consumer, nil } -func (cns *Consumer) Close() error { - return cns.conn.Close() +func (cns *Consumer) Consume(queueName string, th2Pin string, th2Type string, handler func(delivery amqp.Delivery) error) error { + return cns.consume( + queueName, th2Pin, th2Type, cns.consumeWithAutoAck, "consume", + func(delivery amqp.Delivery, timer *prometheus.Timer) error { + defer timer.ObserveDuration() + return handler(delivery) + }, + ) } -func (cns *Consumer) Consume(queueName string, th2Pin string, th2Type string, handler func(delivery amqp.Delivery) error) error { - ch, err := cns.conn.Channel() +func (cns *Consumer) ConsumeWithManualAck(queueName string, th2Pin string, th2Type string, handler func(msgDelivery amqp.Delivery, timer *prometheus.Timer) error) error { + return cns.consume( + queueName, th2Pin, th2Type, cns.consumeWithManualAck, "consumeWithManualAck", + func(delivery amqp.Delivery, timer *prometheus.Timer) error { + return handler(delivery, timer) + }, + ) +} + +func (cns *Consumer) consume(queueName string, th2Pin string, th2Type string, + subscriptionProducer func(queueName string, methodName string) (*amqp.Channel, <-chan amqp.Delivery, error), + methodName string, handler func(delivery amqp.Delivery, timer *prometheus.Timer) error) error { + ch, msgs, err := subscriptionProducer(queueName, methodName) if err != nil { - cns.Logger.Error(). - Err(err). - Str("queue", queueName). - Msg("cannot open channel") return err } - cns.channels[queueName] = ch - - msgs, consErr := ch.Consume( - queueName, // queue - // TODO: we need to provide a name that will help to identify the component - "", // consumer - true, // auto-ack - false, // exclusive - false, // no-local - false, // no-wait - nil, // args - ) - if consErr != nil { - cns.Logger.Error(). - Err(err). - Str("method", "consume"). - Str("queue", queueName). - Msg("Consuming error") - return consErr - } go func() { cns.Logger.Debug(). - Str("method", "consume"). + Str("method", methodName). Str("queue", queueName). Msg("start handling messages") - for d := range msgs { - timer := prometheus.NewTimer(th2RabbitmqMessageProcessDurationSeconds.WithLabelValues(th2Pin, th2Type, queueName)) + running := true + durationObserver := th2RabbitmqMessageProcessDurationSeconds.WithLabelValues(th2Pin, th2Type, queueName) + messageSizeObserver := th2RabbitmqMessageSizeSubscribeBytes.WithLabelValues(th2Pin, th2Type, queueName) + handleDelivery := func(d amqp.Delivery) { + timer := prometheus.NewTimer(durationObserver) cns.Logger.Trace(). Str("exchange", d.Exchange). Str("routing", d.RoutingKey). Int("bodySize", len(d.Body)). Msg("receive delivery") - if err := handler(d); err != nil { + if err := handler(d, timer); err != nil { cns.Logger.Error(). Err(err). Str("exchange", d.Exchange). @@ -119,11 +122,53 @@ func (cns *Consumer) Consume(queueName string, th2Pin string, th2Type string, ha Int("bodySize", len(d.Body)). Msg("Cannot handle delivery") } - timer.ObserveDuration() - th2RabbitmqMessageSizeSubscribeBytes.WithLabelValues(th2Pin, th2Type, queueName).Add(float64(len(d.Body))) + messageSizeObserver.Add(float64(len(d.Body))) + } + deliveries := msgs + chErrors := ch.NotifyClose(make(chan *amqp.Error)) + for running { + select { + case _, ok := <-cns.done: + if !ok { + running = false + // drain messages + for d := range deliveries { + handleDelivery(d) + } + } + case chErr, ok := <-chErrors: + if !ok { + break + } + cns.Logger.Error(). + Err(chErr). + Str("queue", queueName). + Msg("consumer error") + // drain messages + for d := range deliveries { + handleDelivery(d) + } + ch, deliveries, err = subscriptionProducer(queueName, methodName) + if err != nil { + if errors.Is(err, amqp.ErrClosed) { + break + } + // TODO: decide what to do in this case + panic(fmt.Sprintf("failure during consumer %s recovery: %v", methodName, err)) + } + chErrors = ch.NotifyClose(make(chan *amqp.Error)) + cns.Logger.Info(). + Str("queue", queueName). + Msg("consumer channel recovered") + case d, ok := <-deliveries: + if !ok { + break + } + handleDelivery(d) + } } cns.Logger.Debug(). - Str("method", "consume"). + Str("method", methodName). Str("queue", queueName). Msg("stop handling messages") }() @@ -131,55 +176,69 @@ func (cns *Consumer) Consume(queueName string, th2Pin string, th2Type string, ha return nil } -func (cns *Consumer) ConsumeWithManualAck(queueName string, th2Pin string, th2Type string, handler func(msgDelivery amqp.Delivery, timer *prometheus.Timer) error) error { - ch, err := cns.conn.Channel() - if err != nil { - cns.Logger.Error(). - Err(err). - Str("queue", queueName). - Msg("cannot open channel") - return err - } - cns.channels[queueName] = ch - msgs, consErr := ch.Consume( - queueName, // queue - "", // consumer - false, // auto-ack - false, // exclusive - false, // no-local - false, // no-wait - nil, // args - ) - if consErr != nil { - cns.Logger.Error(). - Err(err). - Str("method", "consumeWithManualAck"). - Str("queue", queueName). - Msg("Consuming error") - return consErr - } - go func() { - cns.Logger.Debug(). - Str("method", "consumeWithManualAck"). - Str("queue", queueName). - Msg("start handling messages") - for d := range msgs { - timer := prometheus.NewTimer(th2RabbitmqMessageProcessDurationSeconds.WithLabelValues(th2Pin, th2Type, queueName)) - if err := handler(d, timer); err != nil { +func (cns *Consumer) consumeWithManualAck(queueName string, methodName string) (*amqp.Channel, <-chan amqp.Delivery, error) { + return cns.consumeFromQueue(queueName, methodName, false) +} + +func (cns *Consumer) consumeWithAutoAck(queueName string, methodName string) (*amqp.Channel, <-chan amqp.Delivery, error) { + return cns.consumeFromQueue(queueName, methodName, true) +} + +func (cns *Consumer) consumeFromQueue(queueName string, methodName string, autoAck bool) (*amqp.Channel, <-chan amqp.Delivery, error) { + attempts := 0 + timeout := cns.minRecoveryTimeout + for { + ch, err := cns.getChannel(queueName) + if err != nil { + return nil, nil, err + } + + msgs, err := cns.startConsuming(ch, queueName, autoAck) + if err != nil { + var amqpErr *amqp.Error + isAmqpErr := errors.As(err, &amqpErr) + if !isAmqpErr || amqpErr.Code != amqp.NotFound { cns.Logger.Error(). Err(err). - Str("exchange", d.Exchange). - Str("routing", d.RoutingKey). - Int("bodySize", len(d.Body)). - Msg("cannot handle delivery") + Str("method", methodName). + Str("queue", queueName). + Msg("consuming error") + return nil, nil, err } - th2RabbitmqMessageSizeSubscribeBytes.WithLabelValues(th2Pin, th2Type, queueName).Add(float64(len(d.Body))) + if attempts >= cns.maxMissingQueueRecoveryAttempts { + return nil, nil, err + } + cns.Logger.Warn(). + Str("method", methodName). + Str("queue", queueName). + Int("attempts", attempts). + Dur("timeout", timeout). + Msg("queue is not found. Retry after timeout") + time.Sleep(timeout) + timeout *= 2 + if timeout > cns.maxRecoveryTimeout { + timeout = cns.maxRecoveryTimeout + } + attempts += 1 + continue } - cns.Logger.Debug(). - Str("method", "consumeWithManualAck"). - Str("queue", queueName). - Msg("stop handling messages") - }() + return ch, msgs, nil + } +} - return nil +func (cns *Consumer) startConsuming(ch *amqp.Channel, queueName string, autoAck bool) (<-chan amqp.Delivery, error) { + msgs, err := ch.Consume( + queueName, // queue + // TODO: we need to provide a name that will help to identify the component + "", // consumer + autoAck, // auto-ack + false, // exclusive + false, // no-local + false, // no-wait + nil, // args + ) + if err != nil { + return nil, err + } + return msgs, nil } diff --git a/pkg/queue/rabbitmq/internal/connection/consumer_test.go b/pkg/queue/rabbitmq/internal/connection/consumer_test.go index 9a27db4..b4e15f0 100644 --- a/pkg/queue/rabbitmq/internal/connection/consumer_test.go +++ b/pkg/queue/rabbitmq/internal/connection/consumer_test.go @@ -34,10 +34,11 @@ func TestConsumer_Consume(t *testing.T) { } config := rabbitmq.StartMq(t, "test") - manager, err := NewConnectionManager(config, consumerLogger) + manager, err := NewConnectionManager(config, "test", consumerLogger) if err != nil { t.Fatal(err) } + go manager.ListenForBlockingNotifications() defer manager.Close() conn, err := rabbitmq.RawAmqp(t, config, true) if err != nil { diff --git a/pkg/queue/rabbitmq/internal/connection/publisher.go b/pkg/queue/rabbitmq/internal/connection/publisher.go index ee32a4e..6ae7b18 100644 --- a/pkg/queue/rabbitmq/internal/connection/publisher.go +++ b/pkg/queue/rabbitmq/internal/connection/publisher.go @@ -16,14 +16,15 @@ package connection import ( + "context" "errors" - "sync" - + "fmt" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" amqp "github.com/rabbitmq/amqp091-go" "github.com/rs/zerolog" "github.com/th2-net/th2-common-go/pkg/metrics" + connCfg "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection" ) var th2RabbitmqMessageSizePublishBytes = promauto.NewCounterVec( @@ -43,37 +44,37 @@ var th2RabbitmqMessagePublishTotal = promauto.NewCounterVec( ) type Publisher struct { - url string - conn *amqp.Connection - channels map[string]*amqp.Channel - mutex *sync.Mutex - + *connectionHolder Logger zerolog.Logger } -func NewPublisher(url string, logger zerolog.Logger) (Publisher, error) { +func NewPublisher(url string, configuration connCfg.Config, componentName string, logger zerolog.Logger) (Publisher, error) { if url == "" { return Publisher{}, errors.New("url is not set") } - conn, err := amqp.Dial(url) + publisher := Publisher{ + Logger: logger, + } + c, err := newConnection(url, fmt.Sprintf("%s_publisher", componentName), + logger, configuration, nil, nil) if err != nil { - return Publisher{}, err + return publisher, err } + publisher.connectionHolder = c logger.Debug().Msg("Publisher connected") - return Publisher{ - url: url, - conn: conn, - channels: make(map[string]*amqp.Channel), - mutex: &sync.Mutex{}, - Logger: logger, - }, nil + return publisher, nil } func (pb *Publisher) Publish(body []byte, routingKey string, exchange string, th2Pin string, th2Type string) error { ch, err := pb.getChannel(routingKey) + if err != nil { + return err + } - publError := ch.Publish(exchange, routingKey, false, false, amqp.Publishing{Body: body}) + // Ideally, the context should be passed from outside + // but this is breaking API change and we cannot do that + publError := ch.PublishWithContext(context.Background(), exchange, routingKey, false, false, amqp.Publishing{Body: body}) if publError != nil { pb.Logger.Error().Err(publError).Send() return err @@ -85,36 +86,3 @@ func (pb *Publisher) Publish(body []byte, routingKey string, exchange string, th return nil } - -func (pb *Publisher) Close() error { - return pb.conn.Close() -} - -func (pb *Publisher) getChannel(routingKey string) (*amqp.Channel, error) { - var ch *amqp.Channel - var err error - var exists bool - ch, exists = pb.channels[routingKey] - if !exists { - ch, err = pb.getOrCreateChannel(routingKey) - } - - return ch, err -} - -func (pb *Publisher) getOrCreateChannel(routingKey string) (*amqp.Channel, error) { - pb.mutex.Lock() - defer pb.mutex.Unlock() - var ch *amqp.Channel - var err error - var exists bool - ch, exists = pb.channels[routingKey] - if !exists { - ch, err = pb.conn.Channel() - if err != nil { - return nil, err - } - pb.channels[routingKey] = ch - } - return ch, nil -} diff --git a/pkg/queue/rabbitmq/internal/connection/publisher_test.go b/pkg/queue/rabbitmq/internal/connection/publisher_test.go index eb6a458..bf7afcb 100644 --- a/pkg/queue/rabbitmq/internal/connection/publisher_test.go +++ b/pkg/queue/rabbitmq/internal/connection/publisher_test.go @@ -33,10 +33,11 @@ func TestPublisher(t *testing.T) { } config := rabbitmq.StartMq(t, "test") - manager, err := NewConnectionManager(config, publisherLogger) + manager, err := NewConnectionManager(config, "test", publisherLogger) if err != nil { t.Fatal(err) } + go manager.ListenForBlockingNotifications() defer manager.Close() conn, err := rabbitmq.RawAmqp(t, config, true) if err != nil { diff --git a/test/modules/internal/shared.go b/test/modules/internal/shared.go index f38ae35..033b490 100644 --- a/test/modules/internal/shared.go +++ b/test/modules/internal/shared.go @@ -35,6 +35,13 @@ type testProvider struct { fs fs.FS } +func (p testProvider) GetBoxConfig() common.BoxConfig { + return common.BoxConfig{ + Name: "test", + Book: "test_book", + } +} + func (p testProvider) GetConfig(resource string, target any) error { data, err := fs.ReadFile(p.fs, resource) if err != nil { diff --git a/test/modules/rabbitmq/event/router_test.go b/test/modules/rabbitmq/event/router_test.go index a77f84c..b01f18b 100644 --- a/test/modules/rabbitmq/event/router_test.go +++ b/test/modules/rabbitmq/event/router_test.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/th2-net/th2-common-go/pkg/common" grpcCommon "github.com/th2-net/th2-common-go/pkg/common/grpc/th2_grpc_common" "github.com/th2-net/th2-common-go/pkg/queue" "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq" @@ -48,7 +49,7 @@ func TestEventRouterSendAll(t *testing.T) { conn.BindQueue(config, queue1, routingKey1) conn.BindQueue(config, queue2, routingKey2) - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -96,7 +97,7 @@ func TestEventRouterSendAllReportErrorInNoPinMatch(t *testing.T) { } config := rabbitmqSupport.StartMq(t, "test") - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -170,7 +171,7 @@ func TestEventRouterSubscribeAll(t *testing.T) { conn.BindQueue(config, queue2, key2) conn.BindQueue(config, queue3, key3) - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin1": { Exchange: config.ExchangeName, @@ -250,7 +251,7 @@ func TestEventRouterSubscribeAllWithManualAck(t *testing.T) { conn.BindQueue(config, queue2, key2) conn.BindQueue(config, queue3, key3) - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin1": { Exchange: config.ExchangeName, @@ -316,7 +317,7 @@ func TestEventRouterSubscribeAllReportErrorInNoPinMatch(t *testing.T) { } config := rabbitmqSupport.StartMq(t, "test") - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin": { Exchange: config.ExchangeName, @@ -372,7 +373,7 @@ func TestEventRouterSubscribeAllWithManualAckReportErrorInNoPinMatch(t *testing. } config := rabbitmqSupport.StartMq(t, "test") - _, router, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + _, router, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin": { Exchange: config.ExchangeName, diff --git a/test/modules/rabbitmq/message/reconnect_test.go b/test/modules/rabbitmq/message/reconnect_test.go new file mode 100644 index 0000000..8550709 --- /dev/null +++ b/test/modules/rabbitmq/message/reconnect_test.go @@ -0,0 +1,217 @@ +/* + * Copyright 2024 Exactpro (Exactpro Systems Limited) + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package message + +import ( + "context" + "fmt" + "github.com/stretchr/testify/assert" + "github.com/th2-net/th2-common-go/pkg/common" + "github.com/th2-net/th2-common-go/pkg/queue" + "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq" + "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection" + rabbitmqSupport "github.com/th2-net/th2-common-go/test/modules/rabbitmq" + "io" + "testing" + "time" +) + +func TestPublisherReconnects(t *testing.T) { + if testing.Short() { + t.Skip("do not run containers in short run") + return + } + containerName := fmt.Sprintf("reconnect-test-%d", time.Now().UTC().UnixNano()) + ctx := context.Background() + port := fmt.Sprintf("%d:%s", 9900, rabbitmqSupport.MqPort) + container := rabbitmqSupport.CreateMqContainer(ctx, t, containerName, port) + err := container.Start(ctx) + if err != nil { + t.Fatal("cannot start container:", err) + } + t.Cleanup(func() { + err := container.Terminate(ctx) + if err != nil { + t.Logf("cannot terminate container: %v", err) + } + }) + config := rabbitmqSupport.GetConfigForContainer(ctx, t, container, "test") + + routingKey := setupMq(t, config) + + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ + Queues: map[string]queue.DestinationConfig{ + "publish-pin1": { + Exchange: config.ExchangeName, + RoutingKey: routingKey, + Attributes: []string{"publish", "test", "unique"}, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + defer func(manager io.Closer) { + err := manager.Close() + if err != nil { + t.Logf("cannot close manager connection: %v", err) + } + }(manager) + + err = router.SendRawAll([]byte("hello"), "unique") + if err != nil { + t.Fatal("cannot send message:", err) + } + + err = container.Terminate(ctx) + if err != nil { + t.Fatal("cannot stop container:", err) + } + + // create new container with same port + // cannot Stop and Start container because of wait strategy + go func() { + time.Sleep(7 * time.Second) + container = rabbitmqSupport.CreateMqContainer(ctx, t, containerName, port) + err = container.Start(ctx) + if err != nil { + t.Error("cannot start container:", err) + return + } + _ = setupMq(t, config) + t.Log("rabbitmq container restarted") + }() + t.Log("sending messages after container restart") + err = router.SendRawAll([]byte("hello2"), "unique") + if err != nil { + t.Fatal("cannot send message:", err) + } + t.Log("message sent to queue after container restart") + + conn, err := rabbitmqSupport.RawAmqp(t, config, true) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + data := conn.Consume(conn.GetQueue(t, "test1")) + rabbitmqSupport.CheckReceiveBytes(t, data, []byte("hello2")) +} + +func TestConsumerReconnects(t *testing.T) { + if testing.Short() { + t.Skip("do not run containers in short run") + return + } + containerName := fmt.Sprintf("reconnect-test-%d", time.Now().UTC().UnixNano()) + ctx := context.Background() + port := fmt.Sprintf("%d:%s", 9900, rabbitmqSupport.MqPort) + container := rabbitmqSupport.CreateMqContainer(ctx, t, containerName, port) + err := container.Start(ctx) + if err != nil { + t.Fatal("cannot start container:", err) + } + t.Cleanup(func() { + err := container.Terminate(ctx) + if err != nil { + t.Logf("cannot terminate container: %v", err) + } + }) + config := rabbitmqSupport.GetConfigForContainer(ctx, t, container, "test") + + routingKey := setupMq(t, config) + + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ + Queues: map[string]queue.DestinationConfig{ + "sub-pin1": { + Exchange: config.ExchangeName, + QueueName: "test1", + Attributes: []string{"subscribe", "test", "unique"}, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + defer func(manager io.Closer) { + err := manager.Close() + if err != nil { + t.Logf("cannot close manager connection: %v", err) + } + }(manager) + + channel := make(chan []byte, 1) + monitor, err := router.SubscribeRawAll(rabbitmqSupport.TestRawListener{ + Channel: channel, + }) + if err != nil { + t.Fatal("cannot subscribe message:", err) + } + defer func(monitor queue.Monitor) { + _ = monitor.Unsubscribe() + }(monitor) + rawConn, err := rabbitmqSupport.RawAmqp(t, config, false) + if err != nil { + t.Fatal("cannot get raw connection", err) + } + firstBytes := []byte("hello") + rawConn.Publish(config, routingKey, firstBytes) + actual := <-channel + assert.Equal(t, firstBytes, actual) + + err = container.Terminate(ctx) + if err != nil { + t.Fatal("cannot stop container:", err) + } + + // create new container with same port + // cannot Stop and Start container because of wait strategy + recovered := make(chan struct{}) + go func() { + container = rabbitmqSupport.CreateMqContainer(ctx, t, containerName, port) + err = container.Start(ctx) + if err != nil { + t.Error("cannot start container:", err) + return + } + // delay is added to check that consumer will be recovered even if queue does not yet exist + time.Sleep(5 * time.Second) + _ = setupMq(t, config) + t.Log("rabbitmq container restarted") + recovered <- struct{}{} + close(recovered) + }() + <-recovered + rawConn, err = rabbitmqSupport.RawAmqp(t, config, false) + if err != nil { + t.Fatal("cannot get raw connection", err) + } + secondBytes := []byte("hello2") + rawConn.Publish(config, routingKey, secondBytes) + actual = <-channel + assert.Equal(t, secondBytes, actual) +} + +func setupMq(t *testing.T, config connection.Config) string { + conn, err := rabbitmqSupport.RawAmqp(t, config, true) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + queue1 := conn.CreateQueue("test1") + routingKey1 := "test-publish1" + conn.BindQueue(config, queue1, routingKey1) + return routingKey1 +} diff --git a/test/modules/rabbitmq/message/router_test.go b/test/modules/rabbitmq/message/router_test.go index 6061e2a..751b449 100644 --- a/test/modules/rabbitmq/message/router_test.go +++ b/test/modules/rabbitmq/message/router_test.go @@ -19,6 +19,7 @@ import ( "fmt" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/th2-net/th2-common-go/pkg/common" grpcCommon "github.com/th2-net/th2-common-go/pkg/common/grpc/th2_grpc_common" "github.com/th2-net/th2-common-go/pkg/queue" "github.com/th2-net/th2-common-go/pkg/queue/rabbitmq" @@ -49,7 +50,7 @@ func TestMessageRouterSendAll(t *testing.T) { conn.BindQueue(config, queue1, routingKey1) conn.BindQueue(config, queue2, routingKey2) - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -97,7 +98,7 @@ func TestMessageRouterSendAllReportErrorInNoPinMatch(t *testing.T) { } config := rabbitmqSupport.StartMq(t, "test") - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -158,7 +159,7 @@ func TestMessageRouterSendRaw(t *testing.T) { conn.BindQueue(config, queue1, routingKey1) conn.BindQueue(config, queue2, routingKey2) - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -206,7 +207,7 @@ func TestMessageRouterSendRawReportErrorInNoPinMatch(t *testing.T) { } config := rabbitmqSupport.StartMq(t, "test") - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "publish-pin1": { Exchange: config.ExchangeName, @@ -269,7 +270,7 @@ func TestMessageRouterSubscribeAll(t *testing.T) { conn.BindQueue(config, queue2, key2) conn.BindQueue(config, queue3, key3) - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin1": { Exchange: config.ExchangeName, @@ -349,7 +350,7 @@ func TestMessageRouterSubscribeAllWithAck(t *testing.T) { conn.BindQueue(config, queue2, key2) conn.BindQueue(config, queue3, key3) - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin1": { Exchange: config.ExchangeName, @@ -415,7 +416,7 @@ func TestMessageRouterSubscribeAllReportErrorInNoPinMatch(t *testing.T) { } config := rabbitmqSupport.StartMq(t, "test") - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin": { Exchange: config.ExchangeName, @@ -462,7 +463,7 @@ func TestMessageRouterSubscribeAllWithManualAckReportErrorInNoPinMatch(t *testin } config := rabbitmqSupport.StartMq(t, "test") - router, _, manager, err := rabbitmq.NewRouters(config, &queue.RouterConfig{ + router, _, manager, err := rabbitmq.NewRouters(common.BoxConfig{}, config, &queue.RouterConfig{ Queues: map[string]queue.DestinationConfig{ "sub-pin": { Exchange: config.ExchangeName, diff --git a/test/modules/rabbitmq/shared.go b/test/modules/rabbitmq/shared.go index 038cb0b..acbbe45 100644 --- a/test/modules/rabbitmq/shared.go +++ b/test/modules/rabbitmq/shared.go @@ -31,7 +31,7 @@ import ( const ( containerName = "rabbitmq-connection-test" - mqPort = "5672" + MqPort = "5672" TestBook = "test_book" TestScope = "test_scope" ) @@ -42,31 +42,22 @@ func StartMq(t *testing.T, exchange string) connection.Config { func StartMqWithContainerName(t *testing.T, containerName, exchange string) connection.Config { ctx := context.Background() - req := testcontainers.ContainerRequest{ - Name: containerName, - Image: "rabbitmq:3.10", - ExposedPorts: []string{mqPort}, - WaitingFor: wait.ForLog("Server startup complete"), - } - rabbit, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ - ContainerRequest: req, - Started: true, - Reuse: containerName != "", - }) - if err != nil { - t.Fatal("cannot create container", err) - } + rabbit := CreateMqContainer(ctx, t, containerName, MqPort) t.Cleanup(func() { err := rabbit.Terminate(ctx) if err != nil { t.Logf("cannot stop rabbitmq container: %v", err) } }) + return GetConfigForContainer(ctx, t, rabbit, exchange) +} + +func GetConfigForContainer(ctx context.Context, t *testing.T, rabbit testcontainers.Container, exchange string) connection.Config { host, err := rabbit.Host(ctx) if err != nil { t.Fatal(err) } - port, err := rabbit.MappedPort(ctx, mqPort) + port, err := rabbit.MappedPort(ctx, MqPort) if err != nil { t.Fatal(err) } @@ -79,6 +70,24 @@ func StartMqWithContainerName(t *testing.T, containerName, exchange string) conn } } +func CreateMqContainer(ctx context.Context, t *testing.T, containerName string, port string) testcontainers.Container { + req := testcontainers.ContainerRequest{ + Name: containerName, + Image: "rabbitmq:3.10", + ExposedPorts: []string{port}, + WaitingFor: wait.ForLog("Server startup complete"), + } + rabbit, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + Reuse: containerName != "", + }) + if err != nil { + t.Fatal("cannot create container", err) + } + return rabbit +} + type RawAmqpHolder struct { conn *amqp.Connection ch *amqp.Channel @@ -155,7 +164,7 @@ func (h RawAmqpHolder) BindQueue(connCfg connection.Config, queue amqp.Queue, bi func (h RawAmqpHolder) Publish(connCfg connection.Config, routingKey string, data []byte) { - err := h.ch.Publish(connCfg.ExchangeName, routingKey, true, false, amqp.Publishing{ + err := h.ch.PublishWithContext(context.Background(), connCfg.ExchangeName, routingKey, true, false, amqp.Publishing{ Body: data, }) if err != nil { @@ -163,6 +172,21 @@ func (h RawAmqpHolder) Publish(connCfg connection.Config, routingKey string, dat } } +func (h RawAmqpHolder) GetQueue(t *testing.T, name string) amqp.Queue { + q, err := h.ch.QueueDeclarePassive( + name, + false, // durable + false, // auto delete + false, // exclusive + false, + nil, + ) + if err != nil { + t.Fatal("cannot get queue", err) + } + return q +} + func (h RawAmqpHolder) Consume(queue amqp.Queue) <-chan amqp.Delivery { deliveries, err := h.ch.Consume( @@ -176,6 +200,19 @@ func (h RawAmqpHolder) Consume(queue amqp.Queue) <-chan amqp.Delivery { return deliveries } +type TestRawListener struct { + Channel chan []byte +} + +func (t TestRawListener) OnClose() error { + return nil +} + +func (t TestRawListener) Handle(delivery queue.Delivery, data []byte) error { + t.Channel <- data + return nil +} + type GenericListener[T any] struct { Channel chan *T }