Skip to content

Commit

Permalink
SNOW-817091: Add async execution
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-lf committed Mar 9, 2024
1 parent fd193e1 commit 727b460
Show file tree
Hide file tree
Showing 9 changed files with 556 additions and 70 deletions.
190 changes: 185 additions & 5 deletions Snowflake.Data/Client/SnowflakeDbCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,27 @@
using System.Threading.Tasks;
using Newtonsoft.Json;
using Snowflake.Data.Log;
using System.Text.RegularExpressions;

namespace Snowflake.Data.Client
{
[System.ComponentModel.DesignerCategory("Code")]
public class SnowflakeDbCommand : DbCommand
{
private DbConnection connection;
private SnowflakeDbConnection connection;

private SFStatement sfStatement;

private SnowflakeDbParameterCollection parameterCollection;

private SFLogger logger = SFLoggerFactory.GetLogger<SnowflakeDbCommand>();

// Async max retry and retry pattern
private const int AsyncNoDataMaxRetry = 24;
private readonly int[] _asyncRetryPattern = { 1, 1, 2, 3, 4, 8, 10 };

private static readonly Regex UuidRegex = new Regex("^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$");

public SnowflakeDbCommand()
{
logger.Debug("Constructing SnowflakeDbCommand class");
Expand Down Expand Up @@ -274,6 +281,179 @@ protected override async Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha
}
}

/// <summary>
/// Execute a query in async mode.
/// </summary>
/// <returns>The query id.</returns>
public string ExecuteInAsyncMode()
{
logger.Debug($"ExecuteInAsyncMode");
SFBaseResultSet resultSet = ExecuteInternal(asyncExec: true);
return resultSet.queryId;
}

/// <summary>
/// Executes an asynchronous query in async mode.
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns>The query id.</returns>
public async Task<string> ExecuteAsyncInAsyncMode(CancellationToken cancellationToken)
{
logger.Debug($"ExecuteAsyncInAsyncMode");
var resultSet = await ExecuteInternalAsync(cancellationToken, asyncExec: true).ConfigureAwait(false);
return resultSet.queryId;
}

/// <summary>
/// Gets the query status based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <returns>The query status.</returns>
public QueryStatus GetQueryStatus(string queryId)
{
logger.Debug($"GetQueryStatus");

if (UuidRegex.IsMatch(queryId))
{
var sfStatement = new SFStatement(connection.SfSession);
return sfStatement.GetQueryStatus(queryId);
}
else
{
var errorMessage = $"The given query id {queryId} is not valid uuid";
logger.Error(errorMessage);
throw new Exception(errorMessage);
}

}

/// <summary>
/// Gets the query status based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <param name="cancellationToken"></param>
/// <returns>The query status.</returns>
public async Task<QueryStatus> GetQueryStatusAsync(string queryId, CancellationToken cancellationToken)
{
logger.Debug($"GetQueryStatusAsync");

// Check if queryId is valid uuid
if (UuidRegex.IsMatch(queryId))
{
var sfStatement = new SFStatement(connection.SfSession);
return await sfStatement.GetQueryStatusAsync(queryId, cancellationToken).ConfigureAwait(false);
}
else
{
var errorMessage = $"The given query id {queryId} is not valid uuid";
logger.Error(errorMessage);
throw new Exception(errorMessage);
}
}

/// <summary>
/// Gets the query results based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <returns>The query results.</returns>
public DbDataReader GetResultsFromQueryId(string queryId)
{
logger.Debug($"GetResultsFromQueryId");

int retryPatternPos = 0;
int noDataCounter = 0;

QueryStatus status;
while (true)
{
status = GetQueryStatus(queryId);

if (!QueryStatuses.IsStillRunning(status))
{
break;
}

// Timeout based on query status retry rules
Thread.Sleep(TimeSpan.FromSeconds(_asyncRetryPattern[retryPatternPos]));

// If no data, increment the no data counter
if (status == QueryStatus.NO_DATA)
{
noDataCounter++;

// Check if retry for no data is exceeded
if (noDataCounter > AsyncNoDataMaxRetry)
{
var errorMessage = "Max retry for no data is reached";
logger.Error(errorMessage);
throw new Exception(errorMessage);
}
}

if (retryPatternPos < _asyncRetryPattern.Length - 1)
{
retryPatternPos++;
}
}

connection.SfSession.AsyncQueries.Remove(queryId);
SFBaseResultSet resultSet = sfStatement.GetResultWithId(queryId);

return new SnowflakeDbDataReader(this, resultSet);
}

/// <summary>
/// Gets the query results based on query ID.
/// </summary>
/// <param name="queryId"></param>
/// <param name="cancellationToken"></param>
/// <returns>The query results.</returns>
public async Task<DbDataReader> GetResultsFromQueryIdAsync(string queryId, CancellationToken cancellationToken)
{
logger.Debug($"GetResultsFromQueryIdAsync");

int retryPatternPos = 0;
int noDataCounter = 0;

QueryStatus status;
while (true)
{
status = GetQueryStatus(queryId);

if (!QueryStatuses.IsStillRunning(status))
{
break;
}

// Timeout based on query status retry rules
await Task.Delay(TimeSpan.FromSeconds(_asyncRetryPattern[retryPatternPos]), cancellationToken).ConfigureAwait(false);

// If no data, increment the no data counter
if (status == QueryStatus.NO_DATA)
{
noDataCounter++;

// Check if retry for no data is exceeded
if (noDataCounter > AsyncNoDataMaxRetry)
{
var errorMessage = "Max retry for no data is reached";
logger.Error(errorMessage);
throw new Exception(errorMessage);
}
}

if (retryPatternPos < _asyncRetryPattern.Length - 1)
{
retryPatternPos++;
}
}

connection.SfSession.AsyncQueries.Remove(queryId);
SFBaseResultSet resultSet = await sfStatement.GetResultWithIdAsync(queryId, cancellationToken).ConfigureAwait(false);

return new SnowflakeDbDataReader(this, resultSet);
}

private static Dictionary<string, BindingDTO> convertToBindList(List<SnowflakeDbParameter> parameters)
{
if (parameters == null || parameters.Count == 0)
Expand Down Expand Up @@ -354,18 +534,18 @@ private void SetStatement()
this.sfStatement = new SFStatement(session);
}

private SFBaseResultSet ExecuteInternal(bool describeOnly = false)
private SFBaseResultSet ExecuteInternal(bool describeOnly = false, bool asyncExec = false)
{
CheckIfCommandTextIsSet();
SetStatement();
return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly);
return sfStatement.Execute(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec);
}

