diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index fdcd0508b..ddd73167d 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -3,6 +3,7 @@ */ using System.Data.Common; +using System.Net; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests @@ -520,12 +521,12 @@ public void TestDefaultLoginTimeout() } [Test] - public void TestConnectionFailFast() + public void TestConnectionFailFastForNonRetried404OnLogin() { using (var conn = new SnowflakeDbConnection()) { // Just a way to get a 404 on the login request and make sure there are no retry - string invalidConnectionString = "host=learn.microsoft.com;" + string invalidConnectionString = "host=google.com/404;" + "connection_timeout=0;account=testFailFast;user=testFailFast;password=testFailFast;"; conn.ConnectionString = invalidConnectionString; @@ -538,8 +539,12 @@ public void TestConnectionFailFast() } catch (SnowflakeDbException e) { - Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute().errorCode, - e.ErrorCode); + SnowflakeDbExceptionAssert.HasHttpErrorCodeInExceptionChain(e, HttpStatusCode.NotFound); + SnowflakeDbExceptionAssert.HasMessageInExceptionChain(e, "404 (Not Found)"); + } + catch (Exception unexpected) + { + Assert.Fail($"Unexpected {unexpected.GetType()} exception occurred"); } Assert.AreEqual(ConnectionState.Closed, conn.State); @@ -547,11 +552,11 @@ public void TestConnectionFailFast() } [Test] - public void TestEnableRetry() + public void TestEnableLoginRetryOn404() { using (var conn = new SnowflakeDbConnection()) { - string invalidConnectionString = "host=learn.microsoft.com;" + string invalidConnectionString = "host=google.com/404;" + "connection_timeout=0;account=testFailFast;user=testFailFast;password=testFailFast;disableretry=true;forceretryon404=true"; conn.ConnectionString = invalidConnectionString; @@ -563,8 +568,12 @@ public void TestEnableRetry() } catch (SnowflakeDbException e) { - Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute().errorCode, - e.ErrorCode); + SnowflakeDbExceptionAssert.HasErrorCode(e, SFError.INTERNAL_ERROR); + SnowflakeDbExceptionAssert.HasHttpErrorCodeInExceptionChain(e, HttpStatusCode.NotFound); + } + catch (Exception unexpected) + { + Assert.Fail($"Unexpected {unexpected.GetType()} exception occurred"); } Assert.AreEqual(ConnectionState.Closed, conn.State); @@ -1947,12 +1956,12 @@ public void TestAsyncDefaultLoginTimeout() } [Test] - public void TestAsyncConnectionFailFast() + public void TestAsyncConnectionFailFastForNonRetried404OnLogin() { using (var conn = new SnowflakeDbConnection()) { // Just a way to get a 404 on the login request and make sure there are no retry - string invalidConnectionString = "host=learn.microsoft.com;" + string invalidConnectionString = "host=google.com/404;" + "connection_timeout=0;account=testFailFast;user=testFailFast;password=testFailFast;"; conn.ConnectionString = invalidConnectionString; @@ -1967,7 +1976,12 @@ public void TestAsyncConnectionFailFast() } catch (AggregateException e) { - SnowflakeDbExceptionAssert.HasErrorCode((SnowflakeDbException)e.InnerException, SFError.INTERNAL_ERROR); + SnowflakeDbExceptionAssert.HasHttpErrorCodeInExceptionChain(e, HttpStatusCode.NotFound); + SnowflakeDbExceptionAssert.HasMessageInExceptionChain(e, "404 (Not Found)"); + } + catch (Exception unexpected) + { + Assert.Fail($"Unexpected {unexpected.GetType()} exception occurred"); } Assert.AreEqual(ConnectionState.Closed, conn.State); diff --git a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs index b3d3e50db..62d1e24ff 100644 --- a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs +++ b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs @@ -1,3 +1,8 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.Http; using Snowflake.Data.Core; using Snowflake.Data.Client; using NUnit.Framework; @@ -10,5 +15,73 @@ public static void HasErrorCode(SnowflakeDbException exception, SFError sfError) { Assert.AreEqual(exception.ErrorCode, sfError.GetAttribute().errorCode); } + + public static void HasErrorCode(Exception exception, SFError sfError) + { + Assert.NotNull(exception); + switch (exception) + { + case SnowflakeDbException snowflakeDbException: + Assert.AreEqual(snowflakeDbException.ErrorCode, sfError.GetAttribute().errorCode); + break; + default: + Assert.Fail(exception.GetType() + " type is not " + typeof(SnowflakeDbException)); + break; + } + } + + public static void HasHttpErrorCodeInExceptionChain(Exception exception, HttpStatusCode expected) + { + var exceptions = CollectExceptions(exception); + Assert.AreEqual(true, + exceptions.Any(e => + { + switch (e) + { + case SnowflakeDbException se: + return se.ErrorCode == (int)expected; + case HttpRequestException he: +#if NETFRAMEWORK + return he.Message.Contains(((int)expected).ToString()); +#else + return he.StatusCode == expected; +#endif + default: + return false; + } + }), + $"Any of exceptions in the chain should have HTTP Status: {expected}"); + } + + public static void HasMessageInExceptionChain(Exception exception, string expected) + { + var exceptions = CollectExceptions(exception); + Assert.AreEqual(true, + exceptions.Any(e => e.Message.Contains(expected)), + $"Any of exceptions in the chain should contain message: {expected}"); + } + + private static List CollectExceptions(Exception exception) + { + var collected = new List(); + if (exception is null) + return collected; + switch (exception) + { + case AggregateException aggregate: + var inner = aggregate.Flatten().InnerExceptions; + // collected.AddRange(inner.OfType()); + collected.AddRange(inner); + collected.AddRange(inner + .Where(e => e.InnerException != null) + .SelectMany(e => CollectExceptions(e.InnerException))); + break; + case Exception general: + collected.AddRange(CollectExceptions(general.InnerException)); + collected.Add(general); + break; + } + return collected; + } } }