diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs index 2a403521e..4f5020538 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs @@ -333,7 +333,7 @@ void ThreadProcess2(string connstr) Thread.Sleep(5000); SFStatement statement = new SFStatement(conn1.SfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); SnowflakeDbConnectionPool.ClearAllPools(); diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 6e92c0aac..a4c84caeb 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -175,6 +175,302 @@ public void TestExecuteAsyncWithMaxRetryReached() Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, 30 * 1000); } } + + [Test] + public async Task TestAsyncExecQueryAsync() + { + string queryId; + var expectedWaitTime = 5; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + Assert.IsFalse(conn.IsAnError(queryStatus)); + + // Act + DbDataReader reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); + queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, queryStatus); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test, NonParallelizable] + public async Task TestExecuteNormalQueryWhileAsyncExecQueryIsRunningAsync() + { + string queryId; + var expectedWaitTime = 5; + + SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new SnowflakeDbConnection(ConnectionString); + await connections[i].OpenAsync(CancellationToken.None).ConfigureAwait(false); + } + + // Start the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[0].CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(connections[0].IsStillRunning(queryStatus)); + } + + // Execute a normal query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[1].CreateCommand()) + { + // Arrange + cmd.CommandText = $"select 1;"; + + // Act + var row = cmd.ExecuteScalar(); + + // Assert + Assert.AreEqual(1, row); + } + + // Get results of the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[2].CreateCommand()) + { + // Act + var reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, queryStatus); + } + + for (int i = 0; i < connections.Length; i++) + { + await connections[i].CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestAsyncExecCancelWhileGettingResultsAsync() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + CancellationTokenSource cancelToken = new CancellationTokenSource(); + cmd.CommandText = $"CALL SYSTEM$WAIT(60, \'SECONDS\');"; + + // Act + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + + // Act + cancelToken.Cancel(); + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(queryId, cancelToken.Token).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("The operation was canceled")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestFailedAsyncExecQueryThrowsErrorAsync() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + var statusMaxRetryCount = 5; + var statusRetryCount = 0; + cmd.CommandText = $"SELECT * FROM FAKE_TABLE;"; + + // Act + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + while (statusRetryCount < statusMaxRetryCount && conn.IsStillRunning(queryStatus)) + { + Thread.Sleep(1000); + queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); + statusRetryCount++; + } + + // Assert + Assert.AreEqual(QueryStatus.FailedWithError, queryStatus); + + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("'FAKE_TABLE' does not exist")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetStatusOfInvalidQueryIdAsync() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetQueryStatusAsync(fakeQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetResultsOfInvalidQueryIdAsync() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(fakeQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test, NonParallelizable] + public async Task TestGetStatusOfUnknownQueryIdAsync() + { + string unknownQueryId = "ba321edc-1abc-123e-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var queryStatus = await cmd.GetQueryStatusAsync(unknownQueryId, CancellationToken.None).ConfigureAwait(false); + + // Assert + Assert.AreEqual(QueryStatus.NoData, queryStatus); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + [Ignore("The test takes too long to finish when using the default retry")] + public async Task TestGetResultsOfUnknownQueryIdAsyncWithDefaultRetry() + { + string unknownQueryId = "ab123fed-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.ThrowsAsync(async () => + await cmd.GetResultsFromQueryIdAsync(unknownQueryId, CancellationToken.None).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"Max retry for no data is reached")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } + + [Test] + public async Task TestGetResultsOfUnknownQueryIdAsyncWithConfiguredRetry() + { + var queryResultsRetryCount = 3; + var queryResultsRetryPattern = new int[] { 1, 2 }; + var unknownQueryId = "ab123fed-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + QueryResultsAwaiter queryResultsAwaiter = new QueryResultsAwaiter(new QueryResultsRetryConfig(queryResultsRetryCount, queryResultsRetryPattern)); + + // Act + var thrown = Assert.ThrowsAsync(async () => + await queryResultsAwaiter.RetryUntilQueryResultIsAvailable(conn, unknownQueryId, CancellationToken.None, true).ConfigureAwait(false)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"Max retry for no data is reached")); + } + + await conn.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + } } [TestFixture] @@ -1040,5 +1336,293 @@ public void TestGetQueryId() conn.Close(); } } + + [Test] + public void TestAsyncExecQuery() + { + string queryId; + var expectedWaitTime = 5; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + var queryStatus = cmd.GetQueryStatus(queryId); + + // Assert + Assert.IsTrue(conn.IsStillRunning(queryStatus)); + Assert.IsFalse(conn.IsAnError(queryStatus)); + + // Act + DbDataReader reader = cmd.GetResultsFromQueryId(queryId); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, cmd.GetQueryStatus(queryId)); + } + + conn.Close(); + } + } + + [Test, NonParallelizable] + public void TestExecuteNormalQueryWhileAsyncExecQueryIsRunning() + { + string queryId; + var expectedWaitTime = 5; + + SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; + for (int i = 0; i < connections.Length; i++) + { + connections[i] = new SnowflakeDbConnection(ConnectionString); + connections[i].Open(); + } + + // Start the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[0].CreateCommand()) + { + // Arrange + cmd.CommandText = $"CALL SYSTEM$WAIT({expectedWaitTime}, \'SECONDS\');"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + + // Assert + Assert.IsTrue(connections[0].IsStillRunning(cmd.GetQueryStatus(queryId))); + } + + // Execute a normal query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[1].CreateCommand()) + { + // Arrange + cmd.CommandText = $"select 1;"; + + // Act + var row = cmd.ExecuteScalar(); + + // Assert + Assert.AreEqual(1, row); + } + + // Get results of the async exec query + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)connections[2].CreateCommand()) + { + // Act + DbDataReader reader = cmd.GetResultsFromQueryId(queryId); + + // Assert + Assert.IsTrue(reader.Read()); + Assert.AreEqual($"waited {expectedWaitTime} seconds", reader.GetString(0)); + Assert.AreEqual(QueryStatus.Success, cmd.GetQueryStatus(queryId)); + } + + for (int i = 0; i < connections.Length; i++) + { + connections[i].Close(); + } + } + + [Test] + public void TestFailedAsyncExecQueryThrowsError() + { + string queryId; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + var statusMaxRetryCount = 5; + var statusRetryCount = 0; + cmd.CommandText = $"SELECT * FROM FAKE_TABLE;"; + + // Act + queryId = cmd.ExecuteInAsyncMode(); + while (statusRetryCount < statusMaxRetryCount && conn.IsStillRunning(cmd.GetQueryStatus(queryId))) + { + Thread.Sleep(1000); + statusRetryCount++; + } + + // Assert + Assert.AreEqual(QueryStatus.FailedWithError, cmd.GetQueryStatus(queryId)); + + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(queryId)); + + // Assert + Assert.IsTrue(thrown.Message.Contains("'FAKE_TABLE' does not exist")); + } + + conn.Close(); + } + } + + [Test] + public void TestAsyncExecQueryPutGetThrowsNotImplemented() + { + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + cmd.CommandText = $"PUT file://non_existent_file.csv @~;"; + + // Act + var thrown = Assert.Throws(() => cmd.ExecuteInAsyncMode()); + + // Assert + Assert.IsTrue(thrown.Message.Contains("Get and Put are not supported in async execution mode")); + + // Arrange + cmd.CommandText = "GET @~ file://C:\\tmp\\;"; + + // Act + thrown = Assert.Throws(() => cmd.ExecuteInAsyncMode()); + + // Assert + Assert.IsTrue(thrown.Message.Contains("Get and Put are not supported in async execution mode")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetStatusOfInvalidQueryId() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetQueryStatus(fakeQueryId)); + + // Assert + Assert.IsTrue(thrown.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetResultsOfInvalidQueryId() + { + string fakeQueryId = "fakeQueryId"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(fakeQueryId)); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"The given query id {fakeQueryId} is not valid uuid")); + } + + conn.Close(); + } + } + + [Test, NonParallelizable] + public void TestGetStatusOfUnknownQueryId() + { + string unknownQueryId = "ab123cde-1cba-789a-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var queryStatus = cmd.GetQueryStatus(unknownQueryId); + + // Assert + Assert.AreEqual(QueryStatus.NoData, queryStatus); + } + + conn.Close(); + } + } + + [Test] + [Ignore("The test takes too long to finish when using the default retry")] + public void TestGetResultsOfUnknownQueryIdWithDefaultRetry() + { + string unknownQueryId = "ba987def-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Act + var thrown = Assert.Throws(() => cmd.GetResultsFromQueryId(unknownQueryId)); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"Max retry for no data is reached")); + } + + conn.Close(); + } + } + + [Test] + public void TestGetResultsOfUnknownQueryIdWithConfiguredRetry() + { + var queryResultsRetryCount = 3; + var queryResultsRetryPattern = new int[] { 1, 2 }; + var unknownQueryId = "ba987def-1abc-987f-987f-1234a56b789c"; + + using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + // Arrange + QueryResultsAwaiter queryResultsAwaiter = new QueryResultsAwaiter(new QueryResultsRetryConfig(queryResultsRetryCount, queryResultsRetryPattern)); + var task = queryResultsAwaiter.RetryUntilQueryResultIsAvailable(conn, unknownQueryId, CancellationToken.None, false); + + // Act + var thrown = Assert.Throws(() => task.Wait()); + + // Assert + Assert.IsTrue(thrown.InnerException.Message.Contains($"Max retry for no data is reached")); + } + + conn.Close(); + } + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs b/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs index 330b19f96..24f0f4b0a 100755 --- a/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFStatementTest.cs @@ -6,6 +6,7 @@ namespace Snowflake.Data.Tests.UnitTests { using Snowflake.Data.Core; using NUnit.Framework; + using System; /** * Mock rest request test @@ -21,7 +22,7 @@ public void TestSessionRenew() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); Assert.AreEqual("new_session_token", sfSession.sessionToken); @@ -37,7 +38,7 @@ public void TestSessionRenewDuringQueryExec() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -57,7 +58,7 @@ public void TestServiceName() for (int i = 0; i < 5; i++) { SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false, false); expectServiceName += "a"; Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]); } @@ -73,7 +74,7 @@ public void TestTrimSqlBlockComment() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -88,7 +89,7 @@ public void TestTrimSqlBlockCommentMultiline() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -103,7 +104,7 @@ public void TestTrimSqlLineComment() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -118,9 +119,89 @@ public void TestTrimSqlLineCommentWithClosingNewline() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } + + [Test] + [TestCase("running", QueryStatus.Running)] + [TestCase("RUNNING", QueryStatus.Running)] + [TestCase("resuming_warehouse", QueryStatus.ResumingWarehouse)] + [TestCase("RESUMING_WAREHOUSE", QueryStatus.ResumingWarehouse)] + [TestCase("queued", QueryStatus.Queued)] + [TestCase("QUEUED", QueryStatus.Queued)] + [TestCase("queued_reparing_warehouse", QueryStatus.QueuedReparingWarehouse)] + [TestCase("QUEUED_REPARING_WAREHOUSE", QueryStatus.QueuedReparingWarehouse)] + [TestCase("no_data", QueryStatus.NoData)] + [TestCase("NO_DATA", QueryStatus.NoData)] + [TestCase("aborting", QueryStatus.Aborting)] + [TestCase("ABORTING", QueryStatus.Aborting)] + [TestCase("success", QueryStatus.Success)] + [TestCase("SUCCESS", QueryStatus.Success)] + [TestCase("failed_with_error", QueryStatus.FailedWithError)] + [TestCase("FAILED_WITH_ERROR", QueryStatus.FailedWithError)] + [TestCase("aborted", QueryStatus.Aborted)] + [TestCase("ABORTED", QueryStatus.Aborted)] + [TestCase("failed_with_incident", QueryStatus.FailedWithIncident)] + [TestCase("FAILED_WITH_INCIDENT", QueryStatus.FailedWithIncident)] + [TestCase("disconnected", QueryStatus.Disconnected)] + [TestCase("DISCONNECTED", QueryStatus.Disconnected)] + [TestCase("restarted", QueryStatus.Restarted)] + [TestCase("RESTARTED", QueryStatus.Restarted)] + [TestCase("blocked", QueryStatus.Blocked)] + [TestCase("BLOCKED", QueryStatus.Blocked)] + public void TestGetQueryStatusByStringValue(string stringValue, QueryStatus expectedStatus) + { + Assert.AreEqual(expectedStatus, QueryStatusExtensions.GetQueryStatusByStringValue(stringValue)); + } + + [Test] + [TestCase("UNKNOWN")] + [TestCase("RANDOM_STATUS")] + [TestCase("aBcZyX")] + public void TestGetQueryStatusByStringValueThrowsErrorForUnknownStatus(string stringValue) + { + var thrown = Assert.Throws(() => QueryStatusExtensions.GetQueryStatusByStringValue(stringValue)); + Assert.IsTrue(thrown.Message.Contains("The query status returned by the server is not recognized")); + } + + [Test] + [TestCase(QueryStatus.Running, true)] + [TestCase(QueryStatus.ResumingWarehouse, true)] + [TestCase(QueryStatus.Queued, true)] + [TestCase(QueryStatus.QueuedReparingWarehouse, true)] + [TestCase(QueryStatus.NoData, true)] + [TestCase(QueryStatus.Aborting, false)] + [TestCase(QueryStatus.Success, false)] + [TestCase(QueryStatus.FailedWithError, false)] + [TestCase(QueryStatus.Aborted, false)] + [TestCase(QueryStatus.FailedWithIncident, false)] + [TestCase(QueryStatus.Disconnected, false)] + [TestCase(QueryStatus.Restarted, false)] + [TestCase(QueryStatus.Blocked, false)] + public void TestIsStillRunning(QueryStatus status, bool expectedResult) + { + Assert.AreEqual(expectedResult, QueryStatusExtensions.IsStillRunning(status)); + } + + [Test] + [TestCase(QueryStatus.Aborting, true)] + [TestCase(QueryStatus.FailedWithError, true)] + [TestCase(QueryStatus.Aborted, true)] + [TestCase(QueryStatus.FailedWithIncident, true)] + [TestCase(QueryStatus.Disconnected, true)] + [TestCase(QueryStatus.Blocked, true)] + [TestCase(QueryStatus.Running, false)] + [TestCase(QueryStatus.ResumingWarehouse, false)] + [TestCase(QueryStatus.Queued, false)] + [TestCase(QueryStatus.QueuedReparingWarehouse, false)] + [TestCase(QueryStatus.NoData, false)] + [TestCase(QueryStatus.Success, false)] + [TestCase(QueryStatus.Restarted, false)] + public void TestIsAnError(QueryStatus status, bool expectedResult) + { + Assert.AreEqual(expectedResult, QueryStatusExtensions.IsAnError(status)); + } } } diff --git a/Snowflake.Data/Client/SnowflakeDbCommand.cs b/Snowflake.Data/Client/SnowflakeDbCommand.cs index ca415bacc..ce004df5c 100755 --- a/Snowflake.Data/Client/SnowflakeDbCommand.cs +++ b/Snowflake.Data/Client/SnowflakeDbCommand.cs @@ -9,7 +9,6 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Newtonsoft.Json; using Snowflake.Data.Log; namespace Snowflake.Data.Client @@ -17,7 +16,7 @@ namespace Snowflake.Data.Client [System.ComponentModel.DesignerCategory("Code")] public class SnowflakeDbCommand : DbCommand { - private DbConnection connection; + private SnowflakeDbConnection connection; private SFStatement sfStatement; @@ -25,6 +24,8 @@ public class SnowflakeDbCommand : DbCommand private SFLogger logger = SFLoggerFactory.GetLogger(); + private readonly QueryResultsAwaiter _queryResultsAwaiter = QueryResultsAwaiter.Instance; + public SnowflakeDbCommand() { logger.Debug("Constructing SnowflakeDbCommand class"); @@ -274,6 +275,88 @@ protected override async Task ExecuteDbDataReaderAsync(CommandBeha } } + /// + /// Execute a query in async mode. + /// Async mode means the server will respond immediately with the query ID and execute the query asynchronously + /// + /// The query id. + public string ExecuteInAsyncMode() + { + logger.Debug($"ExecuteInAsyncMode"); + SFBaseResultSet resultSet = ExecuteInternal(asyncExec: true); + return resultSet.queryId; + } + + /// + /// Executes an asynchronous query in async mode. + /// Async mode means the server will respond immediately with the query ID and execute the query asynchronously + /// + /// + /// The query id. + public async Task ExecuteAsyncInAsyncMode(CancellationToken cancellationToken) + { + logger.Debug($"ExecuteAsyncInAsyncMode"); + var resultSet = await ExecuteInternalAsync(cancellationToken, asyncExec: true).ConfigureAwait(false); + return resultSet.queryId; + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + public QueryStatus GetQueryStatus(string queryId) + { + logger.Debug($"GetQueryStatus"); + return _queryResultsAwaiter.GetQueryStatus(connection, queryId); + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// + /// The query status. + public async Task GetQueryStatusAsync(string queryId, CancellationToken cancellationToken) + { + logger.Debug($"GetQueryStatusAsync"); + return await _queryResultsAwaiter.GetQueryStatusAsync(connection, queryId, cancellationToken); + } + + /// + /// Gets the query results based on query ID. + /// + /// + /// The query results. + public DbDataReader GetResultsFromQueryId(string queryId) + { + logger.Debug($"GetResultsFromQueryId"); + + Task task = _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, CancellationToken.None, false); + task.Wait(); + + SFBaseResultSet resultSet = sfStatement.GetResultWithId(queryId); + + return new SnowflakeDbDataReader(this, resultSet); + } + + /// + /// Gets the query results based on query ID. + /// + /// + /// + /// The query results. + public async Task GetResultsFromQueryIdAsync(string queryId, CancellationToken cancellationToken) + { + logger.Debug($"GetResultsFromQueryIdAsync"); + + await _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, cancellationToken, true); + + SFBaseResultSet resultSet = await sfStatement.GetResultWithIdAsync(queryId, cancellationToken).ConfigureAwait(false); + + return new SnowflakeDbDataReader(this, resultSet); + } + private static Dictionary convertToBindList(List parameters) { if (parameters == null || parameters.Count == 0) @@ -354,18 +437,18 @@ private void SetStatement() this.sfStatement = new SFStatement(session); } - private SFBaseResultSet ExecuteInternal(bool describeOnly = false) + private SFBaseResultSet ExecuteInternal(bool describeOnly = false, bool asyncExec = false) { CheckIfCommandTextIsSet(); SetStatement(); - return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly); + return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec); } - private Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false) + private Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false, bool asyncExec = false) { CheckIfCommandTextIsSet(); SetStatement(); - return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, cancellationToken); + return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec, cancellationToken); } private void CheckIfCommandTextIsSet() diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index b773a0150..cce4974fc 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -176,7 +176,7 @@ public override async Task CloseAsync() } #endif - public virtual Task CloseAsync(CancellationToken cancellationToken) + public virtual async Task CloseAsync(CancellationToken cancellationToken) { logger.Debug("Close Connection."); TaskCompletionSource taskCompletionSource = new TaskCompletionSource(); @@ -199,7 +199,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) } else { - SfSession.CloseAsync(cancellationToken).ContinueWith( + await SfSession.CloseAsync(cancellationToken).ContinueWith( previousTask => { if (previousTask.IsFaulted) @@ -220,7 +220,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) _connectionState = ConnectionState.Closed; taskCompletionSource.SetResult(null); } - }, cancellationToken); + }, cancellationToken).ConfigureAwait(false); } } else @@ -229,7 +229,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken) taskCompletionSource.SetResult(null); } } - return taskCompletionSource.Task; + await taskCompletionSource.Task; } protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) @@ -402,6 +402,16 @@ internal void registerConnectionCancellationCallback(CancellationToken externalC } } + public bool IsStillRunning(QueryStatus status) + { + return QueryStatusExtensions.IsStillRunning(status); + } + + public bool IsAnError(QueryStatus status) + { + return QueryStatusExtensions.IsAnError(status); + } + ~SnowflakeDbConnection() { Dispose(false); diff --git a/Snowflake.Data/Core/QueryResultsAwaiter.cs b/Snowflake.Data/Core/QueryResultsAwaiter.cs new file mode 100644 index 000000000..5ea187fbe --- /dev/null +++ b/Snowflake.Data/Core/QueryResultsAwaiter.cs @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using Snowflake.Data.Log; +using System.Threading.Tasks; +using System.Threading; +using System; +using System.Text.RegularExpressions; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Core +{ + internal class QueryResultsRetryConfig + { + private const int DefaultAsyncNoDataMaxRetry = 24; + + private readonly int[] _defaultAsyncRetryPattern = { 1, 1, 2, 3, 4, 8, 10 }; + + internal readonly int _asyncNoDataMaxRetry; + + internal readonly int[] _asyncRetryPattern; + + internal QueryResultsRetryConfig() + { + _asyncNoDataMaxRetry = DefaultAsyncNoDataMaxRetry; + _asyncRetryPattern = _defaultAsyncRetryPattern; + } + + internal QueryResultsRetryConfig(int asyncNoDataMaxRetry, int[] asyncRetryPattern) + { + _asyncNoDataMaxRetry = asyncNoDataMaxRetry; + _asyncRetryPattern = asyncRetryPattern; + } + } + + internal class QueryResultsAwaiter + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private static readonly Regex UuidRegex = new Regex("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$"); + + private QueryResultsRetryConfig _queryResultsRetryConfig; + + internal static readonly QueryResultsAwaiter Instance = new QueryResultsAwaiter(); + + internal QueryResultsAwaiter() + { + _queryResultsRetryConfig = new QueryResultsRetryConfig(); + } + + internal QueryResultsAwaiter(QueryResultsRetryConfig queryResultsRetryConfig) + { + _queryResultsRetryConfig = queryResultsRetryConfig; + } + + internal QueryStatus GetQueryStatus(SnowflakeDbConnection connection, string queryId) + { + if (UuidRegex.IsMatch(queryId)) + { + var sfStatement = new SFStatement(connection.SfSession); + return sfStatement.GetQueryStatus(queryId); + } + else + { + var errorMessage = $"The given query id {queryId} is not valid uuid"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + internal async Task GetQueryStatusAsync(SnowflakeDbConnection connection, string queryId, CancellationToken cancellationToken) + { + if (UuidRegex.IsMatch(queryId)) + { + var sfStatement = new SFStatement(connection.SfSession); + return await sfStatement.GetQueryStatusAsync(queryId, cancellationToken).ConfigureAwait(false); + } + else + { + var errorMessage = $"The given query id {queryId} is not valid uuid"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + /// + /// Checks query status until it is done executing. + /// + /// + /// + /// + /// + internal async Task RetryUntilQueryResultIsAvailable(SnowflakeDbConnection connection, string queryId, CancellationToken cancellationToken, bool isAsync) + { + int retryPatternPos = 0; + int noDataCounter = 0; + + QueryStatus status; + while (true) + { + if (cancellationToken.IsCancellationRequested) + { + s_logger.Debug("Cancellation requested for getting results from query id"); + cancellationToken.ThrowIfCancellationRequested(); + } + + status = isAsync ? await GetQueryStatusAsync(connection, queryId, cancellationToken) : GetQueryStatus(connection, queryId); + + if (!QueryStatusExtensions.IsStillRunning(status)) + { + return; + } + + // Timeout based on query status retry rules + if (isAsync) + { + await Task.Delay(TimeSpan.FromSeconds(_queryResultsRetryConfig._asyncRetryPattern[retryPatternPos]), cancellationToken).ConfigureAwait(false); + } + else + { + Thread.Sleep(TimeSpan.FromSeconds(_queryResultsRetryConfig._asyncRetryPattern[retryPatternPos])); + } + + // If no data, increment the no data counter + if (status == QueryStatus.NoData) + { + noDataCounter++; + + // Check if retry for no data is exceeded + if (noDataCounter > _queryResultsRetryConfig._asyncNoDataMaxRetry) + { + var errorMessage = "Max retry for no data is reached"; + s_logger.Error(errorMessage); + throw new Exception(errorMessage); + } + } + + if (retryPatternPos < _queryResultsRetryConfig._asyncRetryPattern.Length - 1) + { + retryPatternPos++; + } + } + } + } +} diff --git a/Snowflake.Data/Core/RestParams.cs b/Snowflake.Data/Core/RestParams.cs index 9dd4de8c8..72b910847 100644 --- a/Snowflake.Data/Core/RestParams.cs +++ b/Snowflake.Data/Core/RestParams.cs @@ -39,6 +39,8 @@ internal static class RestPath internal const string SF_QUERY_PATH = "/queries/v1/query-request"; + internal const string SF_MONITOR_QUERY_PATH = "/monitoring/queries/"; + internal const string SF_SESSION_HEARTBEAT_PATH = SF_SESSION_PATH + "/heartbeat"; internal const string SF_CONSOLE_LOGIN = "/console/login"; diff --git a/Snowflake.Data/Core/RestRequest.cs b/Snowflake.Data/Core/RestRequest.cs index bc4ec8d21..112743f77 100644 --- a/Snowflake.Data/Core/RestRequest.cs +++ b/Snowflake.Data/Core/RestRequest.cs @@ -129,6 +129,8 @@ internal SFRestRequest() : base() internal bool _isLogin { get; set; } + internal bool _isStatusRequest { get; set; } + public override string ToString() { return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(), @@ -154,7 +156,7 @@ HttpRequestMessage IRestRequest.ToRequestMessage(HttpMethod method) // add quote otherwise it would be reported as error format string osInfo = "(" + SFEnvironment.ClientEnv.osVersion + ")"; - if (isPutGet) + if (isPutGet || _isStatusRequest) { message.Headers.Accept.Add(applicationJson); } @@ -313,6 +315,9 @@ class QueryRequest [JsonProperty(PropertyName = "queryContextDTO", NullValueHandling = NullValueHandling.Ignore)] internal RequestQueryContext QueryContextDTO { get; set; } + + [JsonProperty(PropertyName = "asyncExec")] + internal bool asyncExec { get; set; } } // The query context in query response diff --git a/Snowflake.Data/Core/RestResponse.cs b/Snowflake.Data/Core/RestResponse.cs index 97f3f9772..75f1698ea 100755 --- a/Snowflake.Data/Core/RestResponse.cs +++ b/Snowflake.Data/Core/RestResponse.cs @@ -423,6 +423,37 @@ internal class PutGetEncryptionMaterial internal long smkId { get; set; } } + internal class QueryStatusResponse : BaseRestResponse + { + + [JsonProperty(PropertyName = "data")] + internal QueryStatusData data { get; set; } + } + + internal class QueryStatusData + { + [JsonProperty(PropertyName = "queries", NullValueHandling = NullValueHandling.Ignore)] + internal List queries { get; set; } + } + + internal class QueryStatusDataQueries + { + [JsonProperty(PropertyName = "id", NullValueHandling = NullValueHandling.Ignore)] + internal string id { get; set; } + + [JsonProperty(PropertyName = "status", NullValueHandling = NullValueHandling.Ignore)] + internal string status { get; set; } + + [JsonProperty(PropertyName = "state", NullValueHandling = NullValueHandling.Ignore)] + internal string state { get; set; } + + [JsonProperty(PropertyName = "errorCode", NullValueHandling = NullValueHandling.Ignore)] + internal string errorCode { get; set; } + + [JsonProperty(PropertyName = "errorMessage", NullValueHandling = NullValueHandling.Ignore)] + internal string errorMessage { get; set; } + } + // Retrieved from: https://stackoverflow.com/a/18997172 internal class SingleOrArrayConverter : JsonConverter { diff --git a/Snowflake.Data/Core/SFBindUploader.cs b/Snowflake.Data/Core/SFBindUploader.cs index 68af7405b..a1b3f161d 100644 --- a/Snowflake.Data/Core/SFBindUploader.cs +++ b/Snowflake.Data/Core/SFBindUploader.cs @@ -290,7 +290,7 @@ private void CreateStage() try { SFStatement statement = new SFStatement(session); - SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false); + SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false, false); session.SetArrayBindStage(STAGE_NAME); } catch (Exception e) @@ -314,7 +314,7 @@ internal async Task CreateStageAsync(CancellationToken cancellationToken) try { SFStatement statement = new SFStatement(session); - var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, cancellationToken).ConfigureAwait(false); + var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, false, cancellationToken).ConfigureAwait(false); session.SetArrayBindStage(STAGE_NAME); } catch (Exception e) diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 55b069806..03e1794c9 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -28,7 +28,7 @@ public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, { try { - columnCount = responseData.rowType.Count; + columnCount = responseData.rowType?.Count ?? 0; this.sfStatement = sfStatement; UpdateSessionStatus(responseData); @@ -40,10 +40,10 @@ public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); } - _currentChunk = new SFResultChunk(responseData.rowSet); + _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; responseData.rowSet = null; - sfResultSetMetaData = new SFResultSetMetaData(responseData, this.sfStatement.SfSession); + sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; isClosed = false; diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index a8ada8af2..9252af40e 100644 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -3,13 +3,10 @@ */ using System; -using System.Web; -using Newtonsoft.Json.Linq; using System.Collections.Generic; using System.IO; using System.Linq; using Snowflake.Data.Client; -using Snowflake.Data.Core.FileTransfer; using Snowflake.Data.Log; using System.Threading; using System.Threading.Tasks; @@ -17,6 +14,87 @@ namespace Snowflake.Data.Core { + /// + /// The status types of the query. + /// + public enum QueryStatus + { + [StringAttr(value = "NO_DATA")] + NoData, + [StringAttr(value = "RUNNING")] + Running, + [StringAttr(value = "ABORTING")] + Aborting, + [StringAttr(value = "SUCCESS")] + Success, + [StringAttr(value = "FAILED_WITH_ERROR")] + FailedWithError, + [StringAttr(value = "ABORTED")] + Aborted, + [StringAttr(value = "QUEUED")] + Queued, + [StringAttr(value = "FAILED_WITH_INCIDENT")] + FailedWithIncident, + [StringAttr(value = "DISCONNECTED")] + Disconnected, + [StringAttr(value = "RESUMING_WAREHOUSE")] + ResumingWarehouse, + // purposeful typo + [StringAttr(value = "QUEUED_REPARING_WAREHOUSE")] + QueuedReparingWarehouse, + [StringAttr(value = "RESTARTED")] + Restarted, + [StringAttr(value = "BLOCKED")] + Blocked, + } + + class StringAttr : Attribute + { + public string value { get; set; } + } + + internal static class QueryStatusExtensions + { + internal static QueryStatus GetQueryStatusByStringValue(string stringValue) + { + var statuses = Enum.GetValues(typeof(QueryStatus)) + .Cast() + .Where(v => v.GetAttribute().value.Equals(stringValue, StringComparison.OrdinalIgnoreCase)); + return statuses.Any() ? statuses.First() : throw new Exception("The query status returned by the server is not recognized"); + } + + internal static bool IsStillRunning(QueryStatus status) + { + switch (status) + { + case QueryStatus.Running: + case QueryStatus.ResumingWarehouse: + case QueryStatus.Queued: + case QueryStatus.QueuedReparingWarehouse: + case QueryStatus.NoData: + return true; + default: + return false; + } + } + + internal static bool IsAnError(QueryStatus status) + { + switch (status) + { + case QueryStatus.Aborting: + case QueryStatus.FailedWithError: + case QueryStatus.Aborted: + case QueryStatus.FailedWithIncident: + case QueryStatus.Disconnected: + case QueryStatus.Blocked: + return true; + default: + return false; + } + } + } + class SFStatement { static private SFLogger logger = SFLoggerFactory.GetLogger(); @@ -90,7 +168,7 @@ private void ClearQueryRequestId() _requestId = null; } - private SFRestRequest BuildQueryRequest(string sql, Dictionary bindings, bool describeOnly) + private SFRestRequest BuildQueryRequest(string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { AssignQueryRequestId(); @@ -120,8 +198,9 @@ private SFRestRequest BuildQueryRequest(string sql, Dictionary private bool SessionExpired(BaseRestResponse r) => r.code == SFSession.SF_SESSION_EXPIRED_CODE; - internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, + internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec, CancellationToken cancellationToken) { // Trim the sql query and check if this is a PUT/GET command string trimmedSql = TrimSql(sql); - if (IsPutOrGetCommand(trimmedSql)) { + if (IsPutOrGetCommand(trimmedSql)) + { throw new NotImplementedException("Get and Put are not supported in async calls. Use Execute() instead of ExecuteAsync()."); } @@ -287,8 +367,8 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti logger.Warn("Exception encountered trying to upload binds to stage. Attaching binds in payload instead. {0}", e); } } - - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { QueryExecResponse response = null; @@ -309,19 +389,22 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti var lastResultUrl = response.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); - - if (SessionExpired(response)) - { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); - } - else + while (RequestInProgress(response) || SessionExpired(response)) { - lastResultUrl = response.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); + } + else + { + lastResultUrl = response.data?.getResultUrl; + } } } @@ -338,8 +421,8 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti ClearQueryRequestId(); } } - - internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly) + + internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { // Trim the sql query and check if this is a PUT/GET command string trimmedSql = TrimSql(sql); @@ -347,10 +430,14 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly) + private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { try { @@ -437,7 +524,8 @@ private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dicti timeout, sql, bindings, - describeOnly); + describeOnly, + asyncExec); return BuildResultSet(response, CancellationToken.None); } @@ -531,12 +619,13 @@ internal T ExecuteHelper( int timeout, string sql, Dictionary bindings, - bool describeOnly) + bool describeOnly, + bool asyncExec = false) where T : BaseQueryExecResponse where U : IQueryExecResponseData { registerQueryCancellationCallback(timeout, CancellationToken.None); - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { T response = null; @@ -558,20 +647,24 @@ internal T ExecuteHelper( if (typeof(T) == typeof(QueryExecResponse)) { QueryExecResponse queryResponse = (QueryExecResponse)(object)response; - var lastResultUrl = queryResponse.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = _restRequester.Get(req); + var lastResultUrl = queryResponse.data?.getResultUrl; - if (SessionExpired(response)) + while (RequestInProgress(response) || SessionExpired(response)) { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - SfSession.renewSession(); - } - else - { - lastResultUrl = queryResponse.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = _restRequester.Get(req); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + SfSession.renewSession(); + } + else + { + lastResultUrl = queryResponse.data?.getResultUrl; + } } } } @@ -612,13 +705,14 @@ internal async Task ExecuteAsyncHelper( string sql, Dictionary bindings, bool describeOnly, - CancellationToken cancellationToken + CancellationToken cancellationToken, + bool asyncExec = false ) where T : BaseQueryExecResponse where U : IQueryExecResponseData { registerQueryCancellationCallback(timeout, CancellationToken.None); - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); try { T response = null; @@ -640,20 +734,24 @@ CancellationToken cancellationToken if (typeof(T) == typeof(QueryExecResponse)) { QueryExecResponse queryResponse = (QueryExecResponse)(object)response; - var lastResultUrl = queryResponse.data?.getResultUrl; - while (RequestInProgress(response) || SessionExpired(response)) + if (!asyncExec) { - var req = BuildResultRequest(lastResultUrl); - response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + var lastResultUrl = queryResponse.data?.getResultUrl; - if (SessionExpired(response)) + while (RequestInProgress(response) || SessionExpired(response)) { - logger.Info("Ping pong request failed with session expired, trying to renew the session."); - await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); - } - else - { - lastResultUrl = queryResponse.data?.getResultUrl; + var req = BuildResultRequest(lastResultUrl); + response = await _restRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + + if (SessionExpired(response)) + { + logger.Info("Ping pong request failed with session expired, trying to renew the session."); + await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false); + } + else + { + lastResultUrl = queryResponse.data?.getResultUrl; + } } } } @@ -680,6 +778,138 @@ CancellationToken cancellationToken } } + /// + /// Creates a request to get the query status based on query ID. + /// + /// + /// The request to get the query status. + private SFRestRequest BuildQueryStatusRequest(string queryId) + { + var queryUri = SfSession.BuildUri(RestPath.SF_MONITOR_QUERY_PATH + queryId); + + return new SFRestRequest + { + Url = queryUri, + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken), + serviceName = SfSession.ParameterMap.ContainsKey(SFSessionParameter.SERVICE_NAME) + ? (String)SfSession.ParameterMap[SFSessionParameter.SERVICE_NAME] : null, + HttpTimeout = Timeout.InfiniteTimeSpan, + RestTimeout = Timeout.InfiniteTimeSpan, + sid = SfSession.sessionId, + _isStatusRequest = true + }; + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + internal QueryStatus GetQueryStatus(string queryId) + { + var queryRequest = BuildQueryStatusRequest(queryId); + + try + { + QueryStatusResponse response = null; + bool receivedFirstQueryResponse = false; + while (!receivedFirstQueryResponse) + { + response = _restRequester.Get(queryRequest); + if (SessionExpired(response)) + { + SfSession.renewSession(); + queryRequest.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken); + } + else + { + receivedFirstQueryResponse = true; + } + } + + if (!response.success) + { + throw new SnowflakeDbException( + response.data.queries[0].state, + response.code, + response.message, + queryId); + } + + QueryStatus queryStatus = QueryStatus.NoData; + if (response.data.queries.Count != 0) + { + queryStatus = QueryStatusExtensions.GetQueryStatusByStringValue(response.data.queries[0].status); + } + + return queryStatus; + } + catch + { + logger.Error("Query execution failed."); + throw; + } + finally + { + ClearQueryRequestId(); + } + } + + /// + /// Gets the query status based on query ID. + /// + /// + /// The query status. + internal async Task GetQueryStatusAsync(string queryId, CancellationToken cancellationToken) + { + var queryRequest = BuildQueryStatusRequest(queryId); + + try + { + QueryStatusResponse response = null; + bool receivedFirstQueryResponse = false; + while (!receivedFirstQueryResponse) + { + response = await _restRequester.GetAsync(queryRequest, cancellationToken).ConfigureAwait(false); + if (SessionExpired(response)) + { + SfSession.renewSession(); + queryRequest.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken); + } + else + { + receivedFirstQueryResponse = true; + } + } + + if (!response.success) + { + throw new SnowflakeDbException( + response.data.queries[0].state, + response.code, + response.message, + queryId); + } + + QueryStatus queryStatus = QueryStatus.NoData; + if (response.data.queries.Count != 0) + { + queryStatus = QueryStatusExtensions.GetQueryStatusByStringValue(response.data.queries[0].status); + } + + return queryStatus; + } + catch + { + logger.Error("Query execution failed."); + throw; + } + finally + { + ClearQueryRequestId(); + } + } + /// /// Trim the query by removing spaces and comments at the beginning. ///