diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68bb2e155..8cf85b3fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: creds: ${{ secrets.AZURE_ACI_CREDENTIALS }} enable-AzPSSession: true - name: Setup RabbitMQ - uses: Particular/setup-rabbitmq-action@v1.6.0 + uses: Particular/setup-rabbitmq-action@v1.7.0 with: connection-string-name: RabbitMQTransport_ConnectionString tag: RabbitMQTransport diff --git a/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj b/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj index ca29a6252..14392b34f 100644 --- a/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj +++ b/src/NServiceBus.Transport.RabbitMQ.CommandLine/NServiceBus.Transport.RabbitMQ.CommandLine.csproj @@ -23,6 +23,10 @@ + + + + diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs new file mode 100644 index 000000000..4916f7a28 --- /dev/null +++ b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs @@ -0,0 +1,174 @@ +namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString +{ + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using global::RabbitMQ.Client; + using global::RabbitMQ.Client.Events; + using NUnit.Framework; + + [TestFixture] + public class ChannelProviderTests + { + [Test] + public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + channelProvider.DelayTaskCompletionSource.SetResult(true); + + await channelProvider.FireAndForgetAction(CancellationToken.None); + + var recoveredConnection = channelProvider.PublishConnections.Dequeue(); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(recoveredConnection.WasDisposed, Is.False); + } + + [Test] + public void Should_dispose_connection_when_disposed() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + channelProvider.Dispose(); + + Assert.That(publishConnection.WasDisposed, Is.True); + } + + [Test] + public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + // Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing + // to simulate a pending delay task + channelProvider.Dispose(); + + await channelProvider.FireAndForgetAction(CancellationToken.None); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(channelProvider.PublishConnections, Has.Count.Zero); + } + + [Test] + public async Task Should_dispose_newly_established_connection() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + // This simulates the race of the reconnection loop being fired off with the delay task completed during + // the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop + // and await its completion after the channel provider has been disposed. + var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None); + channelProvider.DelayTaskCompletionSource.SetResult(true); + channelProvider.Dispose(); + + await fireAndForgetTask; + + var recoveredConnection = channelProvider.PublishConnections.Dequeue(); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(recoveredConnection.WasDisposed, Is.True); + } + + class TestableChannelProvider : ChannelProvider + { + public TestableChannelProvider() : base(null, TimeSpan.Zero, null) + { + } + + public Queue PublishConnections { get; } = new Queue(); + + public TaskCompletionSource DelayTaskCompletionSource { get; } = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public Func FireAndForgetAction { get; private set; } + + protected override IConnection CreatePublishConnection() + { + var connection = new FakeConnection(); + PublishConnections.Enqueue(connection); + return connection; + } + + protected override void FireAndForget(Func action, CancellationToken cancellationToken = default) + => FireAndForgetAction = _ => action(cancellationToken); + + protected override async Task DelayReconnect(CancellationToken cancellationToken = default) + { + using (var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken))) + { + await DelayTaskCompletionSource.Task; + } + } + } + + class FakeConnection : IConnection + { + public int LocalPort { get; } + public int RemotePort { get; } + + public void Dispose() => WasDisposed = true; + + public bool WasDisposed { get; private set; } + + public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException(); + + public void Abort() => throw new NotImplementedException(); + + public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException(); + + public void Abort(TimeSpan timeout) => throw new NotImplementedException(); + + public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException(); + + public void Close() => throw new NotImplementedException(); + + public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException(); + + public void Close(TimeSpan timeout) => throw new NotImplementedException(); + + public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException(); + + public IModel CreateModel() => throw new NotImplementedException(); + + public void HandleConnectionBlocked(string reason) => throw new NotImplementedException(); + + public void HandleConnectionUnblocked() => throw new NotImplementedException(); + + public ushort ChannelMax { get; } + public IDictionary ClientProperties { get; } + public ShutdownEventArgs CloseReason { get; } + public AmqpTcpEndpoint Endpoint { get; } + public uint FrameMax { get; } + public TimeSpan Heartbeat { get; } + public bool IsOpen { get; } + public AmqpTcpEndpoint[] KnownHosts { get; } + public IProtocol Protocol { get; } + public IDictionary ServerProperties { get; } + public IList ShutdownReport { get; } + public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}"; + public event EventHandler CallbackException = (sender, args) => { }; + public event EventHandler ConnectionBlocked = (sender, args) => { }; + public event EventHandler ConnectionShutdown = (sender, args) => { }; + public event EventHandler ConnectionUnblocked = (sender, args) => { }; + + public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args); + + static int connectionCounter; + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs similarity index 100% rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs similarity index 100% rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs diff --git a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs index 4400cd55a..6f7851902 100644 --- a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs +++ b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs @@ -2,11 +2,12 @@ namespace NServiceBus.Transport.RabbitMQ { using System; using System.Collections.Concurrent; + using System.Threading; using System.Threading.Tasks; using global::RabbitMQ.Client; using Logging; - sealed class ChannelProvider : IDisposable + class ChannelProvider : IDisposable { public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology) { @@ -18,46 +19,73 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, channels = new ConcurrentQueue(); } - public void CreateConnection() + public void CreateConnection() => connection = CreateConnectionWithShutdownListener(); + + protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection(); + + IConnection CreateConnectionWithShutdownListener() { - connection = connectionFactory.CreatePublishConnection(); - connection.ConnectionShutdown += Connection_ConnectionShutdown; + var newConnection = CreatePublishConnection(); + newConnection.ConnectionShutdown += Connection_ConnectionShutdown; + return newConnection; } void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e) { - if (e.Initiator != ShutdownInitiator.Application) + if (e.Initiator == ShutdownInitiator.Application || sender is null) { - var connection = (IConnection)sender; - - _ = Task.Run(() => Reconnect(connection.ClientProvidedName)); + return; } + + var connectionThatWasShutdown = (IConnection)sender; + + FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token); } - async Task Reconnect(string connectionName) + async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken) { - var reconnected = false; - - while (!reconnected) + while (!cancellationToken.IsCancellationRequested) { Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds); - await Task.Delay(retryDelay).ConfigureAwait(false); - try { - CreateConnection(); - reconnected = true; + await DelayReconnect(cancellationToken).ConfigureAwait(false); + + var newConnection = CreateConnectionWithShutdownListener(); - Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName); + // A race condition is possible where CreatePublishConnection is invoked during Dispose + // where the returned connection isn't disposed so invoking Dispose to be sure + if (cancellationToken.IsCancellationRequested) + { + newConnection.Dispose(); + break; + } + + var oldConnection = Interlocked.Exchange(ref connection, newConnection); + oldConnection?.Dispose(); + break; } - catch (Exception e) + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { - Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, e); + Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName); + break; + } + catch (Exception ex) + { + Logger.InfoFormat("'{0}': Reconnecting to the broker failed: {1}", connectionName, ex); } } + + Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName); } + protected virtual void FireAndForget(Func action, CancellationToken cancellationToken = default) => + // Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack + _ = Task.Run(() => action(cancellationToken), CancellationToken.None); + + protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken); + public ConfirmsAwareChannel GetPublishChannel() { if (!channels.TryDequeue(out var channel) || channel.IsClosed) @@ -84,19 +112,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel) public void Dispose() { - connection?.Dispose(); + if (disposed) + { + return; + } + + stoppingTokenSource.Cancel(); + stoppingTokenSource.Dispose(); + + var oldConnection = Interlocked.Exchange(ref connection, null); + oldConnection?.Dispose(); foreach (var channel in channels) { channel.Dispose(); } + + disposed = true; } readonly ConnectionFactory connectionFactory; readonly TimeSpan retryDelay; readonly IRoutingTopology routingTopology; readonly ConcurrentQueue channels; - IConnection connection; + readonly CancellationTokenSource stoppingTokenSource = new CancellationTokenSource(); + volatile IConnection connection; + bool disposed; static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider)); }