Skip to content

Commit

Permalink
tried to support websocket encoding with mask
Browse files Browse the repository at this point in the history
  • Loading branch information
kerryjiang committed Aug 8, 2024
1 parent e3ff071 commit 91433fc
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 100 deletions.
267 changes: 167 additions & 100 deletions src/SuperSocket.WebSocket/WebSocketEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,74 @@ namespace SuperSocket.WebSocket
public class WebSocketEncoder : IPackageEncoder<WebSocketPackage>
{
private static readonly Encoding _textEncoding = new UTF8Encoding(false);

private const int _size0 = 126;
private const int _size1 = 65536;

private readonly int[] _fragmentSizes;

private readonly ArrayPool<byte> _bufferPool;

protected ArrayPool<byte> BufferPool => _bufferPool;

public IReadOnlyList<IWebSocketExtension> Extensions { get; set; }

private int WriteHead(ref Span<byte> head, byte opCode, long length)
private static int[] _defaultFragmentSizes = new int []
{
1024,
1024 * 4,
1024 * 8,
1024 * 16,
1024 * 32,
1024 * 64
};

public WebSocketEncoder()
: this(ArrayPool<byte>.Shared, _defaultFragmentSizes)
{
}

public WebSocketEncoder(ArrayPool<byte> bufferPool, int[] fragmentSizes)
{
_bufferPool = bufferPool;
_fragmentSizes = fragmentSizes;
}

protected virtual int WriteHead(IBufferWriter<byte> writer, byte opCode, long length)
{
var head = WriteHead(writer, opCode, length, out var headLen);
head[0] = opCode;
writer.Advance(headLen);
return headLen;
}

private Span<byte> WriteHead(IBufferWriter<byte> writer, byte opCode, long length, out int headLen)
{
if (length < _size0)
{
headLen = 2;
var head = writer.GetSpan(headLen);
head[1] = (byte)length;
return 2;

return head;
}
else if (length < _size1)
{
headLen = 4;

var head = writer.GetSpan(headLen);

head[1] = (byte)_size0;
head[2] = (byte)(length / 256);
head[3] = (byte)(length % 256);
return 4;
return head;
}
else
{
headLen = 10;

var head = writer.GetSpan(headLen);

head[1] = (byte)127;

long left = length;
Expand All @@ -49,142 +95,165 @@ private int WriteHead(ref Span<byte> head, byte opCode, long length)
left = left / unit;
}

return 10;
return head;
}
}

private int EncodeEmptyFragment(IBufferWriter<byte> writer, byte opCode, int expectedHeadLength)
private int EncodeEmptyFragment(IBufferWriter<byte> writer, byte opCode)
{
return EncodeSingleFragment(writer, opCode, expectedHeadLength, default);
return EncodeFinalFragment(writer, opCode, ReadOnlySpan<char>.Empty, null, out var _);
}

private int EncodeFragment(IBufferWriter<byte> writer, byte opCode, int expectedHeadLength, int fragmentSize, ReadOnlySpan<char> text, Encoder encoder, out int charsUsed)
private int EncodeFragment(IBufferWriter<byte> writer, byte opCode, int fragmentSize, ReadOnlySpan<char> text, Encoder encoder, out int charsUsed)
{
charsUsed = 0;

var head = writer.GetSpan(expectedHeadLength);
var headLen = WriteHead(writer, opCode, fragmentSize);

writer.Advance(expectedHeadLength);
var encodingContext = CreateDataEncodingContext(writer);

var buffer = writer.GetSpan(fragmentSize).Slice(0, fragmentSize);

encoder.Convert(text, buffer, false, out charsUsed, out int bytesUsed, out bool completed);
writer.Advance(bytesUsed);
OnHeadEncoded(writer, encodingContext);

var totalBytes = bytesUsed;
var isFinal = completed && text.Length == charsUsed;
var totalBytes = 0;
var dataSizeToWrite = fragmentSize;

if (isFinal)
opCode = (byte)(opCode | 0x80);
while (dataSizeToWrite > 0)
{
var buffer = writer.GetSpan(dataSizeToWrite);
var bufferToWrite = Math.Min(buffer.Length, dataSizeToWrite);
buffer = buffer[..bufferToWrite];

encoder.Convert(text, buffer, false, out charsUsed, out var bytesUsed, out var completed);

buffer = buffer[..bytesUsed];

OnDataEncoded(buffer, encodingContext, totalBytes);
writer.Advance(bytesUsed);

WriteHead(ref head, opCode, totalBytes);
totalBytes += bytesUsed;

return totalBytes + expectedHeadLength;
dataSizeToWrite -= charsUsed;

if (totalBytes > fragmentSize)
{
throw new Exception("Size of the data from the decoding must be equal to the fragment size.");
}
}

return GetFragmentTotalLength(headLen, totalBytes);
}

