Skip to content

Commit

Permalink
Add fix + test (#1035)
Browse files Browse the repository at this point in the history
Co-authored-by: Gabriel Beaudoin <[email protected]>
  • Loading branch information
GabrielBeaudoin and Gabriel Beaudoin authored Dec 10, 2023
1 parent d87b256 commit 7872e7d
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 15 deletions.
62 changes: 47 additions & 15 deletions src/core/SIP/Channels/SIPTLSChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -89,29 +90,60 @@ public SIPTLSChannel(X509Certificate2 serverCertificate, IPAddress listenAddress
/// </summary>
/// <param name="streamConnection">The stream connection holding the newly accepted client socket.</param>
protected override void OnAccept(SIPStreamConnection streamConnection)
{
OnAcceptAsync(streamConnection).ConfigureAwait(false);
}

/// <summary>
/// For the TLS channel the SSL stream must be created and any authentication actions undertaken.
/// </summary>
/// <param name="streamConnection">The stream connection holding the newly accepted client socket.</param>
protected async Task OnAcceptAsync(SIPStreamConnection streamConnection)
{
NetworkStream networkStream = new NetworkStream(streamConnection.StreamSocket, true);
SslStream sslStream = new SslStream(networkStream, false);
SslStream sslStream = null;

try
{
sslStream = new SslStream(networkStream, false);
using (var cts = new CancellationTokenSource())
{
var authTask = sslStream.AuthenticateAsServerAsync(m_serverCertificate);
var timeoutTask = Task.Delay(TLS_ATTEMPT_CONNECT_TIMEOUT, cts.Token);

//await sslStream.AuthenticateAsServerAsync(m_serverCertificate).ConfigureAwait(false);
sslStream.AuthenticateAsServer(m_serverCertificate);
var resultTask = await Task.WhenAny(authTask, timeoutTask);
if (resultTask == timeoutTask)
{
logger.LogWarning("SIP TLS Channel failed to connect to remote host. The authentication handshake timed out.");
sslStream.Close();
return;
}
cts.Cancel();

logger.LogDebug($"SIP TLS Channel successfully upgraded accepted client to SSL stream for {ListeningSIPEndPoint}<-{streamConnection.RemoteSIPEndPoint}.");
logger.LogDebug($"SIP TLS Channel successfully upgraded accepted client to SSL stream for {ListeningSIPEndPoint}<-{streamConnection.RemoteSIPEndPoint}.");

//// Display the properties and settings for the authenticated stream.
////DisplaySecurityLevel(sslStream);
////DisplaySecurityServices(sslStream);
////DisplayCertificateInformation(sslStream);
////DisplayStreamProperties(sslStream);

//// Set timeouts for the read and write to 5 seconds.
//sslStream.ReadTimeout = 5000;
//sslStream.WriteTimeout = 5000;
//// Display the properties and settings for the authenticated stream.
////DisplaySecurityLevel(sslStream);
////DisplaySecurityServices(sslStream);
////DisplayCertificateInformation(sslStream);
////DisplayStreamProperties(sslStream);

streamConnection.SslStream = new SIPStreamWrapper(sslStream);
streamConnection.SslStreamBuffer = new byte[2 * SIPStreamConnection.MaxSIPTCPMessageSize];
//// Set timeouts for the read and write to 5 seconds.
//sslStream.ReadTimeout = 5000;
//sslStream.WriteTimeout = 5000;

streamConnection.SslStream = new SIPStreamWrapper(sslStream);
streamConnection.SslStreamBuffer = new byte[2 * SIPStreamConnection.MaxSIPTCPMessageSize];
}

sslStream.BeginRead(streamConnection.SslStreamBuffer, 0, SIPStreamConnection.MaxSIPTCPMessageSize, new AsyncCallback(OnReadCallback), streamConnection);
sslStream.BeginRead(streamConnection.SslStreamBuffer, 0, SIPStreamConnection.MaxSIPTCPMessageSize, new AsyncCallback(OnReadCallback), streamConnection);
}
catch(Exception ex)
{
logger.LogError(ex, "SIP TLS channel could not connect to remote host. {exception}", ex.Message);
sslStream?.Close();
}
}

/// <summary>
Expand Down
64 changes: 64 additions & 0 deletions test/integration/core/SIPTransportIntegrationTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,70 @@ public async void WebSocketLoopbackLargeSendReceiveTest()
logger.LogDebug("Test complete.");
}

[Fact]
public void TlsDoesNotGetStuckOnIncompleteTcpConnection()
{
// Arrange
logger.LogDebug("--> " + System.Reflection.MethodBase.GetCurrentMethod().Name);
logger.BeginScope(System.Reflection.MethodBase.GetCurrentMethod().Name);

ManualResetEventSlim serverReadyEvent = new ManualResetEventSlim(false);
CancellationTokenSource cancelServer = new CancellationTokenSource();
TaskCompletionSource<bool> testComplete = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

Assert.True(File.Exists(@"certs/localhost.pfx"), "The TLS transport channel test was missing the localhost.pfx certificate file.");
var serverCertificate = new X509Certificate2(@"certs/localhost.pfx", "");
serverCertificate.Verify();

var serverChannel = new SIPTLSChannel(serverCertificate, IPAddress.Loopback, 0);
serverChannel.DisableLocalTCPSocketsCheck = true;
var serverTask = Task.Run(() => { RunServer(serverChannel, cancelServer, serverReadyEvent); });

var tlsClientChannel = new SIPTLSChannel(new IPEndPoint(IPAddress.Loopback, 0));
tlsClientChannel.DisableLocalTCPSocketsCheck = true;

var tcpConnection = new TcpClient(new IPEndPoint(IPAddress.Loopback, 0));

// Act
try
{
tcpConnection.Connect(serverChannel.ListeningEndPoint);

var clientTask = Task.Run(async () =>
{
// Try to connect a TLS client
await RunClient(
tlsClientChannel,
serverChannel.GetContactURI(SIPSchemesEnum.sips, new SIPEndPoint(SIPProtocolsEnum.tls, serverChannel.ListeningEndPoint)),
testComplete,
cancelServer,
serverReadyEvent);
});

// Assert
if (!Task.WhenAny(new Task[] { serverTask, clientTask }).Wait(TRANSPORT_TEST_TIMEOUT))
{
logger.LogWarning($"Tasks timed out");
}

if (testComplete.Task.IsCompleted == false)
{
// The client did not set the completed signal. This means the delay task must have completed and hence the test failed.
testComplete.SetResult(false);
}

Assert.True(testComplete.Task.Result);

}
finally
{
tcpConnection.Close();
cancelServer.Cancel();
}
logger.LogDebug("Test complete.");

}

/// <summary>
/// Initialises a SIP transport to act as a server in single request/response exchange.
/// </summary>
Expand Down

0 comments on commit 7872e7d

Please sign in to comment.