Skip to content

Commit

Permalink
Add an auth short timeout
Browse files Browse the repository at this point in the history
Add a polling for next request short timeout
  • Loading branch information
nathanwoctopusdeploy committed Nov 16, 2023
1 parent 0ddba0f commit fccce9d
Show file tree
Hide file tree
Showing 12 changed files with 154 additions and 35 deletions.
15 changes: 12 additions & 3 deletions source/Halibut.Tests/HalibutTimeoutsAndLimitsForTestsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace Halibut.Tests
public class HalibutTimeoutsAndLimitsForTestsBuilder
{
public static readonly TimeSpan HalfTheTcpReceiveTimeout = TimeSpan.FromSeconds(22.5);
static readonly TimeSpan PollingQueueWaitTimeout = TimeSpan.FromSeconds(20);
static readonly TimeSpan ShortTimeout = TimeSpan.FromSeconds(15);

public HalibutTimeoutsAndLimits Build()
{
Expand All @@ -22,10 +24,17 @@ public HalibutTimeoutsAndLimits Build()
TcpClientSendTimeout = HalfTheTcpReceiveTimeout + HalfTheTcpReceiveTimeout,
TcpClientReceiveTimeout = HalfTheTcpReceiveTimeout + HalfTheTcpReceiveTimeout,

TcpClientHeartbeatSendTimeout = TimeSpan.FromSeconds(15),
TcpClientHeartbeatReceiveTimeout = TimeSpan.FromSeconds(15),
TcpClientHeartbeatSendTimeout = ShortTimeout,
TcpClientHeartbeatReceiveTimeout = ShortTimeout,

TcpClientAuthenticationSendTimeout = ShortTimeout,
TcpClientAuthenticationReceiveTimeout = ShortTimeout,

TcpClientPollingForNextRequestSendTimeout = ShortTimeout,
TcpClientPollingForNextRequestReceiveTimeout = PollingQueueWaitTimeout + ShortTimeout,

TcpClientConnectTimeout = TimeSpan.FromSeconds(20),
PollingQueueWaitTimeout = TimeSpan.FromSeconds(20)
PollingQueueWaitTimeout = PollingQueueWaitTimeout
};
}
}
Expand Down
12 changes: 6 additions & 6 deletions source/Halibut.Tests/Support/SerilogLoggerBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,18 @@ public ILogger Build()
var testHash = CurrentTestHash();
var logger = Logger.ForContext("TestHash", testHash);

if (!HasLoggedTestHash.Contains(testName))
{
HasLoggedTestHash.Add(testName);
logger.Information($"Test: {TestContext.CurrentContext.Test.Name} has hash {testHash}");
}

if (traceFileLogger != null)
{
TraceLoggers.AddOrUpdate(testName, traceFileLogger, (_, _) => throw new Exception("This should never be updated. If it is, it means that a test is being run multiple times in a single test run"));
traceFileLogger.SetTestHash(testHash);
}

if (!HasLoggedTestHash.Contains(testName))
{
HasLoggedTestHash.Add(testName);
logger.Information($"Test: {TestContext.CurrentContext.Test.Name} has hash {testHash}");
}

return logger;
}

Expand Down
14 changes: 14 additions & 0 deletions source/Halibut.Tests/Transport/Protocol/ProtocolFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,20 @@ public async Task<T> ReceiveAsync<T>(CancellationToken cancellationToken)
return (T)(nextReadQueue.Count > 0 ? nextReadQueue.Dequeue() : default(T));
}

public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func<Task> func)
{
output.AppendLine("|-- Set Timeout " + timeout);

await func();
}

public async Task<T> WithTimeout<T>(MessageExchangeStreamTimeout timeout, Func<Task<T>> func)
{
output.AppendLine("|-- Set Timeout " + timeout);

return await func();
}

