diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs index 9089cc996..aa5d431ed 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs @@ -50,7 +50,7 @@ public async Task TestAddToPoolOnOpenAsync() } [Test] - public async Task TestDoNotAddToPoolInvalidConnectionAsync() + public async Task TestFailForInvalidConnectionAsync() { // arrange var invalidConnectionString = ";connection_timeout=123"; @@ -63,15 +63,10 @@ public async Task TestDoNotAddToPoolInvalidConnectionAsync() Assert.Fail("OpenAsync should fail for invalid connection string"); } catch {} + var thrown = Assert.Throws(() => SnowflakeDbConnectionPool.GetPool(connection.ConnectionString)); // assert - var pool = SnowflakeDbConnectionPool.GetPool(connection.ConnectionString); - var poolState = pool.GetCurrentState(); - logger.Warn($"Pool state: {poolState}"); - Assert.Less(pool.GetCurrentPoolSize(), SFSessionHttpClientProperties.DefaultMinPoolSize); // for invalid connection string it is used default min pool size - - // cleanup - await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); + Assert.That(thrown.Message, Does.Contain("Required property ACCOUNT is not provided")); } [Test] diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs index 97b54ec59..201e50ca8 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs @@ -50,6 +50,7 @@ public void TestBasicConnectionPool() // assert Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString, null).GetCurrentPoolSize()); } [Test] @@ -177,7 +178,7 @@ public void TestWaitInAQueueForAnIdleSession() { // arrange var connectionString = ConnectionString + "application=TestWaitForMaxSize3;waitingForIdleSessionTimeout=3s;maxPoolSize=2;minPoolSize=0"; - var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + var pool = SnowflakeDbConnectionPool.GetPoolInternal(connectionString); Assert.AreEqual(0, pool.GetCurrentPoolSize(), "the pool is expected to be empty"); const long ADelay = 0; const long BDelay = 400; @@ -262,8 +263,7 @@ public void TestConnectionPoolNotPossibleToDisableForAllPools() public void TestConnectionPoolDisable() { // arrange - var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); - pool.SetPooling(false); + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString + ";poolingEnabled=false"); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; @@ -322,7 +322,7 @@ public void TestConnectionPoolExpirationWorks() // arrange const int ExpirationTimeoutInSeconds = 10; var connectionString = ConnectionString + $"expirationTimeout={ExpirationTimeoutInSeconds};maxPoolSize=4;minPoolSize=2"; - var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + var pool = SnowflakeDbConnectionPool.GetPoolInternal(connectionString); Assert.AreEqual(0, pool.GetCurrentPoolSize()); // act diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs new file mode 100644 index 000000000..801916cb0 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs @@ -0,0 +1,216 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionPoolChangedSessionIT : SFBaseTest + { + private readonly QueryExecResponseData _queryExecResponseChangedRole = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = "role change", + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedDatabase = new() + { + finalDatabaseName = "database changed", + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedSchema = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = "schema changed", + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedWarehouse = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = "warehouse changed" + }; + + private static PoolConfig s_previousPoolConfigRestorer; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_previousPoolConfigRestorer = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + } + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_previousPoolConfigRestorer.Reset(); + } + + [Test] + public void TestPoolDestroysConnectionWhenChangedSessionProperties() + { + var connectionString = ConnectionString + "application=Destroy;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedDatabase); + connection.Close(); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenSessionPropertiesUnchanged() + { + var connectionString = ConnectionString + "application=NoSessionChanges;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForOriginalPoolMode() + { + var connectionString = ConnectionString + "application=OriginalPoolMode;ChangedSession=OriginalPool;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedWarehouse); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + connection.Close(); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForDefaultPoolMode() + { + var connectionString = ConnectionString + "application=DefaultPoolMode;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedRole); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolDestroysAndRecreatesConnection() + { + var connectionString = ConnectionString + "application=DestroyRecreateSession;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedSchema); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestCompareSessionChangesCaseInsensitiveWhenUnquoted() + { + var connectionString = ConnectionString + "application=CompareCaseInsensitive;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var responseData = new QueryExecResponseData() + { + finalDatabaseName = TestEnvironment.TestConfig.database.ToLower(), + finalSchemaName = TestEnvironment.TestConfig.schema.ToUpper(), + finalRoleName = $"{char.ToUpper(TestEnvironment.TestConfig.role[0])}{TestEnvironment.TestConfig.role.Substring(1).ToLower()}", + finalWarehouseName = TestEnvironment.TestConfig.warehouse.ToLower() + }; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(responseData); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestCompareSessionChangesCaseSensitiveWhenQuoted() + { + var connectionString = ConnectionString + "application=CompareCaseSensitive;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var responseData = new QueryExecResponseData() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = $"\\\"SomeQuotedValue\\\"", + finalWarehouseName = TestEnvironment.TestConfig.warehouse.ToLower() + }; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(responseData); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs index e05e342f6..6a0745b23 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs @@ -103,10 +103,17 @@ public void TestConnectionPoolWithDispose() conn1.ConnectionString = "bad connection string"; Assert.Throws(() => conn1.Open()); conn1.Close(); - Thread.Sleep(3000); // minPoolSize = 2 causes that another thread has been started. We sleep to make that thread finish. Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + if (_connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache) + { + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + } + else + { + var thrown = Assert.Throws(() => SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString)); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid")); + } } [Test] diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs index d5f8de2fb..956f7f00c 100644 --- a/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs @@ -223,7 +223,7 @@ public void TestConnectionPoolDisable() { // arrange var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); - pool.SetPooling(false); + SnowflakeDbConnectionPool.SetPooling(false); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; diff --git a/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs b/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs index 595fbb65d..fd2e79409 100644 --- a/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs @@ -24,7 +24,7 @@ public static void BeforeAll() Directory.CreateDirectory(s_workingDirectory); } } - + [OneTimeTearDown] public static void AfterAll() { @@ -36,7 +36,7 @@ public static void AfterEach() { EasyLoggingStarter.Instance.Reset(EasyLoggingLogLevel.Warn); } - + [Test] public void TestEnableEasyLogging() { @@ -48,7 +48,7 @@ public void TestEnableEasyLogging() // act conn.Open(); - + // assert Assert.IsTrue(EasyLoggerManager.HasEasyLoggingAppender()); } @@ -65,13 +65,13 @@ public void TestFailToEnableEasyLoggingForWrongConfiguration() // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } - + [Test] public void TestFailToEnableEasyLoggingWhenConfigHasWrongPermissions() { @@ -79,19 +79,19 @@ public void TestFailToEnableEasyLoggingWhenConfigHasWrongPermissions() { Assert.Ignore("skip test on Windows"); } - + // arrange var configFilePath = CreateConfigTempFile(s_workingDirectory, Config("WARN", s_workingDirectory)); Syscall.chmod(configFilePath, FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IWGRP); using (IDbConnection conn = new SnowflakeDbConnection()) { conn.ConnectionString = ConnectionString + $"CLIENT_CONFIG_FILE={configFilePath}"; - + // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } @@ -103,22 +103,22 @@ public void TestFailToEnableEasyLoggingWhenLogDirectoryNotAccessible() { Assert.Ignore("skip test on Windows"); } - + // arrange var configFilePath = CreateConfigTempFile(s_workingDirectory, Config("WARN", "/")); using (IDbConnection conn = new SnowflakeDbConnection()) { conn.ConnectionString = ConnectionString + $"CLIENT_CONFIG_FILE={configFilePath}"; - + // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.That(thrown.InnerException.Message, Does.Contain("Failed to create logs directory")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } } -} \ No newline at end of file +} diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index e057612d0..ba22c9007 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -1717,7 +1717,7 @@ public void TestEscapeChar() conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); // Data source is empty string for now Assert.AreEqual("", ((SnowflakeDbConnection)conn).DataSource); @@ -1743,7 +1743,7 @@ public void TestEscapeChar1() conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); // Data source is empty string for now Assert.AreEqual("", ((SnowflakeDbConnection)conn).DataSource); diff --git a/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs b/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs new file mode 100644 index 000000000..4a6a03a33 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs @@ -0,0 +1,64 @@ +using System.Net; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Tests.Util; + + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + public class AuthenticationPropertiesValidatorTest + { + private const string _necessaryNonAuthProperties = "account=a;"; + + [TestCase("authenticator=snowflake;user=test;password=test", null)] + [TestCase("authenticator=Snowflake;user=test", "test")] + [TestCase("authenticator=ExternalBrowser", null)] + [TestCase("authenticator=snowflake_jwt;user=test;private_key_file=key.file", null)] + [TestCase("authenticator=SNOWFLAKE_JWT;user=test;private_key=key", null)] + [TestCase("authenticator=Snowflake_jwt;user=test;private_key=key;private_key_pwd=test", null)] + [TestCase("authenticator=oauth;token=value", null)] + [TestCase("AUTHENTICATOR=HTTPS://SOMETHING.OKTA.COM;USER=TEST;PASSWORD=TEST", null)] + [TestCase("authenticator=https://something.oktapreview.com;user=test;password=test", null)] + [TestCase("authenticator=https://vanity.url/snowflake/okta;USER=TEST;PASSWORD=TEST", null)] + public void TestAuthPropertiesValid(string connectionString, string password) + { + // Arrange + var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword; + + // Act/Assert + Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword)); + } + + [TestCase("authenticator=snowflake;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")] + [TestCase("authenticator=snowflake;", "test", SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake;user=;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")] + [TestCase("authenticator=snowflake;user=;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=snowflake;user=;", "test", SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;private_key_file=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;private_key=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=oauth;TOKen=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property TOKEN is not provided")] + [TestCase("authenticator=oauth;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property TOKEN is not provided")] + [TestCase("authenticator=okta;user=;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;user=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=unknown;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + [TestCase("authenticator=http://unknown.okta.com;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + [TestCase("authenticator=https://unknown;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + public void TestAuthPropertiesInvalid(string connectionString, string password, SFError expectedError, string expectedErrorMessage) + { + // Arrange + var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword; + + // Act + var exception = Assert.Throws(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword)); + + // Assert + SnowflakeDbExceptionAssert.HasErrorCode(exception, expectedError); + Assert.That(exception.Message.Contains(expectedErrorMessage), $"Expecting:\n\t{exception.Message}\nto contain:\n\t{expectedErrorMessage}"); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs index 3a2bf3eb0..d1660b41f 100644 --- a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -3,6 +3,7 @@ */ using System; +using System.Net; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -21,7 +22,10 @@ class ConnectionPoolManagerTest private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); private const string ConnectionString1 = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;"; private const string ConnectionString2 = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=1;"; - private readonly SecureString _password = new SecureString(); + private const string ConnectionString3 = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;"; + private readonly SecureString _password1 = null; + private readonly SecureString _password2 = null; + private readonly SecureString _password3 = new NetworkCredential("", "P3").SecurePassword; private static PoolConfig s_poolConfig; [OneTimeSetUp] @@ -49,11 +53,29 @@ public void BeforeEach() public void TestPoolManagerReturnsSessionPoolForGivenConnectionString() { // Act - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); // Assert Assert.AreEqual(ConnectionString1, sessionPool.ConnectionString); - Assert.AreEqual(_password, sessionPool.Password); + Assert.AreEqual(_password1, sessionPool.Password); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringAndSecurelyProvidedPassword() + { + // Act + var sessionPool = _connectionPoolManager.GetPool(ConnectionString3, _password3); + + // Assert + Assert.AreEqual(ConnectionString3, sessionPool.ConnectionString); + Assert.AreEqual(_password3, sessionPool.Password); + } + + [Test] + public void TestPoolManagerThrowsWhenPasswordNotProvided() + { + // Act/Assert + Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionString3, null)); } [Test] @@ -63,8 +85,8 @@ public void TestPoolManagerReturnsSamePoolForGivenConnectionString() var anotherConnectionString = ConnectionString1; // Act - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, _password1); // Assert Assert.AreEqual(sessionPool1, sessionPool2); @@ -77,8 +99,8 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() Assert.AreNotSame(ConnectionString1, ConnectionString2); // Act - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); // Assert Assert.AreNotSame(sessionPool1, sessionPool2); @@ -91,32 +113,32 @@ public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() public void TestGetSessionWorksForSpecifiedConnectionString() { // Act - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); - Assert.AreEqual(_password, sfSession.Password); + Assert.AreEqual(_password1, sfSession.Password); } [Test] public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() { // Act - var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, _password, CancellationToken.None); + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, _password1, CancellationToken.None); // Assert Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); - Assert.AreEqual(_password, sfSession.Password); + Assert.AreEqual(_password1, sfSession.Password); } [Test] public void TestCountingOfSessionProvidedByPool() { // Act - _connectionPoolManager.GetSession(ConnectionString1, _password); + _connectionPoolManager.GetSession(ConnectionString1, _password1); // Assert - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); } @@ -124,13 +146,13 @@ public void TestCountingOfSessionProvidedByPool() public void TestCountingOfSessionReturnedBackToPool() { // Arrange - var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password); + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, _password1); // Act _connectionPoolManager.AddSession(sfSession); // Assert - var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password); + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, _password1); Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); } @@ -138,7 +160,7 @@ public void TestCountingOfSessionReturnedBackToPool() public void TestSetMaxPoolSizeForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password); + _connectionPoolManager.GetPool(ConnectionString1, _password1); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetMaxPoolSize(3)); @@ -151,7 +173,7 @@ public void TestSetMaxPoolSizeForAllPoolsDisabled() public void TestSetTimeoutForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password); + _connectionPoolManager.GetPool(ConnectionString1, _password1); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetTimeout(3000)); @@ -164,7 +186,7 @@ public void TestSetTimeoutForAllPoolsDisabled() public void TestSetPoolingForAllPoolsDisabled() { // Arrange - _connectionPoolManager.GetPool(ConnectionString1, _password); + _connectionPoolManager.GetPool(ConnectionString1, _password1); // Act var thrown = Assert.Throws(() => _connectionPoolManager.SetPooling(false)); @@ -177,8 +199,8 @@ public void TestSetPoolingForAllPoolsDisabled() public void TestGetPoolingOnManagerLevelAlwaysTrue() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); sessionPool1.SetPooling(true); sessionPool2.SetPooling(false); @@ -195,8 +217,8 @@ public void TestGetPoolingOnManagerLevelAlwaysTrue() public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); sessionPool1.SetTimeout(299); sessionPool2.SetTimeout(1313); @@ -211,8 +233,8 @@ public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); sessionPool1.SetTimeout(3600); sessionPool2.SetTimeout(3600); @@ -224,8 +246,8 @@ public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); sessionPool1.SetMaxPoolSize(1); sessionPool2.SetMaxPoolSize(17); @@ -240,8 +262,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() { // Arrange - var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password); - var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password); + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, _password1); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, _password2); sessionPool1.SetMaxPoolSize(33); sessionPool2.SetMaxPoolSize(33); @@ -253,8 +275,8 @@ public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() { // Arrange - EnsurePoolSize(ConnectionString1, 2); - EnsurePoolSize(ConnectionString2, 3); + EnsurePoolSize(ConnectionString1, _password1, 2); + EnsurePoolSize(ConnectionString2, _password2, 3); // act var poolSize = _connectionPoolManager.GetCurrentPoolSize(); @@ -263,13 +285,13 @@ public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() Assert.AreEqual(5, poolSize); } - private void EnsurePoolSize(string connectionString, int requiredCurrentSize) + private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize) { - var sessionPool = _connectionPoolManager.GetPool(connectionString, _password); + var sessionPool = _connectionPoolManager.GetPool(connectionString, password); sessionPool.SetMaxPoolSize(requiredCurrentSize); for (var i = 0; i < requiredCurrentSize; i++) { - _connectionPoolManager.GetSession(connectionString, _password); + _connectionPoolManager.GetSession(connectionString, password); } Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize()); } diff --git a/Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs b/Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs similarity index 100% rename from Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs rename to Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs diff --git a/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs b/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs index 3157619ae..d7399bd65 100644 --- a/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs @@ -68,7 +68,7 @@ public void TestGetAuthenticatorOAuth() public void TestGetAuthenticatorOAuthWithMissingToken() { SnowflakeDbException ex = Assert.Throws(() => GetAuthenticator(OAuthAuthenticator.AUTH_NAME)); - Assert.AreEqual(SFError.INVALID_CONNECTION_STRING.GetAttribute().errorCode, ex.ErrorCode); + Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute().errorCode, ex.ErrorCode); } [Test] diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index 8c3dfcc10..1bf07f037 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -102,20 +102,6 @@ public void TestValidateSupportEscapedQuotesValuesForObjectProperties(string pro Assert.AreEqual(value, properties[sessionProperty]); } - [Test] - public void TestProcessEmptyUserAndPasswordInConnectionString() - { - // arrange - var connectionString = $"ACCOUNT=test;USER=;PASSWORD=;"; - - // act - var properties = SFSessionProperties.ParseConnectionString(connectionString, null); - - // assert - Assert.AreEqual(string.Empty, properties[SFSessionProperty.USER]); - Assert.AreEqual(string.Empty, properties[SFSessionProperty.PASSWORD]); - } - public static IEnumerable ConnectionStringTestCases() { string defAccount = "testaccount"; diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index b9530b83b..fa5eafbd1 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -1,15 +1,12 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ -using Snowflake.Data.Configuration; -using Snowflake.Data.Log; +using Snowflake.Data.Core; +using NUnit.Framework; namespace Snowflake.Data.Tests.UnitTests { - using Snowflake.Data.Core; - using NUnit.Framework; - [TestFixture] class SFSessionTest { @@ -20,26 +17,61 @@ public void TestSessionGoneWhenClose() Mock.MockCloseSessionGone restRequester = new Mock.MockCloseSessionGone(); SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); - sfSession.close(); // no exception is raised. + Assert.DoesNotThrow(() => sfSession.close()); } [Test] - public void TestUpdateDatabaseAndSchema() + public void TestUpdateSessionProperties() { + // arrange string databaseName = "DB_TEST"; string schemaName = "SC_TEST"; - + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + QueryExecResponseData queryExecResponseData = new QueryExecResponseData + { + finalSchemaName = schemaName, + finalDatabaseName = databaseName, + finalRoleName = roleName, + finalWarehouseName = warehouseName + }; + + // act SFSession sfSession = new SFSession("account=test;user=test;password=test", null); - sfSession.UpdateDatabaseAndSchema(databaseName, schemaName); + sfSession.UpdateSessionProperties(queryExecResponseData); + // assert Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); + } + + [Test] + public void TestSkipUpdateSessionPropertiesWhenPropertiesMissing() + { + // arrange + string databaseName = "DB_TEST"; + string schemaName = "SC_TEST"; + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + SFSession sfSession = new SFSession("account=test;user=test;password=test", null); + sfSession.database = databaseName; + sfSession.warehouse = warehouseName; + sfSession.role = roleName; + sfSession.schema = schemaName; + + // act + QueryExecResponseData queryExecResponseWithoutData = new QueryExecResponseData(); + sfSession.UpdateSessionProperties(queryExecResponseWithoutData); + // assert // when database or schema name is missing in the response, // the cached value should keep unchanged - sfSession.UpdateDatabaseAndSchema(null, null); Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); } [Test] @@ -54,12 +86,42 @@ public void TestThatConfiguresEasyLogging(string configPath) var connectionString = configPath == null ? simpleConnectionString : $"{simpleConnectionString}client_config_file={configPath};"; - + // act new SFSession(connectionString, null, easyLoggingStarter.Object); - + // assert easyLoggingStarter.Verify(starter => starter.Init(configPath)); } + + [TestCase(null, "accountDefault", "accountDefault", false)] + [TestCase("initial", "initial", "initial", false)] + [TestCase("initial", null, "initial", false)] + [TestCase("initial", "IniTiaL", "initial", false)] + [TestCase("initial", "final", "final", true)] + [TestCase("initial", "\\\"final\\\"", "\"final\"", true)] + [TestCase("initial", "\\\"Final\\\"", "\"Final\"", true)] + [TestCase("\"Ini\\t\"ial\"", "\\\"Ini\\t\"ial\\\"", "\"Ini\\t\"ial\"", false)] + [TestCase("\"initial\"", "initial", "initial", true)] + [TestCase("\"initial\"", "\\\"initial\\\"", "\"initial\"", false)] + [TestCase("init\"ial", "init\"ial", "init\"ial", false)] + [TestCase("\"init\"ial\"", "\\\"init\"ial\\\"", "\"init\"ial\"", false)] + [TestCase("\"init\"ial\"", "\\\"Init\"ial\\\"", "\"Init\"ial\"", true)] + public void TestSessionPropertyQuotationSafeUpdateOnServerResponse(string sessionInitialValue, string serverResponseFinalSessionValue, string unquotedExpectedFinalValue, bool wasChanged) + { + // Arrange + SFSession sfSession = new SFSession("account=test;user=test;password=test", null); + var changedSessionValue = sessionInitialValue; + + // Act + sfSession.UpdateSessionProperty(ref changedSessionValue, serverResponseFinalSessionValue); + + // Assert + Assert.AreEqual(sfSession.SessionPropertiesChanged, wasChanged); + if (wasChanged || sessionInitialValue is null) + Assert.AreEqual(unquotedExpectedFinalValue, changedSessionValue); + else + Assert.AreEqual(sessionInitialValue, changedSessionValue); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs index 954a7c037..1f1c18758 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs @@ -234,6 +234,24 @@ public void TestExtractFailsForWrongValueOfPoolingEnabled(string propertyValue) Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.POOLINGENABLED.ToString()}")); } + [Test] + [TestCase("OriginalPool", ChangedSessionBehavior.OriginalPool)] + [TestCase("originalpool", ChangedSessionBehavior.OriginalPool)] + [TestCase("ORIGINALPOOL", ChangedSessionBehavior.OriginalPool)] + [TestCase("Destroy", ChangedSessionBehavior.Destroy)] + [TestCase("DESTROY", ChangedSessionBehavior.Destroy)] + public void TestExtractChangedSessionBehaviour(string propertyValue, ChangedSessionBehavior expectedChangedSession) + { + // arrange + var connectionString = $"account=test;user=test;password=test;changedSession={propertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(expectedChangedSession, result.ChangedSession); + } + private ConnectionPoolConfig ExtractConnectionPoolConfig(string connectionString) { var properties = SFSessionProperties.ParseConnectionString(connectionString, null); diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs index a66146823..95b4e596e 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs @@ -1,7 +1,10 @@ using System.Net; using System.Text.RegularExpressions; using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.UnitTests.Session { @@ -87,17 +90,17 @@ public void TestPoolIdentificationBasedOnConnectionString(string connectionStrin } [Test] - public void TestPoolIdentificationForInvalidConnectionString() + public void TestRetrievePoolFailureForInvalidConnectionString() { // arrange var invalidConnectionString = "account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443"; // invalid because password is not provided - var pool = SessionPool.CreateSessionPool(invalidConnectionString, null); // act - var poolIdentification = pool.PoolIdentificationBasedOnConnectionString; + var exception = Assert.Throws(() => SessionPool.CreateSessionPool(invalidConnectionString, null)); // assert - Assert.AreEqual(" [pool: could not parse connection string]", poolIdentification); + SnowflakeDbExceptionAssert.HasErrorCode(exception, SFError.MISSING_CONNECTION_PROPERTY); + Assert.IsTrue(exception.Message.Contains("Required property PASSWORD is not provided")); } [Test] diff --git a/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs index 82ad550d9..e2863f0b5 100644 --- a/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs @@ -14,10 +14,10 @@ public void TestRevertPoolToPreviousVersion() { // act SnowflakeDbConnectionPool.SetOldConnectionPoolVersion(); - + // assert - var sessionPool1 = SnowflakeDbConnectionPool.GetPool(_connectionString1); - var sessionPool2 = SnowflakeDbConnectionPool.GetPool(_connectionString2); + var sessionPool1 = SnowflakeDbConnectionPool.GetPoolInternal(_connectionString1); + var sessionPool2 = SnowflakeDbConnectionPool.GetPoolInternal(_connectionString2); Assert.AreEqual(ConnectionPoolType.SingleConnectionCache, SnowflakeDbConnectionPool.GetConnectionPoolVersion()); Assert.AreEqual(sessionPool1, sessionPool2); } diff --git a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs index 63432da31..881cba861 100644 --- a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs +++ b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs @@ -13,16 +13,16 @@ public static class SnowflakeDbExceptionAssert { public static void HasErrorCode(SnowflakeDbException exception, SFError sfError) { - Assert.AreEqual(exception.ErrorCode, sfError.GetAttribute().errorCode); + Assert.AreEqual(sfError.GetAttribute().errorCode, exception.ErrorCode); } - + public static void HasErrorCode(Exception exception, SFError sfError) { Assert.NotNull(exception); switch (exception) { case SnowflakeDbException snowflakeDbException: - Assert.AreEqual(snowflakeDbException.ErrorCode, sfError.GetAttribute().errorCode); + Assert.AreEqual(sfError.GetAttribute().errorCode, snowflakeDbException.ErrorCode); break; default: Assert.Fail(exception.GetType() + " type is not " + typeof(SnowflakeDbException)); @@ -45,7 +45,7 @@ public static void HasHttpErrorCodeInExceptionChain(Exception exception, HttpSta return he.Message.Contains(((int)expected).ToString()); #else return he.StatusCode == expected; -#endif +#endif default: return false; } diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index b0592db80..fc0ba199d 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -18,12 +18,12 @@ public class SnowflakeDbConnection : DbConnection { private SFLogger logger = SFLoggerFactory.GetLogger(); - internal SFSession SfSession { get; set; } + internal SFSession SfSession { get; set; } internal ConnectionState _connectionState; protected override DbProviderFactory DbProviderFactory => new SnowflakeDbFactory(); - + internal int _connectionTimeout; private bool _disposed = false; @@ -47,7 +47,7 @@ protected enum TransactionRollbackStatus public SnowflakeDbConnection() { _connectionState = ConnectionState.Closed; - _connectionTimeout = + _connectionTimeout = int.Parse(SFSessionProperty.CONNECTION_TIMEOUT.GetAttribute(). defaultValue); _isArrayBindStageCreated = false; @@ -84,12 +84,12 @@ private bool IsNonClosedWithSession() public override int ConnectionTimeout => this._connectionTimeout; /// - /// If the connection to the database is closed, the DataSource returns whatever is contained - /// in the ConnectionString for the DataSource keyword. If the connection is open and the - /// ConnectionString data source keyword's value starts with "|datadirectory|", the property - /// returns whatever is contained in the ConnectionString for the DataSource keyword only. If - /// the connection to the database is open, the property returns what the native provider - /// returns for the DBPROP_INIT_DATASOURCE, and if that is empty, the native provider's + /// If the connection to the database is closed, the DataSource returns whatever is contained + /// in the ConnectionString for the DataSource keyword. If the connection is open and the + /// ConnectionString data source keyword's value starts with "|datadirectory|", the property + /// returns whatever is contained in the ConnectionString for the DataSource keyword only. If + /// the connection to the database is open, the property returns what the native provider + /// returns for the DBPROP_INIT_DATASOURCE, and if that is empty, the native provider's /// DBPROP_DATASOURCENAME is returned. /// Note: not yet implemented /// @@ -115,7 +115,7 @@ public void PreventPooling() SfSession.SetPooling(false); logger.Debug($"Session {SfSession.sessionId} marked not to be pooled any more"); } - + internal bool HasActiveExplicitTransaction() => ExplicitTransaction != null && ExplicitTransaction.IsActive; private bool TryToReturnSessionToPool() @@ -150,12 +150,12 @@ private TransactionRollbackStatus TerminateTransactionForDirtyConnectionReturnin // error to indicate a problem within application code that a connection was closed while still having a pending transaction logger.Error("Closing dirty connection: rollback transaction in session " + SfSession.sessionId + " succeeded."); ExplicitTransaction = null; - return TransactionRollbackStatus.Success; + return TransactionRollbackStatus.Success; } } catch (Exception exception) { - // error to indicate a problem with rollback of an active transaction and inability to return dirty connection to the pool + // error to indicate a problem with rollback of an active transaction and inability to return dirty connection to the pool logger.Error("Closing dirty connection: rollback transaction in session: " + SfSession.sessionId + " failed, exception: " + exception.Message); return TransactionRollbackStatus.Failure; // connection won't be pooled } @@ -254,10 +254,10 @@ await SfSession.CloseAsync(cancellationToken).ContinueWith( protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) { - return SnowflakeDbConnectionPool.GetPooling() && + return SnowflakeDbConnectionPool.GetPooling() && transactionRollbackStatus == TransactionRollbackStatus.Success; } - + public override void Open() { logger.Debug("Open Connection."); @@ -401,7 +401,7 @@ protected override void Dispose(bool disposing) SfSession = null; _connectionState = ConnectionState.Closed; } - + _disposed = true; } diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index e3c21e20a..617c07ebd 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -42,9 +42,21 @@ internal static Task GetSessionAsync(string connectionString, SecureS return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); } - internal static SessionPool GetPool(string connectionString) + public static SnowflakeDbSessionPool GetPool(string connectionString, SecureString password) { s_logger.Debug($"SnowflakeDbConnectionPool::GetPool"); + return new SnowflakeDbSessionPool(ConnectionManager.GetPool(connectionString, password)); + } + + public static SnowflakeDbSessionPool GetPool(string connectionString) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPool"); + return new SnowflakeDbSessionPool(ConnectionManager.GetPool(connectionString)); + } + + internal static SessionPool GetPoolInternal(string connectionString) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPoolInternal"); return ConnectionManager.GetPool(connectionString); } diff --git a/Snowflake.Data/Client/SnowflakeDbSessionPool.cs b/Snowflake.Data/Client/SnowflakeDbSessionPool.cs new file mode 100644 index 000000000..2bd3caecc --- /dev/null +++ b/Snowflake.Data/Client/SnowflakeDbSessionPool.cs @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Client +{ + public class SnowflakeDbSessionPool + { + private readonly SessionPool _sessionPool; + + internal SnowflakeDbSessionPool(SessionPool sessionPool) + => _sessionPool = sessionPool ?? throw new NullReferenceException("SessionPool not provided!"); + + public bool GetPooling() => _sessionPool.GetPooling(); + + public int GetMinPoolSize() => _sessionPool.GetMinPoolSize(); + + public int GetMaxPoolSize() => _sessionPool.GetMaxPoolSize(); + + public int GetCurrentPoolSize() => _sessionPool.GetCurrentPoolSize(); + + public long GetExpirationTimeout() => _sessionPool.GetTimeout(); + + public long GetConnectionTimeout() => _sessionPool.GetConnectionTimeout(); + + public long GetWaitForIdleSessionTimeout() => _sessionPool.GetWaitForIdleSessionTimeout(); + + public void ClearPool() => _sessionPool.ClearSessions(); + + public ChangedSessionBehavior GetChangedSession() => _sessionPool.GetChangedSession(); + } +} diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index 31e0eccca..56a636c4e 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -398,7 +398,7 @@ internal override string GetString(int ordinal) private void UpdateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); } diff --git a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs index d7a7a29d1..a26d542d3 100644 --- a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs @@ -10,7 +10,7 @@ namespace Snowflake.Data.Core.Authenticator { class BasicAuthenticator : BaseAuthenticator, IAuthenticator { - public static readonly string AUTH_NAME = "snowflake"; + public const string AUTH_NAME = "snowflake"; private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); internal BasicAuthenticator(SFSession session) : base(session, AUTH_NAME) diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs index d6ead6818..3b882a05b 100644 --- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs @@ -21,7 +21,7 @@ namespace Snowflake.Data.Core.Authenticator /// class ExternalBrowserAuthenticator : BaseAuthenticator, IAuthenticator { - public static readonly string AUTH_NAME = "externalbrowser"; + public const string AUTH_NAME = "externalbrowser"; private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); private static readonly string TOKEN_REQUEST_PREFIX = "?token="; private static readonly byte[] SUCCESS_RESPONSE = System.Text.Encoding.UTF8.GetBytes( @@ -87,7 +87,7 @@ await session.restRequester.PostAsync( logger.Warn("Browser response timeout"); throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); } - + httpListener.Stop(); } @@ -134,7 +134,7 @@ void IAuthenticator.Authenticate() logger.Warn("Browser response timeout"); throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); } - + httpListener.Stop(); } @@ -150,7 +150,7 @@ private void GetContextCallback(IAsyncResult result) { HttpListenerContext context = httpListener.EndGetContext(result); HttpListenerRequest request = context.Request; - + _samlResponseToken = ValidateAndExtractToken(request); HttpListenerResponse response = context.Response; try diff --git a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs index 150551f91..7a41a8335 100644 --- a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs @@ -17,7 +17,7 @@ namespace Snowflake.Data.Core.Authenticator internal interface IAuthenticator { /// - /// Process the authentication asynchronouly + /// Process the authentication asynchronously /// /// /// @@ -49,19 +49,19 @@ internal abstract class BaseAuthenticator SFLoggerFactory.GetLogger(); // The name of the authenticator. - protected string authName; + private string authName; // The session which created this authenticator. protected SFSession session; // The client environment properties - protected LoginRequestClientEnv ClientEnv = SFEnvironment.ClientEnv; + private LoginRequestClientEnv ClientEnv = SFEnvironment.ClientEnv; /// /// The abstract base for all authenticators. /// /// The session which created the authenticator. - public BaseAuthenticator(SFSession session, string authName) + protected BaseAuthenticator(SFSession session, string authName) { this.session = session; this.authName = authName; @@ -104,7 +104,7 @@ protected void Login() /// /// Builds a simple login request. Each authenticator will fill the Data part with their /// specialized information. The common Data attributes are already filled (clientAppId, - /// ClienAppVersion...). + /// ClientAppVersion...). /// /// A login request to send to the server. private SFRestRequest BuildLoginRequest() @@ -129,10 +129,10 @@ private SFRestRequest BuildLoginRequest() } } - /// - /// Authenticator Factory to build authenticators - /// - internal class AuthenticatorFactory + /// + /// Authenticator Factory to build authenticators + /// + internal class AuthenticatorFactory { private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); /// @@ -155,8 +155,8 @@ internal static IAuthenticator GetAuthenticator(SFSession session) else if (type.Equals(KeyPairAuthenticator.AUTH_NAME, StringComparison.InvariantCultureIgnoreCase)) { // Get private key path or private key from connection settings - if (!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath) && - !session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent)) + if ((!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath) || string.IsNullOrEmpty(pkPath)) && + (!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent) || string.IsNullOrEmpty(pkContent))) { // There is no PRIVATE_KEY_FILE defined, can't authenticate with key-pair string invalidStringDetail = @@ -192,12 +192,8 @@ internal static IAuthenticator GetAuthenticator(SFSession session) { return new OktaAuthenticator(session, type); } - - var e = new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, type); - - logger.Error("Unknown authenticator", e); - - throw e; + logger.Error($"Unknown authenticator {type}"); + throw new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, type); } } } diff --git a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs index fcfb70695..e0c28d4ef 100644 --- a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs @@ -28,7 +28,7 @@ namespace Snowflake.Data.Core.Authenticator class KeyPairAuthenticator : BaseAuthenticator, IAuthenticator { // The authenticator setting value to use to authenticate using key pair authentication. - public static readonly string AUTH_NAME = "snowflake_jwt"; + public const string AUTH_NAME = "snowflake_jwt"; // The logger. private static readonly SFLogger logger = @@ -85,9 +85,9 @@ private string GenerateJwtToken() { logger.Info("Key-pair Authentication"); - bool hasPkPath = + bool hasPkPath = session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath); - bool hasPkContent = + bool hasPkContent = session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent); session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_PWD, out var pkPwd); @@ -152,31 +152,31 @@ private string GenerateJwtToken() byte[] sha256Hash = SHA256Encoder.ComputeHash(publicKeyEncoded); publicKeyFingerPrint = "SHA256:" + Convert.ToBase64String(sha256Hash); } - - // Generating the token + + // Generating the token var now = DateTime.UtcNow; System.DateTime dtDateTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, System.DateTimeKind.Utc); long secondsSinceEpoch = (long)((now - dtDateTime).TotalSeconds); - /* + /* * Payload content - * iss : $accountName.$userName.$pulicKeyFingerprint + * iss : $accountName.$userName.$publicKeyFingerprint * sub : $accountName.$userName * iat : $now * exp : $now + LIFETIME - * + * * Note : Lifetime = 120sec for Python impl, 60sec for Jdbc and Odbc */ - String accountUser = - session.properties[SFSessionProperty.ACCOUNT].ToUpper() + - "." + + String accountUser = + session.properties[SFSessionProperty.ACCOUNT].ToUpper() + + "." + session.properties[SFSessionProperty.USER].ToUpper(); String issuer = accountUser + "." + publicKeyFingerPrint; var claims = new[] { new Claim( - JwtRegisteredClaimNames.Iat, - secondsSinceEpoch.ToString(), + JwtRegisteredClaimNames.Iat, + secondsSinceEpoch.ToString(), System.Security.Claims.ClaimValueTypes.Integer64), new Claim(JwtRegisteredClaimNames.Sub, accountUser), }; diff --git a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs index 5e8f4a310..f36d0353e 100644 --- a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs @@ -14,7 +14,7 @@ namespace Snowflake.Data.Core.Authenticator class OAuthAuthenticator : BaseAuthenticator, IAuthenticator { // The authenticator setting value to use to authenticate using key pair authentication. - public static readonly string AUTH_NAME = "oauth"; + public const string AUTH_NAME = "oauth"; // The logger. private static readonly SFLogger logger = diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 6ec4843d4..1780ccffc 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -22,6 +22,7 @@ namespace Snowflake.Data.Core.Authenticator /// class OktaAuthenticator : BaseAuthenticator, IAuthenticator { + public const string AUTH_NAME = "okta"; private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); internal const string RetryCountHeader = "RetryCount"; diff --git a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs index c811deb8b..18eb4f650 100644 --- a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs +++ b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs @@ -112,7 +112,7 @@ internal override bool Rewind() private void updateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); session.UpdateQueryContextCache(responseData.QueryContext); } diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 03e1794c9..a7586f2c3 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -1,376 +1,376 @@ -/* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Threading; -using System.Threading.Tasks; -using Snowflake.Data.Log; -using Snowflake.Data.Client; -using System.Collections.Generic; -using System.Diagnostics; - -namespace Snowflake.Data.Core -{ - class SFResultSet : SFBaseResultSet - { - internal override ResultFormat ResultFormat => ResultFormat.JSON; - - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - - private readonly int _totalChunkCount; - - private readonly IChunkDownloader _chunkDownloader; - - private BaseResultChunk _currentChunk; - - public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - try - { - columnCount = responseData.rowType?.Count ?? 0; - - this.sfStatement = sfStatement; - UpdateSessionStatus(responseData); - - if (responseData.chunks != null) - { - // counting the first chunk - _totalChunkCount = responseData.chunks.Count; - _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); - } - - _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; - responseData.rowSet = null; - - sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; - - isClosed = false; - - queryId = responseData.queryId; - } - catch(System.Exception ex) - { - s_logger.Error("Result set error queryId="+responseData.queryId, ex); - throw; - } - } - - public enum PutGetResponseRowTypeInfo { - SourceFileName = 0, - DestinationFileName = 1, - SourceFileSize = 2, - DestinationFileSize = 3, - SourceCompressionType = 4, - DestinationCompressionType = 5, - ResultStatus = 6, - ErrorDetails = 7 - } - - public void InitializePutGetRowType(List rowType) - { - foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) - { - rowType.Add(new ExecResponseRowType() - { - name = t.ToString(), - type = "text" - }); - } - } - - public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - responseData.rowType = new List(); - InitializePutGetRowType(responseData.rowType); - - columnCount = responseData.rowType.Count; - - this.sfStatement = sfStatement; - - _currentChunk = new SFResultChunk(responseData.rowSet); - responseData.rowSet = null; - - sfResultSetMetaData = new SFResultSetMetaData(responseData); - - isClosed = false; - - queryId = responseData.queryId; - } - - internal void ResetChunkInfo(BaseResultChunk nextChunk) - { - s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); - _currentChunk.RowSet = null; - _currentChunk = nextChunk; - } - - internal override async Task NextAsync() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - // GetNextChunk could be blocked if download result is not done yet. - // So put this piece of code in a seperate task - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - - return false; - } - - internal override bool Next() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - return false; - } - - internal override bool NextResult() - { - return false; - } - - internal override async Task NextResultAsync(CancellationToken cancellationToken) - { - return await Task.FromResult(false); - } - - internal override bool HasRows() - { - ThrowIfClosed(); - - return _currentChunk.RowCount > 0 || _totalChunkCount > 0; - } - - /// - /// Move cursor back one row. - /// - /// True if it works, false otherwise. - internal override bool Rewind() - { - ThrowIfClosed(); - - return _currentChunk.Rewind(); - } - - internal UTF8Buffer GetObjectInternal(int ordinal) - { - ThrowIfClosed(); - ThrowIfOutOfBounds(ordinal); - - return _currentChunk.ExtractCell(ordinal); - } - - private void UpdateSessionStatus(QueryExecResponseData responseData) - { - SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); - session.UpdateSessionParameterMap(responseData.parameters); - session.UpdateQueryContextCache(responseData.QueryContext); - } - - internal override bool IsDBNull(int ordinal) - { - return (null == GetObjectInternal(ordinal)); - } - - internal override bool GetBoolean(int ordinal) - { - return GetValue(ordinal); - } - - internal override byte GetByte(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override char GetChar(int ordinal) - { - string val = GetString(ordinal); - return val[0]; - } - - internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override DateTime GetDateTime(int ordinal) - { - return GetValue(ordinal); - } - - internal override TimeSpan GetTimeSpan(int ordinal) - { - return GetValue(ordinal); - } - - internal override decimal GetDecimal(int ordinal) - { - return GetValue(ordinal); - } - - internal override double GetDouble(int ordinal) - { - return GetValue(ordinal); - } - - internal override float GetFloat(int ordinal) - { - return GetValue(ordinal); - } - - internal override Guid GetGuid(int ordinal) - { - return GetValue(ordinal); - } - - internal override short GetInt16(int ordinal) - { - return GetValue(ordinal); - } - - internal override int GetInt32(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetInt64(int ordinal) - { - return GetValue(ordinal); - } - - internal override string GetString(int ordinal) - { - ThrowIfOutOfBounds(ordinal); - - var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); - switch (type) - { - case SFDataType.DATE: - var val = GetValue(ordinal); - if (val == DBNull.Value) - return null; - return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); - - default: - return GetObjectInternal(ordinal).SafeToString(); - } - } - - internal override object GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); - } - - private T GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); - } - - // - // Summary: - // Reads a subset of data starting at location indicated by dataOffset into the buffer, - // starting at the location indicated by bufferOffset. - // - // Parameters: - // ordinal: - // The zero-based column ordinal. - // - // dataOffset: - // The index within the data from which to begin the read operation. - // - // buffer: - // The buffer into which to copy the data. - // - // bufferOffset: - // The index with the buffer to which the data will be copied. - // - // length: - // The maximum number of elements to read. - // - // Returns: - // The actual number of elements read. - private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct - { - if (dataOffset < 0) - { - throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); - } - - if (bufferOffset < 0) - { - throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); - } - - if ((null != buffer) && (bufferOffset > buffer.Length)) - { - throw new System.ArgumentException("Destination buffer is not long enough. " + - "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); - } - - T[] data = GetValue(ordinal); - - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks - // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks - // If you pass a buffer that is null, GetChars returns the length of the field in characters. - if (null == buffer) - { - return data.Length; - } - - if (dataOffset > data.Length) - { - throw new System.ArgumentException("Source data is not long enough. " + - "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); - } - else - { - // How much data is available after the offset - long dataLength = data.Length - dataOffset; - // How much data to read - long elementsRead = Math.Min(length, dataLength); - Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); - - return elementsRead; - } - } - } -} +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Log; +using Snowflake.Data.Client; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Snowflake.Data.Core +{ + class SFResultSet : SFBaseResultSet + { + internal override ResultFormat ResultFormat => ResultFormat.JSON; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly int _totalChunkCount; + + private readonly IChunkDownloader _chunkDownloader; + + private BaseResultChunk _currentChunk; + + public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + try + { + columnCount = responseData.rowType?.Count ?? 0; + + this.sfStatement = sfStatement; + UpdateSessionStatus(responseData); + + if (responseData.chunks != null) + { + // counting the first chunk + _totalChunkCount = responseData.chunks.Count; + _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); + } + + _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; + responseData.rowSet = null; + + sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; + + isClosed = false; + + queryId = responseData.queryId; + } + catch(System.Exception ex) + { + s_logger.Error("Result set error queryId="+responseData.queryId, ex); + throw; + } + } + + public enum PutGetResponseRowTypeInfo { + SourceFileName = 0, + DestinationFileName = 1, + SourceFileSize = 2, + DestinationFileSize = 3, + SourceCompressionType = 4, + DestinationCompressionType = 5, + ResultStatus = 6, + ErrorDetails = 7 + } + + public void InitializePutGetRowType(List rowType) + { + foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) + { + rowType.Add(new ExecResponseRowType() + { + name = t.ToString(), + type = "text" + }); + } + } + + public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + responseData.rowType = new List(); + InitializePutGetRowType(responseData.rowType); + + columnCount = responseData.rowType.Count; + + this.sfStatement = sfStatement; + + _currentChunk = new SFResultChunk(responseData.rowSet); + responseData.rowSet = null; + + sfResultSetMetaData = new SFResultSetMetaData(responseData); + + isClosed = false; + + queryId = responseData.queryId; + } + + internal void ResetChunkInfo(BaseResultChunk nextChunk) + { + s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); + _currentChunk.RowSet = null; + _currentChunk = nextChunk; + } + + internal override async Task NextAsync() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + // GetNextChunk could be blocked if download result is not done yet. + // So put this piece of code in a seperate task + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + + return false; + } + + internal override bool Next() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + return false; + } + + internal override bool NextResult() + { + return false; + } + + internal override async Task NextResultAsync(CancellationToken cancellationToken) + { + return await Task.FromResult(false); + } + + internal override bool HasRows() + { + ThrowIfClosed(); + + return _currentChunk.RowCount > 0 || _totalChunkCount > 0; + } + + /// + /// Move cursor back one row. + /// + /// True if it works, false otherwise. + internal override bool Rewind() + { + ThrowIfClosed(); + + return _currentChunk.Rewind(); + } + + internal UTF8Buffer GetObjectInternal(int ordinal) + { + ThrowIfClosed(); + ThrowIfOutOfBounds(ordinal); + + return _currentChunk.ExtractCell(ordinal); + } + + private void UpdateSessionStatus(QueryExecResponseData responseData) + { + SFSession session = this.sfStatement.SfSession; + session.UpdateSessionProperties(responseData); + session.UpdateSessionParameterMap(responseData.parameters); + session.UpdateQueryContextCache(responseData.QueryContext); + } + + internal override bool IsDBNull(int ordinal) + { + return (null == GetObjectInternal(ordinal)); + } + + internal override bool GetBoolean(int ordinal) + { + return GetValue(ordinal); + } + + internal override byte GetByte(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override char GetChar(int ordinal) + { + string val = GetString(ordinal); + return val[0]; + } + + internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override DateTime GetDateTime(int ordinal) + { + return GetValue(ordinal); + } + + internal override TimeSpan GetTimeSpan(int ordinal) + { + return GetValue(ordinal); + } + + internal override decimal GetDecimal(int ordinal) + { + return GetValue(ordinal); + } + + internal override double GetDouble(int ordinal) + { + return GetValue(ordinal); + } + + internal override float GetFloat(int ordinal) + { + return GetValue(ordinal); + } + + internal override Guid GetGuid(int ordinal) + { + return GetValue(ordinal); + } + + internal override short GetInt16(int ordinal) + { + return GetValue(ordinal); + } + + internal override int GetInt32(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetInt64(int ordinal) + { + return GetValue(ordinal); + } + + internal override string GetString(int ordinal) + { + ThrowIfOutOfBounds(ordinal); + + var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); + switch (type) + { + case SFDataType.DATE: + var val = GetValue(ordinal); + if (val == DBNull.Value) + return null; + return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); + + default: + return GetObjectInternal(ordinal).SafeToString(); + } + } + + internal override object GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); + } + + private T GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); + } + + // + // Summary: + // Reads a subset of data starting at location indicated by dataOffset into the buffer, + // starting at the location indicated by bufferOffset. + // + // Parameters: + // ordinal: + // The zero-based column ordinal. + // + // dataOffset: + // The index within the data from which to begin the read operation. + // + // buffer: + // The buffer into which to copy the data. + // + // bufferOffset: + // The index with the buffer to which the data will be copied. + // + // length: + // The maximum number of elements to read. + // + // Returns: + // The actual number of elements read. + private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct + { + if (dataOffset < 0) + { + throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); + } + + if (bufferOffset < 0) + { + throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); + } + + if ((null != buffer) && (bufferOffset > buffer.Length)) + { + throw new System.ArgumentException("Destination buffer is not long enough. " + + "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); + } + + T[] data = GetValue(ordinal); + + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks + // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks + // If you pass a buffer that is null, GetChars returns the length of the field in characters. + if (null == buffer) + { + return data.Length; + } + + if (dataOffset > data.Length) + { + throw new System.ArgumentException("Source data is not long enough. " + + "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); + } + else + { + // How much data is available after the offset + long dataLength = data.Length - dataOffset; + // How much data to read + long elementsRead = Math.Min(length, dataLength); + Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); + + return elementsRead; + } + } + } +} diff --git a/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs index caf7ded2a..a6771930c 100644 --- a/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs +++ b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs @@ -1,33 +1,15 @@ -using System; -using System.Collections.Generic; -using System.Linq; +/* + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. + */ namespace Snowflake.Data.Core.Session { /** - * It describes what should happen to a session with a changed state (e. g. schema/role/etc.) when it is being returned to the pool. + * ChangedSessionBehavior describes what should happen to a session with a changed state (schema/role/database/warehouse) when it returns to the pool. */ - internal enum ChangedSessionBehavior + public enum ChangedSessionBehavior { OriginalPool, - ChangePool, Destroy } - - internal static class ChangedSessionBehaviorExtensions - { - public static List StringValues() - { - return Enum.GetValues(typeof(ChangedSessionBehavior)) - .Cast() - .Where(e => e == ChangedSessionBehavior.OriginalPool) // currently we support only OriginalPool case; TODO: SNOW-937188 - .Select(b => b.ToString()) - .ToList(); - } - - public static ChangedSessionBehavior From(string changedSession) - { - return (ChangedSessionBehavior) Enum.Parse(typeof(ChangedSessionBehavior), changedSession, true); - } - } } diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index b7c885234..febecbbce 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System.Security; @@ -24,6 +24,7 @@ public Task GetSessionAsync(string connectionString, SecureString pas public int GetCurrentPoolSize() => _sessionPool.GetCurrentPoolSize(); public bool SetPooling(bool poolingEnabled) => _sessionPool.SetPooling(poolingEnabled); public bool GetPooling() => _sessionPool.GetPooling(); - public SessionPool GetPool(string _) => _sessionPool; + public SessionPool GetPool(string connectionString) => _sessionPool; + public SessionPool GetPool(string connectionString, SecureString password) => _sessionPool; } } diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs index 4f70ab691..ea1b8ba3b 100644 --- a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -119,7 +119,7 @@ public bool GetPooling() return true; // in new pool pooling is always enabled by default, disabling only by connection string parameter } - internal SessionPool GetPool(string connectionString, SecureString password) + public SessionPool GetPool(string connectionString, SecureString password) { s_logger.Debug($"ConnectionPoolManager::GetPool"); var poolKey = GetPoolKey(connectionString); @@ -143,7 +143,6 @@ public SessionPool GetPool(string connectionString) return GetPool(connectionString, null); } - // TODO: SNOW-937188 private string GetPoolKey(string connectionString) { return connectionString; diff --git a/Snowflake.Data/Core/Session/IConnectionManager.cs b/Snowflake.Data/Core/Session/IConnectionManager.cs index 247cbe2e6..01cfa3e8c 100644 --- a/Snowflake.Data/Core/Session/IConnectionManager.cs +++ b/Snowflake.Data/Core/Session/IConnectionManager.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System.Security; @@ -23,5 +23,6 @@ internal interface IConnectionManager bool SetPooling(bool poolingEnabled); bool GetPooling(); SessionPool GetPool(string connectionString); + SessionPool GetPool(string connectionString, SecureString password); } } diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 1e9785cac..ba078d047 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -46,8 +46,10 @@ public class SFSession internal SFSessionProperties properties; internal string database; - internal string schema; + internal string role; + internal string warehouse; + internal bool sessionPropertiesChanged = false; internal string serverVersion; @@ -101,6 +103,8 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) masterToken = authnResponse.data.masterToken; database = authnResponse.data.authResponseSessionInfo.databaseName; schema = authnResponse.data.authResponseSessionInfo.schemaName; + role = authnResponse.data.authResponseSessionInfo.roleName; + warehouse = authnResponse.data.authResponseSessionInfo.warehouseName; serverVersion = authnResponse.data.serverVersion; masterValidityInSeconds = authnResponse.data.masterValidityInSeconds; UpdateSessionParameterMap(authnResponse.data.nameValueParameter); @@ -181,13 +185,18 @@ internal SFSession( _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; } + catch (SnowflakeDbException e) + { + logger.Error("Unable to initialize session ", e); + throw; + } catch (Exception e) { - logger.Error("Unable to connect", e); + logger.Error("Unable to initialize session ", e); throw new SnowflakeDbException(e, SnowflakeDbException.CONNECTION_FAILURE_SSTATE, SFError.INVALID_CONNECTION_STRING, - "Unable to connect"); + "Unable to initialize session "); } } @@ -460,20 +469,48 @@ internal RequestQueryContext GetQueryContextRequest() return _queryContextCache.GetQueryContextRequest(); } - internal void UpdateDatabaseAndSchema(string databaseName, string schemaName) + internal void UpdateSessionProperties(QueryExecResponseData responseData) { - // with HTAP session metadata removal database/schema - // might be not returened in query result - if (!String.IsNullOrEmpty(databaseName)) - { - this.database = databaseName; - } - if (!String.IsNullOrEmpty(schemaName)) + // with HTAP session metadata removal database/schema might be not returned in query result + UpdateSessionProperty(ref database, responseData.finalDatabaseName); + UpdateSessionProperty(ref schema, responseData.finalSchemaName); + UpdateSessionProperty(ref role, responseData.finalRoleName); + UpdateSessionProperty(ref warehouse, responseData.finalWarehouseName); + } + + internal void UpdateSessionProperty(ref string initialSessionValue, string finalSessionValue) + { + // with HTAP session metadata removal database/schema might be not returned in query result + if (!string.IsNullOrEmpty(finalSessionValue)) { - this.schema = schemaName; + bool quoted = false; + string unquotedFinalValue = UnquoteJson(finalSessionValue, ref quoted); + if (!string.IsNullOrEmpty(initialSessionValue)) + { + quoted |= initialSessionValue.StartsWith("\""); + if (!string.Equals(initialSessionValue, unquotedFinalValue, quoted ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase)) + { + sessionPropertiesChanged = true; + initialSessionValue = unquotedFinalValue; + } + } + else // null session value gets populated and is not treated as a session property change + { + initialSessionValue = unquotedFinalValue; + } } } + private static string UnquoteJson(string value, ref bool unquoted) + { + if (value is null) + return value; + unquoted = value.Length >= 4 && value.StartsWith("\\\"") && value.EndsWith("\\\""); + return unquoted ? value.Replace("\\\"", "\"") : value; + } + + internal bool SessionPropertiesChanged => sessionPropertiesChanged; + internal void startHeartBeatForThisSession() { if (!this.isHeartBeatEnabled) @@ -592,4 +629,3 @@ internal virtual bool IsExpired(TimeSpan timeout, long utcTimeInMillis) internal long GetStartTime() => _startTime; } } - diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index 0b029a40a..b79c332c7 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -1,7 +1,7 @@ using System; using System.Collections.Generic; -using System.Linq; using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; using Snowflake.Data.Core.Session; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; @@ -14,7 +14,7 @@ internal class SFSessionHttpClientProperties private static readonly Extractor s_propertiesExtractor = new Extractor(new SFSessionHttpClientProxyProperties.Extractor()); public const int DefaultMaxPoolSize = 10; public const int DefaultMinPoolSize = 2; - public const ChangedSessionBehavior DefaultChangedSession = ChangedSessionBehavior.OriginalPool; + public const ChangedSessionBehavior DefaultChangedSession = ChangedSessionBehavior.Destroy; public static readonly TimeSpan DefaultWaitingForIdleSessionTimeout = TimeSpan.FromSeconds(30); public static readonly TimeSpan DefaultConnectionTimeout = TimeSpan.FromMinutes(5); public static readonly TimeSpan DefaultExpirationTimeout = TimeSpan.FromHours(1); @@ -22,8 +22,7 @@ internal class SFSessionHttpClientProperties public const int DefaultMaxHttpRetries = 7; public static readonly TimeSpan DefaultRetryTimeout = TimeSpan.FromSeconds(300); private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private static readonly List s_changedSessionValues = ChangedSessionBehaviorExtensions.StringValues(); - + internal bool validateDefaultParameters; internal bool clientSessionKeepAlive; internal TimeSpan connectionTimeout; @@ -47,7 +46,7 @@ public static SFSessionHttpClientProperties ExtractAndValidate(SFSessionProperti extractedProperties.CheckPropertiesAreValid(); return extractedProperties; } - + private void CheckPropertiesAreValid() { try @@ -126,7 +125,7 @@ private void ValidateHttpRetries() s_logger.Warn($"Max retry count provided is 0. Retry count will be infinite"); } } - + private void ValidateMinMaxPoolSize() { if (_minPoolSize > _maxPoolSize) @@ -143,7 +142,7 @@ private void ValidateWaitingForSessionIdleTimeout() } if (TimeoutHelper.IsZeroLength(_waitingForSessionIdleTimeout)) { - s_logger.Warn("Waiting for idle session timeout is 0. There will be no waiting for idle session"); + s_logger.Warn("Waiting for idle session timeout is 0. There will be no waiting for idle session"); } } @@ -219,14 +218,14 @@ public SFSessionHttpClientProperties ExtractProperties(SFSessionProperties prope _poolingEnabled = extractor.ExtractBooleanWithDefaultValue(SFSessionProperty.POOLINGENABLED) }; } - + private ChangedSessionBehavior ExtractChangedSession( SessionPropertiesWithDefaultValuesExtractor extractor, SFSessionProperty property) => extractor.ExtractPropertyWithDefaultValue( property, - ChangedSessionBehaviorExtensions.From, - s => !string.IsNullOrEmpty(s) && s_changedSessionValues.Contains(s, StringComparer.OrdinalIgnoreCase), + i => (ChangedSessionBehavior)Enum.Parse(typeof(ChangedSessionBehavior), i, true), + s => !string.IsNullOrEmpty(s), b => true ); } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 3896f809a..49a9a0e75 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -102,7 +102,7 @@ internal enum SFSessionProperty MAXPOOLSIZE, [SFSessionPropertyAttr(required = false, defaultValue = "2")] MINPOOLSIZE, - [SFSessionPropertyAttr(required = false, defaultValue = "OriginalPool")] + [SFSessionPropertyAttr(required = false, defaultValue = "Destroy")] CHANGEDSESSION, [SFSessionPropertyAttr(required = false, defaultValue = "30s")] WAITINGFORIDLESESSIONTIMEOUT, @@ -247,12 +247,13 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin } } - if (password != null) + if (password != null && password.Length > 0) { properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; } - checkSessionProperties(properties); + ValidateAuthenticator(properties); + CheckSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); ValidateAccountDomain(properties); @@ -283,6 +284,21 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin return properties; } + private static void ValidateAuthenticator(SFSessionProperties properties) + { + var knownAuthenticators = new[] { BasicAuthenticator.AUTH_NAME, OktaAuthenticator.AUTH_NAME, OAuthAuthenticator.AUTH_NAME, KeyPairAuthenticator.AUTH_NAME, ExternalBrowserAuthenticator.AUTH_NAME }; + if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) + { + authenticator = authenticator.ToLower(); + if (!knownAuthenticators.Contains(authenticator) && !(authenticator.Contains(OktaAuthenticator.AUTH_NAME) && authenticator.StartsWith("https://"))) + { + var error = $"Unknown authenticator: {authenticator}"; + logger.Error(error); + throw new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, authenticator); + } + } + } + private static string BuildConnectionStringWithoutSecrets(ref string[] keys, ref string[] values) { var count = keys.Length; @@ -368,7 +384,7 @@ private static bool IsAccountRegexMatched(string account) => .Select(regex => Regex.Match(account, regex, RegexOptions.IgnoreCase)) .All(match => match.Success); - private static void checkSessionProperties(SFSessionProperties properties) + private static void CheckSessionProperties(SFSessionProperties properties) { foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) { @@ -376,17 +392,23 @@ private static void checkSessionProperties(SFSessionProperties properties) if (IsRequired(sessionProperty, properties) && !properties.ContainsKey(sessionProperty)) { - SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, - sessionProperty); + SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, sessionProperty); logger.Error("Missing connection property", e); throw e; } + if (IsRequired(sessionProperty, properties) && string.IsNullOrEmpty(properties[sessionProperty])) + { + SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, sessionProperty); + logger.Error("Empty connection property", e); + throw e; + } + // add default value to the map string defaultVal = sessionProperty.GetAttribute().defaultValue; if (defaultVal != null && !properties.ContainsKey(sessionProperty)) { - logger.Debug($"Sesssion property {sessionProperty} set to default value: {defaultVal}"); + logger.Debug($"Session property {sessionProperty} set to default value: {defaultVal}"); properties.Add(sessionProperty, defaultVal); } } @@ -450,6 +472,12 @@ private static bool IsRequired(SFSessionProperty sessionProperty, SFSessionPrope return !authenticatorDefined || !authenticatorsWithoutUsername .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); } + else if (sessionProperty.Equals(SFSessionProperty.TOKEN)) + { + var authenticatorDefined = properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); + + return !authenticatorDefined || authenticator.Equals(OAuthAuthenticator.AUTH_NAME); + } else { return sessionProperty.GetAttribute().required; diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 8b50144b0..5d38c2e6e 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; @@ -33,6 +34,8 @@ sealed class SessionPool : IDisposable private readonly ConnectionPoolConfig _poolConfig; private bool _configOverriden = false; + private static readonly InvalidOperationException s_notSupportedInCachePoolException = new InvalidOperationException("Feature not supported in a Connection Cache"); + private SessionPool() { // acquiring a lock not needed because one is already acquired in SnowflakeDbConnectionPool @@ -109,10 +112,10 @@ private static Tuple ExtractConfig(string connecti var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); return Tuple.Create(extractedProperties.BuildConnectionPoolConfig(), properties.ConnectionStringWithoutSecrets); } - catch (SnowflakeDbException exception) + catch (Exception exception) { - s_logger.Error("Could not extract pool configuration, using default one", exception); - return Tuple.Create(new ConnectionPoolConfig(), "could not parse connection string"); + s_logger.Error("Failed to extract pool configuration", exception); + throw; } } @@ -422,14 +425,29 @@ internal void ReleaseBusySession(SFSession session) internal bool AddSession(SFSession session, bool ensureMinPoolSize) { + s_logger.Debug("SessionPool::AddSession" + PoolIdentification()); + if (!GetPooling()) return false; + + if (IsMultiplePoolsVersion() && + session.SessionPropertiesChanged && + _poolConfig.ChangedSession == ChangedSessionBehavior.Destroy) + { + s_logger.Debug($"Session returning to pool was changed. Destroying the session: {session.sessionId}."); + session.SetPooling(false); + } + if (!session.GetPooling()) { ReleaseBusySession(session); + if (ensureMinPoolSize) + { + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsWhenReturningSessionToPool()); + } return false; } - s_logger.Debug("SessionPool::AddSession" + PoolIdentification()); + var result = ReturnSessionToPool(session, ensureMinPoolSize); var wasSessionReturnedToPool = result.Item1; var sessionCreationTokens = result.Item2; @@ -529,9 +547,28 @@ public void SetMaxPoolSize(int size) _configOverriden = true; } - public int GetMaxPoolSize() + public int GetMaxPoolSize() => _poolConfig.MaxPoolSize; + + public int GetMinPoolSize() + { + return IsMultiplePoolsVersion() + ? _poolConfig.MinPoolSize + : throw s_notSupportedInCachePoolException; + } + + public ChangedSessionBehavior GetChangedSession() => + IsMultiplePoolsVersion() + ? _poolConfig.ChangedSession + : throw s_notSupportedInCachePoolException; + + public long GetWaitForIdleSessionTimeout() => + IsMultiplePoolsVersion() + ? (long)_poolConfig.WaitingForIdleSessionTimeout.TotalSeconds + : throw s_notSupportedInCachePoolException; + + public long GetConnectionTimeout() { - return _poolConfig.MaxPoolSize; + return TimeoutHelper.IsInfinite(_poolConfig.ConnectionTimeout) ? -1 : (long)_poolConfig.ConnectionTimeout.TotalSeconds; } public void SetTimeout(long seconds) diff --git a/Snowflake.Data/Core/Session/SessionPoolState.cs b/Snowflake.Data/Core/Session/SessionPoolState.cs index 4f29858c3..e548c186e 100644 --- a/Snowflake.Data/Core/Session/SessionPoolState.cs +++ b/Snowflake.Data/Core/Session/SessionPoolState.cs @@ -1,6 +1,6 @@ namespace Snowflake.Data.Core.Session { - public class SessionPoolState + internal class SessionPoolState { private readonly int _idleSessionsCount; private readonly int _busySessionsCount;