Skip to content

Commit

Permalink
SNOW-817091: Async execution (#887)
Browse files Browse the repository at this point in the history
### Description
- Adds the option to execute queries in async mode
- Adds the capability to asynchronously wait for the query to finish and
get the results using the query ID
- Adds checking the query status
- Adds checking if query is still running or encountered an error

### Checklist
- [ ] Code compiles correctly
- [ ] Code is formatted according to [Coding
Conventions](../CodingConventions.md)
- [ ] Created tests which fail without the change (if possible)
- [ ] All tests passing (`dotnet test`)
- [ ] Extended the README / documentation, if necessary
- [ ] Provide JIRA issue id (if possible) or GitHub issue id in PR name
  • Loading branch information
sfc-gh-ext-simba-lf authored Mar 26, 2024
1 parent 8fa5df5 commit 5aec6b4
Show file tree
Hide file tree
Showing 12 changed files with 1,247 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ void ThreadProcess2(string connstr)

Thread.Sleep(5000);
SFStatement statement = new SFStatement(conn1.SfSession);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
SnowflakeDbConnectionPool.ClearAllPools();
Expand Down
584 changes: 584 additions & 0 deletions Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs

Large diffs are not rendered by default.

95 changes: 88 additions & 7 deletions Snowflake.Data.Tests/UnitTests/SFStatementTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace Snowflake.Data.Tests.UnitTests
{
using Snowflake.Data.Core;
using NUnit.Framework;
using System;

/**
* Mock rest request test
Expand All @@ -21,7 +22,7 @@ public void TestSessionRenew()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
Assert.AreEqual("new_session_token", sfSession.sessionToken);
Expand All @@ -37,7 +38,7 @@ public void TestSessionRenewDuringQueryExec()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}
Expand All @@ -57,7 +58,7 @@ public void TestServiceName()
for (int i = 0; i < 5; i++)
{
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false, false);
expectServiceName += "a";
Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]);
}
Expand All @@ -73,7 +74,7 @@ public void TestTrimSqlBlockComment()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}
Expand All @@ -88,7 +89,7 @@ public void TestTrimSqlBlockCommentMultiline()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}
Expand All @@ -103,7 +104,7 @@ public void TestTrimSqlLineComment()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}
Expand All @@ -118,9 +119,89 @@ public void TestTrimSqlLineCommentWithClosingNewline()
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
}

[Test]
[TestCase("running", QueryStatus.Running)]
[TestCase("RUNNING", QueryStatus.Running)]
[TestCase("resuming_warehouse", QueryStatus.ResumingWarehouse)]
[TestCase("RESUMING_WAREHOUSE", QueryStatus.ResumingWarehouse)]
[TestCase("queued", QueryStatus.Queued)]
[TestCase("QUEUED", QueryStatus.Queued)]
[TestCase("queued_reparing_warehouse", QueryStatus.QueuedReparingWarehouse)]
[TestCase("QUEUED_REPARING_WAREHOUSE", QueryStatus.QueuedReparingWarehouse)]
[TestCase("no_data", QueryStatus.NoData)]
[TestCase("NO_DATA", QueryStatus.NoData)]
[TestCase("aborting", QueryStatus.Aborting)]
[TestCase("ABORTING", QueryStatus.Aborting)]
[TestCase("success", QueryStatus.Success)]
[TestCase("SUCCESS", QueryStatus.Success)]
[TestCase("failed_with_error", QueryStatus.FailedWithError)]
[TestCase("FAILED_WITH_ERROR", QueryStatus.FailedWithError)]
[TestCase("aborted", QueryStatus.Aborted)]
[TestCase("ABORTED", QueryStatus.Aborted)]
[TestCase("failed_with_incident", QueryStatus.FailedWithIncident)]
[TestCase("FAILED_WITH_INCIDENT", QueryStatus.FailedWithIncident)]
[TestCase("disconnected", QueryStatus.Disconnected)]
[TestCase("DISCONNECTED", QueryStatus.Disconnected)]
[TestCase("restarted", QueryStatus.Restarted)]
[TestCase("RESTARTED", QueryStatus.Restarted)]
[TestCase("blocked", QueryStatus.Blocked)]
[TestCase("BLOCKED", QueryStatus.Blocked)]
public void TestGetQueryStatusByStringValue(string stringValue, QueryStatus expectedStatus)
{
Assert.AreEqual(expectedStatus, QueryStatusExtensions.GetQueryStatusByStringValue(stringValue));
}

[Test]
[TestCase("UNKNOWN")]
[TestCase("RANDOM_STATUS")]
[TestCase("aBcZyX")]
public void TestGetQueryStatusByStringValueThrowsErrorForUnknownStatus(string stringValue)
{
var thrown = Assert.Throws<Exception>(() => QueryStatusExtensions.GetQueryStatusByStringValue(stringValue));
Assert.IsTrue(thrown.Message.Contains("The query status returned by the server is not recognized"));
}

[Test]
[TestCase(QueryStatus.Running, true)]
[TestCase(QueryStatus.ResumingWarehouse, true)]
[TestCase(QueryStatus.Queued, true)]
[TestCase(QueryStatus.QueuedReparingWarehouse, true)]
[TestCase(QueryStatus.NoData, true)]
[TestCase(QueryStatus.Aborting, false)]
[TestCase(QueryStatus.Success, false)]
[TestCase(QueryStatus.FailedWithError, false)]
[TestCase(QueryStatus.Aborted, false)]
[TestCase(QueryStatus.FailedWithIncident, false)]
[TestCase(QueryStatus.Disconnected, false)]
[TestCase(QueryStatus.Restarted, false)]
[TestCase(QueryStatus.Blocked, false)]
public void TestIsStillRunning(QueryStatus status, bool expectedResult)
{
Assert.AreEqual(expectedResult, QueryStatusExtensions.IsStillRunning(status));
}

[Test]
[TestCase(QueryStatus.Aborting, true)]
[TestCase(QueryStatus.FailedWithError, true)]
[TestCase(QueryStatus.Aborted, true)]
[TestCase(QueryStatus.FailedWithIncident, true)]
[TestCase(QueryStatus.Disconnected, true)]
[TestCase(QueryStatus.Blocked, true)]
[TestCase(QueryStatus.Running, false)]
[TestCase(QueryStatus.ResumingWarehouse, false)]
[TestCase(QueryStatus.Queued, false)]
[TestCase(QueryStatus.QueuedReparingWarehouse, false)]
[TestCase(QueryStatus.NoData, false)]
[TestCase(QueryStatus.Success, false)]
[TestCase(QueryStatus.Restarted, false)]
public void TestIsAnError(QueryStatus status, bool expectedResult)
{
Assert.AreEqual(expectedResult, QueryStatusExtensions.IsAnError(status));
}
}
}
95 changes: 89 additions & 6 deletions Snowflake.Data/Client/SnowflakeDbCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,23 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Snowflake.Data.Log;

namespace Snowflake.Data.Client
{
[System.ComponentModel.DesignerCategory("Code")]
public class SnowflakeDbCommand : DbCommand
{
private DbConnection connection;
private SnowflakeDbConnection connection;

private SFStatement sfStatement;

private SnowflakeDbParameterCollection parameterCollection;

private SFLogger logger = SFLoggerFactory.GetLogger<SnowflakeDbCommand>();

private readonly QueryResultsAwaiter _queryResultsAwaiter = QueryResultsAwaiter.Instance;

public SnowflakeDbCommand()
{
logger.Debug("Constructing SnowflakeDbCommand class");
Expand Down Expand Up @@ -274,6 +275,88 @@ protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha
}
}

/// <summary>
/// Execute a query in async mode.
/// Async mode means the server will respond immediately with the query ID and execute the query asynchronously
/// </summary>
/// <returns>The query id.</returns>
public string ExecuteInAsyncMode()
{
logger.Debug($"ExecuteInAsyncMode");
SFBaseResultSet resultSet = ExecuteInternal(asyncExec: true);
return resultSet.queryId;
}

/// <summary>
/// Executes an asynchronous query in async mode.
/// Async mode means the server will respond immediately with the query ID and execute the query asynchronously
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns>The query id.</returns>
public async Task<string> ExecuteAsyncInAsyncMode(CancellationToken cancellationToken)
{
logger.Debug($"ExecuteAsyncInAsyncMode");
var resultSet = await ExecuteInternalAsync(cancellationToken, asyncExec: true).ConfigureAwait(false);
return resultSet.queryId;
}

/// <summary>
/// Gets the query status based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <returns>The query status.</returns>
public QueryStatus GetQueryStatus(string queryId)
{
logger.Debug($"GetQueryStatus");
return _queryResultsAwaiter.GetQueryStatus(connection, queryId);
}

/// <summary>
/// Gets the query status based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <param name="cancellationToken"></param>
/// <returns>The query status.</returns>
public async Task<QueryStatus> GetQueryStatusAsync(string queryId, CancellationToken cancellationToken)
{
logger.Debug($"GetQueryStatusAsync");
return await _queryResultsAwaiter.GetQueryStatusAsync(connection, queryId, cancellationToken);
}

/// <summary>
/// Gets the query results based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <returns>The query results.</returns>
public DbDataReader GetResultsFromQueryId(string queryId)
{
logger.Debug($"GetResultsFromQueryId");

Task task = _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, CancellationToken.None, false);
task.Wait();

SFBaseResultSet resultSet = sfStatement.GetResultWithId(queryId);

return new SnowflakeDbDataReader(this, resultSet);
}

/// <summary>
/// Gets the query results based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <param name="cancellationToken"></param>
/// <returns>The query results.</returns>
public async Task<DbDataReader> GetResultsFromQueryIdAsync(string queryId, CancellationToken cancellationToken)
{
logger.Debug($"GetResultsFromQueryIdAsync");

await _queryResultsAwaiter.RetryUntilQueryResultIsAvailable(connection, queryId, cancellationToken, true);

SFBaseResultSet resultSet = await sfStatement.GetResultWithIdAsync(queryId, cancellationToken).ConfigureAwait(false);

return new SnowflakeDbDataReader(this, resultSet);
}

private static Dictionary<string, BindingDTO> convertToBindList(List<SnowflakeDbParameter> parameters)
{
if (parameters == null || parameters.Count == 0)
Expand Down Expand Up @@ -354,18 +437,18 @@ private void SetStatement()
this.sfStatement = new SFStatement(session);
}

private SFBaseResultSet ExecuteInternal(bool describeOnly = false)
private SFBaseResultSet ExecuteInternal(bool describeOnly = false, bool asyncExec = false)
{
CheckIfCommandTextIsSet();
SetStatement();
return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly);
return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec);
}

private Task<SFBaseResultSet> ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false)
private Task<SFBaseResultSet> ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false, bool asyncExec = false)
{
CheckIfCommandTextIsSet();
SetStatement();
return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, cancellationToken);
return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec, cancellationToken);
}

private void CheckIfCommandTextIsSet()
Expand Down
18 changes: 14 additions & 4 deletions Snowflake.Data/Client/SnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public override async Task CloseAsync()
}
#endif

public virtual Task CloseAsync(CancellationToken cancellationToken)
public virtual async Task CloseAsync(CancellationToken cancellationToken)
{
logger.Debug("Close Connection.");
TaskCompletionSource<object> taskCompletionSource = new TaskCompletionSource<object>();
Expand All @@ -199,7 +199,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
}
else
{
SfSession.CloseAsync(cancellationToken).ContinueWith(
await SfSession.CloseAsync(cancellationToken).ContinueWith(
previousTask =>
{
if (previousTask.IsFaulted)
Expand All @@ -220,7 +220,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
_connectionState = ConnectionState.Closed;
taskCompletionSource.SetResult(null);
}
}, cancellationToken);
}, cancellationToken).ConfigureAwait(false);
}
}
else
Expand All @@ -229,7 +229,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
taskCompletionSource.SetResult(null);
}
}
return taskCompletionSource.Task;
await taskCompletionSource.Task;
}

protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus)
Expand Down Expand Up @@ -402,6 +402,16 @@ internal void registerConnectionCancellationCallback(CancellationToken externalC
}
}

public bool IsStillRunning(QueryStatus status)
{
return QueryStatusExtensions.IsStillRunning(status);
}

public bool IsAnError(QueryStatus status)
{
return QueryStatusExtensions.IsAnError(status);
}

~SnowflakeDbConnection()
{
Dispose(false);
Expand Down
Loading

0 comments on commit 5aec6b4

Please sign in to comment.