public override string ToString()
{
return output.ToString();
Expand Down
5 changes: 5 additions & 0 deletions source/Halibut/Diagnostics/HalibutTimeoutsAndLimits.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public HalibutTimeoutsAndLimits() { }
public TimeSpan TcpClientHeartbeatSendTimeout { get; set; } = TimeSpan.FromSeconds(60);
public TimeSpan TcpClientHeartbeatReceiveTimeout { get; set; } = TimeSpan.FromSeconds(60);

public TimeSpan TcpClientAuthenticationSendTimeout { get; set; } = TimeSpan.FromSeconds(60);
public TimeSpan TcpClientAuthenticationReceiveTimeout { get; set; } = TimeSpan.FromSeconds(60);
public TimeSpan TcpClientPollingForNextRequestSendTimeout { get; set; } = TimeSpan.FromSeconds(60);
public TimeSpan TcpClientPollingForNextRequestReceiveTimeout { get; set; } = TimeSpan.FromSeconds(30) + TimeSpan.FromSeconds(60);

/// <summary>
/// Amount of time to wait for a successful TCP or WSS connection
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions source/Halibut/Transport/Protocol/IMessageExchangeStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,8 @@ public interface IMessageExchangeStream
Task SendAsync<T>(T message, CancellationToken cancellationToken);

Task<T> ReceiveAsync<T>(CancellationToken cancellationToken);

Task WithTimeout(MessageExchangeStreamTimeout timeout, Func<Task> func);
Task<T> WithTimeout<T>(MessageExchangeStreamTimeout timeout, Func<Task<T>> func);
}
}
4 changes: 3 additions & 1 deletion source/Halibut/Transport/Protocol/MessageExchangeProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ public async Task ExchangeAsSubscriberAsync(Uri subscriptionId, Func<RequestMess

static async Task ReceiveAndProcessRequestAsync(IMessageExchangeStream stream, Func<RequestMessage, Task<ResponseMessage>> incomingRequestProcessor, CancellationToken cancellationToken)
{
var request = await stream.ReceiveAsync<RequestMessage>(cancellationToken);
var request = await stream.WithTimeout(
MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout,
async () => await stream.ReceiveAsync<RequestMessage>(cancellationToken));

if (request != null)
{
Expand Down
54 changes: 37 additions & 17 deletions source/Halibut/Transport/Protocol/MessageExchangeStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ public MessageExchangeStream(Stream stream, IMessageSerializer serializer, Halib

public async Task IdentifyAsClientAsync(CancellationToken cancellationToken)
{
log.Write(EventType.Diagnostic, "Identifying as a client");
await SendIdentityMessageAsync($"{MxClient} {currentVersion}", cancellationToken);
await ExpectServerIdentityAsync(cancellationToken);
await WithTimeout(
MessageExchangeStreamTimeout.AuthenticationShortTimeout,
async () =>
{
log.Write(EventType.Diagnostic, "Identifying as a client");
await SendIdentityMessageAsync($"{MxClient} {currentVersion}", cancellationToken);
await ExpectServerIdentityAsync(cancellationToken);
});
}

async Task SendControlMessageAsync(string message, CancellationToken cancellationToken)
Expand Down Expand Up @@ -85,15 +90,20 @@ await WithTimeout(

public async Task<bool> ExpectNextOrEndAsync(CancellationToken cancellationToken)
{
var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken);

return line switch
{
Next => true,
null => false,
End => false,
_ => throw new ProtocolException($"Expected {Next} or {End}, got: " + line)
};
return await WithTimeout(
MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout,
async () =>
{
var line = await controlMessageReader.ReadUntilNonEmptyControlMessageAsync(stream, cancellationToken);

return line switch
{
Next => true,
null => false,
End => false,
_ => throw new ProtocolException($"Expected {Next} or {End}, got: " + line)
};
});
}

public async Task ExpectProceedAsync(CancellationToken cancellationToken)
Expand All @@ -118,13 +128,23 @@ await WithTimeout(

public async Task IdentifyAsSubscriberAsync(string subscriptionId, CancellationToken cancellationToken)
{
await SendIdentityMessageAsync($"{MxSubscriber} {currentVersion} {subscriptionId}", cancellationToken);
await ExpectServerIdentityAsync(cancellationToken);
await WithTimeout(
MessageExchangeStreamTimeout.AuthenticationShortTimeout,
async () =>
{
await SendIdentityMessageAsync($"{MxSubscriber} {currentVersion} {subscriptionId}", cancellationToken);
await ExpectServerIdentityAsync(cancellationToken);
});
}

public async Task IdentifyAsServerAsync(CancellationToken cancellationToken)
{
await SendIdentityMessageAsync($"{MxServer} {currentVersion}", cancellationToken);
await WithTimeout(
MessageExchangeStreamTimeout.AuthenticationShortTimeout,
async () =>
{
await SendIdentityMessageAsync($"{MxServer} {currentVersion}", cancellationToken);
});
}

public async Task<RemoteIdentity> ReadRemoteIdentityAsync(CancellationToken cancellationToken)
Expand Down Expand Up @@ -187,12 +207,12 @@ public async Task<T> ReceiveAsync<T>(CancellationToken cancellationToken)
return result;
}

async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func<Task> func)
public async Task WithTimeout(MessageExchangeStreamTimeout timeout, Func<Task> func)
{
await stream.WithTimeout(halibutTimeoutsAndLimits, timeout, func);
}

async Task<T> WithTimeout<T>(MessageExchangeStreamTimeout timeout, Func<Task<T>> func)
public async Task<T> WithTimeout<T>(MessageExchangeStreamTimeout timeout, Func<Task<T>> func)
{
return await stream.WithTimeout(halibutTimeoutsAndLimits, timeout, func);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ namespace Halibut.Transport.Protocol
public enum MessageExchangeStreamTimeout
{
NormalTimeout,
ControlMessageExchangeShortTimeout
ControlMessageExchangeShortTimeout,
AuthenticationShortTimeout,
PollingForNextRequestShortTimeout
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ public static void SetReadAndWriteTimeouts(this Stream stream, MessageExchangeSt
stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatSendTimeout.TotalMilliseconds;
stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatReceiveTimeout.TotalMilliseconds;
break;
case MessageExchangeStreamTimeout.AuthenticationShortTimeout:
stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationSendTimeout.TotalMilliseconds;
stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationReceiveTimeout.TotalMilliseconds;
break;
case MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout:
stream.WriteTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestSendTimeout.TotalMilliseconds;
stream.ReadTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestReceiveTimeout.TotalMilliseconds;
break;
default:
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, null);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.IO;
using System.Net.Sockets;
using System.Threading.Tasks;
using Halibut.Diagnostics;

namespace Halibut.Transport.Protocol
{
public static class TcpClientTimeoutExtensionMethods
{
public static async Task WithTimeout(this TcpClient stream, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits, MessageExchangeStreamTimeout timeout, Func<Task> func)
{
var currentReadTimeout = stream.Client.ReceiveTimeout;
var currentWriteTimeout = stream.Client.SendTimeout;

try
{
stream.SetReadAndWriteTimeouts(timeout, halibutTimeoutsAndLimits);
await func();
}
finally
{
stream.ReceiveTimeout = currentReadTimeout;
stream.SendTimeout = currentWriteTimeout;
}
}

public static void SetReadAndWriteTimeouts(this TcpClient stream, MessageExchangeStreamTimeout timeout, HalibutTimeoutsAndLimits halibutTimeoutsAndLimits)
{
switch (timeout)
{
case MessageExchangeStreamTimeout.NormalTimeout:
stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientSendTimeout.TotalMilliseconds;
stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientReceiveTimeout.TotalMilliseconds;
break;
case MessageExchangeStreamTimeout.ControlMessageExchangeShortTimeout:
stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatSendTimeout.TotalMilliseconds;
stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientHeartbeatReceiveTimeout.TotalMilliseconds;
break;
case MessageExchangeStreamTimeout.AuthenticationShortTimeout:
stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationSendTimeout.TotalMilliseconds;
stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientAuthenticationReceiveTimeout.TotalMilliseconds;
break;
case MessageExchangeStreamTimeout.PollingForNextRequestShortTimeout:
stream.Client.SendTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestSendTimeout.TotalMilliseconds;
stream.Client.ReceiveTimeout = (int)halibutTimeoutsAndLimits.TcpClientPollingForNextRequestReceiveTimeout.TotalMilliseconds;
break;
default:
throw new ArgumentOutOfRangeException(nameof(timeout), timeout, null);
}
}
}
}
2 changes: 1 addition & 1 deletion source/Halibut/Transport/SecureListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async Task ExecuteRequest(TcpClient client)
finally
{
if (!connectionAuthorizedAndObserved)
{
{
connectionsObserver.ConnectionAccepted(false);
}

Expand Down
15 changes: 9 additions & 6 deletions source/Halibut/Transport/TcpConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ public async Task<IConnection> EstablishNewConnectionAsync(ExchangeProtocolBuild

log.Write(EventType.SecurityNegotiation, "Performing TLS handshake");

await client.WithTimeout(halibutTimeoutsAndLimits, MessageExchangeStreamTimeout.AuthenticationShortTimeout, async () =>
{
#if NETFRAMEWORK
// TODO: ASYNC ME UP!
// AuthenticateAsClientAsync in .NET 4.8 does not support cancellation tokens. So `cancellationToken` is not respected here.
await ssl.AuthenticateAsClientAsync(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false);
// TODO: ASYNC ME UP!
// AuthenticateAsClientAsync in .NET 4.8 does not support cancellation tokens. So `cancellationToken` is not respected here.
await ssl.AuthenticateAsClientAsync(serviceEndpoint.BaseUri.Host, new X509Certificate2Collection(clientCertificate), SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12, false);
#else
await ssl.AuthenticateAsClientEnforcingTimeout(serviceEndpoint, new X509Certificate2Collection(clientCertificate), cancellationToken);
await ssl.AuthenticateAsClientEnforcingTimeout(serviceEndpoint, new X509Certificate2Collection(clientCertificate), cancellationToken);
#endif

await ssl.WriteAsync(MxLine, 0, MxLine.Length, cancellationToken);
await ssl.FlushAsync(cancellationToken);
await ssl.WriteAsync(MxLine, 0, MxLine.Length, cancellationToken);
await ssl.FlushAsync(cancellationToken);
});

log.Write(EventType.Security, "Secure connection established. Server at {0} identified by thumbprint: {1}, using protocol {2}", client.Client.RemoteEndPoint, serviceEndpoint.RemoteThumbprint, ssl.SslProtocol.ToString());

Expand Down

0 comments on commit fccce9d

Please sign in to comment.