From ee2bff85a5d4804649b1b9dc225c69aa72121a42 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Thu, 23 Nov 2023 15:46:28 +0100 Subject: [PATCH] SNOW-937190 Wait for idle sessions available --- .../IntegrationTests/ConnectingThreads.cs | 159 ++++++++++++++++ .../ConnectionMultiplePoolsIT.cs | 76 ++++++++ .../ConnectionPoolCommonIT.cs | 124 +----------- .../ConnectionSinglePoolCacheIT.cs | 85 +++++++++ .../Session/CreateSessionTokenTest.cs | 38 ++++ .../Session/CreateSessionTokensTest.cs | 83 ++++++++ .../UnitTests/Session/NoOneWaitingTest.cs | 55 ++++++ .../NotCountingCreateSessionTokensTest.cs | 51 +++++ .../Session/SemaphoreBasedQueueTest.cs | 139 ++++++++++++++ .../Core/Session/ConnectionPoolManager.cs | 2 +- .../Core/Session/CreateSessionToken.cs | 20 ++ .../Core/Session/CreateSessionTokens.cs | 43 +++++ .../Core/Session/ICreateSessionTokens.cs | 13 ++ Snowflake.Data/Core/Session/IWaitingQueue.cs | 21 +++ Snowflake.Data/Core/Session/NoOneWaiting.cs | 37 ++++ .../Session/NotCountingCreateSessionTokens.cs | 13 ++ .../Core/Session/SemaphoreBasedQueue.cs | 54 ++++++ .../Core/Session/SessionOrCreateToken.cs | 22 +++ Snowflake.Data/Core/Session/SessionPool.cs | 178 ++++++++++++++---- 19 files changed, 1056 insertions(+), 157 deletions(-) create mode 100644 Snowflake.Data.Tests/IntegrationTests/ConnectingThreads.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokenTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokensTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Session/NoOneWaitingTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Session/NotCountingCreateSessionTokensTest.cs create mode 100644 Snowflake.Data.Tests/UnitTests/Session/SemaphoreBasedQueueTest.cs create mode 100644 Snowflake.Data/Core/Session/CreateSessionToken.cs create mode 100644 Snowflake.Data/Core/Session/CreateSessionTokens.cs create mode 100644 Snowflake.Data/Core/Session/ICreateSessionTokens.cs create mode 100644 Snowflake.Data/Core/Session/IWaitingQueue.cs create mode 100644 Snowflake.Data/Core/Session/NoOneWaiting.cs create mode 100644 Snowflake.Data/Core/Session/NotCountingCreateSessionTokens.cs create mode 100644 Snowflake.Data/Core/Session/SemaphoreBasedQueue.cs create mode 100644 Snowflake.Data/Core/Session/SessionOrCreateToken.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectingThreads.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectingThreads.cs new file mode 100644 index 000000000..be33f1ce5 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectingThreads.cs @@ -0,0 +1,159 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + class ConnectingThreads + { + private string _connectionString; + + private ConcurrentQueue _events = new ConcurrentQueue(); + + private List threads = new List(); + + public ConnectingThreads(string connectionString) + { + _connectionString = connectionString; + } + + public ConnectingThreads NewThread(string name, + long waitBeforeConnectMillis, + long waitAfterConnectMillis, + bool closeOnExit) + { + var thread = new ConnectingThread( + name, + _events, + _connectionString, + waitBeforeConnectMillis, + waitAfterConnectMillis, + closeOnExit).Build(); + threads.Add(thread); + return this; + } + + public ConnectingThreads StartAll() + { + threads.ForEach(thread => thread.Start()); + return this; + } + + public ConnectingThreads JoinAll() + { + threads.ForEach(thread => thread.Join()); + return this; + } + + public IEnumerable Events() => _events.ToArray().OfType(); + } + + class ConnectingThread + { + private string _name; + + private ConcurrentQueue _events; + + private string _connectionString; + + private long _waitBeforeConnectMillis; + + private long _waitAfterConnectMillis; + + private bool _closeOnExit; + + public ConnectingThread( + string name, + ConcurrentQueue events, + string connectionString, + long waitBeforeConnectMillis, + long waitAfterConnectMillis, + bool closeOnExit) + { + _name = name; + _events = events; + _connectionString = connectionString; + _waitBeforeConnectMillis = waitBeforeConnectMillis; + _waitAfterConnectMillis = waitAfterConnectMillis; + _closeOnExit = closeOnExit; + } + + public Thread Build() + { + var thread = new Thread(Execute); + thread.Name = "thread_" + _name; + return thread; + } + + private void Execute() + { + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = _connectionString; + Sleep(_waitBeforeConnectMillis); + var watch = new Stopwatch(); + watch.Start(); + var connected = false; + try + { + connection.Open(); + connected = true; + } + catch (Exception exception) + { + watch.Stop(); + _events.Enqueue(ThreadEvent.EventConnectingFailed(_name, exception, watch.ElapsedMilliseconds)); + } + if (connected) + { + watch.Stop(); + _events.Enqueue(ThreadEvent.EventConnected(_name, watch.ElapsedMilliseconds)); + } + Sleep(_waitAfterConnectMillis); + if (_closeOnExit) + { + connection.Close(); + } + } + + private void Sleep(long millis) + { + if (millis <= 0) + { + return; + } + System.Threading.Thread.Sleep((int) millis); + } + } + + class ThreadEvent + { + public string ThreadName { get; set; } + + public string EventName { get; set; } + + public Exception Error { get; set; } + + public long Timestamp { get; set; } + + public long Duration { get; set; } + + public ThreadEvent(string threadName, string eventName, Exception error, long duration) + { + ThreadName = threadName; + EventName = eventName; + Error = error; + Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + Duration = duration; + } + + public static ThreadEvent EventConnected(string threadName, long duration) => + new ThreadEvent(threadName, "CONNECTED", null, duration); + + public static ThreadEvent EventConnectingFailed(string threadName, Exception exception, long duration) => + new ThreadEvent(threadName, "FAILED_TO_CONNECT", exception, duration); + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs index bfab3f2a6..ffd97b781 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs @@ -1,5 +1,7 @@ using System; using System.Data; +using System.Diagnostics; +using System.Linq; using NUnit.Framework; using Snowflake.Data.Client; using Snowflake.Data.Core.Session; @@ -95,6 +97,72 @@ public void TestReuseSessionInConnectionPoolReachingMaxConnections() // old name Assert.AreEqual(ConnectionState.Closed, conn3.State); Assert.AreEqual(ConnectionState.Closed, conn4.State); } + + [Test] + public void TestWaitForTheIdleConnectionWhenExceedingMaxConnectionsLimit() + { + // arrange + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + pool.SetMaxPoolSize(2); + pool.SetWaitingTimeout(1000); + var conn1 = OpenedConnection(); + var conn2 = OpenedConnection(); + var watch = new Stopwatch(); + + // act + watch.Start(); + var thrown = Assert.Throws(() => OpenedConnection()); + watch.Stop(); + + // assert + Assert.That(thrown.Message, Does.Contain("Unable to connect. Could not obtain a connection from the pool within a given timeout")); + Assert.GreaterOrEqual(watch.ElapsedMilliseconds, 1000); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 1500); + Assert.AreEqual(pool.GetCurrentPoolSize(), 2); + + // cleanup + conn1.Close(); + conn2.Close(); + } + + [Test] + public void TestWaitInAQueueForAnIdleSession() + { + // arrange + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + pool.SetMaxPoolSize(2); + pool.SetWaitingTimeout(3000); + var threads = new ConnectingThreads(ConnectionString) + .NewThread("A", 0, 2000, true) + .NewThread("B", 50, 2000, true) + .NewThread("C", 100, 0, true) + .NewThread("D", 150, 0, true); + var watch = new Stopwatch(); + + // act + watch.Start(); + threads.StartAll().JoinAll(); + watch.Stop(); + + // assert + var events = threads.Events().ToList(); + Assert.AreEqual(4, events.Count); + CollectionAssert.AreEqual( + new[] + { + Tuple.Create("A", "CONNECTED"), + Tuple.Create("B", "CONNECTED"), + Tuple.Create("C", "CONNECTED"), + Tuple.Create("D", "CONNECTED") + }, + events.Select(e => Tuple.Create(e.ThreadName, e.EventName))); + Assert.LessOrEqual(events[0].Duration, 1000); + Assert.LessOrEqual(events[1].Duration, 1000); + Assert.GreaterOrEqual(events[2].Duration, 2000); + Assert.LessOrEqual(events[2].Duration, 3100); + Assert.GreaterOrEqual(events[3].Duration, 2000); + Assert.LessOrEqual(events[3].Duration, 3100); + } [Test] public void TestBusyAndIdleConnectionsCountedInPoolSize() @@ -186,5 +254,13 @@ public void TestNewConnectionPoolClean() Assert.AreEqual(ConnectionState.Closed, conn2.State); Assert.AreEqual(ConnectionState.Closed, conn3.State); } + + private SnowflakeDbConnection OpenedConnection() + { + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + return connection; + } } } diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs index 7129e5f5f..a75c19790 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs @@ -2,11 +2,8 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ -using System; using System.Data; -using System.Data.Common; using System.Threading; -using System.Threading.Tasks; using NUnit.Framework; using Snowflake.Data.Core; using Snowflake.Data.Client; @@ -52,90 +49,7 @@ public static void AfterAllTests() { SnowflakeDbConnectionPool.ClearAllPools(); } - - [Test] - // test connection pooling with concurrent connection - public void TestConcurrentConnectionPooling() - { - // add test case name in connection string to make in unique for each test case - string connStr = ConnectionString + ";application=TestConcurrentConnectionPooling"; - ConcurrentPoolingHelper(connStr, true); - } - - [Test] - // test connection pooling with concurrent connection and no close - // call for connection. Connection is closed when Dispose() is called - // by framework. - public void TestConcurrentConnectionPoolingDispose() - { - // add test case name in connection string to make in unique for each test case - string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingNoClose"; - ConcurrentPoolingHelper(connStr, false); - } - - static void ConcurrentPoolingHelper(string connectionString, bool closeConnection) - { - // thread number a bit larger than pool size so some connections - // would fail on pooling while some connections could success - const int ThreadNum = 12; - // set short pooling timeout to cover the case that connection expired - const int PoolTimeout = 3; - - // reset to default settings in case it changed by other test cases - Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(connectionString).GetPooling()); // to instantiate pool - SnowflakeDbConnectionPool.SetMaxPoolSize(10); - SnowflakeDbConnectionPool.SetTimeout(PoolTimeout); - - var threads = new Task[ThreadNum]; - for (int i = 0; i < ThreadNum; i++) - { - threads[i] = Task.Factory.StartNew(() => - { - QueryExecutionThread(connectionString, closeConnection); - }); - } - Task.WaitAll(threads); - } - - // thead to execute query with new connection in a loop - static void QueryExecutionThread(string connectionString, bool closeConnection) - { - for (int i = 0; i < 100; i++) - { - using (DbConnection conn = new SnowflakeDbConnection(connectionString)) - { - conn.Open(); - using (DbCommand cmd = conn.CreateCommand()) - { - cmd.CommandText = "select 1, 2, 3"; - try - { - using (var reader = cmd.ExecuteReader()) - { - while (reader.Read()) - { - for (int j = 0; j < reader.FieldCount; j++) - { - // Process each column as appropriate - reader.GetFieldValue(j); - } - } - } - } - catch (Exception e) - { - Assert.Fail("Caught unexpected exception: " + e); - } - } - - if (closeConnection) - { - conn.Close(); - } - } - } - } - + [Test] public void TestBasicConnectionPool() { @@ -150,42 +64,6 @@ public void TestBasicConnectionPool() Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); } - [Test] - public void TestConnectionPoolIsFull() - { - var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); - SnowflakeDbConnectionPool.SetMaxPoolSize(2); - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ConnectionString; - conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); - - var conn2 = new SnowflakeDbConnection(); - conn2.ConnectionString = ConnectionString; - conn2.Open(); - Assert.AreEqual(ConnectionState.Open, conn2.State); - - var conn3 = new SnowflakeDbConnection(); - conn3.ConnectionString = ConnectionString; - conn3.Open(); - Assert.AreEqual(ConnectionState.Open, conn3.State); - SnowflakeDbConnectionPool.ClearAllPools(); - pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); - SnowflakeDbConnectionPool.SetMaxPoolSize(2); - - conn1.Close(); - Assert.AreEqual(1, pool.GetCurrentPoolSize()); - conn2.Close(); - Assert.AreEqual(2, pool.GetCurrentPoolSize()); - conn3.Close(); - Assert.AreEqual(2, pool.GetCurrentPoolSize()); - - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(ConnectionState.Closed, conn2.State); - Assert.AreEqual(ConnectionState.Closed, conn3.State); - SnowflakeDbConnectionPool.ClearAllPools(); - } - [Test] public void TestConnectionPoolExpirationWorks() { diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs index 5e0e278c1..42ae22b74 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs @@ -1,4 +1,7 @@ +using System; using System.Data; +using System.Data.Common; +using System.Threading.Tasks; using NUnit.Framework; using Snowflake.Data.Client; using Snowflake.Data.Core.Session; @@ -32,6 +35,88 @@ public static void AfterAllTests() SnowflakeDbConnectionPool.ClearAllPools(); } + [Test] + public void TestConcurrentConnectionPooling() + { + // add test case name in connection string to make in unique for each test case + string connStr = ConnectionString + ";application=TestConcurrentConnectionPooling"; + ConcurrentPoolingHelper(connStr, true); + } + + [Test] + // test connection pooling with concurrent connection and no close + // call for connection. Connection is closed when Dispose() is called + // by framework. + public void TestConcurrentConnectionPoolingDispose() + { + // add test case name in connection string to make in unique for each test case + string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingNoClose"; + ConcurrentPoolingHelper(connStr, false); + } + + static void ConcurrentPoolingHelper(string connectionString, bool closeConnection) + { + // thread number a bit larger than pool size so some connections + // would fail on pooling while some connections could success + const int ThreadNum = 12; + // set short pooling timeout to cover the case that connection expired + const int PoolTimeout = 3; + + // reset to default settings in case it changed by other test cases + Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(connectionString).GetPooling()); // to instantiate pool + SnowflakeDbConnectionPool.SetMaxPoolSize(10); + SnowflakeDbConnectionPool.SetTimeout(PoolTimeout); + + var threads = new Task[ThreadNum]; + for (int i = 0; i < ThreadNum; i++) + { + threads[i] = Task.Factory.StartNew(() => + { + QueryExecutionThread(connectionString, closeConnection); + }); + } + Task.WaitAll(threads); + } + + // thead to execute query with new connection in a loop + static void QueryExecutionThread(string connectionString, bool closeConnection) + { + for (int i = 0; i < 100; i++) + { + using (DbConnection conn = new SnowflakeDbConnection(connectionString)) + { + conn.Open(); + using (DbCommand cmd = conn.CreateCommand()) + { + cmd.CommandText = "select 1, 2, 3"; + try + { + using (var reader = cmd.ExecuteReader()) + { + while (reader.Read()) + { + for (int j = 0; j < reader.FieldCount; j++) + { + // Process each column as appropriate + reader.GetFieldValue(j); + } + } + } + } + catch (Exception e) + { + Assert.Fail("Caught unexpected exception: " + e); + } + } + + if (closeConnection) + { + conn.Close(); + } + } + } + } + [Test] public void TestPoolContainsClosedConnections() // old name: TestConnectionPool { diff --git a/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokenTest.cs b/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokenTest.cs new file mode 100644 index 000000000..17b48403d --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokenTest.cs @@ -0,0 +1,38 @@ +using System; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class CreateSessionTokenTest + { + private readonly long _timeout = 30000; // 30 seconds in millis + + [Test] + public void TestTokenIsNotExpired() + { + // arrange + var token = new CreateSessionToken(_timeout); + + // act + var isExpired = token.IsExpired(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()); + + // assert + Assert.IsFalse(isExpired); + } + + [Test] + public void TestTokenIsExpired() + { + // arrange + var token = new CreateSessionToken(_timeout); + + // act + var isExpired = token.IsExpired(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() + _timeout + 1); + + // assert + Assert.IsTrue(isExpired); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokensTest.cs b/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokensTest.cs new file mode 100644 index 000000000..bc341df44 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/CreateSessionTokensTest.cs @@ -0,0 +1,83 @@ +using System.Threading; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class CreateSessionTokensTest + { + [Test] + public void TestGrantSessionCreation() + { + // arrange + var tokens = new CreateSessionTokens(); + + // act + tokens.BeginCreate(); + + // assert + Assert.AreEqual(1, tokens.Count()); + + // act + tokens.BeginCreate(); + + // assert + Assert.AreEqual(2, tokens.Count()); + } + + [Test] + public void TestCompleteSessionCreation() + { + // arrange + var tokens = new CreateSessionTokens(); + var token1 = tokens.BeginCreate(); + var token2 = tokens.BeginCreate(); + + // act + tokens.EndCreate(token1); + + // assert + Assert.AreEqual(1, tokens.Count()); + + // act + tokens.EndCreate(token2); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteUnknownTokenDoesNotThrowExceptions() + { + // arrange + var tokens = new CreateSessionTokens(); + tokens.BeginCreate(); + var unknownToken = new CreateSessionToken(0); + + // act + tokens.EndCreate(unknownToken); + + // assert + Assert.AreEqual(1, tokens.Count()); + } + + [Test] + public void TestCompleteCleansExpiredTokens() + { + // arrange + var tokens = new CreateSessionTokens(); + tokens._timeout = 50; + var token = tokens.BeginCreate(); + tokens.BeginCreate(); // this token will be cleaned because of expiration + Assert.AreEqual(2, tokens.Count()); + Thread.Sleep((int) tokens._timeout); + + // act + tokens.EndCreate(token); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/NoOneWaitingTest.cs b/Snowflake.Data.Tests/UnitTests/Session/NoOneWaitingTest.cs new file mode 100644 index 000000000..b3f9e4c50 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/NoOneWaitingTest.cs @@ -0,0 +1,55 @@ +using System.Diagnostics; +using System.Threading; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class NoOneWaitingTest + { + [Test] + public void TestWaitDoesNotHangAndReturnsFalse() + { + // arrange + var noOneWaiting = new NoOneWaiting(); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = noOneWaiting.Wait(10000, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 50); + } + + [Test] + public void TestNoOneIsWaiting() + { + // arrange + var noOneWaiting = new NoOneWaiting(); + noOneWaiting.Wait(10000, CancellationToken.None); + + // act + var isAnyoneWaiting = noOneWaiting.IsAnyoneWaiting(); + + // assert + Assert.IsFalse(isAnyoneWaiting); + } + + [Test] + public void TestWaitingDisabled() + { + // arrange + var noOneWaiting = new NoOneWaiting(); + + // act + var isWaitingEnabled = noOneWaiting.IsWaitingEnabled(); + + // assert + Assert.IsFalse(isWaitingEnabled); + } + } +} \ No newline at end of file diff --git a/Snowflake.Data.Tests/UnitTests/Session/NotCountingCreateSessionTokensTest.cs b/Snowflake.Data.Tests/UnitTests/Session/NotCountingCreateSessionTokensTest.cs new file mode 100644 index 000000000..7be440d23 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/NotCountingCreateSessionTokensTest.cs @@ -0,0 +1,51 @@ +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class NotCountingCreateSessionTokensTest + { + [Test] + public void TestGrantSessionCreation() + { + // arrange + var tokens = new NotCountingCreateSessionTokens(); + + // act + tokens.BeginCreate(); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteSessionCreation() + { + // arrange + var tokens = new NotCountingCreateSessionTokens(); + var token = tokens.BeginCreate(); + + // act + tokens.EndCreate(token); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteUnknownTokenDoesNotThrowExceptions() + { + // arrange + var tokens = new NotCountingCreateSessionTokens(); + tokens.BeginCreate(); + var unknownToken = new CreateSessionToken(0); + + // act + tokens.EndCreate(unknownToken); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SemaphoreBasedQueueTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SemaphoreBasedQueueTest.cs new file mode 100644 index 000000000..9fdd2143c --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SemaphoreBasedQueueTest.cs @@ -0,0 +1,139 @@ +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SemaphoreBasedQueueTest + { + [Test] + public void TestWaitForTheResourceUntilTimeout() + { + // arrange + var queue = new SemaphoreBasedQueue(); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = queue.Wait(50, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.GreaterOrEqual(watch.ElapsedMilliseconds, 50); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 100); + } + + [Test] + public void TestWaitForTheResourceUntilCancellation() + { + // arrange + var queue = new SemaphoreBasedQueue(); + var cancellationSource = new CancellationTokenSource(50); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = queue.Wait(30000, cancellationSource.Token); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.GreaterOrEqual(watch.ElapsedMilliseconds, 50); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 100); + } + + [Test] + public void TestWaitUntilResourceAvailable() + { + // arrange + var queue = new SemaphoreBasedQueue(); + var watch = new Stopwatch(); + Task.Run(() => + { + Thread.Sleep(50); + queue.OnResourceIncrease(); + }); + + // act + watch.Start(); + var result = queue.Wait(30000, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsTrue(result); + Assert.GreaterOrEqual(watch.ElapsedMilliseconds, 50); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 500); + } + + [Test] + public void TestWaitingEnabled() + { + // arrange + var queue = new SemaphoreBasedQueue(); + + // act + var isWaitingEnabled = queue.IsWaitingEnabled(); + + // assert + Assert.IsTrue(isWaitingEnabled); + } + + [Test] + public void TestNoOneIsWaiting() + { + // arrange + var queue = new SemaphoreBasedQueue(); + + // act + var isAnyoneWaiting = queue.IsAnyoneWaiting(); + + // assert + Assert.IsFalse(isAnyoneWaiting); + } + + [Test] + public void TestSomeoneIsWaiting() + { + // arrange + var queue = new SemaphoreBasedQueue(); + var syncThreadsSemaphore = new SemaphoreSlim(0, 1); + Task.Run(() => + { + syncThreadsSemaphore.Release(); + return queue.Wait(1000, CancellationToken.None); + }); + syncThreadsSemaphore.Wait(10000); // make sure scheduled thread execution has started + Thread.Sleep(50); + + // act + var isAnyoneWaiting = queue.IsAnyoneWaiting(); + + // assert + Assert.IsTrue(isAnyoneWaiting); + } + + [Test] + public void TestDecreaseResources() + { + // arrange + var queue = new SemaphoreBasedQueue(); + queue.OnResourceIncrease(); + var watch = new Stopwatch(); + + // act + queue.OnResourceDecrease(); + watch.Start(); + var result = queue.Wait(50, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.GreaterOrEqual(watch.ElapsedMilliseconds, 50); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 500); + } + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index a6d9c7294..1e5147b8a 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -144,7 +144,7 @@ public SessionPool GetPool(string connectionString) return GetPool(connectionString, null); } - // TODO: SNOW-937188 + // TODO: SNOW-937188 private string GetPoolKey(string connectionString) { return connectionString; diff --git a/Snowflake.Data/Core/Session/CreateSessionToken.cs b/Snowflake.Data/Core/Session/CreateSessionToken.cs new file mode 100644 index 000000000..00c5ea643 --- /dev/null +++ b/Snowflake.Data/Core/Session/CreateSessionToken.cs @@ -0,0 +1,20 @@ +using System; + +namespace Snowflake.Data.Core.Session +{ + internal class CreateSessionToken + { + public Guid Id { get; } + private readonly long _grantedAtAsEpocMillis; + private readonly long _timeout; + + public CreateSessionToken(long timeout) + { + Id = Guid.NewGuid(); + _grantedAtAsEpocMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _timeout = timeout; + } + + public bool IsExpired(long nowMillis) => nowMillis > _grantedAtAsEpocMillis + _timeout; + } +} diff --git a/Snowflake.Data/Core/Session/CreateSessionTokens.cs b/Snowflake.Data/Core/Session/CreateSessionTokens.cs new file mode 100644 index 000000000..940025446 --- /dev/null +++ b/Snowflake.Data/Core/Session/CreateSessionTokens.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Core.Session +{ + internal class CreateSessionTokens: ICreateSessionTokens + { + internal long _timeout { get; set; } = Timeout; + private const long Timeout = 30000; // 30 seconds as default + private readonly object _tokenLock = new object(); + private readonly List _tokens = new List(); + private int _tokenCount = 0; + + public CreateSessionToken BeginCreate() + { + lock (_tokenLock) + { + var token = new CreateSessionToken(_timeout); + _tokens.Add(token); + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _tokens.RemoveAll(t => t.IsExpired(now)); + _tokenCount = _tokens.Count; + return token; + } + } + + public void EndCreate(CreateSessionToken token) + { + lock (_tokenLock) + { + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _tokens.RemoveAll(t => token.Id == t.Id || t.IsExpired(now)); + _tokenCount = _tokens.Count; + } + } + + public int Count() + { + return _tokenCount; + } + } +} \ No newline at end of file diff --git a/Snowflake.Data/Core/Session/ICreateSessionTokens.cs b/Snowflake.Data/Core/Session/ICreateSessionTokens.cs new file mode 100644 index 000000000..edba11181 --- /dev/null +++ b/Snowflake.Data/Core/Session/ICreateSessionTokens.cs @@ -0,0 +1,13 @@ +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Core.Session +{ + internal interface ICreateSessionTokens + { + CreateSessionToken BeginCreate(); + + void EndCreate(CreateSessionToken token); + + int Count(); + } +} \ No newline at end of file diff --git a/Snowflake.Data/Core/Session/IWaitingQueue.cs b/Snowflake.Data/Core/Session/IWaitingQueue.cs new file mode 100644 index 000000000..b8a6c673b --- /dev/null +++ b/Snowflake.Data/Core/Session/IWaitingQueue.cs @@ -0,0 +1,21 @@ +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + public interface IWaitingQueue + { + bool Wait(int millisecondsTimeout, CancellationToken cancellationToken); + + void OnResourceIncrease(); + + void OnResourceDecrease(); + + bool IsAnyoneWaiting(); + + bool IsWaitingEnabled(); + + long GetWaitingTimeoutMillis(); + + void SetWaitingTimeout(long timeoutMillis); + } +} diff --git a/Snowflake.Data/Core/Session/NoOneWaiting.cs b/Snowflake.Data/Core/Session/NoOneWaiting.cs new file mode 100644 index 000000000..ff0b4ef39 --- /dev/null +++ b/Snowflake.Data/Core/Session/NoOneWaiting.cs @@ -0,0 +1,37 @@ +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + public class NoOneWaiting: IWaitingQueue + { + public bool Wait(int millisecondsTimeout, CancellationToken cancellationToken) + { + return false; + } + + public void OnResourceIncrease() + { + } + + public void OnResourceDecrease() + { + } + + public bool IsAnyoneWaiting() + { + return false; + } + + public bool IsWaitingEnabled() + { + return false; + } + + public long GetWaitingTimeoutMillis() => 0; + + public void SetWaitingTimeout(long timeoutMillis) + { + throw new System.NotImplementedException(); + } + } +} diff --git a/Snowflake.Data/Core/Session/NotCountingCreateSessionTokens.cs b/Snowflake.Data/Core/Session/NotCountingCreateSessionTokens.cs new file mode 100644 index 000000000..eae65a06e --- /dev/null +++ b/Snowflake.Data/Core/Session/NotCountingCreateSessionTokens.cs @@ -0,0 +1,13 @@ +namespace Snowflake.Data.Core.Session +{ + internal class NotCountingCreateSessionTokens: ICreateSessionTokens + { + public CreateSessionToken BeginCreate() => new CreateSessionToken(0); + + public void EndCreate(CreateSessionToken token) + { + } + + public int Count() => 0; + } +} diff --git a/Snowflake.Data/Core/Session/SemaphoreBasedQueue.cs b/Snowflake.Data/Core/Session/SemaphoreBasedQueue.cs new file mode 100644 index 000000000..7426723a3 --- /dev/null +++ b/Snowflake.Data/Core/Session/SemaphoreBasedQueue.cs @@ -0,0 +1,54 @@ +using System; +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + public class SemaphoreBasedQueue: IWaitingQueue + { + private readonly SemaphoreSlim _semaphore = new SemaphoreSlim(0, 1000); // TODO: how not to set it? + private readonly object _lock = new object(); + private int _waitingCount = 0; + private long _waitingTimeoutMillis = 30000; // 30 seconds as default + + public bool Wait(int millisecondsTimeout, CancellationToken cancellationToken) + { + lock (_lock) + { + _waitingCount++; + } + try + { + return _semaphore.Wait(millisecondsTimeout, cancellationToken); + } + catch (OperationCanceledException exception) + { + return false; + } + finally + { + lock (_lock) + { + _waitingCount--; + } + } + } + + public void OnResourceIncrease() + { + _semaphore.Release(1); + } + + public void OnResourceDecrease() + { + _semaphore.Wait(0, CancellationToken.None); + } + + public bool IsAnyoneWaiting() => _waitingCount > 0; + + public bool IsWaitingEnabled() => true; + + public long GetWaitingTimeoutMillis() => _waitingTimeoutMillis; + + public void SetWaitingTimeout(long timeoutMillis) => _waitingTimeoutMillis = timeoutMillis; + } +} diff --git a/Snowflake.Data/Core/Session/SessionOrCreateToken.cs b/Snowflake.Data/Core/Session/SessionOrCreateToken.cs new file mode 100644 index 000000000..1f3ffea8c --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionOrCreateToken.cs @@ -0,0 +1,22 @@ +using System; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionOrCreateToken + { + public SFSession Session { get; } + public CreateSessionToken CreateToken { get; } + + public SessionOrCreateToken(SFSession session) + { + Session = session ?? throw new Exception("Internal error: missing session"); + CreateToken = null; + } + + public SessionOrCreateToken(CreateSessionToken createToken) + { + Session = null; + CreateToken = createToken ?? throw new Exception("Internal error: missing create token"); + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index ee30664b1..63a792b83 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -20,6 +20,8 @@ sealed class SessionPool : IDisposable private static ISessionFactory s_sessionFactory = new SessionFactory(); private readonly List _idleSessions; + private readonly IWaitingQueue _waitingQueue; + private readonly ICreateSessionTokens _createSessionTokens; private int _maxPoolSize; private long _timeout; private const int MaxPoolSize = 10; @@ -28,7 +30,6 @@ sealed class SessionPool : IDisposable internal SecureString Password { get; } private bool _pooling = true; private readonly ICounter _busySessionsCounter; - private readonly bool _allowExceedMaxPoolSize = true; private SessionPool() { @@ -37,6 +38,8 @@ private SessionPool() _maxPoolSize = MaxPoolSize; _timeout = Timeout; _busySessionsCounter = new FixedZeroCounter(); + _waitingQueue = new NoOneWaiting(); + _createSessionTokens = new NotCountingCreateSessionTokens(); } private SessionPool(string connectionString, SecureString password) @@ -48,7 +51,8 @@ private SessionPool(string connectionString, SecureString password) _busySessionsCounter = new NonNegativeCounter(); ConnectionString = connectionString; Password = password; - _allowExceedMaxPoolSize = false; // TODO: SNOW-937190 + _waitingQueue = new SemaphoreBasedQueue(); + _createSessionTokens = new CreateSessionTokens(); } internal static SessionPool CreateSessionCache() => new SessionPool(); @@ -85,6 +89,7 @@ private void CleanExpiredSessions() if (item.IsExpired(_timeout, timeNow)) { _idleSessions.Remove(item); + _waitingQueue.OnResourceDecrease(); item.close(); } } @@ -95,18 +100,20 @@ internal SFSession GetSession(string connStr, SecureString password) { s_logger.Debug("SessionPool::GetSession"); if (!_pooling) - return NewSession(connStr, password); - SFSession session = GetIdleSession(connStr); - return session ?? NewSession(connStr, password); + return NewSession(connStr, password, new CreateSessionToken(0)); + var sessionOrCreateToken = GetIdleSession(connStr); + return sessionOrCreateToken.Session ?? NewSession(connStr, password, sessionOrCreateToken.CreateToken); } internal Task GetSessionAsync(string connStr, SecureString password, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::GetSessionAsync"); if (!_pooling) - return NewSessionAsync(connStr, password, cancellationToken); - SFSession session = GetIdleSession(connStr); - return session != null ? Task.FromResult(session) : NewSessionAsync(connStr, password, cancellationToken); + return NewSessionAsync(connStr, password, new CreateSessionToken(0), cancellationToken); + var sessionOrCreateToken = GetIdleSession(connStr); + return sessionOrCreateToken.Session != null + ? Task.FromResult(sessionOrCreateToken.Session) + : NewSessionAsync(connStr, password, sessionOrCreateToken.CreateToken, cancellationToken); } internal SFSession GetSession() => GetSession(ConnectionString, Password); @@ -114,49 +121,133 @@ internal Task GetSessionAsync(string connStr, SecureString password, internal Task GetSessionAsync(CancellationToken cancellationToken) => GetSessionAsync(ConnectionString, Password, cancellationToken); - private SFSession GetIdleSession(string connStr) + private SessionOrCreateToken GetIdleSession(string connStr) { s_logger.Debug("SessionPool::GetIdleSession"); lock (_sessionPoolLock) { - for (int i = 0; i < _idleSessions.Count; i++) + if (_waitingQueue.IsAnyoneWaiting()) { - if (_idleSessions[i].ConnectionString.Equals(connStr)) + s_logger.Debug("SessionPool::GetIdleSession - someone is already waiting for a session, request is going to be queued"); + } + else + { + var session = ExtractIdleSession(connStr); + if (session != null) { - SFSession session = _idleSessions[i]; - _idleSessions.RemoveAt(i); - long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - if (session.IsExpired(_timeout, timeNow)) - { - session.close(); - i--; - } - else - { - s_logger.Debug($"reuse pooled session with sid {session.sessionId}"); - _busySessionsCounter.Increase(); - return session; - } + s_logger.Debug("SessionPool::GetIdleSession - no one was waiting for a session, a session was extracted from idle sessions"); + _waitingQueue.OnResourceDecrease(); + return new SessionOrCreateToken(session); + } + s_logger.Debug("SessionPool::GetIdleSession - no one was waiting for session, but could not find any idle session available"); + if (IsAllowedToCreateNewSession()) + { + // there is no need to wait for a session since we can create a new one + return new SessionOrCreateToken(_createSessionTokens.BeginCreate()); + } + } + } + return new SessionOrCreateToken(WaitForSession(connStr)); + } + + private bool IsAllowedToCreateNewSession() + { + if (!_waitingQueue.IsWaitingEnabled()) + { + s_logger.Debug($"SessionPool - new session creation granted"); + return true; + } + var currentSize = GetCurrentPoolSize(); + if (currentSize < _maxPoolSize) + { + s_logger.Debug($"SessionPool - new session creation granted because current size is {currentSize} out of {_maxPoolSize}"); + return true; + } + s_logger.Debug($"SessionPool - could not grant new session creation because current size is {currentSize} out of {_maxPoolSize}"); + return false; + } + + private SFSession WaitForSession(string connStr) + { + var timeout = _waitingQueue.GetWaitingTimeoutMillis(); + s_logger.Debug($"SessionPool::WaitForSession for {timeout} millis timeout"); + var beforeWaitingTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + long nowTime = beforeWaitingTime; + while (nowTime < beforeWaitingTime + timeout) // we loop to handle the case if someone overtook us after being woken or session which we were promised has just expired + { + var timeoutLeft = beforeWaitingTime + timeout - nowTime; + var successful = _waitingQueue.Wait((int) timeoutLeft, CancellationToken.None); + if (!successful) + { + s_logger.Debug($"SessionPool::WaitForSession - woken without a session granted"); + throw WaitingFailedException(); + } + s_logger.Debug($"SessionPool::WaitForSession - woken with a session granted"); + lock (_sessionPoolLock) + { + var session = ExtractIdleSession(connStr); + if (session != null) + { + s_logger.Debug($"SessionPool::WaitForSession - a session was extracted from idle sessions"); + return session; + } + } + nowTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + } + s_logger.Debug($"SessionPool::WaitForSession - could not find any idle session available withing a given timeout"); + throw WaitingFailedException(); + } + + private static Exception WaitingFailedException() => new Exception("Could not obtain a connection from the pool within a given timeout"); + + private SFSession ExtractIdleSession(string connStr) + { + for (int i = 0; i < _idleSessions.Count; i++) + { + if (_idleSessions[i].ConnectionString.Equals(connStr)) + { + SFSession session = _idleSessions[i]; + _idleSessions.RemoveAt(i); // we don't do _waitingQueue.OnResourceDecrease() here because it happens in GetIdleSession() + long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + if (session.IsExpired(_timeout, timeNow)) + { + session.close(); // TODO: change to session required + _waitingQueue.OnResourceDecrease(); + i--; + } + else + { + s_logger.Debug($"reuse pooled session with sid {session.sessionId}"); + _busySessionsCounter.Increase(); + return session; } } } return null; } - private SFSession NewSession(String connectionString, SecureString password) + private SFSession NewSession(String connectionString, SecureString password, CreateSessionToken createSessionToken) { s_logger.Debug("SessionPool::NewSession"); try { var session = s_sessionFactory.NewSession(connectionString, password); session.Open(); + s_logger.Debug("SessionPool::NewSession - opened"); if (_pooling) - _busySessionsCounter.Increase(); + { + lock (_sessionPoolLock) + { + _createSessionTokens.EndCreate(createSessionToken); + _busySessionsCounter.Increase(); + } + } return session; } catch (Exception e) { // Otherwise when Dispose() is called, the close request would timeout. + _createSessionTokens.EndCreate(createSessionToken); if (e is SnowflakeDbException) throw; throw new SnowflakeDbException( @@ -167,7 +258,7 @@ private SFSession NewSession(String connectionString, SecureString password) } } - private Task NewSessionAsync(String connectionString, SecureString password, CancellationToken cancellationToken) + private Task NewSessionAsync(String connectionString, SecureString password, CreateSessionToken createSessionToken, CancellationToken cancellationToken) { s_logger.Debug("SessionPool::NewSessionAsync"); var session = s_sessionFactory.NewSession(connectionString, password); @@ -175,6 +266,11 @@ private Task NewSessionAsync(String connectionString, SecureString pa .OpenAsync(cancellationToken) .ContinueWith(previousTask => { + if (previousTask.IsFaulted) + { + _createSessionTokens.EndCreate(createSessionToken); + } + if (previousTask.IsFaulted && previousTask.Exception != null) throw previousTask.Exception; @@ -185,10 +281,16 @@ private Task NewSessionAsync(String connectionString, SecureString pa "Failure while opening session async"); if (_pooling) - _busySessionsCounter.Increase(); + { + lock (_sessionPoolLock) + { + _createSessionTokens.EndCreate(createSessionToken); + _busySessionsCounter.Increase(); + } + } return session; - }, TaskContinuationOptions.NotOnCanceled); + }, TaskContinuationOptions.None); // previously it was NotOnCanceled but we would like to execute it even in case of cancellation to properly update counters } internal bool AddSession(SFSession session) @@ -218,6 +320,7 @@ internal bool AddSession(SFSession session) s_logger.Debug($"pool connection with sid {session.sessionId}"); _idleSessions.Add(session); + _waitingQueue.OnResourceIncrease(); return true; } } @@ -231,6 +334,7 @@ internal void ClearIdleSessions() { session.close(); } + _idleSessions.ForEach(session => _waitingQueue.OnResourceDecrease()); _idleSessions.Clear(); } } @@ -238,11 +342,17 @@ internal void ClearIdleSessions() internal async void ClearAllPoolsAsync() { s_logger.Debug("SessionPool::ClearAllPoolsAsync"); - foreach (SFSession session in _idleSessions) + IEnumerable idleSessionsCopy; + lock (_sessionPoolLock) + { + idleSessionsCopy = _idleSessions.Select(session => session); + _idleSessions.ForEach(session => _waitingQueue.OnResourceDecrease()); + _idleSessions.Clear(); + } + foreach (SFSession session in idleSessionsCopy) { await session.CloseAsync(CancellationToken.None).ConfigureAwait(false); } - _idleSessions.Clear(); } public void SetMaxPoolSize(int size) @@ -267,7 +377,7 @@ public long GetTimeout() public int GetCurrentPoolSize() { - return _idleSessions.Count + _busySessionsCounter.Count(); + return _idleSessions.Count + _busySessionsCounter.Count() + _createSessionTokens.Count(); } public bool SetPooling(bool isEnable) @@ -287,5 +397,7 @@ public bool GetPooling() { return _pooling; } + + public void SetWaitingTimeout(long timeoutMillis) => _waitingQueue.SetWaitingTimeout(timeoutMillis); } }