From 334306700da93615902ff77b1ba9b8c67381a08e Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Fri, 30 Aug 2024 16:16:30 -0700 Subject: [PATCH] fixed bugs about empty package encoding --- src/SuperSocket.WebSocket/WebSocketEncoder.cs | 45 ++++++++++--------- .../WebSocketMaskedEncoder.cs | 9 +++- .../WebSocket/WebSocketBasicTest.cs | 1 + 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/SuperSocket.WebSocket/WebSocketEncoder.cs b/src/SuperSocket.WebSocket/WebSocketEncoder.cs index 3e9a9861a..feb821634 100644 --- a/src/SuperSocket.WebSocket/WebSocketEncoder.cs +++ b/src/SuperSocket.WebSocket/WebSocketEncoder.cs @@ -69,6 +69,7 @@ protected virtual Span WriteHead(IBufferWriter writer, long length, if (length < _size0) { headLen = 2; + var head = writer.GetSpan(headLen); head[1] = (byte)length; @@ -229,43 +230,45 @@ protected virtual int GetFragmentTotalLength(int headLen, int bodyLen) private int EncodeFinalFragment(IBufferWriter writer, byte opCode, ReadOnlySpan text, Encoder encoder, ArraySegment unwrittenBytes) { byte[] buffer = default; + Span bufferSpan = default; object encodingContext = default; try { - // writer should not be touched for now, because head has not been written yet. - encodingContext = CreateDataEncodingContext(null); - var totalWritten = 0; - Span bufferSpan = default; + if (encoder != null) + { + // writer should not be touched for now, because head has not been written yet. + encodingContext = CreateDataEncodingContext(null); - var fragementSize = (text.Length > 0 ? encoder.GetByteCount(text, true) : 0) + unwrittenBytes.Count; + var fragementSize = (text.Length > 0 ? encoder.GetByteCount(text, true) : 0) + unwrittenBytes.Count; - if (fragementSize == 0) - fragementSize = _minEncodeBufferSize; + if (fragementSize == 0) + fragementSize = _minEncodeBufferSize; - buffer = _bufferPool.Rent(fragementSize); + buffer = _bufferPool.Rent(fragementSize); - bufferSpan = buffer.AsSpan(); + bufferSpan = buffer.AsSpan(); - if (unwrittenBytes.Count > 0) - { - unwrittenBytes.AsSpan().CopyTo(bufferSpan); - totalWritten += unwrittenBytes.Count; - OnDataEncoded(bufferSpan.Slice(0, unwrittenBytes.Count), encodingContext, 0); - } + if (unwrittenBytes.Count > 0) + { + unwrittenBytes.AsSpan().CopyTo(bufferSpan); + totalWritten += unwrittenBytes.Count; + OnDataEncoded(bufferSpan.Slice(0, unwrittenBytes.Count), encodingContext, 0); + } - encoder.Convert(text, totalWritten == 0 ? bufferSpan : bufferSpan.Slice(totalWritten), true, out var charsUsed, out var bytesUsed, out bool completed); + encoder.Convert(text, totalWritten == 0 ? bufferSpan : bufferSpan.Slice(totalWritten), true, out var charsUsed, out var bytesUsed, out bool completed); - OnDataEncoded(bufferSpan.Slice(totalWritten, bytesUsed), encodingContext, totalWritten); + OnDataEncoded(bufferSpan.Slice(totalWritten, bytesUsed), encodingContext, totalWritten); - totalWritten += bytesUsed; + totalWritten += bytesUsed; - if (!completed || text.Length != charsUsed) - { - throw new ProtocolException("Unexpected encoding behavior: the text encoding didn't complete with enough buffer."); + if (!completed || text.Length != charsUsed) + { + throw new ProtocolException("Unexpected encoding behavior: the text encoding didn't complete with enough buffer."); + } } opCode = (byte)(opCode | 0x80); diff --git a/src/SuperSocket.WebSocket/WebSocketMaskedEncoder.cs b/src/SuperSocket.WebSocket/WebSocketMaskedEncoder.cs index 529ec4aa9..d3bd6c669 100644 --- a/src/SuperSocket.WebSocket/WebSocketMaskedEncoder.cs +++ b/src/SuperSocket.WebSocket/WebSocketMaskedEncoder.cs @@ -42,7 +42,11 @@ protected override object CreateDataEncodingContext(IBufferWriter writer) protected override Span WriteHead(IBufferWriter writer, long length, out int headLen) { var head = base.WriteHead(writer, length, out headLen); - head[1] = (byte)(head[1] | 0x80); + + // We don't mask data for empty package + if (length > 0) + head[1] = (byte)(head[1] | 0x80); + return head; } @@ -50,6 +54,9 @@ protected override void OnHeadEncoded(IBufferWriter writer, object encodin { var maskingContext = encodingContext as MaskingContext; + if (maskingContext == null) + return; + // Means mask buffer was allocated from writter if (maskingContext.MaskBuffer == null) { diff --git a/test/SuperSocket.Tests/WebSocket/WebSocketBasicTest.cs b/test/SuperSocket.Tests/WebSocket/WebSocketBasicTest.cs index 1f45c901a..2fb5ebdab 100644 --- a/test/SuperSocket.Tests/WebSocket/WebSocketBasicTest.cs +++ b/test/SuperSocket.Tests/WebSocket/WebSocketBasicTest.cs @@ -84,6 +84,7 @@ private string GetTestMessage(int messageSize, bool nonAsciiText) [InlineData(10, true, 8)] [InlineData(16, true, 8)] [InlineData(17, true, 8)] + [InlineData(0, false, 8)] public void TestWebSocketMaskEncoder(int messageLength, bool nonAsciiText, int fragmentSize) { var websocketEncoder = new WebSocketMaskedEncoder(ArrayPool.Shared, new int[] { fragmentSize });