Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support client results in sharding mode #1852

Merged
merged 6 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ internal interface ICallerClientResultsManager : IClientResultsManager
/// Add a invocation which is directly called by current server
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="hub"></param>
/// <param name="connectionId"></param>
/// <param name="invocationId"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<T> AddInvocation<T>(string connectionId, string invocationId, CancellationToken cancellationToken);
Task<T> AddInvocation<T>(string hub, string connectionId, string invocationId, CancellationToken cancellationToken);

void AddServiceMapping(ServiceMappingMessage serviceMappingMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,12 @@ internal IEnumerable<ServiceEndpoint> GetRoutedEndpoints(ServiceMessage message)
return _router.GetEndpointsForConnection(closeConnectionMessage.ConnectionId, endpoints);

case ClientInvocationMessage clientInvocationMessage:
return SingleOrNotSupported(_router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints), clientInvocationMessage);

case ServiceMappingMessage serviceMappingMessage:
return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceMappingMessage.ConnectionId, endpoints), serviceMappingMessage);
return _router.GetEndpointsForConnection(clientInvocationMessage.ConnectionId, endpoints);

case ServiceCompletionMessage serviceCompletionMessage:
return SingleOrNotSupported(_router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints), serviceCompletionMessage);
return _router.GetEndpointsForConnection(serviceCompletionMessage.ConnectionId, endpoints);

// ServiceMappingMessage should never be sent to the service

default:
throw new NotSupportedException(message.GetType().Name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Microsoft.Azure.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR;
using System.Linq;

namespace Microsoft.Azure.SignalR
{
Expand All @@ -20,28 +21,42 @@ internal sealed class CallerClientResultsManager : ICallerClientResultsManager,
private long _lastInvocationId = 0;

private readonly IHubProtocolResolver _hubProtocolResolver;
private IEndpointRouter _endpointRouter { get; }
private IServiceEndpointManager _serviceEndpointManager { get; }
private readonly AckHandler _ackHandler = new();

public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver)
public CallerClientResultsManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter)
{
_hubProtocolResolver = hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver));
_serviceEndpointManager = serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager));
_endpointRouter = endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter));
xingsy97 marked this conversation as resolved.
Show resolved Hide resolved
}

public string GenerateInvocationId(string connectionId)
{
return $"{connectionId}-{_clientResultManagerId}-{Interlocked.Increment(ref _lastInvocationId)}";
}

