diff --git a/src/SuperSocket.Command/CommandMiddleware.cs b/src/SuperSocket.Command/CommandMiddleware.cs index 9b90717e9..5451dc827 100644 --- a/src/SuperSocket.Command/CommandMiddleware.cs +++ b/src/SuperSocket.Command/CommandMiddleware.cs @@ -45,6 +45,8 @@ public class CommandMiddleware : Middleware { private Dictionary _commands; + private Func _unknownPackageHandler; + private ILogger _logger; protected IPackageMapper PackageMapper { get; private set; } @@ -136,6 +138,18 @@ public CommandMiddleware(IServiceProvider serviceProvider, IOptions; + + if (_unknownPackageHandler == null) + { + _logger.LogError($"{nameof(commandOptions.Value.UnknownPackageHandler)} was registered with incorrectly. The expected typew is {typeof(Func).Name}."); + } + } } private void RegisterCommandInterfaces(List commandInterfaces, List commandSetFactories, IServiceProvider serviceProvider, Type sessionType, Type packageType, bool wrapRequired = false) @@ -200,6 +214,13 @@ protected virtual async ValueTask HandlePackage(IAppSession session, TPackageInf { if (!_commands.TryGetValue(package.Key, out ICommandSet commandSet)) { + var unknownPackageHandler = _unknownPackageHandler; + + if (unknownPackageHandler != null) + { + await unknownPackageHandler.Invoke(session, package, cancellationToken); + } + return; } diff --git a/src/SuperSocket.Command/CommandOptions.cs b/src/SuperSocket.Command/CommandOptions.cs index 1d759292c..857a51407 100644 --- a/src/SuperSocket.Command/CommandOptions.cs +++ b/src/SuperSocket.Command/CommandOptions.cs @@ -2,6 +2,10 @@ using System.Linq; using System.Collections.Generic; using System.Reflection; +using System.Threading.Tasks; +using SuperSocket.ProtoBase; +using SuperSocket.Server.Abstractions.Session; +using System.Threading; namespace SuperSocket.Command { @@ -13,6 +17,13 @@ public CommandOptions() _globalCommandFilterTypes = new List(); } + internal object UnknownPackageHandler { get; private set; } + + public void RegisterUnknownPackageHandler(Func unknownPackageHandler) + { + UnknownPackageHandler = unknownPackageHandler; + } + public CommandAssemblyConfig[] Assemblies { get; set; } public List CommandSources { get; set; } diff --git a/test/SuperSocket.Tests/CommandTest.cs b/test/SuperSocket.Tests/CommandTest.cs index 8f0a61577..229b292ba 100644 --- a/test/SuperSocket.Tests/CommandTest.cs +++ b/test/SuperSocket.Tests/CommandTest.cs @@ -89,6 +89,57 @@ public async Task TestCommands(Type hostConfiguratorType) } } + [Theory] + [InlineData(typeof(RegularHostConfigurator))] + [InlineData(typeof(SecureHostConfigurator))] + public async Task TestUnknownCommands(Type hostConfiguratorType) + { + var hostConfigurator = CreateObject(hostConfiguratorType); + using (var server = CreateSocketServerBuilder(hostConfigurator) + .UseCommand(commandOptions => + { + // register commands one by one + commandOptions.AddCommand(); + commandOptions.RegisterUnknownPackageHandler(async (session, package, cancellationToken) => + { + await session.SendAsync(Encoding.UTF8.GetBytes("X\r\n")); + }); + + // register all commands in one assembly + //commandOptions.AddCommandAssembly(typeof(SUB).GetTypeInfo().Assembly); + }) + .BuildAsServer()) + { + + Assert.Equal("TestServer", server.Name); + + Assert.True(await server.StartAsync()); + OutputHelper.WriteLine("Server started."); + + + var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); + await client.ConnectAsync(hostConfigurator.GetServerEndPoint()); + OutputHelper.WriteLine("Connected."); + + using (var stream = await hostConfigurator.GetClientStream(client)) + using (var streamReader = new StreamReader(stream, Utf8Encoding, true)) + using (var streamWriter = new StreamWriter(stream, Utf8Encoding, 1024 * 1024 * 4)) + { + await streamWriter.WriteAsync("ADD 1 2 3\r\n"); + await streamWriter.FlushAsync(); + var line = await streamReader.ReadLineAsync(); + Assert.Equal("6", line); + + await streamWriter.WriteAsync("MULT 2 5\r\n"); + await streamWriter.FlushAsync(); + line = await streamReader.ReadLineAsync(); + Assert.Equal("X", line); + } + + await server.StopAsync(); + } + } + [Theory] [InlineData(typeof(RegularHostConfigurator))] [InlineData(typeof(SecureHostConfigurator))]