diff --git a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs index cc91d53ee..4be33c706 100644 --- a/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR.Common/ClientInvocation/ICallerClientResultsManager.cs @@ -15,11 +15,12 @@ internal interface ICallerClientResultsManager : IClientResultsManager /// Add a invocation which is directly called by current server /// /// + /// /// /// /// /// - Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken); + Task AddInvocation(string hub, string connectionId, string invocationId, CancellationToken cancellationToken); void AddServiceMapping(ServiceMappingMessage serviceMappingMessage); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs index 9b6cc0c79..48b2180d0 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/MultiEndpointServiceConnectionContainer.cs @@ -231,13 +231,12 @@ internal IEnumerable GetRoutedEndpoints(ServiceMessage message) return _router.GetEndpointsForConnection(closeConnectionMessage.ConnectionId, endpoints); case ClientInvocationMessage clientInvocationMessage: - return SingleOrNotSupported(_router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints), clientInvocationMessage); - - case ServiceMappingMessage serviceMappingMessage: - return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceMappingMessage.ConnectionId, endpoints), serviceMappingMessage); + return _router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints); case ServiceCompletionMessage serviceCompletionMessage: - return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints), serviceCompletionMessage); + return _router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints); + + // ServiceMappingMessage should never be sent to the service default: throw new NotSupportedException(message.GetType().Name); diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs index 330dd76e1..95a28488a 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/CallerClientResultsManager.cs @@ -10,6 +10,7 @@ using Microsoft.Azure.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR; +using System.Linq; namespace Microsoft.Azure.SignalR { @@ -20,10 +21,15 @@ internal sealed class CallerClientResultsManager : ICallerClientResultsManager, private long _lastInvocationId = 0; private readonly IHubProtocolResolver _hubProtocolResolver; + private IEndpointRouter _endpointRouter { get; } + private IServiceEndpointManager _serviceEndpointManager { get; } + private readonly AckHandler _ackHandler = new(); - public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver) + public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter) { _hubProtocolResolver = hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)); + _serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)); + _endpointRouter = endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter)); } public string GenerateInvocationId(string connectionId) @@ -31,17 +37,26 @@ public string GenerateInvocationId(string connectionId) return $"{connectionId}-{_clientResultManagerId}-{Interlocked.Increment(ref _lastInvocationId)}"; } - public Task AddInvocation(string connectionId, string invocationId, CancellationToken cancellationToken) + public Task AddInvocation(string hub, string connectionId, string invocationId, CancellationToken cancellationToken) { var tcs = new TaskCompletionSourceWithCancellation( cancellationToken, () => TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Canceled"))); + var serviceEndpoints = _serviceEndpointManager.GetEndpoints(hub); + var ackNumber = _endpointRouter.GetEndpointsForConnection(connectionId, serviceEndpoints).Count(); + + var multiAck = _ackHandler.CreateMultiAck(out var ackId); + + _ackHandler.SetExpectedCount(ackId, ackNumber); + // When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server. // To handle this condition, CallerClientResultsManager itself should record this mapping information rather than waiting for a ServiceMappingMessage sent by service. Only in this condition, this method is called with instanceId != null. var result = _pendingInvocations.TryAdd(invocationId, new PendingInvocation( typeof(T), connectionId, tcs, + ackId, + multiAck, static (state, completionMessage) => { var tcs = (TaskCompletionSourceWithCancellation)state; @@ -108,14 +123,22 @@ public bool TryCompleteResult(string connectionId, CompletionMessage message) // Follow https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/common/Shared/ClientResultsManager.cs#L58 throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'."); } - - // if false the connection disconnected right after the above TryGetValue - // or someone else completed the invocation (likely a bad client) - // we'll ignore both cases - if (_pendingInvocations.TryRemove(message.InvocationId, out _)) + + // Considering multiple endpoints, wait until + // 1. Received a non-error CompletionMessage + // or 2. Received messages from all endpoints + _ackHandler.TriggerAck(item.AckId); + if (message.HasResult || item.ackTask.IsCompletedSuccessfully) { - item.Complete(item.Tcs, message); - return true; + // if false the connection disconnected right after the above TryGetValue + // or someone else completed the invocation (likely a bad client) + // we'll ignore both cases + if (_pendingInvocations.TryRemove(message.InvocationId, out _)) + { + item.Complete(item.Tcs, message); + return true; + } + return false; } return false; } @@ -189,7 +212,7 @@ public void RemoveInvocation(string invocationId) // Unused, here to honor the IInvocationBinder interface but should never be called public Type GetStreamItemType(string streamId) => throw new NotImplementedException(); - private record PendingInvocation(Type Type, string ConnectionId, object Tcs, Action Complete) + private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int AckId, Task ackTask, Action Complete) { public string RouterInstanceId { get; set; } } diff --git a/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs b/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs index a25979a3e..5c79642fc 100644 --- a/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs +++ b/src/Microsoft.Azure.SignalR/ClientInvocation/ClientInvocationManager.cs @@ -4,6 +4,8 @@ using System; using Microsoft.AspNetCore.SignalR; +#nullable enable + namespace Microsoft.Azure.SignalR { internal sealed class ClientInvocationManager : IClientInvocationManager @@ -11,9 +13,13 @@ internal sealed class ClientInvocationManager : IClientInvocationManager public ICallerClientResultsManager Caller { get; } public IRoutedClientResultsManager Router { get; } - public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver) + public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter) { - Caller = new CallerClientResultsManager(hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver))); + Caller = new CallerClientResultsManager( + hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)), + serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)), + endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter)) + ); Router = new RoutedClientResultsManager(); } diff --git a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs index f35660863..f8af10f96 100644 --- a/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs +++ b/src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs @@ -88,11 +88,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil .AddSingleton() .AddSingleton() .AddSingleton() -#if NET7_0_OR_GREATER - .AddSingleton() -#else - .AddSingleton() -#endif .AddSingleton(typeof(NegotiateHandler<>)); // If a custom router is added, do not add the default router @@ -102,6 +97,14 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil // If a custom service event handler is added, do not add the default handler. builder.Services.TryAddSingleton(); + // IEndpointRouter and IAccessKeySynchronizer is required to build ClientInvocationManager. + builder.Services +#if NET7_0_OR_GREATER + .AddSingleton(); +#else + .AddSingleton(); +#endif + #if !NETSTANDARD2_0 builder.Services.TryAddSingleton(); builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs index 05992ee09..175c2e6ed 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs @@ -23,6 +23,7 @@ internal class ServiceLifetimeManager : ServiceLifetimeManagerBase w private readonly IClientInvocationManager _clientInvocationManager; private readonly IClientConnectionManager _clientConnectionManager; private readonly string _callerId; + private readonly string _hub; public ServiceLifetimeManager( IServiceConnectionManager serviceConnectionManager, @@ -49,9 +50,10 @@ public ServiceLifetimeManager( throw new InvalidOperationException(MarkerNotConfiguredError); } #endif + _hub = typeof(THub).Name; if (hubOptions.Value.SupportedProtocols != null && hubOptions.Value.SupportedProtocols.Any(x => x.Equals(Constants.Protocol.BlazorPack, StringComparison.OrdinalIgnoreCase))) { - blazorDetector?.TrySetBlazor(typeof(THub).Name, true); + blazorDetector?.TrySetBlazor(_hub, true); } _callerId = nameProvider?.GetName() ?? throw new ArgumentNullException(nameof(nameProvider)); @@ -128,7 +130,7 @@ public override async Task InvokeConnectionAsync(string connectionId, stri var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId); var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId))); await WriteAsync(message); - var task = _clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken); + var task = _clientInvocationManager.Caller.AddInvocation(_hub, connectionId, invocationId, cancellationToken); // Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349 try diff --git a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs index 3be639312..208c71851 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ClientInvocationManagerTests.cs @@ -3,14 +3,16 @@ #if NET7_0_OR_GREATER using System; using System.Collections.Generic; -using System.Dynamic; using System.Linq; using System.Threading; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Azure.SignalR.Protocol; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.Azure.SignalR @@ -29,6 +31,30 @@ public class ClientInvocationManagerTests private static readonly List TestConnectionIds = new() { "conn0", "conn1" }; private static readonly List TestInstanceIds = new() { "instance0", "instance1" }; private static readonly List TestServerIds = new() { "server1", "server2" }; + private static readonly string SuccessCompleteResult = "success-result"; + private static readonly string ErrorCompleteResult = "error-result"; + + private static ClientInvocationManager GetTestClientInvocationManager(int endpointCount = 1) + { + var services = new ServiceCollection(); + var endpoints = Enumerable.Range(0, endpointCount) + .Select(i => new ServiceEndpoint($"Endpoint=https://test{i}connectionstring;AccessKey=1")) + .ToArray(); + + var config = new ConfigurationBuilder().Build(); + + var serviceProvider = services.AddLogging() + .AddSignalR().AddAzureSignalR(o => o.Endpoints = endpoints) + .Services + .AddSingleton(config) + .BuildServiceProvider(); + + var manager = serviceProvider.GetService(); + var endpointRouter = serviceProvider.GetService(); + + var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver, manager, endpointRouter); + return clientInvocationManager; + } [Theory] [InlineData(true)] @@ -42,15 +68,13 @@ public class ClientInvocationManagerTests */ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult) { + var clientInvocationManager = GetTestClientInvocationManager(); var connectionId = TestConnectionIds[0]; - var targetClientInstanceId = TestInstanceIds[0]; - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); - var invocationResult = "invocation-correct-result"; CancellationToken cancellationToken = new CancellationToken(); // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` - var task = clientInvocationManager.Caller.AddInvocation(connectionId, invocationId, cancellationToken); + var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); @@ -58,8 +82,8 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult) Assert.Equal(typeof(string), t); var completionMessage = isCompletionWithResult - ? CompletionMessage.WithResult(invocationId, invocationResult) - : CompletionMessage.WithError(invocationId, invocationResult); + ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult) + : CompletionMessage.WithError(invocationId, ErrorCompleteResult); ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, completionMessage); Assert.True(ret); @@ -68,12 +92,12 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult) { await task; Assert.True(isCompletionWithResult); - Assert.Equal(invocationResult, task.Result); + Assert.Equal(SuccessCompleteResult, task.Result); } catch (Exception e) { Assert.False(isCompletionWithResult); - Assert.Equal(invocationResult, e.Message); + Assert.Equal(ErrorCompleteResult, e.Message); } } @@ -91,23 +115,21 @@ public async void TestCompleteWithoutRouterServer(bool isCompletionWithResult) public async void TestCompleteWithRouterServer(string protocol, bool isCompletionWithResult) { var serverIds = new string[] { TestServerIds[0], TestServerIds[1] }; - var invocationResult = "invocation-correct-result"; - var ciManagers = new ClientInvocationManager[] - { - new ClientInvocationManager(HubProtocolResolver), - new ClientInvocationManager(HubProtocolResolver), + var ciManagers = new ClientInvocationManager[] { + GetTestClientInvocationManager(), + GetTestClientInvocationManager() }; var invocationId = ciManagers[0].Caller.GenerateInvocationId(TestConnectionIds[0]); CancellationToken cancellationToken = new CancellationToken(); // Server 1 doesn't know the InstanceId of Client 2, so `instaceId` is null for `AddInvocation` - var task = ciManagers[0].Caller.AddInvocation(TestConnectionIds[0], invocationId, cancellationToken); + var task = ciManagers[0].Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cancellationToken); ciManagers[0].Caller.AddServiceMapping(new ServiceMappingMessage(invocationId, TestConnectionIds[1], TestInstanceIds[1])); ciManagers[1].Router.AddInvocation(TestConnectionIds[1], invocationId, serverIds[0], new CancellationToken()); var completionMessage = isCompletionWithResult - ? CompletionMessage.WithResult(invocationId, invocationResult) - : CompletionMessage.WithError(invocationId, invocationResult); + ? CompletionMessage.WithResult(invocationId, SuccessCompleteResult) + : CompletionMessage.WithError(invocationId, ErrorCompleteResult); var ret = ciManagers[1].Router.TryCompleteResult(TestConnectionIds[1], completionMessage); Assert.True(ret); @@ -122,22 +144,22 @@ public async void TestCompleteWithRouterServer(string protocol, bool isCompletio { await task; Assert.True(isCompletionWithResult); - Assert.Equal(invocationResult, task.Result); + Assert.Equal(SuccessCompleteResult, task.Result); } catch (Exception e) { Assert.False(isCompletionWithResult); - Assert.Equal(invocationResult, e.Message); + Assert.Equal(ErrorCompleteResult, e.Message); } } [Fact] public void TestCallerManagerCancellation() { - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = GetTestClientInvocationManager(); var invocationId = clientInvocationManager.Caller.GenerateInvocationId(TestConnectionIds[0]); var cts = new CancellationTokenSource(); - var task = clientInvocationManager.Caller.AddInvocation(TestConnectionIds[0], invocationId, cts.Token); + var task = clientInvocationManager.Caller.AddInvocation("TestHub", TestConnectionIds[0], invocationId, cts.Token); // Check if the invocation is existing Assert.True(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _)); @@ -150,6 +172,106 @@ public void TestCallerManagerCancellation() Assert.False(clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out _)); } + + [Theory] + [InlineData(true, 2)] + [InlineData(false, 2)] + [InlineData(true, 3)] + [InlineData(false, 3)] + // isCompletionWithResult: the invocation is completed with result or error + public async void TestCompleteWithMultiEndpointAtLast(bool isCompletionWithResult, int endpointsCount) + { + Assert.True(endpointsCount > 1); + var clientInvocationManager = GetTestClientInvocationManager(endpointsCount); + var connectionId = TestConnectionIds[0]; + var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); + + var cancellationToken = new CancellationToken(); + // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` + var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); + + var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); + + Assert.True(ret); + Assert.Equal(typeof(string), t); + + var completionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult); + var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult); + + // The first `endpointsCount - 1` CompletionMessage complete the invocation with error + // The last one completes the invocation according to `isCompletionWithResult` + // The invocation should be uncompleted until the last one CompletionMessage + for (var i = 0; i < endpointsCount - 1; i++) + { + var currentCompletionMessage = errorCompletionMessage; + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, currentCompletionMessage); + Assert.False(ret); + } + + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, isCompletionWithResult ? completionMessage : errorCompletionMessage); + Assert.True(ret); + + try + { + await task; + Assert.True(isCompletionWithResult); + Assert.Equal(SuccessCompleteResult, task.Result); + } + catch (Exception e) + { + Assert.False(isCompletionWithResult); + Assert.Equal(ErrorCompleteResult, e.Message); + } + } + + [Theory] + [InlineData(2)] + [InlineData(3)] + public async void TestCompleteWithMultiEndpointAtMiddle(int endpointsCount) + { + Assert.True(endpointsCount > 1); + var clientInvocationManager = GetTestClientInvocationManager(endpointsCount); + var connectionId = TestConnectionIds[0]; + var invocationId = clientInvocationManager.Caller.GenerateInvocationId(connectionId); + + var cancellationToken = new CancellationToken(); + // Server A knows the InstanceId of Client 2, so `instaceId` in `AddInvocation` is `targetClientInstanceId` + var task = clientInvocationManager.Caller.AddInvocation("TestHub", connectionId, invocationId, cancellationToken); + + var ret = clientInvocationManager.Caller.TryGetInvocationReturnType(invocationId, out var t); + + Assert.True(ret); + Assert.Equal(typeof(string), t); + + var successCompletionMessage = CompletionMessage.WithResult(invocationId, SuccessCompleteResult); + var errorCompletionMessage = CompletionMessage.WithError(invocationId, ErrorCompleteResult); + + // The first `endpointsCount - 2` CompletionMessage complete the invocation with error + // The next one completes the invocation with result + // The last one completes the invocation with error and it shouldn't change the invocation result + for (var i = 0; i < endpointsCount - 2; i++) + { + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, errorCompletionMessage); + Assert.False(ret); + } + + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, successCompletionMessage); + Assert.True(ret); + + ret = clientInvocationManager.Caller.TryCompleteResult(connectionId, errorCompletionMessage); + Assert.False(ret); + + try + { + await task; + Assert.Equal(SuccessCompleteResult, task.Result); + } + catch (Exception) + { + Assert.True(false); + } + } + internal static ReadOnlyMemory GetBytes(string proto, HubMessage message) { IHubProtocol hubProtocol = proto == "json" ? new JsonHubProtocol() : new MessagePackHubProtocol(); diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs index 7a08b65f0..dd445fa60 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/DefaultClientInvocationManager.cs @@ -2,13 +2,9 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Azure.SignalR.Tests; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.Azure.SignalR @@ -26,8 +22,13 @@ public DefaultClientInvocationManager() new MessagePackHubProtocol() }, NullLogger.Instance); - - Caller = new CallerClientResultsManager(hubProtocolResolver); + var loggerFactory = new NullLoggerFactory(); + var serviceEndpointManager = new ServiceEndpointManager( + new AccessKeySynchronizer(loggerFactory), + new TestOptionsMonitor(), + loggerFactory + ); + Caller = new CallerClientResultsManager(hubProtocolResolver, serviceEndpointManager, new DefaultEndpointRouter()); Router = new RoutedClientResultsManager(); } diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs index e9bfbc2f6..9b9ea2749 100644 --- a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/ServiceConnectionProxy.cs @@ -59,11 +59,8 @@ public ServiceConnectionProxy( { ConnectionFactory = connectionFactoryCallback?.Invoke(ConnectionFactoryCallbackAsync) ?? new TestConnectionFactory(ConnectionFactoryCallbackAsync); ClientConnectionManager = new ClientConnectionManager(); - ClientInvocationManager = new ClientInvocationManager(new DefaultHubProtocolResolver(new IHubProtocol[] - { - new JsonHubProtocol(), - new MessagePackHubProtocol(), - }, NullLogger.Instance)); + + ClientInvocationManager = new DefaultClientInvocationManager(); _clientPipeOptions = clientPipeOptions; ConnectionDelegateCallback = callback ?? OnConnectionAsync; diff --git a/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs new file mode 100644 index 000000000..d5cf6f787 --- /dev/null +++ b/test/Microsoft.Azure.SignalR.Tests/Infrastructure/TestOptionsMonitor.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.Azure.SignalR.Tests +{ + internal class TestOptionsMonitor : IOptionsMonitor + { + private readonly IOptionsMonitor _monitor; + + public TestOptionsMonitor() + { + var config = new ConfigurationBuilder().Build(); + + var services = new ServiceCollection(); + var endpoints = new List() { new ServiceEndpoint($"Endpoint=https://testconnectionstring;AccessKey=1") }; + var serviceProvider = services.AddLogging() + .AddSignalR().AddAzureSignalR(o => o.Endpoints = endpoints.ToArray()) + .Services + .AddSingleton(config) + .BuildServiceProvider(); + _monitor = serviceProvider.GetRequiredService>(); + } + + public ServiceOptions CurrentValue => _monitor.CurrentValue; + + public ServiceOptions Get(string name) => _monitor.Get(name); + + public IDisposable OnChange(Action listener) => _monitor.OnChange(listener); + } +} diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs index 9d9001079..f91d0f556 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceLifetimeManagerFacts.cs @@ -65,7 +65,7 @@ public async void ServiceLifetimeManagerTest(string functionName, Type type) { var serviceConnectionManager = new TestServiceConnectionManager(); var blazorDetector = new DefaultBlazorDetector(); - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, new ClientConnectionManager(), HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); @@ -86,7 +86,7 @@ public async void ServiceLifetimeManagerGroupTest(string functionName, Type type { var serviceConnectionManager = new TestServiceConnectionManager(); var blazorDetector = new DefaultBlazorDetector(); - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); var serviceLifetimeManager = new ServiceLifetimeManager( serviceConnectionManager, new ClientConnectionManager(), @@ -126,7 +126,7 @@ public async void ServiceLifetimeManagerIntegrationTest(string methodName, Type var serviceConnectionManager = new ServiceConnectionManager(); serviceConnectionManager.SetServiceConnection(proxy.ServiceConnectionContainer); - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, proxy.ClientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); @@ -173,7 +173,7 @@ public async void ServiceLifetimeManagerIgnoreBlazorHubProtocolTest(string funct IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol, "json" } }); IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { "json", "messagepack", MockProtocol } }); var serviceConnectionManager = new TestServiceConnectionManager(); - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); var serviceLifetimeManager = new ServiceLifetimeManager(serviceConnectionManager, new ClientConnectionManager(), protocolResolver, Logger, Marker, globalHubOptions, localHubOptions, blazorDetector, new DefaultServerNameProvider(), clientInvocationManager); @@ -273,7 +273,7 @@ private HubLifetimeManager MockLifetimeManager(IServiceConnectionManage IOptions globalHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); IOptions> localHubOptions = Options.Create(new HubOptions() { SupportedProtocols = new List() { MockProtocol } }); - var clientInvocationManager = new ClientInvocationManager(protocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); return new ServiceLifetimeManager( serviceConnectionManager, @@ -294,7 +294,7 @@ private static ServiceLifetimeManager GetTestClientInvocationServiceLif ServiceConnectionBase serviceConnection, IServiceConnectionManager serviceConnectionManager, ClientConnectionManager clientConnectionManager, - ClientInvocationManager clientInvocationManager = null, + IClientInvocationManager clientInvocationManager = null, ClientConnectionContext clientConnectionContext = null, string protocol = "json" ) @@ -308,7 +308,7 @@ private static ServiceLifetimeManager GetTestClientInvocationServiceLif // Create ServiceLifetimeManager return new ServiceLifetimeManager(serviceConnectionManager, - clientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, null, new DefaultServerNameProvider(), clientInvocationManager ?? new ClientInvocationManager(HubProtocolResolver)); + clientConnectionManager, HubProtocolResolver, Logger, Marker, _globalHubOptions, _localHubOptions, null, new DefaultServerNameProvider(), clientInvocationManager ?? new DefaultClientInvocationManager()); } private static ClientConnectionContext GetClientConnectionContextWithConnection(string connectionId = null, string protocol = null) @@ -329,7 +329,7 @@ public async void TestClientInvocationOneService(string protocol, bool isComplet var serviceConnection = new TestServiceConnection(); var serviceConnectionManager = new TestServiceConnectionManager(); - var clientInvocationManager = new ClientInvocationManager(HubProtocolResolver); + var clientInvocationManager = new DefaultClientInvocationManager(); var clientConnectionContext = GetClientConnectionContextWithConnection(TestConnectionIds[1], protocol); var serviceLifetimeManager = GetTestClientInvocationServiceLifetimeManager(serviceConnection, serviceConnectionManager, new ClientConnectionManager(), clientInvocationManager, clientConnectionContext, protocol); @@ -380,9 +380,9 @@ public async void TestMultiClientInvocationsMultipleService(string protocol, boo var clientConnectionManager = new ClientConnectionManager(); var serviceConnectionManager = new TestServiceConnectionManager(); - var clientInvocationManagers = new List() { - new ClientInvocationManager(HubProtocolResolver), - new ClientInvocationManager(HubProtocolResolver) + var clientInvocationManagers = new List() { + new DefaultClientInvocationManager(), + new DefaultClientInvocationManager() }; var serviceLifetimeManagers = new List>() {