From 4d7792b20929f14026f9df581501689edf08fd6a Mon Sep 17 00:00:00 2001 From: "Liangying.Wei" Date: Wed, 21 Aug 2024 16:05:02 +0800 Subject: [PATCH] Sticky to one service connection if not in async scope (#2023) * Add a byte partitionkey for unscoped message sending --- .../ChatSample/ChatSample.Net60/Program.cs | 4 + .../ChatSample.Net60/StreamingService.cs | 50 +++++ .../ServiceConnectionBase.cs | 5 +- .../ServiceConnectionContainerBase.cs | 197 +++++++++++------- .../CheckWithAckMessage.cs | 10 +- .../CloseWithAckMessage.cs | 4 +- .../ConnectionMessage.cs | 9 +- .../GroupMessage.cs | 34 ++- .../MulticastDataMessage.cs | 33 ++- .../ServiceMessage.cs | 10 + .../ServiceConnectionContainerBaseTests.cs | 37 ++++ .../ServiceProtocolFacts.cs | 109 +++++++++- .../ServiceConnectionContainerTest.cs | 51 ++++- 13 files changed, 443 insertions(+), 110 deletions(-) create mode 100644 samples/ChatSample/ChatSample.Net60/StreamingService.cs diff --git a/samples/ChatSample/ChatSample.Net60/Program.cs b/samples/ChatSample/ChatSample.Net60/Program.cs index 8a7bb3e96..3a6c08868 100644 --- a/samples/ChatSample/ChatSample.Net60/Program.cs +++ b/samples/ChatSample/ChatSample.Net60/Program.cs @@ -1,3 +1,4 @@ +using ChatSample.Net60; using ChatSample.Net60.Hubs; var builder = WebApplication.CreateBuilder(args); @@ -6,6 +7,9 @@ builder.Services.AddRazorPages(); builder.Services.AddSignalR().AddAzureSignalR(); +// uncomment for streaming outside the scope +// builder.Services.AddHostedService(); + var app = builder.Build(); // Configure the HTTP request pipeline. diff --git a/samples/ChatSample/ChatSample.Net60/StreamingService.cs b/samples/ChatSample/ChatSample.Net60/StreamingService.cs new file mode 100644 index 000000000..58968f68f --- /dev/null +++ b/samples/ChatSample/ChatSample.Net60/StreamingService.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + + +using ChatSample.Net60.Hubs; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Logging; + +namespace ChatSample.Net60 +{ + public class StreamingService : IHostedService + { + private readonly IHubContext _hubContext; + private readonly ILogger _logger; + + public StreamingService(IHubContext hubContext, ILogger logger) { + _hubContext = hubContext; + _logger = logger; + } + public Task StartAsync(CancellationToken cancellationToken) + { + return Task.Factory.StartNew(() => StreamingTask(cancellationToken), TaskCreationOptions.LongRunning); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + private async Task StreamingTask(CancellationToken cancellationToken) + { + long counter = 0; + + _logger.LogInformation("Waiting"); + + await Task.Delay(5000); + + _logger.LogInformation("Spamming"); + + while (!cancellationToken.IsCancellationRequested) + { + counter++; + + await _hubContext.Clients.All.SendAsync("ReceiveMessage", counter, counter); + + await Task.Delay(1); + } + } + } +} diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index d2434c2c4..8a544cb06 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -178,6 +178,8 @@ public async Task StartAsync(string target = null) } finally { + // mark the status as Disconnected so that no one will write to this connection anymore + Status = ServiceConnectionStatus.Disconnected; syncTimer?.Stop(); // when ProcessIncoming completes, clean up the connection @@ -195,10 +197,7 @@ public async Task StartAsync(string target = null) finally { // wait until all the connections are cleaned up to close the outgoing pipe - // mark the status as Disconnected so that no one will write to this connection anymore // Don't allow write anymore when the connection is disconnected - Status = ServiceConnectionStatus.Disconnected; - await _writeLock.WaitAsync(); try { diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 9c073287d..f6cf7bf55 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.SignalR.Common; @@ -24,11 +26,16 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta private static readonly int MaxReconnectBackOffInternalInMilliseconds = 1000; + private static readonly TimeSpan MessageWriteRetryDelay = TimeSpan.FromMilliseconds(200); + private static readonly int MessageWriteMaxRetry = 3; + // Give (interval * 3 + 1) delay when check value expire. private static readonly long DefaultServersPingTimeoutTicks = Stopwatch.Frequency * ((long)Constants.Periods.DefaultServersPingInterval.TotalSeconds * 3 + 1); private static readonly Tuple DefaultServersTagContext = new Tuple(string.Empty, 0); + private readonly IReadOnlyDictionary>> _partitionedCache; + private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy(); private readonly object _lock = new object(); @@ -155,6 +162,8 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec } _serversPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.Servers, WriteServersPingAsync, Constants.Periods.DefaultServersPingInterval, Constants.Periods.DefaultServersPingInterval); + + _partitionedCache = Enumerable.Range(0, 256).ToDictionary(i => (byte)i, i => new StrongBox>(new WeakReference(null))); } public event Action ConnectionStatusChanged; @@ -216,7 +225,7 @@ public void HandleAck(AckMessage ackMessage) public virtual Task WriteAsync(ServiceMessage serviceMessage) { - return WriteToScopedOrRandomAvailableConnection(serviceMessage); + return WriteMessageAsync(serviceMessage); } public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default) @@ -233,7 +242,7 @@ public async Task WriteAckableMessageAsync(ServiceMessage serviceMessage, // whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout). // Therefore sending them over different connections creates a possibility for processing them out of original order. // By sending both message types over the same connection we ensure that they are sent (and processed) in their original order. - await WriteToScopedOrRandomAvailableConnection(serviceMessage); + await WriteMessageAsync(serviceMessage); var status = await task; return AckHandler.HandleAckStatus(ackableMessage, status); @@ -414,6 +423,115 @@ protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdow Log.TimeoutWaitingForFinAck(Logger, retry); } + private async Task WriteMessageAsync(ServiceMessage serviceMessage) + { + var connection = SelectConnection(serviceMessage); + + var retry = 0; + var maxRetry = MessageWriteMaxRetry; + var delay = MessageWriteRetryDelay; + while (true) + { + try + { + await connection.WriteAsync(serviceMessage); + return; + } + catch (ServiceConnectionNotActiveException) + { + // enter the re-select logic + retry++; + if (retry == maxRetry) + { + throw; + } + + await Task.Delay(delay); + connection = SelectConnection(serviceMessage); + } + } + } + + private IServiceConnection SelectConnection(ServiceMessage message) + { + IServiceConnection connection = null; + if (ClientConnectionScope.IsScopeEstablished) + { + // see if the execution context already has the connection stored for this container + var containers = ClientConnectionScope.OutboundServiceConnections; + if (!(containers.TryGetValue(Endpoint.UniqueIndex, out var connectionWeakReference) + && connectionWeakReference.TryGetTarget(out connection) + && IsActiveConnection(connection))) + { + connection = GetRandomActiveConnection(); + ClientConnectionScope.OutboundServiceConnections[Endpoint.UniqueIndex] = new WeakReference(connection); + } + } + else + { + // if it is not in scope + // if message is partitionable, use the container's partition cache, otherwise use a random connection + if (message is IPartitionableMessage partitionable) + { + var box = _partitionedCache[partitionable.PartitionKey]; + if (!box.Value.TryGetTarget(out connection) || !IsActiveConnection(connection)) + { + lock (box) + { + if (!box.Value.TryGetTarget(out connection) || !IsActiveConnection(connection)) + { + connection = GetRandomActiveConnection(); + box.Value.SetTarget(connection); + } + } + } + } + else + { + connection = GetRandomActiveConnection(); + } + } + + if (connection == null) + { + throw new ServiceConnectionNotActiveException(); + } + + return connection; + } + + private bool IsActiveConnection(IServiceConnection connection) + { + return connection != null && connection.Status == ServiceConnectionStatus.Connected; + } + + private IServiceConnection GetRandomActiveConnection() + { + var currentConnections = ServiceConnections; + + // go through all the connections, it can be useful when one of the remote service instances is down + var count = currentConnections.Count; + var initial = StaticRandom.Next(-count, count); + var maxRetry = count; + var retry = 0; + var index = (initial & int.MaxValue) % count; + var direction = initial > 0 ? 1 : count - 1; + + while (retry < maxRetry) + { + var connection = currentConnections[index]; + if (IsActiveConnection(connection)) + { + return connection; + } + + retry++; + index = (index + direction) % count; + } + + return null; + } + private async Task RestartFixedServiceConnectionCoreAsync(int index) { if (_terminated) @@ -481,81 +599,6 @@ private void OnConnectionStatusChanged(StatusChange obj) } } - private async Task WriteToScopedOrRandomAvailableConnection(ServiceMessage serviceMessage) - { - // ServiceConnections can change the collection underneath so we make a local copy and pass it along - var currentConnections = ServiceConnections; - - if (ClientConnectionScope.IsScopeEstablished) - { - // see if the execution context already has the connection stored for this container - var containers = ClientConnectionScope.OutboundServiceConnections; - Debug.Assert(containers != null); - containers.TryGetValue(Endpoint.UniqueIndex, out var connectionWeakReference); - IServiceConnection connection = null; - connectionWeakReference?.TryGetTarget(out connection); - - var connectionUsed = await WriteWithRetry(serviceMessage, connection, currentConnections); - - // Todo: - // There is currently no synchronization when persisting selected connection in ClientConnectionScope. - // This is only a concern when there are concurrent writes involved and when one of the following is true: - // - we need to change the selected connection (e.g. the currently persisted connection status is bad) - // - we need to make the initial connection selection (e.g. no secondary connection in async local yet) - // This lack of synchronization can lead to using multiple connections and cause out of order messages. - - // Try to persist the connection choice for the subsequent calls within the same async flow - if (connectionUsed != connection) - { - ClientConnectionScope.OutboundServiceConnections[Endpoint.UniqueIndex] = new WeakReference(connectionUsed); - } - } - else - { - await WriteWithRetry(serviceMessage, null, currentConnections); - } - } - - private async Task WriteWithRetry(ServiceMessage serviceMessage, IServiceConnection connection, List currentConnections) - { - // go through all the connections, it can be useful when one of the remote service instances is down - var count = currentConnections.Count; - var initial = StaticRandom.Next(-count, count); - var maxRetry = count; - var retry = 0; - var index = (initial & int.MaxValue) % count; - var direction = initial > 0 ? 1 : count - 1; - - // ensure a full sweep starting with the connection flowed with the async context - while (retry <= maxRetry) - { - if (connection != null && connection.Status == ServiceConnectionStatus.Connected) - { - try - { - // still possible the connection is not valid - await connection.WriteAsync(serviceMessage); - return connection; - } - catch (ServiceConnectionNotActiveException) - { - if (retry == maxRetry - 1) - { - throw; - } - } - } - - // try current index instead - connection = currentConnections[index]; - - retry++; - index = (index + direction) % count; - } - - throw new ServiceConnectionNotActiveException(); - } - private IEnumerable CreateFixedServiceConnection(int count) { for (int i = 0; i < count; i++) diff --git a/src/Microsoft.Azure.SignalR.Protocols/CheckWithAckMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/CheckWithAckMessage.cs index 6c3cae8c9..b153a840b 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/CheckWithAckMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/CheckWithAckMessage.cs @@ -28,7 +28,7 @@ protected CheckWithAckMessage(int ackId, ulong? tracingId) /// /// A waiting for ack check-user-in-group message. /// - public class CheckUserInGroupWithAckMessage : CheckWithAckMessage + public class CheckUserInGroupWithAckMessage : CheckWithAckMessage, IPartitionableMessage { /// /// Gets or sets the user Id. @@ -39,6 +39,7 @@ public class CheckUserInGroupWithAckMessage : CheckWithAckMessage /// Gets or sets the group name. /// public string GroupName { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); /// /// Initializes a new instance of the class. @@ -57,13 +58,15 @@ public CheckUserInGroupWithAckMessage(string userId, string groupName, int ackId /// /// A waiting for ack check-any-connection-in-group message. /// - public class CheckGroupExistenceWithAckMessage : CheckWithAckMessage + public class CheckGroupExistenceWithAckMessage : CheckWithAckMessage, IPartitionableMessage { /// /// Gets or sets the group name. /// public string GroupName { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -101,12 +104,13 @@ public CheckConnectionExistenceWithAckMessage(string connectionId, int ackId = 0 /// /// A waiting for ack check-user-existence message. /// - public class CheckUserExistenceWithAckMessage : CheckWithAckMessage + public class CheckUserExistenceWithAckMessage : CheckWithAckMessage, IPartitionableMessage { /// /// Gets or sets the user Id. /// public string UserId { get; set; } + public byte PartitionKey => GeneratePartitionKey(UserId); /// /// Initializes a new instance of the class. diff --git a/src/Microsoft.Azure.SignalR.Protocols/CloseWithAckMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/CloseWithAckMessage.cs index b5cd02547..0c842ccc3 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/CloseWithAckMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/CloseWithAckMessage.cs @@ -104,13 +104,15 @@ public CloseUserConnectionsWithAckMessage(string userId, int ackId) : base(ackId /// /// Close connections in a group. /// - public class CloseGroupConnectionsWithAckMessage : CloseMultiConnectionsWithAckMessage + public class CloseGroupConnectionsWithAckMessage : CloseMultiConnectionsWithAckMessage, IPartitionableMessage { /// /// Gets or sets the group name. /// public string GroupName { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// diff --git a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs index 79740304d..830eda38b 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs @@ -126,7 +126,7 @@ public CloseConnectionMessage(string connectionId) : this(connectionId, "") /// /// A connection data message. /// - public class ConnectionDataMessage : ConnectionMessage, IMessageWithTracingId, IHasDataMessageType, IPartializable + public class ConnectionDataMessage : ConnectionMessage, IMessageWithTracingId, IHasDataMessageType, IPartializable, IPartitionableMessage { /// /// Initializes a new instance of the class. @@ -134,10 +134,9 @@ public class ConnectionDataMessage : ConnectionMessage, IMessageWithTracingId, I /// The connection Id. /// Binary data to be delivered. /// The tracing Id of the message - public ConnectionDataMessage(string connectionId, ReadOnlyMemory payload, ulong? tracingId = null) : base(connectionId) + public ConnectionDataMessage(string connectionId, ReadOnlyMemory payload, ulong? tracingId = null) + : this(connectionId, new ReadOnlySequence(payload), tracingId) { - Payload = new ReadOnlySequence(payload); - TracingId = tracingId; } /// @@ -171,6 +170,8 @@ public ConnectionDataMessage(string connectionId, ReadOnlySequence payload /// Gets or sets the payload is partial or not. /// public bool IsPartial { get; set; } + + public byte PartitionKey => GeneratePartitionKey(ConnectionId); } /// diff --git a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs index 954e30bec..1ff99e8e6 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs @@ -6,7 +6,7 @@ namespace Microsoft.Azure.SignalR.Protocol /// /// A join-group message. /// - public class JoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId + public class JoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId, IPartitionableMessage { /// /// Gets or sets the connection Id. @@ -23,6 +23,8 @@ public class JoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId /// public ulong? TracingId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -40,7 +42,7 @@ public JoinGroupMessage(string connectionId, string groupName, ulong? tracingId /// /// A leave-group message. /// - public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId + public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId, IPartitionableMessage { /// /// Gets or sets the connection Id. @@ -57,6 +59,8 @@ public class LeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId /// public ulong? TracingId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -74,7 +78,7 @@ public LeaveGroupMessage(string connectionId, string groupName, ulong? tracingId /// /// A user-join-group message. /// - public class UserJoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId, IHasTtl + public class UserJoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId, IHasTtl, IPartitionableMessage { /// /// Gets or sets the user Id. @@ -96,6 +100,8 @@ public class UserJoinGroupMessage : ExtensibleServiceMessage, IMessageWithTracin /// public int? Ttl { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -113,7 +119,7 @@ public UserJoinGroupMessage(string userId, string groupName, ulong? tracingId = /// /// A user-leave-group message. /// - public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId + public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTracingId, IPartitionableMessage { /// /// Gets or sets the user Id. @@ -130,6 +136,8 @@ public class UserLeaveGroupMessage : ExtensibleServiceMessage, IMessageWithTraci /// public ulong? TracingId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -147,7 +155,7 @@ public UserLeaveGroupMessage(string userId, string groupName, ulong? tracingId = /// /// A waiting for ack user-join-group message. /// - public class UserJoinGroupWithAckMessage : ExtensibleServiceMessage, IMessageWithTracingId, IHasTtl, IAckableMessage + public class UserJoinGroupWithAckMessage : ExtensibleServiceMessage, IMessageWithTracingId, IHasTtl, IAckableMessage, IPartitionableMessage { /// /// Gets or sets the user Id. @@ -174,6 +182,8 @@ public class UserJoinGroupWithAckMessage : ExtensibleServiceMessage, IMessageWit /// public int AckId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -195,7 +205,7 @@ public UserJoinGroupWithAckMessage(string userId, string groupName, int ackId, i /// /// A waiting for ack user-leave-group message. /// - public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWithTracingId, IAckableMessage + public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWithTracingId, IAckableMessage, IPartitionableMessage { /// /// Gets or sets the user Id. @@ -217,6 +227,8 @@ public class UserLeaveGroupWithAckMessage : ExtensibleServiceMessage, IMessageWi /// public int AckId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -236,7 +248,7 @@ public UserLeaveGroupWithAckMessage(string userId, string groupName, int ackId, /// /// A waiting for ack join-group message. /// - public class JoinGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage, IMessageWithTracingId + public class JoinGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage, IMessageWithTracingId, IPartitionableMessage { /// /// Gets or sets the connection Id. @@ -258,6 +270,8 @@ public class JoinGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage /// public ulong? TracingId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -266,7 +280,6 @@ public class JoinGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage /// The tracing Id of the message. public JoinGroupWithAckMessage(string connectionId, string groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId) { - TracingId = tracingId; } /// @@ -288,7 +301,7 @@ public JoinGroupWithAckMessage(string connectionId, string groupName, int ackId, /// /// A waiting for ack leave-group message. /// - public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage, IMessageWithTracingId + public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessage, IMessageWithTracingId, IPartitionableMessage { /// /// Gets or sets the connection Id. @@ -310,6 +323,8 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag /// public ulong? TracingId { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Initializes a new instance of the class. /// @@ -318,7 +333,6 @@ public class LeaveGroupWithAckMessage : ExtensibleServiceMessage, IAckableMessag /// The tracing Id of the message. public LeaveGroupWithAckMessage(string connectionId, string groupName, ulong? tracingId = null): this(connectionId, groupName, 0, tracingId) { - TracingId = tracingId; } /// diff --git a/src/Microsoft.Azure.SignalR.Protocols/MulticastDataMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/MulticastDataMessage.cs index a4d4d0196..3ac4cb31e 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/MulticastDataMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/MulticastDataMessage.cs @@ -47,8 +47,9 @@ protected MulticastDataMessage(IDictionary> payload /// /// A data message which will be sent to multiple connections. /// - public class MultiConnectionDataMessage : MulticastDataMessage + public class MultiConnectionDataMessage : MulticastDataMessage, IPartitionableMessage { + private static readonly byte Key = GeneratePartitionKey(nameof(MultiConnectionDataMessage)); /// /// Initializes a new instance of the class. /// @@ -65,18 +66,22 @@ public MultiConnectionDataMessage(IReadOnlyList connectionList, /// Gets or sets the list of connections which will receive this message. /// public IReadOnlyList ConnectionList { get; set; } + + public byte PartitionKey => Key; } /// /// A data message which will be sent to a user. /// - public class UserDataMessage : MulticastDataMessage + public class UserDataMessage : MulticastDataMessage, IPartitionableMessage { /// /// Gets or sets the user Id. /// public string UserId { get; set; } + public byte PartitionKey => GeneratePartitionKey(UserId); + /// /// Initializes a new instance of the class. /// @@ -92,8 +97,9 @@ public UserDataMessage(string userId, IDictionary> /// /// A data message which will be sent to multiple users. /// - public class MultiUserDataMessage : MulticastDataMessage + public class MultiUserDataMessage : MulticastDataMessage, IPartitionableMessage { + private static readonly byte Key = GeneratePartitionKey(nameof(MultiUserDataMessage)); /// /// Initializes a new instance of the class. /// @@ -109,18 +115,23 @@ public MultiUserDataMessage(IReadOnlyList userList, IDictionary public IReadOnlyList UserList { get; set; } + + public byte PartitionKey => Key; } /// /// A data message which will be broadcasted. /// - public class BroadcastDataMessage : MulticastDataMessage + public class BroadcastDataMessage : MulticastDataMessage, IPartitionableMessage { + private static readonly byte Key = GeneratePartitionKey(nameof(BroadcastDataMessage)); /// /// Gets or sets the list of excluded connection Ids. /// public IReadOnlyList ExcludedList { get; set; } + public byte PartitionKey => Key; + /// /// Initializes a new instance of the class. /// @@ -145,13 +156,15 @@ public BroadcastDataMessage(IReadOnlyList excludedList, IDictionary /// A data message which will be broadcasted within a group. /// - public class GroupBroadcastDataMessage : MulticastDataMessage + public class GroupBroadcastDataMessage : MulticastDataMessage, IPartitionableMessage { /// /// Gets or sets the group name. /// public string GroupName { get; set; } + public byte PartitionKey => GeneratePartitionKey(GroupName); + /// /// Gets or sets the list of excluded connection Ids. /// @@ -196,13 +209,17 @@ public GroupBroadcastDataMessage(string groupName, IReadOnlyList exclude /// /// A data message which will be broadcasted within multiple groups. /// - public class MultiGroupBroadcastDataMessage : MulticastDataMessage + public class MultiGroupBroadcastDataMessage : MulticastDataMessage, IPartitionableMessage { + private static readonly byte Key = GeneratePartitionKey(nameof(MultiGroupBroadcastDataMessage)); + /// /// Gets or sets the list of group names. /// public IReadOnlyList GroupList { get; set; } + public byte PartitionKey => Key; + /// /// Initializes a new instance of the class. /// @@ -218,7 +235,7 @@ public MultiGroupBroadcastDataMessage(IReadOnlyList groupList, IDictiona /// /// A data message to indicate a client invocation request. /// - public class ClientInvocationMessage : MultiPayloadDataMessage + public class ClientInvocationMessage : MultiPayloadDataMessage, IPartitionableMessage { /// /// Initialize a new instance of class. @@ -250,5 +267,7 @@ public ClientInvocationMessage(string invocationId, string connectionId, string /// Gets or sets the caller server Id that init the client invocation. /// public string CallerServerId { get; set; } + + public byte PartitionKey => GeneratePartitionKey(ConnectionId); } } \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs index 732af89a9..09df35130 100644 --- a/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs +++ b/src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs @@ -54,6 +54,11 @@ public interface IAckableMessage int AckId { get; set; } } + public interface IPartitionableMessage + { + byte PartitionKey { get; } + } + /// /// Base class of messages between Azure SignalR Service and SDK. /// @@ -64,6 +69,11 @@ public abstract class ServiceMessage /// The default implementation is a shallow copy as it fits the current needs. /// public virtual ServiceMessage Clone() => MemberwiseClone() as ServiceMessage; + + public static byte GeneratePartitionKey(string input) + { + return (byte)((input?.GetHashCode() ?? 0) & 0xFF); + } } /// diff --git a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnectionContainerBaseTests.cs b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnectionContainerBaseTests.cs index 353e7099b..45aa52438 100644 --- a/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnectionContainerBaseTests.cs +++ b/test/Microsoft.Azure.SignalR.Common.Tests/ServiceConnectionContainerBaseTests.cs @@ -101,4 +101,41 @@ public void TestStrongConnectionStatus() Assert.True(endpoint1.Online); } } + + [Fact] + public void TestWriteMessageOrder() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Warning, expectedErrors: e => true, + logChecker: s => + { + Assert.Single(s); + Assert.Equal("EndpointOffline", s[0].Write.EventId.Name); + return true; + })) + { + var endpoint1 = new TestHubServiceEndpoint(); + var conn1 = new TestServiceConnection(); + var scf = new TestServiceConnectionFactory(endpoint1 => conn1); + var container = new StrongServiceConnectionContainer(scf, 5, null, endpoint1, loggerFactory.CreateLogger(nameof(TestStrongConnectionStatus))); + + // When init, consider the endpoint as online + // TODO: improve the logic + Assert.True(endpoint1.Online); + + conn1.SetStatus(ServiceConnectionStatus.Connecting); + Assert.True(endpoint1.Online); + + conn1.SetStatus(ServiceConnectionStatus.Connected); + Assert.True(endpoint1.Online); + + conn1.SetStatus(ServiceConnectionStatus.Disconnected); + Assert.False(endpoint1.Online); + + conn1.SetStatus(ServiceConnectionStatus.Connecting); + Assert.False(endpoint1.Online); + + conn1.SetStatus(ServiceConnectionStatus.Connected); + Assert.True(endpoint1.Online); + } + } } diff --git a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs index b37279871..b84552832 100644 --- a/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs +++ b/test/Microsoft.Azure.SignalR.Protocols.Tests/ServiceProtocolFacts.cs @@ -9,7 +9,8 @@ using System.Text; using Microsoft.Extensions.Primitives; - +using Moq; +using Newtonsoft.Json.Linq; using Xunit; namespace Microsoft.Azure.SignalR.Protocol.Tests @@ -810,6 +811,112 @@ public void ParseMessageWithExtraData() Assert.Equal(expectedMessage, openConnectionMessage, ServiceMessageEqualityComparer.Instance); } + [Fact] + public void PartitionKeyTest() + { + // group messages for the same group should have the same partition key + var group = "group1"; + var connectionId = Guid.NewGuid().ToString(); + var userId = "user1"; + var value = new ReadOnlyMemory(new byte[] { 1, 2, 3 }); + var payloads = new Dictionary> { ["1"] = value }; + byte? pk = null; + var groupMessages = new IPartitionableMessage[] { + new GroupBroadcastDataMessage(group, payloads), + new JoinGroupMessage(connectionId, group), + new JoinGroupWithAckMessage(connectionId, group), + new UserJoinGroupMessage(userId, group), + new UserJoinGroupWithAckMessage(userId, group, 0), + new LeaveGroupMessage(connectionId, group), + new LeaveGroupWithAckMessage(connectionId, group), + new UserLeaveGroupMessage(userId, group), + new UserLeaveGroupWithAckMessage(userId, group, 0), + new CheckGroupExistenceWithAckMessage(group), + new CheckUserInGroupWithAckMessage(userId, group), + new CloseGroupConnectionsWithAckMessage(group, 0) + }; + + foreach (var i in groupMessages) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + + var userMessages = new IPartitionableMessage[] + { + new UserDataMessage(userId, payloads), + new CheckUserExistenceWithAckMessage(userId), + }; + + pk = null; + foreach (var i in userMessages) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + + var broadcastMessages = new IPartitionableMessage[] + { + new BroadcastDataMessage(payloads), + new BroadcastDataMessage(payloads), + new BroadcastDataMessage(payloads), + new BroadcastDataMessage(payloads), + new BroadcastDataMessage(payloads), + }; + + pk = null; + foreach (var i in broadcastMessages) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + + var mcm = new IPartitionableMessage[] + { + new MultiConnectionDataMessage([], payloads), + new MultiConnectionDataMessage([], payloads), + new MultiConnectionDataMessage([], payloads), + new MultiConnectionDataMessage([], payloads), + new MultiConnectionDataMessage([], payloads), + }; + pk = null; + foreach (var i in mcm) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + + var mgm = new IPartitionableMessage[] + { + new MultiGroupBroadcastDataMessage([], payloads), + new MultiGroupBroadcastDataMessage([], payloads), + new MultiGroupBroadcastDataMessage([], payloads), + new MultiGroupBroadcastDataMessage([], payloads), + new MultiGroupBroadcastDataMessage([], payloads), + }; + pk = null; + foreach (var i in mgm) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + + var mum = new IPartitionableMessage[] + { + new MultiUserDataMessage([], payloads), + new MultiUserDataMessage([], payloads), + new MultiUserDataMessage([], payloads), + new MultiUserDataMessage([], payloads), + new MultiUserDataMessage([], payloads), + }; + pk = null; + foreach (var i in mum) + { + pk ??= i.PartitionKey; + Assert.Equal(pk, i.PartitionKey); + } + } + private static byte ArrayBytes(int size) { return (byte)(0x90 | size); diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTest.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTest.cs index 7fbb2c1b6..f11350128 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTest.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerTest.cs @@ -22,7 +22,7 @@ public async Task TestServiceConnectionOffline() var container = new StrongServiceConnectionContainer(factory, 3, 3, hubServiceEndpoint, NullLogger.Instance); Assert.True(factory.CreatedConnections.TryGetValue(hubServiceEndpoint, out var conns)); - var connections = conns.Select(x => (TestServiceConnection)x); + var connections = conns.Select(x => (TestServiceConnection)x).ToArray(); foreach (var connection in connections) { @@ -32,7 +32,7 @@ public async Task TestServiceConnectionOffline() // write 100 messages. for (var i = 0; i < 100; i++) { - var message = new ConnectionDataMessage("bar", new byte[12]); + var message = new ConnectionDataMessage(i.ToString(), new byte[12]); await container.WriteAsync(message); } @@ -43,12 +43,12 @@ public async Task TestServiceConnectionOffline() messageCount.TryAdd(connection.ConnectionId, connection.ReceivedMessages.Count); } - connections.First().SetStatus(ServiceConnectionStatus.Disconnected); + connections[0].SetStatus(ServiceConnectionStatus.Disconnected); // write 100 more messages. for (var i = 0; i < 100; i++) { - var message = new ConnectionDataMessage("bar", new byte[12]); + var message = new ConnectionDataMessage(i.ToString(), new byte[12]); await container.WriteAsync(message); } @@ -66,4 +66,47 @@ public async Task TestServiceConnectionOffline() index++; } } + + [Fact] + public async Task TestServiceConnectionStickyWrites() + { + var factory = new TestServiceConnectionFactory(); + var hubServiceEndpoint = new HubServiceEndpoint("foo", null, new TestServiceEndpoint()); + + var container = new StrongServiceConnectionContainer(factory, 3, 3, hubServiceEndpoint, NullLogger.Instance); + + Assert.True(factory.CreatedConnections.TryGetValue(hubServiceEndpoint, out var conns)); + var connections = conns.Select(x => (TestServiceConnection)x); + + foreach (var connection in connections) + { + connection.SetStatus(ServiceConnectionStatus.Connected); + } + + // write 100 messages. + for (var i = 0; i < 100; i++) + { + var message = new ConnectionDataMessage(i.ToString(), new byte[12]); + await container.WriteAsync(message); + } + + var messageCount = new Dictionary(); + foreach (var connection in connections) + { + Assert.NotEmpty(connection.ReceivedMessages); + messageCount.TryAdd(connection.ConnectionId, connection.ReceivedMessages.Count); + } + + // write 100 messages with the same connectionIds should double the message count for each service connection + for (var i = 0; i < 100; i++) + { + var message = new ConnectionDataMessage(i.ToString(), new byte[12]); + await container.WriteAsync(message); + } + + foreach (var connection in connections) + { + Assert.Equal(messageCount[connection.ConnectionId] * 2, connection.ReceivedMessages.Count); + } + } }