Skip to content

Commit

Permalink
Increased code coverage for connection pool and review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mhofman committed Oct 30, 2023
1 parent a7b622f commit 5d4c5af
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 72 deletions.
66 changes: 15 additions & 51 deletions Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,14 @@ public SFConnectionPoolIT(ConnectionPoolType connectionPoolTypeUnderTest)
{
_connectionPoolTypeUnderTest = connectionPoolTypeUnderTest;
s_previousPoolConfig = new PoolConfig();
SnowflakeDbConnectionPool.SetConnectionPoolVersion(connectionPoolTypeUnderTest);
}

[SetUp]
public new void BeforeTest()
{
SnowflakeDbConnectionPool.GetPool(ConnectionString); // to instantiate the pool used in tests
SnowflakeDbConnectionPool.GetPool(ConnectionString + " retryCount=1");
SnowflakeDbConnectionPool.GetPool(ConnectionString + " retryCount=2");
SnowflakeDbConnectionPool.SetPooling(true);
SnowflakeDbConnectionPool.SetConnectionPoolVersion(_connectionPoolTypeUnderTest);
SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetPooling(true);
s_logger.Debug($"---------------- BeforeTest ---------------------");
s_logger.Debug($"Testing Pool Type: {SnowflakeDbConnectionPool.GetConnectionPoolVersion()}");
}
Expand Down Expand Up @@ -85,10 +82,8 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio
const int PoolTimeout = 3;

// reset to default settings in case it changed by other test cases
SnowflakeDbConnectionPool.GetPool(connectionString); // to instantiate pool
SnowflakeDbConnectionPool.SetPooling(true);
Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(connectionString).GetPooling()); // to instantiate pool
SnowflakeDbConnectionPool.SetMaxPoolSize(10);
SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetTimeout(PoolTimeout);

var threads = new Task[ThreadNum];
Expand All @@ -100,8 +95,6 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio
});
}
Task.WaitAll(threads);
// set pooling timeout back to default to avoid impact on other test cases
SnowflakeDbConnectionPool.SetTimeout(3600);
}

// thead to execute query with new connection in a loop
Expand Down Expand Up @@ -160,7 +153,6 @@ public void TestBasicConnectionPool()
[Test]
public void TestConnectionPool()
{
SnowflakeDbConnectionPool.ClearAllPools();
var conn1 = new SnowflakeDbConnection(ConnectionString);
conn1.Open();
Assert.AreEqual(ConnectionState.Open, conn1.State);
Expand All @@ -177,14 +169,11 @@ public void TestConnectionPool()
Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize());
Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(ConnectionState.Closed, conn2.State);
SnowflakeDbConnectionPool.ClearAllPools();
}

[Test]
public void TestConnectionPoolIsFull()
{
SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetPooling(true);
var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString);
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
var conn1 = new SnowflakeDbConnection();
Expand Down Expand Up @@ -221,10 +210,8 @@ public void TestConnectionPoolIsFull()
[Test]
public void TestConnectionPoolExpirationWorks()
{
SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
SnowflakeDbConnectionPool.SetTimeout(10);
SnowflakeDbConnectionPool.SetPooling(true);

var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand Down Expand Up @@ -254,7 +241,6 @@ public void TestConnectionPoolClean()
{
TestOnlyForOldPool();

SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand Down Expand Up @@ -282,15 +268,13 @@ public void TestConnectionPoolClean()
Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(ConnectionState.Closed, conn2.State);
Assert.AreEqual(ConnectionState.Closed, conn3.State);
SnowflakeDbConnectionPool.ClearAllPools();
}

[Test]
public void TestNewConnectionPoolClean()
{
TestOnlyForNewPool();

SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand Down Expand Up @@ -320,17 +304,14 @@ public void TestNewConnectionPoolClean()
Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(ConnectionState.Closed, conn2.State);
Assert.AreEqual(ConnectionState.Closed, conn3.State);
SnowflakeDbConnectionPool.ClearAllPools();
}

[Test]
public void TestConnectionPoolFull()
{
TestOnlyForOldPool();

SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
SnowflakeDbConnectionPool.SetPooling(true);

var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand Down Expand Up @@ -372,10 +353,8 @@ public void TestNewConnectionPoolFull()
{
TestOnlyForNewPool();

SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.SetPooling(true);
var sessionPool = SnowflakeDbConnectionPool.GetPool(ConnectionString);
SnowflakeDbConnectionPool.SetMaxPoolSize(2);
var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString);
pool.SetMaxPoolSize(2);

var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand All @@ -387,10 +366,10 @@ public void TestNewConnectionPoolFull()
conn2.Open();
Assert.AreEqual(ConnectionState.Open, conn2.State);

Assert.AreEqual(0, sessionPool.GetCurrentPoolSize());
Assert.AreEqual(0, pool.GetCurrentPoolSize());
conn1.Close();
conn2.Close();
Assert.AreEqual(2, sessionPool.GetCurrentPoolSize());
Assert.AreEqual(2, pool.GetCurrentPoolSize());

var conn3 = new SnowflakeDbConnection();
conn3.ConnectionString = ConnectionString;
Expand All @@ -403,15 +382,14 @@ public void TestNewConnectionPoolFull()
Assert.AreEqual(ConnectionState.Open, conn4.State);

conn3.Close();
Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); // TODO: when SNOW-937189 complete should be 2
Assert.AreEqual(1, pool.GetCurrentPoolSize()); // TODO: when SNOW-937189 complete should be 2
conn4.Close();
Assert.AreEqual(2, sessionPool.GetCurrentPoolSize());
Assert.AreEqual(2, pool.GetCurrentPoolSize());

Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(ConnectionState.Closed, conn2.State);
Assert.AreEqual(ConnectionState.Closed, conn3.State);
Assert.AreEqual(ConnectionState.Closed, conn4.State);
SnowflakeDbConnectionPool.ClearAllPools();
}

