From c982f4b23fb7d373157b648a70c49e722037a97c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Wed, 6 Mar 2024 16:28:30 +0100 Subject: [PATCH] SNOW-1206259 Handle OKTA Auth invalid responses more gracefully --- .../IntegrationTests/SFConnectionIT.cs | 16 ++ .../Core/Authenticator/OktaAuthenticator.cs | 196 ++++++++---------- 2 files changed, 108 insertions(+), 104 deletions(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index c8c250ed4..816e064b3 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2153,6 +2153,22 @@ public void TestAsyncOktaConnectionUntilMaxTimeout() } } } + + [Test] + [Ignore("This test requires established dev Okta SSO and credentials matching Snowflake user")] + public void TestNativeOktaSuccess() + { + var oktaUrl = "https://***.okta.com/"; + var oktaUser = "***"; + var oktaPassword = "***"; + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionStringWithoutAuth + + $";authenticator={oktaUrl};user={oktaUser};password={oktaPassword};"; + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } } } diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 6e258af5b..cca377512 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -17,11 +17,11 @@ namespace Snowflake.Data.Core.Authenticator { /// - /// OktaAuthenticator would perform serveral steps of authentication with Snowflake and Okta idp + /// OktaAuthenticator would perform several steps of authentication with Snowflake and Okta IdP /// class OktaAuthenticator : BaseAuthenticator, IAuthenticator { - private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); internal const string RetryCountHeader = "RetryCount"; internal const string TimeoutElapsedHeader = "TimeoutElapsed"; @@ -29,10 +29,9 @@ class OktaAuthenticator : BaseAuthenticator, IAuthenticator /// /// url of the okta idp /// - private Uri oktaUrl; + private readonly Uri _oktaUrl; - // The raw Saml token. - private string samlRawHtmlString; + private string _rawSamlTokenHtmlString; /// /// Constructor of the Okta authenticator @@ -42,72 +41,63 @@ class OktaAuthenticator : BaseAuthenticator, IAuthenticator internal OktaAuthenticator(SFSession session, string oktaUriString) : base(session, oktaUriString) { - oktaUrl = new Uri(oktaUriString); + _oktaUrl = new Uri(oktaUriString); } /// async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) { - logger.Info("Okta Authentication"); + s_logger.Info("Okta Authentication"); - logger.Debug("step 1: get sso and token url"); + s_logger.Debug("step 1: Get SSO and token URL"); var authenticatorRestRequest = BuildAuthenticatorRestRequest(); var authenticatorResponse = await session.restRequester.PostAsync(authenticatorRestRequest, cancellationToken).ConfigureAwait(false); authenticatorResponse.FilterFailedResponse(); Uri ssoUrl = new Uri(authenticatorResponse.data.ssoUrl); Uri tokenUrl = new Uri(authenticatorResponse.data.tokenUrl); - logger.Debug("step 2: verify urls fetched from step 1"); - logger.Debug("Checking sso url"); - VerifyUrls(ssoUrl, oktaUrl); - logger.Debug("Checking token url"); - VerifyUrls(tokenUrl, oktaUrl); + s_logger.Debug("step 2: Verify URLs fetched from step 1"); + s_logger.Debug("Checking SSO Okta URL"); + VerifyUrls(ssoUrl, _oktaUrl); + s_logger.Debug("Checking token URL"); + VerifyUrls(tokenUrl, _oktaUrl); int retryCount = 0; int timeoutElapsed = 0; Exception lastRetryException = null; HttpResponseMessage samlRawResponse = null; - // If VerifyPostbackUrl() fails, retry with new onetimetoken + // If VerifyPostbackUrl() fails, retry with new one-time token while (RetryLimitIsNotReached(retryCount, timeoutElapsed)) { try { - logger.Debug("step 3: get idp onetime token"); + s_logger.Debug("step 3: Get IdP one-time token"); IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); var idpResponse = await session.restRequester.PostAsync(idpTokenRestRequest, cancellationToken).ConfigureAwait(false); - string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + string onetimeToken = idpResponse.SessionToken ?? idpResponse.CookieToken; - logger.Debug("step 4: get SAML reponse from sso"); - var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); + s_logger.Debug("step 4: Get SAML response from SSO"); + var samlRestRequest = BuildSamlRestRequest(ssoUrl, onetimeToken); samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false); - samlRawHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); - - logger.Debug("step 5: verify postback url in SAML reponse"); +#if NETFRAMEWORK + _rawSamlTokenHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); +#else + _rawSamlTokenHtmlString = await samlRawResponse.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); +#endif + s_logger.Debug("step 5: Verify postback URL in SAML response"); VerifyPostbackUrl(); - logger.Debug("step 6: send SAML reponse to snowflake to login"); - await base.LoginAsync(cancellationToken).ConfigureAwait(false); + s_logger.Debug("step 6: Send SAML response to Snowflake to login"); + await LoginAsync(cancellationToken).ConfigureAwait(false); return; } catch (Exception ex) { lastRetryException = ex; - if (IsPostbackUrlNotFound(lastRetryException)) - { - logger.Debug("Refreshing token for Okta re-authentication and starting from step 3 again"); - - // Get the current retry count and timeout elapsed from the response headers - retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First()); - timeoutElapsed += int.Parse(samlRawResponse.Content.Headers.GetValues(TimeoutElapsedHeader).First()); - } - else - { - logger.Error("Failed to get the correct SAML response from Okta", ex); - throw; - } + HandleAuthenticatorException(ex, samlRawResponse, ref retryCount, ref timeoutElapsed); } - } + } // while retry // Throw exception if max retry count or max timeout has been reached ThrowRetryLimitException(retryCount, timeoutElapsed, lastRetryException); @@ -115,90 +105,100 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) void IAuthenticator.Authenticate() { - logger.Info("Okta Authentication"); + s_logger.Info("Okta Authentication"); - logger.Debug("step 1: get sso and token url"); + s_logger.Debug("step 1: Get SSO and token URL"); var authenticatorRestRequest = BuildAuthenticatorRestRequest(); var authenticatorResponse = session.restRequester.Post(authenticatorRestRequest); authenticatorResponse.FilterFailedResponse(); Uri ssoUrl = new Uri(authenticatorResponse.data.ssoUrl); Uri tokenUrl = new Uri(authenticatorResponse.data.tokenUrl); - logger.Debug("step 2: verify urls fetched from step 1"); - logger.Debug("Checking sso url"); - VerifyUrls(ssoUrl, oktaUrl); - logger.Debug("Checking token url"); - VerifyUrls(tokenUrl, oktaUrl); + s_logger.Debug("step 2: Verify URLs fetched from step 1"); + s_logger.Debug("Checking SSO Okta URL"); + VerifyUrls(ssoUrl, _oktaUrl); + s_logger.Debug("Checking token URL"); + VerifyUrls(tokenUrl, _oktaUrl); int retryCount = 0; int timeoutElapsed = 0; Exception lastRetryException = null; HttpResponseMessage samlRawResponse = null; - // If VerifyPostbackUrl() fails, retry with new onetimetoken + // If VerifyPostbackUrl() fails, retry with new one-time token while (RetryLimitIsNotReached(retryCount, timeoutElapsed)) { try { - logger.Debug("step 3: get idp onetime token"); + s_logger.Debug("step 3: Get IdP one-time token"); IdpTokenRestRequest idpTokenRestRequest = BuildIdpTokenRestRequest(tokenUrl); var idpResponse = session.restRequester.Post(idpTokenRestRequest); - string onetimeToken = idpResponse.SessionToken != null ? idpResponse.SessionToken : idpResponse.CookieToken; + string onetimeToken = idpResponse.SessionToken ?? idpResponse.CookieToken; - logger.Debug("step 4: get SAML reponse from sso"); - var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); + s_logger.Debug("step 4: Get SAML response from SSO"); + var samlRestRequest = BuildSamlRestRequest(ssoUrl, onetimeToken); samlRawResponse = session.restRequester.Get(samlRestRequest); - samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false)).Result; + _rawSamlTokenHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false)).Result; - logger.Debug("step 5: verify postback url in SAML reponse"); + s_logger.Debug("step 5: Verify postback URL in SAML response"); VerifyPostbackUrl(); - logger.Debug("step 6: send SAML reponse to snowflake to login"); - base.Login(); + s_logger.Debug("step 6: Send SAML response to Snowflake to login"); + Login(); return; } catch(Exception ex) { lastRetryException = ex; - if (IsPostbackUrlNotFound(lastRetryException)) - { - logger.Debug("Refreshing token for Okta re-authentication and starting from step 3 again"); - - // Get the current retry count and timeout elapsed from the response headers - retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First()); - timeoutElapsed += int.Parse(samlRawResponse.Content.Headers.GetValues(TimeoutElapsedHeader).First()); - } - else - { - logger.Error("Failed to get the correct SAML response from Okta", ex); - throw; - } + HandleAuthenticatorException(ex, samlRawResponse, ref retryCount, ref timeoutElapsed); } - } + } // while retry // Throw exception if max retry count or max timeout has been reached ThrowRetryLimitException(retryCount, timeoutElapsed, lastRetryException); } + private void HandleAuthenticatorException(Exception ex, HttpResponseMessage samlRawResponse, ref int retryCount, ref int timeoutElapsed) + { + if (IsPostbackUrlNotFound(ex)) + { + s_logger.Debug("Refreshing token for Okta re-authentication and starting from step 3 again"); + + if (samlRawResponse is null) + { + var errorNullSamlResponse = "Failure getting SAML response from Okta SSO"; + s_logger.Error(errorNullSamlResponse); + throw new SnowflakeDbException(ex, SFError.IDP_SAML_POSTBACK_INVALID); + } + + // Get the current retry count and timeout elapsed from the response headers + retryCount += int.Parse(samlRawResponse.Content.Headers.GetValues(RetryCountHeader).First()); + timeoutElapsed += int.Parse(samlRawResponse.Content.Headers.GetValues(TimeoutElapsedHeader).First()); + } + else + { + s_logger.Error("Failed to get the correct SAML response from Okta SSO", ex); + throw ex; + } + } + private SFRestRequest BuildAuthenticatorRestRequest() { var fedUrl = session.BuildUri(RestPath.SF_AUTHENTICATOR_REQUEST_PATH); - var data = new AuthenticatorRequestData() + var data = new AuthenticatorRequestData { AccountName = session.properties[SFSessionProperty.ACCOUNT], - Authenticator = oktaUrl.ToString(), - DriverVersion = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version.ToString(), + Authenticator = _oktaUrl.ToString(), + DriverVersion = System.Reflection.Assembly.GetExecutingAssembly().GetName().Version?.ToString(), DriverName = ".NET" }; - int connectionTimeoutSec = int.Parse(session.properties[SFSessionProperty.CONNECTION_TIMEOUT]); - - return session.BuildTimeoutRestRequest(fedUrl, new AuthenticatorRequest() { Data = data }); + return session.BuildTimeoutRestRequest(fedUrl, new AuthenticatorRequest { Data = data }); } private IdpTokenRestRequest BuildIdpTokenRestRequest(Uri tokenUrl) { - return new IdpTokenRestRequest() + return new IdpTokenRestRequest { Url = tokenUrl, RestTimeout = session.connectionTimeout, @@ -211,9 +211,9 @@ private IdpTokenRestRequest BuildIdpTokenRestRequest(Uri tokenUrl) }; } - private SAMLRestRequest BuildSAMLRestRequest(Uri ssoUrl, string onetimeToken) + private SamlRestRequest BuildSamlRestRequest(Uri ssoUrl, string onetimeToken) { - return new SAMLRestRequest() + return new SamlRestRequest() { Url = ssoUrl, RestTimeout = session.connectionTimeout, @@ -225,7 +225,7 @@ private SAMLRestRequest BuildSAMLRestRequest(Uri ssoUrl, string onetimeToken) /// protected override void SetSpecializedAuthenticatorData(ref LoginRequestData data) { - data.RawSamlResponse = samlRawHtmlString; + data.RawSamlResponse = _rawSamlTokenHtmlString; } private void VerifyUrls(Uri tokenOrSsoUrl, Uri sessionUrl) @@ -233,28 +233,27 @@ private void VerifyUrls(Uri tokenOrSsoUrl, Uri sessionUrl) if (tokenOrSsoUrl.Scheme != sessionUrl.Scheme || tokenOrSsoUrl.Host != sessionUrl.Host) { var e = new SnowflakeDbException( - SFError.IDP_SSO_TOKEN_URL_MISMATCH, tokenOrSsoUrl.ToString(), oktaUrl.ToString()); - logger.Error("Different urls", e); + SFError.IDP_SSO_TOKEN_URL_MISMATCH, tokenOrSsoUrl.ToString(), _oktaUrl.ToString()); + s_logger.Error("Different urls", e); throw e; } } private void VerifyPostbackUrl() { - int formIndex = samlRawHtmlString.IndexOf("().errorCode; } @@ -313,21 +301,21 @@ private void ThrowRetryLimitException(int retryCount, int timeoutElapsed, Except } errorMessage += " while trying to authenticate through Okta"; - logger.Error(errorMessage); + s_logger.Error(errorMessage); throw new SnowflakeDbException(lastRetryException, SFError.INTERNAL_ERROR, errorMessage); } } internal class IdpTokenRestRequest : BaseRestRequest, IRestRequest { - private static MediaTypeWithQualityHeaderValue jsonHeader = new MediaTypeWithQualityHeaderValue("application/json"); + private static readonly MediaTypeWithQualityHeaderValue s_jsonHeader = new MediaTypeWithQualityHeaderValue("application/json"); internal IdpTokenRequest JsonBody { get; set; } HttpRequestMessage IRestRequest.ToRequestMessage(HttpMethod method) { HttpRequestMessage message = newMessage(method, Url); - message.Headers.Accept.Add(jsonHeader); + message.Headers.Accept.Add(s_jsonHeader); var json = JsonConvert.SerializeObject(JsonBody, JsonUtils.JsonSettings); message.Content = new StringContent(json, Encoding.UTF8, "application/json"); @@ -353,7 +341,7 @@ class IdpTokenResponse internal String SessionToken { get; set; } } - class SAMLRestRequest : BaseRestRequest, IRestRequest + class SamlRestRequest : BaseRestRequest, IRestRequest { internal string OnetimeToken { set; get; }