From c401f2ca867fc5bb9995ef87cacfb79ca35dcc24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Wed, 11 Oct 2023 11:08:01 +0200 Subject: [PATCH] SNOW-902611 - removed singleton pattern from SessionPool; introduced new interface for ConnectionManager; split tests of ConnectionPool and some cleanup in the classes SNOW-902608 - new Connection Pool Manager version implementation - unit tests for new pool; introduction of SessionFactory for unit testing of new pool manager and session pooling - integration tests for two versions of Connection Pool - ClearAllPools removes entire SessionPools collection in new pool - Fixes to flaky tests of connection pool --- CodingConventions.md | 12 + .../IntegrationTests/SFConnectionPoolIT.cs | 173 +++++++--- .../UnitTests/ConnectionPoolManagerTest.cs | 323 ++++++++++++++++++ .../SnowflakeDbConnectionPoolTest.cs | 25 ++ Snowflake.Data.Tests/Util/PoolConfig.cs | 4 + .../Client/SnowflakeDbConnectionPool.cs | 84 ++++- Snowflake.Data/Core/ErrorMessages.resx | 3 + Snowflake.Data/Core/SFError.cs | 3 + .../Core/Session/ConnectionCacheManager.cs | 7 +- .../Core/Session/ConnectionPoolManager.cs | 150 ++++++++ .../Core/Session/ConnectionPoolType.cs | 8 + .../Core/Session/IConnectionManager.cs | 1 + .../Core/Session/ISessionFactory.cs | 9 + Snowflake.Data/Core/Session/SFSession.cs | 19 +- Snowflake.Data/Core/Session/SessionFactory.cs | 12 + Snowflake.Data/Core/Session/SessionPool.cs | 67 ++-- 16 files changed, 817 insertions(+), 83 deletions(-) create mode 100644 Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs create mode 100644 Snowflake.Data/Core/Session/ConnectionPoolManager.cs create mode 100644 Snowflake.Data/Core/Session/ConnectionPoolType.cs create mode 100644 Snowflake.Data/Core/Session/ISessionFactory.cs create mode 100644 Snowflake.Data/Core/Session/SessionFactory.cs diff --git a/CodingConventions.md b/CodingConventions.md index 19ca8fc75..0242f583e 100644 --- a/CodingConventions.md +++ b/CodingConventions.md @@ -85,6 +85,18 @@ public class ExampleClass } ``` +#### Property + +Use PascalCase, eg. `SomeProperty`. + +```csharp +public ExampleProperty +{ + get; + set; +} +``` + ### Local variables Use camelCase, eg. `someVariable`. diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs index 5c4529225..9c9eaa19b 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs @@ -2,35 +2,43 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ -using Snowflake.Data.Tests.Util; using System; using System.Data; using System.Data.Common; using System.Threading; using System.Threading.Tasks; +using NUnit.Framework; using Snowflake.Data.Core; using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; using Snowflake.Data.Log; -using NUnit.Framework; +using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests { - [TestFixture, NonParallelizable] + [TestFixture(ConnectionPoolType.SingleConnectionCache)] + [TestFixture(ConnectionPoolType.MultipleConnectionPool)] + [NonParallelizable] class SFConnectionPoolIT : SFBaseTest { + private readonly ConnectionPoolType _connectionPoolTypeUnderTest; + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); private static PoolConfig s_previousPoolConfig; - [OneTimeSetUp] - public static void BeforeAllTests() + public SFConnectionPoolIT(ConnectionPoolType connectionPoolTypeUnderTest) { + _connectionPoolTypeUnderTest = connectionPoolTypeUnderTest; s_previousPoolConfig = new PoolConfig(); } - + [SetUp] public new void BeforeTest() { - SnowflakeDbConnectionPool.SetPooling(true); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(_connectionPoolTypeUnderTest); SnowflakeDbConnectionPool.ClearAllPools(); + SnowflakeDbConnectionPool.SetPooling(true); + s_logger.Debug($"---------------- BeforeTest ---------------------"); + s_logger.Debug($"Testing Pool Type: {SnowflakeDbConnectionPool.GetConnectionPoolVersion()}"); } [TearDown] @@ -74,6 +82,7 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio const int PoolTimeout = 3; // reset to default settings in case it changed by other test cases + Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(connectionString).GetPooling()); // to instantiate pool SnowflakeDbConnectionPool.SetMaxPoolSize(10); SnowflakeDbConnectionPool.SetTimeout(PoolTimeout); @@ -86,8 +95,6 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio }); } Task.WaitAll(threads); - // set pooling timeout back to default to avoid impact on other test cases - SnowflakeDbConnectionPool.SetTimeout(3600); } // thead to execute query with new connection in a loop @@ -140,7 +147,7 @@ public void TestBasicConnectionPool() conn1.Close(); Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); } [Test] @@ -150,24 +157,24 @@ public void TestConnectionPool() conn1.Open(); Assert.AreEqual(ConnectionState.Open, conn1.State); conn1.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); var conn2 = new SnowflakeDbConnection(); conn2.ConnectionString = ConnectionString; conn2.Open(); Assert.AreEqual(ConnectionState.Open, conn2.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); conn2.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); - SnowflakeDbConnectionPool.ClearAllPools(); } [Test] public void TestConnectionPoolIsFull() { + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); SnowflakeDbConnectionPool.SetMaxPoolSize(2); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; @@ -175,22 +182,24 @@ public void TestConnectionPoolIsFull() Assert.AreEqual(ConnectionState.Open, conn1.State); var conn2 = new SnowflakeDbConnection(); - conn2.ConnectionString = ConnectionString + " retryCount=1"; + conn2.ConnectionString = ConnectionString; conn2.Open(); Assert.AreEqual(ConnectionState.Open, conn2.State); var conn3 = new SnowflakeDbConnection(); - conn3.ConnectionString = ConnectionString + " retryCount=2"; + conn3.ConnectionString = ConnectionString; conn3.Open(); Assert.AreEqual(ConnectionState.Open, conn3.State); SnowflakeDbConnectionPool.ClearAllPools(); + pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + SnowflakeDbConnectionPool.SetMaxPoolSize(2); conn1.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); conn2.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); conn3.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); @@ -223,13 +232,14 @@ public void TestConnectionPoolExpirationWorks() // The pooling timeout should apply to all connections being pooled, // not just the connections created after the new setting, // so expected result should be 0 - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - SnowflakeDbConnectionPool.SetPooling(false); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); } [Test] public void TestConnectionPoolClean() { + TestOnlyForOldPool(); + SnowflakeDbConnectionPool.SetMaxPoolSize(2); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; @@ -257,12 +267,49 @@ public void TestConnectionPoolClean() Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); Assert.AreEqual(ConnectionState.Closed, conn3.State); + } + + [Test] + public void TestNewConnectionPoolClean() + { + TestOnlyForNewPool(); + + SnowflakeDbConnectionPool.SetMaxPoolSize(2); + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = ConnectionString; + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = ConnectionString + " retryCount=1"; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + + var conn3 = new SnowflakeDbConnection(); + conn3.ConnectionString = ConnectionString + " retryCount=2"; + conn3.Open(); + Assert.AreEqual(ConnectionState.Open, conn3.State); + + conn1.Close(); + conn2.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn2.ConnectionString).GetCurrentPoolSize()); SnowflakeDbConnectionPool.ClearAllPools(); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn2.ConnectionString).GetCurrentPoolSize()); + conn3.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn3.ConnectionString).GetCurrentPoolSize()); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + Assert.AreEqual(ConnectionState.Closed, conn3.State); } [Test] public void TestConnectionPoolFull() { + TestOnlyForOldPool(); + SnowflakeDbConnectionPool.SetMaxPoolSize(2); var conn1 = new SnowflakeDbConnection(); @@ -300,6 +347,50 @@ public void TestConnectionPoolFull() SnowflakeDbConnectionPool.ClearAllPools(); } + [Test] + public void TestNewConnectionPoolFull() + { + TestOnlyForNewPool(); + + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + pool.SetMaxPoolSize(2); + + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = ConnectionString; + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = ConnectionString; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + conn1.Close(); + conn2.Close(); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + + var conn3 = new SnowflakeDbConnection(); + conn3.ConnectionString = ConnectionString; + conn3.Open(); + Assert.AreEqual(ConnectionState.Open, conn3.State); + + var conn4 = new SnowflakeDbConnection(); + conn4.ConnectionString = ConnectionString; + conn4.Open(); + Assert.AreEqual(ConnectionState.Open, conn4.State); + + conn3.Close(); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); // TODO: when SNOW-937189 complete should be 2 + conn4.Close(); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + Assert.AreEqual(ConnectionState.Closed, conn3.State); + Assert.AreEqual(ConnectionState.Closed, conn4.State); + } + [Test] public void TestConnectionPoolMultiThreading() { @@ -335,6 +426,7 @@ void ThreadProcess2(string connstr) SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); + conn1.Close(); SnowflakeDbConnectionPool.ClearAllPools(); SnowflakeDbConnectionPool.SetMaxPoolSize(0); SnowflakeDbConnectionPool.SetPooling(false); @@ -343,7 +435,8 @@ void ThreadProcess2(string connstr) [Test] public void TestConnectionPoolDisable() { - SnowflakeDbConnectionPool.SetPooling(false); + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + pool.SetPooling(false); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; @@ -352,27 +445,21 @@ public void TestConnectionPoolDisable() conn1.Close(); Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); } [Test] public void TestConnectionPoolWithDispose() { SnowflakeDbConnectionPool.SetMaxPoolSize(1); - + var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ""; - try - { - conn1.Open(); - } - catch (SnowflakeDbException ex) - { - conn1.Close(); - } + conn1.ConnectionString = "bad connection string"; + Assert.Throws(() => conn1.Open()); + conn1.Close(); Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); } [Test] @@ -381,7 +468,6 @@ public void TestConnectionPoolTurnOff() SnowflakeDbConnectionPool.SetPooling(false); SnowflakeDbConnectionPool.SetPooling(true); SnowflakeDbConnectionPool.SetMaxPoolSize(1); - SnowflakeDbConnectionPool.ClearAllPools(); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; @@ -390,10 +476,19 @@ public void TestConnectionPoolTurnOff() conn1.Close(); Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - - SnowflakeDbConnectionPool.SetPooling(false); - //Put a breakpoint at SFSession close function, after connection pool is off, it will send close session request. + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); + } + + private void TestOnlyForOldPool() + { + if (_connectionPoolTypeUnderTest != ConnectionPoolType.SingleConnectionCache) + Assert.Ignore($"Test case relates only to {ConnectionPoolType.SingleConnectionCache} pool type"); + } + + private void TestOnlyForNewPool() + { + if (_connectionPoolTypeUnderTest != ConnectionPoolType.MultipleConnectionPool) + Assert.Ignore($"Test case relates only to {ConnectionPoolType.MultipleConnectionPool} pool type"); } } } diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs new file mode 100644 index 000000000..a28a3754c --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -0,0 +1,323 @@ +/* + * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + */ + +using System.Collections.Generic; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Moq; +using Snowflake.Data.Client; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture, NonParallelizable] + class ConnectionPoolManagerTest + { + private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); + private const string ConnectionString1 = "database=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;"; + private const string ConnectionString2 = "database=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;"; + private readonly SecureString _password = new SecureString(); + private static PoolConfig s_poolConfig; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SessionPool.SessionFactory = new MockSessionFactory(); + } + + [OneTimeTearDown] + public void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionString() + { + // Act + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + + // Assert + Assert.AreEqual(ConnectionString1, sessionPool.ConnectionString); + Assert.AreEqual(_password, sessionPool.Password); + } + + [Test] + public void TestPoolManagerReturnsSamePoolForGivenConnectionString() + { + // Arrange + var anotherConnectionString = ConnectionString1; + + // Act + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, _password); + + // Assert + Assert.AreEqual(sessionPool1, sessionPool2); + } + + [Test] + public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() + { + // Arrange + Assert.AreNotSame(ConnectionString1, ConnectionString2); + + // Act + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + + // Assert + Assert.AreNotSame(sessionPool1, sessionPool2); + Assert.AreEqual(ConnectionString1, sessionPool1.ConnectionString); + Assert.AreEqual(ConnectionString2, sessionPool2.ConnectionString); + } + + + [Test] + public void TestGetSessionWorksForSpecifiedConnectionString() + { + // Act + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password); + + // Assert + Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); + Assert.AreEqual(_password, sfSession.Password); + } + + [Test] + public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() + { + // Act + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, _password, CancellationToken.None); + + // Assert + Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); + Assert.AreEqual(_password, sfSession.Password); + } + + [Test] + [Ignore("Enable after completion of SNOW-937189")] // TODO: + public void TestCountingOfSessionProvidedByPool() + { + // Act + _connectionPoolManager.GetSession(ConnectionString1, _password); + + // Assert + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); + } + + [Test] + public void TestCountingOfSessionReturnedBackToPool() + { + // Arrange + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password); + + // Act + _connectionPoolManager.AddSession(sfSession); + + // Assert + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); + } + + [Test] + public void TestSetMaxPoolSizeForAllPools() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + + // Act + _connectionPoolManager.SetMaxPoolSize(3); + + // Assert + Assert.AreEqual(3, sessionPool1.GetMaxPoolSize()); + Assert.AreEqual(3, sessionPool2.GetMaxPoolSize()); + } + + [Test] + public void TestSetTimeoutForAllPools() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + + // Act + _connectionPoolManager.SetTimeout(3000); + + // Assert + Assert.AreEqual(3000, sessionPool1.GetTimeout()); + Assert.AreEqual(3000, sessionPool2.GetTimeout()); + } + + [Test] + public void TestSetPoolingDisabledForAllPools() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + + // Act + _connectionPoolManager.SetPooling(false); + + // Assert + Assert.AreEqual(false, sessionPool1.GetPooling()); + } + + [Test] + public void TestSetPoolingEnabledBack() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + _connectionPoolManager.SetPooling(false); + + // Act + _connectionPoolManager.SetPooling(true); + + // Assert + Assert.AreEqual(true, sessionPool1.GetPooling()); + } + + [Test] + public void TestGetPoolingOnManagerLevelWhenNotAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetPooling(true); + sessionPool2.SetPooling(false); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetPooling()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Pooling values")); + } + + [Test] + public void TestGetPoolingOnManagerLevelWorksWhenAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetPooling(true); + sessionPool2.SetPooling(true); + + // Act/Assert + Assert.AreEqual(true,_connectionPoolManager.GetPooling()); + } + + [Test] + public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetTimeout(299); + sessionPool2.SetTimeout(1313); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetTimeout()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Timeout values")); + } + + [Test] + public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetTimeout(3600); + sessionPool2.SetTimeout(3600); + + // Act/Assert + Assert.AreEqual(3600,_connectionPoolManager.GetTimeout()); + } + + [Test] + public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetMaxPoolSize(1); + sessionPool2.SetMaxPoolSize(17); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetMaxPoolSize()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Max Pool Size values")); + } + + [Test] + public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + sessionPool1.SetMaxPoolSize(33); + sessionPool2.SetMaxPoolSize(33); + + // Act/Assert + Assert.AreEqual(33,_connectionPoolManager.GetMaxPoolSize()); + } + + [Test] + public void TestGetCurrentPoolSizeThrowsExceptionWhenNotAllPoolsEqual() + { + // Arrange + EnsurePoolSize(ConnectionString1, 2); + EnsurePoolSize(ConnectionString2, 3); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetCurrentPoolSize()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Current Pool Size values")); + } + + private void EnsurePoolSize(string connectionString, int requiredCurrentSize) + { + var sessionPool = _connectionPoolManager.GetPool(connectionString, _password); + sessionPool.SetMaxPoolSize(requiredCurrentSize); + var busySessions = new List(); + for (var i = 0; i < requiredCurrentSize; i++) + { + var sfSession = _connectionPoolManager.GetSession(connectionString, _password); + busySessions.Add(sfSession); + } + + foreach (var session in busySessions) // TODO: remove after SNOW-937189 since sessions will be already counted by GetCurrentPool size + { + session.close(); + _connectionPoolManager.AddSession(session); + } + + Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize()); + } + } + + class MockSessionFactory : ISessionFactory + { + public SFSession NewSession(string connectionString, SecureString password) + { + var mockSfSession = new Mock(connectionString, password); + mockSfSession.Setup(x => x.Open()).Verifiable(); + mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this)); + mockSfSession.Setup(x => x.IsNotOpen()).Returns(false); + mockSfSession.Setup(x => x.IsExpired(It.IsAny(), It.IsAny())).Returns(false); + return mockSfSession.Object; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs new file mode 100644 index 000000000..82ad550d9 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs @@ -0,0 +1,25 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests +{ + public class SnowflakeDbConnectionPoolTest + { + private readonly string _connectionString1 = "database=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;"; + private readonly string _connectionString2 = "database=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;"; + + [Test] + public void TestRevertPoolToPreviousVersion() + { + // act + SnowflakeDbConnectionPool.SetOldConnectionPoolVersion(); + + // assert + var sessionPool1 = SnowflakeDbConnectionPool.GetPool(_connectionString1); + var sessionPool2 = SnowflakeDbConnectionPool.GetPool(_connectionString2); + Assert.AreEqual(ConnectionPoolType.SingleConnectionCache, SnowflakeDbConnectionPool.GetConnectionPoolVersion()); + Assert.AreEqual(sessionPool1, sessionPool2); + } + } +} diff --git a/Snowflake.Data.Tests/Util/PoolConfig.cs b/Snowflake.Data.Tests/Util/PoolConfig.cs index 4856da243..078b6e359 100644 --- a/Snowflake.Data.Tests/Util/PoolConfig.cs +++ b/Snowflake.Data.Tests/Util/PoolConfig.cs @@ -3,6 +3,7 @@ */ using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; namespace Snowflake.Data.Tests.Util { @@ -11,16 +12,19 @@ class PoolConfig private readonly bool _pooling; private readonly long _timeout; private readonly int _maxPoolSize; + private readonly ConnectionPoolType _connectionPoolType; public PoolConfig() { _maxPoolSize = SnowflakeDbConnectionPool.GetMaxPoolSize(); _timeout = SnowflakeDbConnectionPool.GetTimeout(); _pooling = SnowflakeDbConnectionPool.GetPooling(); + _connectionPoolType = SnowflakeDbConnectionPool.GetConnectionPoolVersion(); } public void Reset() { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(_connectionPoolType); SnowflakeDbConnectionPool.SetMaxPoolSize(_maxPoolSize); SnowflakeDbConnectionPool.SetTimeout(_timeout); SnowflakeDbConnectionPool.SetPooling(_pooling); diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index f643fa5c9..46348cf61 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -2,6 +2,7 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ +using System; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -14,72 +15,127 @@ namespace Snowflake.Data.Client public class SnowflakeDbConnectionPool { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private static readonly IConnectionManager s_connectionManager = new ConnectionCacheManager(); + private static readonly Object s_connectionManagerInstanceLock = new Object(); + private static IConnectionManager s_connectionManager; + private const ConnectionPoolType DefaultConnectionPoolType = ConnectionPoolType.SingleConnectionCache; // TODO: set to MultipleConnectionPool once development of entire ConnectionPoolManager epic is complete + + private static IConnectionManager ConnectionManager + { + get + { + if (s_connectionManager != null) + return s_connectionManager; + SetConnectionPoolVersion(DefaultConnectionPoolType); + return s_connectionManager; + } + } internal static SFSession GetSession(string connectionString, SecureString password) { - s_logger.Debug("SnowflakeDbConnectionPool::GetSession"); - return s_connectionManager.GetSession(connectionString, password); + s_logger.Debug($"SnowflakeDbConnectionPool::GetSession"); + return ConnectionManager.GetSession(connectionString, password); } internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) { - s_logger.Debug("SnowflakeDbConnectionPool::GetSessionAsync"); - return s_connectionManager.GetSessionAsync(connectionString, password, cancellationToken); + s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync"); + return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); + } + + internal static SessionPool GetPool(string connectionString) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPool"); + return ConnectionManager.GetPool(connectionString); } internal static bool AddSession(SFSession session) { s_logger.Debug("SnowflakeDbConnectionPool::AddSession"); - return s_connectionManager.AddSession(session); + return ConnectionManager.AddSession(session); } public static void ClearAllPools() { s_logger.Debug("SnowflakeDbConnectionPool::ClearAllPools"); - s_connectionManager.ClearAllPools(); + ConnectionManager.ClearAllPools(); } public static void SetMaxPoolSize(int maxPoolSize) { s_logger.Debug("SnowflakeDbConnectionPool::SetMaxPoolSize"); - s_connectionManager.SetMaxPoolSize(maxPoolSize); + ConnectionManager.SetMaxPoolSize(maxPoolSize); } public static int GetMaxPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetMaxPoolSize"); - return s_connectionManager.GetMaxPoolSize(); + return ConnectionManager.GetMaxPoolSize(); } public static void SetTimeout(long connectionTimeout) { s_logger.Debug("SnowflakeDbConnectionPool::SetTimeout"); - s_connectionManager.SetTimeout(connectionTimeout); + ConnectionManager.SetTimeout(connectionTimeout); } public static long GetTimeout() { s_logger.Debug("SnowflakeDbConnectionPool::GetTimeout"); - return s_connectionManager.GetTimeout(); + return ConnectionManager.GetTimeout(); } public static int GetCurrentPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetCurrentPoolSize"); - return s_connectionManager.GetCurrentPoolSize(); + return ConnectionManager.GetCurrentPoolSize(); } public static bool SetPooling(bool isEnable) { s_logger.Debug("SnowflakeDbConnectionPool::SetPooling"); - return s_connectionManager.SetPooling(isEnable); + return ConnectionManager.SetPooling(isEnable); } public static bool GetPooling() { s_logger.Debug("SnowflakeDbConnectionPool::GetPooling"); - return s_connectionManager.GetPooling(); + return ConnectionManager.GetPooling(); + } + + internal static void SetOldConnectionPoolVersion() // TODO: set to public once development of entire ConnectionPoolManager epic is complete + { + SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); + } + + internal static void SetConnectionPoolVersion(ConnectionPoolType requestedPoolType) + { + lock (s_connectionManagerInstanceLock) + { + s_connectionManager?.ClearAllPools(); + if (requestedPoolType == ConnectionPoolType.MultipleConnectionPool) + { + s_connectionManager = new ConnectionPoolManager(); + s_logger.Info("SnowflakeDbConnectionPool - multiple connection pools enabled"); + } + if (requestedPoolType == ConnectionPoolType.SingleConnectionCache) + { + s_connectionManager = new ConnectionCacheManager(); + s_logger.Warn("SnowflakeDbConnectionPool - connection cache enabled"); + } + } + } + + internal static ConnectionPoolType GetConnectionPoolVersion() + { + if (ConnectionManager != null) + { + switch (ConnectionManager) + { + case ConnectionCacheManager _: return ConnectionPoolType.SingleConnectionCache; + case ConnectionPoolManager _: return ConnectionPoolType.MultipleConnectionPool; + } + } + return DefaultConnectionPoolType; } } } diff --git a/Snowflake.Data/Core/ErrorMessages.resx b/Snowflake.Data/Core/ErrorMessages.resx index 7159b86a0..8c50a5b57 100755 --- a/Snowflake.Data/Core/ErrorMessages.resx +++ b/Snowflake.Data/Core/ErrorMessages.resx @@ -183,4 +183,7 @@ Browser response timed out after {0} seconds. + + Cannot return result set as a scalar value: {0} + \ No newline at end of file diff --git a/Snowflake.Data/Core/SFError.cs b/Snowflake.Data/Core/SFError.cs index 2aac72f16..3a71e95fb 100755 --- a/Snowflake.Data/Core/SFError.cs +++ b/Snowflake.Data/Core/SFError.cs @@ -78,6 +78,9 @@ public enum SFError [SFErrorAttr(errorCode = 270057)] BROWSER_RESPONSE_TIMEOUT, + + [SFErrorAttr(errorCode = 270058)] + INCONSISTENT_RESULT_ERROR, } class SFErrorAttr : Attribute diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index e10a984e3..c871c24f1 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + using System.Security; using System.Threading; using System.Threading.Tasks; @@ -6,7 +10,7 @@ namespace Snowflake.Data.Core.Session { internal sealed class ConnectionCacheManager : IConnectionManager { - private readonly SessionPool _sessionPool = new SessionPool(); + private readonly SessionPool _sessionPool = SessionPool.CreateSessionCache(); public SFSession GetSession(string connectionString, SecureString password) => _sessionPool.GetSession(connectionString, password); public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) => _sessionPool.GetSessionAsync(connectionString, password, cancellationToken); @@ -19,5 +23,6 @@ public Task GetSessionAsync(string connectionString, SecureString pas public int GetCurrentPoolSize() => _sessionPool.GetCurrentPoolSize(); public bool SetPooling(bool poolingEnabled) => _sessionPool.SetPooling(poolingEnabled); public bool GetPooling() => _sessionPool.GetPooling(); + public SessionPool GetPool(string _) => _sessionPool; } } diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs new file mode 100644 index 000000000..844daa54d --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Client; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.Session +{ + internal sealed class ConnectionPoolManager : IConnectionManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private static readonly Object s_poolsLock = new Object(); + private readonly Dictionary _pools; + + internal ConnectionPoolManager() + { + lock (s_poolsLock) + { + _pools = new Dictionary(); + } + } + + public SFSession GetSession(string connectionString, SecureString password) + { + s_logger.Debug($"ConnectionPoolManager::GetSession"); + return GetPool(connectionString, password).GetSession(); + } + + public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + { + s_logger.Debug($"ConnectionPoolManager::GetSessionAsync"); + return GetPool(connectionString, password).GetSessionAsync(cancellationToken); + } + + public bool AddSession(SFSession session) + { + s_logger.Debug($"ConnectionPoolManager::AddSession for {session.ConnectionString}"); + return GetPool(session.ConnectionString, session.Password).AddSession(session); + } + + public void ClearAllPools() + { + s_logger.Debug("ConnectionPoolManager::ClearAllPools"); + foreach (var sessionPool in _pools.Values) + { + sessionPool.ClearAllPools(); + } + _pools.Clear(); + } + + public void SetMaxPoolSize(int maxPoolSize) + { + s_logger.Debug("ConnectionPoolManager::SetMaxPoolSize for all pools"); + foreach (var pool in _pools.Values) + { + pool.SetMaxPoolSize(maxPoolSize); + } + } + + public int GetMaxPoolSize() + { + s_logger.Debug("ConnectionPoolManager::GetMaxPoolSize"); + var values = _pools.Values.Select(it => it.GetMaxPoolSize()).Distinct().ToList(); + return values.Count == 1 + ? values.First() + : throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Max Pool Size values"); + } + + public void SetTimeout(long connectionTimeout) + { + s_logger.Debug("ConnectionPoolManager::SetTimeout for all pools"); + foreach (var pool in _pools.Values) + { + pool.SetTimeout(connectionTimeout); + } + } + + public long GetTimeout() + { + s_logger.Debug("ConnectionPoolManager::GetTimeout"); + var values = _pools.Values.Select(it => it.GetTimeout()).Distinct().ToList(); + return values.Count == 1 + ? values.First() + : throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Timeout values"); + } + + public int GetCurrentPoolSize() + { + s_logger.Debug("ConnectionPoolManager::GetCurrentPoolSize"); + var values = _pools.Values.Select(it => it.GetCurrentPoolSize()).Distinct().ToList(); + return values.Count == 1 + ? values.First() + : throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Current Pool Size values"); + } + + public bool SetPooling(bool poolingEnabled) + { + s_logger.Debug("ConnectionPoolManager::SetPooling for all pools"); + return _pools.Values + .Select(pool => pool.SetPooling(poolingEnabled)) + .All(setPoolingResult => setPoolingResult); + } + + public bool GetPooling() + { + s_logger.Debug("ConnectionPoolManager::GetPooling"); + var values = _pools.Values.Select(it => it.GetPooling()).Distinct().ToList(); + return values.Count == 1 + ? values.First() + : throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Pooling values"); + } + + internal SessionPool GetPool(string connectionString, SecureString password) + { + s_logger.Debug($"ConnectionPoolManager::GetPool"); + var poolKey = GetPoolKey(connectionString); + + if (_pools.TryGetValue(poolKey, out var item)) + return item; + lock (s_poolsLock) + { + if (_pools.TryGetValue(poolKey, out var poolCreatedWhileWaitingOnLock)) + return poolCreatedWhileWaitingOnLock; + s_logger.Info($"Creating new pool"); + var pool = SessionPool.CreateSessionPool(connectionString, password); + _pools.Add(poolKey, pool); + return pool; + } + } + + public SessionPool GetPool(string connectionString) + { + s_logger.Debug($"ConnectionPoolManager::GetPool"); + return GetPool(connectionString, null); + } + + // TODO: SNOW-937188 + private string GetPoolKey(string connectionString) + { + return connectionString; + } + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionPoolType.cs b/Snowflake.Data/Core/Session/ConnectionPoolType.cs new file mode 100644 index 000000000..5844878fc --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionPoolType.cs @@ -0,0 +1,8 @@ +namespace Snowflake.Data.Core.Session +{ + internal enum ConnectionPoolType + { + SingleConnectionCache, + MultipleConnectionPool + } +} diff --git a/Snowflake.Data/Core/Session/IConnectionManager.cs b/Snowflake.Data/Core/Session/IConnectionManager.cs index e72ade2e7..c64699d54 100644 --- a/Snowflake.Data/Core/Session/IConnectionManager.cs +++ b/Snowflake.Data/Core/Session/IConnectionManager.cs @@ -21,5 +21,6 @@ internal interface IConnectionManager int GetCurrentPoolSize(); bool SetPooling(bool poolingEnabled); bool GetPooling(); + SessionPool GetPool(string connectionString); } } diff --git a/Snowflake.Data/Core/Session/ISessionFactory.cs b/Snowflake.Data/Core/Session/ISessionFactory.cs new file mode 100644 index 000000000..f9416de8d --- /dev/null +++ b/Snowflake.Data/Core/Session/ISessionFactory.cs @@ -0,0 +1,9 @@ +using System.Security; + +namespace Snowflake.Data.Core.Session +{ + internal interface ISessionFactory + { + SFSession NewSession(string connectionString, SecureString password); + } +} diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index d153ffff4..f554263d3 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ using System; @@ -15,7 +15,6 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; -using Snowflake.Data.Configuration; namespace Snowflake.Data.Core { @@ -69,7 +68,8 @@ public class SFSession private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; private long _startTime = 0; - internal string connStr = null; + internal string ConnectionString { get; } + internal SecureString Password { get; } private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); @@ -145,8 +145,9 @@ internal SFSession( EasyLoggingStarter easyLoggingStarter) { _easyLoggingStarter = easyLoggingStarter; - connStr = connectionString; - properties = SFSessionProperties.parseConnectionString(connectionString, password); + ConnectionString = connectionString; + Password = password; + properties = SFSessionProperties.parseConnectionString(ConnectionString, Password); _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); ValidateApplicationName(properties); try @@ -215,7 +216,7 @@ internal Uri BuildUri(string path, Dictionary queryParams = null return uriBuilder.Uri; } - internal void Open() + internal virtual void Open() { logger.Debug("Open Session"); @@ -227,7 +228,7 @@ internal void Open() authenticator.Authenticate(); } - internal async Task OpenAsync(CancellationToken cancellationToken) + internal virtual async Task OpenAsync(CancellationToken cancellationToken) { logger.Debug("Open Session Async"); @@ -557,12 +558,12 @@ internal void heartbeat() } } - internal bool IsNotOpen() + internal virtual bool IsNotOpen() { return _startTime == 0; } - internal bool IsExpired(long timeoutInSeconds, long utcTimeInSeconds) + internal virtual bool IsExpired(long timeoutInSeconds, long utcTimeInSeconds) { return _startTime + timeoutInSeconds <= utcTimeInSeconds; } diff --git a/Snowflake.Data/Core/Session/SessionFactory.cs b/Snowflake.Data/Core/Session/SessionFactory.cs new file mode 100644 index 000000000..2eb0ba6df --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionFactory.cs @@ -0,0 +1,12 @@ +using System.Security; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionFactory : ISessionFactory + { + public SFSession NewSession(string connectionString, SecureString password) + { + return new SFSession(connectionString, password); + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index a7eae7726..0c62f61be 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -17,22 +17,39 @@ sealed class SessionPool : IDisposable { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); private static readonly object s_sessionPoolLock = new object(); - private readonly List _sessions; + private static ISessionFactory s_sessionFactory = new SessionFactory(); + + private readonly List _idleSessions; private int _maxPoolSize; private long _timeout; private const int MaxPoolSize = 10; private const long Timeout = 3600; + internal string ConnectionString { get; } + internal SecureString Password { get; } private bool _pooling = true; + private bool _allowExceedMaxPoolSize = true; - internal SessionPool() + private SessionPool() { lock (s_sessionPoolLock) { - _sessions = new List(); + _idleSessions = new List(); _maxPoolSize = MaxPoolSize; _timeout = Timeout; } } + + private SessionPool(string connectionString, SecureString password) : this() + { + ConnectionString = connectionString; + Password = password; + _allowExceedMaxPoolSize = false; // TODO: SNOW-937190 + } + + internal static SessionPool CreateSessionCache() => new SessionPool(); + + internal static SessionPool CreateSessionPool(string connectionString, SecureString password) => + new SessionPool(connectionString, password); ~SessionPool() { @@ -46,6 +63,11 @@ public void Dispose() ClearAllPools(); } + internal static ISessionFactory SessionFactory + { + set => s_sessionFactory = value; + } + private void CleanExpiredSessions() { s_logger.Debug("SessionPool::CleanExpiredSessions"); @@ -53,11 +75,11 @@ private void CleanExpiredSessions() { long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - foreach (var item in _sessions.ToList()) + foreach (var item in _idleSessions.ToList()) { if (item.IsExpired(_timeout, timeNow)) { - _sessions.Remove(item); + _idleSessions.Remove(item); item.close(); } } @@ -82,17 +104,22 @@ internal Task GetSessionAsync(string connStr, SecureString password, return session != null ? Task.FromResult(session) : NewSessionAsync(connStr, password, cancellationToken); } + internal SFSession GetSession() => GetSession(ConnectionString, Password); + + internal Task GetSessionAsync(CancellationToken cancellationToken) => + GetSessionAsync(ConnectionString, Password, cancellationToken); + private SFSession GetIdleSession(string connStr) { s_logger.Debug("SessionPool::GetIdleSession"); lock (s_sessionPoolLock) { - for (int i = 0; i < _sessions.Count; i++) + for (int i = 0; i < _idleSessions.Count; i++) { - if (_sessions[i].connStr.Equals(connStr)) + if (_idleSessions[i].ConnectionString.Equals(connStr)) { - SFSession session = _sessions[i]; - _sessions.RemoveAt(i); + SFSession session = _idleSessions[i]; + _idleSessions.RemoveAt(i); long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); if (session.IsExpired(_timeout, timeNow)) { @@ -115,7 +142,7 @@ private SFSession NewSession(String connectionString, SecureString password) s_logger.Debug("SessionPool::NewSession"); try { - var session = new SFSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password); session.Open(); return session; } @@ -135,7 +162,7 @@ private SFSession NewSession(String connectionString, SecureString password) private Task NewSessionAsync(String connectionString, SecureString password, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::NewSessionAsync"); - var session = new SFSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password); return session .OpenAsync(cancellationToken) .ContinueWith(previousTask => @@ -164,18 +191,18 @@ internal bool AddSession(SFSession session) lock (s_sessionPoolLock) { - if (_sessions.Count >= _maxPoolSize) + if (_idleSessions.Count >= _maxPoolSize) { CleanExpiredSessions(); } - if (_sessions.Count >= _maxPoolSize) + if (_idleSessions.Count >= _maxPoolSize) { - // pool is full + s_logger.Warn($"Pool is full - unable to add session with sid {session.sessionId}"); return false; } s_logger.Debug($"pool connection with sid {session.sessionId}"); - _sessions.Add(session); + _idleSessions.Add(session); return true; } } @@ -185,22 +212,22 @@ internal void ClearAllPools() s_logger.Debug("SessionPool::ClearAllPools"); lock (s_sessionPoolLock) { - foreach (SFSession session in _sessions) + foreach (SFSession session in _idleSessions) { session.close(); } - _sessions.Clear(); + _idleSessions.Clear(); } } internal async void ClearAllPoolsAsync() { s_logger.Debug("SessionPool::ClearAllPoolsAsync"); - foreach (SFSession session in _sessions) + foreach (SFSession session in _idleSessions) { await session.CloseAsync(CancellationToken.None).ConfigureAwait(false); } - _sessions.Clear(); + _idleSessions.Clear(); } public void SetMaxPoolSize(int size) @@ -225,7 +252,7 @@ public long GetTimeout() public int GetCurrentPoolSize() { - return _sessions.Count; + return _idleSessions.Count; } public bool SetPooling(bool isEnable)