diff --git a/src/ArtemisNetCoreClient/Connection.cs b/src/ArtemisNetCoreClient/Connection.cs index 2ce0723..e8cb4ee 100644 --- a/src/ArtemisNetCoreClient/Connection.cs +++ b/src/ArtemisNetCoreClient/Connection.cs @@ -1,5 +1,6 @@ using System.Buffers; using System.Collections.Concurrent; +using System.Diagnostics; using ActiveMQ.Artemis.Core.Client.Framing; using Microsoft.Extensions.Logging; @@ -16,7 +17,7 @@ internal class Connection : IConnection, IChannel private readonly ConcurrentDictionary> _completionSources = new(); private readonly SemaphoreSlim _lock = new(1, 1); private readonly IdGenerator _sessionChannelIdGenerator = new(10); - private volatile bool _disposed; + private readonly CancellationTokenSource _receiveLoopCancellationToken; public Connection(ILoggerFactory loggerFactory, Transport2 transport, Endpoint endpoint) { @@ -26,16 +27,17 @@ public Connection(ILoggerFactory loggerFactory, Transport2 transport, Endpoint e _endpoint = endpoint; _channels.TryAdd(1, this); - _receiveLoopTask = Task.Factory.StartNew(ReceiveLoop, TaskCreationOptions.LongRunning); + _receiveLoopCancellationToken = new CancellationTokenSource(); + _receiveLoopTask = Task.Run(ReceiveLoop); } - private void ReceiveLoop() + private async Task ReceiveLoop() { - while (_disposed == false) + while (_receiveLoopCancellationToken.IsCancellationRequested == false) { try { - var inboundPacket = _transport.ReceivePacket(); + var inboundPacket = await _transport.ReceivePacketAsync(_receiveLoopCancellationToken.Token); try { if (_channels.TryGetValue(inboundPacket.ChannelId, out var channel)) @@ -49,15 +51,19 @@ private void ReceiveLoop() } finally { - ArrayPool.Shared.Return(inboundPacket.Payload.Array!); + if (inboundPacket.Payload.Array is { } array) + { + ArrayPool.Shared.Return(array); + } } } + catch (OperationCanceledException) + { + // Ignore + } catch (IOException e) { - if (!_disposed) - { - _logger.LogError(e, "Error in network communication"); - } + _logger.LogError(e, "Error in network communication"); } } } @@ -147,7 +153,8 @@ internal void RemoveChannel(long channelId) public async ValueTask DisposeAsync() { - _disposed = true; + await _receiveLoopCancellationToken.CancelAsync(); + _receiveLoopCancellationToken.Dispose(); await _transport.DisposeAsync(); await _receiveLoopTask; } diff --git a/src/ArtemisNetCoreClient/Transport2.cs b/src/ArtemisNetCoreClient/Transport2.cs index 09cc449..4611d51 100644 --- a/src/ArtemisNetCoreClient/Transport2.cs +++ b/src/ArtemisNetCoreClient/Transport2.cs @@ -75,19 +75,19 @@ public async ValueTask DisposeAsync() _socket.Dispose(); } - internal InboundPacket ReceivePacket() + internal async ValueTask ReceivePacketAsync(CancellationToken cancellationToken) { - var (frameSize, packetType, channelId) = ReadHeader(); - var payloadSize = frameSize - sizeof(byte) - sizeof(long); + var header = await ReadHeaderAsync(cancellationToken); + var payloadSize = header.FrameSize - sizeof(byte) - sizeof(long); var buffer = ArrayPool.Shared.Rent(payloadSize); try { - _reader.ReadExactly(buffer.AsSpan(0, payloadSize)); + await _reader.ReadExactlyAsync(buffer, 0, payloadSize, cancellationToken); return new InboundPacket { - PacketType = packetType, - ChannelId = channelId, + PacketType = header.PacketType, + ChannelId = header.ChannelId, Payload = new ArraySegment(buffer, 0, payloadSize) }; } @@ -97,23 +97,43 @@ internal InboundPacket ReceivePacket() throw; } } - - private (int frameSize, PacketType packetType, long channelId) ReadHeader() + + private const int HeaderSize = sizeof(int) + sizeof(byte) + sizeof(long); + private async ValueTask
ReadHeaderAsync(CancellationToken cancellationToken) { - Span headerBuffer = stackalloc byte[sizeof(int) + sizeof(byte) + sizeof(long)]; - _reader.ReadExactly(headerBuffer); - - var readBytes = ArtemisBinaryConverter.ReadInt32(headerBuffer, out var frameSize); - readBytes += ArtemisBinaryConverter.ReadByte(headerBuffer[readBytes..], out var packetType); - readBytes += ArtemisBinaryConverter.ReadInt64(headerBuffer[readBytes..], out var channelId); - - Debug.Assert(readBytes == headerBuffer.Length, $"Expected to read {headerBuffer.Length} bytes but got {readBytes}"); + var buffer = ArrayPool.Shared.Rent(Header.HeaderSize); + try + { + await _reader.ReadExactlyAsync(buffer, 0, HeaderSize, cancellationToken); + return new Header(buffer); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } +} + +internal readonly struct Header +{ + public const int HeaderSize = sizeof(int) + sizeof(byte) + sizeof(long); - return (frameSize, (PacketType) packetType, channelId); + public Header(ReadOnlySpan buffer) + { + var readBytes = ArtemisBinaryConverter.ReadInt32(buffer, out FrameSize); + readBytes += ArtemisBinaryConverter.ReadByte(buffer[readBytes..], out var packetType); + PacketType = (PacketType) packetType; + readBytes += ArtemisBinaryConverter.ReadInt64(buffer[readBytes..], out ChannelId); + + Debug.Assert(readBytes == HeaderSize, $"Expected to read {HeaderSize} bytes but got {readBytes}"); } + + public readonly int FrameSize; + public readonly long ChannelId; + public readonly PacketType PacketType; } -internal readonly ref struct InboundPacket +internal readonly struct InboundPacket { public long ChannelId { get; init; } public PacketType PacketType { get; init; }