From c6684195b12feb005806fc28a1fc6d6fd8328b63 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 09:48:18 -0800 Subject: [PATCH 1/7] SNOW-916949: Fix Okta retry for SSO/SAML endpoints --- .../Core/Authenticator/OktaAuthenticator.cs | 156 +++++++++++++++--- Snowflake.Data/Core/HttpUtil.cs | 22 ++- Snowflake.Data/Core/RestParams.cs | 6 + Snowflake.Data/Core/Session/SFSession.cs | 6 + 4 files changed, 162 insertions(+), 28 deletions(-) diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 296bf518b..5df3975c4 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -12,6 +12,7 @@ using Snowflake.Data.Client; using System.Text; using System.Web; +using System.Linq; namespace Snowflake.Data.Core.Authenticator { @@ -22,6 +23,9 @@ class OktaAuthenticator : BaseAuthenticator, IAuthenticator { private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); + internal const string RetryCountHeader = "RetryCount"; + internal const string TimeoutElapsedHeader = "TimeoutElapsed"; + /// /// url of the okta idp /// @@ -59,23 +63,54 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) logger.Debug("Checking token url"); VerifyUrls(tokenUrl, oktaUrl); - logger.Debug("step 3: get idp onetime token"); - IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); - var idpResponse = await session.restRequester.PostAsync(idpTokenRestRequest, cancellationToken).ConfigureAwait(false); - string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + int retryCount = 0; + int timeoutElapsed = 0; + Exception lastRetryException = null; + HttpResponseMessage samlRawResponse = null; - logger.Debug("step 4: get SAML reponse from sso"); - var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); - using (var samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false)) - { - samlRawHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); + // If VerifyPostbackUrl() fails, retry with new onetimetoken + while (RetryLimitIsNotReached(retryCount, timeoutElapsed)) + { + try + { + logger.Debug("step 3: get idp onetime token"); + IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); + var idpResponse = await session.restRequester.PostAsync(idpTokenRestRequest, cancellationToken).ConfigureAwait(false); + string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + + logger.Debug("step 4: get SAML reponse from sso"); + var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); + samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false); + samlRawHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); + + logger.Debug("step 5: verify postback url in SAML reponse"); + VerifyPostbackUrl(); + + logger.Debug("step 6: send SAML reponse to snowflake to login"); + await base.LoginAsync(cancellationToken).ConfigureAwait(false); + break; + } + catch (Exception ex) + { + lastRetryException = ex; + if (IsPostbackUrlNotFound(lastRetryException)) + { + logger.Info("Refreshing token for Okta re-authentication and starting from step 3 again"); + + // Get the current retry count and timeout elapsed from the response headers + retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First()); + timeoutElapsed += int.Parse(samlRawResponse.Content.Headers.GetValues(TimeoutElapsedHeader).First()); + } + else + { + logger.Error("Failed to get the correct SAML response from Okta", ex); + throw; + } + } } - logger.Debug("step 5: verify postback url in SAML reponse"); - VerifyPostbackUrl(); - - logger.Debug("step 6: send SAML reponse to snowflake to login"); - await base.LoginAsync(cancellationToken).ConfigureAwait(false); + // Throw exception if max retry count or max timeout has been reached + ThrowRetryLimitException(retryCount, timeoutElapsed, lastRetryException); } void IAuthenticator.Authenticate() @@ -95,23 +130,55 @@ void IAuthenticator.Authenticate() logger.Debug("Checking token url"); VerifyUrls(tokenUrl, oktaUrl); - logger.Debug("step 3: get idp onetime token"); - IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); - var idpResponse = session.restRequester.Post(idpTokenRestRequest); - string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + int retryCount = 0; + int timeoutElapsed = 0; + Exception lastRetryException = null; + HttpResponseMessage samlRawResponse = null; - logger.Debug("step 4: get SAML reponse from sso"); - var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); - using (var samlRawResponse = session.restRequester.Get(samlRestRequest)) + // If VerifyPostbackUrl() fails, retry with new onetimetoken + while (RetryLimitIsNotReached(retryCount, timeoutElapsed)) { - samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false)).Result; - } + try + { + logger.Debug("step 3: get idp onetime token"); + IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); + var idpResponse = session.restRequester.Post(idpTokenRestRequest); + string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + + logger.Debug("step 4: get SAML reponse from sso"); + var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); + samlRawResponse = session.restRequester.Get(samlRestRequest); + samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false)).Result; - logger.Debug("step 5: verify postback url in SAML reponse"); - VerifyPostbackUrl(); + logger.Debug("step 5: verify postback url in SAML reponse"); + VerifyPostbackUrl(); - logger.Debug("step 6: send SAML reponse to snowflake to login"); - base.Login(); + + logger.Debug("step 6: send SAML reponse to snowflake to login"); + base.Login(); + break; + } + catch(Exception ex) + { + lastRetryException = ex; + if (IsPostbackUrlNotFound(lastRetryException)) + { + logger.Debug("Refreshing token for Okta re-authentication and starting from step 3 again"); + + // Get the current retry count and timeout elapsed from the response headers + retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First()); + timeoutElapsed += int.Parse(samlRawResponse.Content.Headers.GetValues(TimeoutElapsedHeader).First()); + } + else + { + logger.Error("Failed to get the correct SAML response from Okta", ex); + throw; + } + } + } + + // Throw exception if max retry count or max timeout has been reached + ThrowRetryLimitException(retryCount, timeoutElapsed, lastRetryException); } private SFRestRequest BuildAuthenticatorRestRequest() @@ -215,6 +282,41 @@ private void FilterFailedResponse(BaseRestResponse response) throw e; } } + + private bool RetryLimitIsNotReached(int retryCount, int timeoutElapsed) + { + return retryCount < session._maxRetryCount && timeoutElapsed < session._maxRetryTimeout; + } + + private bool IsPostbackUrlNotFound(Exception ex) + { + if (ex is SnowflakeDbException) + { + SnowflakeDbException error = ex as SnowflakeDbException; + return error.ErrorCode == SFError.IDP_SAML_POSTBACK_NOTFOUND.GetAttribute().errorCode; + } + + return false; + } + + private void ThrowRetryLimitException(int retryCount, int timeoutElapsed, Exception lastRetryException) + { + string errorMessage = ""; + if (retryCount >= session._maxRetryCount) + { + errorMessage = $"The retry count has reached its limit of {session._maxRetryCount}"; + } + if (timeoutElapsed >= session._maxRetryTimeout) + { + errorMessage += string.IsNullOrEmpty(errorMessage) ? "The" : " and the"; + errorMessage += $" timeout elapsed has reached its limit of {session._maxRetryTimeout}"; + + } + errorMessage += " while trying to authenticate through Okta"; + + logger.Error(errorMessage); + throw new SnowflakeDbException(lastRetryException, SFError.INTERNAL_ERROR, errorMessage); + } } internal class IdpTokenRestRequest : BaseRestRequest, IRestRequest diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index b48d5a3a9..9c5e22442 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -14,6 +14,7 @@ using System.Security.Authentication; using System.Runtime.InteropServices; using System.Linq; +using Snowflake.Data.Core.Authenticator; namespace Snowflake.Data.Core { @@ -342,7 +343,9 @@ protected override async Task SendAsync(HttpRequestMessage CancellationToken cancellationToken) { HttpResponseMessage response = null; - bool isLoginRequest = IsLoginEndpoint(requestMessage.RequestUri.AbsolutePath); + string absolutePath = requestMessage.RequestUri.AbsolutePath; + bool isLoginRequest = IsLoginEndpoint(absolutePath); + bool isOktaSSORequest = IsOktaSSORequest(requestMessage.RequestUri.Host, absolutePath); int backOffInSec = s_baseBackOffTime; int totalRetryTime = 0; @@ -411,6 +414,12 @@ protected override async Task SendAsync(HttpRequestMessage if (response != null) { + if (isOktaSSORequest) + { + response.Content.Headers.Add(OktaAuthenticator.RetryCountHeader, retryCount.ToString()); + response.Content.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, totalRetryTime.ToString()); + } + if (response.IsSuccessStatusCode) { logger.Debug($"Success Response: StatusCode: {(int)response.StatusCode}, ReasonPhrase: '{response.ReasonPhrase}'"); @@ -534,6 +543,17 @@ static internal bool IsLoginEndpoint(string endpoint) { return null != s_supportedEndpointsForRetryPolicy.FirstOrDefault(ep => endpoint.Equals(ep)); } + + /// + /// Checks if request is for Okta and an SSO SAML endpoint. + /// + /// The host url to check. + /// The endpoint to check. + /// True if the endpoint is an okta sso saml request, false otherwise. + static internal bool IsOktaSSORequest(string host, string endpoint) + { + return host.Contains(OktaUrl.DOMAIN) && endpoint.Contains(OktaUrl.SSO_SAML_PATH); + } } } diff --git a/Snowflake.Data/Core/RestParams.cs b/Snowflake.Data/Core/RestParams.cs index 1188affb0..9dd4de8c8 100644 --- a/Snowflake.Data/Core/RestParams.cs +++ b/Snowflake.Data/Core/RestParams.cs @@ -44,6 +44,12 @@ internal static class RestPath internal const string SF_CONSOLE_LOGIN = "/console/login"; } + internal static class OktaUrl + { + internal const string DOMAIN = "okta.com"; + internal const string SSO_SAML_PATH = "/sso/saml"; + } + internal class SFEnvironment { static SFEnvironment() diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 2ad440407..e39370f19 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -79,6 +79,10 @@ public class SFSession internal bool _disableConsoleLogin; + internal int _maxRetryCount; + + internal int _maxRetryTimeout; + internal void ProcessLoginResponse(LoginResponse authnResponse) { if (authnResponse.success) @@ -164,6 +168,8 @@ internal SFSession( connectionTimeout = extractedProperties.TimeoutDuration(); properties.TryGetValue(SFSessionProperty.CLIENT_CONFIG_FILE, out var easyLoggingConfigFile); _easyLoggingStarter.Init(easyLoggingConfigFile); + _maxRetryCount = extractedProperties.maxHttpRetries; + _maxRetryTimeout = extractedProperties.retryTimeout; } catch (Exception e) { From 0fd4856de7c07ca5f8b368e2b88a21cc425cfa53 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 10:08:56 -0800 Subject: [PATCH 2/7] SNOW-916949: Add tests --- .../IntegrationTests/SFConnectionIT.cs | 71 +++++++++++++++++++ .../Mock/MockOktaRetryMaxTimeout.cs | 55 ++++++++++++++ .../UnitTests/HttpUtilTest.cs | 18 +++++ Snowflake.Data.Tests/UnitTests/SFOktaTest.cs | 12 +++- 4 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 5a4976162..2470068eb 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -817,6 +817,40 @@ public void TestOktaConnection() } } + [Test] + public void TestOktaConnectionUntilMaxTimeout() + { + var expectedMaxRetryCount = 15; + var expectedMaxConnectionTimeout = 450; + var mockRestRequester = new MockOktaRetryMaxTimeout(expectedMaxRetryCount, expectedMaxConnectionTimeout); + using (DbConnection conn = new MockSnowflakeDbConnection(mockRestRequester)) + { + try + { + conn.ConnectionString + = ConnectionStringWithoutAuth + + String.Format( + ";authenticator={0};user={1};password={2};MAXHTTPRETRIES={3};RETRY_TIMEOUT={4};", + testConfig.oktaUrl, + testConfig.oktaUser, + testConfig.oktaPassword, + expectedMaxRetryCount, + expectedMaxConnectionTimeout); + conn.Open(); + Assert.Fail(); + } + catch (Exception e) + { + Assert.IsInstanceOf(e); + Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute().errorCode, ((SnowflakeDbException)e).ErrorCode); + Assert.IsTrue(e.Message.Contains( + $"The retry count has reached its limit of {expectedMaxRetryCount} and " + + $"the timeout elapsed has reached its limit of {expectedMaxConnectionTimeout} " + + "while trying to authenticate through Okta")); + } + } + } + [Test] [Ignore("This test requires manual setup and therefore cannot be run in CI")] public void TestOkta2ConnectionsFollowingEachOther() @@ -2056,6 +2090,43 @@ public void TestExplicitTransactionOperationsTracked() Assert.AreEqual(false, conn.HasActiveExplicitTransaction()); } } + + + [Test] + public void TestAsyncOktaConnectionUntilMaxTimeout() + { + var expectedMaxRetryCount = 15; + var expectedMaxConnectionTimeout = 450; + var mockRestRequester = new MockOktaRetryMaxTimeout(expectedMaxRetryCount, expectedMaxConnectionTimeout); + using (DbConnection conn = new MockSnowflakeDbConnection(mockRestRequester)) + { + Task connectTask = null; + try + { + conn.ConnectionString + = ConnectionStringWithoutAuth + + String.Format( + ";authenticator={0};user={1};password={2};MAXHTTPRETRIES={3};RETRY_TIMEOUT={4};", + testConfig.oktaUrl, + testConfig.oktaUser, + testConfig.oktaPassword, + expectedMaxRetryCount, + expectedMaxConnectionTimeout); + connectTask = conn.OpenAsync(CancellationToken.None); + connectTask.Wait(); + Assert.Fail(); + } + catch (Exception e) + { + Assert.IsInstanceOf(e.InnerException); + Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute().errorCode, ((SnowflakeDbException)e.InnerException).ErrorCode); + Assert.IsTrue(e.InnerException.InnerException.Message.Contains( + $"The retry count has reached its limit of {expectedMaxRetryCount} and " + + $"the timeout elapsed has reached its limit of {expectedMaxConnectionTimeout} " + + "while trying to authenticate through Okta")); + } + } + } } } diff --git a/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs b/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs new file mode 100644 index 000000000..20f7c573d --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Newtonsoft.Json; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Authenticator; +using System; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Snowflake.Data.Tests.Mock +{ + + class MockOktaRetryMaxTimeout : RestRequester, IMockRestRequester + { + internal bool _forceTimeoutForNonLoginRequestsOnly = false; + internal int _maxRetryCount; + internal int _maxRetryTimeout; + + public MockOktaRetryMaxTimeout(int maxRetryCount, int maxRetryTimeout) : base(null) + { + _maxRetryCount = maxRetryCount; + _maxRetryTimeout = maxRetryTimeout; + } + + public void setHttpClient(HttpClient httpClient) + { + _HttpClient = httpClient; + } + + protected override async Task SendAsync(HttpRequestMessage message, + TimeSpan restTimeout, + CancellationToken externalCancellationToken, + string sid = "") + { + if (HttpUtil.IsOktaSSORequest(message.RequestUri.Host, message.RequestUri.AbsolutePath)) + { + var mockContent = new StringContent(JsonConvert.SerializeObject("().errorCode, e.ErrorCode); + Assert.AreEqual(SFError.IDP_SAML_POSTBACK_NOTFOUND.GetAttribute().errorCode, ((SnowflakeDbException)e.InnerException).ErrorCode); } + noPostbackContent.Headers.Remove(OktaAuthenticator.RetryCountHeader); + noPostbackContent.Headers.Remove(OktaAuthenticator.TimeoutElapsedHeader); } [Test] From 5bed4df12320e5d4591fba9f3673338a82874f4f Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 12:48:25 -0800 Subject: [PATCH 3/7] SNOW-916949: Skip test that require manual setup --- Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 2470068eb..f312a02bc 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -818,6 +818,7 @@ public void TestOktaConnection() } [Test] + [Ignore("This test requires manual setup and therefore cannot be run in CI")] public void TestOktaConnectionUntilMaxTimeout() { var expectedMaxRetryCount = 15; @@ -2093,6 +2094,7 @@ public void TestExplicitTransactionOperationsTracked() [Test] + [Ignore("This test requires manual setup and therefore cannot be run in CI")] public void TestAsyncOktaConnectionUntilMaxTimeout() { var expectedMaxRetryCount = 15; From a572300f1d90406387118304761656d9ca39c075 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 14:18:51 -0800 Subject: [PATCH 4/7] SNOW-916949: Refactor mock class and remove skip tag --- .../IntegrationTests/SFConnectionIT.cs | 38 ++++++++----- Snowflake.Data.Tests/Mock/MockOkta.cs | 4 ++ .../Mock/MockOktaRetryMaxTimeout.cs | 55 ------------------- Snowflake.Data.Tests/UnitTests/SFOktaTest.cs | 18 +++--- .../Core/Authenticator/OktaAuthenticator.cs | 1 - 5 files changed, 38 insertions(+), 78 deletions(-) delete mode 100644 Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index f312a02bc..534d90a58 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -17,6 +17,7 @@ namespace Snowflake.Data.Tests.IntegrationTests using System.Diagnostics; using Snowflake.Data.Tests.Mock; using System.Runtime.InteropServices; + using System.Net.Http; [TestFixture] class SFConnectionIT : SFBaseTest @@ -818,12 +819,19 @@ public void TestOktaConnection() } [Test] - [Ignore("This test requires manual setup and therefore cannot be run in CI")] public void TestOktaConnectionUntilMaxTimeout() { var expectedMaxRetryCount = 15; var expectedMaxConnectionTimeout = 450; - var mockRestRequester = new MockOktaRetryMaxTimeout(expectedMaxRetryCount, expectedMaxConnectionTimeout); + var oktaUrl = "https://test.okta.com"; + var mockRestRequester = new MockOktaRestRequester() + { + TokenUrl = $"{oktaUrl}/api/v1/sessions?additionalFields=cookieToken", + SSOUrl = $"{oktaUrl}/app/testaccount/sso/saml", + ResponseContent = new StringContent("(IRestRequest request) { @@ -31,6 +33,8 @@ public Task GetAsync(IRestRequest request, CancellationToke { var response = new HttpResponseMessage(System.Net.HttpStatusCode.OK); response.Content = ResponseContent; + response.Content.Headers.Add(OktaAuthenticator.RetryCountHeader, MaxRetryCount.ToString()); + response.Content.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, MaxRetryTimeout.ToString()); return Task.FromResult(response); } diff --git a/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs b/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs deleted file mode 100644 index 20f7c573d..000000000 --- a/Snowflake.Data.Tests/Mock/MockOktaRetryMaxTimeout.cs +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. - */ - -using Newtonsoft.Json; -using Snowflake.Data.Core; -using Snowflake.Data.Core.Authenticator; -using System; -using System.Net.Http; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace Snowflake.Data.Tests.Mock -{ - - class MockOktaRetryMaxTimeout : RestRequester, IMockRestRequester - { - internal bool _forceTimeoutForNonLoginRequestsOnly = false; - internal int _maxRetryCount; - internal int _maxRetryTimeout; - - public MockOktaRetryMaxTimeout(int maxRetryCount, int maxRetryTimeout) : base(null) - { - _maxRetryCount = maxRetryCount; - _maxRetryTimeout = maxRetryTimeout; - } - - public void setHttpClient(HttpClient httpClient) - { - _HttpClient = httpClient; - } - - protected override async Task SendAsync(HttpRequestMessage message, - TimeSpan restTimeout, - CancellationToken externalCancellationToken, - string sid = "") - { - if (HttpUtil.IsOktaSSORequest(message.RequestUri.Host, message.RequestUri.AbsolutePath)) - { - var mockContent = new StringContent(JsonConvert.SerializeObject("< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); StringContent noPostbackContent = new StringContent(" < !DOCTYPE html >< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); + const int MaxRetryCount = 1; + const int MaxRetryTimeout = 1; + [Test] public void TestSsoTokenUrlMismatch() { @@ -21,6 +23,8 @@ public void TestSsoTokenUrlMismatch() { TokenUrl = "https://snowflakecomputing.okta1.com/api/v1/sessions?additionalFields=cookieToken", SSOUrl = "https://snowflakecomputing.okta.com/app/snowflake_testaccountdev_1/blah/sso/saml", + MaxRetryCount = MaxRetryCount, + MaxRetryTimeout = MaxRetryTimeout }; var sfSession = new SFSession("account=test;user=test;password=test;authenticator=https://snowflake.okta.com", null, restRequester); sfSession.Open(); @@ -34,10 +38,6 @@ public void TestSsoTokenUrlMismatch() [Test] public void TestMissingPostbackUrl() { - const int retryCount = 1; - const int retryTimeout = 1; - noPostbackContent.Headers.Add(OktaAuthenticator.RetryCountHeader, retryCount.ToString()); - noPostbackContent.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, retryTimeout.ToString()); try { var restRequester = new Mock.MockOktaRestRequester() @@ -45,17 +45,17 @@ public void TestMissingPostbackUrl() TokenUrl = "https://snowflakecomputing.okta.com/api/v1/sessions?additionalFields=cookieToken", SSOUrl = "https://snowflakecomputing.okta.com/app/snowflake_testaccountdev_1/blah/sso/saml", ResponseContent = noPostbackContent, + MaxRetryCount = MaxRetryCount, + MaxRetryTimeout = MaxRetryTimeout }; var sfSession = new SFSession("account=test;user=test;password=test;authenticator=https://snowflakecomputing.okta.com;" + - $"host=test;MAXHTTPRETRIES={retryCount};RETRY_TIMEOUT={retryTimeout};", null, restRequester); + $"host=test;MAXHTTPRETRIES={MaxRetryCount};RETRY_TIMEOUT={MaxRetryTimeout};", null, restRequester); sfSession.Open(); Assert.Fail("Should not pass"); } catch (SnowflakeDbException e) { Assert.AreEqual(SFError.IDP_SAML_POSTBACK_NOTFOUND.GetAttribute().errorCode, ((SnowflakeDbException)e.InnerException).ErrorCode); } - noPostbackContent.Headers.Remove(OktaAuthenticator.RetryCountHeader); - noPostbackContent.Headers.Remove(OktaAuthenticator.TimeoutElapsedHeader); } [Test] @@ -68,6 +68,8 @@ public void TestWrongPostbackUrl() TokenUrl = "https://snowflakecomputing.okta.com/api/v1/sessions?additionalFields=cookieToken", SSOUrl = "https://snowflakecomputing.okta.com/app/snowflake_testaccountdev_1/blah/sso/saml", ResponseContent = wrongPostbackContent, + MaxRetryCount = MaxRetryCount, + MaxRetryTimeout = MaxRetryTimeout }; var sfSession = new SFSession("account=test;user=test;password=test;authenticator=https://snowflakecomputing.okta.com;host=test", null, restRequester); sfSession.Open(); diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 5df3975c4..be04ceab7 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -153,7 +153,6 @@ void IAuthenticator.Authenticate() logger.Debug("step 5: verify postback url in SAML reponse"); VerifyPostbackUrl(); - logger.Debug("step 6: send SAML reponse to snowflake to login"); base.Login(); break; From 14bec0133de2b79f520090b1447235533d900bee Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 15:06:51 -0800 Subject: [PATCH 5/7] SNOW-916949: Fix async test for NET Framework --- Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs | 8 +++++++- Snowflake.Data.Tests/UnitTests/SFOktaTest.cs | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 534d90a58..fe9419939 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2132,7 +2132,13 @@ public void TestAsyncOktaConnectionUntilMaxTimeout() { Assert.IsInstanceOf(e.InnerException); Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute().errorCode, ((SnowflakeDbException)e.InnerException).ErrorCode); - Assert.IsTrue(e.InnerException.InnerException.Message.Contains( + Exception oktaException; +#if NETFRAMEWORK + oktaException = e.InnerException.InnerException.InnerException; +#else + oktaException = e.InnerException.InnerException; +#endif + Assert.IsTrue(oktaException.Message.Contains( $"The retry count has reached its limit of {expectedMaxRetryCount} and " + $"the timeout elapsed has reached its limit of {expectedMaxConnectionTimeout} " + "while trying to authenticate through Okta")); diff --git a/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs b/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs index db07f1c66..e801bd30d 100644 --- a/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs @@ -11,8 +11,8 @@ class SFOktaTest StringContent wrongPostbackContent = new StringContent(" < !DOCTYPE html >< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); StringContent noPostbackContent = new StringContent(" < !DOCTYPE html >< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); - const int MaxRetryCount = 1; - const int MaxRetryTimeout = 1; + const int MaxRetryCount = 15; + const int MaxRetryTimeout = 400; [Test] public void TestSsoTokenUrlMismatch() From 1d6ac86fea306fc9a696a31253717c6ceaee75d5 Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf Date: Fri, 9 Feb 2024 15:45:27 -0800 Subject: [PATCH 6/7] SNOW-916949: Fix throwing an exception for correct authentication --- Snowflake.Data.Tests/Mock/MockOkta.cs | 45 ++++++++++++++---- Snowflake.Data.Tests/UnitTests/SFOktaTest.cs | 47 ++++++++++++++++++- .../Core/Authenticator/OktaAuthenticator.cs | 4 +- 3 files changed, 84 insertions(+), 12 deletions(-) diff --git a/Snowflake.Data.Tests/Mock/MockOkta.cs b/Snowflake.Data.Tests/Mock/MockOkta.cs index d83c5d1f5..56e245632 100644 --- a/Snowflake.Data.Tests/Mock/MockOkta.cs +++ b/Snowflake.Data.Tests/Mock/MockOkta.cs @@ -47,18 +47,45 @@ public Task PostAsync(IRestRequest postRequest, CancellationToken cancella { if (postRequest is SFRestRequest) { - // authenticator - var authnResponse = new AuthenticatorResponse + if (((SFRestRequest)postRequest).jsonBody is AuthenticatorRequest) { - success = true, - data = new AuthenticatorResponseData + // authenticator + var authnResponse = new AuthenticatorResponse { - tokenUrl = TokenUrl, - ssoUrl = SSOUrl, - } - }; + success = true, + data = new AuthenticatorResponseData + { + tokenUrl = TokenUrl, + ssoUrl = SSOUrl, + } + }; + + return Task.FromResult((T)(object)authnResponse); + } + else + { + // login + var loginResponse = new LoginResponse + { + success = true, + data = new LoginResponseData + { + sessionId = "", + token = "", + masterToken = "", + masterValidityInSeconds = 0, + authResponseSessionInfo = new SessionInfo + { + databaseName = "", + schemaName = "", + roleName = "", + warehouseName = "", + } + } + }; - return Task.FromResult((T)(object)authnResponse); + return Task.FromResult((T)(object)loginResponse); + } } else { diff --git a/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs b/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs index e801bd30d..75c73824d 100644 --- a/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFOktaTest.cs @@ -2,6 +2,8 @@ using Snowflake.Data.Client; using Snowflake.Data.Core; using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace Snowflake.Data.Tests.UnitTests { @@ -10,7 +12,7 @@ class SFOktaTest { StringContent wrongPostbackContent = new StringContent(" < !DOCTYPE html >< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); StringContent noPostbackContent = new StringContent(" < !DOCTYPE html >< html lang =\"en\">\n\n\n\n\n\n\n\n\n\n\n\nSnowflake - Signing in...\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
\n\n\n\n \n\n
\n
\n
\n
\n \"Please\"Please\"Okta\"
\n
\n

Signing in to TESTACCOUNT (regression)

\n
\n
\n
\n\n \n
\n \n \n
\n\n
\n\n\n\n"); - + StringContent correctPostbackContent = new StringContent("
Date: Fri, 9 Feb 2024 16:12:16 -0800 Subject: [PATCH 7/7] SNOW-916949: Make log level match other log messages --- Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index dae821878..6e258af5b 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -95,7 +95,7 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) lastRetryException = ex; if (IsPostbackUrlNotFound(lastRetryException)) { - logger.Info("Refreshing token for Okta re-authentication and starting from step 3 again"); + logger.Debug("Refreshing token for Okta re-authentication and starting from step 3 again"); // Get the current retry count and timeout elapsed from the response headers retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First());