[Test]
Expand Down Expand Up @@ -457,9 +435,8 @@ void ThreadProcess2(string connstr)
[Test]
public void TestConnectionPoolDisable()
{
SnowflakeDbConnectionPool.ClearAllPools();
SnowflakeDbConnectionPool.GetPool(ConnectionString);
SnowflakeDbConnectionPool.SetPooling(false);
var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString);
pool.SetPooling(false);

var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand All @@ -468,27 +445,18 @@ public void TestConnectionPoolDisable()
conn1.Close();

Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize());
Assert.AreEqual(0, pool.GetCurrentPoolSize());
}

[Test]
public void TestConnectionPoolWithDispose()
{
SnowflakeDbConnectionPool.SetPooling(true);
SnowflakeDbConnectionPool.SetMaxPoolSize(1);
SnowflakeDbConnectionPool.ClearAllPools();


var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = "bad connection string";
try
{
conn1.Open();
}
catch (SnowflakeDbException ex)
{
Console.WriteLine("connection failed:" + ex);
conn1.Close();
}
Assert.Throws<SnowflakeDbException>(() => conn1.Open());
conn1.Close();

Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize());
Expand All @@ -500,7 +468,6 @@ public void TestConnectionPoolTurnOff()
SnowflakeDbConnectionPool.SetPooling(false);
SnowflakeDbConnectionPool.SetPooling(true);
SnowflakeDbConnectionPool.SetMaxPoolSize(1);
SnowflakeDbConnectionPool.ClearAllPools();

var conn1 = new SnowflakeDbConnection();
conn1.ConnectionString = ConnectionString;
Expand All @@ -510,9 +477,6 @@ public void TestConnectionPoolTurnOff()

Assert.AreEqual(ConnectionState.Closed, conn1.State);
Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize());

SnowflakeDbConnectionPool.SetPooling(false);
//Put a breakpoint at SFSession close function, after connection pool is off, it will send close session request.
}

private void TestOnlyForOldPool()
Expand Down
23 changes: 23 additions & 0 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerSwitchTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using NUnit.Framework;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Session;

namespace Snowflake.Data.Tests.UnitTests
{
public class ConnectionPoolManagerSwitchTest
{
private readonly string _connectionString1 = "database=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;";
private readonly string _connectionString2 = "database=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;";

[Test]
public void TestRevertPoolToPreviousVersion()
{
SnowflakeDbConnectionPool.SetOldConnectionPoolVersion();

var sessionPool1 = SnowflakeDbConnectionPool.GetPool(_connectionString1);
var sessionPool2 = SnowflakeDbConnectionPool.GetPool(_connectionString2);
Assert.AreEqual(ConnectionPoolType.SingleConnectionCache, SnowflakeDbConnectionPool.GetConnectionPoolVersion());
Assert.AreEqual(sessionPool1, sessionPool2);
}
}
}
54 changes: 49 additions & 5 deletions Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
*/

using System;
using System.Collections.Generic;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -116,7 +116,6 @@ public void TestCountingOfSessionProvidedByPool()
}

[Test]
[Ignore("Enable after completion of SNOW-937189")] // TODO:
public void TestCountingOfSessionReturnedBackToPool()
{
// Arrange
Expand Down Expand Up @@ -197,7 +196,10 @@ public void TestGetPoolingOnManagerLevelWhenNotAllPoolsEqual()
sessionPool2.SetPooling(false);

// Act/Assert
Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetPooling());
var exception = Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetPooling());
Assert.IsNotNull(exception);
Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
Assert.IsTrue(exception.Message.Contains("Multiple pools have different Pooling values"));
}

[Test]
Expand All @@ -223,7 +225,10 @@ public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual()
sessionPool2.SetTimeout(1313);

// Act/Assert
Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetTimeout());
var exception = Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetTimeout());
Assert.IsNotNull(exception);
Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
Assert.IsTrue(exception.Message.Contains("Multiple pools have different Timeout values"));
}