protected virtual object CreateDataEncodingContext(IBufferWriter<byte> writer)
{
return null;
}

private int EncodeFragmentWithBuffer(IBufferWriter<byte> writer, byte opCode, int fragmentSize, ReadOnlySpan<char> text, Encoder encoder, out int charsUsed)
protected virtual void OnHeadEncoded(IBufferWriter<byte> writer, object encodingContext)
{
}

protected virtual void OnDataEncoded(Span<byte> encodedData, object encodingContext, int previusEncodedDataSize)
{
}

protected virtual void CleanupEncodingContext(object encodingContext)
{
}

protected virtual int GetFragmentTotalLength(int headLen, int bodyLen)
{
return headLen + bodyLen;
}

private int EncodeFinalFragment(IBufferWriter<byte> writer, byte opCode, ReadOnlySpan<char> text, Encoder encoder, out int charsUsed)
{
charsUsed = 0;

var bufferPool = ArrayPool<byte>.Shared;
var buffer = bufferPool.Rent(fragmentSize);
byte[] buffer = default;

object encodingContext = default;

try
{
var bufferSpan = buffer.AsSpan().Slice(0, fragmentSize);

encoder.Convert(text, bufferSpan, false, out charsUsed, out int bytesUsed, out bool completed);
// writer should not be touched for now, because head has not been written yet.
encodingContext = CreateDataEncodingContext(null);

var totalBytes = bytesUsed;
var isFinal = completed && text.Length == charsUsed;
var bytesUsed = 0;
Span<byte> bufferSpan = default;

if (isFinal)
opCode = (byte)(opCode | 0x80);
if (text.Length > 0)
{
var bufferSize = text.Length <= _size0 ? _size0 : _fragmentSizes[0];

var headLen = bytesUsed < _size0 ? 2 : 4;
buffer = _bufferPool.Rent(bufferSize);

var head = writer.GetSpan(headLen);
bufferSpan = buffer.AsSpan();

WriteHead(ref head, opCode, totalBytes);
writer.Advance(headLen);
encoder.Convert(text, bufferSpan, false, out charsUsed, out bytesUsed, out bool completed);

var pipelineBuffer = writer.GetSpan(totalBytes).Slice(0, totalBytes);
OnDataEncoded(bufferSpan[..bytesUsed], encodingContext, 0);

bufferSpan.Slice(0, totalBytes).CopyTo(pipelineBuffer);
writer.Advance(totalBytes);
var isFinal = completed && text.Length == charsUsed;

return totalBytes + headLen;
}
finally
{
bufferPool.Return(buffer);
}
}

private int EncodeSingleFragment(IBufferWriter<byte> writer, byte opCode, int expectedHeadLength, ReadOnlySpan<char> text)
{
var head = writer.GetSpan(expectedHeadLength);
if (isFinal)
opCode = (byte)(opCode | 0x80);
}
else
{
opCode = (byte)(opCode | 0x80);
bytesUsed = 0;
}

writer.Advance(expectedHeadLength);
var headLen = WriteHead(writer, opCode, bytesUsed);

var totalBytes = text.Length > 0 ? writer.Write(text, _textEncoding) : 0;
OnHeadEncoded(writer, encodingContext);

WriteHead(ref head, (byte)(opCode | 0x80), totalBytes);
if (bytesUsed > 0)
{
writer.Write(bufferSpan[..bytesUsed]);
}

return totalBytes + expectedHeadLength;
return GetFragmentTotalLength(headLen, bytesUsed);
}
finally
{
if (buffer != null)
_bufferPool.Return(buffer);

CleanupEncodingContext(encodingContext);
}
}

