Skip to content

Commit

Permalink
SNOW-1373257 Make secure password be a part of pool key
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-knozderko committed May 10, 2024
1 parent bcce8ef commit 2f6339a
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 44 deletions.
137 changes: 100 additions & 37 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
*/

using System;
using System.Net;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -12,6 +11,7 @@
using Snowflake.Data.Core.Session;
using Moq;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;

namespace Snowflake.Data.Tests.UnitTests
Expand All @@ -22,10 +22,8 @@ class ConnectionPoolManagerTest
private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager();
private const string ConnectionString1 = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;";
private const string ConnectionString2 = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=1;";
private const string ConnectionString3 = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;";
private readonly SecureString _password1 = null;
private readonly SecureString _password2 = null;
private readonly SecureString _password3 = new NetworkCredential("", "P3").SecurePassword;
private const string ConnectionStringWithoutPassword = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;";
private readonly SecureString _password3 = SecureStringHelper.Encode("P3");
private static PoolConfig s_poolConfig;

[OneTimeSetUp]
Expand Down Expand Up @@ -53,29 +51,29 @@ public void BeforeEach()
public void TestPoolManagerReturnsSessionPoolForGivenConnectionString()
{
// Act
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null);

// Assert
Assert.AreEqual(ConnectionString1, sessionPool.ConnectionString);
Assert.AreEqual(_password1, sessionPool.Password);
Assert.AreEqual(null, sessionPool.Password);
}

[Test]
public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringAndSecurelyProvidedPassword()
{
// Act
var sessionPool = _connectionPoolManager.GetPool(ConnectionString3, _password3);
var sessionPool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3);

// Assert
Assert.AreEqual(ConnectionString3, sessionPool.ConnectionString);
Assert.AreEqual(ConnectionStringWithoutPassword, sessionPool.ConnectionString);
Assert.AreEqual(_password3, sessionPool.Password);
}

[Test]
public void TestPoolManagerThrowsWhenPasswordNotProvided()
{
// Act/Assert
Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetPool(ConnectionString3, null));
Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, null));
}

[Test]
Expand All @@ -85,8 +83,8 @@ public void TestPoolManagerReturnsSamePoolForGivenConnectionString()
var anotherConnectionString = ConnectionString1;

// Act
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, _password1);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, null);

// Assert
Assert.AreEqual(sessionPool1, sessionPool2);
Expand All @@ -99,8 +97,8 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings()
Assert.AreNotSame(ConnectionString1, ConnectionString2);

// Act
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);

// Assert
Assert.AreNotSame(sessionPool1, sessionPool2);
Expand All @@ -113,54 +111,54 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings()
public void TestGetSessionWorksForSpecifiedConnectionString()
{
// Act
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Assert.AreEqual(_password1, sfSession.Password);
Assert.AreEqual(null, sfSession.Password);
}

[Test]
public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString()
{
// Act
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, _password1, CancellationToken.None);
var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None);

// Assert
Assert.AreEqual(ConnectionString1, sfSession.ConnectionString);
Assert.AreEqual(_password1, sfSession.Password);
Assert.AreEqual(null, sfSession.Password);
}

[Test]
public void TestCountingOfSessionProvidedByPool()
{
// Act
_connectionPoolManager.GetSession(ConnectionString1, _password1);
_connectionPoolManager.GetSession(ConnectionString1, null);

// Assert
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null);
Assert.AreEqual(1, sessionPool.GetCurrentPoolSize());
}

[Test]
public void TestCountingOfSessionReturnedBackToPool()
{
// Arrange
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1);
var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null);

// Act
_connectionPoolManager.AddSession(sfSession);

// Assert
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null);
Assert.AreEqual(1, sessionPool.GetCurrentPoolSize());
}

[Test]
public void TestSetMaxPoolSizeForAllPoolsDisabled()
{
// Arrange
_connectionPoolManager.GetPool(ConnectionString1, _password1);
_connectionPoolManager.GetPool(ConnectionString1, null);

// Act
var thrown = Assert.Throws<Exception>(() => _connectionPoolManager.SetMaxPoolSize(3));
Expand All @@ -173,7 +171,7 @@ public void TestSetMaxPoolSizeForAllPoolsDisabled()
public void TestSetTimeoutForAllPoolsDisabled()
{
// Arrange
_connectionPoolManager.GetPool(ConnectionString1, _password1);
_connectionPoolManager.GetPool(ConnectionString1, null);

// Act
var thrown = Assert.Throws<Exception>(() => _connectionPoolManager.SetTimeout(3000));
Expand All @@ -186,7 +184,7 @@ public void TestSetTimeoutForAllPoolsDisabled()
public void TestSetPoolingForAllPoolsDisabled()
{
// Arrange
_connectionPoolManager.GetPool(ConnectionString1, _password1);
_connectionPoolManager.GetPool(ConnectionString1, null);

// Act
var thrown = Assert.Throws<Exception>(() => _connectionPoolManager.SetPooling(false));
Expand All @@ -199,8 +197,8 @@ public void TestSetPoolingForAllPoolsDisabled()
public void TestGetPoolingOnManagerLevelAlwaysTrue()
{
// Arrange
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);
sessionPool1.SetPooling(true);
sessionPool2.SetPooling(false);

Expand All @@ -217,8 +215,8 @@ public void TestGetPoolingOnManagerLevelAlwaysTrue()
public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual()
{
// Arrange
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);
sessionPool1.SetTimeout(299);
sessionPool2.SetTimeout(1313);

Expand All @@ -233,8 +231,8 @@ public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual()
public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual()
{
// Arrange
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);
sessionPool1.SetTimeout(3600);
sessionPool2.SetTimeout(3600);

Expand All @@ -246,8 +244,8 @@ public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual()
public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual()
{
// Arrange
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);
sessionPool1.SetMaxPoolSize(1);
sessionPool2.SetMaxPoolSize(17);

Expand All @@ -262,8 +260,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual()
public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual()
{
// Arrange
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2);
var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null);
var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null);
sessionPool1.SetMaxPoolSize(33);
sessionPool2.SetMaxPoolSize(33);

Expand All @@ -275,8 +273,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual()
public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes()
{
// Arrange
EnsurePoolSize(ConnectionString1, _password1, 2);
EnsurePoolSize(ConnectionString2, _password2, 3);
EnsurePoolSize(ConnectionString1, null, 2);
EnsurePoolSize(ConnectionString2, null, 3);

// act
var poolSize = _connectionPoolManager.GetCurrentPoolSize();
Expand All @@ -285,6 +283,71 @@ public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes()
Assert.AreEqual(5, poolSize);
}

[Test]
public void TestReturnPoolForSecurePassword()
{
// arrange
const string AnotherPassword = "anotherPassword";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1);

// act
var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different

// assert
Assert.AreEqual(0, pool.GetCurrentPoolSize());
Assert.AreEqual(AnotherPassword, SecureStringHelper.Decode(pool.Password));
}

