Skip to content

Commit

Permalink
Refactor of SessionPool
Browse files Browse the repository at this point in the history
Move responsibility of handling all session operations to the SessionPool.
Introduction of pool versioning
  • Loading branch information
sfc-gh-mhofman committed Sep 15, 2023
1 parent 622680d commit 64a40ed
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 132 deletions.
1 change: 1 addition & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class PoolConfig {

public PoolConfig()
{
var connectionPoolManagerBase = SnowflakeDbConnectionPool.Instance; // TODO: check why necessary
_maxPoolSize = SnowflakeDbConnectionPool.GetMaxPoolSize();
_timeout = SnowflakeDbConnectionPool.GetTimeout();
_pooling = SnowflakeDbConnectionPool.GetPooling();
Expand Down
2 changes: 1 addition & 1 deletion Snowflake.Data.Tests/Mock/MockSnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public override void Open()

public override Task OpenAsync(CancellationToken cancellationToken)
{
registerConnectionCancellationCallback(cancellationToken);
RegisterConnectionCancellationCallback(cancellationToken);

SetMockSession();

Expand Down
128 changes: 57 additions & 71 deletions Snowflake.Data/Client/SnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class SnowflakeDbConnection : DbConnection

internal int _connectionTimeout;

private bool disposed = false;
private bool _disposed = false;

private static Mutex _arraybindingMutex = new Mutex();

Expand Down Expand Up @@ -151,9 +151,9 @@ public override void Close()
logger.Debug("Close Connection.");
if (IsNonClosedWithSession())
{
var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;
var transactionRollbackStatus = GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;

if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession))
if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(ConnectionString, Password, SfSession))
{
logger.Debug($"Session pooled: {SfSession.sessionId}");
}
Expand All @@ -172,7 +172,7 @@ public override void Close()
// Adding an override for CloseAsync will prevent the need for casting to SnowflakeDbConnection to call CloseAsync(CancellationToken).
public override async Task CloseAsync()
{
await CloseAsync(CancellationToken.None);
await CloseAsync(CancellationToken.None).ConfigureAwait(false);
}
#endif

Expand All @@ -189,9 +189,9 @@ public Task CloseAsync(CancellationToken cancellationToken)
{
if (IsNonClosedWithSession())
{
var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;
var transactionRollbackStatus = GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;

if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession))
if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(ConnectionString, Password, SfSession))
{
logger.Debug($"Session pooled: {SfSession.sessionId}");
_connectionState = ConnectionState.Closed;
Expand Down Expand Up @@ -234,86 +234,53 @@ public Task CloseAsync(CancellationToken cancellationToken)

private bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus)
{
return SnowflakeDbConnectionPool.GetPooling() &&
return GetPooling() &&
transactionRollbackStatus == TransactionRollbackStatus.Success;
}

private bool GetPooling()
{
return SnowflakeDbConnectionPool.GetPool(ConnectionString, Password).GetPooling();
}

public override void Open()
{
logger.Debug("Open Connection.");
if (_connectionState != ConnectionState.Closed)
{
logger.Debug($"Open with a connection already opened: {_connectionState}");
logger.Warn($"Opening a connection already opened: {_connectionState}");
return;
}
SfSession = SnowflakeDbConnectionPool.GetSession(this.ConnectionString);
if (SfSession != null)

try
{
logger.Debug($"Connection open with pooled session: {SfSession.sessionId}");
OnSessionOpen();
SfSession = SnowflakeDbConnectionPool.GetSession(ConnectionString, Password);
OnSessionEstablished();
}
else
catch (Exception e)
{
SetSession();
try
{
SfSession.Open();
}
catch (Exception e)
{
// Otherwise when Dispose() is called, the close request would timeout.
_connectionState = ConnectionState.Closed;
logger.Error("Unable to connect", e);
if (!(e.GetType() == typeof(SnowflakeDbException)))
{
throw
new SnowflakeDbException(
e,
SnowflakeDbException.CONNECTION_FAILURE_SSTATE,
SFError.INTERNAL_ERROR,
"Unable to connect. " + e.Message);
}
else
{
throw;
}
}
RethrowOnSessionOpenFailure(e);
}
OnSessionEstablished();
}

