From 5aec6b47ff1011928734837f5c5edac0dbab76ec Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf <115584722+sfc-gh-ext-simba-lf@users.noreply.github.com> Date: Mon, 25 Mar 2024 18:01:09 -0700 Subject: [PATCH 01/12] SNOW-817091: Async execution (#887) ### Description - Adds the option to execute queries in async mode - Adds the capability to asynchronously wait for the query to finish and get the results using the query ID - Adds checking the query status - Adds checking if query is still running or encountered an error ### Checklist - [ ] Code compiles correctly - [ ] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [ ] Created tests which fail without the change (if possible) - [ ] All tests passing (`dotnet test`) - [ ] Extended the README / documentation, if necessary - [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .../IntegrationTests/SFConnectionPoolIT.cs | 2 +- .../IntegrationTests/SFDbCommandIT.cs | 584 ++++++++++++++++++ .../UnitTests/SFStatementTest.cs | 95 ++- Snowflake.Data/Client/SnowflakeDbCommand.cs | 95 ++- .../Client/SnowflakeDbConnection.cs | 18 +- Snowflake.Data/Core/QueryResultsAwaiter.cs | 146 +++++ Snowflake.Data/Core/RestParams.cs | 2 + Snowflake.Data/Core/RestRequest.cs | 7 +- Snowflake.Data/Core/RestResponse.cs | 31 + Snowflake.Data/Core/SFBindUploader.cs | 4 +- Snowflake.Data/Core/SFResultSet.cs | 6 +- Snowflake.Data/Core/SFStatement.cs | 332 ++++++++-- 12 files changed, 1247 insertions(+), 75 deletions(-) create mode 100644 Snowflake.Data/Core/QueryResultsAwaiter.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs index 2a403521e..4f5020538 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs @@ -333,7 +333,7 @@ void ThreadProcess2(string connstr) Thread.Sleep(5000); SFStatement statement = new SFStatement(conn1.SfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); SnowflakeDbConnectionPool.ClearAllPools(); diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 6e92c0aac..a4c84caeb 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -175,6 +175,302 @@ public void TestExecuteAsyncWithMaxRetryReached() Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, 30 * 1000); } } + + [Test] + public async Task TestAsyncExecQueryAsync() + { + string queryId; + var expectedWaitTime = 5; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + Assert.IsFalse(conn.IsAnError(queryStatus)); + + // Act + DbDataReader reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); + queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, queryStatus); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test, NonParallelizable] + public async Task TestExecuteNormalQueryWhileAsyncExecQueryIsRunningAsync() + { + string queryId; + var expectedWaitTime = 5; + + SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new SnowflakeDbConnection(ConnectionString); + await connections[i].OpenAsync(CancellationToken.None).ConfigureAwait(false); + } + + // Start the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[0].CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(connections[0].IsStillRunning(queryStatus)); + } + + // Execute a normal query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[1].CreateCommand()) + { + // Arrange + cmd.CommandText = $"select 1;"; + + // Act + var row = cmd.ExecuteScalar(); + + // Assert + Assert.AreEqual(1, row); + } + + // Get results of the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[2].CreateCommand()) + { + // Act + var reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, queryStatus); + } + + for (int i = 0; i < connections.Length; i++) + { + await connections[i].CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestAsyncExecCancelWhileGettingResultsAsync() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + CancellationTokenSource cancelToken = new CancellationTokenSource(); + cmd.CommandText = $"CALL SYSTEM$WAIT(60, \'SECONDS\');"; + + // Act + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + + // Act + cancelToken.Cancel(); + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(queryId, cancelToken.Token).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("The operation was canceled")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestFailedAsyncExecQueryThrowsErrorAsync() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + var statusMaxRetryCount = 5; + var statusRetryCount = 0; + cmd.CommandText = $"SELECT * FROM FAKE_TABLE;"; + + // Act + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + while (statusRetryCount < statusMaxRetryCount && conn.IsStillRunning(queryStatus)) + { + Thread.Sleep(1000); + queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + statusRetryCount++; + } + + // Assert + Assert.AreEqual(QueryStatus.FailedWithError, queryStatus); + + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("'FAKE_TABLE' does not exist")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetStatusOfInvalidQueryIdAsync() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetQueryStatusAsync(fakeQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetResultsOfInvalidQueryIdAsync() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(fakeQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test, NonParallelizable] + public async Task TestGetStatusOfUnknownQueryIdAsync() + { + string unknownQueryId = "ba321edc-1abc-123e-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var queryStatus = await cmd.GetQueryStatusAsync(unknownQueryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.AreEqual(QueryStatus.NoData, queryStatus); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + [Ignore("The test takes too long to finish when using the default retry")] + public async Task TestGetResultsOfUnknownQueryIdAsyncWithDefaultRetry() + { + string unknownQueryId = "ab123fed-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(unknownQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"Max retry for no data is reached")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetResultsOfUnknownQueryIdAsyncWithConfiguredRetry() + { + var queryResultsRetryCount = 3; + var queryResultsRetryPattern = new int[] { 1, 2 }; + var unknownQueryId = "ab123fed-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + QueryResultsAwaiter queryResultsAwaiter = new QueryResultsAwaiter(new QueryResultsRetryConfig(queryResultsRetryCount, queryResultsRetryPattern)); + + // Act + var thrown = Assert.ThrowsAsync(async () => + await queryResultsAwaiter.RetryUntilQueryResultIsAvailable(conn, unknownQueryId, CancellationToken.None, true).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"Max retry for no data is reached")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } } [TestFixture] @@ -1040,5 +1336,293 @@ public void TestGetQueryId() conn.Close(); } } + + [Test] + public void TestAsyncExecQuery() + { + string queryId; + var expectedWaitTime = 5; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + var queryStatus = cmd.GetQueryStatus(queryId); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + Assert.IsFalse(conn.IsAnError(queryStatus)); + + // Act + DbDataReader reader = cmd.GetResultsFromQueryId(queryId); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, cmd.GetQueryStatus(queryId)); + } + + conn.Close(); + } + } + + [Test, NonParallelizable] + public void TestExecuteNormalQueryWhileAsyncExecQueryIsRunning() + { + string queryId; + var expectedWaitTime = 5; + + SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new SnowflakeDbConnection(ConnectionString); + connections[i].Open(); + } + + // Start the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[0].CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + + // Assert + Assert.IsTrue(connections[0].IsStillRunning(cmd.GetQueryStatus(queryId))); + } + + // Execute a normal query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[1].CreateCommand()) + { + // Arrange + cmd.CommandText = $"select 1;"; + + // Act + var row = cmd.ExecuteScalar(); + + // Assert + Assert.AreEqual(1, row); + } + + // Get results of the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[2].CreateCommand()) + { + // Act + DbDataReader reader = cmd.GetResultsFromQueryId(queryId); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, cmd.GetQueryStatus(queryId)); + } + + for (int i = 0; i < connections.Length; i++) + { + connections[i].Close(); + } + } + + [Test] + public void TestFailedAsyncExecQueryThrowsError() + { + string queryId; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + var statusMaxRetryCount = 5; + var statusRetryCount = 0; + cmd.CommandText = $"SELECT * FROM FAKE_TABLE;"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + while (statusRetryCount < statusMaxRetryCount && conn.IsStillRunning(cmd.GetQueryStatus(queryId))) + { + Thread.Sleep(1000); + statusRetryCount++; + } + + // Assert + Assert.AreEqual(QueryStatus.FailedWithError, cmd.GetQueryStatus(queryId)); + + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(queryId)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("'FAKE_TABLE' does not exist")); + } + + conn.Close(); + } + } + + [Test] + public void TestAsyncExecQueryPutGetThrowsNotImplemented() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"PUT file://non_existent_file.csv @~;"; + + // Act + var thrown = Assert.Throws(() => cmd.ExecuteInAsyncMode()); + + // Assert + Assert.IsTrue(thrown.Message.Contains("Get and Put are not supported in async execution mode")); + + // Arrange + cmd.CommandText = "GET @~ file://C:\\tmp\\;"; + + // Act + thrown = Assert.Throws(() => cmd.ExecuteInAsyncMode()); + + // Assert + Assert.IsTrue(thrown.Message.Contains("Get and Put are not supported in async execution mode")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetStatusOfInvalidQueryId() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetQueryStatus(fakeQueryId)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetResultsOfInvalidQueryId() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(fakeQueryId)); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + conn.Close(); + } + } + + [Test, NonParallelizable] + public void TestGetStatusOfUnknownQueryId() + { + string unknownQueryId = "ab123cde-1cba-789a-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var queryStatus = cmd.GetQueryStatus(unknownQueryId); + + // Assert + Assert.AreEqual(QueryStatus.NoData, queryStatus); + } + + conn.Close(); + } + } + + [Test] + [Ignore("The test takes too long to finish when using the default retry")] + public void TestGetResultsOfUnknownQueryIdWithDefaultRetry() + { + string unknownQueryId = "ba987def-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(unknownQueryId)); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"Max retry for no data is reached")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetResultsOfUnknownQueryIdWithConfiguredRetry() + { + var queryResultsRetryCount = 3; + var queryResultsRetryPattern = new int[] { 1, 2 }; + var unknownQueryId = "ba987def-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + QueryResultsAwaiter queryResultsAwaiter = new QueryResultsAwaiter(new QueryResultsRetryConfig(queryResultsRetryCount, queryResultsRetryPattern)); + var task = queryResultsAwaiter.RetryUntilQueryResultIsAvailable(conn, unknownQueryId, CancellationToken.None, false); + + // Act + var thrown = Assert.Throws(() => task.Wait()); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"Max retry for no data is reached")); + } + + conn.Close(); + } + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs b/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs index 330b19f96..24f0f4b0a 100755 --- a/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs @@ -6,6 +6,7 @@ namespace Snowflake.Data.Tests.UnitTests { using Snowflake.Data.Core; using NUnit.Framework; + using System; /** * Mock rest request test @@ -21,7 +22,7 @@ public void TestSessionRenew() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); Assert.AreEqual("new_session_token", sfSession.sessionToken); @@ -37,7 +38,7 @@ public void TestSessionRenewDuringQueryExec() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -57,7 +58,7 @@ public void TestServiceName() for (int i = 0; i < 5; i++) { SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false, false); expectServiceName += "a"; Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]); } @@ -73,7 +74,7 @@ public void TestTrimSqlBlockComment() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -88,7 +89,7 @@ public void TestTrimSqlBlockCommentMultiline() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -103,7 +104,7 @@ public void TestTrimSqlLineComment() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -118,9 +119,89 @@ public void TestTrimSqlLineCommentWithClosingNewline() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } + + [Test] + [TestCase("running", QueryStatus.Running)] + [TestCase("RUNNING", QueryStatus.Running)] + [TestCase("resuming_warehouse", QueryStatus.ResumingWarehouse)] + [TestCase("RESUMING_WAREHOUSE", QueryStatus.ResumingWarehouse)] + [TestCase("queued", QueryStatus.Queued)] + [TestCase("QUEUED", QueryStatus.Queued)] + [TestCase("queued_reparing_warehouse", QueryStatus.QueuedReparingWarehouse)] + [TestCase("QUEUED_REPARING_WAREHOUSE", QueryStatus.QueuedReparingWarehouse)] + [TestCase("no_data", QueryStatus.NoData)] + [TestCase("NO_DATA", QueryStatus.NoData)] + [TestCase("aborting", QueryStatus.Aborting)] + [TestCase("ABORTING", QueryStatus.Aborting)] + [TestCase("success", QueryStatus.Success)] + [TestCase("SUCCESS", QueryStatus.Success)] + [TestCase("failed_with_error", QueryStatus.FailedWithError)] + [TestCase("FAILED_WITH_ERROR", QueryStatus.FailedWithError)] + [TestCase("aborted", QueryStatus.Aborted)] + [TestCase("ABORTED", QueryStatus.Aborted)] + [TestCase("failed_with_incident", QueryStatus.FailedWithIncident)] + [TestCase("FAILED_WITH_INCIDENT", QueryStatus.FailedWithIncident)] + [TestCase("disconnected", QueryStatus.Disconnected)] + [TestCase("DISCONNECTED", QueryStatus.Disconnected)] + [TestCase("restarted", QueryStatus.Restarted)] + [TestCase("RESTARTED", QueryStatus.Restarted)] + [TestCase("blocked", QueryStatus.Blocked)] + [TestCase("BLOCKED", QueryStatus.Blocked)] + public void TestGetQueryStatusByStringValue(string stringValue, QueryStatus expectedStatus) + { + Assert.AreEqual(expectedStatus, QueryStatusExtensions.GetQueryStatusByStringValue(stringValue)); + } + + [Test] + [TestCase("UNKNOWN")] + [TestCase("RANDOM_STATUS")] + [TestCase("aBcZyX")] + public void TestGetQueryStatusByStringValueThrowsErrorForUnknownStatus(string stringValue) + { + var thrown = Assert.Throws(() => QueryStatusExtensions.GetQueryStatusByStringValue(stringValue)); + Assert.IsTrue(thrown.Message.Contains("The query status returned by the server is not recognized")); + } + + [Test] + [TestCase(QueryStatus.Running, true)] + [TestCase(QueryStatus.ResumingWarehouse, true)] + [TestCase(QueryStatus.Queued, true)] + [TestCase(QueryStatus.QueuedReparingWarehouse, true)] + [TestCase(QueryStatus.NoData, true)] + [TestCase(QueryStatus.Aborting, false)] + [TestCase(QueryStatus.Success, false)] + [TestCase(QueryStatus.FailedWithError, false)] + [TestCase(QueryStatus.Aborted, false)] + [TestCase(QueryStatus.FailedWithIncident, false)] + [TestCase(QueryStatus.Disconnected, false)] + [TestCase(QueryStatus.Restarted, false)] + [TestCase(QueryStatus.Blocked, false)] + public void TestIsStillRunning(QueryStatus status, bool expectedResult) + { + Assert.AreEqual(expectedResult, QueryStatusExtensions.IsStillRunning(status)); + } + + [Test] + [TestCase(QueryStatus.Aborting, true)] + [TestCase(QueryStatus.FailedWithError, true)] + [TestCase(QueryStatus.Aborted, true)] + [TestCase(QueryStatus.FailedWithIncident, true)] + [TestCase(QueryStatus.Disconnected, true)] + [TestCase(QueryStatus.Blocked, true)] + [TestCase(QueryStatus.Running, false)] + [TestCase(QueryStatus.ResumingWarehouse, false)] + [TestCase(QueryStatus.Queued, false)] + [TestCase(QueryStatus.QueuedReparingWarehouse, false)] + [TestCase(QueryStatus.NoData, false)] + [TestCase(QueryStatus.Success, false)] + [TestCase(QueryStatus.Restarted, false)] + public void TestIsAnError(QueryStatus status, bool expectedResult) + { + Assert.AreEqual(expectedResult, QueryStatusExtensions.IsAnError(status)); + } } } diff --git a/Snowflake.Data/Client/SnowflakeDbCommand.cs b/Snowflake.Data/Client/SnowflakeDbCommand.cs index ca415bacc..ce004df5c 100755 --- a/Snowflake.Data/Client/SnowflakeDbCommand.cs +++ b/Snowflake.Data/Client/SnowflakeDbCommand.cs @@ -9,7 +9,6 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Newtonsoft.Json; using Snowflake.Data.Log; namespace Snowflake.Data.Client @@ -17,7 +16,7 @@ namespace Snowflake.Data.Client [System.ComponentModel.DesignerCategory("Code")] public class SnowflakeDbCommand : DbCommand { - private DbConnection connection; + private SnowflakeDbConnection connection; private SFStatement sfStatement; @@ -25,6 +24,8 @@ public class SnowflakeDbCommand : DbCommand private SFLogger logger = SFLoggerFactory.GetLogger(); + private readonly QueryResultsAwaiter _queryResultsAwaiter = QueryResultsAwaiter.Instance; + public SnowflakeDbCommand() { logger.Debug("Constructing SnowflakeDbCommand class"); @@ -274,6 +275,88 @@ protected override async Task ExecuteDbDataReaderAsync(CommandBeha } } + /// + /// Execute a query in async mode. + /// Async mode means the server will respond immediately with the query ID and execute the query asynchronously + /// + /// The query id. + public string ExecuteInAsyncMode() + { + logger.Debug($"ExecuteInAsyncMode"); + SFBaseResultSet resultSet = ExecuteInternal(asyncExec: true); + return resultSet.queryId; + } + + /// + /// Executes an asynchronous query in async mode. + /// Async mode means the server will respond immediately with the query ID and execute the query asynchronously + /// + /// + /// The query id. + public async Task ExecuteAsyncInAsyncMode(CancellationToken cancellationToken) + { + logger.Debug($"ExecuteAsyncInAsyncMode"); + var resultSet = await ExecuteInternalAsync(cancellationToken, asyncExec: true).ConfigureAwait(false); + return resultSet.queryId; + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + public QueryStatus GetQueryStatus(string queryId) + { + logger.Debug($"GetQueryStatus"); + return _queryResultsAwaiter.GetQueryStatus(connection, queryId); + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// + /// The query status. + public async Task GetQueryStatusAsync(string queryId, CancellationToken cancellationToken) + { + logger.Debug($"GetQueryStatusAsync"); + return await _queryResultsAwaiter.GetQueryStatusAsync(connection, queryId, cancellationToken); + } + + /// + /// Gets the query results based on query ID. + /// + /// + /// The query results. + public DbDataReader GetResultsFromQueryId(string queryId) + { + logger.Debug($"GetResultsFromQueryId"); + + Task task = _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, CancellationToken.None, false); + task.Wait(); + + SFBaseResultSet resultSet = sfStatement.GetResultWithId(queryId); + + return new SnowflakeDbDataReader(this, resultSet); + } + + /// + /// Gets the query results based on query ID. + /// + /// + /// + /// The query results. + public async Task GetResultsFromQueryIdAsync(string queryId, CancellationToken cancellationToken) + { + logger.Debug($"GetResultsFromQueryIdAsync"); + + await _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, cancellationToken, true); + + SFBaseResultSet resultSet = await sfStatement.GetResultWithIdAsync(queryId, cancellationToken).ConfigureAwait(false); + + return new SnowflakeDbDataReader(this, resultSet); + } + private static Dictionary convertToBindList(List parameters) { if (parameters == null || parameters.Count == 0) @@ -354,18 +437,18 @@ private void SetStatement() this.sfStatement = new SFStatement(session); } - private SFBaseResultSet ExecuteInternal(bool describeOnly = false) + private SFBaseResultSet ExecuteInternal(bool describeOnly = false, bool asyncExec = false) { CheckIfCommandTextIsSet(); SetStatement(); - return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly); + return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec); } - private Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false) + private Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false, bool asyncExec = false) { CheckIfCommandTextIsSet(); SetStatement(); - return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, cancellationToken); + return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec, cancellationToken); } private void CheckIfCommandTextIsSet() diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index b773a0150..cce4974fc 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -176,7 +176,7 @@ public override async Task CloseAsync() } #endif - public virtual Task CloseAsync(CancellationToken cancellationToken) + public virtual async Task CloseAsync(CancellationToken cancellationToken) { logger.Debug("Close Connection."); TaskCompletionSource taskCompletionSource = new TaskCompletionSource(); @@ -199,7 +199,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) } else { - SfSession.CloseAsync(cancellationToken).ContinueWith( + await SfSession.CloseAsync(cancellationToken).ContinueWith( previousTask => { if (previousTask.IsFaulted) @@ -220,7 +220,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) _connectionState = ConnectionState.Closed; taskCompletionSource.SetResult(null); } - }, cancellationToken); + }, cancellationToken).ConfigureAwait(false); } } else @@ -229,7 +229,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) taskCompletionSource.SetResult(null); } } - return taskCompletionSource.Task; + await taskCompletionSource.Task; } protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) @@ -402,6 +402,16 @@ internal void registerConnectionCancellationCallback(CancellationToken externalC } } + public bool IsStillRunning(QueryStatus status) + { + return QueryStatusExtensions.IsStillRunning(status); + } + + public bool IsAnError(QueryStatus status) + { + return QueryStatusExtensions.IsAnError(status); + } + ~SnowflakeDbConnection() { Dispose(false); diff --git a/Snowflake.Data/Core/QueryResultsAwaiter.cs b/Snowflake.Data/Core/QueryResultsAwaiter.cs new file mode 100644 index 000000000..5ea187fbe --- /dev/null +++ b/Snowflake.Data/Core/QueryResultsAwaiter.cs @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Log; +using System.Threading.Tasks; +using System.Threading; +using System; +using System.Text.RegularExpressions; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Core +{ + internal class QueryResultsRetryConfig + { + private const int DefaultAsyncNoDataMaxRetry = 24; + + private readonly int[] _defaultAsyncRetryPattern = { 1, 1, 2, 3, 4, 8, 10 }; + + internal readonly int _asyncNoDataMaxRetry; + + internal readonly int[] _asyncRetryPattern; + + internal QueryResultsRetryConfig() + { + _asyncNoDataMaxRetry = DefaultAsyncNoDataMaxRetry; + _asyncRetryPattern = _defaultAsyncRetryPattern; + } + + internal QueryResultsRetryConfig(int asyncNoDataMaxRetry, int[] asyncRetryPattern) + { + _asyncNoDataMaxRetry = asyncNoDataMaxRetry; + _asyncRetryPattern = asyncRetryPattern; + } + } + + internal class QueryResultsAwaiter + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly Regex UuidRegex = new Regex("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"); + + private QueryResultsRetryConfig _queryResultsRetryConfig; + + internal static readonly QueryResultsAwaiter Instance = new QueryResultsAwaiter(); + + internal QueryResultsAwaiter() + { + _queryResultsRetryConfig = new QueryResultsRetryConfig(); + } + + internal QueryResultsAwaiter(QueryResultsRetryConfig queryResultsRetryConfig) + { + _queryResultsRetryConfig = queryResultsRetryConfig; + } + + internal QueryStatus GetQueryStatus(SnowflakeDbConnection connection, string queryId) + { + if (UuidRegex.IsMatch(queryId)) + { + var sfStatement = new SFStatement(connection.SfSession); + return sfStatement.GetQueryStatus(queryId); + } + else + { + var errorMessage = $"The given query id {queryId} is not valid uuid"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + internal async Task GetQueryStatusAsync(SnowflakeDbConnection connection, string queryId, CancellationToken cancellationToken) + { + if (UuidRegex.IsMatch(queryId)) + { + var sfStatement = new SFStatement(connection.SfSession); + return await sfStatement.GetQueryStatusAsync(queryId, cancellationToken).ConfigureAwait(false); + } + else + { + var errorMessage = $"The given query id {queryId} is not valid uuid"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + /// + /// Checks query status until it is done executing. + /// + /// + /// + /// + /// + internal async Task RetryUntilQueryResultIsAvailable(SnowflakeDbConnection connection, string queryId, CancellationToken cancellationToken, bool isAsync) + { + int retryPatternPos = 0; + int noDataCounter = 0; + + QueryStatus status; + while (true) + { + if (cancellationToken.IsCancellationRequested) + { + s_logger.Debug("Cancellation requested for getting results from query id"); + cancellationToken.ThrowIfCancellationRequested(); + } + + status = isAsync ? await GetQueryStatusAsync(connection, queryId, cancellationToken) : GetQueryStatus(connection, queryId); + + if (!QueryStatusExtensions.IsStillRunning(status)) + { + return; + } + + // Timeout based on query status retry rules + if (isAsync) + { + await Task.Delay(TimeSpan.FromSeconds(_queryResultsRetryConfig._asyncRetryPattern[retryPatternPos]), cancellationToken).ConfigureAwait(false); + } + else + { + Thread.Sleep(TimeSpan.FromSeconds(_queryResultsRetryConfig._asyncRetryPattern[retryPatternPos])); + } + + // If no data, increment the no data counter + if (status == QueryStatus.NoData) + { + noDataCounter++; + + // Check if retry for no data is exceeded + if (noDataCounter > _queryResultsRetryConfig._asyncNoDataMaxRetry) + { + var errorMessage = "Max retry for no data is reached"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + if (retryPatternPos < _queryResultsRetryConfig._asyncRetryPattern.Length - 1) + { + retryPatternPos++; + } + } + } + } +} diff --git a/Snowflake.Data/Core/RestParams.cs b/Snowflake.Data/Core/RestParams.cs index 9dd4de8c8..72b910847 100644 --- a/Snowflake.Data/Core/RestParams.cs +++ b/Snowflake.Data/Core/RestParams.cs @@ -39,6 +39,8 @@ internal static class RestPath internal const string SF_QUERY_PATH = "/queries/v1/query-request"; + internal const string SF_MONITOR_QUERY_PATH = "/monitoring/queries/"; + internal const string SF_SESSION_HEARTBEAT_PATH = SF_SESSION_PATH + "/heartbeat"; internal const string SF_CONSOLE_LOGIN = "/console/login"; diff --git a/Snowflake.Data/Core/RestRequest.cs b/Snowflake.Data/Core/RestRequest.cs index bc4ec8d21..112743f77 100644 --- a/Snowflake.Data/Core/RestRequest.cs +++ b/Snowflake.Data/Core/RestRequest.cs @@ -129,6 +129,8 @@ internal SFRestRequest() : base() internal bool _isLogin { get; set; } + internal bool _isStatusRequest { get; set; } + public override string ToString() { return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), @@ -154,7 +156,7 @@ HttpRequestMessage IRestRequest.ToRequestMessage(HttpMethod method) // add quote otherwise it would be reported as error format string osInfo = "(" + SFEnvironment.ClientEnv.osVersion + ")"; - if (isPutGet) + if (isPutGet || _isStatusRequest) { message.Headers.Accept.Add(applicationJson); } @@ -313,6 +315,9 @@ class QueryRequest [JsonProperty(PropertyName = "queryContextDTO", NullValueHandling = NullValueHandling.Ignore)] internal RequestQueryContext QueryContextDTO { get; set; } + + [JsonProperty(PropertyName = "asyncExec")] + internal bool asyncExec { get; set; } } // The query context in query response diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index 97f3f9772..75f1698ea 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -423,6 +423,37 @@ internal class PutGetEncryptionMaterial internal long smkId { get; set; } } + internal class QueryStatusResponse : BaseRestResponse + { + + [JsonProperty(PropertyName = "data")] + internal QueryStatusData data { get; set; } + } + + internal class QueryStatusData + { + [JsonProperty(PropertyName = "queries", NullValueHandling = NullValueHandling.Ignore)] + internal List queries { get; set; } + } + + internal class QueryStatusDataQueries + { + [JsonProperty(PropertyName = "id", NullValueHandling = NullValueHandling.Ignore)] + internal string id { get; set; } + + [JsonProperty(PropertyName = "status", NullValueHandling = NullValueHandling.Ignore)] + internal string status { get; set; } + + [JsonProperty(PropertyName = "state", NullValueHandling = NullValueHandling.Ignore)] + internal string state { get; set; } + + [JsonProperty(PropertyName = "errorCode", NullValueHandling = NullValueHandling.Ignore)] + internal string errorCode { get; set; } + + [JsonProperty(PropertyName = "errorMessage", NullValueHandling = NullValueHandling.Ignore)] + internal string errorMessage { get; set; } + } + // Retrieved from: https://stackoverflow.com/a/18997172 internal class SingleOrArrayConverter : JsonConverter { diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index 68af7405b..a1b3f161d 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -290,7 +290,7 @@ private void CreateStage() try { SFStatement statement = new SFStatement(session); - SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false); + SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false, false); session.SetArrayBindStage(STAGE_NAME); } catch (Exception e) @@ -314,7 +314,7 @@ internal async Task CreateStageAsync(CancellationToken cancellationToken) try { SFStatement statement = new SFStatement(session); - var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, cancellationToken).ConfigureAwait(false); + var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, false, cancellationToken).ConfigureAwait(false); session.SetArrayBindStage(STAGE_NAME); } catch (Exception e) diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 55b069806..03e1794c9 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -28,7 +28,7 @@ public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, { try { - columnCount = responseData.rowType.Count; + columnCount = responseData.rowType?.Count ?? 0; this.sfStatement = sfStatement; UpdateSessionStatus(responseData); @@ -40,10 +40,10 @@ public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); } - _currentChunk = new SFResultChunk(responseData.rowSet); + _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; responseData.rowSet = null; - sfResultSetMetaData = new SFResultSetMetaData(responseData, this.sfStatement.SfSession); + sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; isClosed = false; diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index a8ada8af2..9252af40e 100644 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -3,13 +3,10 @@ */ using System; -using System.Web; -using Newtonsoft.Json.Linq; using System.Collections.Generic; using System.IO; using System.Linq; using Snowflake.Data.Client; -using Snowflake.Data.Core.FileTransfer; using Snowflake.Data.Log; using System.Threading; using System.Threading.Tasks; @@ -17,6 +14,87 @@ namespace Snowflake.Data.Core { + /// + /// The status types of the query. + /// + public enum QueryStatus + { + [StringAttr(value = "NO_DATA")] + NoData, + [StringAttr(value = "RUNNING")] + Running, + [StringAttr(value = "ABORTING")] + Aborting, + [StringAttr(value = "SUCCESS")] + Success, + [StringAttr(value = "FAILED_WITH_ERROR")] + FailedWithError, + [StringAttr(value = "ABORTED")] + Aborted, + [StringAttr(value = "QUEUED")] + Queued, + [StringAttr(value = "FAILED_WITH_INCIDENT")] + FailedWithIncident, + [StringAttr(value = "DISCONNECTED")] + Disconnected, + [StringAttr(value = "RESUMING_WAREHOUSE")] + ResumingWarehouse, + // purposeful typo + [StringAttr(value = "QUEUED_REPARING_WAREHOUSE")] + QueuedReparingWarehouse, + [StringAttr(value = "RESTARTED")] + Restarted, + [StringAttr(value = "BLOCKED")] + Blocked, + } + + class StringAttr : Attribute + { + public string value { get; set; } + } + + internal static class QueryStatusExtensions + { + internal static QueryStatus GetQueryStatusByStringValue(string stringValue) + { + var statuses = Enum.GetValues(typeof(QueryStatus)) + .Cast() + .Where(v => v.GetAttribute().value.Equals(stringValue, StringComparison.OrdinalIgnoreCase)); + return statuses.Any() ? statuses.First() : throw new Exception("The query status returned by the server is not recognized"); + } + + internal static bool IsStillRunning(QueryStatus status) + { + switch (status) + { + case QueryStatus.Running: + case QueryStatus.ResumingWarehouse: + case QueryStatus.Queued: + case QueryStatus.QueuedReparingWarehouse: + case QueryStatus.NoData: + return true; + default: + return false; + } + } + + internal static bool IsAnError(QueryStatus status) + { + switch (status) + { + case QueryStatus.Aborting: + case QueryStatus.FailedWithError: + case QueryStatus.Aborted: + case QueryStatus.FailedWithIncident: + case QueryStatus.Disconnected: + case QueryStatus.Blocked: + return true; + default: + return false; + } + } + } + class SFStatement { static private SFLogger logger = SFLoggerFactory.GetLogger(); @@ -90,7 +168,7 @@ private void ClearQueryRequestId() _requestId = null; } - private SFRestRequest BuildQueryRequest(string sql, Dictionary bindings, bool describeOnly) + private SFRestRequest BuildQueryRequest(string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { AssignQueryRequestId(); @@ -120,8 +198,9 @@ private SFRestRequest BuildQueryRequest(string sql, Dictionary private bool SessionExpired(BaseRestResponse r) => r.code == SFSession.SF_SESSION_EXPIRED_CODE; - internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, + internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec, CancellationToken cancellationToken) { // Trim the sql query and check if this is a PUT/GET command string trimmedSql = TrimSql(sql); - if (IsPutOrGetCommand(trimmedSql)) { + if (IsPutOrGetCommand(trimmedSql)) + { throw new NotImplementedException("Get and Put are not supported in async calls. Use Execute() instead of ExecuteAsync()."); } @@ -287,8 +367,8 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti logger.Warn("Exception encountered trying to upload binds to stage. Attaching binds in payload instead. {0}", e); } } - - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { QueryExecResponse response = null; @@ -309,19 +389,22 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti var lastResultUrl = response.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); - - if (SessionExpired(response)) - { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); - } - else + while (RequestInProgress(response) || SessionExpired(response)) { - lastResultUrl = response.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); + } + else + { + lastResultUrl = response.data?.getResultUrl; + } } } @@ -338,8 +421,8 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti ClearQueryRequestId(); } } - - internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly) + + internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { // Trim the sql query and check if this is a PUT/GET command string trimmedSql = TrimSql(sql); @@ -347,10 +430,14 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly) + private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { try { @@ -437,7 +524,8 @@ private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dicti timeout, sql, bindings, - describeOnly); + describeOnly, + asyncExec); return BuildResultSet(response, CancellationToken.None); } @@ -531,12 +619,13 @@ internal T ExecuteHelper( int timeout, string sql, Dictionary bindings, - bool describeOnly) + bool describeOnly, + bool asyncExec = false) where T : BaseQueryExecResponse where U : IQueryExecResponseData { registerQueryCancellationCallback(timeout, CancellationToken.None); - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { T response = null; @@ -558,20 +647,24 @@ internal T ExecuteHelper( if (typeof(T) == typeof(QueryExecResponse)) { QueryExecResponse queryResponse = (QueryExecResponse)(object)response; - var lastResultUrl = queryResponse.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = _restRequester.Get(req); + var lastResultUrl = queryResponse.data?.getResultUrl; - if (SessionExpired(response)) + while (RequestInProgress(response) || SessionExpired(response)) { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - SfSession.renewSession(); - } - else - { - lastResultUrl = queryResponse.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = _restRequester.Get(req); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + SfSession.renewSession(); + } + else + { + lastResultUrl = queryResponse.data?.getResultUrl; + } } } } @@ -612,13 +705,14 @@ internal async Task ExecuteAsyncHelper( string sql, Dictionary bindings, bool describeOnly, - CancellationToken cancellationToken + CancellationToken cancellationToken, + bool asyncExec = false ) where T : BaseQueryExecResponse where U : IQueryExecResponseData { registerQueryCancellationCallback(timeout, CancellationToken.None); - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { T response = null; @@ -640,20 +734,24 @@ CancellationToken cancellationToken if (typeof(T) == typeof(QueryExecResponse)) { QueryExecResponse queryResponse = (QueryExecResponse)(object)response; - var lastResultUrl = queryResponse.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + var lastResultUrl = queryResponse.data?.getResultUrl; - if (SessionExpired(response)) + while (RequestInProgress(response) || SessionExpired(response)) { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); - } - else - { - lastResultUrl = queryResponse.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); + } + else + { + lastResultUrl = queryResponse.data?.getResultUrl; + } } } } @@ -680,6 +778,138 @@ CancellationToken cancellationToken } } + /// + /// Creates a request to get the query status based on query ID. + /// + /// + /// The request to get the query status. + private SFRestRequest BuildQueryStatusRequest(string queryId) + { + var queryUri = SfSession.BuildUri(RestPath.SF_MONITOR_QUERY_PATH + queryId); + + return new SFRestRequest + { + Url = queryUri, + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken), + serviceName = SfSession.ParameterMap.ContainsKey(SFSessionParameter.SERVICE_NAME) + ? (String)SfSession.ParameterMap[SFSessionParameter.SERVICE_NAME] : null, + HttpTimeout = Timeout.InfiniteTimeSpan, + RestTimeout = Timeout.InfiniteTimeSpan, + sid = SfSession.sessionId, + _isStatusRequest = true + }; + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + internal QueryStatus GetQueryStatus(string queryId) + { + var queryRequest = BuildQueryStatusRequest(queryId); + + try + { + QueryStatusResponse response = null; + bool receivedFirstQueryResponse = false; + while (!receivedFirstQueryResponse) + { + response = _restRequester.Get(queryRequest); + if (SessionExpired(response)) + { + SfSession.renewSession(); + queryRequest.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken); + } + else + { + receivedFirstQueryResponse = true; + } + } + + if (!response.success) + { + throw new SnowflakeDbException( + response.data.queries[0].state, + response.code, + response.message, + queryId); + } + + QueryStatus queryStatus = QueryStatus.NoData; + if (response.data.queries.Count != 0) + { + queryStatus = QueryStatusExtensions.GetQueryStatusByStringValue(response.data.queries[0].status); + } + + return queryStatus; + } + catch + { + logger.Error("Query execution failed."); + throw; + } + finally + { + ClearQueryRequestId(); + } + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + internal async Task GetQueryStatusAsync(string queryId, CancellationToken cancellationToken) + { + var queryRequest = BuildQueryStatusRequest(queryId); + + try + { + QueryStatusResponse response = null; + bool receivedFirstQueryResponse = false; + while (!receivedFirstQueryResponse) + { + response = await _restRequester.GetAsync(queryRequest, cancellationToken).ConfigureAwait(false); + if (SessionExpired(response)) + { + SfSession.renewSession(); + queryRequest.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken); + } + else + { + receivedFirstQueryResponse = true; + } + } + + if (!response.success) + { + throw new SnowflakeDbException( + response.data.queries[0].state, + response.code, + response.message, + queryId); + } + + QueryStatus queryStatus = QueryStatus.NoData; + if (response.data.queries.Count != 0) + { + queryStatus = QueryStatusExtensions.GetQueryStatusByStringValue(response.data.queries[0].status); + } + + return queryStatus; + } + catch + { + logger.Error("Query execution failed."); + throw; + } + finally + { + ClearQueryRequestId(); + } + } + /// /// Trim the query by removing spaces and comments at the beginning. /// From ee0cdaa6e893fc8ba7322ac5394190495fc6b5ae Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Tue, 26 Mar 2024 14:27:30 +0100 Subject: [PATCH 02/12] SNOW-817091 async executions documentation (#896) ### Description Async executions documentation. ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- README.md | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/README.md b/README.md index 41a7d61da..cc7aba2ef 100644 --- a/README.md +++ b/README.md @@ -503,6 +503,65 @@ Note that because this method is not available in the generic `IDataReader` inte TimeSpan timeSpanTime = ((SnowflakeDbDataReader)reader).GetTimeSpan(13); ``` +## Execute a query asynchronously on the server + +You can run the query asynchronously on the server. The server responds immediately with `queryId` and continues to execute the query asynchronously. +Then you can use this `queryId` to check the query status or wait until the query is completed and get the results. +It is fine to start the query in one session and continue to query for the results in another one based on the queryId. + +**Note**: There are 2 levels of asynchronous execution. One is asynchronous execution in terms of C# language (`async await`). +Another is asynchronous execution of the query by the server (you can recognize it by `InAsyncMode` containing method names, e. g. `ExecuteInAsyncMode`, `ExecuteAsyncInAsyncMode`). + +Example of synchronous code starting a query to be executed asynchronously on the server: +```cs +using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) +{ + conn.Open(); + SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand(); + cmd.CommandText = "SELECT ..."; + var queryId = cmd.ExecuteInAsyncMode(); + // ... +} +``` + +Example of asynchronous code starting a query to be executed asynchronously on the server: +```cs +using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) +{ + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + cmd.CommandText = "SELECT ..."; + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + // ... +} +``` + +You can check the status of a query executed asynchronously on the server either in synchronous code: +```cs +var queryStatus = cmd.GetQueryStatus(queryId); +Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running +Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error +``` +or the same in an asynchronous code: +```cs +var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); +Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running +Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error +``` + +The following example shows how to get query results. +The operation will repeatedly check the query status until the query is completed or timeout happened or reaching the maximum number of attempts. +The synchronous code example: +```cs +DbDataReader reader = cmd.GetResultsFromQueryId(queryId); +``` +and the asynchronous code example: +```cs +DbDataReader reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); +``` + +**Note**: GET/PUT operations are currently not enabled for asynchronous executions. + ## Executing a Batch of SQL Statements (Multi-Statement Support) With version 2.0.18 and later of the .NET connector, you can send From 030351414f2e380aa73ae99f4b039dffc3aac939 Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Wed, 27 Mar 2024 15:38:46 +0100 Subject: [PATCH 03/12] MINOR: Bumped up DotNet connector MINOR version from 3.0.0 to 3.1.0 (#902) @noreview - This is an automated process. No review is required ### Description MINOR: Bumped up DotNet connector MINOR version from 3.0.0 to 3.1.0 ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name Co-authored-by: Jenkins User --- Snowflake.Data/Snowflake.Data.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index 397bf9652..0621c5fb0 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -13,7 +13,7 @@ Snowflake Connector for .NET howryu, tchen Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. - 3.0.0 + 3.1.0 Full 7.3 From ac0860fec989700d2abfefe2132f7910e287339c Mon Sep 17 00:00:00 2001 From: sfc-gh-ext-simba-lf <115584722+sfc-gh-ext-simba-lf@users.noreply.github.com> Date: Mon, 1 Apr 2024 09:18:45 -0700 Subject: [PATCH 04/12] SNOW-979288: Add explicit DbType Parameter assignment (#889) ### Description Add explicit DbType Parameter assignment ### Checklist - [ ] Code compiles correctly - [ ] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [ ] Created tests which fail without the change (if possible) - [ ] All tests passing (`dotnet test`) - [ ] Extended the README / documentation, if necessary - [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .../IntegrationTests/SFBindTestIT.cs | 64 +++++++++++++ .../UnitTests/SFDbParameterTest.cs | 90 +++++++++++++++++++ Snowflake.Data/Client/SnowflakeDbParameter.cs | 25 +++++- Snowflake.Data/Core/SFDataConverter.cs | 25 ++++++ 4 files changed, 203 insertions(+), 1 deletion(-) diff --git a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs index 94bdc8bd8..e222e5892 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs @@ -769,5 +769,69 @@ public void testPutArrayBind1() conn.Close(); } } + + [Test] + public void testExplicitDbTypeAssignmentForSimpleValue() + { + + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + CreateOrReplaceTable(conn, TableName, new[] + { + "cola INTEGER", + }); + + using (IDbCommand cmd = conn.CreateCommand()) + { + string insertCommand = $"insert into {TableName} values (?)"; + cmd.CommandText = insertCommand; + + var p1 = cmd.CreateParameter(); + p1.ParameterName = "1"; + p1.Value = 1; + cmd.Parameters.Add(p1); + + var count = cmd.ExecuteNonQuery(); + Assert.AreEqual(1, count); + } + + conn.Close(); + } + } + + [Test] + public void testExplicitDbTypeAssignmentForArrayValue() + { + + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + CreateOrReplaceTable(conn, TableName, new[] + { + "cola INTEGER", + }); + + using (IDbCommand cmd = conn.CreateCommand()) + { + string insertCommand = $"insert into {TableName} values (?)"; + cmd.CommandText = insertCommand; + + var p1 = cmd.CreateParameter(); + p1.ParameterName = "1"; + p1.Value = new int[] { 1, 2, 3 }; + cmd.Parameters.Add(p1); + + var count = cmd.ExecuteNonQuery(); + Assert.AreEqual(3, count); + } + + conn.Close(); + } + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SFDbParameterTest.cs b/Snowflake.Data.Tests/UnitTests/SFDbParameterTest.cs index 7674d096a..f9b83cdf1 100644 --- a/Snowflake.Data.Tests/UnitTests/SFDbParameterTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFDbParameterTest.cs @@ -7,7 +7,9 @@ namespace Snowflake.Data.Tests using NUnit.Framework; using Snowflake.Data.Client; using Snowflake.Data.Core; + using System; using System.Data; + using System.Text; [TestFixture] class SFDbParameterTest @@ -125,5 +127,93 @@ public void TestDbParameterResetDbType([Values] SFDataType expectedSFDataType) _parameter.ResetDbType(); Assert.AreEqual(SFDataType.None, _parameter.SFDataType); } + + [Test] + public void TestDbTypeExplicitAssignment([Values] DbType expectedDbType) + { + _parameter = new SnowflakeDbParameter(); + + switch (expectedDbType) + { + case DbType.SByte: + _parameter.Value = new sbyte(); + break; + case DbType.Byte: + _parameter.Value = new byte(); + break; + case DbType.Int16: + _parameter.Value = new short(); + break; + case DbType.Int32: + _parameter.Value = new int(); + break; + case DbType.Int64: + _parameter.Value = new long(); + break; + case DbType.UInt16: + _parameter.Value = new ushort(); + break; + case DbType.UInt32: + _parameter.Value = new uint(); + break; + case DbType.UInt64: + _parameter.Value = new ulong(); + break; + case DbType.Decimal: + _parameter.Value = new decimal(); + break; + case DbType.Boolean: + _parameter.Value = true; + break; + case DbType.Single: + _parameter.Value = new float(); + break; + case DbType.Double: + _parameter.Value = new double(); + break; + case DbType.Guid: + _parameter.Value = new Guid(); + break; + case DbType.String: + _parameter.Value = "thisIsAString"; + break; + case DbType.DateTime: + _parameter.Value = DateTime.Now; + break; + case DbType.DateTimeOffset: + _parameter.Value = DateTimeOffset.Now; + break; + case DbType.Binary: + _parameter.Value = Encoding.UTF8.GetBytes("BinaryData"); + break; + case DbType.Object: + _parameter.Value = new object(); + break; + default: + // Not supported + expectedDbType = default(DbType); + break; + } + + Assert.AreEqual(expectedDbType, _parameter.DbType); + } + + [Test] + public void TestDbTypeExplicitAssignmentWithNullValueAndDefaultDbType() + { + _parameter = new SnowflakeDbParameter(); + _parameter.Value = null; + Assert.AreEqual(default(DbType), _parameter.DbType); + } + + [Test] + public void TestDbTypeExplicitAssignmentWithNullValueAndNonDefaultDbType() + { + var nonDefaultDbType = DbType.String; + _parameter = new SnowflakeDbParameter(); + _parameter.Value = null; + _parameter.DbType = nonDefaultDbType; + Assert.AreEqual(nonDefaultDbType, _parameter.DbType); + } } } diff --git a/Snowflake.Data/Client/SnowflakeDbParameter.cs b/Snowflake.Data/Client/SnowflakeDbParameter.cs index 4cfc02449..c03e785ec 100755 --- a/Snowflake.Data/Client/SnowflakeDbParameter.cs +++ b/Snowflake.Data/Client/SnowflakeDbParameter.cs @@ -15,6 +15,8 @@ public class SnowflakeDbParameter : DbParameter private SFDataType OriginType; + private DbType _dbType; + public SnowflakeDbParameter() { SFDataType = SFDataType.None; @@ -34,7 +36,28 @@ public SnowflakeDbParameter(int ParameterIndex, SFDataType SFDataType) this.SFDataType = SFDataType; } - public override DbType DbType { get; set; } + public override DbType DbType + { + get + { + if (_dbType != default(DbType) || Value == null || Value is DBNull) + { + return _dbType; + } + + var type = Value.GetType(); + if (type.IsArray && type != typeof(byte[])) + { + return SFDataConverter.TypeToDbTypeMap[type.GetElementType()]; + } + else + { + return SFDataConverter.TypeToDbTypeMap[type]; + } + } + + set => _dbType = value; + } public override ParameterDirection Direction { diff --git a/Snowflake.Data/Core/SFDataConverter.cs b/Snowflake.Data/Core/SFDataConverter.cs index 1159016af..2e380f73d 100755 --- a/Snowflake.Data/Core/SFDataConverter.cs +++ b/Snowflake.Data/Core/SFDataConverter.cs @@ -3,6 +3,7 @@ */ using System; +using System.Collections.Generic; using System.Data; using System.Globalization; using System.Text; @@ -20,6 +21,30 @@ static class SFDataConverter { internal static readonly DateTime UnixEpoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc); + internal static readonly Dictionary TypeToDbTypeMap = new Dictionary() + { + [typeof(byte)] = DbType.Byte, + [typeof(sbyte)] = DbType.SByte, + [typeof(short)] = DbType.Int16, + [typeof(ushort)] = DbType.UInt16, + [typeof(int)] = DbType.Int32, + [typeof(uint)] = DbType.UInt32, + [typeof(long)] = DbType.Int64, + [typeof(ulong)] = DbType.UInt64, + [typeof(float)] = DbType.Single, + [typeof(double)] = DbType.Double, + [typeof(decimal)] = DbType.Decimal, + [typeof(bool)] = DbType.Boolean, + [typeof(string)] = DbType.String, + [typeof(char)] = DbType.StringFixedLength, + [typeof(Guid)] = DbType.Guid, + [typeof(DateTime)] = DbType.DateTime, + [typeof(DateTimeOffset)] = DbType.DateTimeOffset, + [typeof(TimeSpan)] = DbType.Time, + [typeof(byte[])] = DbType.Binary, + [typeof(object)] = DbType.Object + }; + internal static object ConvertToCSharpVal(UTF8Buffer srcVal, SFDataType srcType, Type destType) { if (srcVal == null) From a302e0f72e3e5c3837264d7372783ff7a0a7357b Mon Sep 17 00:00:00 2001 From: Krzysztof Nozderko Date: Wed, 3 Apr 2024 12:56:59 +0200 Subject: [PATCH 05/12] Fix coding conventions PR template (#905) ### Description Fix coding conventions PR template ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .github/PULL_REQUEST_TEMPLATE.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index d84de2b17..eb38611c3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -3,8 +3,8 @@ Please explain the changes you made here. ### Checklist - [ ] Code compiles correctly -- [ ] Code is formatted according to [Coding Conventions](../CodingConventions.md) +- [ ] Code is formatted according to [Coding Conventions](../blob/master/CodingConventions.md) - [ ] Created tests which fail without the change (if possible) - [ ] All tests passing (`dotnet test`) - [ ] Extended the README / documentation, if necessary -- [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name \ No newline at end of file +- [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name From dfba44e488735237a744aaadc8969e7e6aa5cdd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Thu, 4 Apr 2024 13:38:50 +0200 Subject: [PATCH 06/12] SNOW-1271212 Fixed values uploaded to stage for bindings exceeding CLIENT_STAGE_ARRAY_BINDING_THRESHOLD (#897) ### Description When number of binded values during query execution exceeds the threshold of a session parameter CLIENT_STAGE_ARRAY_BINDING_THRESHOLD then values are written as a CSV file to a stage and it get's picked during query execution. Improper values (or values truncating nanos) has been uploaded prior to this fix for date and time related columns of type: DATE, TIME, TIMESTAMP_LTZ, TIMESTAMP_NTZ, TIMESTAMP_TZ. ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .../IntegrationTests/SFBindTestIT.cs | 266 ++++++++++++++++-- Snowflake.Data.Tests/SFBaseTest.cs | 11 +- .../UnitTests/SFBindUploaderTest.cs | 107 +++++++ .../Util/DbCommandExtensions.cs | 18 ++ .../Util/DbConnectionExtensions.cs | 23 ++ .../Util/TableTypeExtensions.cs | 30 ++ Snowflake.Data/Client/SnowflakeDbCommand.cs | 2 + Snowflake.Data/Core/SFBindUploader.cs | 42 ++- Snowflake.Data/Core/SFStatement.cs | 2 + 9 files changed, 454 insertions(+), 47 deletions(-) create mode 100644 Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs create mode 100644 Snowflake.Data.Tests/Util/DbCommandExtensions.cs create mode 100644 Snowflake.Data.Tests/Util/DbConnectionExtensions.cs create mode 100644 Snowflake.Data.Tests/Util/TableTypeExtensions.cs diff --git a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs index e222e5892..00a1857a2 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFBindTestIT.cs @@ -1,25 +1,31 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; using System.Data; +using System.Linq; +using Microsoft.IdentityModel.Tokens; using Newtonsoft.Json; +using Snowflake.Data.Log; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using System.Text; +using System.Globalization; +using System.Collections.Generic; +using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests { - using NUnit.Framework; - using Snowflake.Data.Client; - using Snowflake.Data.Core; - using System.Text; - using System.Globalization; - using System.Collections.Generic; [TestFixture] class SFBindTestIT : SFBaseTest { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + [Test] - public void testArrayBind() + public void TestArrayBind() { using (IDbConnection conn = new SnowflakeDbConnection()) @@ -59,7 +65,7 @@ public void testArrayBind() } [Test] - public void testBindNullValue() + public void TestBindNullValue() { using (SnowflakeDbConnection dbConnection = new SnowflakeDbConnection()) { @@ -196,7 +202,7 @@ public void testBindNullValue() } [Test] - public void testBindValue() + public void TestBindValue() { using (SnowflakeDbConnection dbConnection = new SnowflakeDbConnection()) { @@ -313,7 +319,7 @@ public void testBindValue() command.CommandText = $"insert into {TableName}(stringData) values(:p0)"; param.Value = DBNull.Value; command.Parameters.Add(param); - int rowsInserted = command.ExecuteNonQuery(); + command.ExecuteNonQuery(); } catch (SnowflakeDbException e) { @@ -347,7 +353,7 @@ public void testBindValue() } [Test] - public void testBindValueWithSFDataType() + public void TestBindValueWithSFDataType() { using (SnowflakeDbConnection dbConnection = new SnowflakeDbConnection()) { @@ -440,7 +446,7 @@ public void testBindValueWithSFDataType() command.CommandText = $"insert into {TableName}(unsupportedType) values(:p0)"; param.Value = DBNull.Value; command.Parameters.Add(param); - int rowsInserted = command.ExecuteNonQuery(); + command.ExecuteNonQuery(); } catch (SnowflakeDbException e) { @@ -468,7 +474,7 @@ public void testBindValueWithSFDataType() } [Test] - public void testParameterCollection() + public void TestParameterCollection() { using (IDbConnection conn = new SnowflakeDbConnection()) { @@ -524,7 +530,7 @@ public void testParameterCollection() } [Test] - public void testPutArrayBind() + public void TestPutArrayBind() { using (IDbConnection conn = new SnowflakeDbConnection()) { @@ -646,10 +652,6 @@ public void testPutArrayBind() cmd.CommandText = $"SELECT * FROM {TableName}"; IDataReader reader = cmd.ExecuteReader(); Assert.IsTrue(reader.Read()); - - //cmd.CommandText = "drop table if exists testPutArrayBind"; - //cmd.ExecuteNonQuery(); - } conn.Close(); @@ -657,7 +659,7 @@ public void testPutArrayBind() } [Test] - public void testPutArrayBindWorkDespiteOtTypeNameHandlingAuto() + public void TestPutArrayBindWorkDespiteOtTypeNameHandlingAuto() { JsonConvert.DefaultSettings = () => new JsonSerializerSettings { TypeNameHandling = TypeNameHandling.Auto @@ -729,7 +731,7 @@ public void testPutArrayBindWorkDespiteOtTypeNameHandlingAuto() } [Test] - public void testPutArrayBind1() + public void TestPutArrayIntegerBind() { using (IDbConnection conn = new SnowflakeDbConnection()) { @@ -771,7 +773,7 @@ public void testPutArrayBind1() } [Test] - public void testExplicitDbTypeAssignmentForSimpleValue() + public void TestExplicitDbTypeAssignmentForSimpleValue() { using (IDbConnection conn = new SnowflakeDbConnection()) @@ -803,7 +805,7 @@ public void testExplicitDbTypeAssignmentForSimpleValue() } [Test] - public void testExplicitDbTypeAssignmentForArrayValue() + public void TestExplicitDbTypeAssignmentForArrayValue() { using (IDbConnection conn = new SnowflakeDbConnection()) @@ -833,5 +835,223 @@ public void testExplicitDbTypeAssignmentForArrayValue() conn.Close(); } } + + private const string FormatYmd = "yyyy/MM/dd"; + private const string FormatHms = "HH\\:mm\\:ss"; + private const string FormatHmsf = "HH\\:mm\\:ss\\.fff"; + private const string FormatYmdHms = "yyyy/MM/dd HH\\:mm\\:ss"; + private const string FormatYmdHmsZ = "yyyy/MM/dd HH\\:mm\\:ss zzz"; + + // STANDARD Tables + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + /* TODO: Enable when features available on the automated tests environment + // HYBRID Tables + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.JSON, SFTableType.Hybrid, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Hybrid, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + // ICEBERG Tables; require env variables: ICEBERG_EXTERNAL_VOLUME, ICEBERG_CATALOG, ICEBERG_BASE_LOCATION. + [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + // [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] // Unsupported data type 'TIMESTAMP_TZ(6)' for iceberg tables + [TestCase(ResultFormat.JSON, SFTableType.Iceberg, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.DATE, null, DbType.Date, FormatYmd, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIME, null, DbType.Time, FormatHms, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIME, 6, DbType.Time, FormatHmsf, null)] + [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIMESTAMP_NTZ, 6, DbType.DateTime, FormatYmdHms, null)] + // [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIMESTAMP_TZ, 6, DbType.DateTime, FormatYmdHmsZ, null)] // Unsupported data type 'TIMESTAMP_TZ(6)' for iceberg tables + [TestCase(ResultFormat.ARROW, SFTableType.Iceberg, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, null)] + */ + // Session TimeZone cases + [TestCase(ResultFormat.ARROW, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Europe/Warsaw")] + [TestCase(ResultFormat.JSON, SFTableType.Standard, SFDataType.TIMESTAMP_LTZ, 6, DbType.DateTimeOffset, FormatYmdHmsZ, "Asia/Tokyo")] + public void TestDateTimeBinding(ResultFormat resultFormat, SFTableType tableType, SFDataType columnType, Int32? columnPrecision, DbType bindingType, string comparisonFormat, string timeZone) + { + // Arrange + var timestamp = "2023/03/15 13:17:29.207 +05:00"; // 08:17:29.207 UTC + var expected = ExpectedTimestampWrapper.From(timestamp, columnType); + var columnWithPrecision = ColumnTypeWithPrecision(columnType, columnPrecision); + var testCase = $"ResultFormat={resultFormat}, TableType={tableType}, ColumnType={columnWithPrecision}, BindingType={bindingType}, ComparisonFormat={comparisonFormat}"; + var bindingThreshold = 65280; // when exceeded enforces bindings via file on stage + var smallBatchRowCount = 2; + var bigBatchRowCount = bindingThreshold / 2; + s_logger.Info(testCase); + + using (IDbConnection conn = new SnowflakeDbConnection(ConnectionString)) + { + conn.Open(); + + conn.ExecuteNonQuery($"alter session set DOTNET_QUERY_RESULT_FORMAT = {resultFormat}"); + if (!timeZone.IsNullOrEmpty()) // Driver ignores this setting and relies on local environment timezone + conn.ExecuteNonQuery($"alter session set TIMEZONE = '{timeZone}'"); + + CreateOrReplaceTable(conn, + TableName, + tableType.TableDDLCreationPrefix(), + new[] { + "id number(10,0) not null primary key", // necessary only for HYBRID tables + $"ts {columnWithPrecision}" + }, + tableType.TableDDLCreationFlags()); + + // Act+Assert + var sqlInsert = $"insert into {TableName} (id, ts) values (?, ?)"; + InsertSingleRecord(conn, sqlInsert, bindingType, 1, expected); + InsertMultipleRecords(conn, sqlInsert, bindingType, 2, expected, smallBatchRowCount, false); + InsertMultipleRecords(conn, sqlInsert, bindingType, smallBatchRowCount+2, expected, bigBatchRowCount, true); + + // Assert + var row = 0; + using (var select = conn.CreateCommand($"select id, ts from {TableName} order by id")) + { + s_logger.Debug(select.CommandText); + var reader = select.ExecuteReader(); + while (reader.Read()) + { + ++row; + string faultMessage = $"Mismatch for row: {row}, {testCase}"; + Assert.AreEqual(row, reader.GetInt32(0)); + expected.AssertEqual(reader.GetValue(1), comparisonFormat, faultMessage); + } + } + Assert.AreEqual(1+smallBatchRowCount+bigBatchRowCount, row); + } + } + + private void InsertSingleRecord(IDbConnection conn, string sqlInsert, DbType binding, int identifier, ExpectedTimestampWrapper ts) + { + using (var insert = conn.CreateCommand(sqlInsert)) + { + // Arrange + insert.Add("1", DbType.Int32, identifier); + if (ExpectedTimestampWrapper.IsOffsetType(ts.ExpectedColumnType())) + { + var parameter = (SnowflakeDbParameter)insert.Add("2", binding, ts.GetDateTimeOffset()); + parameter.SFDataType = ts.ExpectedColumnType(); + } + else + { + insert.Add("2", binding, ts.GetDateTime()); + } + + // Act + s_logger.Info(sqlInsert); + var rowsAffected = insert.ExecuteNonQuery(); + + // Assert + Assert.AreEqual(1, rowsAffected); + Assert.IsNull(((SnowflakeDbCommand)insert).GetBindStage()); + } + } + + private void InsertMultipleRecords(IDbConnection conn, string sqlInsert, DbType binding, int initialIdentifier, ExpectedTimestampWrapper ts, int rowsCount, bool shouldUseBinding) + { + using (var insert = conn.CreateCommand(sqlInsert)) + { + // Arrange + insert.Add("1", DbType.Int32, Enumerable.Range(initialIdentifier, rowsCount).ToArray()); + if (ExpectedTimestampWrapper.IsOffsetType(ts.ExpectedColumnType())) + { + var parameter = (SnowflakeDbParameter)insert.Add("2", binding, Enumerable.Repeat(ts.GetDateTimeOffset(), rowsCount).ToArray()); + parameter.SFDataType = ts.ExpectedColumnType(); + } + else + { + insert.Add("2", binding, Enumerable.Repeat(ts.GetDateTime(), rowsCount).ToArray()); + } + + // Act + s_logger.Debug(sqlInsert); + var rowsAffected = insert.ExecuteNonQuery(); + + // Assert + Assert.AreEqual(rowsCount, rowsAffected); + if (shouldUseBinding) + Assert.IsNotEmpty(((SnowflakeDbCommand)insert).GetBindStage()); + else + Assert.IsNull(((SnowflakeDbCommand)insert).GetBindStage()); + } + } + + private static string ColumnTypeWithPrecision(SFDataType columnType, Int32? columnPrecision) + => columnPrecision != null ? $"{columnType}({columnPrecision})" : $"{columnType}"; + } + + class ExpectedTimestampWrapper + { + private readonly SFDataType _columnType; + private readonly DateTime? _expectedDateTime; + private readonly DateTimeOffset? _expectedDateTimeOffset; + + internal static ExpectedTimestampWrapper From(string timestampWithTimeZone, SFDataType columnType) + { + if (IsOffsetType(columnType)) + { + var dateTimeOffset = DateTimeOffset.ParseExact(timestampWithTimeZone, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture); + return new ExpectedTimestampWrapper(dateTimeOffset, columnType); + } + + var dateTime = DateTime.ParseExact(timestampWithTimeZone, "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture); + return new ExpectedTimestampWrapper(dateTime, columnType); + } + + private ExpectedTimestampWrapper(DateTime dateTime, SFDataType columnType) + { + _expectedDateTime = dateTime; + _expectedDateTimeOffset = null; + _columnType = columnType; + } + + private ExpectedTimestampWrapper(DateTimeOffset dateTimeOffset, SFDataType columnType) + { + _expectedDateTimeOffset = dateTimeOffset; + _expectedDateTime = null; + _columnType = columnType; + } + + internal SFDataType ExpectedColumnType() => _columnType; + + internal void AssertEqual(object actual, string comparisonFormat, string faultMessage) + { + switch (_columnType) + { + case SFDataType.TIMESTAMP_TZ: + Assert.AreEqual(GetDateTimeOffset().ToString(comparisonFormat), ((DateTimeOffset)actual).ToString(comparisonFormat), faultMessage); + break; + case SFDataType.TIMESTAMP_LTZ: + Assert.AreEqual(GetDateTimeOffset().ToUniversalTime().ToString(comparisonFormat), ((DateTimeOffset)actual).ToUniversalTime().ToString(comparisonFormat), faultMessage); + break; + default: + Assert.AreEqual(GetDateTime().ToString(comparisonFormat), ((DateTime)actual).ToString(comparisonFormat), faultMessage); + break; + } + } + + internal DateTime GetDateTime() => _expectedDateTime ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); + + internal DateTimeOffset GetDateTimeOffset() => _expectedDateTimeOffset ?? throw new Exception($"Column {_columnType} is not matching the expected value type {typeof(DateTime)}"); + + internal static bool IsOffsetType(SFDataType type) => type == SFDataType.TIMESTAMP_LTZ || type == SFDataType.TIMESTAMP_TZ; } } diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 01ae94501..0bb2e1555 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -12,6 +12,7 @@ using System.Runtime.InteropServices; using NUnit.Framework; using Snowflake.Data.Client; +using Snowflake.Data.Log; using Snowflake.Data.Tests.Util; [assembly:LevelOfParallelism(10)] @@ -56,6 +57,8 @@ public static void TearDownContext() #endif public class SFBaseTestAsync { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private const string ConnectionStringWithoutAuthFmt = "scheme={0};host={1};port={2};" + "account={3};role={4};db={5};schema={6};warehouse={7}"; private const string ConnectionStringSnowflakeAuthFmt = ";user={0};password={1};"; @@ -106,10 +109,16 @@ private void RemoveTables() } protected void CreateOrReplaceTable(IDbConnection conn, string tableName, IEnumerable columns, string additionalQueryStr = null) + { + CreateOrReplaceTable(conn, tableName, "", columns, additionalQueryStr); + } + + protected void CreateOrReplaceTable(IDbConnection conn, string tableName, string tableType, IEnumerable columns, string additionalQueryStr = null) { var columnsStr = string.Join(", ", columns); var cmd = conn.CreateCommand(); - cmd.CommandText = $"CREATE OR REPLACE TABLE {tableName}({columnsStr}) {additionalQueryStr}"; + cmd.CommandText = $"CREATE OR REPLACE {tableType} TABLE {tableName}({columnsStr}) {additionalQueryStr}"; + s_logger.Debug(cmd.CommandText); cmd.ExecuteNonQuery(); _tablesToRemove.Add(tableName); diff --git a/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs b/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs new file mode 100644 index 000000000..ac5172086 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SFBindUploaderTest.cs @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using NUnit.Framework; +using Snowflake.Data.Core; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + class SFBindUploaderTest + { + private readonly SFBindUploader _bindUploader = new SFBindUploader(null, "test"); + + [TestCase(SFDataType.DATE, "0", "1/1/1970")] + [TestCase(SFDataType.DATE, "73785600000", "5/4/1972")] + [TestCase(SFDataType.DATE, "1709164800000", "2/29/2024")] + public void TestCsvDataConversionForDate(SFDataType dbType, string input, string expected) + { + // Arrange + var dateExpected = DateTime.Parse(expected); + var check = SFDataConverter.csharpValToSfVal(SFDataType.DATE, dateExpected); + Assert.AreEqual(check, input); + // Act + DateTime dateActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); + // Assert + Assert.AreEqual(dateExpected, dateActual); + } + + [TestCase(SFDataType.TIME, "0", "00:00:00.000000")] + [TestCase(SFDataType.TIME, "100000000", "00:00:00.100000")] + [TestCase(SFDataType.TIME, "1000000000", "00:00:01.000000")] + [TestCase(SFDataType.TIME, "60123456000", "00:01:00.123456")] + [TestCase(SFDataType.TIME, "46801000000000", "13:00:01.000000")] + public void TestCsvDataConversionForTime(SFDataType dbType, string input, string expected) + { + // Arrange + DateTime timeExpected = DateTime.Parse(expected); + var check = SFDataConverter.csharpValToSfVal(SFDataType.TIME, timeExpected); + Assert.AreEqual(check, input); + // Act + DateTime timeActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); + // Assert + Assert.AreEqual(timeExpected, timeActual); + } + + [TestCase(SFDataType.TIMESTAMP_LTZ, "39600000000000", "1970-01-01T12:00:00.0000000+01:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "1341136800000000000", "2012-07-01T12:00:00.0000000+02:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "352245599987654000", "1981-02-28T23:59:59.9876540+02:00")] + [TestCase(SFDataType.TIMESTAMP_LTZ, "1678868249207000000", "2023/03/15T13:17:29.207+05:00")] + public void TestCsvDataConversionForTimestampLtz(SFDataType dbType, string input, string expected) + { + // Arrange + var timestampExpected = DateTimeOffset.Parse(expected); + var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_LTZ, timestampExpected); + Assert.AreEqual(check, input); + // Act + var timestampActual = DateTimeOffset.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); + // Assert + Assert.AreEqual(timestampExpected.ToLocalTime(), timestampActual); + } + + [TestCase(SFDataType.TIMESTAMP_TZ, "1341136800000000000 1560", "2012-07-01 12:00:00.000000 +02:00")] + [TestCase(SFDataType.TIMESTAMP_TZ, "352245599987654000 1560", "1981-02-28 23:59:59.987654 +02:00")] + public void TestCsvDataConversionForTimestampTz(SFDataType dbType, string input, string expected) + { + // Arrange + DateTimeOffset timestampExpected = DateTimeOffset.Parse(expected); + var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_TZ, timestampExpected); + Assert.AreEqual(check, input); + // Act + DateTimeOffset timestampActual = DateTimeOffset.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); + // Assert + Assert.AreEqual(timestampExpected, timestampActual); + } + + [TestCase(SFDataType.TIMESTAMP_NTZ, "1341144000000000000", "2012-07-01 12:00:00.000000")] + [TestCase(SFDataType.TIMESTAMP_NTZ, "352252799987654000", "1981-02-28 23:59:59.987654")] + public void TestCsvDataConversionForTimestampNtz(SFDataType dbType, string input, string expected) + { + // Arrange + DateTime timestampExpected = DateTime.Parse(expected); + var check = SFDataConverter.csharpValToSfVal(SFDataType.TIMESTAMP_NTZ, timestampExpected); + Assert.AreEqual(check, input); + // Act + DateTime timestampActual = DateTime.Parse(_bindUploader.GetCSVData(dbType.ToString(), input)); + // Assert + Assert.AreEqual(timestampExpected, timestampActual); + } + + [TestCase(SFDataType.TEXT, "", "\"\"")] + [TestCase(SFDataType.TEXT, "\"", "\"\"\"\"")] + [TestCase(SFDataType.TEXT, "\n", "\"\n\"")] + [TestCase(SFDataType.TEXT, "\t", "\"\t\"")] + [TestCase(SFDataType.TEXT, ",", "\",\"")] + [TestCase(SFDataType.TEXT, "Sample text", "Sample text")] + public void TestCsvDataConversionForText(SFDataType dbType, string input, string expected) + { + // Act + var actual = _bindUploader.GetCSVData(dbType.ToString(), input); + // Assert + Assert.AreEqual(expected, actual); + } + + } +} diff --git a/Snowflake.Data.Tests/Util/DbCommandExtensions.cs b/Snowflake.Data.Tests/Util/DbCommandExtensions.cs new file mode 100644 index 000000000..fb336d5c3 --- /dev/null +++ b/Snowflake.Data.Tests/Util/DbCommandExtensions.cs @@ -0,0 +1,18 @@ +using System.Data; + +namespace Snowflake.Data.Tests.Util +{ + public static class DbCommandExtensions + { + internal static IDbDataParameter Add(this IDbCommand command, string name, DbType dbType, object value) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.DbType = dbType; + parameter.Value = value; + command.Parameters.Add(parameter); + return parameter; + } + + } +} diff --git a/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs b/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs new file mode 100644 index 000000000..02b7e47dd --- /dev/null +++ b/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs @@ -0,0 +1,23 @@ +using System.Data; + +namespace Snowflake.Data.Tests.Util +{ + public static class DbConnectionExtensions + { + internal static IDbCommand CreateCommand(this IDbConnection connection, string commandText) + { + var command = connection.CreateCommand(); + command.Connection = connection; + command.CommandText = commandText; + return command; + } + + internal static int ExecuteNonQuery(this IDbConnection connection, string commandText) + { + var command = connection.CreateCommand(); + command.Connection = connection; + command.CommandText = commandText; + return command.ExecuteNonQuery(); + } + } +} diff --git a/Snowflake.Data.Tests/Util/TableTypeExtensions.cs b/Snowflake.Data.Tests/Util/TableTypeExtensions.cs new file mode 100644 index 000000000..4c00f3a1d --- /dev/null +++ b/Snowflake.Data.Tests/Util/TableTypeExtensions.cs @@ -0,0 +1,30 @@ +using System; +using NUnit.Framework; + +namespace Snowflake.Data.Tests.Util +{ + public enum SFTableType + { + Standard, + Hybrid, + Iceberg + } + + static class TableTypeExtensions + { + internal static string TableDDLCreationPrefix(this SFTableType val) => val == SFTableType.Standard ? "" : val.ToString().ToUpper(); + + internal static string TableDDLCreationFlags(this SFTableType val) + { + if (val != SFTableType.Iceberg) + return ""; + var externalVolume = Environment.GetEnvironmentVariable("ICEBERG_EXTERNAL_VOLUME"); + var catalog = Environment.GetEnvironmentVariable("ICEBERG_CATALOG"); + var baseLocation = Environment.GetEnvironmentVariable("ICEBERG_BASE_LOCATION"); + Assert.IsNotNull(externalVolume, "env ICEBERG_EXTERNAL_VOLUME not set!"); + Assert.IsNotNull(catalog, "env ICEBERG_CATALOG not set!"); + Assert.IsNotNull(baseLocation, "env ICEBERG_BASE_LOCATION not set!"); + return $"EXTERNAL_VOLUME = '{externalVolume}' CATALOG = '{catalog}' BASE_LOCATION = '{baseLocation}'"; + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbCommand.cs b/Snowflake.Data/Client/SnowflakeDbCommand.cs index ce004df5c..36a04f151 100755 --- a/Snowflake.Data/Client/SnowflakeDbCommand.cs +++ b/Snowflake.Data/Client/SnowflakeDbCommand.cs @@ -460,5 +460,7 @@ private void CheckIfCommandTextIsSet() throw new Exception(errorMessage); } } + + internal string GetBindStage() => sfStatement?.GetBindStage(); } } diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index a1b3f161d..71dec60fb 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -224,13 +224,13 @@ internal async Task UploadStreamAsync(MemoryStream stream, string destFileName, statement.SetUploadStream(stream, destFileName, stagePath); await statement.ExecuteTransferAsync(putStmt, cancellationToken).ConfigureAwait(false); } - private string GetCSVData(string sType, string sValue) + + internal string GetCSVData(string sType, string sValue) { if (sValue == null) return sValue; - DateTime dateTime = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Unspecified); - DateTimeOffset dateTimeOffset = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Unspecified); + DateTime epoch = SFDataConverter.UnixEpoch; switch (sType) { case "TEXT": @@ -246,33 +246,29 @@ private string GetCSVData(string sType, string sValue) return '"' + sValue.Replace("\"", "\"\"") + '"'; return sValue; case "DATE": - long dateLong = long.Parse(sValue); - DateTime date = dateTime.AddMilliseconds(dateLong).ToUniversalTime(); + long msFromEpoch = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ms] from Epoch + DateTime date = epoch.AddMilliseconds(msFromEpoch); return date.ToShortDateString(); case "TIME": - long timeLong = long.Parse(sValue); - DateTime time = dateTime.AddMilliseconds(timeLong).ToUniversalTime(); - return time.ToLongTimeString(); + long nsSinceMidnight = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Midnight + DateTime time = epoch.AddTicks(nsSinceMidnight/100); + return time.ToString("HH:mm:ss.fffffff"); case "TIMESTAMP_LTZ": - long ltzLong = long.Parse(sValue); - TimeSpan ltzts = new TimeSpan(ltzLong / 100); - DateTime ltzdt = dateTime + ltzts; - return ltzdt.ToString(); + long nsFromEpochLtz = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Epoch + DateTime ltz = epoch.AddTicks(nsFromEpochLtz/100); + return ltz.ToLocalTime().ToString("O"); // ISO 8601 format case "TIMESTAMP_NTZ": - long ntzLong = long.Parse(sValue); - TimeSpan ts = new TimeSpan(ntzLong/100); - DateTime dt = dateTime + ts; - return dt.ToString("yyyy-MM-dd HH:mm:ss.fffffff"); + long nsFromEpochNtz = long.Parse(sValue); // SFDateConverter.csharpValToSfVal provides in [ns] from Epoch + DateTime ntz = epoch.AddTicks(nsFromEpochNtz/100); + return ntz.ToString("yyyy-MM-dd HH:mm:ss.fffffff"); case "TIMESTAMP_TZ": string[] tstzString = sValue.Split(' '); - long tzLong = long.Parse(tstzString[0]); - int tzInt = (int.Parse(tstzString[1]) - 1440) / 60; - TimeSpan tzts = new TimeSpan(tzLong/100); - DateTime tzdt = dateTime + tzts; - TimeSpan tz = new TimeSpan(tzInt, 0, 0); - DateTimeOffset tzDateTimeOffset = new DateTimeOffset(tzdt, tz); + long nsFromEpochTz = long.Parse(tstzString[0]); // SFDateConverter provides in [ns] from Epoch + int timeZoneOffset = int.Parse(tstzString[1]) - 1440; // SFDateConverter provides in minutes increased by 1440m + DateTime timestamp = epoch.AddTicks(nsFromEpochTz/100).AddMinutes(timeZoneOffset); + TimeSpan offset = TimeSpan.FromMinutes(timeZoneOffset); + DateTimeOffset tzDateTimeOffset = new DateTimeOffset(timestamp.Ticks, offset); return tzDateTimeOffset.ToString("yyyy-MM-dd HH:mm:ss.fffffff zzz"); - } return sValue; } diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index 9252af40e..05e905263 100644 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -147,6 +147,8 @@ internal SFStatement(SFSession session) _restRequester = session.restRequester; } + internal string GetBindStage() => _bindStage; + private void AssignQueryRequestId() { lock (_requestIdLock) From e09283551ee6c07e27141b2154ebc0b83de554a8 Mon Sep 17 00:00:00 2001 From: Juan Martinez Ramirez <126511805+sfc-gh-jmartinez@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:04:53 -0600 Subject: [PATCH 07/12] SNOW-977565: Added start and end symbol to match full regex to bypass proxy server. (#899) ### Description Added start and end symbol to match string with full regex to bypass proxy server. ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- README.md | 18 +++++- .../IntegrationTests/SFConnectionIT.cs | 57 ++++++++++++++++++- Snowflake.Data.Tests/SFBaseTest.cs | 5 ++ Snowflake.Data/Core/HttpUtil.cs | 7 ++- 4 files changed, 83 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index cc7aba2ef..7c7e65356 100644 --- a/README.md +++ b/README.md @@ -163,7 +163,7 @@ The following table lists all valid connection properties: | PROXYPORT | Depends | The port number of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | | PROXYUSER | No | The username for authenticating to the proxy server.

This parameter was introduced in v2.0.4. | | PROXYPASSWORD | Depends | The password for authenticating to the proxy server.

If USEPROXY is `true` and PROXYUSER is set, you must set this parameter.

This parameter was introduced in v2.0.4. | -| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.

This parameter was introduced in v2.0.4. | +| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.
The host target value should fully match with any item from the proxy host list to bypass the proxy server.

This parameter was introduced in v2.0.4. | | FILE_TRANSFER_MEMORY_THRESHOLD | No | The maximum number of bytes to store in memory used in order to provide a file encryption. If encrypting/decrypting file size exceeds provided value a temporary file will be created and the work will be continued in the temporary file instead of memory.
If no value provided 1MB will be used as a default value (that is 1048576 bytes).
It is possible to configure any integer value bigger than zero representing maximal number of bytes to reside in memory. | | CLIENT_CONFIG_FILE | No | The location of the client configuration json file. In this file you can configure easy logging feature. | | ALLOWUNDERSCORESINHOST | No | Specifies whether to allow underscores in account names. This impacts PrivateLink customers whose account names contain underscores. In this situation, you must override the default value by setting allowUnderscoresInHost to true. | @@ -377,6 +377,22 @@ using (IDbConnection conn = new SnowflakeDbConnection()) } ``` +The NONPROXYHOSTS property could be set to specify if the server proxy should be bypassed by an specified host. This should be defined using the full host url or including the url + `*` wilcard symbol. + +Examples: + +- `*` (Bypassed all hosts from the proxy server) +- `*.snowflakecomputing.com` ('Bypass all host that ends with `snowflakecomputing.com`') +- `https:\\testaccount.snowflakecomputing.com` (Bypass proxy server using full host url). +- `*.myserver.com | *testaccount*` (You can specify multiple regex for the property divided by `|`) + + +> Note: The nonproxyhost value should match the full url including the http or https section. The '*' wilcard could be added to bypass the hostname successfully. + +- `myaccount.snowflakecomputing.com` (Not bypassed). +- `*myaccount.snowflakecomputing.com` (Bypassed). + + ## Using Connection Pools Instead of creating a connection each time your client application needs to access Snowflake, you can define a cache of Snowflake connections that can be reused as needed. Connection pooling usually reduces the lag time to make a connection. However, it can slow down client failover to an alternative DNS when a DNS problem occurs. diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index 8d69fe606..c248ef575 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -1540,6 +1540,60 @@ public void TestInvalidProxySettingFromConnectionString() } } + [Test] + [TestCase("*")] + [TestCase("*{0}*")] + [TestCase("^*{0}*")] + [TestCase("*{0}*$")] + [TestCase("^*{0}*$")] + [TestCase("^nonmatch*{0}$|*")] + [TestCase("*a*", "a")] + [TestCase("*la*", "la")] + public void TestNonProxyHostShouldBypassProxyServer(string regexHost, string proxyHost = "proxyserverhost") + { + using (var conn = new SnowflakeDbConnection()) + { + // Arrange + var host = ResolveHost(); + var nonProxyHosts = string.Format(regexHost, $"{host}"); + conn.ConnectionString = + $"{ConnectionString}USEPROXY=true;PROXYHOST={proxyHost};NONPROXYHOSTS={nonProxyHosts};PROXYPORT=3128;"; + + // Act + conn.Open(); + + // Assert + // The connection would fail to open if the web proxy would be used because the proxy is configured to a non-existent host. + Assert.AreEqual(ConnectionState.Open, conn.State); + } + } + + [Test] + [TestCase("invalid{0}")] + [TestCase("*invalid{0}*")] + [TestCase("^invalid{0}$")] + [TestCase("*a.b")] + [TestCase("a", "a")] + [TestCase("la", "la")] + public void TestNonProxyHostShouldNotBypassProxyServer(string regexHost, string proxyHost = "proxyserverhost") + { + using (var conn = new SnowflakeDbConnection()) + { + // Arrange + var nonProxyHosts = string.Format(regexHost, $"{testConfig.host}"); + conn.ConnectionString = + $"{ConnectionString}connection_timeout=5;USEPROXY=true;PROXYHOST={proxyHost};NONPROXYHOSTS={nonProxyHosts};PROXYPORT=3128;"; + + // Act/Assert + // The connection would fail to open if the web proxy would be used because the proxy is configured to a non-existent host. + var exception = Assert.Throws(() => conn.Open()); + + // Assert + Assert.AreEqual(270001, exception.ErrorCode); + AssertIsConnectionFailure(exception); + } + } + [Test] public void TestUseProxyFalseWithInvalidProxyConnectionString() { @@ -1561,7 +1615,7 @@ public void TestInvalidProxySettingWithByPassListFromConnectionString() = ConnectionString + String.Format( ";useProxy=true;proxyHost=Invalid;proxyPort=8080;nonProxyHosts={0}", - "*.foo.com %7C" + testConfig.account + ".snowflakecomputing.com|" + testConfig.host); + $"*.foo.com %7C{testConfig.account}.snowflakecomputing.com|*{testConfig.host}*"); conn.Open(); // Because testConfig.host is in the bypass list, the proxy should not be used } @@ -2169,6 +2223,7 @@ public void TestNativeOktaSuccess() Assert.AreEqual(ConnectionState.Open, conn.State); } } + } } diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 0bb2e1555..6aacb94f9 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -155,6 +155,11 @@ public SFBaseTestAsync() testConfig.password); protected TestConfig testConfig { get; } + + protected string ResolveHost() + { + return testConfig.host ?? $"{testConfig.account}.snowflakecomputing.com"; + } } [SetUpFixture] diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 9c5e22442..531e76fd7 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -12,7 +12,6 @@ using System.Collections.Specialized; using System.Web; using System.Security.Authentication; -using System.Runtime.InteropServices; using System.Linq; using Snowflake.Data.Core.Authenticator; @@ -186,7 +185,11 @@ internal HttpMessageHandler SetupCustomHttpHandler(HttpClientConfig config) entry = entry.Replace(".", "[.]"); // * -> .* because * is a quantifier and need a char or group to apply to entry = entry.Replace("*", ".*"); - + + entry = entry.StartsWith("^") ? entry : $"^{entry}"; + + entry = entry.EndsWith("$") ? entry : $"{entry}$"; + // Replace with the valid entry syntax bypassList[i] = entry; From aa5b1015a76e42b76c45ce25f90ba06bcff34334 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Thu, 4 Apr 2024 22:51:14 +0200 Subject: [PATCH 08/12] SNOW-1293828 Updated security policy (#904) ### Description Security policy information ### Checklist - [ ] Code compiles correctly - [ ] Code is formatted according to [Coding Conventions](../CodingConventions.md) - [ ] Created tests which fail without the change (if possible) - [ ] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- SECURITY.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..1940f0be4 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,4 @@ +# Security Policy + +Please refer to the Snowflake [HackerOne program](https://hackerone.com/snowflake?type=team) for our security policies and for reporting any security vulnerabilities. +For other security related questions and concerns, please contact the Snowflake security team at security@snowflake.com From b703f9375e0963e49fc637f31cc53e95e5050d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Hofman?= Date: Mon, 15 Apr 2024 14:14:13 +0200 Subject: [PATCH 09/12] SNOW-1168205 iceberg table compliance testing (#908) ### Description Introduction of ICEBERG table CRUD/Structure type/bindings testing ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../blob/master/CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- Snowflake.Data.Tests/App.config | 160 +++--- .../IcebergTests/TestIcebergTable.cs | 539 ++++++++++++++++++ .../Snowflake.Data.Tests.csproj | 2 +- .../Util/DbCommandExtensions.cs | 5 +- .../Util/DbConnectionExtensions.cs | 17 +- Snowflake.Data.Tests/Util/TestDataHelpers.cs | 44 ++ Snowflake.Data/Core/SFDataConverter.cs | 2 +- 7 files changed, 681 insertions(+), 88 deletions(-) create mode 100644 Snowflake.Data.Tests/IcebergTests/TestIcebergTable.cs create mode 100644 Snowflake.Data.Tests/Util/TestDataHelpers.cs diff --git a/Snowflake.Data.Tests/App.config b/Snowflake.Data.Tests/App.config index a7920fad2..5e3dd1335 100755 --- a/Snowflake.Data.Tests/App.config +++ b/Snowflake.Data.Tests/App.config @@ -1,80 +1,80 @@ - - - - -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Snowflake.Data.Tests/IcebergTests/TestIcebergTable.cs b/Snowflake.Data.Tests/IcebergTests/TestIcebergTable.cs new file mode 100644 index 000000000..44e4e6229 --- /dev/null +++ b/Snowflake.Data.Tests/IcebergTests/TestIcebergTable.cs @@ -0,0 +1,539 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Globalization; +using System.Linq; +using System.Text; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Tests.Util; +using static Snowflake.Data.Tests.Util.TestData; + +namespace Snowflake.Data.Tests.IcebergTests +{ + [TestFixture(ResultFormat.ARROW)] + [TestFixture(ResultFormat.JSON)] + [NonParallelizable] + public class TestIcebergTable : SFBaseTest + { + private const string TableNameIceberg = "DOTNET_TEST_DATA_IB"; + private const string TableNameHybrid = "DOTNET_TEST_DATA_HY"; + private const string SqlCreateIcebergTableColumns = @"nu1 number(10,0), + nu2 number(19,0), + nu3 number(18,2), + nu4 number(38,0), + f float, + tx varchar(16777216), + bt boolean, + bf boolean, + dt date, + tm time, + ntz timestamp_ntz(6), + ltz timestamp_ltz(6), + bi binary(5), + ar array(number(10,0)), + ob object(a number(10,0), b varchar), + ma map(varchar, varchar)"; + private const string SqlCreateHybridTableColumns = @"id number(10,0) not null primary key, + nu number(10,0), + tx2 varchar(100)"; + private const string IcebergTableCreateFlags = "external_volume = 'demo_exvol' catalog = 'snowflake' base_location = 'x/'"; + private const string SqlColumnsSimpleTypes = "nu1,nu2,nu3,nu4,f,tx,bt,bf,dt,tm,ntz,ltz,bi"; + private const string SqlColumnsHybridTypes = "id,nu,tx2"; + private const string SqlColumnsStructureTypes = "ar,ob,ma"; + private const int I32 = 1; + private const long I64 = 9223372036854775807; + private const decimal Dec = (decimal)2.67; + private const double Dbl = 3.333e8; + private const float Flt = -1.0e7f; + private const string Txt = "Sample text"; + private const bool B1 = true; + private const bool B0 = false; + private const int Id1 = 1; + private const int Id2 = 2; + private const string Txt1 = "sample text for join1"; + private const string Txt2 = "sample text for join2"; + private static readonly DateTime s_ts = DateTime.ParseExact("2023/03/15 13:17:29.207", "yyyy/MM/dd HH:mm:ss.fff", CultureInfo.InvariantCulture); + private readonly DateTime _dt = s_ts.Date; + private readonly DateTime _tm = s_ts; + private readonly DateTime _ntz = s_ts; + private readonly DateTimeOffset _ltz = DateTimeOffset.ParseExact("2023/03/15 13:17:29.207 +05:00", "yyyy/MM/dd HH:mm:ss.fff zzz", CultureInfo.InvariantCulture); + private readonly byte[] _bi = Encoding.Default.GetBytes("flake"); + private readonly ResultFormat _resultFormat; + private const string FormatYmd = "yyyy-MM-dd"; + private const string FormatHms = "HH:mm:ss"; + private const string FormatYmdHms = "yyyy-MM-dd HH:mm:ss"; + private const string FormatYmdHmsf = "yyyy-MM-dd HH:mm:ss.fffffff"; + private const string FormatYmdHmsfZ = "yyyy-MM-dd HH:mm:ss.fffffff zzz"; + + public TestIcebergTable(ResultFormat resultFormat) + { + _resultFormat = resultFormat; + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestInsertPlainText() + { + // Arrange + using (var conn = OpenConnection()) + { + CreateIcebergTable(conn); + SetResultFormat(conn); + + // Act + conn.ExecuteNonQuery(@$"insert into {TableNameIceberg} ({SqlColumnsSimpleTypes}) + values ({I32}, {I64}, {Dec}, {Dbl}, {Flt}, '{Txt}', {B1}, {B0}, + '{_dt.ToString(FormatYmd)}', + '{_tm.ToString(FormatHms)}', + '{_ntz.ToString(FormatYmdHms)}', + '{_ltz.ToString(FormatYmdHmsfZ)}', + '{ByteArrayToHexString(_bi)}')"); + + // Assert + var reader = conn.ExecuteReader($"select {SqlColumnsSimpleTypes} from {TableNameIceberg}"); + int rowsRead = 0; + while (reader.Read()) + { + rowsRead++; + AssertRowValuesEqual(reader, SqlCreateIcebergTableColumns.Split('\n'), I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + } + Assert.AreEqual(1, rowsRead); + } + } + + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestInsertWithValueBinding() + { + // Arrange + using (var conn = OpenConnection()) + { + CreateIcebergTable(conn); + SetResultFormat(conn); + + // Act + InsertSingleRow(conn, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + + // Assert + var reader = conn.ExecuteReader($"select {SqlColumnsSimpleTypes} from {TableNameIceberg}"); + int rowsRead = 0; + while (reader.Read()) + { + rowsRead++; + AssertRowValuesEqual(reader, SqlCreateIcebergTableColumns.Split('\n'), I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + } + Assert.AreEqual(1, rowsRead); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestUpdateWithValueBinding() + { + // Arrange + var i32 = I32 * 2; + var i64 = I32; + var dec = Dec + (decimal)0.1; + var dbl = Dbl / 16; + var flt = Flt * 2.5; + var txt = Txt + " updated"; + var b1 = !B1; + var b0 = !B0; + var dt = _dt.Add(TimeSpan.FromDays(3)); + var tm = _tm.AddMinutes(7); + var ntz = _ntz.Add(TimeSpan.FromDays(10)); + var ltz = _ltz.Subtract(TimeSpan.FromSeconds(37)); + var bi = Encoding.Default.GetBytes("Snow"); + using (var conn = OpenConnection()) + { + CreateIcebergTable(conn); + SetResultFormat(conn); + InsertSingleRow(conn, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + + // Act + using (var cmd = conn.CreateCommand($"update {TableNameIceberg} set nu1=?,nu2=?,nu3=?,nu4=?,f=?,tx=?,bt=?,bf=?,dt=?,tm=?,ntz=?,ltz=?,bi=? where nu1=? and (bt=? or dt=?)")) + { + cmd.Add("1", DbType.Int32, i32); + cmd.Add("2", DbType.Int64, i64); + cmd.Add("3", DbType.Decimal, dec); + cmd.Add("4", DbType.Double, dbl); + cmd.Add("5", DbType.Double, flt); + cmd.Add("6", DbType.String, txt); + cmd.Add("7", DbType.Boolean, b1); + cmd.Add("8", DbType.Boolean, b0); + cmd.Add("9", DbType.Date, dt); + cmd.Add("10", DbType.Time, tm); + cmd.Add("11", DbType.DateTime, ntz); + cmd.Add("12", DbType.DateTime, ltz).SFDataType = SFDataType.TIMESTAMP_LTZ; + cmd.Add("13", DbType.Binary, bi); + cmd.Add("14", DbType.Int32, I32); + cmd.Add("15", DbType.Boolean, B1); + cmd.Add("16", DbType.Date, _dt); + Assert.AreEqual(1, cmd.ExecuteNonQuery()); + } + + // Assert + var reader = conn.ExecuteReader($"select {SqlColumnsSimpleTypes} from {TableNameIceberg}"); + int rowsRead = 0; + while (reader.Read()) + { + rowsRead++; + AssertRowValuesEqual(reader, SqlCreateIcebergTableColumns.Split('\n'), i32, i64, dec, dbl, flt, txt, b1, b0, dt, tm, ntz, ltz, bi); + } + Assert.AreEqual(1, rowsRead); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestJoin() + { + using (var conn = OpenConnection()) + { + // Arrange + CreateIcebergTable(conn); + CreateHybridTable(conn); + InsertManyRows(conn, 10, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm,_ntz,_ltz,_bi); + InsertHybridTableData(conn); + SetResultFormat(conn); + + // Act + var sql = @$"select i.nu1,i.nu2,i.nu3,i.nu4,i.f,i.tx,i.bt,i.bf,i.dt,i.tm,i.ntz,i.ltz,i.bi, h.id,h.nu,h.tx2 + from {TableNameIceberg} i + join {TableNameHybrid} h + on i.nu1 = h.nu order by i.nu1"; + + // Assert + var resultSetColumns = @"nu1 number(10,0), + nu2 number(19,0), + nu3 number(18,2), + nu4 number(38,0), + f float, + tx varchar(16777216), + bt boolean, + bf boolean, + dt date, + tm time, + ntz timestamp_ntz(6), + ltz timestamp_ltz(6), + bi binary(5), + id number(10,0), + nu number(10,0), + tx2 varchar(100)".Split('\n'); + var reader = (DbDataReader)conn.ExecuteReader(sql); + Assert.AreEqual(true, reader.Read()); + AssertRowValuesEqual(reader, resultSetColumns, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi, Id1, I32, Txt1); + Assert.AreEqual(true, reader.Read()); + AssertRowValuesEqual(reader, resultSetColumns, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi, Id2, I32, Txt2); + Assert.AreEqual(false, reader.Read()); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestDelete() + { + using (var conn = OpenConnection()) + { + // Arrange + CreateIcebergTable(conn); + InsertManyRows(conn, 100, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + SetResultFormat(conn); + + // Act + var cmd = conn.CreateCommand($"delete from {TableNameIceberg} where nu1 = ?"); + cmd.Add("1", DbType.Int32, I32); + var removed = cmd.ExecuteReader(); + + // Assert + Assert.AreEqual(1, removed.RecordsAffected); + var left = conn.ExecuteReader($"select count(*) from {TableNameIceberg} where nu1 <> {I32}"); + Assert.AreEqual(true, left.Read()); + Assert.AreEqual(99, left.GetInt32(0)); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestDeleteAll() + { + using (var conn = OpenConnection()) + { + // Arrange + CreateIcebergTable(conn); + InsertManyRows(conn, 100, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + SetResultFormat(conn); + + // Act + var cmd = conn.CreateCommand($@"delete from {TableNameIceberg}"); + var removed = cmd.ExecuteReader(); + + // Assert + Assert.AreEqual(100, removed.RecordsAffected); + var left = conn.ExecuteReader($"select count(*) from {TableNameIceberg}"); + Assert.AreEqual(true, left.Read()); + Assert.AreEqual(0, left.GetInt32(0)); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestMultiStatement() + { + using (var conn = OpenConnection()) + { + // Arrange + CreateIcebergTable(conn); + InsertSingleRow(conn, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + SetResultFormat(conn); + + // Act + var cmd = conn.CreateCommand($"select * from {TableNameIceberg};select 1;select current_timestamp;select * from {TableNameIceberg}"); + cmd.Add("MULTI_STATEMENT_COUNT", DbType.Int32, 4); + var reader = cmd.ExecuteReader(); + + // Assert + int rowsRead = 0; + while (reader.Read()) + { + rowsRead++; + AssertRowValuesEqual((DbDataReader)reader, SqlCreateIcebergTableColumns.Split('\n'), I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + } + Assert.AreEqual(1, rowsRead); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestBatchInsertForLargeData() + { + using (var conn = OpenConnection()) + { + // Arrange + CreateIcebergTable(conn); + SetResultFormat(conn); + InsertManyRowsWithNulls(conn, 20_000, I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi); + + // Act + var reader = conn.ExecuteReader($"select {SqlColumnsSimpleTypes} from {TableNameIceberg} order by nu1"); + + // Assert + var resultSetColumns = SqlCreateIcebergTableColumns.Split('\n'); + var expected = new object[] {I32, I64, Dec, Dbl, Flt, Txt, B1, B0, _dt, _tm, _ntz, _ltz, _bi}; + var rowsRead = 0; + while (reader.Read()) + { + ++rowsRead; + expected[0] = rowsRead; + var expectedRow = NullEachNthValueBesidesFirst(expected, rowsRead-1); + AssertRowValuesEqual(reader, resultSetColumns, expectedRow); + } + Assert.AreEqual(20_000, rowsRead); + } + } + + [Test] + [Ignore("TODO: Enable when features available on the automated tests environment")] + public void TestStructuredTypesAsJsonString() + { + using (var conn = OpenConnection()) + { + SetResultFormat(conn); + CreateIcebergTable(conn); + var sql = @$"insert into {TableNameIceberg} ({SqlColumnsStructureTypes}) + select + [1,2,3]::ARRAY(number), + {{'a' : 1, 'b': 'two'}}::OBJECT(a number, b varchar), + {{'4':'one', '5': 'two', '6': 'three'}}::MAP(varchar, varchar) + "; + conn.ExecuteNonQuery(sql); + + var dbDataReader = conn.ExecuteReader($"select {SqlColumnsStructureTypes} from {TableNameIceberg}"); + int rowsRead = 0; + while (dbDataReader.Read()) + { + rowsRead++; + Assert.AreEqual("[1,2,3]", RemoveBlanks(dbDataReader.GetString(0))); + Assert.AreEqual("{\"a\":1,\"b\":\"two\"}", RemoveBlanks(dbDataReader.GetString(1))); + Assert.AreEqual("{\"4\":\"one\",\"5\":\"two\",\"6\":\"three\"}", RemoveBlanks(dbDataReader.GetString(2))); + } + Assert.AreEqual(1, rowsRead); + } + } + + private void CreateIcebergTable(SnowflakeDbConnection conn) + => conn.ExecuteNonQuery($"create or replace iceberg table {TableNameIceberg} ({SqlCreateIcebergTableColumns}) {IcebergTableCreateFlags}"); + + private void CreateHybridTable(SnowflakeDbConnection conn) + => conn.ExecuteNonQuery($"create or replace hybrid table {TableNameHybrid} ({SqlCreateHybridTableColumns})"); + + private void SetResultFormat(SnowflakeDbConnection conn) + => conn.ExecuteNonQuery($"alter session set DOTNET_QUERY_RESULT_FORMAT={_resultFormat}"); + + private SnowflakeDbConnection OpenConnection() + { + var conn = new SnowflakeDbConnection(ConnectionString); + conn.Open(); + return conn; + } + + private void InsertSingleRow(SnowflakeDbConnection conn, params object[] bindings) + { + Assert.AreEqual(13, bindings.Length); + var sqlInsert = $"insert into {TableNameIceberg} ({SqlColumnsSimpleTypes}) values (?,?,?,?,?,?,?,?,?,?,?,?,?)"; + using (var cmd = conn.CreateCommand(sqlInsert)) + { + cmd.Add("1", DbType.Int32, bindings[0]); + cmd.Add("2", DbType.Int64, bindings[1]); + cmd.Add("3", DbType.Decimal, bindings[2]); + cmd.Add("4", DbType.Double, bindings[3]); + cmd.Add("5", DbType.Double, bindings[4]); + cmd.Add("6", DbType.String, bindings[5]); + cmd.Add("7", DbType.Boolean, bindings[6]); + cmd.Add("8", DbType.Boolean, bindings[7]); + cmd.Add("9", DbType.DateTime, bindings[8]); + cmd.Add("10", DbType.DateTime, bindings[9]); + cmd.Add("11", DbType.DateTime, bindings[10]); + cmd.Add("12", DbType.DateTimeOffset, bindings[11]).SFDataType = SFDataType.TIMESTAMP_LTZ; + cmd.Add("13", DbType.Binary, bindings[12]); + Assert.AreEqual(1, cmd.ExecuteNonQuery()); + } + } + + private void InsertManyRows(SnowflakeDbConnection conn, int times, params object[] bindings) + { + Assert.AreEqual(13, bindings.Length); + var sqlInsert = $"insert into {TableNameIceberg} ({SqlColumnsSimpleTypes}) values (?,?,?,?,?,?,?,?,?,?,?,?,?)"; + using (var cmd = conn.CreateCommand(sqlInsert)) + { + cmd.Add("1", DbType.Int32, Enumerable.Range((int)bindings[0], times).ToArray()); + cmd.Add("2", DbType.Int64, Enumerable.Repeat((long)bindings[1], times).ToArray()); + cmd.Add("3", DbType.Decimal, Enumerable.Repeat((decimal)bindings[2], times).ToArray()); + cmd.Add("4", DbType.Double, Enumerable.Repeat((double)bindings[3], times).ToArray()); + cmd.Add("5", DbType.Double, Enumerable.Repeat((float)bindings[4], times).ToArray()); + cmd.Add("6", DbType.String, Enumerable.Repeat((string)bindings[5], times).ToArray()); + cmd.Add("7", DbType.Boolean, Enumerable.Repeat((bool)bindings[6], times).ToArray()); + cmd.Add("8", DbType.Boolean, Enumerable.Repeat((bool)bindings[7], times).ToArray()); + cmd.Add("9", DbType.DateTime, Enumerable.Repeat((DateTime)bindings[8], times).ToArray()); + cmd.Add("10", DbType.DateTime, Enumerable.Repeat((DateTime)bindings[9], times).ToArray()); + cmd.Add("11", DbType.DateTime, Enumerable.Repeat((DateTime)bindings[10], times).ToArray()); + cmd.Add("12", DbType.DateTimeOffset, Enumerable.Repeat((DateTimeOffset)bindings[11], times).ToArray()) + .SFDataType = SFDataType.TIMESTAMP_LTZ; + cmd.Add("13", DbType.Binary, Enumerable.Repeat((byte[])bindings[12], times).ToArray()); + Assert.AreEqual(times, cmd.ExecuteNonQuery()); + } + } + + private void InsertManyRowsWithNulls(SnowflakeDbConnection conn, int times, params object[] bindings) + { + Assert.AreEqual(13, bindings.Length); + var sqlInsert = $"insert into {TableNameIceberg} ({SqlColumnsSimpleTypes}) values (?,?,?,?,?,?,?,?,?,?,?,?,?)"; + using (var cmd = conn.CreateCommand(sqlInsert)) + { + cmd.Add("1", DbType.Int32, Enumerable.Range((int)bindings[0], times).ToArray()); + + var longArray = Enumerable.Repeat((long?)bindings[1], times).ToArray(); + cmd.Add("2", DbType.Int64, NullEachNthValue(longArray, 2)); + + var decArray = Enumerable.Repeat((decimal?)bindings[2], times).ToArray(); + cmd.Add("3", DbType.Decimal, NullEachNthValue(decArray, 3)); + + var dblArray = Enumerable.Repeat((double?)bindings[3], times).ToArray(); + cmd.Add("4", DbType.Double, NullEachNthValue(dblArray, 4)); + + var fltArray = Enumerable.Repeat((float?)bindings[4], times).ToArray(); + cmd.Add("5", DbType.Double, NullEachNthValue(fltArray, 5)); + + var strArray = Enumerable.Repeat((string)bindings[5], times).ToArray(); + cmd.Add("6", DbType.String, NullEachNthValue(strArray, 6)); + + var bltArray = Enumerable.Repeat((bool?)bindings[6], times).ToArray(); + cmd.Add("7", DbType.Boolean, NullEachNthValue(bltArray, 7)); + + var blfArray = Enumerable.Repeat((bool?)bindings[7], times).ToArray(); + cmd.Add("8", DbType.Boolean, NullEachNthValue(blfArray, 8)); + + var dtArray = Enumerable.Repeat((DateTime?)bindings[8], times).ToArray(); + cmd.Add("9", DbType.Date, NullEachNthValue(dtArray, 9)); + + var tmArray = Enumerable.Repeat((DateTime?)bindings[9], times).ToArray(); + cmd.Add("10", DbType.Time, NullEachNthValue(tmArray, 10)); + + var ntzArray = Enumerable.Repeat((DateTime?)bindings[10], times).ToArray(); + cmd.Add("11", DbType.DateTime, NullEachNthValue(ntzArray, 11)); + + var ltzArray = Enumerable.Repeat((DateTimeOffset?)bindings[11], times).ToArray(); + cmd.Add("12", DbType.DateTimeOffset, NullEachNthValue(ltzArray, 12)) + .SFDataType = SFDataType.TIMESTAMP_LTZ; + + var binArray = Enumerable.Repeat((byte[])bindings[12], times).ToArray(); + cmd.Add("13", DbType.Binary, NullEachNthValue(binArray, 13)); + + Assert.AreEqual(times, cmd.ExecuteNonQuery()); + } + } + + private void InsertHybridTableData(SnowflakeDbConnection conn) + { + using (var cmd = conn.CreateCommand($"insert into {TableNameHybrid} ({SqlColumnsHybridTypes}) values (?,?,?)")) + { + cmd.Add("1", DbType.Int32, new[]{Id1, Id2}); + cmd.Add("2", DbType.Int32, new[]{I32, I32}); + cmd.Add("3", DbType.String, new[]{Txt1,Txt2}); + cmd.ExecuteNonQuery(); + } + } + + private void AssertRowValuesEqual(DbDataReader actualRow, string[] columns, params object[] expectedRow) + { + foreach (var idx in Enumerable.Range(0, expectedRow.Length)) + { + var expected = expectedRow[idx]; + if (expected is DBNull || expected == null) + { + Assert.IsTrue(actualRow.IsDBNull(idx)); + continue; + } + + var column = columns[idx].ToUpper().Trim(); + var mismatch = $"Mismatch on column {idx}: {column}"; + switch (expected) + { + case Int32 i32: + Assert.AreEqual(i32, actualRow.GetInt32(idx), mismatch); + break; + case Int64 i64: + Assert.AreEqual(i64, actualRow.GetInt64(idx), mismatch); + break; + case Decimal dec: + Assert.AreEqual(dec, actualRow.GetDecimal(idx), mismatch); + break; + case float flt: + Assert.AreEqual(flt, actualRow.GetFloat(idx), mismatch); + break; + case String str: + Assert.AreEqual(str, actualRow.GetString(idx), mismatch); + break; + case Boolean bl: + Assert.AreEqual(bl, actualRow.GetBoolean(idx), mismatch); + break; + case DateTime dt: + var frmt = column.Contains(" TIME") ? FormatHms : FormatYmdHmsf; + Assert.AreEqual(dt.ToString(frmt), actualRow.GetDateTime(idx).ToString(frmt), mismatch); + break; + case DateTimeOffset dto: + Assert.AreEqual(dto.ToUniversalTime().ToString(FormatYmdHmsfZ), + actualRow.GetFieldValue(idx).ToUniversalTime().ToString(FormatYmdHmsfZ), + mismatch); + break; + case byte[] bt: + Assert.AreEqual(bt, actualRow.GetFieldValue(idx), mismatch); + break; + } + } + } + } +} diff --git a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj index a44fd4499..86decd67a 100644 --- a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj +++ b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj @@ -9,7 +9,7 @@ Snowflake Connector for .NET Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. true - 7.3 + 9 $(SEQUENTIAL_ENV) diff --git a/Snowflake.Data.Tests/Util/DbCommandExtensions.cs b/Snowflake.Data.Tests/Util/DbCommandExtensions.cs index fb336d5c3..d9dc0f2f8 100644 --- a/Snowflake.Data.Tests/Util/DbCommandExtensions.cs +++ b/Snowflake.Data.Tests/Util/DbCommandExtensions.cs @@ -1,12 +1,13 @@ using System.Data; +using Snowflake.Data.Client; namespace Snowflake.Data.Tests.Util { public static class DbCommandExtensions { - internal static IDbDataParameter Add(this IDbCommand command, string name, DbType dbType, object value) + internal static SnowflakeDbParameter Add(this IDbCommand command, string name, DbType dbType, object value) { - var parameter = command.CreateParameter(); + var parameter = (SnowflakeDbParameter)command.CreateParameter(); parameter.ParameterName = name; parameter.DbType = dbType; parameter.Value = value; diff --git a/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs b/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs index 02b7e47dd..e8efc371d 100644 --- a/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs +++ b/Snowflake.Data.Tests/Util/DbConnectionExtensions.cs @@ -1,23 +1,32 @@ using System.Data; +using System.Data.Common; +using Snowflake.Data.Client; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.IcebergTests; namespace Snowflake.Data.Tests.Util { public static class DbConnectionExtensions { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + internal static IDbCommand CreateCommand(this IDbConnection connection, string commandText) { var command = connection.CreateCommand(); command.Connection = connection; command.CommandText = commandText; + s_logger.Debug(commandText); return command; } internal static int ExecuteNonQuery(this IDbConnection connection, string commandText) { - var command = connection.CreateCommand(); - command.Connection = connection; - command.CommandText = commandText; - return command.ExecuteNonQuery(); + var rowsAffected = connection.CreateCommand(commandText).ExecuteNonQuery(); + s_logger.Debug($"Affected row(s): {rowsAffected}"); + return rowsAffected; } + + public static DbDataReader ExecuteReader(this SnowflakeDbConnection connection, string commandText) + => (DbDataReader)connection.CreateCommand(commandText).ExecuteReader(); } } diff --git a/Snowflake.Data.Tests/Util/TestDataHelpers.cs b/Snowflake.Data.Tests/Util/TestDataHelpers.cs new file mode 100644 index 000000000..170151c7f --- /dev/null +++ b/Snowflake.Data.Tests/Util/TestDataHelpers.cs @@ -0,0 +1,44 @@ +using System.Linq; +using System.Text; + +namespace Snowflake.Data.Tests.Util +{ + internal static class TestData + { + internal static string ByteArrayToHexString(byte[] ba) + { + StringBuilder hex = new StringBuilder(ba.Length * 2); + foreach (byte b in ba) + hex.AppendFormat("{0:x2}", b); + return hex.ToString(); + } + + internal static T?[] NullEachNthValue(T?[] sourceColumn, int nullEachNthItem) where T : struct + { + var destination = new T?[sourceColumn.Length]; + foreach (var rowIndex in Enumerable.Range(0, sourceColumn.Length)) + destination[rowIndex] = rowIndex % nullEachNthItem == 0 ? null : sourceColumn[rowIndex]; + return destination; + } + + internal static T?[] NullEachNthValue(T?[] sourceColumn, int nullEachNthItem) where T : class + { + var destination = new T?[sourceColumn.Length]; + foreach (var rowIndex in Enumerable.Range(0, sourceColumn.Length)) + destination[rowIndex] = rowIndex % nullEachNthItem == 0 ? null : sourceColumn[rowIndex]; + return destination; + } + + internal static object[] NullEachNthValueBesidesFirst(object[] sourceRow, int nullEachNthItem) + { + var ret = new object[sourceRow.Length]; + foreach (var column in Enumerable.Range(0, sourceRow.Length)) + ret[column] = column > 0 && nullEachNthItem % (column + 1) == 0 ? null : sourceRow[column]; + return ret; + } + + internal static string RemoveBlanks(string text) + => text.Replace("\n", "").Replace(" ", ""); + + } +} diff --git a/Snowflake.Data/Core/SFDataConverter.cs b/Snowflake.Data/Core/SFDataConverter.cs index 2e380f73d..6822f03f4 100755 --- a/Snowflake.Data/Core/SFDataConverter.cs +++ b/Snowflake.Data/Core/SFDataConverter.cs @@ -327,7 +327,7 @@ internal static string csharpValToSfVal(SFDataType sfDataType, object srcVal) { string destVal = null; - if (srcVal != DBNull.Value) + if (srcVal != DBNull.Value && srcVal != null) { switch (sfDataType) { From 46e496ac2cd49f98b5c5be157acb1dbfff13d6b6 Mon Sep 17 00:00:00 2001 From: Steven Lizano <130484280+sfc-gh-erojaslizano@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:46:59 -0600 Subject: [PATCH 10/12] [SNOW-921048]Add linter action (#892) Add linter action - Add linter.yml workflow - Add .EditorConfig with default rules - Add .DS_Store to gitignore --- .EditorConfig | 61 ++++++++++++++++++++++++++++++++++++ .github/workflows/linter.yml | 43 +++++++++++++++++++++++++ .github/workflows/main.yml | 3 +- .gitignore | 3 ++ 4 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 .EditorConfig create mode 100644 .github/workflows/linter.yml diff --git a/.EditorConfig b/.EditorConfig new file mode 100644 index 000000000..61ab3ee2c --- /dev/null +++ b/.EditorConfig @@ -0,0 +1,61 @@ +root = true +# All files +[*.*] +indent_style = space +indent_size = 4 +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 +max_line_length=150 + +# Interfaces should start with I and PascalCase +dotnet_naming_rule.interfaces_begin_with_I.severity = warning +dotnet_naming_rule.interfaces_begin_with_I.symbols = interfaces +dotnet_naming_rule.interfaces_begin_with_I.style = prefix_and_pascal_case +dotnet_naming_rule.interfaces_begin_with_I.required_prefix = I +dotnet_naming_symbols.interfaces.applicable_kinds = interface +dotnet_diagnostic.interfaces_begin_with_I.severity = warning +dotnet_diagnostic.interfaces_begin_with_I.enabled = true + +# Static fields should start with _s +dotnet_naming_rule.static_fields_begin_with_s.severity = warning +dotnet_naming_rule.static_fields_begin_with_s.symbols = static_fields +dotnet_naming_rule.static_fields_begin_with_s.style = custom +dotnet_naming_rule.static_fields_begin_with_s.custom_recommended_prefix = _r +dotnet_naming_rule.static_fields_begin_with_s.required_prefix = _r +dotnet_naming_rule.static_fields_begin_with_s.capitalization = camel_case +dotnet_naming_symbols.static_fields.applicable_kinds = field +dotnet_naming_symbols.static_fields.applicable_accessibilities = public, internal, private, protected, protected_internal +dotnet_naming_symbols.static_fields.required_modifiers = static +dotnet_diagnostic.static_fields_begin_with_s.severity = warning +dotnet_diagnostic.static_fields_begin_with_s.enabled = true + +# Enforce use of Pascal case in enums, classes, const and methods +dotnet_naming_rule.enforce_pascal_case.severity = suggestion +dotnet_naming_rule.enforce_pascal_case.symbols = methods, enums, consts, public_methods, public_classes +dotnet_naming_rule.enforce_pascal_case.style = pascal_case +dotnet_naming_symbols.methods.applicable_kinds = method +dotnet_naming_symbols.enums.applicable_kinds = enum +dotnet_naming_symbols.consts.applicable_kinds = field +dotnet_naming_symbols.consts.applicable_modifiers = const +dotnet_naming_symbols.public_methods.applicable_kinds = method +dotnet_naming_symbols.public_methods.applicable_accessibilities = public +dotnet_naming_symbols.public_classes.applicable_kinds = class +dotnet_naming_symbols.public_classes.applicable_accessibilities = public +dotnet_diagnostic.enforce_pascal_case.severity = suggestion +dotnet_diagnostic.enforce_pascal_case.enabled = true + +# private and internal members should start with underscore +dotnet_naming_rule.private_and_internal_members_start_with_underscore.severity = warning +dotnet_naming_rule.private_and_internal_members_start_with_underscore.symbols = private_fields, internal_fields, private_properties, internal_properties, private_methods, internal_methods +dotnet_naming_rule.private_and_internal_members_start_with_underscore.style = underscore_prefix +dotnet_naming_symbols.private_fields.applicable_kinds = field +dotnet_naming_symbols.internal_fields.applicable_kinds = field +dotnet_naming_symbols.private_properties.applicable_kinds = property +dotnet_naming_symbols.internal_properties.applicable_kinds = property +dotnet_naming_symbols.private_methods.applicable_kinds = method +dotnet_naming_symbols.internal_methods.applicable_kinds = method +dotnet_naming_symbols.private_methods.applicable_accessibilities = private +dotnet_naming_symbols.internal_methods.applicable_accessibilities = internal +dotnet_diagnostic.private_and_internal_members_start_with_underscore.severity = warning +dotnet_diagnostic.private_and_internal_members_start_with_underscore.enabled = true diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml new file mode 100644 index 000000000..dc328598f --- /dev/null +++ b/.github/workflows/linter.yml @@ -0,0 +1,43 @@ +name: Code standards check + +# Triggers the workflow on pull request events but only for the master branch +on: + pull_request: + branches: [ master ] + workflow_dispatch: + inputs: + logLevel: + default: warning + description: "Log level" + required: true + tags: + description: "Linter" + required: false + +concurrency: + # older builds for the same pull request number or branch should be cancelled + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +env: + DOTNET_VERSION: 6.0 + DOTNET_LEGACY_VERSION: 4.7.1 + +jobs: + run-linter: + name: Run linter + runs-on: windows-latest + steps: + - name: Check out Git repository + uses: actions/checkout@v3 + - name: Set up .NET + uses: actions/setup-dotnet@v1 + with: + dotnet-version: "6.0.x" + dotnet-quality: 'ga' + - name: Run linters + uses: wearerequired/lint-action@v2 + with: + dotnet_format: true + continue_on_error: true + check_name: ${linter} run diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 86da91f0e..d0bd7c6bd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,8 @@ jobs: dotnet: ['net6.0', 'net472', 'net471'] cloud_env: ['AZURE', 'GCP', 'AWS'] steps: - - uses: actions/checkout@v3 + - name: Checkout code + uses: actions/checkout@v3 - name: Setup Dotnet uses: actions/setup-dotnet@v3 with: diff --git a/.gitignore b/.gitignore index 325ad49d6..268c8f4dc 100644 --- a/.gitignore +++ b/.gitignore @@ -309,3 +309,6 @@ whitesource/ Snowflake.Data.Tests/macos_*_performance.csv Snowflake.Data.Tests/windows_*_performance.csv Snowflake.Data.Tests/unix_*_performance.csv + +# Ignore Mac files +**/.DS_Store \ No newline at end of file From 47235fb9f11656669e0d8f3d4edcea4e106c36dd Mon Sep 17 00:00:00 2001 From: Steven Lizano <130484280+sfc-gh-erojaslizano@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:58:01 -0600 Subject: [PATCH 11/12] [SNOW-1313929] Add https to endpoints without it (#914) ### Description - Add https to endpoints without it - Add unit test for SetCommonClientConfig ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../blob/master/CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- .../UnitTests/SFS3ClientTest.cs | 31 +++++++++++++++++++ .../FileTransfer/StorageClient/SFS3Client.cs | 29 ++++++++--------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs index 561819623..da3baf531 100644 --- a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs @@ -3,6 +3,7 @@ */ using System; +using Amazon.S3.Encryption; namespace Snowflake.Data.Tests.UnitTests { @@ -219,6 +220,36 @@ public void TestUploadFile(string requestKey, ResultStatus expectedResultStatus) AssertForUploadFileTests(expectedResultStatus); } + [Test] + public void TestAppendHttpsToEndpoint() + { + // Arrange + var amazonS3Client = new AmazonS3Config(); + var endpoint = "endpointWithNoHttps.com"; + var expectedEndpoint = "https://endpointWithNoHttps.com"; + + // ACT + SFS3Client.SetCommonClientConfig(amazonS3Client, string.Empty, endpoint, 1, 0); + + // Assert + Assert.That(amazonS3Client.ServiceURL, Is.EqualTo(expectedEndpoint)); + } + + [Test] + public void TestAppendHttpsToEndpointWithBrackets() + { + // Arrange + var amazonS3Client = new AmazonS3Config(); + var endpoint = "[endpointWithNoHttps.com]"; + var expectedEndpoint = "https://endpointWithNoHttps.com"; + + // ACT + SFS3Client.SetCommonClientConfig(amazonS3Client, string.Empty, endpoint, 1, 0); + + // Assert + Assert.That(amazonS3Client.ServiceURL, Is.EqualTo(expectedEndpoint)); + } + [Test] [TestCase(MockS3Client.AwsStatusOk, ResultStatus.UPLOADED)] [TestCase(SFS3Client.EXPIRED_TOKEN, ResultStatus.RENEW_TOKEN)] diff --git a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs index e68fbbd3e..88b20c1d5 100644 --- a/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs +++ b/Snowflake.Data/Core/FileTransfer/StorageClient/SFS3Client.cs @@ -116,7 +116,7 @@ public SFS3Client( stageInfo.endPoint, maxRetry, parallel); - + // Get the AWS token value and create the S3 client if (stageInfo.stageCredentials.TryGetValue(AWS_TOKEN, out string awsSessionToken)) { @@ -164,7 +164,7 @@ public RemoteLocation ExtractBucketNameAndPath(string stageLocation) { bucketName = stageLocation.Substring(0, stageLocation.IndexOf('/')); - s3path = stageLocation.Substring(stageLocation.IndexOf('/') + 1, + s3path = stageLocation.Substring(stageLocation.IndexOf('/') + 1, stageLocation.Length - stageLocation.IndexOf('/') - 1); if (s3path != null && !s3path.EndsWith("/")) { @@ -287,13 +287,13 @@ private FileHeader HandleFileHeaderResponse(ref SFFileMetadata fileMetadata, Get } /// - /// Set the client configuration common to both client with and without client-side + /// Set the client configuration common to both client with and without client-side /// encryption. /// /// The client config to update. /// The region if any. /// The endpoint if any. - private static void SetCommonClientConfig( + internal static void SetCommonClientConfig( AmazonS3Config clientConfig, string region, string endpoint, @@ -309,23 +309,25 @@ private static void SetCommonClientConfig( } // If a specific endpoint is specified use this - if ((null != endpoint) && (0 != endpoint.Length)) + if (!string.IsNullOrEmpty(endpoint)) { var start = endpoint.IndexOf('['); var end = endpoint.IndexOf(']'); - if(start > -1 && end > -1 && end > start) + if (start > -1 && end > -1 && end > start) { endpoint = endpoint.Substring(start + 1, end - start - 1); - if(!endpoint.Contains("https")) - { - endpoint = "https://" + endpoint; - } } + + if (!endpoint.StartsWith("https://", StringComparison.OrdinalIgnoreCase)) + { + endpoint = "https://" + endpoint; + } + clientConfig.ServiceURL = endpoint; } // The region information used to determine the endpoint for the service. - // RegionEndpoint and ServiceURL are mutually exclusive properties. + // RegionEndpoint and ServiceURL are mutually exclusive properties. // If both stageInfo.endPoint and stageInfo.region have a value, stageInfo.region takes // precedence and ServiceUrl will be reset to null. if ((null != region) && (0 != region.Length)) @@ -337,7 +339,6 @@ private static void SetCommonClientConfig( // Unavailable for .net framework 4.6 //clientConfig.MaxConnectionsPerServer = parallel; clientConfig.MaxErrorRetry = maxRetry; - } /// @@ -410,7 +411,7 @@ private PutObjectRequest GetPutObjectRequest(ref AmazonS3Client client, SFFileMe { PutGetStageInfo stageInfo = fileMetadata.stageInfo; RemoteLocation location = ExtractBucketNameAndPath(stageInfo.location); - + // Create S3 PUT request fileBytesStream.Position = 0; PutObjectRequest putObjectRequest = new PutObjectRequest @@ -585,4 +586,4 @@ private SFFileMetadata HandleDownloadFileErr(Exception ex, SFFileMetadata fileMe return fileMetadata; } } -} \ No newline at end of file +} From 14cf8a5daaf9611b237eb4b51ccdd61cb44761d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joanna=20Siemi=C5=84ska?= Date: Fri, 19 Apr 2024 08:13:05 +0200 Subject: [PATCH 12/12] SNOW-834812 Adding connection parameter QUERY_TAG. (#916) ### Description Adding connection parameter QUERY_TAG. ### Checklist - [x] Code compiles correctly - [x] Code is formatted according to [Coding Conventions](../blob/master/CodingConventions.md) - [x] Created tests which fail without the change (if possible) - [x] All tests passing (`dotnet test`) - [x] Extended the README / documentation, if necessary - [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name --- README.md | 5 +-- .../IntegrationTests/SFConnectionIT.cs | 18 ++++++++++ .../UnitTests/SFSessionPropertyTest.cs | 35 ++++++++++++++++++- Snowflake.Data/Core/SFStatement.cs | 18 ++++++++-- Snowflake.Data/Core/Session/SFSession.cs | 3 ++ .../Core/Session/SFSessionProperty.cs | 4 ++- 6 files changed, 77 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 7c7e65356..263de2c0f 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ The following table lists all valid connection properties:
| Connection Property | Required | Comment | -| ------------------------------ | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +|--------------------------------| -------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | ACCOUNT | Yes | Your full account name might include additional segments that identify the region and cloud platform where your account is hosted | | APPLICATION | No | **_Snowflake partner use only_**: Specifies the name of a partner application to connect through .NET. The name must match the following pattern: ^\[A-Za-z](\[A-Za-z0-9.-]){1,50}$ (one letter followed by 1 to 50 letter, digit, .,- or, \_ characters). | | DB | No | | @@ -163,10 +163,11 @@ The following table lists all valid connection properties: | PROXYPORT | Depends | The port number of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | | PROXYUSER | No | The username for authenticating to the proxy server.

This parameter was introduced in v2.0.4. | | PROXYPASSWORD | Depends | The password for authenticating to the proxy server.

If USEPROXY is `true` and PROXYUSER is set, you must set this parameter.

This parameter was introduced in v2.0.4. | -| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.
The host target value should fully match with any item from the proxy host list to bypass the proxy server.

This parameter was introduced in v2.0.4. | +| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.
The host target value should fully match with any item from the proxy host list to bypass the proxy server.

This parameter was introduced in v2.0.4. | | FILE_TRANSFER_MEMORY_THRESHOLD | No | The maximum number of bytes to store in memory used in order to provide a file encryption. If encrypting/decrypting file size exceeds provided value a temporary file will be created and the work will be continued in the temporary file instead of memory.
If no value provided 1MB will be used as a default value (that is 1048576 bytes).
It is possible to configure any integer value bigger than zero representing maximal number of bytes to reside in memory. | | CLIENT_CONFIG_FILE | No | The location of the client configuration json file. In this file you can configure easy logging feature. | | ALLOWUNDERSCORESINHOST | No | Specifies whether to allow underscores in account names. This impacts PrivateLink customers whose account names contain underscores. In this situation, you must override the default value by setting allowUnderscoresInHost to true. | +| QUERY_TAG | No | Optional string that can be used to tag queries and other SQL statements executed within a connection. The tags are displayed in the output of the QUERY_HISTORY , QUERY_HISTORY_BY_* functions. |
diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index c248ef575..cc4fea738 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -2224,6 +2224,24 @@ public void TestNativeOktaSuccess() } } + [Test] + public void TestConnectStringWithQueryTag() + { + using (var conn = new SnowflakeDbConnection()) + { + string expectedQueryTag = "Test QUERY_TAG 12345"; + conn.ConnectionString = ConnectionString + $";query_tag={expectedQueryTag}"; + + conn.Open(); + var command = conn.CreateCommand(); + // This query itself will be part of the history and will have the query tag + command.CommandText = "SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())"; + var queryTag = command.ExecuteScalar(); + + Assert.AreEqual(expectedQueryTag, queryTag); + } + } + } } diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index 4b2e3ec8f..309570ca6 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -470,6 +470,38 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.ALLOWUNDERSCORESINHOST, "true" } } }; + var testQueryTag = "Test QUERY_TAG 12345"; + var testCaseQueryTag = new TestCase() + { + ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};QUERY_TAG={testQueryTag}", + ExpectedProperties = new SFSessionProperties() + { + { SFSessionProperty.ACCOUNT, $"{defAccount}" }, + { SFSessionProperty.USER, defUser }, + { SFSessionProperty.HOST, $"{defAccount}.snowflakecomputing.com" }, + { SFSessionProperty.AUTHENTICATOR, defAuthenticator }, + { SFSessionProperty.SCHEME, defScheme }, + { SFSessionProperty.CONNECTION_TIMEOUT, defConnectionTimeout }, + { SFSessionProperty.PASSWORD, defPassword }, + { SFSessionProperty.PORT, defPort }, + { SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS, "true" }, + { SFSessionProperty.USEPROXY, "false" }, + { SFSessionProperty.INSECUREMODE, "false" }, + { SFSessionProperty.DISABLERETRY, "false" }, + { SFSessionProperty.FORCERETRYON404, "false" }, + { SFSessionProperty.CLIENT_SESSION_KEEP_ALIVE, "false" }, + { SFSessionProperty.FORCEPARSEERROR, "false" }, + { SFSessionProperty.BROWSER_RESPONSE_TIMEOUT, defBrowserResponseTime }, + { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout }, + { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries }, + { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, + { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, + { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, + { SFSessionProperty.ALLOWUNDERSCORESINHOST, "false" }, + { SFSessionProperty.QUERY_TAG, testQueryTag } + } + }; + return new TestCase[] { simpleTestCase, @@ -482,7 +514,8 @@ public static IEnumerable ConnectionStringTestCases() testCaseWithDisableConsoleLogin, testCaseComplicatedAccountName, testCaseUnderscoredAccountName, - testCaseUnderscoredAccountNameWithEnabledAllowUnderscores + testCaseUnderscoredAccountNameWithEnabledAllowUnderscores, + testCaseQueryTag }; } diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index 05e905263..ae5ecbf4e 100644 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -110,7 +110,9 @@ class SFStatement private const string SF_QUERY_RESULT_PATH = "/queries/{0}/result"; private const string SF_PARAM_MULTI_STATEMENT_COUNT = "MULTI_STATEMENT_COUNT"; - + + private const string SF_PARAM_QUERY_TAG = "QUERY_TAG"; + private const int SF_QUERY_IN_PROGRESS = 333333; private const int SF_QUERY_IN_PROGRESS_ASYNC = 333334; @@ -141,10 +143,13 @@ class SFStatement // the query id of the last query string _lastQueryId = null; + private string _queryTag = null; + internal SFStatement(SFSession session) { SfSession = session; _restRequester = session.restRequester; + _queryTag = session._queryTag; } internal string GetBindStage() => _bindStage; @@ -195,7 +200,16 @@ private SFRestRequest BuildQueryRequest(string sql, Dictionary(); + } + bodyParameters[SF_PARAM_QUERY_TAG] = _queryTag; + } QueryRequest postBody = new QueryRequest(); postBody.sqlText = sql; diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 8f56fdda4..3b0c80f8d 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -83,6 +83,8 @@ public class SFSession internal int _maxRetryTimeout; + internal String _queryTag; + internal void ProcessLoginResponse(LoginResponse authnResponse) { if (authnResponse.success) @@ -168,6 +170,7 @@ internal SFSession( connectionTimeout = extractedProperties.TimeoutDuration(); properties.TryGetValue(SFSessionProperty.CLIENT_CONFIG_FILE, out var easyLoggingConfigFile); _easyLoggingStarter.Init(easyLoggingConfigFile); + properties.TryGetValue(SFSessionProperty.QUERY_TAG, out _queryTag); _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 6ed45be81..7ce6b4731 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -94,7 +94,9 @@ internal enum SFSessionProperty [SFSessionPropertyAttr(required = false, defaultValue = "true")] DISABLE_CONSOLE_LOGIN, [SFSessionPropertyAttr(required = false, defaultValue = "false")] - ALLOWUNDERSCORESINHOST + ALLOWUNDERSCORESINHOST, + [SFSessionPropertyAttr(required = false)] + QUERY_TAG } class SFSessionPropertyAttr : Attribute