Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TH2-5181] Add rabbitmq connection and consumer recovery #16

Merged
merged 10 commits into from
Apr 28, 2024
1 change: 1 addition & 0 deletions pkg/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Module interface {
type ModuleKey string

type ConfigProvider interface {
GetBoxConfig() BoxConfig
GetConfig(resourceName string, target any) error
}

Expand Down
6 changes: 1 addition & 5 deletions pkg/factory/commonFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,10 @@ func NewFromConfig(config Config) (common.Factory, error) {
config.FileExtension,
log.ForComponent("file_provider"),
)
var boxConfig common.BoxConfig
if err := provider.GetConfig("box", &boxConfig); err != nil {
log.Global().Warn().Err(err).Msg("cannot read box configuration")
}
cf := &commonFactory{
modules: make(map[common.ModuleKey]common.Module),
cfgProvider: provider,
boxConfig: boxConfig,
boxConfig: provider.GetBoxConfig(),
}
err := cf.Register(prometheus.NewModule)
if err != nil {
Expand Down
14 changes: 13 additions & 1 deletion pkg/factory/configProvider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"github.com/rs/zerolog"
"github.com/th2-net/th2-common-go/pkg/common"
"github.com/th2-net/th2-common-go/pkg/log"
"io/fs"
"os"
)
Expand All @@ -37,17 +38,28 @@ func NewFileProvider(configPath string, extension string, logger zerolog.Logger)
}