[Test]
public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay()
{
// arrange
var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}";
EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2);
EnsurePoolSize(connectionStringWithPassword, null, 5);
EnsurePoolSize(connectionStringWithPassword, _password3, 8);

// act
var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3);
var pool2 = _connectionPoolManager.GetPool(connectionStringWithPassword, null);
var pool3 = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3);

// assert
Assert.AreEqual(2, pool1.GetCurrentPoolSize());
Assert.AreEqual(5, pool2.GetCurrentPoolSize());
Assert.AreEqual(8, pool3.GetCurrentPoolSize());
}

[Test]
[TestCase(null)]
[TestCase("")]
public void TestGetPoolFailsWhenNoPasswordProvided(string password)
{
// arrange
var securePassword = password == null ? null : SecureStringHelper.Encode(password);

// act
var thrown = Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, securePassword));

// assert
Assert.That(thrown.Message, Does.Contain("Required property PASSWORD is not provided"));
}

[Test]
public void TestPoolDoesNotSerializePassword()
{
// arrange
var password = SecureStringHelper.Decode(_password3);
var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={password}";
var pool = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3);

// act
var serializedPool = pool.ToString();

// assert
Assert.IsFalse(serializedPool.Contains(password, StringComparison.OrdinalIgnoreCase));
}

private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize)
{
var sessionPool = _connectionPoolManager.GetPool(connectionString, password);
Expand Down
36 changes: 36 additions & 0 deletions Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System;
using System.Net;
using System.Text.RegularExpressions;
using NUnit.Framework;
using Snowflake.Data.Client;
using Snowflake.Data.Core;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Tests.Util;

namespace Snowflake.Data.Tests.UnitTests.Session
Expand Down Expand Up @@ -130,5 +132,39 @@ public void TestPoolIdentificationForOldPool()
// assert
Assert.AreEqual("", poolIdentification);
}

[Test]
[TestCase(null)]
[TestCase("")]
[TestCase("anyPassword")]
public void TestValidateValidSecurePassword(string password)
{
// arrange
var securePassword = password == null ? null : SecureStringHelper.Encode(password);
var pool = SessionPool.CreateSessionPool(ConnectionString, securePassword);

// act
Assert.DoesNotThrow(() => pool.ValidateSecurePassword(securePassword));
}

[Test]
[TestCase("somePassword", null)]
[TestCase("somePassword", "")]
[TestCase("somePassword", "anotherPassword")]
[TestCase("", "anotherPassword")]
[TestCase(null, "anotherPassword")]
public void TestFailToValidateNotMatchingSecurePassword(string poolPassword, string notMatchingPassword)
{
// arrange
var poolSecurePassword = poolPassword == null ? null : SecureStringHelper.Encode(poolPassword);
var notMatchingSecurePassword = notMatchingPassword == null ? null : SecureStringHelper.Encode(notMatchingPassword);
var pool = SessionPool.CreateSessionPool(ConnectionString, poolSecurePassword);

// act
var thrown = Assert.Throws<Exception>(() => pool.ValidateSecurePassword(notMatchingSecurePassword));

// assert
Assert.That(thrown.Message, Does.Contain("Could not get a pool because of password mismatch"));
}
}
}
23 changes: 23 additions & 0 deletions Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using NUnit.Framework;
using Snowflake.Data.Core.Tools;

namespace Snowflake.Data.Tests.UnitTests.Tools
{
[TestFixture]
public class SecureStringHelperTest
{
[Test]
public void TestConvertPassword()
{
// arrange
var passwordText = "testPassword";

// act
var securePassword = SecureStringHelper.Encode(passwordText);
var decodedPassword = SecureStringHelper.Decode(securePassword);

// assert
Assert.AreEqual(passwordText, decodedPassword);
}
}
}
Loading

0 comments on commit 2f6339a

Please sign in to comment.