Skip to content

Commit

Permalink
UnknownPackageHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
kerryjiang committed Oct 30, 2024
1 parent 5d74143 commit 699ef0f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/SuperSocket.Command/CommandMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ public class CommandMiddleware<TKey, TNetPackageInfo, TPackageInfo> : Middleware
{
private Dictionary<TKey, ICommandSet> _commands;

private Func<IAppSession, TPackageInfo, CancellationToken, ValueTask> _unknownPackageHandler;

private ILogger _logger;

protected IPackageMapper<TNetPackageInfo, TPackageInfo> PackageMapper { get; private set; }
Expand Down Expand Up @@ -136,6 +138,18 @@ public CommandMiddleware(IServiceProvider serviceProvider, IOptions<CommandOptio
_commands = commandDict;

PackageMapper = packageMapper != null ? packageMapper : CreatePackageMapper(serviceProvider);

var unknownPackageHandler = commandOptions.Value.UnknownPackageHandler;

if (unknownPackageHandler != null)
{
_unknownPackageHandler = unknownPackageHandler as Func<IAppSession, TPackageInfo, CancellationToken, ValueTask>;

if (_unknownPackageHandler == null)
{
_logger.LogError($"{nameof(commandOptions.Value.UnknownPackageHandler)} was registered with incorrectly. The expected typew is {typeof(Func<IAppSession, TPackageInfo, ValueTask>).Name}.");
}
}
}

private void RegisterCommandInterfaces(List<CommandTypeInfo> commandInterfaces, List<ICommandSetFactory> commandSetFactories, IServiceProvider serviceProvider, Type sessionType, Type packageType, bool wrapRequired = false)
Expand Down Expand Up @@ -200,6 +214,13 @@ protected virtual async ValueTask HandlePackage(IAppSession session, TPackageInf
{
if (!_commands.TryGetValue(package.Key, out ICommandSet commandSet))
{
var unknownPackageHandler = _unknownPackageHandler;

if (unknownPackageHandler != null)
{
await unknownPackageHandler.Invoke(session, package, cancellationToken);
}

return;
}

Expand Down
11 changes: 11 additions & 0 deletions src/SuperSocket.Command/CommandOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
using System.Linq;
using System.Collections.Generic;
using System.Reflection;
using System.Threading.Tasks;
using SuperSocket.ProtoBase;
using SuperSocket.Server.Abstractions.Session;
using System.Threading;

namespace SuperSocket.Command
{
Expand All @@ -13,6 +17,13 @@ public CommandOptions()
_globalCommandFilterTypes = new List<Type>();
}

internal object UnknownPackageHandler { get; private set; }

public void RegisterUnknownPackageHandler<TPackageInfo>(Func<IAppSession, TPackageInfo, CancellationToken, ValueTask> unknownPackageHandler)
{
UnknownPackageHandler = unknownPackageHandler;
}

public CommandAssemblyConfig[] Assemblies { get; set; }

public List<ICommandSource> CommandSources { get; set; }
Expand Down
51 changes: 51 additions & 0 deletions test/SuperSocket.Tests/CommandTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,57 @@ public async Task TestCommands(Type hostConfiguratorType)
}
}

[Theory]
[InlineData(typeof(RegularHostConfigurator))]
[InlineData(typeof(SecureHostConfigurator))]
public async Task TestUnknownCommands(Type hostConfiguratorType)
{
var hostConfigurator = CreateObject<IHostConfigurator>(hostConfiguratorType);
using (var server = CreateSocketServerBuilder<StringPackageInfo, CommandLinePipelineFilter>(hostConfigurator)
.UseCommand(commandOptions =>
{
// register commands one by one
commandOptions.AddCommand<ADD>();
commandOptions.RegisterUnknownPackageHandler<StringPackageInfo>(async (session, package, cancellationToken) =>
{
await session.SendAsync(Encoding.UTF8.GetBytes("X\r\n"));
});

// register all commands in one assembly
//commandOptions.AddCommandAssembly(typeof(SUB).GetTypeInfo().Assembly);
})
.BuildAsServer())
{

Assert.Equal("TestServer", server.Name);

Assert.True(await server.StartAsync());
OutputHelper.WriteLine("Server started.");


var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
await client.ConnectAsync(hostConfigurator.GetServerEndPoint());
OutputHelper.WriteLine("Connected.");

using (var stream = await hostConfigurator.GetClientStream(client))
using (var streamReader = new StreamReader(stream, Utf8Encoding, true))
using (var streamWriter = new StreamWriter(stream, Utf8Encoding, 1024 * 1024 * 4))
{
await streamWriter.WriteAsync("ADD 1 2 3\r\n");
await streamWriter.FlushAsync();
var line = await streamReader.ReadLineAsync();
Assert.Equal("6", line);

await streamWriter.WriteAsync("MULT 2 5\r\n");
await streamWriter.FlushAsync();
line = await streamReader.ReadLineAsync();
Assert.Equal("X", line);
}

await server.StopAsync();
}
}

[Theory]
[InlineData(typeof(RegularHostConfigurator))]
[InlineData(typeof(SecureHostConfigurator))]
Expand Down

0 comments on commit 699ef0f

Please sign in to comment.