diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index da9de769b..265c05f63 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -2,6 +2,7 @@ * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ +using System; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -14,72 +15,103 @@ namespace Snowflake.Data.Client public class SnowflakeDbConnectionPool { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private static readonly IConnectionManager s_connectionManager = new ConnectionManagerV1(); + private static readonly Object s_connectionManagerInstanceLock = new Object(); + private static IConnectionManager s_connectionManager; + private static IConnectionManager ConnectionManager + { + get + { + if (s_connectionManager != null) + return s_connectionManager; + lock (s_connectionManagerInstanceLock) + { + s_connectionManager = new ConnectionManagerV1(); // old implementation of the pool as a default + } + return s_connectionManager; + } + } internal static SFSession GetSession(string connectionString, SecureString password) { s_logger.Debug("SnowflakeDbConnectionPool::GetSession"); - return s_connectionManager.GetSession(connectionString, password); + return ConnectionManager.GetSession(connectionString, password); } internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) { s_logger.Debug("SnowflakeDbConnectionPool::GetSessionAsync"); - return s_connectionManager.GetSessionAsync(connectionString, password, cancellationToken); + return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); } internal static bool AddSession(SFSession session) { s_logger.Debug("SnowflakeDbConnectionPool::AddSession"); - return s_connectionManager.AddSession(session); + return ConnectionManager.AddSession(session); } public static void ClearAllPools() { s_logger.Debug("SnowflakeDbConnectionPool::ClearAllPools"); - s_connectionManager.ClearAllPools(); + ConnectionManager.ClearAllPools(); } public static void SetMaxPoolSize(int maxPoolSize) { s_logger.Debug("SnowflakeDbConnectionPool::SetMaxPoolSize"); - s_connectionManager.SetMaxPoolSize(maxPoolSize); + ConnectionManager.SetMaxPoolSize(maxPoolSize); } public static int GetMaxPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetMaxPoolSize"); - return s_connectionManager.GetMaxPoolSize(); + return ConnectionManager.GetMaxPoolSize(); } public static void SetTimeout(long connectionTimeout) { s_logger.Debug("SnowflakeDbConnectionPool::SetTimeout"); - s_connectionManager.SetTimeout(connectionTimeout); + ConnectionManager.SetTimeout(connectionTimeout); } public static long GetTimeout() { s_logger.Debug("SnowflakeDbConnectionPool::GetTimeout"); - return s_connectionManager.GetTimeout(); + return ConnectionManager.GetTimeout(); } public static int GetCurrentPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetCurrentPoolSize"); - return s_connectionManager.GetCurrentPoolSize(); + return ConnectionManager.GetCurrentPoolSize(); } public static bool SetPooling(bool isEnable) { s_logger.Debug("SnowflakeDbConnectionPool::SetPooling"); - return s_connectionManager.SetPooling(isEnable); + return ConnectionManager.SetPooling(isEnable); } public static bool GetPooling() { s_logger.Debug("SnowflakeDbConnectionPool::GetPooling"); - return s_connectionManager.GetPooling(); + return ConnectionManager.GetPooling(); + } + + internal static void SwapVersion() + { + lock (s_connectionManagerInstanceLock) + { + if (ConnectionManager is ConnectionManagerV1) + { + s_connectionManager.ClearAllPools(); + s_connectionManager = new ConnectionManagerV2(); + } + if (ConnectionManager is ConnectionManagerV2) + { + s_connectionManager.ClearAllPools(); + s_connectionManager = new ConnectionManagerV1(); + } + } } } } diff --git a/Snowflake.Data/Core/Session/ConnectionManagerV1.cs b/Snowflake.Data/Core/Session/ConnectionManagerV1.cs index 0ad0520e0..e93f176a4 100644 --- a/Snowflake.Data/Core/Session/ConnectionManagerV1.cs +++ b/Snowflake.Data/Core/Session/ConnectionManagerV1.cs @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + */ + using System.Security; using System.Threading; using System.Threading.Tasks; @@ -6,7 +10,7 @@ namespace Snowflake.Data.Core.Session { internal sealed class ConnectionManagerV1 : IConnectionManager { - private readonly SessionPool _sessionPool = new SessionPool(); + private readonly SessionPool _sessionPool = SessionPool.CreateSessionPoolV1(); public SFSession GetSession(string connectionString, SecureString password) => _sessionPool.GetSession(connectionString, password); public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) => _sessionPool.GetSessionAsync(connectionString, password, cancellationToken); diff --git a/Snowflake.Data/Core/Session/ConnectionManagerV2.cs b/Snowflake.Data/Core/Session/ConnectionManagerV2.cs new file mode 100644 index 000000000..e0042ac53 --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionManagerV2.cs @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.Session +{ + internal sealed class ConnectionManagerV2 : IConnectionManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private static readonly Object s_poolsLock = new Object(); + private readonly Dictionary _pools; + + internal ConnectionManagerV2() + { + lock (s_poolsLock) + { + _pools = new Dictionary(); + } + } + + public SFSession GetSession(string connectionString, SecureString password) + => GetPool(connectionString, password).GetSession(); + + public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + => GetPool(connectionString, password).GetSessionAsync(cancellationToken); + + public bool AddSession(SFSession session) + => GetPool(session.ConnectionString, session.Password).AddSession(session); + + public void ClearAllPools() + { + foreach (var sessionPool in _pools.Values) + { + sessionPool.ClearAllPools(); + } + } + + public void SetMaxPoolSize(int maxPoolSize) + { + foreach (var pool in _pools.Values) + { + pool.SetMaxPoolSize(maxPoolSize); + } + } + + public int GetMaxPoolSize() => throw ApiNotSupportedException(); + + public void SetTimeout(long connectionTimeout) + { + foreach (var pool in _pools.Values) + { + pool.SetTimeout(connectionTimeout); + } + } + + public long GetTimeout() => throw ApiNotSupportedException(); + + public int GetCurrentPoolSize() => throw ApiNotSupportedException(); + + public bool SetPooling(bool poolingEnabled) + { + bool switched = true; + foreach (var pool in _pools.Values) + { + if (!pool.SetPooling(poolingEnabled)) + switched = false; + } + return switched; + } + + public bool GetPooling() => throw ApiNotSupportedException(); + + private NotSupportedException ApiNotSupportedException() + { + var message = "Pool settings are controlled with connection string parameters or from a Session Pool object"; + s_logger.Error(message); + return new NotSupportedException(message); + } + + internal SessionPool GetPool(string connectionString, SecureString password) + { + var poolKey = GetPoolKey(connectionString); + + if (_pools.TryGetValue(poolKey, out var item)) + return item; + lock (s_poolsLock) + { + var pool = SessionPool.CreateSessionPoolV2(connectionString, password); + _pools.Add(poolKey, pool); + return pool; + } + } + + // TODO: SNOW-937188 + private string GetPoolKey(string connectionString) + { + return connectionString; + } + } +} diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index ad9aa07cd..625664795 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ using System; @@ -15,7 +15,6 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; -using Snowflake.Data.Configuration; namespace Snowflake.Data.Core { @@ -69,7 +68,8 @@ public class SFSession private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; private long _startTime = 0; - internal string connStr = null; + internal readonly string ConnectionString; + internal readonly SecureString Password; private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); @@ -145,8 +145,9 @@ internal SFSession( EasyLoggingStarter easyLoggingStarter) { _easyLoggingStarter = easyLoggingStarter; - connStr = connectionString; - properties = SFSessionProperties.parseConnectionString(connectionString, password); + ConnectionString = connectionString; + Password = password; + properties = SFSessionProperties.parseConnectionString(ConnectionString, Password); _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); ValidateApplicationName(properties); try diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 5acd72855..e2c037169 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -17,22 +17,37 @@ sealed class SessionPool : IDisposable { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); private static readonly object s_sessionPoolLock = new object(); - private readonly List _sessionPool; + private readonly List _idleSessions; private int _maxPoolSize; private long _timeout; private const int MaxPoolSize = 10; private const long Timeout = 3600; + private string _connectionString; + private SecureString _password; private bool _pooling = true; + private bool _allowExceedMaxPoolSize = true; internal SessionPool() { lock (s_sessionPoolLock) { - _sessionPool = new List(); + _idleSessions = new List(); _maxPoolSize = MaxPoolSize; _timeout = Timeout; } } + + internal SessionPool(string connectionString, SecureString password) : this() + { + _connectionString = connectionString; + _password = password; + _allowExceedMaxPoolSize = false; // TODO: SNOW-937190 + } + + internal static SessionPool CreateSessionPoolV1() => new SessionPool(); + + internal static SessionPool CreateSessionPoolV2(string connectionString, SecureString password) => + new SessionPool(connectionString, password); ~SessionPool() { @@ -51,11 +66,11 @@ private void CleanExpiredSessions() { long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - foreach (var item in _sessionPool.ToList()) + foreach (var item in _idleSessions.ToList()) { if (item.IsExpired(_timeout, timeNow)) { - _sessionPool.Remove(item); + _idleSessions.Remove(item); item.close(); } } @@ -80,17 +95,22 @@ internal Task GetSessionAsync(string connStr, SecureString password, return session != null ? Task.FromResult(session) : NewSessionAsync(connStr, password, cancellationToken); } + internal SFSession GetSession() => GetSession(_connectionString, _password); + + internal Task GetSessionAsync(CancellationToken cancellationToken) => + GetSessionAsync(_connectionString, _password, cancellationToken); + private SFSession GetIdleSession(string connStr) { s_logger.Debug("SessionPool::GetIdleSession"); lock (s_sessionPoolLock) { - for (int i = 0; i < _sessionPool.Count; i++) + for (int i = 0; i < _idleSessions.Count; i++) { - if (_sessionPool[i].connStr.Equals(connStr)) + if (_idleSessions[i].ConnectionString.Equals(connStr)) { - SFSession session = _sessionPool[i]; - _sessionPool.RemoveAt(i); + SFSession session = _idleSessions[i]; + _idleSessions.RemoveAt(i); long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); if (session.IsExpired(_timeout, timeNow)) { @@ -162,18 +182,18 @@ internal bool AddSession(SFSession session) lock (s_sessionPoolLock) { - if (_sessionPool.Count >= _maxPoolSize) + if (_idleSessions.Count >= _maxPoolSize) { CleanExpiredSessions(); } - if (_sessionPool.Count >= _maxPoolSize) + if (_idleSessions.Count >= _maxPoolSize) { - // pool is full + s_logger.Warn($"Pool is full - unable to add session with sid {session.sessionId}"); return false; } s_logger.Debug($"pool connection with sid {session.sessionId}"); - _sessionPool.Add(session); + _idleSessions.Add(session); return true; } } @@ -183,11 +203,11 @@ internal void ClearAllPools() s_logger.Debug("SessionPool::ClearAllPools"); lock (s_sessionPoolLock) { - foreach (SFSession session in _sessionPool) + foreach (SFSession session in _idleSessions) { session.close(); } - _sessionPool.Clear(); + _idleSessions.Clear(); } } @@ -213,7 +233,7 @@ public long GetTimeout() public int GetCurrentPoolSize() { - return _sessionPool.Count; + return _idleSessions.Count; } public bool SetPooling(bool isEnable)