Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-916949: Fix Okta retry for SSO/SAML endpoints #865

Merged
merged 8 commits into from
Feb 12, 2024
89 changes: 89 additions & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Snowflake.Data.Tests.IntegrationTests
using System.Diagnostics;
using Snowflake.Data.Tests.Mock;
using System.Runtime.InteropServices;
using System.Net.Http;

[TestFixture]
class SFConnectionIT : SFBaseTest
Expand Down Expand Up @@ -817,6 +818,46 @@ public void TestOktaConnection()
}
}

[Test]
public void TestOktaConnectionUntilMaxTimeout()
{
var expectedMaxRetryCount = 15;
var expectedMaxConnectionTimeout = 450;
var oktaUrl = "https://test.okta.com";
var mockRestRequester = new MockOktaRestRequester()
{
TokenUrl = $"{oktaUrl}/api/v1/sessions?additionalFields=cookieToken",
SSOUrl = $"{oktaUrl}/app/testaccount/sso/saml",
ResponseContent = new StringContent("<form=error}"),
MaxRetryCount = expectedMaxRetryCount,
MaxRetryTimeout = expectedMaxConnectionTimeout
};
using (DbConnection conn = new MockSnowflakeDbConnection(mockRestRequester))
{
try
{
conn.ConnectionString
= ConnectionStringWithoutAuth
+ String.Format(
";authenticator={0};user=test;password=test;MAXHTTPRETRIES={1};RETRY_TIMEOUT={2};",
oktaUrl,
expectedMaxRetryCount,
expectedMaxConnectionTimeout);
conn.Open();
Assert.Fail();
}
catch (Exception e)
{
Assert.IsInstanceOf<SnowflakeDbException>(e);
Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute<SFErrorAttr>().errorCode, ((SnowflakeDbException)e).ErrorCode);
Assert.IsTrue(e.Message.Contains(
$"The retry count has reached its limit of {expectedMaxRetryCount} and " +
$"the timeout elapsed has reached its limit of {expectedMaxConnectionTimeout} " +
"while trying to authenticate through Okta"));
}
}
}

[Test]
[Ignore("This test requires manual setup and therefore cannot be run in CI")]
public void TestOkta2ConnectionsFollowingEachOther()
Expand Down Expand Up @@ -2056,6 +2097,54 @@ public void TestExplicitTransactionOperationsTracked()
Assert.AreEqual(false, conn.HasActiveExplicitTransaction());
}
}


[Test]
public void TestAsyncOktaConnectionUntilMaxTimeout()
{
var expectedMaxRetryCount = 15;
var expectedMaxConnectionTimeout = 450;
var oktaUrl = "https://test.okta.com";
var mockRestRequester = new MockOktaRestRequester()
{
TokenUrl = $"{oktaUrl}/api/v1/sessions?additionalFields=cookieToken",
SSOUrl = $"{oktaUrl}/app/testaccount/sso/saml",
ResponseContent = new StringContent("<form=error}"),
MaxRetryCount = expectedMaxRetryCount,
MaxRetryTimeout = expectedMaxConnectionTimeout
};
using (DbConnection conn = new MockSnowflakeDbConnection(mockRestRequester))
{
try
{
conn.ConnectionString
= ConnectionStringWithoutAuth
+ String.Format(
";authenticator={0};user=test;password=test;MAXHTTPRETRIES={1};RETRY_TIMEOUT={2};",
oktaUrl,
expectedMaxRetryCount,
expectedMaxConnectionTimeout);
Task connectTask = conn.OpenAsync(CancellationToken.None);
connectTask.Wait();
Assert.Fail();
}
catch (Exception e)
{
Assert.IsInstanceOf<SnowflakeDbException>(e.InnerException);
Assert.AreEqual(SFError.INTERNAL_ERROR.GetAttribute<SFErrorAttr>().errorCode, ((SnowflakeDbException)e.InnerException).ErrorCode);
Exception oktaException;
#if NETFRAMEWORK
oktaException = e.InnerException.InnerException.InnerException;
#else
oktaException = e.InnerException.InnerException;
#endif
Assert.IsTrue(oktaException.Message.Contains(
$"The retry count has reached its limit of {expectedMaxRetryCount} and " +
$"the timeout elapsed has reached its limit of {expectedMaxConnectionTimeout} " +
"while trying to authenticate through Okta"));
}
}
}
}
}

