diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 68bb2e155..8cf85b3fc 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -51,7 +51,7 @@ jobs:
creds: ${{ secrets.AZURE_ACI_CREDENTIALS }}
enable-AzPSSession: true
- name: Setup RabbitMQ
- uses: Particular/setup-rabbitmq-action@v1.6.0
+ uses: Particular/setup-rabbitmq-action@v1.7.0
with:
connection-string-name: RabbitMQTransport_ConnectionString
tag: RabbitMQTransport
diff --git a/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj b/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj
index ca29a6252..14392b34f 100644
--- a/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj
+++ b/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj
@@ -23,6 +23,10 @@
+
+
+
+
diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs
new file mode 100644
index 000000000..4916f7a28
--- /dev/null
+++ b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs
@@ -0,0 +1,174 @@
+namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using global::RabbitMQ.Client;
+ using global::RabbitMQ.Client.Events;
+ using NUnit.Framework;
+
+ [TestFixture]
+ public class ChannelProviderTests
+ {
+ [Test]
+ public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown()
+ {
+ var channelProvider = new TestableChannelProvider();
+ channelProvider.CreateConnection();
+
+ var publishConnection = channelProvider.PublishConnections.Dequeue();
+ publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));
+
+ channelProvider.DelayTaskCompletionSource.SetResult(true);
+
+ await channelProvider.FireAndForgetAction(CancellationToken.None);
+
+ var recoveredConnection = channelProvider.PublishConnections.Dequeue();
+
+ Assert.That(publishConnection.WasDisposed, Is.True);
+ Assert.That(recoveredConnection.WasDisposed, Is.False);
+ }
+
+ [Test]
+ public void Should_dispose_connection_when_disposed()
+ {
+ var channelProvider = new TestableChannelProvider();
+ channelProvider.CreateConnection();
+
+ var publishConnection = channelProvider.PublishConnections.Dequeue();
+ channelProvider.Dispose();
+
+ Assert.That(publishConnection.WasDisposed, Is.True);
+ }
+
+ [Test]
+ public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending()
+ {
+ var channelProvider = new TestableChannelProvider();
+ channelProvider.CreateConnection();
+
+ var publishConnection = channelProvider.PublishConnections.Dequeue();
+ publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));
+
+ // Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing
+ // to simulate a pending delay task
+ channelProvider.Dispose();
+
+ await channelProvider.FireAndForgetAction(CancellationToken.None);
+
+ Assert.That(publishConnection.WasDisposed, Is.True);
+ Assert.That(channelProvider.PublishConnections, Has.Count.Zero);
+ }
+
+ [Test]
+ public async Task Should_dispose_newly_established_connection()
+ {
+ var channelProvider = new TestableChannelProvider();
+ channelProvider.CreateConnection();
+
+ var publishConnection = channelProvider.PublishConnections.Dequeue();
+ publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test"));
+
+ // This simulates the race of the reconnection loop being fired off with the delay task completed during
+ // the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop
+ // and await its completion after the channel provider has been disposed.
+ var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None);
+ channelProvider.DelayTaskCompletionSource.SetResult(true);
+ channelProvider.Dispose();
+
+ await fireAndForgetTask;
+
+ var recoveredConnection = channelProvider.PublishConnections.Dequeue();
+
+ Assert.That(publishConnection.WasDisposed, Is.True);
+ Assert.That(recoveredConnection.WasDisposed, Is.True);
+ }
+
+ class TestableChannelProvider : ChannelProvider
+ {
+ public TestableChannelProvider() : base(null, TimeSpan.Zero, null)
+ {
+ }
+
+ public Queue PublishConnections { get; } = new Queue();
+
+ public TaskCompletionSource DelayTaskCompletionSource { get; } = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ public Func FireAndForgetAction { get; private set; }
+
+ protected override IConnection CreatePublishConnection()
+ {
+ var connection = new FakeConnection();
+ PublishConnections.Enqueue(connection);
+ return connection;
+ }
+
+ protected override void FireAndForget(Func action, CancellationToken cancellationToken = default)
+ => FireAndForgetAction = _ => action(cancellationToken);
+
+ protected override async Task DelayReconnect(CancellationToken cancellationToken = default)
+ {
+ using (var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken)))
+ {
+ await DelayTaskCompletionSource.Task;
+ }
+ }
+ }
+
+ class FakeConnection : IConnection
+ {
+ public int LocalPort { get; }
+ public int RemotePort { get; }
+
+ public void Dispose() => WasDisposed = true;
+
+ public bool WasDisposed { get; private set; }
+
+ public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException();
+
+ public void Abort() => throw new NotImplementedException();
+
+ public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException();
+
+ public void Abort(TimeSpan timeout) => throw new NotImplementedException();
+
+ public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();
+
+ public void Close() => throw new NotImplementedException();
+
+ public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException();
+
+ public void Close(TimeSpan timeout) => throw new NotImplementedException();
+
+ public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException();
+
+ public IModel CreateModel() => throw new NotImplementedException();
+
+ public void HandleConnectionBlocked(string reason) => throw new NotImplementedException();
+
+ public void HandleConnectionUnblocked() => throw new NotImplementedException();
+
+ public ushort ChannelMax { get; }
+ public IDictionary ClientProperties { get; }
+ public ShutdownEventArgs CloseReason { get; }
+ public AmqpTcpEndpoint Endpoint { get; }
+ public uint FrameMax { get; }
+ public TimeSpan Heartbeat { get; }
+ public bool IsOpen { get; }
+ public AmqpTcpEndpoint[] KnownHosts { get; }
+ public IProtocol Protocol { get; }
+ public IDictionary ServerProperties { get; }
+ public IList ShutdownReport { get; }
+ public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}";
+ public event EventHandler CallbackException = (sender, args) => { };
+ public event EventHandler ConnectionBlocked = (sender, args) => { };
+ public event EventHandler ConnectionShutdown = (sender, args) => { };
+ public event EventHandler ConnectionUnblocked = (sender, args) => { };
+
+ public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args);
+
+ static int connectionCounter;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs
similarity index 100%
rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs
rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs
diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs
similarity index 100%
rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs
rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs
diff --git a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
index 4400cd55a..6f7851902 100644
--- a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
+++ b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs
@@ -2,11 +2,12 @@ namespace NServiceBus.Transport.RabbitMQ
{
using System;
using System.Collections.Concurrent;
+ using System.Threading;
using System.Threading.Tasks;
using global::RabbitMQ.Client;
using Logging;
- sealed class ChannelProvider : IDisposable
+ class ChannelProvider : IDisposable
{
public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology)
{
@@ -18,46 +19,73 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay,
channels = new ConcurrentQueue();
}
- public void CreateConnection()
+ public void CreateConnection() => connection = CreateConnectionWithShutdownListener();
+
+ protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection();
+
+ IConnection CreateConnectionWithShutdownListener()
{
- connection = connectionFactory.CreatePublishConnection();
- connection.ConnectionShutdown += Connection_ConnectionShutdown;
+ var newConnection = CreatePublishConnection();
+ newConnection.ConnectionShutdown += Connection_ConnectionShutdown;
+ return newConnection;
}
void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e)
{
- if (e.Initiator != ShutdownInitiator.Application)
+ if (e.Initiator == ShutdownInitiator.Application || sender is null)
{
- var connection = (IConnection)sender;
-
- _ = Task.Run(() => Reconnect(connection.ClientProvidedName));
+ return;
}
+
+ var connectionThatWasShutdown = (IConnection)sender;
+
+ FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token);
}
- async Task Reconnect(string connectionName)
+ async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken)
{
- var reconnected = false;
-
- while (!reconnected)
+ while (!cancellationToken.IsCancellationRequested)
{
Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds);
- await Task.Delay(retryDelay).ConfigureAwait(false);
-
try
{
- CreateConnection();
- reconnected = true;
+ await DelayReconnect(cancellationToken).ConfigureAwait(false);
+
+ var newConnection = CreateConnectionWithShutdownListener();
- Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
+ // A race condition is possible where CreatePublishConnection is invoked during Dispose
+ // where the returned connection isn't disposed so invoking Dispose to be sure
+ if (cancellationToken.IsCancellationRequested)
+ {
+ newConnection.Dispose();
+ break;
+ }
+
+ var oldConnection = Interlocked.Exchange(ref connection, newConnection);
+ oldConnection?.Dispose();
+ break;
}
- catch (Exception e)
+ catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
- Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, e);
+ Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName);
+ break;
+ }
+ catch (Exception ex)
+ {
+ Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, ex);
}
}
+
+ Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName);
}
+ protected virtual void FireAndForget(Func action, CancellationToken cancellationToken = default) =>
+ // Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack
+ _ = Task.Run(() => action(cancellationToken), CancellationToken.None);
+
+ protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken);
+
public ConfirmsAwareChannel GetPublishChannel()
{
if (!channels.TryDequeue(out var channel) || channel.IsClosed)
@@ -84,19 +112,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel)
public void Dispose()
{
- connection?.Dispose();
+ if (disposed)
+ {
+ return;
+ }
+
+ stoppingTokenSource.Cancel();
+ stoppingTokenSource.Dispose();
+
+ var oldConnection = Interlocked.Exchange(ref connection, null);
+ oldConnection?.Dispose();
foreach (var channel in channels)
{
channel.Dispose();
}
+
+ disposed = true;
}
readonly ConnectionFactory connectionFactory;
readonly TimeSpan retryDelay;
readonly IRoutingTopology routingTopology;
readonly ConcurrentQueue channels;
- IConnection connection;
+ readonly CancellationTokenSource stoppingTokenSource = new CancellationTokenSource();
+ volatile IConnection connection;
+ bool disposed;
static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider));
}