func NewFileProviderForFS(fs fs.FS, extension string, logger zerolog.Logger) common.ConfigProvider {
return &fileConfigProvider{
provider := fileConfigProvider{
configFS: fs,
fileExtension: extension,
zLogger: &logger,
}
boxConfig := common.BoxConfig{}
if err := provider.GetConfig("box", &boxConfig); err != nil {
log.Global().Warn().Err(err).Msg("cannot read box configuration. user default values")
}
provider.boxConfig = boxConfig
return &provider
}

type fileConfigProvider struct {
configFS fs.FS
fileExtension string
zLogger *zerolog.Logger
boxConfig common.BoxConfig
}

func (cfd *fileConfigProvider) GetBoxConfig() common.BoxConfig {
return cfd.boxConfig
}

func (cfd *fileConfigProvider) GetConfig(resourceName string, target any) error {
Expand Down
6 changes: 4 additions & 2 deletions pkg/modules/queue/rabbitmq.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,21 @@ func newRabbitMq(
provider common.ConfigProvider,
queueConfiguration queue.RouterConfig,
) (Module, error) {
boxConfig := provider.GetBoxConfig()
connConfiguration := connection.Config{}
configErr := provider.GetConfig(connectionConfigFilename, &connConfiguration)
if configErr != nil {
return nil, configErr
}
return NewRabbitMq(connConfiguration, queueConfiguration)
return NewRabbitMq(boxConfig, connConfiguration, queueConfiguration)
}

func NewRabbitMq(
boxConfig common.BoxConfig,
connConfiguration connection.Config,
queueConfiguration queue.RouterConfig,
) (Module, error) {
messageRouter, eventRouter, manager, err := rabbitmq.NewRouters(connConfiguration, &queueConfiguration)
messageRouter, eventRouter, manager, err := rabbitmq.NewRouters(boxConfig, connConfiguration, &queueConfiguration)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/queue/rabbitmq/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package rabbitmq

import (
"github.com/rs/zerolog"
"github.com/th2-net/th2-common-go/pkg/common"
"github.com/th2-net/th2-common-go/pkg/log"
"github.com/th2-net/th2-common-go/pkg/queue"
"github.com/th2-net/th2-common-go/pkg/queue/event"
Expand All @@ -29,13 +30,15 @@ import (
)

func NewRouters(
boxConfig common.BoxConfig,
connection connection.Config,
config *queue.RouterConfig,
) (messageRouter message.Router, eventRouter event.Router, closer io.Closer, err error) {
manager, err := internal.NewConnectionManager(connection, log.ForComponent("connection_manager"))
manager, err := internal.NewConnectionManager(connection, boxConfig.Name, log.ForComponent("connection_manager"))
if err != nil {
return
}
go manager.ListenForBlockingNotifications()
messageRouter = newMessageRouter(&manager, config, log.ForComponent("message_router"))
eventRouter = newEventRouter(&manager, config, log.ForComponent("event_router"))
closer = &manager
Expand Down
257 changes: 257 additions & 0 deletions pkg/queue/rabbitmq/internal/connection/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/*
* Copyright 2024 Exactpro (Exactpro Systems Limited)
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package connection

import (
"errors"
amqp "github.com/rabbitmq/amqp091-go"
"github.com/rs/zerolog"
"github.com/th2-net/th2-common-go/pkg/queue/rabbitmq/connection"
"sync"
"time"
)

const (
defaultMinRecoveryTimeout = 1 * time.Second
defaultMaxRecoveryTimeout = 60 * time.Second
// defaultMaxRecoveryAttempts used in case an error with status NOT_FOUND is returned from channel
defaultMaxRecoveryAttempts = 5
)

type connectionHolder struct {
connMutex sync.RWMutex
conn *amqp.Connection
channels map[string]*amqp.Channel
done chan struct{}
reconnectToMq func() (*amqp.Connection, error)
onConnectionRecovered func()
onChannelRecovered func(channelKey string)
logger zerolog.Logger
notifyMutex sync.Mutex
notifyRecovered []chan struct{}
minRecoveryTimeout time.Duration
maxRecoveryTimeout time.Duration
}

func newConnection(url string, name string, logger zerolog.Logger,
configuration connection.Config,
onConnectionRecovered func(), onChannelRecovered func(channelKey string)) (*connectionHolder, error) {
if configuration.MinConnectionRecoveryTimeout > configuration.MaxConnectionRecoveryTimeout {
return nil, errors.New("min connection recovery timeout is greater than max connection recovery timeout")
}
var minRecoveryTimeout time.Duration
var maxRecoveryTimeout time.Duration
if configuration.MinConnectionRecoveryTimeout > 0 {
minRecoveryTimeout = time.Duration(configuration.MinConnectionRecoveryTimeout) * time.Millisecond
} else {
minRecoveryTimeout = defaultMinRecoveryTimeout
OptimumCode marked this conversation as resolved.
Show resolved Hide resolved
}
if configuration.MaxConnectionRecoveryTimeout > 0 {
maxRecoveryTimeout = time.Duration(configuration.MaxConnectionRecoveryTimeout) * time.Millisecond
} else {
maxRecoveryTimeout = defaultMaxRecoveryTimeout
}
conn, err := dial(url, name)
if err != nil {
return nil, err
}
return &connectionHolder{
connMutex: sync.RWMutex{},
conn: conn,
channels: make(map[string]*amqp.Channel),
done: make(chan struct{}),
reconnectToMq: func() (*amqp.Connection, error) {
return dial(url, name)
},
onConnectionRecovered: onConnectionRecovered,
onChannelRecovered: onChannelRecovered,
logger: logger,
notifyMutex: sync.Mutex{},
notifyRecovered: make([]chan struct{}, 0),
minRecoveryTimeout: minRecoveryTimeout,
maxRecoveryTimeout: maxRecoveryTimeout,
}, nil
}

func (c *connectionHolder) runConnectionRoutine() {
run := true
connectionClosed := true
var connectionErrors chan *amqp.Error
for run {
if connectionClosed {
connectionClosed = false
c.connMutex.RLock()
connectionErrors = c.conn.NotifyClose(make(chan *amqp.Error))
c.connMutex.RUnlock()
}
select {
case <-c.done:
c.logger.Info().
Msg("stopping connection routine")
run = false
break
case connErr, ok := <-connectionErrors:
if !ok {
// normal close
run = false
break
}
connectionClosed = true
c.logger.Error().
Err(connErr).
Msg("received connection error. reconnecting")
c.tryToReconnect()
if c.onConnectionRecovered != nil {
c.onConnectionRecovered()
}
c.notifyMutex.Lock()
for _, ch := range c.notifyRecovered {
close(ch)
}
c.notifyRecovered = c.notifyRecovered[:0]
c.notifyMutex.Unlock()
}
}
}

func (c *connectionHolder) tryToReconnect() {
var delay = c.minRecoveryTimeout
for {
err := c.reconnect()
if err == nil {
c.logger.Info().
Msg("connection to rabbitmq restored")
break
}
c.logger.Error().
Err(err).
Dur("timeout", delay).
Msg("reconnect failed. retrying after timeout")
time.Sleep(delay)
delay *= 2
if delay > c.maxRecoveryTimeout {
delay = c.maxRecoveryTimeout
}
}
}

func (c *connectionHolder) reconnect() (err error) {
c.connMutex.Lock()
defer c.connMutex.Unlock()
conn := c.conn
if conn != nil {
_ = conn.Close()
// clear map with channels
c.channels = make(map[string]*amqp.Channel)
}
conn, err = c.reconnectToMq()
if err == nil {
c.conn = conn
}
return
}

func (c *connectionHolder) registerBlockingListener(blocking chan amqp.Blocking) <-chan amqp.Blocking {
c.connMutex.RLock()
defer c.connMutex.RUnlock()
return c.conn.NotifyBlocked(blocking)
}

func dial(url string, name string) (*amqp.Connection, error) {
properties := amqp.NewConnectionProperties()
properties.SetClientConnectionName(name)
conn, err := amqp.DialConfig(url, amqp.Config{
Heartbeat: 30 * time.Second,
Locale: "en_US",
Properties: properties,
})
return conn, err
}

func (c *connectionHolder) Close() error {
close(c.done)
c.connMutex.RLock()
defer c.connMutex.RUnlock()
return c.conn.Close()
}

func (c *connectionHolder) waitRecovered(ch chan struct{}) <-chan struct{} {
c.connMutex.RLock()
if !c.conn.IsClosed() {
close(ch)
c.connMutex.RUnlock()
return ch
}
c.connMutex.RUnlock()

c.notifyMutex.Lock()
c.notifyRecovered = append(c.notifyRecovered, ch)
c.notifyMutex.Unlock()
return ch
}

func (c *connectionHolder) getChannel(key string) (*amqp.Channel, error) {
var ch *amqp.Channel
var err error
var exists bool
<-c.waitRecovered(make(chan struct{}))
c.connMutex.RLock()
ch, exists = c.channels[key]
c.connMutex.RUnlock()
if !exists {
ch, err = c.getOrCreateChannel(key)
}

return ch, err
}

func (c *connectionHolder) getOrCreateChannel(key string) (*amqp.Channel, error) {
c.connMutex.Lock()
defer c.connMutex.Unlock()
var ch *amqp.Channel
var err error
var exists bool
ch, exists = c.channels[key]
if !exists {
ch, err = c.conn.Channel()
if err != nil {
return nil, err
}
c.channels[key] = ch
go func(ch *amqp.Channel) {
select {
case err, ok := <-ch.NotifyClose(make(chan *amqp.Error)):
if !ok {
break
}
c.connMutex.Lock()
c.logger.Warn().
Err(err).
Str("channelKey", key).
Msg("removing cached channel")
delete(c.channels, key)
c.connMutex.Unlock()
if c.onChannelRecovered != nil {
c.onChannelRecovered(key)
}
case <-c.done:
// closed. do nothing
}

}(ch)
}
return ch, nil
}
Loading
Loading