Skip to content

Commit

Permalink
Add more tests and minor fix“
Browse files Browse the repository at this point in the history
  • Loading branch information
vicancy committed Aug 16, 2024
1 parent 48484c2 commit 876ac8d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ public async Task StartAsync(string target = null)
_ = UpdateAzureIdentityAsync(key, syncTimer);
}
await ProcessIncomingAsync(connection);

// mark the status as Disconnected so that no one will write to this connection anymore
Status = ServiceConnectionStatus.Disconnected;
}
finally
{
Expand All @@ -195,10 +198,7 @@ public async Task StartAsync(string target = null)
finally
{
// wait until all the connections are cleaned up to close the outgoing pipe
// mark the status as Disconnected so that no one will write to this connection anymore
// Don't allow write anymore when the connection is disconnected
Status = ServiceConnectionStatus.Disconnected;

await _writeLock.WaitAsync();
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ private IServiceConnection SelectConnection(ServiceMessage message)
var containers = ClientConnectionScope.OutboundServiceConnections;
if (!(containers.TryGetValue(Endpoint.UniqueIndex, out var connectionWeakReference)
&& connectionWeakReference.TryGetTarget(out connection)
&& connection != null))
&& IsActiveConnection(connection)))
{
connection = GetRandomActiveConnection();
ClientConnectionScope.OutboundServiceConnections[Endpoint.UniqueIndex] = new WeakReference<IServiceConnection>(connection);
Expand All @@ -476,13 +476,13 @@ private IServiceConnection SelectConnection(ServiceMessage message)
return new WeakReference<IServiceConnection>(connection);
}, (_, reference) =>
{
if (reference.TryGetTarget(out connection) && connection != null)
if (reference.TryGetTarget(out connection) && IsActiveConnection(connection))
{
return reference;
}
lock (reference)
{
if (reference.TryGetTarget(out connection) && connection != null)
if (reference.TryGetTarget(out connection) && IsActiveConnection(connection))
{
return reference;
}
Expand All @@ -508,6 +508,11 @@ private IServiceConnection SelectConnection(ServiceMessage message)
return connection;
}

private bool IsActiveConnection(IServiceConnection connection)
{
return connection != null && connection.Status == ServiceConnectionStatus.Connected;
}

private IServiceConnection GetRandomActiveConnection()
{
var currentConnections = ServiceConnections;
Expand All @@ -519,11 +524,11 @@ private IServiceConnection GetRandomActiveConnection()
var retry = 0;
var index = (initial & int.MaxValue) % count;
var direction = initial > 0 ? 1 : count - 1;
var connection = currentConnections[index];

while (retry < maxRetry)
{
if (connection != null && connection.Status == ServiceConnectionStatus.Connected)
var connection = currentConnections[index];
if (IsActiveConnection(connection))
{
return connection;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.Protocols/ConnectionMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public ConnectionDataMessage(string connectionId, ReadOnlySequence<byte> payload
{
Payload = payload;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("c." + connectionId);
PartitionKey = GeneratePartitionKey(connectionId);
}

/// <summary>
Expand Down
14 changes: 7 additions & 7 deletions src/Microsoft.Azure.SignalR.Protocols/GroupMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public JoinGroupMessage(string connectionId, string groupName, ulong? tracingId
ConnectionId = connectionId;
GroupName = groupName;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -73,7 +73,7 @@ public LeaveGroupMessage(string connectionId, string groupName, ulong? tracingId
ConnectionId = connectionId;
GroupName = groupName;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -115,7 +115,7 @@ public UserJoinGroupMessage(string userId, string groupName, ulong? tracingId =
UserId = userId;
GroupName = groupName;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -203,7 +203,7 @@ public UserJoinGroupWithAckMessage(string userId, string groupName, int ackId, i
TracingId = tracingId;
AckId = ackId;
Ttl = ttl;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -247,7 +247,7 @@ public UserLeaveGroupWithAckMessage(string userId, string groupName, int ackId,
GroupName = groupName;
TracingId = tracingId;
AckId = ackId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -301,7 +301,7 @@ public JoinGroupWithAckMessage(string connectionId, string groupName, int ackId,
GroupName = groupName;
AckId = ackId;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand Down Expand Up @@ -355,7 +355,7 @@ public LeaveGroupWithAckMessage(string connectionId, string groupName, int ackId
GroupName = groupName;
AckId = ackId;
TracingId = tracingId;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}
}
19 changes: 12 additions & 7 deletions src/Microsoft.Azure.SignalR.Protocols/MulticastDataMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ protected MulticastDataMessage(IDictionary<string, ReadOnlyMemory<byte>> payload
/// </summary>
public class MultiConnectionDataMessage : MulticastDataMessage, IPartitionableMessage
{
private static readonly byte Key = GeneratePartitionKey(nameof(MultiConnectionDataMessage));
/// <summary>
/// Initializes a new instance of the <see cref="MultiConnectionDataMessage"/> class.
/// </summary>
Expand All @@ -59,7 +60,7 @@ public MultiConnectionDataMessage(IReadOnlyList<string> connectionList,
IDictionary<string, ReadOnlyMemory<byte>> payloads, ulong? tracingId = null) : base(payloads, tracingId)
{
ConnectionList = connectionList;
PartitionKey = GeneratePartitionKey(nameof(MultiConnectionDataMessage));
PartitionKey = Key;
}

/// <summary>
Expand Down Expand Up @@ -91,7 +92,7 @@ public class UserDataMessage : MulticastDataMessage, IPartitionableMessage
public UserDataMessage(string userId, IDictionary<string, ReadOnlyMemory<byte>> payloads, ulong? tracingId = null) : base(payloads, tracingId)
{
UserId = userId;
PartitionKey = GeneratePartitionKey("u." + userId);
PartitionKey = GeneratePartitionKey(userId);
}
}

Expand All @@ -100,6 +101,7 @@ public UserDataMessage(string userId, IDictionary<string, ReadOnlyMemory<byte>>
/// </summary>
public class MultiUserDataMessage : MulticastDataMessage, IPartitionableMessage
{
private static readonly byte Key = GeneratePartitionKey(nameof(MultiUserDataMessage));
/// <summary>
/// Initializes a new instance of the <see cref="MultiUserDataMessage"/> class.
/// </summary>
Expand All @@ -109,7 +111,7 @@ public class MultiUserDataMessage : MulticastDataMessage, IPartitionableMessage
public MultiUserDataMessage(IReadOnlyList<string> userList, IDictionary<string, ReadOnlyMemory<byte>> payloads, ulong? tracingId = null) : base(payloads, tracingId)
{
UserList = userList;
PartitionKey = GeneratePartitionKey(nameof(MultiUserDataMessage));
PartitionKey = Key;
}

/// <summary>
Expand All @@ -125,6 +127,7 @@ public MultiUserDataMessage(IReadOnlyList<string> userList, IDictionary<string,
/// </summary>
public class BroadcastDataMessage : MulticastDataMessage, IPartitionableMessage
{
private static readonly byte Key = GeneratePartitionKey(nameof(BroadcastDataMessage));
/// <summary>
/// Gets or sets the list of excluded connection Ids.
/// </summary>
Expand All @@ -150,7 +153,7 @@ public BroadcastDataMessage(IDictionary<string, ReadOnlyMemory<byte>> payloads,
public BroadcastDataMessage(IReadOnlyList<string> excludedList, IDictionary<string, ReadOnlyMemory<byte>> payloads, ulong? tracingId = null) : base(payloads, tracingId)
{
ExcludedList = excludedList;
PartitionKey = GeneratePartitionKey(nameof(BroadcastDataMessage));
PartitionKey = Key;
}
}

Expand Down Expand Up @@ -204,7 +207,7 @@ public GroupBroadcastDataMessage(string groupName, IReadOnlyList<string> exclude
{
GroupName = groupName;
ExcludedList = excludedList;
PartitionKey = GeneratePartitionKey("g." + groupName);
PartitionKey = GeneratePartitionKey(groupName);
}
}

Expand All @@ -213,6 +216,8 @@ public GroupBroadcastDataMessage(string groupName, IReadOnlyList<string> exclude
/// </summary>
public class MultiGroupBroadcastDataMessage : MulticastDataMessage, IPartitionableMessage
{
private static readonly byte Key = GeneratePartitionKey(nameof(MultiGroupBroadcastDataMessage));

/// <summary>
/// Gets or sets the list of group names.
/// </summary>
Expand All @@ -229,7 +234,7 @@ public class MultiGroupBroadcastDataMessage : MulticastDataMessage, IPartitionab
public MultiGroupBroadcastDataMessage(IReadOnlyList<string> groupList, IDictionary<string, ReadOnlyMemory<byte>> payloads, ulong? tracingId = null) : base(payloads, tracingId)
{
GroupList = groupList;
PartitionKey = GeneratePartitionKey(nameof(MultiGroupBroadcastDataMessage));
PartitionKey = Key;
}
}

Expand All @@ -252,7 +257,7 @@ public ClientInvocationMessage(string invocationId, string connectionId, string
InvocationId = invocationId;
ConnectionId = connectionId;
CallerServerId = callerServerId;
PartitionKey = GeneratePartitionKey("c." + connectionId);
PartitionKey = GeneratePartitionKey(connectionId);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Azure.SignalR.Protocols/ServiceMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class ServiceMessage
/// </summary>
public virtual ServiceMessage Clone() => MemberwiseClone() as ServiceMessage;

public byte GeneratePartitionKey(string input)
public static byte GeneratePartitionKey(string input)
{
return (byte)(input.GetHashCode() & 0xFF);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public async Task TestServiceConnectionOffline()
var container = new StrongServiceConnectionContainer(factory, 3, 3, hubServiceEndpoint, NullLogger.Instance);

Assert.True(factory.CreatedConnections.TryGetValue(hubServiceEndpoint, out var conns));
var connections = conns.Select(x => (TestServiceConnection)x);
var connections = conns.Select(x => (TestServiceConnection)x).ToArray();

foreach (var connection in connections)
{
Expand All @@ -32,7 +32,7 @@ public async Task TestServiceConnectionOffline()
// write 100 messages.
for (var i = 0; i < 100; i++)
{
var message = new ConnectionDataMessage("bar", new byte[12]);
var message = new ConnectionDataMessage(i.ToString(), new byte[12]);
await container.WriteAsync(message);
}

Expand All @@ -43,12 +43,12 @@ public async Task TestServiceConnectionOffline()
messageCount.TryAdd(connection.ConnectionId, connection.ReceivedMessages.Count);
}

connections.First().SetStatus(ServiceConnectionStatus.Disconnected);
connections[0].SetStatus(ServiceConnectionStatus.Disconnected);

// write 100 more messages.
for (var i = 0; i < 100; i++)
{
var message = new ConnectionDataMessage("bar", new byte[12]);
var message = new ConnectionDataMessage(i.ToString(), new byte[12]);
await container.WriteAsync(message);
}

Expand All @@ -66,4 +66,47 @@ public async Task TestServiceConnectionOffline()
index++;
}
}

[Fact]
public async Task TestServiceConnectionStickyWrites()
{
var factory = new TestServiceConnectionFactory();
var hubServiceEndpoint = new HubServiceEndpoint("foo", null, new TestServiceEndpoint());

var container = new StrongServiceConnectionContainer(factory, 3, 3, hubServiceEndpoint, NullLogger.Instance);

Assert.True(factory.CreatedConnections.TryGetValue(hubServiceEndpoint, out var conns));
var connections = conns.Select(x => (TestServiceConnection)x);

foreach (var connection in connections)
{
connection.SetStatus(ServiceConnectionStatus.Connected);
}

// write 100 messages.
for (var i = 0; i < 100; i++)
{
var message = new ConnectionDataMessage(i.ToString(), new byte[12]);
await container.WriteAsync(message);
}

var messageCount = new Dictionary<string, int>();
foreach (var connection in connections)
{
Assert.NotEmpty(connection.ReceivedMessages);
messageCount.TryAdd(connection.ConnectionId, connection.ReceivedMessages.Count);
}

// write 100 messages with the same connectionIds should double the message count for each service connection
for (var i = 0; i < 100; i++)
{
var message = new ConnectionDataMessage(i.ToString(), new byte[12]);
await container.WriteAsync(message);
}

foreach (var connection in connections)
{
Assert.Equal(messageCount[connection.ConnectionId] * 2, connection.ReceivedMessages.Count);
}
}
}

0 comments on commit 876ac8d

Please sign in to comment.