public int EncodeDataMessage(IBufferWriter<byte> writer, WebSocketPackage pack)
protected virtual void EncodeDataMessageBody(IBufferWriter<byte> writer, WebSocketPackage pack)
{
var head = writer.GetSpan(10);

var headLen = WriteHead(ref head, (byte)(pack.OpCodeByte | 0x80), pack.Data.Length);

writer.Advance(headLen);

foreach (var dataPiece in pack.Data)
{
writer.Write(dataPiece.Span);
}

return (int)(pack.Data.Length + headLen);
}

private (int headLen, int fragmentSize, bool bufferWrite) GetEstimateFragmentation(int msgSize)
public int EncodeDataMessage(IBufferWriter<byte> writer, WebSocketPackage pack)
{
var minSize = msgSize;
var maxSize = _textEncoding.GetMaxByteCount(msgSize);
var headLen = WriteHead(writer, (byte)(pack.OpCodeByte | 0x80), pack.Data.Length);

var fragmentSize = 0;
var headLen = 0;
var bufferWrite = false;
EncodeDataMessageBody(writer, pack);

if (maxSize < _size0)
headLen = 2;
else if (minSize >= _size0 && maxSize < _size1)
headLen = 4;
else if (minSize >= _size1)
{
headLen = 4;
fragmentSize = _size1 - 1;
}
return (int)(pack.Data.Length + headLen);
}

if (headLen == 0)
private int GetFragmentSize(int msgSize)
{
for (var i = _fragmentSizes.Length - 1; i >= 0; i--)
{
if (minSize < _size0 && maxSize >= _size0)
{
headLen = 2;
fragmentSize = _size0 - 1;
}
else
var fragmentSize = _fragmentSizes[i];

if (msgSize >= fragmentSize)
{
headLen = 4;
fragmentSize = _size1 - 1;
bufferWrite = true;
return fragmentSize;
}
}

return (headLen, fragmentSize, bufferWrite);
return 0;
}

public int Encode(IBufferWriter<byte> writer, WebSocketPackage pack)
Expand All @@ -207,7 +276,7 @@ public int Encode(IBufferWriter<byte> writer, WebSocketPackage pack)
var msgSize = !string.IsNullOrEmpty(pack.Message) ? pack.Message.Length : 0;

if (msgSize == 0)
return EncodeEmptyFragment(writer, pack.OpCodeByte, 2);
return EncodeEmptyFragment(writer, pack.OpCodeByte);

var total = 0;
var text = pack.Message.AsSpan();
Expand All @@ -218,23 +287,21 @@ public int Encode(IBufferWriter<byte> writer, WebSocketPackage pack)

while (true)
{
(var headLen, var fragmentSize, var bufferWrite) = GetEstimateFragmentation(text.Length);

if (fragmentSize == 0)
{
total += EncodeSingleFragment(writer, isContinuation ? (byte)OpCode.Continuation : pack.OpCodeByte, headLen, text);
break;
}
var fragmentSize = GetFragmentSize(text.Length);

var charsUsed = 0;

if (!bufferWrite)
total += EncodeFragment(writer, isContinuation ? (byte)OpCode.Continuation : pack.OpCodeByte, headLen, fragmentSize, text, encoder, out charsUsed);
if (fragmentSize > 0)
{
total += EncodeFragment(writer, isContinuation ? (byte)OpCode.Continuation : pack.OpCodeByte, fragmentSize, text, encoder, out charsUsed);
}
else
total += EncodeFragmentWithBuffer(writer, isContinuation ? (byte)OpCode.Continuation : pack.OpCodeByte, fragmentSize, text, encoder, out charsUsed);

if (text.Length <= charsUsed)
break;
{
total += EncodeFinalFragment(writer, isContinuation ? (byte)OpCode.Continuation : pack.OpCodeByte, text, encoder, out charsUsed);

if (text.Length <= charsUsed)
break;
}

text = text.Slice(charsUsed);

Expand Down
Loading

0 comments on commit 91433fc

Please sign in to comment.