diff --git a/source/Halibut.Tests/HalibutTimeoutsAndLimitsForTestsBuilder.cs b/source/Halibut.Tests/HalibutTimeoutsAndLimitsForTestsBuilder.cs index ac5ce63c..46850364 100644 --- a/source/Halibut.Tests/HalibutTimeoutsAndLimitsForTestsBuilder.cs +++ b/source/Halibut.Tests/HalibutTimeoutsAndLimitsForTestsBuilder.cs @@ -6,6 +6,8 @@ namespace Halibut.Tests public class HalibutTimeoutsAndLimitsForTestsBuilder { public static readonly TimeSpan HalfTheTcpReceiveTimeout = TimeSpan.FromSeconds(22.5); + static readonly TimeSpan PollingQueueWaitTimeout = TimeSpan.FromSeconds(20); + static readonly TimeSpan ShortTimeout = TimeSpan.FromSeconds(15); public HalibutTimeoutsAndLimits Build() { @@ -22,10 +24,17 @@ public HalibutTimeoutsAndLimits Build() TcpClientSendTimeout = HalfTheTcpReceiveTimeout + HalfTheTcpReceiveTimeout, TcpClientReceiveTimeout = HalfTheTcpReceiveTimeout + HalfTheTcpReceiveTimeout, - TcpClientHeartbeatSendTimeout = TimeSpan.FromSeconds(15), - TcpClientHeartbeatReceiveTimeout = TimeSpan.FromSeconds(15), + TcpClientHeartbeatSendTimeout = ShortTimeout, + TcpClientHeartbeatReceiveTimeout = ShortTimeout, + + TcpClientAuthenticationSendTimeout = ShortTimeout, + TcpClientAuthenticationReceiveTimeout = ShortTimeout, + + TcpClientPollingForNextRequestSendTimeout = ShortTimeout, + TcpClientPollingForNextRequestReceiveTimeout = PollingQueueWaitTimeout + ShortTimeout, + TcpClientConnectTimeout = TimeSpan.FromSeconds(20), - PollingQueueWaitTimeout = TimeSpan.FromSeconds(20) + PollingQueueWaitTimeout = PollingQueueWaitTimeout }; } } diff --git a/source/Halibut.Tests/Support/SerilogLoggerBuilder.cs b/source/Halibut.Tests/Support/SerilogLoggerBuilder.cs index d447c312..52db0f7e 100644 --- a/source/Halibut.Tests/Support/SerilogLoggerBuilder.cs +++ b/source/Halibut.Tests/Support/SerilogLoggerBuilder.cs @@ -57,18 +57,18 @@ public ILogger Build() var testHash = CurrentTestHash(); var logger = Logger.ForContext("TestHash", testHash); - if (!HasLoggedTestHash.Contains(testName)) - { - HasLoggedTestHash.Add(testName); - logger.Information($"Test: {TestContext.CurrentContext.Test.Name} has hash {testHash}"); - } - if (traceFileLogger != null) { TraceLoggers.AddOrUpdate(testName, traceFileLogger, (_, _) => throw new Exception("This should never be updated. If it is, it means that a test is being run multiple times in a single test run")); traceFileLogger.SetTestHash(testHash); } + if (!HasLoggedTestHash.Contains(testName)) + { + HasLoggedTestHash.Add(testName); + logger.Information($"Test: {TestContext.CurrentContext.Test.Name} has hash {testHash}"); + } + return logger; } diff --git a/source/Halibut.Tests/Timeouts/TimeoutsApplyDuringHandShake.cs b/source/Halibut.Tests/Timeouts/TimeoutsApplyDuringHandShake.cs index a0f8fae3..a51decb2 100644 --- a/source/Halibut.Tests/Timeouts/TimeoutsApplyDuringHandShake.cs +++ b/source/Halibut.Tests/Timeouts/TimeoutsApplyDuringHandShake.cs @@ -64,7 +64,7 @@ int writeNumberToPauseOn // Ie pause on the first or second write } sw.Stop(); - sw.Elapsed.Should().BeCloseTo(clientAndService.Service.TimeoutsAndLimits.TcpClientReceiveTimeout, TimeSpan.FromSeconds(15), "Since a paused connection early on should not hang forever."); + sw.Elapsed.Should().BeCloseTo(clientAndService.Service.TimeoutsAndLimits.TcpClientAuthenticationSendTimeout, TimeSpan.FromSeconds(15), "Since a paused connection early on should not hang forever."); await echo.SayHelloAsync("The pump wont be paused here so this should work."); } diff --git a/source/Halibut.Tests/Transport/Protocol/ProtocolFixture.cs b/source/Halibut.Tests/Transport/Protocol/ProtocolFixture.cs index 87c5cf6a..0b5fb9dd 100644 --- a/source/Halibut.Tests/Transport/Protocol/ProtocolFixture.cs +++ b/source/Halibut.Tests/Transport/Protocol/ProtocolFixture.cs @@ -121,22 +121,32 @@ public async Task ShouldExchangeAsSubscriber() AssertOutput(@" --> MX-SUBSCRIBE subscriptionId <-- MX-SERVER +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED"); } @@ -179,37 +189,57 @@ public async Task ShouldExchangeAsSubscriberWithPooling() AssertOutput(@" --> MX-SUBSCRIBE subscriptionId <-- MX-SERVER +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> ResponseMessage --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED +|-- Set Timeout PollingForNextRequestShortTimeout <-- RequestMessage +|-- Revert Timeout PollingForNextRequestShortTimeout --> NEXT <-- PROCEED"); } @@ -279,17 +309,12 @@ public void SetNumberOfReads(int reads) numberOfReads = reads; } - public void IdentifyAsClient() - { - output.AppendLine("--> MX-CLIENT"); - output.AppendLine("<-- MX-SERVER"); - } - public async Task IdentifyAsClientAsync(CancellationToken cancellationToken) { await Task.CompletedTask; - IdentifyAsClient(); + output.AppendLine("--> MX-CLIENT"); + output.AppendLine("<-- MX-SERVER"); } public async Task SendNextAsync(CancellationToken cancellationToken) @@ -371,6 +396,22 @@ public async Task ReceiveAsync(CancellationToken cancellationToken) return (T)(nextReadQueue.Count > 0 ? nextReadQueue.Dequeue() : default(T)); } + public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func func) + { + output.AppendLine("|-- Set Timeout " + timeout); + await func(); + output.AppendLine("|-- Revert Timeout " + timeout); + } + + public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func> func) + { + output.AppendLine("|-- Set Timeout " + timeout); + var response = await func(); + output.AppendLine("|-- Revert Timeout " + timeout); + + return response; + } + public override string ToString() { return output.ToString(); diff --git a/source/Halibut/Diagnostics/HalibutTimeoutsAndLimits.cs b/source/Halibut/Diagnostics/HalibutTimeoutsAndLimits.cs index 965939cf..ec43b3bf 100644 --- a/source/Halibut/Diagnostics/HalibutTimeoutsAndLimits.cs +++ b/source/Halibut/Diagnostics/HalibutTimeoutsAndLimits.cs @@ -47,21 +47,26 @@ public HalibutTimeoutsAndLimits() { } /// /// Amount of time to wait for a TCP or SslStream write to complete successfully /// - public TimeSpan TcpClientSendTimeout { get; set; } = TimeSpan.FromMinutes(10); + public TimeSpan TcpClientSendTimeout { get; set; } = TimeSpan.FromMinutes(1); /// /// Amount of time to wait for a TCP or SslStream read to complete successfully /// - public TimeSpan TcpClientReceiveTimeout { get; set; } = TimeSpan.FromMinutes(10); + public TimeSpan TcpClientReceiveTimeout { get; set; } = TimeSpan.FromMinutes(5); /// /// Amount of time a connection can stay in the pool /// - public TimeSpan TcpClientPooledConnectionTimeout { get; set; } = TimeSpan.FromMinutes(9); + public TimeSpan TcpClientPooledConnectionTimeout { get; set; } = TimeSpan.FromMinutes(4.5); public TimeSpan TcpClientHeartbeatSendTimeout { get; set; } = TimeSpan.FromSeconds(60); public TimeSpan TcpClientHeartbeatReceiveTimeout { get; set; } = TimeSpan.FromSeconds(60); + public TimeSpan TcpClientAuthenticationSendTimeout { get; set; } = TimeSpan.FromSeconds(60); + public TimeSpan TcpClientAuthenticationReceiveTimeout { get; set; } = TimeSpan.FromSeconds(60); + public TimeSpan TcpClientPollingForNextRequestSendTimeout { get; set; } = TimeSpan.FromSeconds(60); + public TimeSpan TcpClientPollingForNextRequestReceiveTimeout { get; set; } = TimeSpan.FromSeconds(60); + /// /// Amount of time to wait for a successful TCP or WSS connection /// diff --git a/source/Halibut/Transport/Protocol/IMessageExchangeStream.cs b/source/Halibut/Transport/Protocol/IMessageExchangeStream.cs index 2b6e4b12..364bc8ce 100644 --- a/source/Halibut/Transport/Protocol/IMessageExchangeStream.cs +++ b/source/Halibut/Transport/Protocol/IMessageExchangeStream.cs @@ -26,5 +26,8 @@ public interface IMessageExchangeStream Task SendAsync(T message, CancellationToken cancellationToken); Task ReceiveAsync(CancellationToken cancellationToken); + + Task WithTimeout(MessageExchangeStreamTimeout timeout, Func func); + Task WithTimeout(MessageExchangeStreamTimeout timeout, Func> func); } } \ No newline at end of file diff --git a/source/Halibut/Transport/Protocol/MessageExchangeProtocol.cs b/source/Halibut/Transport/Protocol/MessageExchangeProtocol.cs index 1a79bd68..60957702 100644 --- a/source/Halibut/Transport/Protocol/MessageExchangeProtocol.cs +++ b/source/Halibut/Transport/Protocol/MessageExchangeProtocol.cs @@ -93,7 +93,9 @@ public async Task ExchangeAsSubscriberAsync(Uri subscriptionId, Func> incomingRequestProcessor, CancellationToken cancellationToken) { - var request = await stream.ReceiveAsync(cancellationToken); + var request = await stream.WithTimeout( + MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout, + async () => await stream.ReceiveAsync(cancellationToken)); if (request != null) { diff --git a/source/Halibut/Transport/Protocol/MessageExchangeStream.cs b/source/Halibut/Transport/Protocol/MessageExchangeStream.cs index f5c90674..7ad6a738 100644 --- a/source/Halibut/Transport/Protocol/MessageExchangeStream.cs +++ b/source/Halibut/Transport/Protocol/MessageExchangeStream.cs @@ -35,16 +35,22 @@ public MessageExchangeStream(Stream stream, IMessageSerializer serializer, Halib this.halibutTimeoutsAndLimits = halibutTimeoutsAndLimits; this.controlMessageReader = new ControlMessageReader(halibutTimeoutsAndLimits); this.serializer = serializer; - SetNormalTimeoutsAsync(); + + SetReadAndWriteTimeouts(MessageExchangeStreamTimeout.NormalTimeout); } static int streamCount; public async Task IdentifyAsClientAsync(CancellationToken cancellationToken) { - log.Write(EventType.Diagnostic, "Identifying as a client"); - await SendIdentityMessageAsync($"{MxClient} {currentVersion}", cancellationToken); - await ExpectServerIdentityAsync(cancellationToken); + await WithTimeout( + MessageExchangeStreamTimeout.AuthenticationShortTimeout, + async () => + { + log.Write(EventType.Diagnostic, "Identifying as a client"); + await SendIdentityMessageAsync($"{MxClient} {currentVersion}", cancellationToken); + await ExpectServerIdentityAsync(cancellationToken); + }); } async Task SendControlMessageAsync(string message, CancellationToken cancellationToken) @@ -63,64 +69,82 @@ async Task SendIdentityMessageAsync(string identityLine, CancellationToken cance public async Task SendNextAsync(CancellationToken cancellationToken) { - SetShortTimeoutsAsync(); - await SendControlMessageAsync(Next, cancellationToken); - SetNormalTimeoutsAsync(); + await WithTimeout( + MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout, + async () => await SendControlMessageAsync(Next, cancellationToken)); } public async Task SendProceedAsync(CancellationToken cancellationToken) { - await SendControlMessageAsync(Proceed, cancellationToken); + await WithTimeout( + MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout, + async () => await SendControlMessageAsync(Proceed, cancellationToken)); } public async Task SendEndAsync(CancellationToken cancellationToken) { - SetShortTimeoutsAsync(); - await SendControlMessageAsync(End, cancellationToken); - SetNormalTimeoutsAsync(); + await WithTimeout( + MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout, + async () => await SendControlMessageAsync(End, cancellationToken)); } public async Task ExpectNextOrEndAsync(CancellationToken cancellationToken) { - var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken); + return await WithTimeout( + MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout, + async () => + { + var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken); - return line switch - { - Next => true, - null => false, - End => false, - _ => throw new ProtocolException($"Expected {Next} or {End}, got: " + line) - }; + return line switch + { + Next => true, + null => false, + End => false, + _ => throw new ProtocolException($"Expected {Next} or {End}, got: " + line) + }; + }); } public async Task ExpectProceedAsync(CancellationToken cancellationToken) { - SetShortTimeoutsAsync(); - - var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken); - - if (line == null) - { - throw new AuthenticationException($"Expected {Proceed}, got no data"); - } + await WithTimeout( + MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout, + async () => + { + var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken); - if (line != Proceed) - { - throw new ProtocolException($"Expected {Proceed}, got: " + line); - } + if (line == null) + { + throw new AuthenticationException($"Expected {Proceed}, got no data"); + } - SetNormalTimeoutsAsync(); + if (line != Proceed) + { + throw new ProtocolException($"Expected {Proceed}, got: " + line); + } + }); } public async Task IdentifyAsSubscriberAsync(string subscriptionId, CancellationToken cancellationToken) { - await SendIdentityMessageAsync($"{MxSubscriber} {currentVersion} {subscriptionId}", cancellationToken); - await ExpectServerIdentityAsync(cancellationToken); + await WithTimeout( + MessageExchangeStreamTimeout.AuthenticationShortTimeout, + async () => + { + await SendIdentityMessageAsync($"{MxSubscriber} {currentVersion} {subscriptionId}", cancellationToken); + await ExpectServerIdentityAsync(cancellationToken); + }); } public async Task IdentifyAsServerAsync(CancellationToken cancellationToken) { - await SendIdentityMessageAsync($"{MxServer} {currentVersion}", cancellationToken); + await WithTimeout( + MessageExchangeStreamTimeout.AuthenticationShortTimeout, + async () => + { + await SendIdentityMessageAsync($"{MxServer} {currentVersion}", cancellationToken); + }); } public async Task ReadRemoteIdentityAsync(CancellationToken cancellationToken) @@ -182,6 +206,21 @@ public async Task ReceiveAsync(CancellationToken cancellationToken) log.Write(EventType.Diagnostic, "Received: {0}", result); return result; } + + public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func func) + { + await stream.WithTimeout(halibutTimeoutsAndLimits, timeout, func); + } + + public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func> func) + { + return await stream.WithTimeout(halibutTimeoutsAndLimits, timeout, func); + } + + void SetReadAndWriteTimeouts(MessageExchangeStreamTimeout timeout) + { + stream.SetReadAndWriteTimeouts(timeout, halibutTimeoutsAndLimits); + } static RemoteIdentityType ParseIdentityType(string identityType) { @@ -282,28 +321,5 @@ async Task WriteEachStreamAsync(IEnumerable streams, CancellationTok await stream.FlushAsync(cancellationToken); } } - - void SetNormalTimeoutsAsync() - { - // TODO - ASYNC ME UP! - // We should always be given a stream that can timeout. - if (!stream.CanTimeout) - return; - - stream.WriteTimeout = (int)this.halibutTimeoutsAndLimits.TcpClientSendTimeout.TotalMilliseconds; - stream.ReadTimeout = (int)this.halibutTimeoutsAndLimits.TcpClientReceiveTimeout.TotalMilliseconds; - } - - void SetShortTimeoutsAsync() - { - - // TODO - ASYNC ME UP! - // We should always be given a stream that can timeout. - if (!stream.CanTimeout) - return; - - stream.WriteTimeout = (int)this.halibutTimeoutsAndLimits.TcpClientHeartbeatSendTimeout.TotalMilliseconds; - stream.ReadTimeout = (int)this.halibutTimeoutsAndLimits.TcpClientHeartbeatReceiveTimeout.TotalMilliseconds; - } } } diff --git a/source/Halibut/Transport/Protocol/MessageExchangeStreamTimeout.cs b/source/Halibut/Transport/Protocol/MessageExchangeStreamTimeout.cs new file mode 100644 index 00000000..4e7ecb00 --- /dev/null +++ b/source/Halibut/Transport/Protocol/MessageExchangeStreamTimeout.cs @@ -0,0 +1,12 @@ +using System; + +namespace Halibut.Transport.Protocol +{ + public enum MessageExchangeStreamTimeout + { + NormalTimeout, + ControlMessageExchangeShortTimeout, + AuthenticationShortTimeout, + PollingForNextRequestShortTimeout + } +} \ No newline at end of file diff --git a/source/Halibut/Transport/Protocol/StreamTimeoutExtensionMethods.cs b/source/Halibut/Transport/Protocol/StreamTimeoutExtensionMethods.cs new file mode 100644 index 00000000..8b80f030 --- /dev/null +++ b/source/Halibut/Transport/Protocol/StreamTimeoutExtensionMethods.cs @@ -0,0 +1,86 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Halibut.Diagnostics; + +namespace Halibut.Transport.Protocol +{ + public static class StreamTimeoutExtensionMethods + { + public static async Task WithTimeout(this Stream stream, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits, MessageExchangeStreamTimeout timeout, Func func) + { + if (!stream.CanTimeout) + { + await func(); + + return; + } + + var currentReadTimeout = stream.ReadTimeout; + var currentWriteTimeout = stream.WriteTimeout; + + try + { + stream.SetReadAndWriteTimeouts(timeout, halibutTimeoutsAndLimits); + await func(); + } + finally + { + stream.ReadTimeout = currentReadTimeout; + stream.WriteTimeout = currentWriteTimeout; + } + } + + public static async Task WithTimeout(this Stream stream, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits, MessageExchangeStreamTimeout timeout, Func> func) + { + if (!stream.CanTimeout) + { + return await func(); + } + + var currentReadTimeout = stream.ReadTimeout; + var currentWriteTimeout = stream.WriteTimeout; + + try + { + stream.SetReadAndWriteTimeouts(timeout, halibutTimeoutsAndLimits); + return await func(); + } + finally + { + stream.ReadTimeout = currentReadTimeout; + stream.WriteTimeout = currentWriteTimeout; + } + } + + public static void SetReadAndWriteTimeouts(this Stream stream, MessageExchangeStreamTimeout timeout, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits) + { + if (!stream.CanTimeout) + { + return; + } + + switch (timeout) + { + case MessageExchangeStreamTimeout.NormalTimeout: + stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientSendTimeout.TotalMilliseconds; + stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout: + stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatSendTimeout.TotalMilliseconds; + stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.AuthenticationShortTimeout: + stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationSendTimeout.TotalMilliseconds; + stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout: + stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestSendTimeout.TotalMilliseconds; + stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestReceiveTimeout.TotalMilliseconds; + break; + default: + throw new ArgumentOutOfRangeException(nameof(timeout), timeout, null); + } + } + } +} diff --git a/source/Halibut/Transport/Protocol/TcpClientTimeoutExtensionMethods.cs b/source/Halibut/Transport/Protocol/TcpClientTimeoutExtensionMethods.cs new file mode 100644 index 00000000..a0cb4365 --- /dev/null +++ b/source/Halibut/Transport/Protocol/TcpClientTimeoutExtensionMethods.cs @@ -0,0 +1,53 @@ +using System; +using System.IO; +using System.Net.Sockets; +using System.Threading.Tasks; +using Halibut.Diagnostics; + +namespace Halibut.Transport.Protocol +{ + public static class TcpClientTimeoutExtensionMethods + { + public static async Task WithTimeout(this TcpClient stream, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits, MessageExchangeStreamTimeout timeout, Func func) + { + var currentReadTimeout = stream.Client.ReceiveTimeout; + var currentWriteTimeout = stream.Client.SendTimeout; + + try + { + stream.SetReadAndWriteTimeouts(timeout, halibutTimeoutsAndLimits); + await func(); + } + finally + { + stream.ReceiveTimeout = currentReadTimeout; + stream.SendTimeout = currentWriteTimeout; + } + } + + public static void SetReadAndWriteTimeouts(this TcpClient stream, MessageExchangeStreamTimeout timeout, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits) + { + switch (timeout) + { + case MessageExchangeStreamTimeout.NormalTimeout: + stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientSendTimeout.TotalMilliseconds; + stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout: + stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatSendTimeout.TotalMilliseconds; + stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.AuthenticationShortTimeout: + stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationSendTimeout.TotalMilliseconds; + stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationReceiveTimeout.TotalMilliseconds; + break; + case MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout: + stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestSendTimeout.TotalMilliseconds; + stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestReceiveTimeout.TotalMilliseconds; + break; + default: + throw new ArgumentOutOfRangeException(nameof(timeout), timeout, null); + } + } + } +} diff --git a/source/Halibut/Transport/SecureListener.cs b/source/Halibut/Transport/SecureListener.cs index cc943705..4d383aaf 100644 --- a/source/Halibut/Transport/SecureListener.cs +++ b/source/Halibut/Transport/SecureListener.cs @@ -276,7 +276,7 @@ async Task ExecuteRequest(TcpClient client) finally { if (!connectionAuthorizedAndObserved) - { + { connectionsObserver.ConnectionAccepted(false); } diff --git a/source/Halibut/Transport/TcpConnectionFactory.cs b/source/Halibut/Transport/TcpConnectionFactory.cs index d3e96627..d5eedd7c 100644 --- a/source/Halibut/Transport/TcpConnectionFactory.cs +++ b/source/Halibut/Transport/TcpConnectionFactory.cs @@ -43,16 +43,19 @@ public async Task EstablishNewConnectionAsync(ExchangeProtocolBuild log.Write(EventType.SecurityNegotiation, "Performing TLS handshake"); + await client.WithTimeout(halibutTimeoutsAndLimits, MessageExchangeStreamTimeout.AuthenticationShortTimeout, async () => + { #if NETFRAMEWORK - // TODO: ASYNC ME UP! - // AuthenticateAsClientAsync in .NET 4.8 does not support cancellation tokens. So `cancellationToken` is not respected here. - await ssl.AuthenticateAsClientAsync(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false); + // TODO: ASYNC ME UP! + // AuthenticateAsClientAsync in .NET 4.8 does not support cancellation tokens. So `cancellationToken` is not respected here. + await ssl.AuthenticateAsClientAsync(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false); #else - await ssl.AuthenticateAsClientEnforcingTimeout(serviceEndpoint, new X509Certificate2Collection(clientCertificate), cancellationToken); + await ssl.AuthenticateAsClientEnforcingTimeout(serviceEndpoint, new X509Certificate2Collection(clientCertificate), cancellationToken); #endif - await ssl.WriteAsync(MxLine, 0, MxLine.Length, cancellationToken); - await ssl.FlushAsync(cancellationToken); + await ssl.WriteAsync(MxLine, 0, MxLine.Length, cancellationToken); + await ssl.FlushAsync(cancellationToken); + }); log.Write(EventType.Security, "Secure connection established. Server at {0} identified by thumbprint: {1}, using protocol {2}", client.Client.RemoteEndPoint, serviceEndpoint.RemoteThumbprint, ssl.SslProtocol.ToString());