diff --git a/Snowflake.Data.Tests/SFDbCommandAsynchronous.cs b/Snowflake.Data.Tests/SFDbCommandAsynchronous.cs new file mode 100644 index 0000000..e8584b9 --- /dev/null +++ b/Snowflake.Data.Tests/SFDbCommandAsynchronous.cs @@ -0,0 +1,122 @@ +using NUnit.Framework; +using System.Data; + +namespace Tortuga.Data.Snowflake.Tests; + +[TestFixture] +class SFDbCommandAsynchronous : SFBaseTest +{ + SnowflakeDbConnection StartSnowflakeConnection() + { + var conn = new SnowflakeDbConnection(); + conn.ConnectionString = ConnectionString; + + conn.Open(); + + return conn; + } + + [Test] + public void TestLongRunningQuery() + { + string queryId; + using (var conn = StartSnowflakeConnection()) + { + using (var cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + cmd.CommandText = "select count(seq4()) from table(generator(timelimit => 15)) v order by 1"; + var status = cmd.StartAsynchronousQuery(); + Assert.False(status.IsQueryDone); + Assert.False(status.IsQuerySuccessful); + queryId = status.QueryId; + } + + Assert.IsNotEmpty(queryId); + } + + // start a new connection to make sure works across sessions + using (var conn = StartSnowflakeConnection()) + { + SnowflakeDbQueryStatus status; + do + { + status = SnowflakeDbAsynchronousQueryHelper.GetQueryStatus(conn, queryId); + if (status.IsQueryDone) + { + break; + } + else + { + Assert.False(status.IsQuerySuccessful); + } + + Thread.Sleep(5000); + } while (true); + + // once it finished, it should be successfull + Assert.True(status.IsQuerySuccessful); + } + + // start a new connection to make sure works across sessions + using (var conn = StartSnowflakeConnection()) + { + using (var cmd = SnowflakeDbAsynchronousQueryHelper.CreateQueryResultsCommand(conn, queryId)) + { + using (IDataReader reader = cmd.ExecuteReader()) + { + // only one result is returned + Assert.IsTrue(reader.Read()); + } + } + + conn.Close(); + } + } + + [Test] + public void TestSimpleCommand() + { + string queryId; + + using (var conn = StartSnowflakeConnection()) + { + using (var cmd = (SnowflakeDbCommand)conn.CreateCommand()) + { + cmd.CommandText = "select 1"; + + var status = cmd.StartAsynchronousQuery(); + // even a fast asynchronous call will not be done initially + Assert.False(status.IsQueryDone); + Assert.False(status.IsQuerySuccessful); + queryId = status.QueryId; + + Assert.IsNotEmpty(queryId); + } + } + + // start a new connection to make sure works across sessions + using (var conn = StartSnowflakeConnection()) + { + SnowflakeDbQueryStatus status; + status = SnowflakeDbAsynchronousQueryHelper.GetQueryStatus(conn, queryId); + // since query is so fast, expect it to be done the first time we check the status + Assert.True(status.IsQueryDone); + Assert.True(status.IsQuerySuccessful); + } + + // start a new connection to make sure works across sessions + using (var conn = StartSnowflakeConnection()) + { + // because this query is so quick, we do not need to check the status before fetching the result + + using (var cmd = SnowflakeDbAsynchronousQueryHelper.CreateQueryResultsCommand(conn, queryId)) + { + var val = cmd.ExecuteScalar(); + + Assert.AreEqual(1L, (long)val); + } + + conn.Close(); + } + } +} diff --git a/Snowflake.Data.Tests/SFStatementTest.cs b/Snowflake.Data.Tests/SFStatementTest.cs index 7b1413a..c2979b4 100755 --- a/Snowflake.Data.Tests/SFStatementTest.cs +++ b/Snowflake.Data.Tests/SFStatementTest.cs @@ -24,7 +24,7 @@ public void TestSessionRenew() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester, SnowflakeDbConfiguration.Default); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); Assert.AreEqual("new_session_token", sfSession.m_SessionToken); @@ -40,7 +40,7 @@ public void TestSessionRenewDuringQueryExec() SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester, SnowflakeDbConfiguration.Default); sfSession.Open(); SFStatement statement = new SFStatement(sfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); } @@ -60,7 +60,7 @@ public void TestServiceName() for (int i = 0; i < 5; i++) { var statement = new SFStatement(sfSession); - statement.Execute(0, "SELECT 1", null, false); + statement.Execute(0, "SELECT 1", null, false, false); expectServiceName += "a"; Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]); } diff --git a/Snowflake.Data/Core/Messages/QueryRequest.cs b/Snowflake.Data/Core/Messages/QueryRequest.cs index 041ed79..442d84c 100644 --- a/Snowflake.Data/Core/Messages/QueryRequest.cs +++ b/Snowflake.Data/Core/Messages/QueryRequest.cs @@ -16,4 +16,10 @@ class QueryRequest [JsonProperty(PropertyName = "bindings")] internal Dictionary? ParameterBindings { get; set; } + + /// + /// indicates whether query should be asynchronous + /// + [JsonProperty(PropertyName = "asyncExec")] + internal bool asyncExec { get; set; } } diff --git a/Snowflake.Data/Core/RequestProcessing/SFStatement.cs b/Snowflake.Data/Core/RequestProcessing/SFStatement.cs index 062f2ce..6c41f6a 100644 --- a/Snowflake.Data/Core/RequestProcessing/SFStatement.cs +++ b/Snowflake.Data/Core/RequestProcessing/SFStatement.cs @@ -63,7 +63,7 @@ private void ClearQueryRequestId() m_RequestId = null; } - private SFRestRequest BuildQueryRequest(string sql, Dictionary? bindings, bool describeOnly) + private SFRestRequest BuildQueryRequest(string sql, Dictionary? bindings, bool describeOnly, bool asyncExec) { AssignQueryRequestId(); @@ -83,6 +83,7 @@ private SFRestRequest BuildQueryRequest(string sql, Dictionary r?.Code == SF_SESSION_EXPIRED_CODE; - internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, CancellationToken cancellationToken) + static string BuildQueryResultUrl(string queryId) + { + return $"/queries/{queryId}/result"; + } + + internal async Task CheckQueryStatusAsync(int timeout, string queryId, CancellationToken cancellationToken) { RegisterQueryCancellationCallback(timeout, cancellationToken); - var queryRequest = BuildQueryRequest(sql, bindings, describeOnly); + // rest api + var lastResultUrl = BuildQueryResultUrl(queryId); + //// sql api + //var lastResultUrl = $"/api/statements/{queryId}"; try { QueryExecResponse? response = null; - var receivedFirstQueryResponse = false; + bool receivedFirstQueryResponse = false; while (!receivedFirstQueryResponse) { - response = await m_RestRequester.PostAsync(queryRequest, cancellationToken).ConfigureAwait(false); + var req = BuildResultRequest(lastResultUrl); + response = await m_RestRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); if (SessionExpired(response)) { SFSession.renewSession(); - queryRequest.AuthorizationToken = string.Format(CultureInfo.InvariantCulture, SF_AUTHORIZATION_SNOWFLAKE_FMT, SFSession.m_SessionToken); } else { @@ -188,20 +197,117 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti } } - var lastResultUrl = response!.Data!.GetResultUrl; + var d = BuildQueryStatusFromQueryResponse(response!); + SFSession.UpdateAsynchronousQueryStatus(queryId, d); + return d; + } + finally + { + CleanUpCancellationTokenSources(); + ClearQueryRequestId(); + } + } + + internal static SnowflakeDbQueryStatus BuildQueryStatusFromQueryResponse(QueryExecResponse response) + { + var isDone = !RequestInProgress(response); + var d = new SnowflakeDbQueryStatus(response.Data!.QueryId! + , isDone + // only consider to be successful if also done + , isDone && response.Success); + return d; + } - while (RequestInProgress(response) || SessionExpired(response)) + /// + /// Fetches the result of a query that has already been executed. + /// + /// + /// + /// + /// + internal async Task GetQueryResultAsync(int timeout, string queryId + , CancellationToken cancellationToken) + { + RegisterQueryCancellationCallback(timeout, cancellationToken); + // rest api + var lastResultUrl = BuildQueryResultUrl(queryId); + try + { + QueryExecResponse? response = null; + + bool receivedFirstQueryResponse = false; + + while (!receivedFirstQueryResponse || RequestInProgress(response) || SessionExpired(response)) { - var req = BuildResultRequest(lastResultUrl!); + var req = BuildResultRequest(lastResultUrl); response = await m_RestRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + receivedFirstQueryResponse = true; + + if (SessionExpired(response)) + { + SFSession.renewSession(); + } + else + { + lastResultUrl = response.Data?.GetResultUrl!; + } + } + + return BuildResultSet(response!, cancellationToken); + } + finally + { + CleanUpCancellationTokenSources(); + ClearQueryRequestId(); + } + } + internal async Task ExecuteAsync(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec, CancellationToken cancellationToken) + { + RegisterQueryCancellationCallback(timeout, cancellationToken); + var queryRequest = BuildQueryRequest(sql, bindings, describeOnly, asyncExec); + try + { + QueryExecResponse? response = null; + var receivedFirstQueryResponse = false; + while (!receivedFirstQueryResponse) + { + response = await m_RestRequester.PostAsync(queryRequest, cancellationToken).ConfigureAwait(false); if (SessionExpired(response)) + { SFSession.renewSession(); + queryRequest.AuthorizationToken = string.Format(CultureInfo.InvariantCulture, SF_AUTHORIZATION_SNOWFLAKE_FMT, SFSession.m_SessionToken); + } else - lastResultUrl = response.Data?.GetResultUrl; + { + receivedFirstQueryResponse = true; + } } - return BuildResultSet(response, cancellationToken); + SFBaseResultSet? result = null; + if (!asyncExec) + { + var lastResultUrl = response!.Data!.GetResultUrl; + + while (RequestInProgress(response) || SessionExpired(response)) + { + var req = BuildResultRequest(lastResultUrl!); + response = await m_RestRequester.GetAsync(req, cancellationToken).ConfigureAwait(false); + + if (SessionExpired(response)) + SFSession.renewSession(); + else + lastResultUrl = response.Data?.GetResultUrl; + } + } + else + { + // if this was an asynchronous query, need to track it with the session + result = BuildResultSet(response!, cancellationToken); + var d = BuildQueryStatusFromQueryResponse(response!); + SFSession.AddAsynchronousQueryStatus(result.m_QueryId!, d); + } + return result ?? BuildResultSet(response!, cancellationToken); } finally { @@ -210,7 +316,7 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti } } - internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly) + internal SFBaseResultSet Execute(int timeout, string sql, Dictionary bindings, bool describeOnly, bool asyncExec) { // Trim the sql query and check if this is a PUT/GET command var trimmedSql = TrimSql(sql); @@ -233,7 +339,7 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary(req); - - if (SessionExpired(response)) - { - SFSession.renewSession(); - } - else + var lastResultUrl = response?.Data?.GetResultUrl; + while (RequestInProgress(response) || SessionExpired(response)) { - lastResultUrl = response.Data?.GetResultUrl; + var req = BuildResultRequest(lastResultUrl!); + response = m_RestRequester.Get(req); + + if (SessionExpired(response)) + { + SFSession.renewSession(); + } + else + { + lastResultUrl = response.Data?.GetResultUrl; + } } } - return BuildResultSet(response!, CancellationToken.None); + else + { + // if this was an asynchronous query, need to track it with the session + result = BuildResultSet(response!, CancellationToken.None); + var d = BuildQueryStatusFromQueryResponse(response!); + SFSession.AddAsynchronousQueryStatus(result.m_QueryId!, d); + } + + return result ?? BuildResultSet(response!, CancellationToken.None); } } finally @@ -332,7 +450,7 @@ internal T ExecuteHelper(int timeout, string sql, Dictionary /// Move cursor back one row. /// /// True if it works, false otherwise. internal abstract bool Rewind(); - protected SFBaseResultSet(SnowflakeDbConfiguration configuration) + protected SFBaseResultSet(SnowflakeDbConfiguration configuration, SnowflakeDbQueryStatus? queryStatus) { Configuration = configuration; + QueryStatus = queryStatus; } internal T GetValue(int columnIndex) diff --git a/Snowflake.Data/Core/ResponseProcessing/SFResultSet.cs b/Snowflake.Data/Core/ResponseProcessing/SFResultSet.cs index 131f532..fdaed9d 100644 --- a/Snowflake.Data/Core/ResponseProcessing/SFResultSet.cs +++ b/Snowflake.Data/Core/ResponseProcessing/SFResultSet.cs @@ -18,16 +18,23 @@ class SFResultSet : SFBaseResultSet IResultChunk m_CurrentChunk; - public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base(sfStatement.SFSession.Configuration) + public SFResultSet(QueryExecResponse response, SFStatement sfStatement, CancellationToken cancellationToken) : base(sfStatement.SFSession.Configuration, SFStatement.BuildQueryStatusFromQueryResponse(response)) { - if (responseData.RowType == null) - throw new ArgumentException($"responseData.rowType is null", nameof(responseData)); - if (responseData.RowSet == null) - throw new ArgumentException($"responseData.rowSet is null", nameof(responseData)); + if (response.Data == null) + throw new ArgumentException($"response.Data is null", nameof(response)); - m_ColumnCount = responseData.RowType.Count; + //if (response.responseData.RowType == null) + // throw new ArgumentException($"responseData.rowType is null", nameof(responseData)); + //if (response.responseData.RowSet == null) + // throw new ArgumentException($"responseData.rowSet is null", nameof(responseData)); + + var responseData = response.Data; + // async result will not provide parameters, so need to set + responseData.Parameters = responseData.Parameters ?? new List(); + + m_ColumnCount = responseData.RowType?.Count ?? 0; m_CurrentChunkRowIdx = -1; - m_CurrentChunkRowCount = responseData.RowSet.GetLength(0); + m_CurrentChunkRowCount = responseData.RowSet?.GetLength(0) ?? 0; SFStatement = sfStatement; updateSessionStatus(responseData); @@ -72,7 +79,7 @@ public void initializePutGetRowType(List rowType) } } - public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base(sfStatement.SFSession.Configuration) + public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base(sfStatement.SFSession.Configuration, null) { if (responseData.RowSet == null) throw new ArgumentException($"responseData.rowSet is null", nameof(responseData)); diff --git a/Snowflake.Data/Core/ResponseProcessing/SFResultSetMetaData.cs b/Snowflake.Data/Core/ResponseProcessing/SFResultSetMetaData.cs index f5ffa27..0c7b3ed 100644 --- a/Snowflake.Data/Core/ResponseProcessing/SFResultSetMetaData.cs +++ b/Snowflake.Data/Core/ResponseProcessing/SFResultSetMetaData.cs @@ -20,7 +20,7 @@ class SFResultSetMetaData internal readonly string? m_TimestampeTZOutputFormat; - internal List m_RowTypes; + internal List? m_RowTypes; internal readonly SFStatementType m_StatementType; @@ -33,13 +33,13 @@ class SFResultSetMetaData internal SFResultSetMetaData(QueryExecResponseData queryExecResponseData) { - if (queryExecResponseData.RowType == null) - throw new ArgumentException($"queryExecResponseData.rowType is null", nameof(queryExecResponseData)); + //if (queryExecResponseData.RowType == null) + // throw new ArgumentException($"queryExecResponseData.rowType is null", nameof(queryExecResponseData)); if (queryExecResponseData.Parameters == null) throw new ArgumentException($"queryExecResponseData.parameters is null", nameof(queryExecResponseData)); m_RowTypes = queryExecResponseData.RowType; - m_ColumnCount = m_RowTypes.Count; + m_ColumnCount = m_RowTypes?.Count ?? 0; m_StatementType = FindStatementTypeById(queryExecResponseData.StatementTypeId); m_ColumnTypes = InitColumnTypes(); @@ -71,10 +71,13 @@ internal SFResultSetMetaData(PutGetResponseData putGetResponseData) List> InitColumnTypes() { + if (m_RowTypes == null && m_ColumnCount > 0) + throw new InvalidOperationException($"{nameof(m_RowTypes)} is null"); + var types = new List>(); for (var i = 0; i < m_ColumnCount; i++) { - var column = m_RowTypes[i]; + var column = m_RowTypes![i]; //this is not null if m_ColumnCount >= 1 var dataType = SnowflakeDbDataTypeExtensions.FromSql(column.Type!); var nativeType = GetNativeTypeForColumn(dataType, column); @@ -94,6 +97,9 @@ internal int GetColumnIndexByName(string targetColumnName) } else { + if (m_RowTypes == null) + throw new InvalidOperationException($"{nameof(m_RowTypes)} is null"); + var indexCounter = 0; foreach (var rowType in m_RowTypes) { @@ -166,6 +172,9 @@ internal Type GetCSharpTypeByIndex(int targetIndex) if (targetIndex < 0 || targetIndex >= m_ColumnCount) throw new SnowflakeDbException(SnowflakeDbError.ColumnIndexOutOfBound, targetIndex); + if (m_RowTypes == null) + throw new InvalidOperationException($"{nameof(m_RowTypes)} is null"); + var sfType = GetColumnTypeByIndex(targetIndex); return GetNativeTypeForColumn(sfType, m_RowTypes[targetIndex]); } @@ -175,6 +184,9 @@ internal Type GetCSharpTypeByIndex(int targetIndex) if (targetIndex < 0 || targetIndex >= m_ColumnCount) throw new SnowflakeDbException(SnowflakeDbError.ColumnIndexOutOfBound, targetIndex); + if (m_RowTypes == null) + throw new InvalidOperationException($"{nameof(m_RowTypes)} is null"); + return m_RowTypes[targetIndex].Name; } diff --git a/Snowflake.Data/Core/Sessions/SFSession.cs b/Snowflake.Data/Core/Sessions/SFSession.cs index 23be24d..147aa2e 100644 --- a/Snowflake.Data/Core/Sessions/SFSession.cs +++ b/Snowflake.Data/Core/Sessions/SFSession.cs @@ -390,4 +390,61 @@ static HttpClientHandler SetupCustomHttpHandler(HttpClientConfig config) } return httpHandler; } + + /// + /// Tracks asynchronous queries that were started by the session + /// + ConcurrentDictionary AsynchronousQueryStatuses = new ConcurrentDictionary(); + + /// + /// Updates the status of asynchronous query associated with the session + /// + /// + /// + internal void UpdateAsynchronousQueryStatus(string queryId, SnowflakeDbQueryStatus status) + { + if (AsynchronousQueryStatuses.ContainsKey(queryId)) + { + // only track the status if the query was previously associated with this session + AsynchronousQueryStatuses[queryId] = status; + } + } + + /// + /// Associates an asynchronous query with the session + /// + /// + /// + internal void AddAsynchronousQueryStatus(string queryId, SnowflakeDbQueryStatus status) + { + AsynchronousQueryStatuses[queryId] = status; + } + + /// + /// Function that checks if the active session can be closed when the connection is closed. If + /// there are active asynchronous queries running, the session should stay open even if the + /// connection closes so that the queries can finish running. + /// + /// true if it is safe to close this session, false if not + internal bool IsSafeToClose() + { + foreach (var item in AsynchronousQueryStatuses) + { + if (!item.Value.IsQueryDone) + { + // since the last check of the query indicates it was not done, perform another check + + var sfStatement = new SFStatement(this); + var status = sfStatement.CheckQueryStatusAsync(0, item.Key, CancellationToken.None).Result; + if (!status.IsQueryDone) + { + // query is still not done, so it is not safe to close the session, or else + // the asynchronous query will be cancelled + return false; + } + } + } + + return true; + } } diff --git a/Snowflake.Data/SnowflakeDbAsynchronousQueryHelper.cs b/Snowflake.Data/SnowflakeDbAsynchronousQueryHelper.cs new file mode 100644 index 0000000..1fdb284 --- /dev/null +++ b/Snowflake.Data/SnowflakeDbAsynchronousQueryHelper.cs @@ -0,0 +1,115 @@ +using Tortuga.Data.Snowflake.Core.RequestProcessing; + +namespace Tortuga.Data.Snowflake; + +/// +/// Methods to help perform asynchronous queries. +/// +public static class SnowflakeDbAsynchronousQueryHelper +{ + /// + /// Starts a query asynchronously. + /// + /// + /// The query id. + public static SnowflakeDbQueryStatus StartAsynchronousQuery(SnowflakeDbCommand cmd) + { + if (cmd == null) + throw new ArgumentNullException(nameof(cmd), $"{nameof(cmd)} is null."); + + return cmd.StartAsynchronousQuery(); + } + + /// + /// Starts a query asynchronously. + /// + /// + /// + /// The query id. + public static async Task StartAsynchronousQueryAsync(SnowflakeDbCommand cmd, CancellationToken cancellationToken) + { + if (cmd == null) + throw new ArgumentNullException(nameof(cmd), $"{nameof(cmd)} is null."); + + return await cmd.StartAsynchronousQueryAsync(cancellationToken).ConfigureAwait(false); + } + + // https://docs.snowflake.com/en/sql-reference/functions/result_scan.html + // select * from table(result_scan('query id')); + + // https://docs.snowflake.com/en/sql-reference/functions/query_history.html + // select * + // from table(information_schema.query_history()) + // only returns + + // https://docs.snowflake.com/en/sql-reference/account-usage/query_history.html + // Latency for the view may be up to 45 minutes. + + /// + /// Use to get the status of a query to determine if you can fetch the result. + /// + /// + /// + /// + public static SnowflakeDbQueryStatus GetQueryStatus(SnowflakeDbConnection conn, string queryId) + { + if (conn == null) + throw new ArgumentNullException(nameof(conn), $"{nameof(conn)} is null."); + + if (string.IsNullOrEmpty(queryId)) + throw new ArgumentException($"{nameof(queryId)} is null or empty.", nameof(queryId)); + + return GetQueryStatusAsync(conn, queryId, CancellationToken.None).Result; + } + + /// + /// Use to get the status of a query to determine if you can fetch the result. + /// + /// + /// + /// + /// + public static async Task GetQueryStatusAsync(SnowflakeDbConnection conn, + string queryId, CancellationToken cancellationToken) + { + if (conn == null) + throw new ArgumentNullException(nameof(conn), $"{nameof(conn)} is null."); + + if (string.IsNullOrEmpty(queryId)) + throw new ArgumentException($"{nameof(queryId)} is null or empty.", nameof(queryId)); + + return await GetStatusUsingRestApiAsync(conn, queryId, cancellationToken).ConfigureAwait(false); + } + + private static async Task GetStatusUsingRestApiAsync(SnowflakeDbConnection conn, string queryId, CancellationToken cancellationToken) + { + var sfStatement = new SFStatement(conn.SfSession!); + var r = await sfStatement.CheckQueryStatusAsync(0, queryId, cancellationToken).ConfigureAwait(false); + return r; + } + + /// + /// Can use the resulting to fetch the results of the query. + /// + /// + /// + /// + public static SnowflakeDbCommand CreateQueryResultsCommand(SnowflakeDbConnection conn, string queryId) + { + if (conn == null) + throw new ArgumentNullException(nameof(conn), $"{nameof(conn)} is null."); + + if (string.IsNullOrEmpty(queryId)) + throw new ArgumentException($"{nameof(queryId)} is null or empty.", nameof(queryId)); + + return CreateQueryResultsCommandForRestApi(conn, queryId); + } + + private static SnowflakeDbCommand CreateQueryResultsCommandForRestApi(SnowflakeDbConnection conn, string queryId) + { + var cmd = (SnowflakeDbCommand)conn.CreateCommand(); + cmd.HandleAsyncResponse = true; + cmd.CommandText = queryId; + return cmd; + } +} diff --git a/Snowflake.Data/SnowflakeDbCommand.cs b/Snowflake.Data/SnowflakeDbCommand.cs index 1db7df4..08aa147 100644 --- a/Snowflake.Data/SnowflakeDbCommand.cs +++ b/Snowflake.Data/SnowflakeDbCommand.cs @@ -65,6 +65,35 @@ public override UpdateRowSource UpdatedRowSource set => throw new NotSupportedException($"The {nameof(UpdatedRowSource)} property is not supported."); } + /// + /// When true, will expect the CommandText to have the query id and will get a result from an existing query + /// + internal bool HandleAsyncResponse; + + /// + /// Starts a query asynchronously. + /// + /// The query id. + public SnowflakeDbQueryStatus StartAsynchronousQuery() + { + SFBaseResultSet resultSet = ExecuteInternal(asyncExec: true); + return resultSet.QueryStatus!; + } + + /// + /// Starts a query asynchronously. + /// + /// + /// The query id. + public async Task StartAsynchronousQueryAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + throw new TaskCanceledException(); + + var resultSet = await ExecuteInternalAsync(cancellationToken, asyncExec: true).ConfigureAwait(false); + return resultSet.QueryStatus!; + } + protected override DbConnection? DbConnection { get => m_Connection; @@ -226,18 +255,39 @@ effectiveValue is not char[] && } } - SFBaseResultSet ExecuteInternal(bool describeOnly = false) + SFBaseResultSet ExecuteInternal(bool describeOnly = false, bool asyncExec = false) { if (CommandText == null) throw new InvalidOperationException($"{nameof(CommandText)} is null"); - return SetStatement().Execute(CommandTimeout, CommandText, ConvertToBindList(), describeOnly); + + SetStatement(); //this will ensure m_SFStatement is not null + + if (HandleAsyncResponse) + { + //TODO-JLA - Don't call .Result + return m_SFStatement!.GetQueryResultAsync(CommandTimeout, CommandText, CancellationToken.None).Result; + } + else + { + return m_SFStatement!.Execute(CommandTimeout, CommandText, ConvertToBindList(), describeOnly, asyncExec: asyncExec); + } } - Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false) + Task ExecuteInternalAsync(CancellationToken cancellationToken, bool describeOnly = false, bool asyncExec = false) { if (CommandText == null) throw new InvalidOperationException($"{nameof(CommandText)} is null"); - return SetStatement().ExecuteAsync(CommandTimeout, CommandText, ConvertToBindList(), describeOnly, cancellationToken); + + SetStatement(); //this will ensure m_SFStatement is not null + + if (HandleAsyncResponse) + { + return m_SFStatement!.GetQueryResultAsync(CommandTimeout, CommandText, cancellationToken); + } + else + { + return m_SFStatement!.ExecuteAsync(CommandTimeout, CommandText, ConvertToBindList(), describeOnly, asyncExec, cancellationToken); + } } SFStatement SetStatement() diff --git a/Snowflake.Data/SnowflakeDbConnection.cs b/Snowflake.Data/SnowflakeDbConnection.cs index c21161a..e7e7e68 100644 --- a/Snowflake.Data/SnowflakeDbConnection.cs +++ b/Snowflake.Data/SnowflakeDbConnection.cs @@ -102,7 +102,7 @@ public async Task ChangeDatabaseAsync(string databaseName) public override void Close() { - if (m_ConnectionState != ConnectionState.Closed && SfSession != null) + if (m_ConnectionState != ConnectionState.Closed && SfSession != null && SfSession.IsSafeToClose()) SfSession.Close(); m_ConnectionState = ConnectionState.Closed; } diff --git a/Snowflake.Data/SnowflakeDbDataReader.cs b/Snowflake.Data/SnowflakeDbDataReader.cs index 1b7fb5a..9319063 100644 --- a/Snowflake.Data/SnowflakeDbDataReader.cs +++ b/Snowflake.Data/SnowflakeDbDataReader.cs @@ -199,6 +199,8 @@ static DataTable PopulateSchemaTable(SFBaseResultSet resultSet) if (resultSet.SFResultSetMetaData == null) throw new ArgumentException($"{nameof(resultSet.SFResultSetMetaData)} is null.", nameof(resultSet)); + if (resultSet.SFResultSetMetaData.m_RowTypes == null) + throw new ArgumentException($"{nameof(resultSet.SFResultSetMetaData.m_RowTypes)} is null.", nameof(resultSet)); var columnOrdinal = 0; var sfResultSetMetaData = resultSet.SFResultSetMetaData; diff --git a/Snowflake.Data/SnowflakeDbQueryStatus.cs b/Snowflake.Data/SnowflakeDbQueryStatus.cs new file mode 100644 index 0000000..9cde106 --- /dev/null +++ b/Snowflake.Data/SnowflakeDbQueryStatus.cs @@ -0,0 +1,29 @@ +namespace Tortuga.Data.Snowflake; + +public class SnowflakeDbQueryStatus +{ + /// + /// When true, indicates that the query has finished for one reason or another, and there is no reason to wait further for + /// the query to finish. If false, the query is still executing, so the result will not be available. + /// + public bool IsQueryDone { get; } + + /// + /// true indicates that the query completely finished running without any issues, so the result is available. false indicates + /// the result is not ready. You need to inspect to determine if the query is still running + /// as opposed to encountering an error. + /// + public bool IsQuerySuccessful { get; } + + /// + /// The id used to track the query in Snowflake. + /// + public string QueryId { get; } + + public SnowflakeDbQueryStatus(string queryId, bool isQueryDone, bool isQuerySuccessful) + { + QueryId = queryId; + IsQueryDone = isQueryDone; + IsQuerySuccessful = isQuerySuccessful; + } +}