Skip to content

Commit

Permalink
Sticky to one service connection if not in async scope (#2023)
Browse files Browse the repository at this point in the history
* Add a byte partitionkey for unscoped message sending
  • Loading branch information
vicancy authored Aug 21, 2024
1 parent 2d3b312 commit 4d7792b
Show file tree
Hide file tree
Showing 13 changed files with 443 additions and 110 deletions.
4 changes: 4 additions & 0 deletions samples/ChatSample/ChatSample.Net60/Program.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ChatSample.Net60;
using ChatSample.Net60.Hubs;

var builder = WebApplication.CreateBuilder(args);
Expand All @@ -6,6 +7,9 @@
builder.Services.AddRazorPages();
builder.Services.AddSignalR().AddAzureSignalR();

// uncomment for streaming outside the scope
// builder.Services.AddHostedService<StreamingService>();

var app = builder.Build();

// Configure the HTTP request pipeline.
Expand Down
50 changes: 50 additions & 0 deletions samples/ChatSample/ChatSample.Net60/StreamingService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.


using ChatSample.Net60.Hubs;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Logging;

namespace ChatSample.Net60
{
public class StreamingService : IHostedService
{
private readonly IHubContext<ChatHub> _hubContext;
private readonly ILogger<StreamingService> _logger;

public StreamingService(IHubContext<ChatHub> hubContext, ILogger<StreamingService> logger) {
_hubContext = hubContext;
_logger = logger;
}
public Task StartAsync(CancellationToken cancellationToken)
{
return Task.Factory.StartNew(() => StreamingTask(cancellationToken), TaskCreationOptions.LongRunning);
}

public Task StopAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}

private async Task StreamingTask(CancellationToken cancellationToken)
{
long counter = 0;

_logger.LogInformation("Waiting");

await Task.Delay(5000);

_logger.LogInformation("Spamming");

while (!cancellationToken.IsCancellationRequested)
{
counter++;

await _hubContext.Clients.All.SendAsync("ReceiveMessage", counter, counter);

await Task.Delay(1);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ public async Task StartAsync(string target = null)
}
finally
{
// mark the status as Disconnected so that no one will write to this connection anymore
Status = ServiceConnectionStatus.Disconnected;
syncTimer?.Stop();

// when ProcessIncoming completes, clean up the connection
Expand All @@ -195,10 +197,7 @@ public async Task StartAsync(string target = null)
finally
{
// wait until all the connections are cleaned up to close the outgoing pipe
// mark the status as Disconnected so that no one will write to this connection anymore
// Don't allow write anymore when the connection is disconnected
Status = ServiceConnectionStatus.Disconnected;

await _writeLock.WaitAsync();
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.SignalR.Common;
Expand All @@ -24,11 +26,16 @@ internal abstract class ServiceConnectionContainerBase : IServiceConnectionConta

private static readonly int MaxReconnectBackOffInternalInMilliseconds = 1000;

private static readonly TimeSpan MessageWriteRetryDelay = TimeSpan.FromMilliseconds(200);
private static readonly int MessageWriteMaxRetry = 3;

// Give (interval * 3 + 1) delay when check value expire.
private static readonly long DefaultServersPingTimeoutTicks = Stopwatch.Frequency * ((long)Constants.Periods.DefaultServersPingInterval.TotalSeconds * 3 + 1);

private static readonly Tuple<string, long> DefaultServersTagContext = new Tuple<string, long>(string.Empty, 0);

private readonly IReadOnlyDictionary<byte, StrongBox<WeakReference<IServiceConnection>>> _partitionedCache;

private readonly BackOffPolicy _backOffPolicy = new BackOffPolicy();

private readonly object _lock = new object();
Expand Down Expand Up @@ -155,6 +162,8 @@ protected ServiceConnectionContainerBase(IServiceConnectionFactory serviceConnec
}

_serversPing = new CustomizedPingTimer(Logger, Constants.CustomizedPingTimer.Servers, WriteServersPingAsync, Constants.Periods.DefaultServersPingInterval, Constants.Periods.DefaultServersPingInterval);

_partitionedCache = Enumerable.Range(0, 256).ToDictionary(i => (byte)i, i => new StrongBox<WeakReference<IServiceConnection>>(new WeakReference<IServiceConnection>(null)));
}

public event Action<StatusChange> ConnectionStatusChanged;
Expand Down Expand Up @@ -216,7 +225,7 @@ public void HandleAck(AckMessage ackMessage)

public virtual Task WriteAsync(ServiceMessage serviceMessage)
{
return WriteToScopedOrRandomAvailableConnection(serviceMessage);
return WriteMessageAsync(serviceMessage);
}

public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage, CancellationToken cancellationToken = default)
Expand All @@ -233,7 +242,7 @@ public async Task<bool> WriteAckableMessageAsync(ServiceMessage serviceMessage,
// whereas ackable ones complete upon full roundtrip of the message and the ack (or timeout).
// Therefore sending them over different connections creates a possibility for processing them out of original order.
// By sending both message types over the same connection we ensure that they are sent (and processed) in their original order.
await WriteToScopedOrRandomAvailableConnection(serviceMessage);
await WriteMessageAsync(serviceMessage);

var status = await task;
return AckHandler.HandleAckStatus(ackableMessage, status);
Expand Down Expand Up @@ -414,6 +423,115 @@ protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdow
Log.TimeoutWaitingForFinAck(Logger, retry);
}

private async Task WriteMessageAsync(ServiceMessage serviceMessage)
{
var connection = SelectConnection(serviceMessage);

var retry = 0;
var maxRetry = MessageWriteMaxRetry;
var delay = MessageWriteRetryDelay;
while (true)
{
try
{
await connection.WriteAsync(serviceMessage);
return;
}
catch (ServiceConnectionNotActiveException)
{
// enter the re-select logic
retry++;
if (retry == maxRetry)
{
throw;
}

await Task.Delay(delay);
connection = SelectConnection(serviceMessage);
}
}
}

private IServiceConnection SelectConnection(ServiceMessage message)
{
IServiceConnection connection = null;
if (ClientConnectionScope.IsScopeEstablished)
{
// see if the execution context already has the connection stored for this container
var containers = ClientConnectionScope.OutboundServiceConnections;
if (!(containers.TryGetValue(Endpoint.UniqueIndex, out var connectionWeakReference)
&& connectionWeakReference.TryGetTarget(out connection)
&& IsActiveConnection(connection)))
{
connection = GetRandomActiveConnection();
ClientConnectionScope.OutboundServiceConnections[Endpoint.UniqueIndex] = new WeakReference<IServiceConnection>(connection);
}
}
else
{
// if it is not in scope
// if message is partitionable, use the container's partition cache, otherwise use a random connection
if (message is IPartitionableMessage partitionable)
{
var box = _partitionedCache[partitionable.PartitionKey];
if (!box.Value.TryGetTarget(out connection) || !IsActiveConnection(connection))
{
lock (box)
{
if (!box.Value.TryGetTarget(out connection) || !IsActiveConnection(connection))
{
connection = GetRandomActiveConnection();
box.Value.SetTarget(connection);
}
}
}
}
else
{
connection = GetRandomActiveConnection();
}
}

if (connection == null)
{
throw new ServiceConnectionNotActiveException();
}

return connection;
}

private bool IsActiveConnection(IServiceConnection connection)
{
return connection != null && connection.Status == ServiceConnectionStatus.Connected;
}

private IServiceConnection GetRandomActiveConnection()
{
var currentConnections = ServiceConnections;

// go through all the connections, it can be useful when one of the remote service instances is down
var count = currentConnections.Count;
var initial = StaticRandom.Next(-count, count);
var maxRetry = count;
var retry = 0;
var index = (initial & int.MaxValue) % count;
var direction = initial > 0 ? 1 : count - 1;

while (retry < maxRetry)
{
var connection = currentConnections[index];
if (IsActiveConnection(connection))
{
return connection;
}

retry++;
index = (index + direction) % count;
}

return null;
}

private async Task RestartFixedServiceConnectionCoreAsync(int index)
{
if (_terminated)
Expand Down Expand Up @@ -481,81 +599,6 @@ private void OnConnectionStatusChanged(StatusChange obj)
}
}

private async Task WriteToScopedOrRandomAvailableConnection(ServiceMessage serviceMessage)
{
// ServiceConnections can change the collection underneath so we make a local copy and pass it along
var currentConnections = ServiceConnections;

if (ClientConnectionScope.IsScopeEstablished)
{
// see if the execution context already has the connection stored for this container
var containers = ClientConnectionScope.OutboundServiceConnections;
Debug.Assert(containers != null);
containers.TryGetValue(Endpoint.UniqueIndex, out var connectionWeakReference);
IServiceConnection connection = null;
connectionWeakReference?.TryGetTarget(out connection);

var connectionUsed = await WriteWithRetry(serviceMessage, connection, currentConnections);

// Todo:
// There is currently no synchronization when persisting selected connection in ClientConnectionScope.
// This is only a concern when there are concurrent writes involved and when one of the following is true:
// - we need to change the selected connection (e.g. the currently persisted connection status is bad)
// - we need to make the initial connection selection (e.g. no secondary connection in async local yet)
// This lack of synchronization can lead to using multiple connections and cause out of order messages.

// Try to persist the connection choice for the subsequent calls within the same async flow
if (connectionUsed != connection)
{
ClientConnectionScope.OutboundServiceConnections[Endpoint.UniqueIndex] = new WeakReference<IServiceConnection>(connectionUsed);
}
}
else
{
await WriteWithRetry(serviceMessage, null, currentConnections);
}
}

private async Task<IServiceConnection> WriteWithRetry(ServiceMessage serviceMessage, IServiceConnection connection, List<IServiceConnection> currentConnections)
{
// go through all the connections, it can be useful when one of the remote service instances is down
var count = currentConnections.Count;
var initial = StaticRandom.Next(-count, count);
var maxRetry = count;
var retry = 0;
var index = (initial & int.MaxValue) % count;
var direction = initial > 0 ? 1 : count - 1;

// ensure a full sweep starting with the connection flowed with the async context
while (retry <= maxRetry)
{
if (connection != null && connection.Status == ServiceConnectionStatus.Connected)
{
try
{
// still possible the connection is not valid
await connection.WriteAsync(serviceMessage);
return connection;
}
catch (ServiceConnectionNotActiveException)
{
if (retry == maxRetry - 1)
{
throw;
}
}
}

// try current index instead
connection = currentConnections[index];

retry++;
index = (index + direction) % count;
}

throw new ServiceConnectionNotActiveException();
}

private IEnumerable<IServiceConnection> CreateFixedServiceConnection(int count)
{
for (int i = 0; i < count; i++)
Expand Down
10 changes: 7 additions & 3 deletions src/Microsoft.Azure.SignalR.Protocols/CheckWithAckMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected CheckWithAckMessage(int ackId, ulong? tracingId)
/// <summary>
/// A waiting for ack check-user-in-group message.
/// </summary>
public class CheckUserInGroupWithAckMessage : CheckWithAckMessage
public class CheckUserInGroupWithAckMessage : CheckWithAckMessage, IPartitionableMessage
{
/// <summary>
/// Gets or sets the user Id.
Expand All @@ -39,6 +39,7 @@ public class CheckUserInGroupWithAckMessage : CheckWithAckMessage
/// Gets or sets the group name.
/// </summary>
public string GroupName { get; set; }
public byte PartitionKey => GeneratePartitionKey(GroupName);

/// <summary>
/// Initializes a new instance of the <see cref="CheckUserInGroupWithAckMessage"/> class.
Expand All @@ -57,13 +58,15 @@ public CheckUserInGroupWithAckMessage(string userId, string groupName, int ackId
/// <summary>
/// A waiting for ack check-any-connection-in-group message.
/// </summary>
public class CheckGroupExistenceWithAckMessage : CheckWithAckMessage
public class CheckGroupExistenceWithAckMessage : CheckWithAckMessage, IPartitionableMessage
{
/// <summary>
/// Gets or sets the group name.
/// </summary>
public string GroupName { get; set; }

public byte PartitionKey => GeneratePartitionKey(GroupName);

/// <summary>
/// Initializes a new instance of the <see cref="CheckGroupExistenceWithAckMessage"/> class.
/// </summary>
Expand Down Expand Up @@ -101,12 +104,13 @@ public CheckConnectionExistenceWithAckMessage(string connectionId, int ackId = 0
/// <summary>
/// A waiting for ack check-user-existence message.
/// </summary>
public class CheckUserExistenceWithAckMessage : CheckWithAckMessage
public class CheckUserExistenceWithAckMessage : CheckWithAckMessage, IPartitionableMessage
{
/// <summary>
/// Gets or sets the user Id.
/// </summary>
public string UserId { get; set; }
public byte PartitionKey => GeneratePartitionKey(UserId);

/// <summary>
/// Initializes a new instance of the <see cref="CheckUserExistenceWithAckMessage"/> class.
Expand Down
Loading

0 comments on commit 4d7792b

Please sign in to comment.