diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index c35ad7283..232de654b 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -22,6 +22,8 @@ namespace Snowflake.Data.Tests.IntegrationTests using System.Runtime.InteropServices; using System.Net.Http; using System.Security.Authentication; + using Moq.Protected; + using Moq; [TestFixture] class SFConnectionIT : SFBaseTest @@ -584,13 +586,25 @@ public void TestEnableLoginRetryOn404() [Test] - public void TestAuthenticationExceptionThrowsExceptionAndNotRetried() + public void TestNonRetryableHttpExceptionThrowsError() { - var mockRestRequester = new MockInfiniteTimeout(); + var handler = new Mock(); + handler.Protected() + .Setup>( + "SendAsync", + ItExpr.Is(req => req.RequestUri.ToString().Contains("https://authenticationexceptiontest.com/")), + ItExpr.IsAny()) + .ThrowsAsync(new HttpRequestException("", new AuthenticationException())); + + var httpClient = HttpUtil.Instance.GetHttpClient( + new HttpClientConfig(false, "fakeHost", "fakePort", "user", "password", "fakeProxyList", false, false, 7), + handler.Object); + + var mockRestRequester = new MockInfiniteTimeout(httpClient); using (var conn = new MockSnowflakeDbConnection(mockRestRequester)) { - string invalidConnectionString = "host=google.com/404;" + string invalidConnectionString = "host=authenticationexceptiontest.com;" + "account=account;user=user;password=password;"; conn.ConnectionString = invalidConnectionString; @@ -603,10 +617,7 @@ public void TestAuthenticationExceptionThrowsExceptionAndNotRetried() catch (AggregateException e) { Assert.IsInstanceOf(e.InnerException); -#if NET6_0_OR_GREATER Assert.IsInstanceOf(e.InnerException.InnerException); - Assert.IsTrue(e.InnerException.InnerException.Message.Contains("The remote certificate is invalid because of errors in the certificate chain: RevocationStatusUnknown")); -#endif } catch (Exception unexpected) { diff --git a/Snowflake.Data.Tests/Mock/MockInfiniteTimeout.cs b/Snowflake.Data.Tests/Mock/MockInfiniteTimeout.cs index 2fd5330a1..49aee556f 100644 --- a/Snowflake.Data.Tests/Mock/MockInfiniteTimeout.cs +++ b/Snowflake.Data.Tests/Mock/MockInfiniteTimeout.cs @@ -13,14 +13,23 @@ namespace Snowflake.Data.Tests.Mock class MockInfiniteTimeout : RestRequester, IMockRestRequester { - public MockInfiniteTimeout() : base(null) + HttpClient mockHttpClient; + + public MockInfiniteTimeout(HttpClient mockHttpClient = null) : base(null) { - // Does nothing + this.mockHttpClient = mockHttpClient; } public void setHttpClient(HttpClient httpClient) { - base._HttpClient = httpClient; + if (mockHttpClient != null) + { + base._HttpClient = mockHttpClient; + } + else + { + base._HttpClient = httpClient; + } } protected override async Task SendAsync(HttpRequestMessage message, diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 7c1760ac5..d71ba1362 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -100,16 +100,16 @@ private HttpUtil() private Dictionary _HttpClients = new Dictionary(); - internal HttpClient GetHttpClient(HttpClientConfig config) + internal HttpClient GetHttpClient(HttpClientConfig config, DelegatingHandler customHandler = null) { lock (httpClientProviderLock) { - return RegisterNewHttpClientIfNecessary(config); + return RegisterNewHttpClientIfNecessary(config, customHandler); } } - private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config) + private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config, DelegatingHandler customHandler = null) { string name = config.ConfKey; if (!_HttpClients.ContainsKey(name)) @@ -117,7 +117,7 @@ private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config) logger.Debug("Http client not registered. Adding."); var httpClient = new HttpClient( - new RetryHandler(SetupCustomHttpHandler(config), config.DisableRetry, config.ForceRetryOn404, config.MaxHttpRetries, config.IncludeRetryReason)) + new RetryHandler(SetupCustomHttpHandler(config, customHandler), config.DisableRetry, config.ForceRetryOn404, config.MaxHttpRetries, config.IncludeRetryReason)) { Timeout = Timeout.InfiniteTimeSpan }; @@ -129,8 +129,13 @@ private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config) return _HttpClients[name]; } - internal HttpMessageHandler SetupCustomHttpHandler(HttpClientConfig config) + internal HttpMessageHandler SetupCustomHttpHandler(HttpClientConfig config, DelegatingHandler customHandler = null) { + if (customHandler != null) + { + return customHandler; + } + HttpMessageHandler httpHandler; try { @@ -394,9 +399,6 @@ protected override async Task SendAsync(HttpRequestMessage catch (Exception e) { lastException = e; - Exception mostInnerException = e; - while (mostInnerException.InnerException != null) mostInnerException = mostInnerException.InnerException; - if (cancellationToken.IsCancellationRequested) { logger.Debug("SF rest request timeout or explicit cancel called."); @@ -407,15 +409,21 @@ protected override async Task SendAsync(HttpRequestMessage logger.Warn("Http request timeout. Retry the request"); totalRetryTime += (int)httpTimeout.TotalSeconds; } - else if (mostInnerException is AuthenticationException) - { - logger.Error("Non-retryable error encountered: ", e); - throw; - } else { - //TODO: Should probably check to see if the error is recoverable or transient. - logger.Warn("Error occurred during request, retrying...", e); + Exception innermostException = e; + while (innermostException.InnerException != null) innermostException = innermostException.InnerException; + + if (innermostException is AuthenticationException) + { + logger.Error("Non-retryable error encountered: ", e); + throw; + } + else + { + //TODO: Should probably check to see if the error is recoverable or transient. + logger.Warn("Error occurred during request, retrying...", e); + } } }