diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs new file mode 100644 index 000000000..589565ddf --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs @@ -0,0 +1,46 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture, NonParallelizable] + public class ConnectionCacheManagerTest + { + private readonly ConnectionCacheManager _connectionCacheManager = new ConnectionCacheManager(); + private const string ConnectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;"; + private static PoolConfig s_poolConfig; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); + SessionPool.SessionFactory = new MockSessionFactory(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [SetUp] + public void BeforeEach() + { + _connectionCacheManager.ClearAllPools(); + } + + [Test] + public void TestEnablePoolingRegardlessOfConnectionStringProperty() + { + // act + var pool = _connectionCacheManager.GetPool(ConnectionString + "poolingEnabled=false"); + + // assert + Assert.IsTrue(pool.GetPooling()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index b7d29c1b1..70efa47fb 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -211,6 +211,18 @@ public void TestGetPoolingOnManagerLevelAlwaysTrue() Assert.IsFalse(sessionPool2.GetPooling()); } + [Test] + [TestCase("authenticator=externalbrowser;account=test;user=test;")] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key")] + public void TestDisabledPoolingWhenSecretesProvidedExternally(string connectionString) + { + // act + var pool = _connectionPoolManager.GetPool(connectionString, null); + + // assert + Assert.IsFalse(pool.GetPooling()); + } + [Test] public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() { diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index 71af2b1c9..40c7551f8 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -139,40 +139,6 @@ public void TestValidateSupportEscapedQuotesInsideValuesForObjectProperties(stri Assert.AreEqual(expectedValue, properties[sessionProperty]); } - [Test] - [TestCase("", "false")] - [TestCase("poolingEnabled=true", "true")] - [TestCase("poolingEnabled=false", "false")] - public void TestPoolingEnabledForExternalBrowserAuthenticator(string connectionParam, string expectedPoolingEnabled) - { - // arrange - var connectionString = $"ACCOUNT=test;AUTHENTICATOR=externalbrowser;{connectionParam}"; - - // act - var properties = SFSessionProperties.ParseConnectionString(connectionString, null); - - // assert - Assert.AreEqual(expectedPoolingEnabled, properties[SFSessionProperty.POOLINGENABLED]); - } - - [Test] - [TestCase(BasicAuthenticator.AUTH_NAME, "true")] - [TestCase(KeyPairAuthenticator.AUTH_NAME, "true")] - [TestCase(OAuthAuthenticator.AUTH_NAME, "true")] - [TestCase(OktaAuthenticator.AUTH_NAME, "true")] - [TestCase(ExternalBrowserAuthenticator.AUTH_NAME, "false")] - public void TestDefaultPoolingEnabledForAuthenticator(string authenticator, string expectedPoolingEnabled) - { - // arrange - var connectionString = $"ACCOUNT=test;USER=test;PASSWORD=test;TOKEN=test;AUTHENTICATOR={authenticator}"; - - // act - var properties = SFSessionProperties.ParseConnectionString(connectionString, null); - - // assert - Assert.AreEqual(expectedPoolingEnabled, properties[SFSessionProperty.POOLINGENABLED]); - } - public static IEnumerable ConnectionStringTestCases() { string defAccount = "testaccount"; @@ -263,7 +229,7 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, - { SFSessionProperty.POOLINGENABLED, "false" } // connection pooling is disabled for external browser authentication + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseWithProxySettings = new TestCase() diff --git a/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs index 1f1c18758..0cc61f28b 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs @@ -219,6 +219,28 @@ public void TestExtractPoolingEnabled(string propertyValue, bool poolingEnabled) Assert.AreEqual(poolingEnabled, result.PoolingEnabled); } + [Test] + [TestCase("account=test;user=test;password=test;", true)] + [TestCase("authenticator=externalbrowser;account=test;user=test;", false)] + [TestCase("authenticator=externalbrowser;account=test;user=test;poolingEnabled=true;", true)] + [TestCase("authenticator=externalbrowser;account=test;user=test;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;poolingEnabled=true;", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey;poolingEnabled=true;", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;private_key_pwd=secretPwd", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;private_key_pwd=", false)] + public void TestDisablePoolingDefaultWhenSecretsProvidedExternally(string connectionString, bool poolingEnabled) + { + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(poolingEnabled, result.PoolingEnabled); + } + [Test] [TestCase("wrong_value")] [TestCase("15")] @@ -252,12 +274,8 @@ public void TestExtractChangedSessionBehaviour(string propertyValue, ChangedSess Assert.AreEqual(expectedChangedSession, result.ChangedSession); } - private ConnectionPoolConfig ExtractConnectionPoolConfig(string connectionString) - { - var properties = SFSessionProperties.ParseConnectionString(connectionString, null); - var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); - return extractedProperties.BuildConnectionPoolConfig(); - } + private ConnectionPoolConfig ExtractConnectionPoolConfig(string connectionString) => + SessionPool.ExtractConfig(connectionString, null).Item1; public class TimeoutTestCase { diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index b79c332c7..2ba4709e7 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -47,6 +47,34 @@ public static SFSessionHttpClientProperties ExtractAndValidate(SFSessionProperti return extractedProperties; } + public void DisablePoolingDefaultIfSecretsProvidedExternally(SFSessionProperties properties) + { + var authenticator = properties[SFSessionProperty.AUTHENTICATOR].ToLower(); + if (ExternalBrowserAuthenticator.AUTH_NAME.Equals(authenticator)) + { + DisablePoolingIfNotExplicitlyEnabled(properties, "external browser"); + + } else if (KeyPairAuthenticator.AUTH_NAME.Equals(authenticator) + && properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_FILE) + && !properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_PWD)) + { + DisablePoolingIfNotExplicitlyEnabled(properties, "key pair with private key in a file"); + } + } + + private void DisablePoolingIfNotExplicitlyEnabled(SFSessionProperties properties, string authenticationDescription) + { + if (!properties.IsPoolingEnabledValueProvided && _poolingEnabled) + { + _poolingEnabled = false; + s_logger.Info($"Disabling connection pooling for {authenticationDescription} authentication"); + } + else if (properties.IsPoolingEnabledValueProvided && _poolingEnabled) + { + s_logger.Warn($"Connection pooling is enabled for {authenticationDescription} authentication which is not recommended"); + } + } + private void CheckPropertiesAreValid() { try diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index c2e3777b6..12b650ce6 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -128,6 +128,8 @@ class SFSessionProperties : Dictionary internal string ConnectionStringWithoutSecrets { get; set; } + internal bool IsPoolingEnabledValueProvided { get; set; } + // Connection string properties to obfuscate in the log private static readonly List s_secretProps = Enum.GetValues(typeof(SFSessionProperty)) .Cast() @@ -254,7 +256,7 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin } ValidateAuthenticator(properties); - DisableConnectionPoolingForExternalBrowser(properties); + properties.IsPoolingEnabledValueProvided = properties.IsNonEmptyValueProvided(SFSessionProperty.POOLINGENABLED); CheckSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); ValidateAccountDomain(properties); @@ -308,25 +310,8 @@ private static void ValidateAuthenticator(SFSessionProperties properties) } } - private static void DisableConnectionPoolingForExternalBrowser(SFSessionProperties properties) - { - if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) - { - authenticator = authenticator.ToLower(); - if (authenticator.Equals(ExternalBrowserAuthenticator.AUTH_NAME)) - { - if (!properties.TryGetValue(SFSessionProperty.POOLINGENABLED, out var poolingEnabledStr)) - { - properties.Add(SFSessionProperty.POOLINGENABLED, "false"); - logger.Info("Connection pooling is disabled for external browser authentication"); - } - else if (Boolean.TryParse(poolingEnabledStr, out var poolingEnabled) && poolingEnabled) - { - logger.Warn("Connection pooling is enabled for external browser authentication"); - } - } - } - } + internal bool IsNonEmptyValueProvided(SFSessionProperty property) => + TryGetValue(property, out var propertyValueStr) && !string.IsNullOrEmpty(propertyValueStr); private static string BuildConnectionStringWithoutSecrets(ref string[] keys, ref string[] values) { diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index e0d72bc4d..c28db40c0 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -103,12 +103,13 @@ private void CleanExpiredSessions() } } - private static Tuple ExtractConfig(string connectionString, SecureString password) + internal static Tuple ExtractConfig(string connectionString, SecureString password) { try { var properties = SFSessionProperties.ParseConnectionString(connectionString, password); var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); + extractedProperties.DisablePoolingDefaultIfSecretsProvidedExternally(properties); return Tuple.Create(extractedProperties.BuildConnectionPoolConfig(), properties.ConnectionStringWithoutSecrets); } catch (Exception exception)