private Task<SFBaseResultSet> ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false)
private Task<SFBaseResultSet> ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false, bool asyncExec = false)
{
CheckIfCommandTextIsSet();
SetStatement();
return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, cancellationToken);
return sfStatement.ExecuteAsync(CommandTimeout, CommandText, convertToBindList(parameterCollection.parameterList), describeOnly, asyncExec, cancellationToken);
}

private void CheckIfCommandTextIsSet()
Expand Down
17 changes: 11 additions & 6 deletions Snowflake.Data/Client/SnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ public override void Close()
{
var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;

if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession))
if (CanReuseSession(transactionRollbackStatus) &&
SfSession.StillRunningAsyncQueries() &&
SnowflakeDbConnectionPool.AddSession(SfSession)
)
{
logger.Debug($"Session pooled: {SfSession.sessionId}");
}
Expand All @@ -176,7 +179,7 @@ public override async Task CloseAsync()
}
#endif

public virtual Task CloseAsync(CancellationToken cancellationToken)
public virtual async Task CloseAsync(CancellationToken cancellationToken)
{
logger.Debug("Close Connection.");
TaskCompletionSource<object> taskCompletionSource = new TaskCompletionSource<object>();
Expand All @@ -191,15 +194,17 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
{
var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined;

if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession))
if (CanReuseSession(transactionRollbackStatus) &&
await SfSession.StillRunningAsyncQueriesAsync(cancellationToken).ConfigureAwait(false) &&
SnowflakeDbConnectionPool.AddSession(SfSession))
{
logger.Debug($"Session pooled: {SfSession.sessionId}");
_connectionState = ConnectionState.Closed;
taskCompletionSource.SetResult(null);
}
else
{
SfSession.CloseAsync(cancellationToken).ContinueWith(
await SfSession.CloseAsync(cancellationToken).ContinueWith(
previousTask =>
{
if (previousTask.IsFaulted)
Expand All @@ -220,7 +225,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
_connectionState = ConnectionState.Closed;
taskCompletionSource.SetResult(null);
}
}, cancellationToken);
}, cancellationToken).ConfigureAwait(false);
}
}
else
Expand All @@ -229,7 +234,7 @@ public virtual Task CloseAsync(CancellationToken cancellationToken)
taskCompletionSource.SetResult(null);
}
}
return taskCompletionSource.Task;
await taskCompletionSource.Task;
}

protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus)
Expand Down
2 changes: 2 additions & 0 deletions Snowflake.Data/Core/RestParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ internal static class RestPath

internal const string SF_QUERY_PATH = "/queries/v1/query-request";

internal const string SF_MONITOR_QUERY_PATH = "/monitoring/queries/";

internal const string SF_SESSION_HEARTBEAT_PATH = SF_SESSION_PATH + "/heartbeat";

internal const string SF_CONSOLE_LOGIN = "/console/login";
Expand Down
7 changes: 6 additions & 1 deletion Snowflake.Data/Core/RestRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ internal SFRestRequest() : base()

internal bool _isLogin { get; set; }

internal bool _isStatusRequest { get; set; }

public override string ToString()
{
return String.Format("SFRestRequest {{url: {0}, request body: {1} }}", Url.ToString(),
Expand All @@ -154,7 +156,7 @@ HttpRequestMessage IRestRequest.ToRequestMessage(HttpMethod method)
// add quote otherwise it would be reported as error format
string osInfo = "(" + SFEnvironment.ClientEnv.osVersion + ")";

if (isPutGet)
if (isPutGet || _isStatusRequest)
{
message.Headers.Accept.Add(applicationJson);
}
Expand Down Expand Up @@ -313,6 +315,9 @@ class QueryRequest

[JsonProperty(PropertyName = "queryContextDTO", NullValueHandling = NullValueHandling.Ignore)]
internal RequestQueryContext QueryContextDTO { get; set; }

[JsonProperty(PropertyName = "asyncExec")]
internal bool asyncExec { get; set; }
}

// The query context in query response
Expand Down
31 changes: 31 additions & 0 deletions Snowflake.Data/Core/RestResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,37 @@ internal class PutGetEncryptionMaterial
internal long smkId { get; set; }
}

internal class QueryStatusResponse : BaseRestResponse
{

[JsonProperty(PropertyName = "data")]
internal QueryStatusData data { get; set; }
}

internal class QueryStatusData
{
[JsonProperty(PropertyName = "queries", NullValueHandling = NullValueHandling.Ignore)]
internal List<QueryStatusDataQueries> queries { get; set; }
}

internal class QueryStatusDataQueries
{
[JsonProperty(PropertyName = "id", NullValueHandling = NullValueHandling.Ignore)]
internal string id { get; set; }

[JsonProperty(PropertyName = "status", NullValueHandling = NullValueHandling.Ignore)]
internal string status { get; set; }

[JsonProperty(PropertyName = "state", NullValueHandling = NullValueHandling.Ignore)]
internal string state { get; set; }

[JsonProperty(PropertyName = "errorCode", NullValueHandling = NullValueHandling.Ignore)]
internal string errorCode { get; set; }

[JsonProperty(PropertyName = "errorMessage", NullValueHandling = NullValueHandling.Ignore)]
internal string errorMessage { get; set; }
}

// Retrieved from: https://stackoverflow.com/a/18997172
internal class SingleOrArrayConverter<T> : JsonConverter
{
Expand Down
4 changes: 2 additions & 2 deletions Snowflake.Data/Core/SFBindUploader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ private void CreateStage()
try
{
SFStatement statement = new SFStatement(session);
SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false);
SFBaseResultSet resultSet = statement.Execute(0, CREATE_STAGE_STMT, null, false, false);
session.SetArrayBindStage(STAGE_NAME);
}
catch (Exception e)
Expand All @@ -314,7 +314,7 @@ internal async Task CreateStageAsync(CancellationToken cancellationToken)
try
{
SFStatement statement = new SFStatement(session);
var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, cancellationToken).ConfigureAwait(false);
var resultSet = await statement.ExecuteAsync(0, CREATE_STAGE_STMT, null, false, false, cancellationToken).ConfigureAwait(false);
session.SetArrayBindStage(STAGE_NAME);
}
catch (Exception e)
Expand Down
Loading

0 comments on commit 727b460

Please sign in to comment.