From 26ae6c371fd3680e7deaa1aa3eb4ec93bb9fe1c0 Mon Sep 17 00:00:00 2001 From: Terence Fan Date: Thu, 17 Oct 2024 12:36:40 +0800 Subject: [PATCH] remove `IConnectionMigrationFeature` on normal close --- .../ServerConnections/ServiceConnection.cs | 3 ++ .../ServiceMessageTests.cs | 44 ++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index a139cf943..811ddd206 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -148,6 +148,7 @@ protected override Task OnClientConnectedAsync(OpenConnectionMessage message) var connection = _clientConnectionFactory.CreateConnection(message, ConfigureContext) as ClientConnectionContext; connection.ServiceConnection = this; + connection.Features.Set(null); if (message.Headers.TryGetValue(Constants.AsrsMigrateFrom, out var from)) { connection.Features.Set(new ConnectionMigrationFeature(from, ServerId)); @@ -184,6 +185,8 @@ protected override Task OnClientDisconnectedAsync(CloseConnectionMessage message { if (_clientConnectionManager.TryRemoveClientConnection(message.ConnectionId, out var c) && c is ClientConnectionContext connection) { + connection.Features.Set(null); + if (message.Headers.TryGetValue(Constants.AsrsMigrateTo, out var to)) { connection.AbortOnClose = false; diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs index 395f9a201..f8388a121 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceMessageTests.cs @@ -117,6 +117,45 @@ public async Task TestCloseConnectionMessageWithMigrateOut() await connection.StopAsync(); } + [Fact] + public async Task TestMigrateInConnectionAndNormalClose() + { + var clientConnectionFactory = new TestClientConnectionFactory(); + var clientInvocationManager = new DefaultClientInvocationManager(); + var connection = CreateServiceConnection(clientConnectionFactory: clientConnectionFactory, clientInvocationManager: clientInvocationManager); + _ = connection.StartAsync(); + await connection.ConnectionInitializedTask.OrTimeout(); + + var openConnectionMessage = new OpenConnectionMessage("foo", Array.Empty()) { Protocol = "json" }; + openConnectionMessage.Headers.Add(Constants.AsrsMigrateFrom, "another-server"); + _ = connection.WriteFromServiceAsync(openConnectionMessage); + await connection.ClientConnectedTask.OrTimeout(); + + Assert.Equal(1, clientConnectionFactory.Connections.Count); + var clientConnection = clientConnectionFactory.Connections[0]; + var feature = clientConnection.Features.Get(); + Assert.NotNull(feature); + Assert.Equal("another-server", feature.MigrateFrom); + + // write a handshake response + var message = new SignalRProtocol.HandshakeResponseMessage(""); + SignalRProtocol.HandshakeProtocol.WriteResponseMessage(message, clientConnection.Transport.Output); + await clientConnection.Transport.Output.FlushAsync(); + + // signalr handshake response should be skipped. + await Assert.ThrowsAsync(async () => await connection.ExpectSignalRMessage(SignalRProtocol.HandshakeResponseMessage.Empty).OrTimeout(1000)); + + // write close connection message + await connection.WriteFromServiceAsync(new CloseConnectionMessage(clientConnection.ConnectionId)); + + // wait until app task completed. + await clientConnection.LifetimeTask; + + await connection.StopAsync(); + + Assert.Null(clientConnection.Features.Get()); + } + [Fact] public async Task TestCloseConnectionMessage() { @@ -143,7 +182,6 @@ public async Task TestCloseConnectionMessage() await connection.WriteFromServiceAsync(new CloseConnectionMessage(clientConnection.ConnectionId)); // wait until app task completed. - await Assert.ThrowsAsync(async () => await clientConnection.LifetimeTask.OrTimeout(1000)); await clientConnection.LifetimeTask; await connection.ExpectSignalRMessage(SignalRProtocol.HandshakeResponseMessage.Empty).OrTimeout(1000); @@ -516,6 +554,7 @@ public async Task ExpectSignalRMessage(T message, string connectionId = null) Assert.IsType(actual); } _payload = payload; + throw new Exception(_payload.ToString()); } public void CompleteWriteFromService() @@ -561,7 +600,8 @@ private async Task> GetPayloadAsync(string connectionId = { Assert.Equal(connectionId, dataMessage.ConnectionId); } - Reader.AdvanceTo(buffer.Start); + Reader.AdvanceTo(buffer.Start, buffer.End); + return dataMessage.Payload; } else