Expand Down
49 changes: 40 additions & 9 deletions Snowflake.Data.Tests/Mock/MockOkta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class MockOktaRestRequester : IMockRestRequester
public string TokenUrl { get; set; }
public string SSOUrl { get; set; }
public StringContent ResponseContent { get; set; }
public int MaxRetryCount { get; set; }
public int MaxRetryTimeout { get; set; }

public T Get<T>(IRestRequest request)
{
Expand All @@ -31,6 +33,8 @@ public Task<HttpResponseMessage> GetAsync(IRestRequest request, CancellationToke
{
var response = new HttpResponseMessage(System.Net.HttpStatusCode.OK);
response.Content = ResponseContent;
response.Content.Headers.Add(OktaAuthenticator.RetryCountHeader, MaxRetryCount.ToString());
response.Content.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, MaxRetryTimeout.ToString());
return Task.FromResult(response);
}

Expand All @@ -43,18 +47,45 @@ public Task<T> PostAsync<T>(IRestRequest postRequest, CancellationToken cancella
{
if (postRequest is SFRestRequest)
{
// authenticator
var authnResponse = new AuthenticatorResponse
if (((SFRestRequest)postRequest).jsonBody is AuthenticatorRequest)
{
success = true,
data = new AuthenticatorResponseData
// authenticator
var authnResponse = new AuthenticatorResponse
{
tokenUrl = TokenUrl,
ssoUrl = SSOUrl,
}
};
success = true,
data = new AuthenticatorResponseData
{
tokenUrl = TokenUrl,
ssoUrl = SSOUrl,
}
};

return Task.FromResult<T>((T)(object)authnResponse);
}
else
{
// login
var loginResponse = new LoginResponse
{
success = true,
data = new LoginResponseData
{
sessionId = "",
token = "",
masterToken = "",
masterValidityInSeconds = 0,
authResponseSessionInfo = new SessionInfo
{
databaseName = "",
schemaName = "",
roleName = "",
warehouseName = "",
}
}
};

return Task.FromResult<T>((T)(object)authnResponse);
return Task.FromResult<T>((T)(object)loginResponse);
}
}
else
{
Expand Down
18 changes: 18 additions & 0 deletions Snowflake.Data.Tests/UnitTests/HttpUtilTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,24 @@ public void TestIsLoginUrl(string requestUrl, bool expectedIsLoginEndpoint)
Assert.AreEqual(expectedIsLoginEndpoint, isLoginEndpoint);
}

// Parameters: request url, expected value
[TestCase("https://dev.okta.com/sso/saml", true)]
[TestCase("https://test.snowflakecomputing.com/session/v1/login-request", false)]
[TestCase("https://test.snowflakecomputing.com/session/authenticator-request", false)]
[TestCase("https://test.snowflakecomputing.com/session/token-request", false)]
[Test]
public void TestIsOktaSSORequest(string requestUrl, bool expectedIsOktaSSORequest)
{
// given
var uri = new Uri(requestUrl);

// when
bool isOktaSSORequest = HttpUtil.IsOktaSSORequest(uri.Host, uri.AbsolutePath);

// then
Assert.AreEqual(expectedIsOktaSSORequest, isOktaSSORequest);
}

// Parameters: time in seconds
[TestCase(4)]
[TestCase(8)]
Expand Down
59 changes: 57 additions & 2 deletions Snowflake.Data.Tests/UnitTests/SFOktaTest.cs

Large diffs are not rendered by default.

Loading
Loading