Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1433638 sql trimming only for PUT/GET detection #957

Merged
merged 1 commit into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void TestCancelExecuteAsync()
}
catch
{
// assert that cancel is not triggered by timeout, but external cancellation
// assert that cancel is not triggered by timeout, but external cancellation
Assert.IsTrue(externalCancel.IsCancellationRequested);
}
Thread.Sleep(2000);
Expand Down Expand Up @@ -503,7 +503,7 @@ public void TestRowsAffectedOverflowInt()
using (IDbConnection conn = new SnowflakeDbConnection(ConnectionString))
{
conn.Open();

CreateOrReplaceTable(conn, TableName, new []{"c1 NUMBER"});

using (IDbCommand command = conn.CreateCommand())
Expand Down Expand Up @@ -608,7 +608,7 @@ public void TestSimpleLargeResultSet()
conn.Close();
}
}

[Test, NonParallelizable]
public void TestUseV1ResultParser()
{
Expand Down Expand Up @@ -1021,13 +1021,13 @@ public void testPutArrayBindAsync()

private void ArrayBindTest(string connstr, string tableName, int size)
{

CancellationTokenSource externalCancel = new CancellationTokenSource(TimeSpan.FromSeconds(100));
using (DbConnection conn = new SnowflakeDbConnection())
{
conn.ConnectionString = connstr;
conn.Open();

CreateOrReplaceTable(conn, tableName, new []
{
"cola INTEGER",
Expand Down Expand Up @@ -1197,7 +1197,7 @@ public void testExecuteScalarAsyncSelect()
{
conn.ConnectionString = ConnectionString;
conn.Open();

CreateOrReplaceTable(conn, TableName, new []{"cola INTEGER"});

using (DbCommand cmd = conn.CreateCommand())
Expand Down Expand Up @@ -1624,7 +1624,7 @@ public void TestGetResultsOfUnknownQueryIdWithConfiguredRetry()
conn.Close();
}
}

[Test]
public void TestSetQueryTagOverridesConnectionString()
{
Expand All @@ -1633,16 +1633,48 @@ public void TestSetQueryTagOverridesConnectionString()
string expectedQueryTag = "Test QUERY_TAG 12345";
string connectQueryTag = "Test 123";
conn.ConnectionString = ConnectionString + $";query_tag={connectQueryTag}";

conn.Open();
var command = conn.CreateCommand();
((SnowflakeDbCommand)command).QueryTag = expectedQueryTag;
// This query itself will be part of the history and will have the query tag
command.CommandText = "SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())";
var queryTag = command.ExecuteScalar();

Assert.AreEqual(expectedQueryTag, queryTag);
}
}

[Test]
public void TestCommandWithCommentEmbedded()
{
using (var conn = new SnowflakeDbConnection(ConnectionString))
{
conn.Open();
var command = conn.CreateCommand();

command.CommandText = "\r\nselect '--'\r\n";
var reader = command.ExecuteReader();

Assert.IsTrue(reader.Read());
Assert.AreEqual("--", reader.GetString(0));
}
}

[Test]
public async Task TestCommandWithCommentEmbeddedAsync()
{
using (var conn = new SnowflakeDbConnection(ConnectionString))
{
conn.Open();
var command = conn.CreateCommand();

command.CommandText = "\r\nselect '--'\r\n";
var reader = await command.ExecuteReaderAsync().ConfigureAwait(false);

Assert.IsTrue(await reader.ReadAsync().ConfigureAwait(false));
Assert.AreEqual("--", reader.GetString(0));
}
}
}
}
44 changes: 16 additions & 28 deletions Snowflake.Data.Tests/UnitTests/SFStatementTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,10 @@ public void TestServiceName()
[Test]
public void TestTrimSqlBlockComment()
{
Mock.MockRestSessionExpiredInQueryExec restRequester = new Mock.MockRestSessionExpiredInQueryExec();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment*/select 1/*comment*/", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
const string SqlSource = "/*comment*/select 1/*comment*/";
const string SqlExpected = "select 1";

Assert.AreEqual(SqlExpected, SFStatement.TrimSql(SqlSource));
}

/// <summary>
Expand All @@ -85,13 +82,10 @@ public void TestTrimSqlBlockComment()
[Test]
public void TestTrimSqlBlockCommentMultiline()
{
Mock.MockRestSessionExpiredInQueryExec restRequester = new Mock.MockRestSessionExpiredInQueryExec();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
const string SqlSource = "/*comment\r\ncomment*/select 1/*comment\r\ncomment*/";
const string SqlExpected = "select 1";

Assert.AreEqual(SqlExpected, SFStatement.TrimSql(SqlSource));
}

/// <summary>
Expand All @@ -100,13 +94,10 @@ public void TestTrimSqlBlockCommentMultiline()
[Test]
public void TestTrimSqlLineComment()
{
Mock.MockRestSessionExpiredInQueryExec restRequester = new Mock.MockRestSessionExpiredInQueryExec();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
const string SqlSource = "--comment\r\nselect 1\r\n--comment";
const string SqlExpected = "select 1\r\n--comment";

Assert.AreEqual(SqlExpected, SFStatement.TrimSql(SqlSource));
}

/// <summary>
Expand All @@ -115,13 +106,10 @@ public void TestTrimSqlLineComment()
[Test]
public void TestTrimSqlLineCommentWithClosingNewline()
{
Mock.MockRestSessionExpiredInQueryExec restRequester = new Mock.MockRestSessionExpiredInQueryExec();
SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester);
sfSession.Open();
SFStatement statement = new SFStatement(sfSession);
SFBaseResultSet resultSet = statement.Execute(0, "--comment\r\nselect 1\r\n--comment\r\n", null, false, false);
Assert.AreEqual(true, resultSet.Next());
Assert.AreEqual("1", resultSet.GetString(0));
const string SqlSource = "--comment\r\nselect 1\r\n--comment\r\n";
const string SqlExpected = "select 1";

Assert.AreEqual(SqlExpected, SFStatement.TrimSql(SqlSource));
}

