diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index d1660b41f..66249da3a 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -3,7 +3,6 @@ */ using System; -using System.Net; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -12,6 +11,7 @@ using Snowflake.Data.Core.Session; using Moq; using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.UnitTests @@ -22,10 +22,8 @@ class ConnectionPoolManagerTest private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); private const string ConnectionString1 = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;"; private const string ConnectionString2 = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=1;"; - private const string ConnectionString3 = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;"; - private readonly SecureString _password1 = null; - private readonly SecureString _password2 = null; - private readonly SecureString _password3 = new NetworkCredential("", "P3").SecurePassword; + private const string ConnectionStringWithoutPassword = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;"; + private readonly SecureString _password3 = SecureStringHelper.Encode("P3"); private static PoolConfig s_poolConfig; [OneTimeSetUp] @@ -53,21 +51,21 @@ public void BeforeEach() public void TestPoolManagerReturnsSessionPoolForGivenConnectionString() { // Act - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); // Assert Assert.AreEqual(ConnectionString1, sessionPool.ConnectionString); - Assert.AreEqual(_password1, sessionPool.Password); + Assert.AreEqual(null, sessionPool.Password); } [Test] public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringAndSecurelyProvidedPassword() { // Act - var sessionPool = _connectionPoolManager.GetPool(ConnectionString3, _password3); + var sessionPool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); // Assert - Assert.AreEqual(ConnectionString3, sessionPool.ConnectionString); + Assert.AreEqual(ConnectionStringWithoutPassword, sessionPool.ConnectionString); Assert.AreEqual(_password3, sessionPool.Password); } @@ -75,7 +73,7 @@ public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringAndSecurely public void TestPoolManagerThrowsWhenPasswordNotProvided() { // Act/Assert - Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionString3, null)); + Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, null)); } [Test] @@ -85,8 +83,8 @@ public void TestPoolManagerReturnsSamePoolForGivenConnectionString() var anotherConnectionString = ConnectionString1; // Act - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, _password1); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, null); // Assert Assert.AreEqual(sessionPool1, sessionPool2); @@ -99,8 +97,8 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() Assert.AreNotSame(ConnectionString1, ConnectionString2); // Act - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); // Assert Assert.AreNotSame(sessionPool1, sessionPool2); @@ -113,32 +111,32 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() public void TestGetSessionWorksForSpecifiedConnectionString() { // Act - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); - Assert.AreEqual(_password1, sfSession.Password); + Assert.AreEqual(null, sfSession.Password); } [Test] public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() { // Act - var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, _password1, CancellationToken.None); + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); - Assert.AreEqual(_password1, sfSession.Password); + Assert.AreEqual(null, sfSession.Password); } [Test] public void TestCountingOfSessionProvidedByPool() { // Act - _connectionPoolManager.GetSession(ConnectionString1, _password1); + _connectionPoolManager.GetSession(ConnectionString1, null); // Assert - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); } @@ -146,13 +144,13 @@ public void TestCountingOfSessionProvidedByPool() public void TestCountingOfSessionReturnedBackToPool() { // Arrange - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); // Act _connectionPoolManager.AddSession(sfSession); // Assert - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); } @@ -160,7 +158,7 @@ public void TestCountingOfSessionReturnedBackToPool() public void TestSetMaxPoolSizeForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password1); + _connectionPoolManager.GetPool(ConnectionString1, null); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetMaxPoolSize(3)); @@ -173,7 +171,7 @@ public void TestSetMaxPoolSizeForAllPoolsDisabled() public void TestSetTimeoutForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password1); + _connectionPoolManager.GetPool(ConnectionString1, null); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetTimeout(3000)); @@ -186,7 +184,7 @@ public void TestSetTimeoutForAllPoolsDisabled() public void TestSetPoolingForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password1); + _connectionPoolManager.GetPool(ConnectionString1, null); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetPooling(false)); @@ -199,8 +197,8 @@ public void TestSetPoolingForAllPoolsDisabled() public void TestGetPoolingOnManagerLevelAlwaysTrue() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); sessionPool1.SetPooling(true); sessionPool2.SetPooling(false); @@ -217,8 +215,8 @@ public void TestGetPoolingOnManagerLevelAlwaysTrue() public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); sessionPool1.SetTimeout(299); sessionPool2.SetTimeout(1313); @@ -233,8 +231,8 @@ public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); sessionPool1.SetTimeout(3600); sessionPool2.SetTimeout(3600); @@ -246,8 +244,8 @@ public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); sessionPool1.SetMaxPoolSize(1); sessionPool2.SetMaxPoolSize(17); @@ -262,8 +260,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); sessionPool1.SetMaxPoolSize(33); sessionPool2.SetMaxPoolSize(33); @@ -275,8 +273,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() { // Arrange - EnsurePoolSize(ConnectionString1, _password1, 2); - EnsurePoolSize(ConnectionString2, _password2, 3); + EnsurePoolSize(ConnectionString1, null, 2); + EnsurePoolSize(ConnectionString2, null, 3); // act var poolSize = _connectionPoolManager.GetCurrentPoolSize(); @@ -285,6 +283,71 @@ public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() Assert.AreEqual(5, poolSize); } + [Test] + public void TestReturnPoolForSecurePassword() + { + // arrange + const string AnotherPassword = "anotherPassword"; + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1); + + // act + var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + Assert.AreEqual(AnotherPassword, SecureStringHelper.Decode(pool.Password)); + } + + [Test] + public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay() + { + // arrange + var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}"; + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2); + EnsurePoolSize(connectionStringWithPassword, null, 5); + EnsurePoolSize(connectionStringWithPassword, _password3, 8); + + // act + var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); + var pool2 = _connectionPoolManager.GetPool(connectionStringWithPassword, null); + var pool3 = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3); + + // assert + Assert.AreEqual(2, pool1.GetCurrentPoolSize()); + Assert.AreEqual(5, pool2.GetCurrentPoolSize()); + Assert.AreEqual(8, pool3.GetCurrentPoolSize()); + } + + [Test] + [TestCase(null)] + [TestCase("")] + public void TestGetPoolFailsWhenNoPasswordProvided(string password) + { + // arrange + var securePassword = password == null ? null : SecureStringHelper.Encode(password); + + // act + var thrown = Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, securePassword)); + + // assert + Assert.That(thrown.Message, Does.Contain("Required property PASSWORD is not provided")); + } + + [Test] + public void TestPoolDoesNotSerializePassword() + { + // arrange + var password = SecureStringHelper.Decode(_password3); + var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={password}"; + var pool = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3); + + // act + var serializedPool = pool.ToString(); + + // assert + Assert.IsFalse(serializedPool.Contains(password, StringComparison.OrdinalIgnoreCase)); + } + private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize) { var sessionPool = _connectionPoolManager.GetPool(connectionString, password); diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs index 95b4e596e..fca8f7de1 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs @@ -1,9 +1,11 @@ +using System; using System.Net; using System.Text.RegularExpressions; using NUnit.Framework; using Snowflake.Data.Client; using Snowflake.Data.Core; using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.UnitTests.Session @@ -130,5 +132,39 @@ public void TestPoolIdentificationForOldPool() // assert Assert.AreEqual("", poolIdentification); } + + [Test] + [TestCase(null)] + [TestCase("")] + [TestCase("anyPassword")] + public void TestValidateValidSecurePassword(string password) + { + // arrange + var securePassword = password == null ? null : SecureStringHelper.Encode(password); + var pool = SessionPool.CreateSessionPool(ConnectionString, securePassword); + + // act + Assert.DoesNotThrow(() => pool.ValidateSecurePassword(securePassword)); + } + + [Test] + [TestCase("somePassword", null)] + [TestCase("somePassword", "")] + [TestCase("somePassword", "anotherPassword")] + [TestCase("", "anotherPassword")] + [TestCase(null, "anotherPassword")] + public void TestFailToValidateNotMatchingSecurePassword(string poolPassword, string notMatchingPassword) + { + // arrange + var poolSecurePassword = poolPassword == null ? null : SecureStringHelper.Encode(poolPassword); + var notMatchingSecurePassword = notMatchingPassword == null ? null : SecureStringHelper.Encode(notMatchingPassword); + var pool = SessionPool.CreateSessionPool(ConnectionString, poolSecurePassword); + + // act + var thrown = Assert.Throws(() => pool.ValidateSecurePassword(notMatchingSecurePassword)); + + // assert + Assert.That(thrown.Message, Does.Contain("Could not get a pool because of password mismatch")); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs new file mode 100644 index 000000000..52b10ed17 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs @@ -0,0 +1,23 @@ +using NUnit.Framework; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Tools +{ + [TestFixture] + public class SecureStringHelperTest + { + [Test] + public void TestConvertPassword() + { + // arrange + var passwordText = "testPassword"; + + // act + var securePassword = SecureStringHelper.Encode(passwordText); + var decodedPassword = SecureStringHelper.Decode(securePassword); + + // assert + Assert.AreEqual(passwordText, decodedPassword); + } + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index ea1b8ba3b..aa0271952 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; namespace Snowflake.Data.Core.Session @@ -122,14 +123,21 @@ public bool GetPooling() public SessionPool GetPool(string connectionString, SecureString password) { s_logger.Debug($"ConnectionPoolManager::GetPool"); - var poolKey = GetPoolKey(connectionString); + var poolKey = GetPoolKey(connectionString, password); if (_pools.TryGetValue(poolKey, out var item)) + { + item.ValidateSecurePassword(password); return item; + } + lock (s_poolsLock) { if (_pools.TryGetValue(poolKey, out var poolCreatedWhileWaitingOnLock)) + { + poolCreatedWhileWaitingOnLock.ValidateSecurePassword(password); return poolCreatedWhileWaitingOnLock; + } s_logger.Info($"Creating new pool"); var pool = SessionPool.CreateSessionPool(connectionString, password); _pools.Add(poolKey, pool); @@ -143,9 +151,9 @@ public SessionPool GetPool(string connectionString) return GetPool(connectionString, null); } - private string GetPoolKey(string connectionString) - { - return connectionString; - } + private string GetPoolKey(string connectionString, SecureString password) => + password != null && password.Length > 0 + ? connectionString + ";password=" + SecureStringHelper.Decode(password) + ";" + : connectionString + ";password=;"; } } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 49a9a0e75..a29a239ba 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -13,6 +13,7 @@ using System.Linq; using System.Text; using System.Text.RegularExpressions; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Core { @@ -249,7 +250,7 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin if (password != null && password.Length > 0) { - properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; + properties[SFSessionProperty.PASSWORD] = SecureStringHelper.Decode(password); } ValidateAuthenticator(properties); diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 5d38c2e6e..fc3f154d2 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -9,7 +9,6 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; -using Snowflake.Data.Core.Authenticator; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; @@ -119,6 +118,17 @@ private static Tuple ExtractConfig(string connecti } } + internal void ValidateSecurePassword(SecureString password) + { + if (!ExtractPassword(Password).Equals(ExtractPassword(password))) + { + throw new Exception("Could not get a pool because of password mismatch"); + } + } + + private string ExtractPassword(SecureString password) => + password == null ? string.Empty : SecureStringHelper.Decode(password); + internal SFSession GetSession(string connStr, SecureString password) { s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); diff --git a/Snowflake.Data/Core/Tools/SecureStringHelper.cs b/Snowflake.Data/Core/Tools/SecureStringHelper.cs new file mode 100644 index 000000000..5d7b685c1 --- /dev/null +++ b/Snowflake.Data/Core/Tools/SecureStringHelper.cs @@ -0,0 +1,12 @@ +using System.Net; +using System.Security; + +namespace Snowflake.Data.Core.Tools +{ + internal static class SecureStringHelper + { + public static string Decode(SecureString password) => new NetworkCredential(string.Empty, password).Password; + + public static SecureString Encode(string password) => new NetworkCredential(string.Empty, password).SecurePassword; + } +}