public override Task OpenAsync(CancellationToken cancellationToken)
{
logger.Debug("Open Connection Async.");
if (_connectionState != ConnectionState.Closed)
{
logger.Debug($"Open with a connection already opened: {_connectionState}");
return Task.CompletedTask;
}
SfSession = SnowflakeDbConnectionPool.GetSession(this.ConnectionString);
if (SfSession != null)
{
logger.Debug($"Connection open with pooled session: {SfSession.sessionId}");
OnSessionEstablished();
logger.Warn($"Opening a connection already opened: {_connectionState}");
return Task.CompletedTask;
}

registerConnectionCancellationCallback(cancellationToken);
SetSession();

return SfSession.OpenAsync(cancellationToken).ContinueWith(
previousTask =>
OnSessionOpen();
return SnowflakeDbConnectionPool.GetSessionAsync(ConnectionString, Password, cancellationToken)
.ContinueWith(previousTask =>
{
if (previousTask.IsFaulted)
{
// Exception from SfSession.OpenAsync
Exception sfSessionEx = previousTask.Exception;
_connectionState = ConnectionState.Closed;
logger.Error("Unable to connect", sfSessionEx);
throw new SnowflakeDbException(
sfSessionEx,
SnowflakeDbException.CONNECTION_FAILURE_SSTATE,
SFError.INTERNAL_ERROR,
"Unable to connect");
RethrowOnSessionOpenFailure(previousTask.Exception);
}
else if (previousTask.IsCanceled)
{
Expand All @@ -322,7 +289,7 @@ public override Task OpenAsync(CancellationToken cancellationToken)
}
else
{
logger.Debug("All good");
logger.Debug($"Connection open with pooled session: {SfSession.sessionId}");
// Only continue if the session was opened successfully
OnSessionEstablished();
}
Expand All @@ -344,23 +311,42 @@ public void SetArrayBindStageCreated()
{
_isArrayBindStageCreated = true;
}

/// <summary>
/// Create a new SFsession with the connection string settings.
/// </summary>
/// <exception cref="SnowflakeDbException">If the connection string can't be processed</exception>
private void SetSession()

private void OnSessionOpen()
{
SfSession = new SFSession(ConnectionString, Password);
_connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds;
logger.Debug("Opening session");
_connectionState = ConnectionState.Connecting;
}

private void OnSessionEstablished()
{
if (SfSession == null)
{
logger.Error("Error during opening session");
throw new SnowflakeDbException(SFError.INTERNAL_ERROR, "Unable to establish a session");
}
logger.Debug("Session established");
_connectionState = ConnectionState.Open;
_connectionTimeout = (int)SfSession.connectionTimeout.TotalSeconds;
logger.Debug($"Connection open with pooled session: {SfSession.sessionId}");
}

private void RethrowOnSessionOpenFailure(Exception exception)
{
// Otherwise when Dispose() is called, the close request would timeout.
_connectionState = ConnectionState.Closed;
logger.Error("Unable to connect: ", exception);
if (exception != null && exception is SnowflakeDbException dbException)
throw dbException;

var errorMessage = "Unable to connect. " + (exception != null ? exception.Message : "");
throw new SnowflakeDbException(
exception,
SnowflakeDbException.CONNECTION_FAILURE_SSTATE,
SFError.INTERNAL_ERROR,
errorMessage);
}

protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel)
{
// Parameterless BeginTransaction() method of the super class calls this method with IsolationLevel.Unspecified,
Expand All @@ -382,20 +368,20 @@ protected override DbCommand CreateDbCommand()

protected override void Dispose(bool disposing)
{
if (disposed)
if (_disposed)
return;

try
{
this.Close();
Close();
}
catch (Exception ex)
{
// Prevent an exception from being thrown when disposing of this object
logger.Error("Unable to close connection", ex);
}

disposed = true;
_disposed = true;

base.Dispose(disposing);
}
Expand All @@ -406,7 +392,7 @@ protected override void Dispose(bool disposing)
/// layer or timeout reached. Whichever comes first would trigger query cancellation.
/// </summary>
/// <param name="externalCancellationToken">cancellation token from upper layer</param>
internal void registerConnectionCancellationCallback(CancellationToken externalCancellationToken)
internal void RegisterConnectionCancellationCallback(CancellationToken externalCancellationToken)
{
if (!externalCancellationToken.IsCancellationRequested)
{
Expand Down
84 changes: 67 additions & 17 deletions Snowflake.Data/Client/SnowflakeDbConnectionPool.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using Snowflake.Data.Core;
using System;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
using Snowflake.Data.Core;
using Snowflake.Data.Core.Session;
using Snowflake.Data.Log;

Expand All @@ -7,58 +11,104 @@ namespace Snowflake.Data.Client
public class SnowflakeDbConnectionPool
{
private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<SnowflakeDbConnectionPool>();

internal static SFSession GetSession(string connStr)
private static readonly Object s_instanceLock = new Object();
private static ConnectionPoolManagerBase s_connectionPoolManager;
private static readonly PoolManagerVersion s_poolVersion = PoolManagerVersion.Version1;

public static ConnectionPoolManagerBase Instance
{
s_logger.Debug("SnowflakeDbConnectionPool::GetSession");
return SessionPoolSingleton.Instance.GetSession(connStr);
get
{
if (s_connectionPoolManager != null)
return s_connectionPoolManager;
lock (s_instanceLock)
{
s_connectionPoolManager = ProvideConnectionPoolManager();
}
return s_connectionPoolManager;
}
}

internal static bool AddSession(SFSession session)
public static SessionPool GetPool(string connectionString, SecureString password)
{
s_logger.Debug("SnowflakeDbConnectionPool::AddSession");
return SessionPoolSingleton.Instance.AddSession(session);
s_logger.Debug("SnowflakeDbConnectionPool::GetSession");
return Instance.GetPool(connectionString, password);
}

public static void ClearAllPools()
{
s_logger.Debug("SnowflakeDbConnectionPool::ClearAllPools");
SessionPoolSingleton.Instance.ClearAllPools();
Instance.ClearAllPools();
}

public static void SetMaxPoolSize(int size)
{
SessionPoolSingleton.Instance.SetMaxPoolSize(size);
s_logger.Debug("SnowflakeDbConnectionPool::SetMaxPoolSize");
Instance.SetMaxPoolSize(size);
}

public static int GetMaxPoolSize()
{
return SessionPoolSingleton.Instance.GetMaxPoolSize();
s_logger.Debug("SnowflakeDbConnectionPool::GetMaxPoolSize");
return Instance.GetMaxPoolSize();
}

public static void SetTimeout(long time)
{
SessionPoolSingleton.Instance.SetTimeout(time);
s_logger.Debug("SnowflakeDbConnectionPool::SetTimeout");
Instance.SetTimeout(time);
}

public static long GetTimeout()
{
return SessionPoolSingleton.Instance.GetTimeout();
s_logger.Debug("SnowflakeDbConnectionPool::GetTimeout");
return Instance.GetTimeout();
}

public static int GetCurrentPoolSize()
{
return SessionPoolSingleton.Instance.GetCurrentPoolSize();
s_logger.Debug("SnowflakeDbConnectionPool::GetCurrentPoolSize");
return Instance.GetCurrentPoolSize();
}

public static bool SetPooling(bool isEnable)
{
return SessionPoolSingleton.Instance.SetPooling(isEnable);
s_logger.Debug("SnowflakeDbConnectionPool::SetPooling");
return Instance.SetPooling(isEnable);
}

public static bool GetPooling()
{
return SessionPoolSingleton.Instance.GetPooling();
s_logger.Debug("SnowflakeDbConnectionPool::GetPooling");
return Instance.GetPooling();
}

private static ConnectionPoolManagerBase ProvideConnectionPoolManager()
{
if (s_poolVersion == PoolManagerVersion.Version1)
return new ConnectionPoolManagerV1();

throw new NotSupportedException("Pool version not supported");
}

internal static SFSession GetSession(string connectionString, SecureString password)
{
s_logger.Debug("SnowflakeDbConnectionPool::GetSession");
return Instance.GetSession(connectionString, password);
}

internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken)
{
s_logger.Debug("SnowflakeDbConnectionPool::GetSessionAsync");
return Instance.GetSessionAsync(connectionString, password, cancellationToken);
}

internal static bool AddSession(string connectionString, SecureString password, SFSession session)
{
s_logger.Debug("SnowflakeDbConnectionPool::AddSession");
return Instance.AddSession(connectionString, password, session);
}


}
}
Loading

0 comments on commit 64a40ed

Please sign in to comment.