[Test]
Expand Down
44 changes: 22 additions & 22 deletions Snowflake.Data/Core/SFStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ class SFStatement
private const string SF_QUERY_RESULT_PATH = "/queries/{0}/result";

private const string SF_PARAM_MULTI_STATEMENT_COUNT = "MULTI_STATEMENT_COUNT";

private const string SF_PARAM_QUERY_TAG = "QUERY_TAG";

private const int SF_QUERY_IN_PROGRESS = 333333;

private const int SF_QUERY_IN_PROGRESS_ASYNC = 333334;
Expand All @@ -124,8 +124,8 @@ class SFStatement
private readonly IRestRequester _restRequester;

private CancellationTokenSource _timeoutTokenSource;
// Merged cancellation token source for all cancellation signal.

// Merged cancellation token source for all cancellation signal.
// Cancel callback will be registered under token issued by this source.
private CancellationTokenSource _linkedCancellationTokenSource;

Expand All @@ -151,21 +151,21 @@ internal SFStatement(SFSession session)
_restRequester = session.restRequester;
_queryTag = session._queryTag;
}
internal SFStatement(SFSession session, string queryTag)

internal SFStatement(SFSession session, string queryTag)
{
SfSession = session;
_restRequester = session.restRequester;
_queryTag = queryTag ?? session._queryTag;
_queryTag = queryTag ?? session._queryTag;
}

internal string GetBindStage() => _bindStage;

private void AssignQueryRequestId()
{
lock (_requestIdLock)
{

if (_requestId != null)
{
logger.Info("Another query is running.");
Expand Down Expand Up @@ -207,16 +207,16 @@ private SFRestRequest BuildQueryRequest(string sql, Dictionary<string, BindingDT
// remove it from parameter bindings so it won't break
// parameter binding feature
bindings.Remove(SF_PARAM_MULTI_STATEMENT_COUNT);
}
}

if (_queryTag != null)
{
if (bodyParameters == null)
{
bodyParameters = new Dictionary<string, string>();
}
bodyParameters[SF_PARAM_QUERY_TAG] = _queryTag;
}
}

QueryRequest postBody = new QueryRequest();
postBody.sqlText = sql;
Expand Down Expand Up @@ -317,7 +317,7 @@ private void SetTimeout(int timeout)
this._timeoutTokenSource = timeout > 0 ? new CancellationTokenSource(timeout * 1000) :
new CancellationTokenSource(Timeout.InfiniteTimeSpan);
}