public Task<T> AddInvocation<T>(string connectionId, string invocationId, CancellationToken cancellationToken)
public Task<T> AddInvocation<T>(string hub, string connectionId, string invocationId, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSourceWithCancellation<T>(
cancellationToken,
() => TryCompleteResult(connectionId, CompletionMessage.WithError(invocationId, "Canceled")));

var serviceEndpoints = _serviceEndpointManager.GetEndpoints(hub);
var ackNumber = _endpointRouter.GetEndpointsForConnection(connectionId, serviceEndpoints).Count();

var multiAck = _ackHandler.CreateMultiAck(out var ackId);

_ackHandler.SetExpectedCount(ackId, ackNumber);

// When the caller server is also the client router, Azure SignalR service won't send a ServiceMappingMessage to server.
// To handle this condition, CallerClientResultsManager itself should record this mapping information rather than waiting for a ServiceMappingMessage sent by service. Only in this condition, this method is called with instanceId != null.
var result = _pendingInvocations.TryAdd(invocationId,
new PendingInvocation(
typeof(T), connectionId, tcs,
ackId,
multiAck,
static (state, completionMessage) =>
{
var tcs = (TaskCompletionSourceWithCancellation<T>)state;
Expand Down Expand Up @@ -108,14 +123,22 @@ public bool TryCompleteResult(string connectionId, CompletionMessage message)
// Follow https://github.com/dotnet/aspnetcore/blob/main/src/SignalR/common/Shared/ClientResultsManager.cs#L58
throw new InvalidOperationException($"Connection ID '{connectionId}' is not valid for invocation ID '{message.InvocationId}'.");
}

// if false the connection disconnected right after the above TryGetValue
// or someone else completed the invocation (likely a bad client)
// we'll ignore both cases
if (_pendingInvocations.TryRemove(message.InvocationId, out _))

// Considering multiple endpoints, wait until
// 1. Received a non-error CompletionMessage
// or 2. Received messages from all endpoints
_ackHandler.TriggerAck(item.AckId);
if (message.HasResult || item.ackTask.IsCompletedSuccessfully)
{
item.Complete(item.Tcs, message);
return true;
// if false the connection disconnected right after the above TryGetValue
// or someone else completed the invocation (likely a bad client)
// we'll ignore both cases
if (_pendingInvocations.TryRemove(message.InvocationId, out _))
{
item.Complete(item.Tcs, message);
return true;
}
return false;
}
return false;
}
Expand Down Expand Up @@ -189,7 +212,7 @@ public void RemoveInvocation(string invocationId)
// Unused, here to honor the IInvocationBinder interface but should never be called
public Type GetStreamItemType(string streamId) => throw new NotImplementedException();

private record PendingInvocation(Type Type, string ConnectionId, object Tcs, Action<object, CompletionMessage> Complete)
private record PendingInvocation(Type Type, string ConnectionId, object Tcs, int AckId, Task ackTask, Action<object, CompletionMessage> Complete)
{
public string RouterInstanceId { get; set; }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
using System;
using Microsoft.AspNetCore.SignalR;

#nullable enable

namespace Microsoft.Azure.SignalR
{
internal sealed class ClientInvocationManager : IClientInvocationManager
{
public ICallerClientResultsManager Caller { get; }
public IRoutedClientResultsManager Router { get; }

public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver)
public ClientInvocationManager(IHubProtocolResolver hubProtocolResolver, IServiceEndpointManager serviceEndpointManager, IEndpointRouter endpointRouter)
{
Caller = new CallerClientResultsManager(hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)));
Caller = new CallerClientResultsManager(
hubProtocolResolver ?? throw new ArgumentNullException(nameof(hubProtocolResolver)),
serviceEndpointManager ?? throw new ArgumentNullException(nameof(serviceEndpointManager)),
endpointRouter ?? throw new ArgumentNullException(nameof(endpointRouter))
);
Router = new RoutedClientResultsManager();
}

Expand Down
13 changes: 8 additions & 5 deletions src/Microsoft.Azure.SignalR/DependencyInjectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
.AddSingleton<IClientConnectionFactory, ClientConnectionFactory>()
.AddSingleton<IHostedService, HeartBeat>()
.AddSingleton<IAccessKeySynchronizer, AccessKeySynchronizer>()
#if NET7_0_OR_GREATER
.AddSingleton<IClientInvocationManager, ClientInvocationManager>()
#else
.AddSingleton<IClientInvocationManager, DummyClientInvocationManager>()
#endif
.AddSingleton(typeof(NegotiateHandler<>));

// If a custom router is added, do not add the default router
Expand All @@ -102,6 +97,14 @@ private static ISignalRServerBuilder AddAzureSignalRCore(this ISignalRServerBuil
// If a custom service event handler is added, do not add the default handler.
builder.Services.TryAddSingleton<IServiceEventHandler, DefaultServiceEventHandler>();

// IEndpointRouter and IAccessKeySynchronizer is required to build ClientInvocationManager.
builder.Services
#if NET7_0_OR_GREATER
.AddSingleton<IClientInvocationManager, ClientInvocationManager>();
#else
.AddSingleton<IClientInvocationManager, DummyClientInvocationManager>();
#endif

#if !NETSTANDARD2_0
builder.Services.TryAddSingleton<AzureSignalRHostedService>();
builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<IStartupFilter, AzureSignalRStartupFilter>());
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.Azure.SignalR/HubHost/ServiceLifetimeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ internal class ServiceLifetimeManager<THub> : ServiceLifetimeManagerBase<THub> w
private readonly IClientInvocationManager _clientInvocationManager;
private readonly IClientConnectionManager _clientConnectionManager;
private readonly string _callerId;
private readonly string _hub;

public ServiceLifetimeManager(
IServiceConnectionManager<THub> serviceConnectionManager,
Expand All @@ -49,9 +50,10 @@ public ServiceLifetimeManager(
throw new InvalidOperationException(MarkerNotConfiguredError);
}
#endif
_hub = typeof(THub).Name;
if (hubOptions.Value.SupportedProtocols != null && hubOptions.Value.SupportedProtocols.Any(x => x.Equals(Constants.Protocol.BlazorPack, StringComparison.OrdinalIgnoreCase)))
{
blazorDetector?.TrySetBlazor(typeof(THub).Name, true);
blazorDetector?.TrySetBlazor(_hub, true);
}

_callerId = nameProvider?.GetName() ?? throw new ArgumentNullException(nameof(nameProvider));
Expand Down Expand Up @@ -128,7 +130,7 @@ public override async Task<T> InvokeConnectionAsync<T>(string connectionId, stri
var invocationId = _clientInvocationManager.Caller.GenerateInvocationId(connectionId);
var message = AppendMessageTracingId(new ClientInvocationMessage(invocationId, connectionId, _callerId, SerializeAllProtocols(methodName, args, invocationId)));
await WriteAsync(message);
var task = _clientInvocationManager.Caller.AddInvocation<T>(connectionId, invocationId, cancellationToken);
var task = _clientInvocationManager.Caller.AddInvocation<T>(_hub, connectionId, invocationId, cancellationToken);

// Exception handling follows https://source.dot.net/#Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs,349
try
Expand Down
Loading
Loading