diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs index 877166c2c..b9d8e57f3 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ClientConnectionContext.cs @@ -54,7 +54,9 @@ internal class ClientConnectionContext : ConnectionContext, IConnectionStatFeature { private const int WritingState = 1; + private const int CompletedState = 2; + private const int IdleState = 0; private static readonly PipeOptions DefaultPipeOptions = new PipeOptions(pauseWriterThreshold: 0, @@ -63,12 +65,13 @@ internal class ClientConnectionContext : ConnectionContext, useSynchronizationContext: false); private readonly TaskCompletionSource _connectionEndTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource(); - private int _connectionState = IdleState; + private readonly CancellationTokenSource _abortOutgoingCts = new CancellationTokenSource(); private readonly object _heartbeatLock = new object(); + private int _connectionState = IdleState; + private List<(Action handler, object state)> _heartbeatHandlers; private volatile bool _abortOnClose = true; @@ -175,6 +178,7 @@ public async Task WriteMessageAsync(ReadOnlySequence payload) { _lastMessageReceivedAt = DateTime.UtcNow.Ticks; _receivedBytes += payload.Length; + // Start write await WriteMessageAsyncCore(payload); } @@ -237,6 +241,53 @@ public void CancelOutgoing(int millisecondsDelay = 0) } } + internal static bool TryGetRemoteIpAddress(IHeaderDictionary headers, out IPAddress address) + { + var forwardedFor = headers.GetCommaSeparatedValues("X-Forwarded-For"); + if (forwardedFor.Length > 0 && IPAddress.TryParse(forwardedFor[0], out address)) + { + return true; + } + address = null; + return false; + } + + private static void ProcessQuery(string queryString, out string originalPath) + { + originalPath = string.Empty; + var query = QueryHelpers.ParseNullableQuery(queryString); + if (query == null) + { + return; + } + + if (query.TryGetValue(Constants.QueryParameter.RequestCulture, out var culture)) + { + SetCurrentThreadCulture(culture.FirstOrDefault()); + } + if (query.TryGetValue(Constants.QueryParameter.OriginalPath, out var path)) + { + originalPath = path.FirstOrDefault(); + } + } + + private static void SetCurrentThreadCulture(string cultureName) + { + if (!string.IsNullOrEmpty(cultureName)) + { + try + { + var requestCulture = new RequestCulture(cultureName); + CultureInfo.CurrentCulture = requestCulture.Culture; + CultureInfo.CurrentUICulture = requestCulture.UICulture; + } + catch (Exception) + { + // skip invalid culture, normal won't hit. + } + } + } + private FeatureCollection BuildFeatures(OpenConnectionMessage serviceMessage) { var features = new FeatureCollection(); @@ -311,52 +362,5 @@ private string GetInstanceId(IDictionary header) } return string.Empty; } - - internal static bool TryGetRemoteIpAddress(IHeaderDictionary headers, out IPAddress address) - { - var forwardedFor = headers.GetCommaSeparatedValues("X-Forwarded-For"); - if (forwardedFor.Length > 0 && IPAddress.TryParse(forwardedFor[0], out address)) - { - return true; - } - address = null; - return false; - } - - private static void ProcessQuery(string queryString, out string originalPath) - { - originalPath = string.Empty; - var query = QueryHelpers.ParseNullableQuery(queryString); - if (query == null) - { - return; - } - - if (query.TryGetValue(Constants.QueryParameter.RequestCulture, out var culture)) - { - SetCurrentThreadCulture(culture.FirstOrDefault()); - } - if (query.TryGetValue(Constants.QueryParameter.OriginalPath, out var path)) - { - originalPath = path.FirstOrDefault(); - } - } - - private static void SetCurrentThreadCulture(string cultureName) - { - if (!string.IsNullOrEmpty(cultureName)) - { - try - { - var requestCulture = new RequestCulture(cultureName); - CultureInfo.CurrentCulture = requestCulture.Culture; - CultureInfo.CurrentUICulture = requestCulture.UICulture; - } - catch (Exception) - { - // skip invalid culture, normal won't hit. - } - } - } } } diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index f67636c22..1a5aeda94 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -21,16 +21,20 @@ internal partial class ServiceConnection : ServiceConnectionBase { private const int DefaultCloseTimeoutMilliseconds = 30000; + private const string ClientConnectionCountInHub = "#clientInHub"; + + private const string ClientConnectionCountInServiceConnection = "#client"; + // Fix issue: https://github.com/Azure/azure-signalr/issues/198 // .NET Framework has restriction about reserved string as the header name like "User-Agent" private static readonly Dictionary CustomHeader = new Dictionary { { Constants.AsrsUserAgent, ProductInfo.GetProductInfo() } }; - private const string ClientConnectionCountInHub = "#clientInHub"; - private const string ClientConnectionCountInServiceConnection = "#client"; - private readonly IConnectionFactory _connectionFactory; + private readonly IClientConnectionFactory _clientConnectionFactory; + private readonly int _closeTimeOutMilliseconds; + private readonly IClientConnectionManager _clientConnectionManager; private readonly ConcurrentDictionary _connectionIds = @@ -155,10 +159,12 @@ protected override Task OnClientDisconnectedAsync(CloseConnectionMessage closeCo { context.AbortOnClose = false; context.Features.Set(new ConnectionMigrationFeature(ServerId, to)); + // We have to prevent SignalR `{type: 7}` (close message) from reaching our client while doing migration. // Since all data messages will be sent to `ServiceConnection` directly. // We can simply ignore all messages came from the application. context.CancelOutgoing(); + // The close connection message must be the last message, so we could complete the pipe. context.CompleteIncoming(); } @@ -211,14 +217,15 @@ protected override Task OnPingMessageAsync(PingMessage pingMessage) if (RuntimeServicePingMessage.TryGetOffline(pingMessage, out var instanceId)) { _clientInvocationManager.Caller.CleanupInvocationsByInstance(instanceId); + // Router invocations will be cleanup by its `CleanupInvocationsByConnection`, which is called by `RemoveClientConnection`. - // In `base.OnPingMessageAsync`, `CleanupClientConnections(instanceId)` will finally execute `RemoveClientConnection` for each ConnectionId. + // In `base.OnPingMessageAsync`, `CleanupClientConnections(instanceId)` will finally execute `RemoveClientConnection` for each ConnectionId. } #endif return base.OnPingMessageAsync(pingMessage); } - private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) + private async Task ProcessClientConnectionAsync(ClientConnectionContext connection) { try { @@ -276,6 +283,7 @@ private async Task ProcessClientConnectionAsync(ClientConnectionContext connecti // Inform the Service that we will remove the client because SignalR told us it is disconnected. var serviceMessage = new CloseConnectionMessage(connection.ConnectionId, errorMessage: exception?.Message); + // when it fails, it means the underlying connection is dropped // service is responsible for closing the client connections in this case and there is no need to throw await SafeWriteAsync(serviceMessage); @@ -494,4 +502,4 @@ private Task OnErrorCompletionAsync(ErrorCompletionMessage errorCompletionMessag return Task.CompletedTask; } } -} \ No newline at end of file +}