diff --git a/src/SuperSocket.ProtoBase/ProxyProtocol/ProxyProtocolV2PartReader.cs b/src/SuperSocket.ProtoBase/ProxyProtocol/ProxyProtocolV2PartReader.cs index a1ab4af50..a95c424f0 100644 --- a/src/SuperSocket.ProtoBase/ProxyProtocol/ProxyProtocolV2PartReader.cs +++ b/src/SuperSocket.ProtoBase/ProxyProtocol/ProxyProtocolV2PartReader.cs @@ -10,7 +10,7 @@ class ProxyProtocolV2PartReader : ProxyProtocolPackagePartReader _bufferPool = ArrayPool.Shared; @@ -97,17 +97,17 @@ public override bool Process(TPackageInfo package, object filterContext, ref Seq try { - var addressBufferSpan = addressBuffer.AsSpan()[..IPV6_ADDRESS_LEN]; + var addressBufferSpan = addressBuffer.AsSpan().Slice(0, IPV6_ADDRESS_LEN); - reader.Sequence.Slice(0, IPV6_ADDRESS_LEN).CopyTo(addressBufferSpan); - reader.Advance(IPV6_ADDRESS_LEN); + var sequenceToRead = reader.UnreadSequence; + sequenceToRead.Slice(0, IPV6_ADDRESS_LEN).CopyTo(addressBufferSpan); proxyInfo.SourceIPAddress = new IPAddress(addressBufferSpan); - reader.Sequence.Slice(0, IPV6_ADDRESS_LEN).CopyTo(addressBufferSpan); - reader.Advance(IPV6_ADDRESS_LEN); - + sequenceToRead.Slice(IPV6_ADDRESS_LEN, IPV6_ADDRESS_LEN).CopyTo(addressBufferSpan); proxyInfo.DestinationIPAddress = new IPAddress(addressBufferSpan); + + reader.Advance(IPV6_ADDRESS_LEN * 2); } finally { diff --git a/test/SuperSocket.Tests/FixedHeaderProtocolTest.cs b/test/SuperSocket.Tests/FixedHeaderProtocolTest.cs index 6e1440ad8..32276e9f3 100644 --- a/test/SuperSocket.Tests/FixedHeaderProtocolTest.cs +++ b/test/SuperSocket.Tests/FixedHeaderProtocolTest.cs @@ -3,8 +3,6 @@ using System.Collections.Generic; using System.Text; using Microsoft.Extensions.Configuration; -using Microsoft.Extensions.Hosting; -using SuperSocket; using SuperSocket.ProtoBase; using SuperSocket.Server.Host; using SuperSocket.Server.Abstractions; @@ -21,7 +19,7 @@ public FixedHeaderProtocolTest(ITestOutputHelper outputHelper) : base(outputHelp } - class MyFixedHeaderPipelineFilter : FixedHeaderPipelineFilter + internal class MyFixedHeaderPipelineFilter : FixedHeaderPipelineFilter { public MyFixedHeaderPipelineFilter() : base(4) diff --git a/test/SuperSocket.Tests/ProxyProtocolHostConfigurator.cs b/test/SuperSocket.Tests/ProxyProtocolHostConfigurator.cs index 990e7fb69..051dc2850 100644 --- a/test/SuperSocket.Tests/ProxyProtocolHostConfigurator.cs +++ b/test/SuperSocket.Tests/ProxyProtocolHostConfigurator.cs @@ -1,4 +1,5 @@ using System; +using System.Buffers.Binary; using System.IO; using System.Net; using System.Net.Sockets; @@ -9,6 +10,7 @@ using SuperSocket.ProtoBase; using SuperSocket.Server.Abstractions; using SuperSocket.Server.Abstractions.Host; +using Xunit; namespace SuperSocket.Tests { @@ -16,20 +18,65 @@ public class ProxyProtocolHostConfigurator : IHostConfigurator { private IHostConfigurator _innerHostConfigurator; - private static readonly byte[] _proxyProtocolV2_IPV4_SampleData = new byte[] + private static readonly byte[] _proxyProtocolV2_SIGNATURE = new byte[] { + // Signature 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, - 0x55, 0x49, 0x54, 0x0A, - 0x21, 0x11, 0x00, 0x0c, - 0xac, 0x13, 0x00, 0x01, - 0xac, 0x13, 0x00, 0x03, - 0xa6, 0x52, 0x00, 0x50 + 0x55, 0x49, 0x54, 0x0A }; - public ProxyProtocolHostConfigurator(IHostConfigurator hostConfigurator) + private IPEndPoint _sourceIPEndPoint; + private IPEndPoint _destinationIPEndPoint; + + private ReadOnlySpan CreateProxyProtocolData(IPEndPoint sourceIPEndPoint, IPEndPoint destinationIPEndPoint) + { + var isIpV4 = sourceIPEndPoint.Address.AddressFamily == AddressFamily.InterNetwork; + var ipAddressLength = isIpV4 ? 4 : 16; + + var addressLength = isIpV4 + ? (ipAddressLength * 2 + 4) + : (ipAddressLength * 2 + 4); + + var data = new byte[4 + addressLength]; + + data[0] = 0x21; + data[1] = (byte)((_innerHostConfigurator is UdpHostConfigurator ? 0x02 : 0x01) | (isIpV4 ? 0x10 : 0x20)); + + var span = data.AsSpan(); + + BinaryPrimitives.WriteUInt16BigEndian(span.Slice(2, 2), (ushort)addressLength); + + var spanToWrite = span.Slice(4); + + var addressSpan = spanToWrite.Slice(0, ipAddressLength); + + var written = 0; + + sourceIPEndPoint.Address.TryWriteBytes(addressSpan, out written); + + Assert.Equal(ipAddressLength, written); + + spanToWrite = spanToWrite.Slice(ipAddressLength); + + addressSpan = spanToWrite.Slice(0, ipAddressLength); + destinationIPEndPoint.Address.TryWriteBytes(addressSpan, out written); + + Assert.Equal(ipAddressLength, written); + + spanToWrite = spanToWrite.Slice(ipAddressLength); + + BinaryPrimitives.WriteUInt16BigEndian(spanToWrite.Slice(0, 2), (ushort)sourceIPEndPoint.Port); + BinaryPrimitives.WriteUInt16BigEndian(spanToWrite.Slice(2, 2), (ushort)destinationIPEndPoint.Port); + + return span; + } + + public ProxyProtocolHostConfigurator(IHostConfigurator hostConfigurator, IPEndPoint sourceIPEndPoint, IPEndPoint destinationIPEndPoint) { _innerHostConfigurator = hostConfigurator; + _sourceIPEndPoint = sourceIPEndPoint; + _destinationIPEndPoint = destinationIPEndPoint; } public string WebSocketSchema => _innerHostConfigurator.WebSocketSchema; @@ -56,8 +103,9 @@ public Socket CreateClient() public async ValueTask GetClientStream(Socket socket) { var stream = await _innerHostConfigurator.GetClientStream(socket); - - await stream.WriteAsync(_proxyProtocolV2_IPV4_SampleData, 0, _proxyProtocolV2_IPV4_SampleData.Length); + + stream.Write(_proxyProtocolV2_SIGNATURE, 0, _proxyProtocolV2_SIGNATURE.Length); + stream.Write(CreateProxyProtocolData(_sourceIPEndPoint, _destinationIPEndPoint)); await stream.FlushAsync(); return stream; diff --git a/test/SuperSocket.Tests/ProxyProtocolTest.cs b/test/SuperSocket.Tests/ProxyProtocolTest.cs index cf5ff6c99..9e5347e52 100644 --- a/test/SuperSocket.Tests/ProxyProtocolTest.cs +++ b/test/SuperSocket.Tests/ProxyProtocolTest.cs @@ -10,12 +10,47 @@ using SuperSocket.Server.Abstractions; using Xunit; using Xunit.Abstractions; +using System.Net; +using System.Linq; +using System.Threading.Tasks; +using SuperSocket.Server.Abstractions.Session; +using SuperSocket.Server; namespace SuperSocket.Tests { [Trait("Category", "ProxyProtocol")] public class ProxyProtocolTest : FixedHeaderProtocolTest { + private static readonly IPAddress[] _ipv4AddressPool = (new [] + { + "247.47.227.3", + "112.207.192.91", + "123.193.169.24", + "191.213.152.251", + "7.132.159.148", + "214.75.171.159", + "170.103.166.188", + "228.111.89.87", + "4.122.43.89", + "206.222.157.16" + }).Select(ip => IPAddress.Parse(ip)).ToArray(); + + private static readonly IPAddress[] _ipv6AddressPool = (new[] + { + "4466:a5cd:cacc:faa1:4522:055e:9094:f1a3", + "f438:4e9c:0d38:6ae5:ef44:4b0e:c160:a254", + "529f:bc3e:2e56:8365:dc06:5772:87a5:e658", + "4b19:5c07:d470:018a:36a0:b31a:59aa:cd48", + "2a61:61ff:8a01:5473:0091:e416:aeda:f924", + "0f29:faaa:c984:c0fd:d0d2:36ca:7132:933e", + "1598:7240:de55:0803:305a:b7f4:4eab:fd00", + "6431:494f:9a92:4ea7:5645:a3ab:945a:ca72", + "48d4:c5d6:b3e8:9859:dc0f:a8d0:e085:3518", + "8e72:3b78:2b3e:33ad:7b12:4b14:37de:e9f1" + }).Select(ip => IPAddress.Parse(ip)).ToArray(); + + private static readonly Random _rd = new Random(); + public ProxyProtocolTest(ITestOutputHelper outputHelper) : base(outputHelper) { @@ -23,7 +58,13 @@ public ProxyProtocolTest(ITestOutputHelper outputHelper) protected override IHostConfigurator CreateHostConfigurator(Type hostConfiguratorType) { - return new ProxyProtocolHostConfigurator(base.CreateHostConfigurator(hostConfiguratorType)); + var radomNumber = _rd.Next(0, 1000); + var addressPool = radomNumber % 2 == 0 ? _ipv4AddressPool : _ipv6AddressPool; + + var sourceIPEndPoint = new IPEndPoint(addressPool[_rd.Next(0, addressPool.Length)], _rd.Next(100, 9999)); + var destinationIPEndPoint = new IPEndPoint(addressPool[_rd.Next(0, addressPool.Length)], _rd.Next(100, 9999)); + + return new ProxyProtocolHostConfigurator(base.CreateHostConfigurator(hostConfiguratorType), sourceIPEndPoint, destinationIPEndPoint); } protected override Dictionary LoadMemoryConfig(Dictionary configSettings) @@ -32,5 +73,58 @@ protected override Dictionary LoadMemoryConfig(Dictionary(); + + using (var server = CreateSocketServerBuilder(hostConfigurator) + .UsePackageHandler(async (s, p) => + { + taskCompletionSource.SetResult(s); + await Task.CompletedTask; + }) + .ConfigureAppConfiguration((HostBuilder, configBuilder) => + { + configBuilder.AddInMemoryCollection(LoadMemoryConfig(new Dictionary())); + }).BuildAsServer() as IServer) + { + await server.StartAsync(); + + using (var socket = CreateClient(hostConfigurator)) + { + using (var socketStream = await hostConfigurator.GetClientStream(socket)) + using (var reader = hostConfigurator.GetStreamReader(socketStream, Utf8Encoding)) + using (var writer = new ConsoleWriter(socketStream, Utf8Encoding, 1024 * 8)) + { + var line = Guid.NewGuid().ToString(); + writer.Write(CreateRequest(line)); + writer.Flush(); + + var session = await taskCompletionSource.Task as AppSession; + + Assert.NotNull(session.Connection.ProxyInfo); + + Assert.Equal(sourceIPAddress, session.Connection.ProxyInfo.SourceIPAddress); + Assert.Equal(sourcePort, session.Connection.ProxyInfo.SourcePort); + + Assert.Equal(destinationIPAddress, session.Connection.ProxyInfo.DestinationIPAddress); + Assert.Equal(destinationPort, session.Connection.ProxyInfo.DestinationPort); + } + } + + await server.StopAsync(); + } + } } }