Skip to content

Commit

Permalink
SNOW-1490901 Passcode support for mfa authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-knozderko committed Jun 21, 2024
1 parent 393762a commit 434b03d
Show file tree
Hide file tree
Showing 37 changed files with 430 additions and 165 deletions.
21 changes: 21 additions & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Data.Common;
using System.Net;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;

namespace Snowflake.Data.Tests.IntegrationTests
Expand Down Expand Up @@ -2271,6 +2272,26 @@ public void TestUseMultiplePoolsConnectionPoolByDefault()
// assert
Assert.AreEqual(ConnectionPoolType.MultipleConnectionPool, poolVersion);
}

[Test]
[Ignore("Requires manual steps and environment with mfa authentication enrolled")] // to enroll to mfa authentication edit your user profile
public void TestMfaWithPasswordConnection()
{
// arrange
using (SnowflakeDbConnection conn = new SnowflakeDbConnection())
{
conn.Passcode = SecureStringHelper.Encode("123456");
// manual action: stop here in breakpoint to provide proper passcode by: conn.Passcode = SecureStringHelper.Encode("...");
conn.ConnectionString = ConnectionString + "minPoolSize=0;application=DuoTest";

// act
conn.Open();

// assert
Assert.AreEqual(ConnectionState.Open, conn.State);
// manual action: verify that you have received no push request for given connection
}
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ public override Task OpenAsync(CancellationToken cancellationToken)
cancellationToken);

}

private void SetMockSession()
{
SfSession = new SFSession(ConnectionString, Password, _restRequester);
SfSession = new SFSession(ConnectionString, Password, Passcode, _restRequester);

_connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds;

Expand All @@ -92,7 +92,7 @@ private void OnSessionEstablished()
{
_connectionState = ConnectionState.Open;
}

protected override bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus)
{
return false;
Expand Down
50 changes: 25 additions & 25 deletions Snowflake.Data.Tests/UnitTests/ArrowResultSetTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public void BeforeTest()
// by default generate Int32 values from 1 to RowCount
PrepareTestCase(SFDataType.FIXED, 0, Enumerable.Range(1, RowCount).ToArray());
}

[Test]
public void TestResultFormatIsArrow()
{
Expand Down Expand Up @@ -139,7 +139,7 @@ public void TestGetValueReturnsNull()
var arrowResultSet = new ArrowResultSet(responseData, sfStatement, new CancellationToken());

arrowResultSet.Next();

Assert.AreEqual(true, arrowResultSet.IsDBNull(0));
Assert.AreEqual(DBNull.Value, arrowResultSet.GetValue(0));
}
Expand All @@ -151,7 +151,7 @@ public void TestGetDecimal()

TestGetNumber(testValues);
}

[Test]
public void TestGetNumber64()
{
Expand All @@ -164,7 +164,7 @@ public void TestGetNumber64()
public void TestGetNumber32()
{
var testValues = new int[] { 0, 100, -100, Int32.MaxValue, Int32.MinValue };

TestGetNumber(testValues);
}

Expand All @@ -175,7 +175,7 @@ public void TestGetNumber16()

TestGetNumber(testValues);
}

[Test]
public void TestGetNumber8()
{
Expand All @@ -199,7 +199,7 @@ private void TestGetNumber(IEnumerable testValues)
Assert.AreEqual(expectedValue, _arrowResultSet.GetDecimal(ColumnIndex));
Assert.AreEqual(expectedValue, _arrowResultSet.GetDouble(ColumnIndex));
Assert.AreEqual(expectedValue, _arrowResultSet.GetFloat(ColumnIndex));

if (expectedValue >= Int64.MinValue && expectedValue <= Int64.MaxValue)
{
// get integer value
Expand Down Expand Up @@ -229,7 +229,7 @@ public void TestGetBoolean()
var testValues = new bool[] { true, false };

PrepareTestCase(SFDataType.BOOLEAN, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Expand All @@ -244,15 +244,15 @@ public void TestGetReal()
var testValues = new double[] { 0, Double.MinValue, Double.MaxValue };

PrepareTestCase(SFDataType.REAL, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDouble(ColumnIndex));
}
}

[Test]
public void TestGetText()
{
Expand All @@ -263,15 +263,15 @@ public void TestGetText()
};

PrepareTestCase(SFDataType.TEXT, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetString(ColumnIndex));
}
}

[Test]
public void TestGetTextWithOneChar()
{
Expand All @@ -289,14 +289,14 @@ public void TestGetTextWithOneChar()
#endif

PrepareTestCase(SFDataType.TEXT, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetChar(ColumnIndex));
}
}

[Test]
public void TestGetArray()
{
Expand All @@ -307,7 +307,7 @@ public void TestGetArray()
};

PrepareTestCase(SFDataType.ARRAY, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Expand All @@ -319,7 +319,7 @@ public void TestGetArray()
Assert.AreEqual(testValue.Length, str.Length);
}
}

[Test]
public void TestGetBinary()
{
Expand All @@ -341,7 +341,7 @@ public void TestGetBinary()
Assert.AreEqual(testValue[j], buffer[j], "position " + j);
}
}

[Test]
public void TestGetDate()
{
Expand All @@ -353,15 +353,15 @@ public void TestGetDate()
};

PrepareTestCase(SFDataType.DATE, 0, testValues);

foreach (var testValue in testValues)
{
_arrowResultSet.Next();
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex));
}
}

[Test]
public void TestGetTime()
{
Expand All @@ -383,7 +383,7 @@ public void TestGetTime()
Assert.AreEqual(testValue, _arrowResultSet.GetValue(ColumnIndex));
Assert.AreEqual(testValue, _arrowResultSet.GetDateTime(ColumnIndex));
}
}
}
}

