Skip to content

Commit

Permalink
remove IConnectionMigrationFeature on normal close
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Oct 17, 2024
1 parent 5caea7f commit 26ae6c3
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message)
var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext) as ClientConnectionContext;
connection.ServiceConnection = this;

connection.Features.Set<IConnectionMigrationFeature>(null);
if (message.Headers.TryGetValue(Constants.AsrsMigrateFrom, out var from))
{
connection.Features.Set<IConnectionMigrationFeature>(new ConnectionMigrationFeature(from, ServerId));
Expand Down Expand Up @@ -184,6 +185,8 @@ protected override Task OnClientDisconnectedAsync(CloseConnectionMessage message
{
if (_clientConnectionManager.TryRemoveClientConnection(message.ConnectionId, out var c) && c is ClientConnectionContext connection)
{
connection.Features.Set<IConnectionMigrationFeature>(null);

if (message.Headers.TryGetValue(Constants.AsrsMigrateTo, out var to))
{
connection.AbortOnClose = false;
Expand Down
44 changes: 42 additions & 2 deletions test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,45 @@ public async Task TestCloseConnectionMessageWithMigrateOut()
await connection.StopAsync();
}

[Fact]
public async Task TestMigrateInConnectionAndNormalClose()
{
var clientConnectionFactory = new TestClientConnectionFactory();
var clientInvocationManager = new DefaultClientInvocationManager();
var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, clientInvocationManager: clientInvocationManager);
_ = connection.StartAsync();
await connection.ConnectionInitializedTask.OrTimeout();

var openConnectionMessage = new OpenConnectionMessage("foo", Array.Empty<Claim>()) { Protocol = "json" };
openConnectionMessage.Headers.Add(Constants.AsrsMigrateFrom, "another-server");
_ = connection.WriteFromServiceAsync(openConnectionMessage);
await connection.ClientConnectedTask.OrTimeout();

Assert.Equal(1, clientConnectionFactory.Connections.Count);
var clientConnection = clientConnectionFactory.Connections[0];
var feature = clientConnection.Features.Get<IConnectionMigrationFeature>();
Assert.NotNull(feature);
Assert.Equal("another-server", feature.MigrateFrom);

// write a handshake response
var message = new SignalRProtocol.HandshakeResponseMessage("");
SignalRProtocol.HandshakeProtocol.WriteResponseMessage(message, clientConnection.Transport.Output);
await clientConnection.Transport.Output.FlushAsync();

// signalr handshake response should be skipped.
await Assert.ThrowsAsync<TimeoutException>(async () => await connection.ExpectSignalRMessage(SignalRProtocol.HandshakeResponseMessage.Empty).OrTimeout(1000));

// write close connection message
await connection.WriteFromServiceAsync(new CloseConnectionMessage(clientConnection.ConnectionId));

// wait until app task completed.
await clientConnection.LifetimeTask;

await connection.StopAsync();

Assert.Null(clientConnection.Features.Get<IConnectionMigrationFeature>());
}

[Fact]
public async Task TestCloseConnectionMessage()
{
Expand All @@ -143,7 +182,6 @@ public async Task TestCloseConnectionMessage()
await connection.WriteFromServiceAsync(new CloseConnectionMessage(clientConnection.ConnectionId));

// wait until app task completed.
await Assert.ThrowsAsync<TimeoutException>(async () => await clientConnection.LifetimeTask.OrTimeout(1000));
await clientConnection.LifetimeTask;

await connection.ExpectSignalRMessage(SignalRProtocol.HandshakeResponseMessage.Empty).OrTimeout(1000);
Expand Down Expand Up @@ -516,6 +554,7 @@ public async Task ExpectSignalRMessage<T>(T message, string connectionId = null)
Assert.IsType<T>(actual);
}
_payload = payload;
throw new Exception(_payload.ToString());
}

public void CompleteWriteFromService()
Expand Down Expand Up @@ -561,7 +600,8 @@ private async Task<ReadOnlySequence<byte>> GetPayloadAsync(string connectionId =
{
Assert.Equal(connectionId, dataMessage.ConnectionId);
}
Reader.AdvanceTo(buffer.Start);
Reader.AdvanceTo(buffer.Start, buffer.End);

return dataMessage.Payload;
}
else
Expand Down

0 comments on commit 26ae6c3

Please sign in to comment.