[Test]
Expand All @@ -249,7 +254,10 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual()
sessionPool2.SetMaxPoolSize(17);

// Act/Assert
Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetMaxPoolSize());
var exception = Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetMaxPoolSize());
Assert.IsNotNull(exception);
Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
Assert.IsTrue(exception.Message.Contains("Multiple pools have different Max Pool Size values"));
}

[Test]
Expand All @@ -264,6 +272,40 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual()
// Act/Assert
Assert.AreEqual(33,_connectionPoolManager.GetMaxPoolSize());
}

[Test]
public void TestGetCurrentPoolSizeThrowsExceptionWhenNotAllPoolsEqual()
{
// Arrange
EnsurePoolSize(_connectionString1, 2);
EnsurePoolSize(_connectionString2, 3);

// Act/Assert
var exception = Assert.Throws<SnowflakeDbException>(() => _connectionPoolManager.GetCurrentPoolSize());
Assert.IsNotNull(exception);
Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
Assert.IsTrue(exception.Message.Contains("Multiple pools have different Current Pool Size values"));
}

private void EnsurePoolSize(string connectionString, int requiredCurrentSize)
{
var sessionPool = _connectionPoolManager.GetPool(connectionString, _password);
sessionPool.SetMaxPoolSize(requiredCurrentSize);
var busySessions = new List<SFSession>();
for (var i = 0; i < requiredCurrentSize; i++)
{
var sfSession = _connectionPoolManager.GetSession(connectionString, _password);
busySessions.Add(sfSession);
}

foreach (var session in busySessions) // TODO: remove after SNOW-937189 since sessions will be already counted by GetCurrentPool size
{
session.close();
_connectionPoolManager.AddSession(session);
}

Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize());
}
}

class MockSessionFactory : ISessionFactory
Expand All @@ -273,6 +315,8 @@ public SFSession NewSession(string connectionString, SecureString password)
var mockSfSession = new Mock<SFSession>(connectionString, password);
mockSfSession.Setup(x => x.Open()).Verifiable();
mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this));
mockSfSession.Setup(x => x.IsNotOpen()).Returns(false);
mockSfSession.Setup(x => x.IsExpired(It.IsAny<long>(), It.IsAny<long>())).Returns(false);
return mockSfSession.Object;
}
}
Expand Down
10 changes: 5 additions & 5 deletions Snowflake.Data/Client/SnowflakeDbConnectionPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class SnowflakeDbConnectionPool
private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<SnowflakeDbConnectionPool>();
private static readonly Object s_connectionManagerInstanceLock = new Object();
private static IConnectionManager s_connectionManager;
private const ConnectionPoolType DefaultConnectionPoolType = ConnectionPoolType.SingleConnectionCache; // TODO: set to public once development of entire ConnectionPoolManager epic is complete
private const ConnectionPoolType DefaultConnectionPoolType = ConnectionPoolType.SingleConnectionCache; // TODO: set to MultipleConnectionPool once development of entire ConnectionPoolManager epic is complete

private static IConnectionManager ConnectionManager
{
Expand All @@ -32,13 +32,13 @@ private static IConnectionManager ConnectionManager

internal static SFSession GetSession(string connectionString, SecureString password)
{
s_logger.Debug($"SnowflakeDbConnectionPool::GetSession for {connectionString}");
s_logger.Debug($"SnowflakeDbConnectionPool::GetSession");
return ConnectionManager.GetSession(connectionString, password);
}

internal static Task<SFSession> GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken)
{
s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync for {connectionString}");
s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync");
return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken);
}

Expand Down Expand Up @@ -112,12 +112,12 @@ internal static void SetConnectionPoolVersion(ConnectionPoolType requestedPoolTy
lock (s_connectionManagerInstanceLock)
{
s_connectionManager?.ClearAllPools();
if (ConnectionPoolType.MultipleConnectionPool.Equals(requestedPoolType))
if (requestedPoolType == ConnectionPoolType.MultipleConnectionPool)
{
s_connectionManager = new ConnectionPoolManager();
s_logger.Info("SnowflakeDbConnectionPool - multiple connection pools enabled");
}
if (ConnectionPoolType.SingleConnectionCache.Equals(requestedPoolType))
if (requestedPoolType == ConnectionPoolType.SingleConnectionCache)
{
s_connectionManager = new ConnectionCacheManager();
s_logger.Warn("SnowflakeDbConnectionPool - connection cache enabled");
Expand Down
3 changes: 3 additions & 0 deletions Snowflake.Data/Core/ErrorMessages.resx
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,7 @@
<data name="BROWSER_RESPONSE_TIMEOUT" xml:space="preserve">
<value>Browser response timed out after {0} seconds.</value>
</data>
<data name="INCONSISTENT_RESULT_ERROR" xml:space="preserve">
<value>Cannot return result set as a scalar value: {0}</value>
</data>
</root>
Loading

0 comments on commit 5d4c5af

Please sign in to comment.