[Test]
Expand Down Expand Up @@ -473,10 +473,10 @@ private QueryExecResponseData PrepareResponseData(RecordBatch recordBatch, SFDat
return new QueryExecResponseData
{
rowType = recordBatch.Schema.FieldsList
.Select(col =>
.Select(col =>
new ExecResponseRowType
{
name = col.Name,
name = col.Name,
type = sfType.ToString(),
scale = scale
}).ToList(),
Expand All @@ -491,7 +491,7 @@ private string ConvertToBase64String(RecordBatch recordBatch)
{
if (recordBatch == null)
return "";

using (var stream = new MemoryStream())
{
using (var writer = new ArrowStreamWriter(stream, recordBatch.Schema))
Expand All @@ -502,12 +502,12 @@ private string ConvertToBase64String(RecordBatch recordBatch)
return Convert.ToBase64String(stream.ToArray());
}
}

private SFStatement PrepareStatement()
{
SFSession session = new SFSession("user=user;password=password;account=account;", null);
SFSession session = new SFSession("user=user;password=password;account=account;", null, null);
return new SFStatement(session);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void TestAuthPropertiesValid(string connectionString, string password)
var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword;

// Act/Assert
Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword));
Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword, null));
}

[TestCase("authenticator=snowflake;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")]
Expand All @@ -54,7 +54,7 @@ public void TestAuthPropertiesInvalid(string connectionString, string password,
var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword;

// Act
var exception = Assert.Throws<SnowflakeDbException>(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword));
var exception = Assert.Throws<SnowflakeDbException>(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword, null));

// Assert
SnowflakeDbExceptionAssert.HasErrorCode(exception, expectedError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private QueryExecResponseData mockQueryRequestData()
private SFResultSet mockSFResultSet(QueryExecResponseData responseData, CancellationToken token)
{
string connectionString = "user=user;password=password;account=account;";
SFSession session = new SFSession(connectionString, null);
SFSession session = new SFSession(connectionString, null , null);
List<NameValueParameter> list = new List<NameValueParameter>
{
new NameValueParameter { name = "CLIENT_PREFETCH_THREADS", value = "3" }
Expand Down
28 changes: 14 additions & 14 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings()
public void TestGetSessionWorksForSpecifiedConnectionString()
{
// Act
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Expand All @@ -122,7 +122,7 @@ public void TestGetSessionWorksForSpecifiedConnectionString()
public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString()
{
// Act
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None);
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, null, CancellationToken.None);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Expand All @@ -133,7 +133,7 @@ public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString()
public void TestCountingOfSessionProvidedByPool()
{
// Act
_connectionPoolManager.GetSession(ConnectionString1, null);
_connectionPoolManager.GetSession(ConnectionString1, null, null);

// Assert
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null);
Expand All @@ -144,7 +144,7 @@ public void TestCountingOfSessionProvidedByPool()
public void TestCountingOfSessionReturnedBackToPool()
{
// Arrange
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null, null);

// Act
_connectionPoolManager.AddSession(sfSession);
Expand Down Expand Up @@ -285,8 +285,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual()
public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes()
{
// Arrange
EnsurePoolSize(ConnectionString1, null, 2);
EnsurePoolSize(ConnectionString2, null, 3);
EnsurePoolSize(ConnectionString1, null, null,2);
EnsurePoolSize(ConnectionString2, null, null, 3);

// act
var poolSize = _connectionPoolManager.GetCurrentPoolSize();
Expand All @@ -300,7 +300,7 @@ public void TestReturnPoolForSecurePassword()
{
// arrange
const string AnotherPassword = "anotherPassword";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1);
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 1);

// act
var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different
Expand All @@ -315,9 +315,9 @@ public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay()
{
// arrange
var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2);
EnsurePoolSize(connectionStringWithPassword, null, 5);
EnsurePoolSize(connectionStringWithPassword, _password3, 8);
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, null, 2);
EnsurePoolSize(connectionStringWithPassword, null, null, 5);
EnsurePoolSize(connectionStringWithPassword, _password3, null, 8);

// act
var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3);
Expand Down Expand Up @@ -360,23 +360,23 @@ public void TestPoolDoesNotSerializePassword()
Assert.IsFalse(serializedPool.Contains(password));
}

private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize)
private void EnsurePoolSize(string connectionString, SecureString password, SecureString passcode, int requiredCurrentSize)
{
var sessionPool = _connectionPoolManager.GetPool(connectionString, password);
sessionPool.SetMaxPoolSize(requiredCurrentSize);
for (var i = 0; i < requiredCurrentSize; i++)
{
_connectionPoolManager.GetSession(connectionString, password);
_connectionPoolManager.GetSession(connectionString, password, passcode);
}
Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize());
}
}

class MockSessionFactory : ISessionFactory
{
public SFSession NewSession(string connectionString, SecureString password)
public SFSession NewSession(string connectionString, SecureString password, SecureString passcode)
{
var mockSfSession = new Mock<SFSession>(connectionString, password);
var mockSfSession = new Mock<SFSession>(connectionString, password, passcode);
mockSfSession.Setup(x => x.Open()).Verifiable();
mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this));
mockSfSession.Setup(x => x.IsNotOpen()).Returns(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SFAuthenticatorFactoryTest
private IAuthenticator GetAuthenticator(string authenticatorName, string extraParams = "")
{
string connectionString = $"account=test;user=test;password=test;authenticator={authenticatorName};{extraParams}";
SFSession session = new SFSession(connectionString, null);
SFSession session = new SFSession(connectionString, null, null);

return AuthenticatorFactory.GetAuthenticator(session);
}
Expand Down
Loading

0 comments on commit 434b03d

Please sign in to comment.