/// <summary>
/// Register cancel callback. Two factors: either external cancellation token passed down from upper
/// layer or timeout reached. Whichever comes first would trigger query cancellation.
Expand Down Expand Up @@ -363,7 +363,7 @@ internal async Task<SFBaseResultSet> ExecuteAsync(int timeout, string sql, Dicti
}

registerQueryCancellationCallback(timeout, cancellationToken);

int arrayBindingThreshold = 0;
if (SfSession.ParameterMap.ContainsKey(SFSessionParameter.CLIENT_STAGE_ARRAY_BINDING_THRESHOLD))
{
Expand Down Expand Up @@ -457,10 +457,10 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary<string, Bin
{
throw new NotImplementedException("Get and Put are not supported in async execution mode");
}
return ExecuteSqlWithPutGet(timeout, trimmedSql, bindings, describeOnly);
return ExecuteSqlWithPutGet(timeout, sql, trimmedSql, bindings, describeOnly);
}

return ExecuteSqlOtherThanPutGet(timeout, trimmedSql, bindings, describeOnly, asyncExec);
return ExecuteSqlOtherThanPutGet(timeout, sql, bindings, describeOnly, asyncExec);
}
finally
{
Expand All @@ -469,7 +469,7 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary<string, Bin
}
}

private SFBaseResultSet ExecuteSqlWithPutGet(int timeout, string sql, Dictionary<string, BindingDTO> bindings, bool describeOnly)
private SFBaseResultSet ExecuteSqlWithPutGet(int timeout, string sql, string trimmedSql, Dictionary<string, BindingDTO> bindings, bool describeOnly)
{
try
{
Expand All @@ -484,7 +484,7 @@ private SFBaseResultSet ExecuteSqlWithPutGet(int timeout, string sql, Dictionary
logger.Debug("PUT/GET queryId: " + (response.data != null ? response.data.queryId : "Unknown"));

SFFileTransferAgent fileTransferAgent =
new SFFileTransferAgent(sql, SfSession, response.data, CancellationToken.None);
new SFFileTransferAgent(trimmedSql, SfSession, response.data, CancellationToken.None);

// Start the file transfer
fileTransferAgent.execute();
Expand All @@ -507,7 +507,7 @@ private SFBaseResultSet ExecuteSqlWithPutGet(int timeout, string sql, Dictionary
throw new SnowflakeDbException(ex, SFError.INTERNAL_ERROR);
}
}

private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dictionary<string, BindingDTO> bindings, bool describeOnly, bool asyncExec)
{
try
Expand Down Expand Up @@ -562,7 +562,7 @@ private SFBaseResultSet ExecuteSqlOtherThanPutGet(int timeout, string sql, Dicti
throw;
}
}

internal async Task<SFBaseResultSet> GetResultWithIdAsync(string resultId, CancellationToken cancellationToken)
{
var req = BuildResultRequestWithId(resultId);
Expand Down Expand Up @@ -938,7 +938,7 @@ internal async Task<QueryStatus> GetQueryStatusAsync(string queryId, Cancellatio
/// </summary>
/// <param name="originalSql">The original sql query.</param>
/// <returns>The query without the blanks and comments at the beginning.</returns>
private string TrimSql(string originalSql)
internal static string TrimSql(string originalSql)
{
char[] sqlQueryBuf = originalSql.ToCharArray();
var builder = new StringBuilder();
Expand Down Expand Up @@ -1054,7 +1054,7 @@ internal SFBaseResultSet ExecuteTransfer(string sql)
false);

PutGetStageInfo stageInfo = new PutGetStageInfo();

SFFileTransferAgent fileTransferAgent =
new SFFileTransferAgent(sql, SfSession, response.data, ref _uploadStream, _destFilename, _stagePath, CancellationToken.None);

Expand Down
Loading