From e727d1a2310fbb54478d824fd4ff3f9099c713b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Fri, 26 Apr 2024 22:27:53 +0200 Subject: [PATCH] SNOW-937188 pool mode where a session gets destroyed on pooling when some of its settings gets changed and no longer match the pool initial setup --- .../ConnectionChangedSessionIT.cs | 160 ++++ .../EasyLoggingStarterTest.cs | 0 .../UnitTests/SFSessionTest.cs | 58 +- Snowflake.Data/Core/ArrowResultSet.cs | 4 +- .../Core/SFMultiStatementsResultSet.cs | 2 +- Snowflake.Data/Core/SFResultSet.cs | 752 +++++++++--------- .../Core/Session/ChangedSessionBehavior.cs | 2 +- Snowflake.Data/Core/Session/SFSession.cs | 35 +- Snowflake.Data/Core/Session/SessionPool.cs | 10 + 9 files changed, 619 insertions(+), 404 deletions(-) create mode 100644 Snowflake.Data.Tests/IntegrationTests/ConnectionChangedSessionIT.cs rename Snowflake.Data.Tests/UnitTests/{Session => Logger}/EasyLoggingStarterTest.cs (100%) diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionChangedSessionIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionChangedSessionIT.cs new file mode 100644 index 000000000..0919223de --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionChangedSessionIT.cs @@ -0,0 +1,160 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionChangedSessionIT : SFBaseTest + { + private readonly QueryExecResponseData _queryExecResponseChangedRole = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = "role change", + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedDatabase = new() + { + finalDatabaseName = "database changed", + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedSchema = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = "schema changed", + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedWarehouse = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = "warehouse changed" + }; + + private static PoolConfig s_previousPoolConfigRestorer; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_previousPoolConfigRestorer = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + } + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_previousPoolConfigRestorer.Reset(); + } + + [Test] + public void TestPoolDestroysConnectionWhenChangedSessionProperties() + { + var connectionString = ConnectionString + "application=Destroy;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedDatabase); + connection.Close(); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenSessionPropertiesUnchanged() + { + var connectionString = ConnectionString + "application=NoSessionChanges;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForOriginalPoolMode() + { + var connectionString = ConnectionString + "application=OriginalPoolMode;ChangedSession=OriginalPool;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedWarehouse); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + connection.Close(); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForDefaultPoolMode() + { + var connectionString = ConnectionString + "application=DefaultPoolMode;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedRole); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolDestroysAndRecreatesConnection() + { + var connectionString = ConnectionString + "application=DestroyRecreateSession;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedSchema); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs b/Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs similarity index 100% rename from Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs rename to Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index b9530b83b..39f28fb04 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -1,15 +1,12 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ -using Snowflake.Data.Configuration; -using Snowflake.Data.Log; +using Snowflake.Data.Core; +using NUnit.Framework; namespace Snowflake.Data.Tests.UnitTests { - using Snowflake.Data.Core; - using NUnit.Framework; - [TestFixture] class SFSessionTest { @@ -20,26 +17,61 @@ public void TestSessionGoneWhenClose() Mock.MockCloseSessionGone restRequester = new Mock.MockCloseSessionGone(); SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); - sfSession.close(); // no exception is raised. + Assert.DoesNotThrow(() => sfSession.close()); } [Test] - public void TestUpdateDatabaseAndSchema() + public void TestUpdateSessionProperties() { + // arrange string databaseName = "DB_TEST"; string schemaName = "SC_TEST"; - + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + QueryExecResponseData queryExecResponseData = new QueryExecResponseData + { + finalSchemaName = schemaName, + finalDatabaseName = databaseName, + finalRoleName = roleName, + finalWarehouseName = warehouseName + }; + + // act SFSession sfSession = new SFSession("account=test;user=test;password=test", null); - sfSession.UpdateDatabaseAndSchema(databaseName, schemaName); + sfSession.UpdateSessionProperties(queryExecResponseData); + // assert Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); + } + + [Test] + public void TestSkipUpdateSessionPropertiesWhenPropertiesMissing() + { + // arrange + string databaseName = "DB_TEST"; + string schemaName = "SC_TEST"; + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + SFSession sfSession = new SFSession("account=test;user=test;password=test", null); + sfSession.database = databaseName; + sfSession.warehouse = warehouseName; + sfSession.role = roleName; + sfSession.schema = schemaName; + // act + QueryExecResponseData queryExecResponseWithoutData = new QueryExecResponseData(); + sfSession.UpdateSessionProperties(queryExecResponseWithoutData); + + // assert // when database or schema name is missing in the response, // the cached value should keep unchanged - sfSession.UpdateDatabaseAndSchema(null, null); Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); } [Test] @@ -54,10 +86,10 @@ public void TestThatConfiguresEasyLogging(string configPath) var connectionString = configPath == null ? simpleConnectionString : $"{simpleConnectionString}client_config_file={configPath};"; - + // act new SFSession(connectionString, null, easyLoggingStarter.Object); - + // assert easyLoggingStarter.Verify(starter => starter.Init(configPath)); } diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index 31e0eccca..56a636c4e 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -398,7 +398,7 @@ internal override string GetString(int ordinal) private void UpdateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); } diff --git a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs index c811deb8b..18eb4f650 100644 --- a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs +++ b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs @@ -112,7 +112,7 @@ internal override bool Rewind() private void updateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); session.UpdateQueryContextCache(responseData.QueryContext); } diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 03e1794c9..a7586f2c3 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -1,376 +1,376 @@ -/* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Threading; -using System.Threading.Tasks; -using Snowflake.Data.Log; -using Snowflake.Data.Client; -using System.Collections.Generic; -using System.Diagnostics; - -namespace Snowflake.Data.Core -{ - class SFResultSet : SFBaseResultSet - { - internal override ResultFormat ResultFormat => ResultFormat.JSON; - - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - - private readonly int _totalChunkCount; - - private readonly IChunkDownloader _chunkDownloader; - - private BaseResultChunk _currentChunk; - - public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - try - { - columnCount = responseData.rowType?.Count ?? 0; - - this.sfStatement = sfStatement; - UpdateSessionStatus(responseData); - - if (responseData.chunks != null) - { - // counting the first chunk - _totalChunkCount = responseData.chunks.Count; - _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); - } - - _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; - responseData.rowSet = null; - - sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; - - isClosed = false; - - queryId = responseData.queryId; - } - catch(System.Exception ex) - { - s_logger.Error("Result set error queryId="+responseData.queryId, ex); - throw; - } - } - - public enum PutGetResponseRowTypeInfo { - SourceFileName = 0, - DestinationFileName = 1, - SourceFileSize = 2, - DestinationFileSize = 3, - SourceCompressionType = 4, - DestinationCompressionType = 5, - ResultStatus = 6, - ErrorDetails = 7 - } - - public void InitializePutGetRowType(List rowType) - { - foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) - { - rowType.Add(new ExecResponseRowType() - { - name = t.ToString(), - type = "text" - }); - } - } - - public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - responseData.rowType = new List(); - InitializePutGetRowType(responseData.rowType); - - columnCount = responseData.rowType.Count; - - this.sfStatement = sfStatement; - - _currentChunk = new SFResultChunk(responseData.rowSet); - responseData.rowSet = null; - - sfResultSetMetaData = new SFResultSetMetaData(responseData); - - isClosed = false; - - queryId = responseData.queryId; - } - - internal void ResetChunkInfo(BaseResultChunk nextChunk) - { - s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); - _currentChunk.RowSet = null; - _currentChunk = nextChunk; - } - - internal override async Task NextAsync() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - // GetNextChunk could be blocked if download result is not done yet. - // So put this piece of code in a seperate task - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - - return false; - } - - internal override bool Next() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - return false; - } - - internal override bool NextResult() - { - return false; - } - - internal override async Task NextResultAsync(CancellationToken cancellationToken) - { - return await Task.FromResult(false); - } - - internal override bool HasRows() - { - ThrowIfClosed(); - - return _currentChunk.RowCount > 0 || _totalChunkCount > 0; - } - - /// - /// Move cursor back one row. - /// - /// True if it works, false otherwise. - internal override bool Rewind() - { - ThrowIfClosed(); - - return _currentChunk.Rewind(); - } - - internal UTF8Buffer GetObjectInternal(int ordinal) - { - ThrowIfClosed(); - ThrowIfOutOfBounds(ordinal); - - return _currentChunk.ExtractCell(ordinal); - } - - private void UpdateSessionStatus(QueryExecResponseData responseData) - { - SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); - session.UpdateSessionParameterMap(responseData.parameters); - session.UpdateQueryContextCache(responseData.QueryContext); - } - - internal override bool IsDBNull(int ordinal) - { - return (null == GetObjectInternal(ordinal)); - } - - internal override bool GetBoolean(int ordinal) - { - return GetValue(ordinal); - } - - internal override byte GetByte(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override char GetChar(int ordinal) - { - string val = GetString(ordinal); - return val[0]; - } - - internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override DateTime GetDateTime(int ordinal) - { - return GetValue(ordinal); - } - - internal override TimeSpan GetTimeSpan(int ordinal) - { - return GetValue(ordinal); - } - - internal override decimal GetDecimal(int ordinal) - { - return GetValue(ordinal); - } - - internal override double GetDouble(int ordinal) - { - return GetValue(ordinal); - } - - internal override float GetFloat(int ordinal) - { - return GetValue(ordinal); - } - - internal override Guid GetGuid(int ordinal) - { - return GetValue(ordinal); - } - - internal override short GetInt16(int ordinal) - { - return GetValue(ordinal); - } - - internal override int GetInt32(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetInt64(int ordinal) - { - return GetValue(ordinal); - } - - internal override string GetString(int ordinal) - { - ThrowIfOutOfBounds(ordinal); - - var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); - switch (type) - { - case SFDataType.DATE: - var val = GetValue(ordinal); - if (val == DBNull.Value) - return null; - return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); - - default: - return GetObjectInternal(ordinal).SafeToString(); - } - } - - internal override object GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); - } - - private T GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); - } - - // - // Summary: - // Reads a subset of data starting at location indicated by dataOffset into the buffer, - // starting at the location indicated by bufferOffset. - // - // Parameters: - // ordinal: - // The zero-based column ordinal. - // - // dataOffset: - // The index within the data from which to begin the read operation. - // - // buffer: - // The buffer into which to copy the data. - // - // bufferOffset: - // The index with the buffer to which the data will be copied. - // - // length: - // The maximum number of elements to read. - // - // Returns: - // The actual number of elements read. - private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct - { - if (dataOffset < 0) - { - throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); - } - - if (bufferOffset < 0) - { - throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); - } - - if ((null != buffer) && (bufferOffset > buffer.Length)) - { - throw new System.ArgumentException("Destination buffer is not long enough. " + - "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); - } - - T[] data = GetValue(ordinal); - - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks - // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks - // If you pass a buffer that is null, GetChars returns the length of the field in characters. - if (null == buffer) - { - return data.Length; - } - - if (dataOffset > data.Length) - { - throw new System.ArgumentException("Source data is not long enough. " + - "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); - } - else - { - // How much data is available after the offset - long dataLength = data.Length - dataOffset; - // How much data to read - long elementsRead = Math.Min(length, dataLength); - Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); - - return elementsRead; - } - } - } -} +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Log; +using Snowflake.Data.Client; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Snowflake.Data.Core +{ + class SFResultSet : SFBaseResultSet + { + internal override ResultFormat ResultFormat => ResultFormat.JSON; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly int _totalChunkCount; + + private readonly IChunkDownloader _chunkDownloader; + + private BaseResultChunk _currentChunk; + + public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + try + { + columnCount = responseData.rowType?.Count ?? 0; + + this.sfStatement = sfStatement; + UpdateSessionStatus(responseData); + + if (responseData.chunks != null) + { + // counting the first chunk + _totalChunkCount = responseData.chunks.Count; + _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); + } + + _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; + responseData.rowSet = null; + + sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; + + isClosed = false; + + queryId = responseData.queryId; + } + catch(System.Exception ex) + { + s_logger.Error("Result set error queryId="+responseData.queryId, ex); + throw; + } + } + + public enum PutGetResponseRowTypeInfo { + SourceFileName = 0, + DestinationFileName = 1, + SourceFileSize = 2, + DestinationFileSize = 3, + SourceCompressionType = 4, + DestinationCompressionType = 5, + ResultStatus = 6, + ErrorDetails = 7 + } + + public void InitializePutGetRowType(List rowType) + { + foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) + { + rowType.Add(new ExecResponseRowType() + { + name = t.ToString(), + type = "text" + }); + } + } + + public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + responseData.rowType = new List(); + InitializePutGetRowType(responseData.rowType); + + columnCount = responseData.rowType.Count; + + this.sfStatement = sfStatement; + + _currentChunk = new SFResultChunk(responseData.rowSet); + responseData.rowSet = null; + + sfResultSetMetaData = new SFResultSetMetaData(responseData); + + isClosed = false; + + queryId = responseData.queryId; + } + + internal void ResetChunkInfo(BaseResultChunk nextChunk) + { + s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); + _currentChunk.RowSet = null; + _currentChunk = nextChunk; + } + + internal override async Task NextAsync() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + // GetNextChunk could be blocked if download result is not done yet. + // So put this piece of code in a seperate task + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + + return false; + } + + internal override bool Next() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + return false; + } + + internal override bool NextResult() + { + return false; + } + + internal override async Task NextResultAsync(CancellationToken cancellationToken) + { + return await Task.FromResult(false); + } + + internal override bool HasRows() + { + ThrowIfClosed(); + + return _currentChunk.RowCount > 0 || _totalChunkCount > 0; + } + + /// + /// Move cursor back one row. + /// + /// True if it works, false otherwise. + internal override bool Rewind() + { + ThrowIfClosed(); + + return _currentChunk.Rewind(); + } + + internal UTF8Buffer GetObjectInternal(int ordinal) + { + ThrowIfClosed(); + ThrowIfOutOfBounds(ordinal); + + return _currentChunk.ExtractCell(ordinal); + } + + private void UpdateSessionStatus(QueryExecResponseData responseData) + { + SFSession session = this.sfStatement.SfSession; + session.UpdateSessionProperties(responseData); + session.UpdateSessionParameterMap(responseData.parameters); + session.UpdateQueryContextCache(responseData.QueryContext); + } + + internal override bool IsDBNull(int ordinal) + { + return (null == GetObjectInternal(ordinal)); + } + + internal override bool GetBoolean(int ordinal) + { + return GetValue(ordinal); + } + + internal override byte GetByte(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override char GetChar(int ordinal) + { + string val = GetString(ordinal); + return val[0]; + } + + internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override DateTime GetDateTime(int ordinal) + { + return GetValue(ordinal); + } + + internal override TimeSpan GetTimeSpan(int ordinal) + { + return GetValue(ordinal); + } + + internal override decimal GetDecimal(int ordinal) + { + return GetValue(ordinal); + } + + internal override double GetDouble(int ordinal) + { + return GetValue(ordinal); + } + + internal override float GetFloat(int ordinal) + { + return GetValue(ordinal); + } + + internal override Guid GetGuid(int ordinal) + { + return GetValue(ordinal); + } + + internal override short GetInt16(int ordinal) + { + return GetValue(ordinal); + } + + internal override int GetInt32(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetInt64(int ordinal) + { + return GetValue(ordinal); + } + + internal override string GetString(int ordinal) + { + ThrowIfOutOfBounds(ordinal); + + var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); + switch (type) + { + case SFDataType.DATE: + var val = GetValue(ordinal); + if (val == DBNull.Value) + return null; + return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); + + default: + return GetObjectInternal(ordinal).SafeToString(); + } + } + + internal override object GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); + } + + private T GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); + } + + // + // Summary: + // Reads a subset of data starting at location indicated by dataOffset into the buffer, + // starting at the location indicated by bufferOffset. + // + // Parameters: + // ordinal: + // The zero-based column ordinal. + // + // dataOffset: + // The index within the data from which to begin the read operation. + // + // buffer: + // The buffer into which to copy the data. + // + // bufferOffset: + // The index with the buffer to which the data will be copied. + // + // length: + // The maximum number of elements to read. + // + // Returns: + // The actual number of elements read. + private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct + { + if (dataOffset < 0) + { + throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); + } + + if (bufferOffset < 0) + { + throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); + } + + if ((null != buffer) && (bufferOffset > buffer.Length)) + { + throw new System.ArgumentException("Destination buffer is not long enough. " + + "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); + } + + T[] data = GetValue(ordinal); + + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks + // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks + // If you pass a buffer that is null, GetChars returns the length of the field in characters. + if (null == buffer) + { + return data.Length; + } + + if (dataOffset > data.Length) + { + throw new System.ArgumentException("Source data is not long enough. " + + "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); + } + else + { + // How much data is available after the offset + long dataLength = data.Length - dataOffset; + // How much data to read + long elementsRead = Math.Min(length, dataLength); + Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); + + return elementsRead; + } + } + } +} diff --git a/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs index caf7ded2a..50cf9893c 100644 --- a/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs +++ b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs @@ -20,7 +20,7 @@ public static List StringValues() { return Enum.GetValues(typeof(ChangedSessionBehavior)) .Cast() - .Where(e => e == ChangedSessionBehavior.OriginalPool) // currently we support only OriginalPool case; TODO: SNOW-937188 + .Where(e => e != ChangedSessionBehavior.ChangePool) // no support yet for ChangedSessionBehavior.ChangePool case .Select(b => b.ToString()) .ToList(); } diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 1e9785cac..fc6aedef8 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -46,8 +46,10 @@ public class SFSession internal SFSessionProperties properties; internal string database; - internal string schema; + internal string role; + internal string warehouse; + internal bool sessionPropertiesChanged = false; internal string serverVersion; @@ -101,6 +103,8 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) masterToken = authnResponse.data.masterToken; database = authnResponse.data.authResponseSessionInfo.databaseName; schema = authnResponse.data.authResponseSessionInfo.schemaName; + role = authnResponse.data.authResponseSessionInfo.roleName; + warehouse = authnResponse.data.authResponseSessionInfo.warehouseName; serverVersion = authnResponse.data.serverVersion; masterValidityInSeconds = authnResponse.data.masterValidityInSeconds; UpdateSessionParameterMap(authnResponse.data.nameValueParameter); @@ -460,20 +464,30 @@ internal RequestQueryContext GetQueryContextRequest() return _queryContextCache.GetQueryContextRequest(); } - internal void UpdateDatabaseAndSchema(string databaseName, string schemaName) + internal void UpdateSessionProperties(QueryExecResponseData responseData) { - // with HTAP session metadata removal database/schema - // might be not returened in query result - if (!String.IsNullOrEmpty(databaseName)) - { - this.database = databaseName; - } - if (!String.IsNullOrEmpty(schemaName)) + // with HTAP session metadata removal database/schema might be not returned in query result + UpdateSessionProperty(ref database, responseData.finalDatabaseName); + UpdateSessionProperty(ref schema, responseData.finalSchemaName); + UpdateSessionProperty(ref role, responseData.finalRoleName); + UpdateSessionProperty(ref warehouse, responseData.finalWarehouseName); + } + + private void UpdateSessionProperty(ref string initialSessionValue, string finalSessionValue) + { + // with HTAP session metadata removal database/schema might be not returned in query result + if (!String.IsNullOrEmpty(finalSessionValue)) { - this.schema = schemaName; + if (!String.IsNullOrEmpty(initialSessionValue) && initialSessionValue != finalSessionValue) + { + sessionPropertiesChanged = true; + } + initialSessionValue = finalSessionValue; } } + internal bool SessionPropertiesChanged => sessionPropertiesChanged; + internal void startHeartBeatForThisSession() { if (!this.isHeartBeatEnabled) @@ -592,4 +606,3 @@ internal virtual bool IsExpired(TimeSpan timeout, long utcTimeInMillis) internal long GetStartTime() => _startTime; } } - diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 401e435f6..5495bb562 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -413,6 +413,16 @@ internal bool AddSession(SFSession session, bool ensureMinPoolSize) ReleaseBusySession(session); return false; } + + if (session.SessionPropertiesChanged && _poolConfig.ChangedSession == ChangedSessionBehavior.Destroy) + { + session.SetPooling(false); + Task.Run(() => session.CloseAsync(CancellationToken.None)).ConfigureAwait(false); + ReleaseBusySession(session); + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsWhenReturningSessionToPool()); + return false; + } + const string AddSessionMessage = "SessionPool::AddSession"; var addSessionMessage = IsMultiplePoolsVersion() ? $"{AddSessionMessage} - returning session to pool identified by connection string: {ConnectionString}"