diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
index cc91d53ee..4be33c706 100644
--- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs
@@ -15,11 +15,12 @@ internal interface ICallerClientResultsManager : IClientResultsManager
/// Add a invocation which is directly called by current server
///
///
+ ///
///
///
///
///
- Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken);
+ Task AddInvocation(string hub, string connectionId, string invocationId, CancellationToken cancellationToken);
void AddServiceMapping(ServiceMappingMessage serviceMappingMessage);
diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs
index 9b6cc0c79..48b2180d0 100644
--- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs
+++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs
@@ -231,13 +231,12 @@ internal IEnumerable GetRoutedEndpoints(ServiceMessage message)
return _router.GetEndpointsForConnection(closeConnectionMessage.ConnectionId, endpoints);
case ClientInvocationMessage clientInvocationMessage:
- return SingleOrNotSupported(_router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints), clientInvocationMessage);
-
- case ServiceMappingMessage serviceMappingMessage:
- return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceMappingMessage.ConnectionId, endpoints), serviceMappingMessage);
+ return _router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints);
case ServiceCompletionMessage serviceCompletionMessage:
- return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints), serviceCompletionMessage);
+ return _router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints);
+
+ // ServiceMappingMessage should never be sent to the service
default:
throw new NotSupportedException(message.GetType().Name);
diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
index 330dd76e1..95a28488a 100644
--- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
+++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs
@@ -10,6 +10,7 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR;
+using System.Linq;
namespace Microsoft.Azure.SignalR
{
@@ -20,10 +21,15 @@ internal sealed class CallerClientResultsManager : ICallerClientResultsManager,
private long _lastInvocationId = 0;
private readonly IHubProtocolResolver _hubProtocolResolver;
+ private IEndpointRouter _endpointRouter { get; }
+ private IServiceEndpointManager _serviceEndpointManager { get; }
+ private readonly AckHandler _ackHandler = new();
- public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver)
+ public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter)
{
_hubProtocolResolver = hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver));
+ _serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
+ _endpointRouter = endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter));
}
public string GenerateInvocationId(string connectionId)
@@ -31,17 +37,26 @@ public string GenerateInvocationId(string connectionId)
return $"{connectionId}-{_clientResultManagerId}-{Interlocked.Increment(ref _lastInvocationId)}";
}
- public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken)
+ public Task AddInvocation(string hub, string connectionId, string invocationId, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSourceWithCancellation(
cancellationToken,
() => TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Canceled")));
+ var serviceEndpoints = _serviceEndpointManager.GetEndpoints(hub);
+ var ackNumber = _endpointRouter.GetEndpointsForConnection(connectionId, serviceEndpoints).Count();
+
+ var multiAck = _ackHandler.CreateMultiAck(out var ackId);
+
+ _ackHandler.SetExpectedCount(ackId, ackNumber);
+
// When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server.
// To handle this condition, CallerClientResultsManager itself should record this mapping information rather than waiting for a ServiceMappingMessage sent by service. Only in this condition, this method is called with instanceId != null.
var result = _pendingInvocations.TryAdd(invocationId,
new PendingInvocation(
typeof(T), connectionId, tcs,
+ ackId,
+ multiAck,
static (state, completionMessage) =>
{
var tcs = (TaskCompletionSourceWithCancellation)state;
@@ -108,14 +123,22 @@ public bool TryCompleteResult(string connectionId, CompletionMessage message)
// Follow https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/common/Shared/ClientResultsManager.cs#L58
throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'.");
}
-
- // if false the connection disconnected right after the above TryGetValue
- // or someone else completed the invocation (likely a bad client)
- // we'll ignore both cases
- if (_pendingInvocations.TryRemove(message.InvocationId, out _))
+
+ // Considering multiple endpoints, wait until
+ // 1. Received a non-error CompletionMessage
+ // or 2. Received messages from all endpoints
+ _ackHandler.TriggerAck(item.AckId);
+ if (message.HasResult || item.ackTask.IsCompletedSuccessfully)
{
- item.Complete(item.Tcs, message);
- return true;
+ // if false the connection disconnected right after the above TryGetValue
+ // or someone else completed the invocation (likely a bad client)
+ // we'll ignore both cases
+ if (_pendingInvocations.TryRemove(message.InvocationId, out _))
+ {
+ item.Complete(item.Tcs, message);
+ return true;
+ }
+ return false;
}
return false;
}
@@ -189,7 +212,7 @@ public void RemoveInvocation(string invocationId)
// Unused, here to honor the IInvocationBinder interface but should never be called
public Type GetStreamItemType(string streamId) => throw new NotImplementedException();
- private record PendingInvocation(Type Type, string ConnectionId, object Tcs, Action Complete)
+ private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int AckId, Task ackTask, Action Complete)
{
public string RouterInstanceId { get; set; }
}
diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs
index a25979a3e..5c79642fc 100644
--- a/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs
+++ b/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs
@@ -4,6 +4,8 @@
using System;
using Microsoft.AspNetCore.SignalR;
+#nullable enable
+
namespace Microsoft.Azure.SignalR
{
internal sealed class ClientInvocationManager : IClientInvocationManager
@@ -11,9 +13,13 @@ internal sealed class ClientInvocationManager : IClientInvocationManager
public ICallerClientResultsManager Caller { get; }
public IRoutedClientResultsManager Router { get; }
- public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver)
+ public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter)
{
- Caller = new CallerClientResultsManager(hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)));
+ Caller = new CallerClientResultsManager(
+ hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)),
+ serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)),
+ endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter))
+ );
Router = new RoutedClientResultsManager();
}
diff --git a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
index f35660863..f8af10f96 100644
--- a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
+++ b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
@@ -88,11 +88,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton()
.AddSingleton()
.AddSingleton()
-#if NET7_0_OR_GREATER
- .AddSingleton()
-#else
- .AddSingleton()
-#endif
.AddSingleton(typeof(NegotiateHandler<>));
// If a custom router is added, do not add the default router
@@ -102,6 +97,14 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
// If a custom service event handler is added, do not add the default handler.
builder.Services.TryAddSingleton();
+ // IEndpointRouter and IAccessKeySynchronizer is required to build ClientInvocationManager.
+ builder.Services
+#if NET7_0_OR_GREATER
+ .AddSingleton();
+#else
+ .AddSingleton();
+#endif
+
#if !NETSTANDARD2_0
builder.Services.TryAddSingleton();
builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton());
diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
index 05992ee09..175c2e6ed 100644
--- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
+++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
@@ -23,6 +23,7 @@ internal class ServiceLifetimeManager : ServiceLifetimeManagerBase w
private readonly IClientInvocationManager _clientInvocationManager;
private readonly IClientConnectionManager _clientConnectionManager;
private readonly string _callerId;
+ private readonly string _hub;
public ServiceLifetimeManager(
IServiceConnectionManager serviceConnectionManager,
@@ -49,9 +50,10 @@ public ServiceLifetimeManager(
throw new InvalidOperationException(MarkerNotConfiguredError);
}
#endif
+ _hub = typeof(THub).Name;
if (hubOptions.Value.SupportedProtocols != null && hubOptions.Value.SupportedProtocols.Any(x => x.Equals(Constants.Protocol.BlazorPack, StringComparison.OrdinalIgnoreCase)))
{
- blazorDetector?.TrySetBlazor(typeof(THub).Name, true);
+ blazorDetector?.TrySetBlazor(_hub, true);
}
_callerId = nameProvider?.GetName() ?? throw new ArgumentNullException(nameof(nameProvider));
@@ -128,7 +130,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri
var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId);
var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId)));
await WriteAsync(message);
- var task = _clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken);
+ var task = _clientInvocationManager.Caller.AddInvocation(_hub, connectionId, invocationId, cancellationToken);
// Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349
try
diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs
index 3be639312..208c71851 100644
--- a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs
@@ -3,14 +3,16 @@
#if NET7_0_OR_GREATER
using System;
using System.Collections.Generic;
-using System.Dynamic;
using System.Linq;
using System.Threading;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Azure.SignalR.Protocol;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Options;
using Xunit;
namespace Microsoft.Azure.SignalR
@@ -29,6 +31,30 @@ public class ClientInvocationManagerTests
private static readonly List TestConnectionIds = new() { "conn0", "conn1" };
private static readonly List TestInstanceIds = new() { "instance0", "instance1" };
private static readonly List TestServerIds = new() { "server1", "server2" };
+ private static readonly string SuccessCompleteResult = "success-result";
+ private static readonly string ErrorCompleteResult = "error-result";
+
+ private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1)
+ {
+ var services = new ServiceCollection();
+ var endpoints = Enumerable.Range(0, endpointCount)
+ .Select(i => new ServiceEndpoint($"Endpoint=https://test{i}connectionstring;AccessKey=1"))
+ .ToArray();
+
+ var config = new ConfigurationBuilder().Build();
+
+ var serviceProvider = services.AddLogging()
+ .AddSignalR().AddAzureSignalR(o => o.Endpoints = endpoints)
+ .Services
+ .AddSingleton(config)
+ .BuildServiceProvider();
+
+ var manager = serviceProvider.GetService();
+ var endpointRouter = serviceProvider.GetService();
+
+ var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver, manager, endpointRouter);
+ return clientInvocationManager;
+ }
[Theory]
[InlineData(true)]
@@ -42,15 +68,13 @@ public class ClientInvocationManagerTests
*/
public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult)
{
+ var clientInvocationManager = GetTestClientInvocationManager();
var connectionId = TestConnectionIds[0];
- var targetClientInstanceId = TestInstanceIds[0];
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId);
- var invocationResult = "invocation-correct-result";
CancellationToken cancellationToken = new CancellationToken();
// Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId`
- var task = clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken);
+ var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken);
var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t);
@@ -58,8 +82,8 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult)
Assert.Equal(typeof(string), t);
var completionMessage = isCompletionWithResult
- ? CompletionMessage.WithResult(invocationId, invocationResult)
- : CompletionMessage.WithError(invocationId, invocationResult);
+ ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult)
+ : CompletionMessage.WithError(invocationId, ErrorCompleteResult);
ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, completionMessage);
Assert.True(ret);
@@ -68,12 +92,12 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult)
{
await task;
Assert.True(isCompletionWithResult);
- Assert.Equal(invocationResult, task.Result);
+ Assert.Equal(SuccessCompleteResult, task.Result);
}
catch (Exception e)
{
Assert.False(isCompletionWithResult);
- Assert.Equal(invocationResult, e.Message);
+ Assert.Equal(ErrorCompleteResult, e.Message);
}
}
@@ -91,23 +115,21 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult)
public async void TestCompleteWithRouterServer(string protocol, bool isCompletionWithResult)
{
var serverIds = new string[] { TestServerIds[0], TestServerIds[1] };
- var invocationResult = "invocation-correct-result";
- var ciManagers = new ClientInvocationManager[]
- {
- new ClientInvocationManager(HubProtocolResolver),
- new ClientInvocationManager(HubProtocolResolver),
+ var ciManagers = new ClientInvocationManager[] {
+ GetTestClientInvocationManager(),
+ GetTestClientInvocationManager()
};
var invocationId = ciManagers[0].Caller.GenerateInvocationId(TestConnectionIds[0]);
CancellationToken cancellationToken = new CancellationToken();
// Server 1 doesn't know the InstanceId of Client 2, so `instaceId` is null for `AddInvocation`
- var task = ciManagers[0].Caller.AddInvocation(TestConnectionIds[0], invocationId, cancellationToken);
+ var task = ciManagers[0].Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cancellationToken);
ciManagers[0].Caller.AddServiceMapping(new ServiceMappingMessage(invocationId, TestConnectionIds[1], TestInstanceIds[1]));
ciManagers[1].Router.AddInvocation(TestConnectionIds[1], invocationId, serverIds[0], new CancellationToken());
var completionMessage = isCompletionWithResult
- ? CompletionMessage.WithResult(invocationId, invocationResult)
- : CompletionMessage.WithError(invocationId, invocationResult);
+ ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult)
+ : CompletionMessage.WithError(invocationId, ErrorCompleteResult);
var ret = ciManagers[1].Router.TryCompleteResult(TestConnectionIds[1], completionMessage);
Assert.True(ret);
@@ -122,22 +144,22 @@ public async void TestCompleteWithRouterServer(string protocol, bool isCompletio
{
await task;
Assert.True(isCompletionWithResult);
- Assert.Equal(invocationResult, task.Result);
+ Assert.Equal(SuccessCompleteResult, task.Result);
}
catch (Exception e)
{
Assert.False(isCompletionWithResult);
- Assert.Equal(invocationResult, e.Message);
+ Assert.Equal(ErrorCompleteResult, e.Message);
}
}
[Fact]
public void TestCallerManagerCancellation()
{
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = GetTestClientInvocationManager();
var invocationId = clientInvocationManager.Caller.GenerateInvocationId(TestConnectionIds[0]);
var cts = new CancellationTokenSource();
- var task = clientInvocationManager.Caller.AddInvocation(TestConnectionIds[0], invocationId, cts.Token);
+ var task = clientInvocationManager.Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cts.Token);
// Check if the invocation is existing
Assert.True(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _));
@@ -150,6 +172,106 @@ public void TestCallerManagerCancellation()
Assert.False(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _));
}
+
+ [Theory]
+ [InlineData(true, 2)]
+ [InlineData(false, 2)]
+ [InlineData(true, 3)]
+ [InlineData(false, 3)]
+ // isCompletionWithResult: the invocation is completed with result or error
+ public async void TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount)
+ {
+ Assert.True(endpointsCount > 1);
+ var clientInvocationManager = GetTestClientInvocationManager(endpointsCount);
+ var connectionId = TestConnectionIds[0];
+ var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId);
+
+ var cancellationToken = new CancellationToken();
+ // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId`
+ var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken);
+
+ var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t);
+
+ Assert.True(ret);
+ Assert.Equal(typeof(string), t);
+
+ var completionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult);
+ var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult);
+
+ // The first `endpointsCount - 1` CompletionMessage complete the invocation with error
+ // The last one completes the invocation according to `isCompletionWithResult`
+ // The invocation should be uncompleted until the last one CompletionMessage
+ for (var i = 0; i < endpointsCount - 1; i++)
+ {
+ var currentCompletionMessage = errorCompletionMessage;
+ ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, currentCompletionMessage);
+ Assert.False(ret);
+ }
+
+ ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, isCompletionWithResult ? completionMessage : errorCompletionMessage);
+ Assert.True(ret);
+
+ try
+ {
+ await task;
+ Assert.True(isCompletionWithResult);
+ Assert.Equal(SuccessCompleteResult, task.Result);
+ }
+ catch (Exception e)
+ {
+ Assert.False(isCompletionWithResult);
+ Assert.Equal(ErrorCompleteResult, e.Message);
+ }
+ }
+
+ [Theory]
+ [InlineData(2)]
+ [InlineData(3)]
+ public async void TestCompleteWithMultiEndpointAtMiddle(int endpointsCount)
+ {
+ Assert.True(endpointsCount > 1);
+ var clientInvocationManager = GetTestClientInvocationManager(endpointsCount);
+ var connectionId = TestConnectionIds[0];
+ var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId);
+
+ var cancellationToken = new CancellationToken();
+ // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId`
+ var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken);
+
+ var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t);
+
+ Assert.True(ret);
+ Assert.Equal(typeof(string), t);
+
+ var successCompletionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult);
+ var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult);
+
+ // The first `endpointsCount - 2` CompletionMessage complete the invocation with error
+ // The next one completes the invocation with result
+ // The last one completes the invocation with error and it shouldn't change the invocation result
+ for (var i = 0; i < endpointsCount - 2; i++)
+ {
+ ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, errorCompletionMessage);
+ Assert.False(ret);
+ }
+
+ ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, successCompletionMessage);
+ Assert.True(ret);
+
+ ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, errorCompletionMessage);
+ Assert.False(ret);
+
+ try
+ {
+ await task;
+ Assert.Equal(SuccessCompleteResult, task.Result);
+ }
+ catch (Exception)
+ {
+ Assert.True(false);
+ }
+ }
+
internal static ReadOnlyMemory GetBytes(string proto, HubMessage message)
{
IHubProtocol hubProtocol = proto == "json" ? new JsonHubProtocol() : new MessagePackHubProtocol();
diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs
index 7a08b65f0..dd445fa60 100644
--- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs
@@ -2,13 +2,9 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
+using Microsoft.Azure.SignalR.Tests;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.Azure.SignalR
@@ -26,8 +22,13 @@ public DefaultClientInvocationManager()
new MessagePackHubProtocol()
},
NullLogger.Instance);
-
- Caller = new CallerClientResultsManager(hubProtocolResolver);
+ var loggerFactory = new NullLoggerFactory();
+ var serviceEndpointManager = new ServiceEndpointManager(
+ new AccessKeySynchronizer(loggerFactory),
+ new TestOptionsMonitor(),
+ loggerFactory
+ );
+ Caller = new CallerClientResultsManager(hubProtocolResolver, serviceEndpointManager, new DefaultEndpointRouter());
Router = new RoutedClientResultsManager();
}
diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs
index e9bfbc2f6..9b9ea2749 100644
--- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs
@@ -59,11 +59,8 @@ public ServiceConnectionProxy(
{
ConnectionFactory = connectionFactoryCallback?.Invoke(ConnectionFactoryCallbackAsync) ?? new TestConnectionFactory(ConnectionFactoryCallbackAsync);
ClientConnectionManager = new ClientConnectionManager();
- ClientInvocationManager = new ClientInvocationManager(new DefaultHubProtocolResolver(new IHubProtocol[]
- {
- new JsonHubProtocol(),
- new MessagePackHubProtocol(),
- }, NullLogger.Instance));
+
+ ClientInvocationManager = new DefaultClientInvocationManager();
_clientPipeOptions = clientPipeOptions;
ConnectionDelegateCallback = callback ?? OnConnectionAsync;
diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs
new file mode 100644
index 000000000..d5cf6f787
--- /dev/null
+++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Options;
+
+namespace Microsoft.Azure.SignalR.Tests
+{
+ internal class TestOptionsMonitor : IOptionsMonitor
+ {
+ private readonly IOptionsMonitor _monitor;
+
+ public TestOptionsMonitor()
+ {
+ var config = new ConfigurationBuilder().Build();
+
+ var services = new ServiceCollection();
+ var endpoints = new List() { new ServiceEndpoint($"Endpoint=https://testconnectionstring;AccessKey=1") };
+ var serviceProvider = services.AddLogging()
+ .AddSignalR().AddAzureSignalR(o => o.Endpoints = endpoints.ToArray())
+ .Services
+ .AddSingleton(config)
+ .BuildServiceProvider();
+ _monitor = serviceProvider.GetRequiredService>();
+ }
+
+ public ServiceOptions CurrentValue => _monitor.CurrentValue;
+
+ public ServiceOptions Get(string name) => _monitor.Get(name);
+
+ public IDisposable OnChange(Action listener) => _monitor.OnChange(listener);
+ }
+}
diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
index 9d9001079..f91d0f556 100644
--- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
+++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs
@@ -65,7 +65,7 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type)
{
var serviceConnectionManager = new TestServiceConnectionManager();
var blazorDetector = new DefaultBlazorDetector();
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager,
new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager);
@@ -86,7 +86,7 @@ public async void ServiceLifetimeManagerGroupTest(string functionName, Type type
{
var serviceConnectionManager = new TestServiceConnectionManager();
var blazorDetector = new DefaultBlazorDetector();
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
var serviceLifetimeManager = new ServiceLifetimeManager(
serviceConnectionManager,
new ClientConnectionManager(),
@@ -126,7 +126,7 @@ public async void ServiceLifetimeManagerIntegrationTest(string methodName, Type
var serviceConnectionManager = new ServiceConnectionManager();
serviceConnectionManager.SetServiceConnection(proxy.ServiceConnectionContainer);
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager,
proxy.ClientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager);
@@ -173,7 +173,7 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct
IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol, "json" } });
IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol } });
var serviceConnectionManager = new TestServiceConnectionManager();
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager,
new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager);
@@ -273,7 +273,7 @@ private HubLifetimeManager MockLifetimeManager(IServiceConnectionManage
IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } });
IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } });
- var clientInvocationManager = new ClientInvocationManager(protocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
return new ServiceLifetimeManager(
serviceConnectionManager,
@@ -294,7 +294,7 @@ private static ServiceLifetimeManager GetTestClientInvocationServiceLif
ServiceConnectionBase serviceConnection,
IServiceConnectionManager serviceConnectionManager,
ClientConnectionManager clientConnectionManager,
- ClientInvocationManager clientInvocationManager = null,
+ IClientInvocationManager clientInvocationManager = null,
ClientConnectionContext clientConnectionContext = null,
string protocol = "json"
)
@@ -308,7 +308,7 @@ private static ServiceLifetimeManager GetTestClientInvocationServiceLif
// Create ServiceLifetimeManager
return new ServiceLifetimeManager(serviceConnectionManager,
- clientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, null, new DefaultServerNameProvider(), clientInvocationManager ?? new ClientInvocationManager(HubProtocolResolver));
+ clientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, null, new DefaultServerNameProvider(), clientInvocationManager ?? new DefaultClientInvocationManager());
}
private static ClientConnectionContext GetClientConnectionContextWithConnection(string connectionId = null, string protocol = null)
@@ -329,7 +329,7 @@ public async void TestClientInvocationOneService(string protocol, bool isComplet
var serviceConnection = new TestServiceConnection();
var serviceConnectionManager = new TestServiceConnectionManager();
- var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver);
+ var clientInvocationManager = new DefaultClientInvocationManager();
var clientConnectionContext = GetClientConnectionContextWithConnection(TestConnectionIds[1], protocol);
var serviceLifetimeManager = GetTestClientInvocationServiceLifetimeManager(serviceConnection, serviceConnectionManager, new ClientConnectionManager(), clientInvocationManager, clientConnectionContext, protocol);
@@ -380,9 +380,9 @@ public async void TestMultiClientInvocationsMultipleService(string protocol, boo
var clientConnectionManager = new ClientConnectionManager();
var serviceConnectionManager = new TestServiceConnectionManager();
- var clientInvocationManagers = new List() {
- new ClientInvocationManager(HubProtocolResolver),
- new ClientInvocationManager(HubProtocolResolver)
+ var clientInvocationManagers = new List() {
+ new DefaultClientInvocationManager(),
+ new DefaultClientInvocationManager()
};
var serviceLifetimeManagers = new List>() {