From aab71cf7815b2c025af1f0646eec005cfcf3352e Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Tue, 25 May 2021 11:10:17 -0700 Subject: [PATCH 01/13] Add proxy support - Each new proxy configuration is linked to a different httpclient --- .../Mock/MockRetryUntilRestTimeout.cs | 2 +- Snowflake.Data.Tests/SFBaseTest.cs | 23 +- Snowflake.Data.Tests/SFConnectionIT.cs | 234 ++++++++++++++++++ Snowflake.Data.Tests/SFStatementTest.cs | 6 +- .../Client/SnowflakeDbConnection.cs | 22 +- Snowflake.Data/Client/SnowflakeDbException.cs | 41 +-- Snowflake.Data/Core/ChunkDownloaderFactory.cs | 3 +- Snowflake.Data/Core/HttpUtil.cs | 46 +++- Snowflake.Data/Core/RestRequester.cs | 99 +++++++- .../Core/SFBlockingChunkDownloader.cs | 13 +- .../Core/SFBlockingChunkDownloaderV3.cs | 11 +- Snowflake.Data/Core/SFChunkDownloaderV2.cs | 8 +- Snowflake.Data/Core/SFSession.cs | 58 +++-- Snowflake.Data/Core/SFSessionProperty.cs | 60 ++++- Snowflake.Data/Core/SFStatement.cs | 9 +- 15 files changed, 545 insertions(+), 90 deletions(-) diff --git a/Snowflake.Data.Tests/Mock/MockRetryUntilRestTimeout.cs b/Snowflake.Data.Tests/Mock/MockRetryUntilRestTimeout.cs index ece3fb380..eef0ac14b 100644 --- a/Snowflake.Data.Tests/Mock/MockRetryUntilRestTimeout.cs +++ b/Snowflake.Data.Tests/Mock/MockRetryUntilRestTimeout.cs @@ -69,7 +69,7 @@ private async Task SendAsync(HttpRequestMessage request, { // Http timeout of 1ms to force retries request.Properties[BaseRestRequest.HTTP_REQUEST_TIMEOUT_KEY] = TimeSpan.FromMilliseconds(1); - var response = await HttpUtil.getHttpClient().SendAsync(request, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token).ConfigureAwait(false); + var response = await HttpUtil.getHttpClient(null).SendAsync(request, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token).ConfigureAwait(false); response.EnsureSuccessStatusCode(); return response; diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 69e19f0d2..27eec224b 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -88,7 +88,7 @@ public void SFTestSetup() String cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); - StreamReader reader = new StreamReader("parameters.json"); + StreamReader reader = new StreamReader("C:\\Users\\natachab\\Snowflake\\fromMasterToWorkOnTicket\\Snowflake.Data.Tests\\parameters.json"); var testConfigString = reader.ReadToEnd(); @@ -174,6 +174,27 @@ public class TestConfig [JsonProperty(PropertyName = "SNOWFLAKE_TEST_EXP_OAUTH_TOKEN", NullValueHandling = NullValueHandling.Ignore)] internal string expOauthToken { get; set; } + [JsonProperty(PropertyName = "PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] + internal string proxyHost { get; set; } + + [JsonProperty(PropertyName = "PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] + internal string proxyPort { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyHost { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyPort { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_USER", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyUser { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_PWD", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyPwd { get; set; } + + [JsonProperty(PropertyName = "NON_PROXY_HOSTS", NullValueHandling = NullValueHandling.Ignore)] + internal string nonProxyHosts { get; set; } + public TestConfig() { this.protocol = "https"; diff --git a/Snowflake.Data.Tests/SFConnectionIT.cs b/Snowflake.Data.Tests/SFConnectionIT.cs index c06e9c35a..766b0bc66 100755 --- a/Snowflake.Data.Tests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/SFConnectionIT.cs @@ -45,6 +45,42 @@ public void TestBasicConnection() } } + [Test] + public void TestIncorrectUserOrPasswordBasicConnection() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = String.Format("scheme={0};host={1};port={2};" + + "account={3};role={4};db={5};schema={6};warehouse={7};user={8};password={9};", + testConfig.protocol, + testConfig.host, + testConfig.port, + testConfig.account, + testConfig.role, + testConfig.database, + testConfig.schema, + testConfig.warehouse, + "unknown", + testConfig.password); + + Assert.AreEqual(conn.State, ConnectionState.Closed); + try + { + conn.Open(); + Assert.Fail(); + + } + catch (SnowflakeDbException e) + { + // Expected + logger.Debug("Failed opening connection ", e); + Assert.AreEqual("08006", e.SqlState); // Connection failure + } + + Assert.AreEqual(ConnectionState.Closed, conn.State); + } + } + [Test] public void TestConnectViaSecureString() { @@ -736,6 +772,204 @@ public void TestValidOAuthExpiredTokenConnection() Assert.AreEqual(390318, e.ErrorCode); } } + + [Test] + [Ignore("Ignore this test until configuration is setup for CI integration. Can be run manually.")] + public void TestCorrectProxySettingFromConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1}", + testConfig.proxyHost, + testConfig.proxyPort); + + conn.Open(); + } + } + + [Test] + [Ignore("Ignore this test until configuration is setup for CI integration. Can be run manually.")] + public void TestCorrectProxyWithCredsSettingFromConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1};proxyUser={2};proxyPassword={3}", + testConfig.authProxyHost, + testConfig.authProxyPort, + testConfig.authProxyUser, + testConfig.authProxyPwd); + + conn.Open(); + } + } + + [Test] + [Ignore("Ignore this test until configuration is setup for CI integration. Can be run manually.")] + public void TestCorrectProxySettingWithByPassListFromConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1};proxyUser={2};proxyPassword={3};nonProxyHosts={4}", + testConfig.authProxyHost, + testConfig.authProxyPort, + testConfig.authProxyUser, + testConfig.authProxyPwd, + "*.foo.com %7C" + testConfig.host + "|localhost"); + + conn.Open(); + } + } + + [Test] + [Ignore("Ignore this test until configuration is setup for CI integration. Can be run manually.")] + public void TestMultipleConnectionWithDifferentProxySettings() + { + // Authenticated proxy + using (var conn1 = new SnowflakeDbConnection()) + { + conn1.ConnectionString = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1};proxyUser={2};proxyPassword={3}", + testConfig.authProxyHost, + testConfig.authProxyPort, + testConfig.authProxyUser, + testConfig.authProxyPwd); + conn1.Open(); + } + + // No proxy + using (var conn2 = new SnowflakeDbConnection()) + { + conn2.ConnectionString = ConnectionString; + conn2.Open(); + } + + // Non authenticated proxy + using (var conn3 = new SnowflakeDbConnection()) + { + conn3.ConnectionString = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1}", + testConfig.proxyHost, + testConfig.proxyPort); + conn3.Open(); + } + + // Invalid proxy + using (var conn4 = new SnowflakeDbConnection()) + { + conn4.ConnectionString = + ConnectionString + "connection_timeout=20;useProxy=true;proxyHost=Invalid;proxyPort=8080"; + try + { + conn4.Open(); + Assert.Fail(); + } + catch + { + // Expected + } + } + + // Another authenticated proxy connection + //Should use same httpclient than previous authenticated proxy connection + using (var conn5 = new SnowflakeDbConnection()) + { + conn5.ConnectionString = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1};proxyUser={2};proxyPassword={3}", + testConfig.authProxyHost, + testConfig.authProxyPort, + testConfig.authProxyUser, + testConfig.authProxyPwd); + conn5.Open(); + } + + // No proxy again + // Should use same httpclient than previous no proxy connection + using (var conn6 = new SnowflakeDbConnection()) + { + conn6.ConnectionString = ConnectionString; + conn6.Open(); + } + + // Another authenticated proxy, but this will create a new httpclient because there is + // a bypass list + using (var conn7 = new SnowflakeDbConnection()) + { + conn7.ConnectionString + = ConnectionString + + String.Format( + ";useProxy=true;proxyHost={0};proxyPort={1};proxyUser={2};proxyPassword={3};nonProxyHosts={4}", + testConfig.authProxyHost, + testConfig.authProxyPort, + testConfig.authProxyUser, + testConfig.authProxyPwd, + "*.foo.com %7C" + testConfig.host + "|localhost"); + + conn7.Open(); + } + + } + + [Test] + public void TestInvalidProxySettingFromConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = + ConnectionString + "connection_timeout=5;useProxy=true;proxyHost=Invalid;proxyPort=8080"; + try + { + conn.Open(); + Assert.Fail(); + } + catch (SnowflakeDbException e) + { + // Expected + logger.Debug("Failed opening connection ", e); + Assert.AreEqual(270001, e.ErrorCode); //Internal error + Assert.AreEqual("08006", e.SqlState); // Connection failure + } + } + } + + [Test] + public void TestUseProxyFalseWithInvalidProxyConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = + ConnectionString + ";useProxy=false;proxyHost=Invalid;proxyPort=8080"; + conn.Open(); + // Because useProxy=false, the proxy settings are ignored + } + } + + [Test] + public void TestInvalidProxySettingWithByPassListFromConnectionString() + { + using (var conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionString + + String.Format( + ";useProxy=true;proxyHost=Invalid;proxyPort=8080;nonProxyHosts={0}", + "*.foo.com %7C" + testConfig.host + "|localhost"); + + conn.Open(); + // Because testConfig.host is in the bypass list, the proxy should not be used + } + } } [TestFixture] diff --git a/Snowflake.Data.Tests/SFStatementTest.cs b/Snowflake.Data.Tests/SFStatementTest.cs index 8e21d8a0f..e7681656f 100755 --- a/Snowflake.Data.Tests/SFStatementTest.cs +++ b/Snowflake.Data.Tests/SFStatementTest.cs @@ -20,7 +20,7 @@ public void TestSessionRenew() Mock.MockRestSessionExpired restRequester = new Mock.MockRestSessionExpired(); SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); - SFStatement statement = new SFStatement(sfSession, restRequester); + SFStatement statement = new SFStatement(sfSession); SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); @@ -36,7 +36,7 @@ public void TestSessionRenewDuringQueryExec() Mock.MockRestSessionExpiredInQueryExec restRequester = new Mock.MockRestSessionExpiredInQueryExec(); SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); - SFStatement statement = new SFStatement(sfSession, restRequester); + SFStatement statement = new SFStatement(sfSession); SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false); Assert.AreEqual(true, resultSet.Next()); Assert.AreEqual("1", resultSet.GetString(0)); @@ -56,7 +56,7 @@ public void TestServiceName() Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]); for (int i = 0; i < 5; i++) { - SFStatement statement = new SFStatement(sfSession, restRequester); + SFStatement statement = new SFStatement(sfSession); SFBaseResultSet resultSet = statement.Execute(0, "SELECT 1", null, false); expectServiceName += "a"; Assert.AreEqual(expectServiceName, sfSession.ParameterMap[SFSessionParameter.SERVICE_NAME]); diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index 869cce5f4..870c0c951 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -116,9 +116,12 @@ public override void Open() logger.Error("Unable to connect", e); if (!(e.GetType() == typeof(SnowflakeDbException))) { - throw new SnowflakeDbException(e.InnerException, - SFError.INTERNAL_ERROR, - "Unable to connect"); + throw + new SnowflakeDbException( + e, + SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + SFError.INTERNAL_ERROR, + "Unable to connect. " + e.Message); } else { @@ -142,9 +145,11 @@ public override Task OpenAsync(CancellationToken cancellationToken) Exception sfSessionEx = previousTask.Exception; _connectionState = ConnectionState.Closed; logger.Error("Unable to connect", sfSessionEx.InnerException); - throw new SnowflakeDbException(sfSessionEx.InnerException, - SFError.INTERNAL_ERROR, - "Unable to connect"); + throw new SnowflakeDbException( + sfSessionEx, + SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + SFError.INTERNAL_ERROR, + "Unable to connect"); } else if (previousTask.IsCanceled) { @@ -153,9 +158,8 @@ public override Task OpenAsync(CancellationToken cancellationToken) } else { - logger.Debug("All good"); - // Only continue if the session was opened successfully - OnSessionEstablished(); + // Only continue if the session was opened successfully + OnSessionEstablished(); } }, cancellationToken); diff --git a/Snowflake.Data/Client/SnowflakeDbException.cs b/Snowflake.Data/Client/SnowflakeDbException.cs index 3f4962666..7b78f2f11 100755 --- a/Snowflake.Data/Client/SnowflakeDbException.cs +++ b/Snowflake.Data/Client/SnowflakeDbException.cs @@ -17,22 +17,25 @@ namespace Snowflake.Data.Client /// public sealed class SnowflakeDbException : DbException { + // Sql states not coming directly from the server. + internal static string CONNECTION_FAILURE_SSTATE = "08006"; + static private ResourceManager rm = new ResourceManager("Snowflake.Data.Core.ErrorMessages", typeof(SnowflakeDbException).Assembly); - private string sqlState; + public string SqlState { get; private set; } - private int vendorCode; + private int VendorCode; - private string errorMessage; + private string ErrorMessage; - public string queryId { get; } + public string QueryId { get; } public override string Message { get { - return errorMessage; + return ErrorMessage; } } @@ -40,35 +43,43 @@ public override int ErrorCode { get { - return vendorCode; + return VendorCode; } } public SnowflakeDbException(string sqlState, int vendorCode, string errorMessage, string queryId) { - this.sqlState = sqlState; - this.vendorCode = vendorCode; - this.errorMessage = errorMessage; - this.queryId = queryId; + this.SqlState = sqlState; + this.VendorCode = vendorCode; + this.ErrorMessage = errorMessage; + this.QueryId = queryId; } public SnowflakeDbException(SFError error, params object[] args) { - this.errorMessage = string.Format(rm.GetString(error.ToString()), args); - this.vendorCode = error.GetAttribute().errorCode; + this.ErrorMessage = string.Format(rm.GetString(error.ToString()), args); + this.VendorCode = error.GetAttribute().errorCode; } public SnowflakeDbException(Exception innerException, SFError error, params object[] args) : base(string.Format(rm.GetString(error.ToString()), args), innerException) { - this.errorMessage = string.Format(rm.GetString(error.ToString()), args); - this.vendorCode = error.GetAttribute().errorCode; + this.ErrorMessage = string.Format(rm.GetString(error.ToString()), args); + this.VendorCode = error.GetAttribute().errorCode; + } + + public SnowflakeDbException(Exception innerException, string sqlState, SFError error, params object[] args) + : base(string.Format(rm.GetString(error.ToString()), args), innerException) + { + this.ErrorMessage = string.Format(rm.GetString(error.ToString()), args); + this.VendorCode = error.GetAttribute().errorCode; + this.SqlState = sqlState; } public override string ToString() { return string.Format("Error: {0} SqlState: {1}, VendorCode: {2}, QueryId: {3}", - errorMessage, sqlState, vendorCode, queryId); + ErrorMessage, SqlState, VendorCode, QueryId); } } } diff --git a/Snowflake.Data/Core/ChunkDownloaderFactory.cs b/Snowflake.Data/Core/ChunkDownloaderFactory.cs index f4016a99a..e5b16446e 100755 --- a/Snowflake.Data/Core/ChunkDownloaderFactory.cs +++ b/Snowflake.Data/Core/ChunkDownloaderFactory.cs @@ -31,7 +31,8 @@ public static IChunkDownloader GetDownloader(QueryExecResponseData responseData, responseData.chunks, responseData.qrmk, responseData.chunkHeaders, - cancellationToken); + cancellationToken, + resultSet.sfStatement.SfSession.restRequester); default: return new SFBlockingChunkDownloaderV3(responseData.rowType.Count, responseData.chunks, diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 3425db48b..f206fb387 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -17,7 +17,14 @@ namespace Snowflake.Data.Core { class HttpUtil { - static private HttpClient httpClient = null; + private static SFLogger logger = SFLoggerFactory.GetLogger(); + + + // Pool of http clients per proxy settings + private static Dictionary httpClients = + new Dictionary(); + + private static HttpClient noProxyHttpClient; static private CookieContainer cookieContainer = null; @@ -41,20 +48,36 @@ static public void ClearCookies(Uri uri) } } - - static public HttpClient getHttpClient() + /// + /// Get the http client for the given proxy, or null if no proxy are used. + /// + /// The proxy to use or null for no proxy. + /// The corresponding Httpclient instance. + static public HttpClient getHttpClient(IWebProxy proxy) { lock (httpClientInitLock) { - if (httpClient == null) + if ( null == proxy) { - initHttpClient(); + logger.Debug("noProxyHttpClient " + noProxyHttpClient); + // Init noProxyHttpClient + if (null == noProxyHttpClient) noProxyHttpClient = initHttpClient(null); + return noProxyHttpClient; + } + + if (!httpClients.TryGetValue(proxy, out HttpClient httpClient)) + { + //logger.Debug("Need a new HttpClient for this proxy."); + // Need a new HttpClient for this proxy + httpClient = initHttpClient(proxy); + // Add to the pool + httpClients.Add(proxy, httpClient); } return httpClient; } } - static private void initHttpClient() + static private HttpClient initHttpClient(IWebProxy proxy) { HttpClientHandler httpHandler = new HttpClientHandler() @@ -64,12 +87,17 @@ static private void initHttpClient() // Enforce tls v1.2 SslProtocols = SslProtocols.Tls12, AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, - }; + CookieContainer = cookieContainer = new CookieContainer() + }; - HttpUtil.httpClient = new HttpClient(new RetryHandler(httpHandler)); + if (null != proxy) httpHandler.Proxy = proxy; + + HttpClient httpClient = new HttpClient(new RetryHandler(httpHandler)); // HttpClient has a default timeout of 100 000 ms, we don't want to interfere with our // own connection and command timeout - HttpUtil.httpClient.Timeout = Timeout.InfiniteTimeSpan; + httpClient.Timeout = Timeout.InfiniteTimeSpan; + + return httpClient; } /// diff --git a/Snowflake.Data/Core/RestRequester.cs b/Snowflake.Data/Core/RestRequester.cs index 216268571..9c31cdefc 100644 --- a/Snowflake.Data/Core/RestRequester.cs +++ b/Snowflake.Data/Core/RestRequester.cs @@ -9,6 +9,8 @@ using System.Threading.Tasks; using Snowflake.Data.Client; using Snowflake.Data.Log; +using System.Net; +using System.Collections.Generic; namespace Snowflake.Data.Core { @@ -34,17 +36,87 @@ internal class RestRequester : IRestRequester { private static SFLogger logger = SFLoggerFactory.GetLogger(); - private static readonly RestRequester instance = new RestRequester(); + private static Dictionary RestRequestersByProxy = + new Dictionary(); - private RestRequester() + private static readonly object requesterPoolLock = new object(); + + // The proxy to use making the requests + internal IWebProxy Proxy; + + private RestRequester(IWebProxy proxy) { + Proxy = proxy; } - - static internal RestRequester Instance + + /// + /// Get the RestRequester associated with the given proxy information or create a new one + /// if none exist in the pool already. + /// + /// The proxy host. + /// The proxy port. + /// The proxy username or null if none. + /// The proxy password or null if none. + /// The list of urls to by-pass the proxy. + /// The RestRequester for this proxy. + /// The port is not a valid int + /// The port value is too large + internal static IRestRequester GetRestRequester( + string proxyHost, + string proxyPort, + string proxyUser, + string proxyPassword, + string noProxyList) { - get { return instance; } + string key = string.Join(";", new string[]{ proxyHost, proxyPort, proxyUser, proxyPassword, noProxyList }); + lock(requesterPoolLock) + { + if (!RestRequestersByProxy.TryGetValue(key, out IRestRequester requester)) + { + WebProxy webProxy = null; + if (null != proxyHost) + { + // New proxy needed + webProxy = new WebProxy(proxyHost, int.Parse(proxyPort)); + + // Add credential if provided + if (!String.IsNullOrEmpty(proxyUser)) + { + ICredentials credentials = new NetworkCredential(proxyUser, proxyPassword); + webProxy.Credentials = credentials; + } + + // Add bypasslist if provided + if (!String.IsNullOrEmpty(noProxyList)) + { + string[] bypassList = noProxyList.Split( + new char[] { '|' }, + StringSplitOptions.RemoveEmptyEntries); + // Convert simplified syntax to standard regular expression syntax + string entry = null; + for (int i = 0; i < bypassList.Length; i++) + { + // Get the original entry + entry = bypassList[i].Trim(); + // . -> [.] because . means any char + entry = entry.Replace(".", "[.]"); + // * -> .* because * is a quantifier and need a char or group to apply to + entry = entry.Replace("*", ".*"); + + // Replace with the valid entry syntax + bypassList[i] = entry; + + } + webProxy.BypassList = bypassList; + } + } + requester = new RestRequester(webProxy); + RestRequestersByProxy.Add(key, requester); + } + + return requester; + } } - public T Post(IRestRequest request) { //Run synchronous in a new thread-pool task. @@ -99,14 +171,25 @@ private async Task SendAsync(HttpRequestMessage request, try { - var response = await HttpUtil.getHttpClient().SendAsync(request, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token).ConfigureAwait(false); + //logger.Debug("Execute request with proxy " + ((null != Proxy) ? Proxy.ToString() : "no proxy")); + HttpClient httpClient = HttpUtil.getHttpClient(Proxy); + var response = await httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token).ConfigureAwait(false); response.EnsureSuccessStatusCode(); return response; } catch(Exception e) { - throw restRequestTimeout.IsCancellationRequested ? new SnowflakeDbException(SFError.REQUEST_TIMEOUT) : e; + if (restRequestTimeout.IsCancellationRequested) + { + // Timeout or cancellation + throw new SnowflakeDbException(e, SFError.REQUEST_TIMEOUT); + } + else + { + //rethrow + throw; + } } } } diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloader.cs b/Snowflake.Data/Core/SFBlockingChunkDownloader.cs index 8e059ba38..015eaab28 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloader.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloader.cs @@ -28,12 +28,12 @@ class SFBlockingChunkDownloader : IChunkDownloader private int nextChunkToDownloadIndex; - // External cancellation token, used to stop donwload + // External cancellation token, used to stop download private CancellationToken externalCancellationToken; private readonly int prefetchThreads; - private static IRestRequester restRequester = RestRequester.Instance; + private readonly IRestRequester RestRequester; private Dictionary chunkHeaders; @@ -43,14 +43,15 @@ public SFBlockingChunkDownloader(int colCount, ListchunkInfos, string qrmk, Dictionary chunkHeaders, CancellationToken cancellationToken, - SFBaseResultSet ResultSet) + SFBaseResultSet resultSet) { this.qrmk = qrmk; this.chunkHeaders = chunkHeaders; this.chunks = new List(); this.nextChunkToDownloadIndex = 0; - this.ResultSet = ResultSet; - this.prefetchThreads = GetPrefetchThreads(ResultSet); + this.ResultSet = resultSet; + this.prefetchThreads = GetPrefetchThreads(resultSet); + RestRequester = resultSet.sfStatement.SfSession.restRequester; externalCancellationToken = cancellationToken; var idx = 0; @@ -124,7 +125,7 @@ private async Task DownloadChunkAsync(DownloadContext downloadCont }; - var httpResponse = await restRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false); + var httpResponse = await RestRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false); Stream stream = Task.Run(async() => await httpResponse.Content.ReadAsStreamAsync()).Result; IEnumerable encoding; //TODO this shouldn't be required. diff --git a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs index c767620f1..6ad88e068 100755 --- a/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs +++ b/Snowflake.Data/Core/SFBlockingChunkDownloaderV3.cs @@ -37,7 +37,7 @@ class SFBlockingChunkDownloaderV3 : IChunkDownloader private readonly int prefetchSlot; - private static IRestRequester restRequester = RestRequester.Instance; + private readonly IRestRequester RestRequester; private Dictionary chunkHeaders; @@ -61,6 +61,7 @@ public SFBlockingChunkDownloaderV3(int colCount, this.chunkInfos = chunkInfos; this.nextChunkToConsumeIndex = 0; this.taskQueues = new List>(); + RestRequester = ResultSet.sfStatement.SfSession.restRequester; externalCancellationToken = cancellationToken; for (int i=0; i GetNextChunkAsync() - { - return _downloadTasks.IsCompleted ? Task.FromResult(null) : _downloadTasks.Take(); - }*/ - public Task GetNextChunkAsync() { logger.Info($"NextChunkToConsume: {nextChunkToConsumeIndex}, NextChunkToDownload: {nextChunkToDownloadIndex}"); @@ -140,7 +135,7 @@ private async Task DownloadChunkAsync(DownloadContextV3 downloadCo chunkHeaders = downloadContext.chunkHeaders }; - using (var httpResponse = await restRequester.GetAsync(downloadRequest, downloadContext.cancellationToken) + using (var httpResponse = await RestRequester.GetAsync(downloadRequest, downloadContext.cancellationToken) .ConfigureAwait(continueOnCapturedContext: false)) using (Stream stream = await httpResponse.Content.ReadAsStreamAsync() .ConfigureAwait(continueOnCapturedContext: false)) diff --git a/Snowflake.Data/Core/SFChunkDownloaderV2.cs b/Snowflake.Data/Core/SFChunkDownloaderV2.cs index d94528f08..9d9b69ce1 100755 --- a/Snowflake.Data/Core/SFChunkDownloaderV2.cs +++ b/Snowflake.Data/Core/SFChunkDownloaderV2.cs @@ -30,16 +30,18 @@ class SFChunkDownloaderV2 : IChunkDownloader //TODO: parameterize prefetch slot private const int prefetchSlot = 5; - private static IRestRequester restRequester = RestRequester.Instance; + private IRestRequester RestRequester; private Dictionary chunkHeaders; public SFChunkDownloaderV2(int colCount, ListchunkInfos, string qrmk, - Dictionary chunkHeaders, CancellationToken cancellationToken) + Dictionary chunkHeaders, CancellationToken cancellationToken, + IRestRequester restRequester) { this.qrmk = qrmk; this.chunkHeaders = chunkHeaders; this.chunks = new List(); + RestRequester = restRequester; externalCancellationToken = cancellationToken; var idx = 0; @@ -130,7 +132,7 @@ private async Task DownloadChunkAsync(DownloadContextV2 downloadCo chunkHeaders = downloadContext.chunkHeaders }; - var httpResponse = await restRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false); + var httpResponse = await RestRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false); Stream stream = await httpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); if (httpResponse.Content.Headers.TryGetValues("Content-Encoding", out var encoding)) diff --git a/Snowflake.Data/Core/SFSession.cs b/Snowflake.Data/Core/SFSession.cs index d199409b8..4113671ed 100755 --- a/Snowflake.Data/Core/SFSession.cs +++ b/Snowflake.Data/Core/SFSession.cs @@ -12,6 +12,7 @@ using Snowflake.Data.Core.Authenticator; using System.Threading; using System.Threading.Tasks; +using System.Net; namespace Snowflake.Data.Core { @@ -55,7 +56,11 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) } else { - SnowflakeDbException e = new SnowflakeDbException("", authnResponse.code, authnResponse.message, ""); + SnowflakeDbException e = new SnowflakeDbException + (SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + authnResponse.code, + authnResponse.message, + ""); logger.Error("Authentication failed", e); throw e; } @@ -85,15 +90,8 @@ internal Uri BuildLoginUrl() /// Constructor /// /// A string in the form of "key1=value1;key2=value2" - internal SFSession(String connectionString, SecureString password) : - this(connectionString, password, RestRequester.Instance) + internal SFSession(String connectionString, SecureString password) { - } - - internal SFSession(String connectionString, SecureString password, IRestRequester restRequester) - { - - this.restRequester = restRequester; properties = SFSessionProperties.parseConnectionString(connectionString, password); ParameterMap = new Dictionary(); @@ -103,13 +101,38 @@ internal SFSession(String connectionString, SecureString password, IRestRequeste { ParameterMap[SFSessionParameter.CLIENT_VALIDATE_DEFAULT_PARAMETERS] = Boolean.Parse(properties[SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS]); + timeoutInSec = int.Parse(properties[SFSessionProperty.CONNECTION_TIMEOUT]); + string proxyHost = null; + string proxyPort = null; + string noProxyHosts = null; + string proxyPwd = null; + string proxyUser = null; + if (Boolean.Parse(properties[SFSessionProperty.USEPROXY])) + { + // Let's try to get the associated RestRequester + properties.TryGetValue(SFSessionProperty.PROXYHOST, out proxyHost); + properties.TryGetValue(SFSessionProperty.PROXYPORT, out proxyPort); + properties.TryGetValue(SFSessionProperty.NONPROXYHOSTS, out noProxyHosts); + properties.TryGetValue(SFSessionProperty.PROXYPASSWORD, out proxyPwd); + properties.TryGetValue(SFSessionProperty.PROXYUSER, out proxyUser); + + if (!String.IsNullOrEmpty(noProxyHosts)) + { + // The list is url-encoded + // Host names are separated with a URL-escaped pipe symbol (%7C). + noProxyHosts = HttpUtility.UrlDecode(noProxyHosts); + } + } - timeoutInSec = int.Parse(properties[SFSessionProperty.CONNECTION_TIMEOUT]); + restRequester = + RestRequester.GetRestRequester(proxyHost, proxyPort, proxyUser, proxyPwd, noProxyHosts); + } catch (Exception e) { - logger.Error(e.Message); - throw new SnowflakeDbException(e.InnerException, + logger.Error("Unable to connect", e); + throw new SnowflakeDbException(e, + SnowflakeDbException.CONNECTION_FAILURE_SSTATE, SFError.INVALID_CONNECTION_STRING, "Unable to connect"); } @@ -124,7 +147,14 @@ internal SFSession(String connectionString, SecureString password, IRestRequeste logger.Warn($"Connection timeout provided is negative. Timeout will be infinite."); } - connectionTimeout = timeoutInSec > 0 ? TimeSpan.FromSeconds(timeoutInSec) : Timeout.InfiniteTimeSpan; + connectionTimeout = timeoutInSec > 0 ? TimeSpan.FromSeconds(timeoutInSec) : Timeout.InfiniteTimeSpan; + } + + internal SFSession(String connectionString, SecureString password, IRestRequester restRequester) : + this(connectionString, password) + { + // Override the Rest requester with a mock for testing if necessary + this.restRequester = restRequester; } internal Uri BuildUri(string path, Dictionary queryParams = null) @@ -174,7 +204,7 @@ internal async Task OpenAsync(CancellationToken cancellationToken) internal void close() { // Nothing to do if the session is not open - if (null != sessionToken) return; + if (null == sessionToken) return; // Send a close session request var queryParams = new Dictionary(); diff --git a/Snowflake.Data/Core/SFSessionProperty.cs b/Snowflake.Data/Core/SFSessionProperty.cs index 603e39bc7..75f77e993 100755 --- a/Snowflake.Data/Core/SFSessionProperty.cs +++ b/Snowflake.Data/Core/SFSessionProperty.cs @@ -48,6 +48,19 @@ internal enum SFSessionProperty PRIVATE_KEY, [SFSessionPropertyAttr(required = false)] TOKEN, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + USEPROXY, + [SFSessionPropertyAttr(required = false)] + PROXYHOST, + [SFSessionPropertyAttr(required = false)] + PROXYPORT, + [SFSessionPropertyAttr(required = false)] + PROXYUSER, + [SFSessionPropertyAttr(required = false)] + PROXYPASSWORD, + [SFSessionPropertyAttr(required = false)] + NONPROXYHOSTS, + } class SFSessionPropertyAttr : Attribute @@ -67,7 +80,9 @@ class SFSessionProperties : Dictionary SFSessionProperty.PASSWORD, SFSessionProperty.PRIVATE_KEY, SFSessionProperty.TOKEN, - SFSessionProperty.PRIVATE_KEY_PWD}; + SFSessionProperty.PRIVATE_KEY_PWD, + SFSessionProperty.PROXYPASSWORD, + }; public override bool Equals(object obj) { @@ -167,13 +182,14 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin } else { - // An equal sign was not doubled or something else happended + // An equal sign was not doubled or something else happened // making the connection invalid string invalidStringDetail = String.Format("Invalid key value pair {0}", keyVal); SnowflakeDbException e = - new SnowflakeDbException(SFError.INVALID_CONNECTION_STRING, - new object[] { invalidStringDetail }); + new SnowflakeDbException( + SFError.INVALID_CONNECTION_STRING, + new object[] { invalidStringDetail }); logger.Error("Invalid string.", e); throw e; } @@ -193,11 +209,43 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin } } + bool useProxy = false; + if (properties.ContainsKey(SFSessionProperty.USEPROXY)) + { + try + { + useProxy = Boolean.Parse(properties[SFSessionProperty.USEPROXY]); + } + catch (Exception e) + { + // The useProxy setting is not a valid boolean value + logger.Error("Unable to connect", e); + throw new SnowflakeDbException(e, + SFError.INVALID_CONNECTION_STRING, + e.Message); + } + } + + // Based on which proxy settings have been provided, update the required settings list + if (useProxy) + { + // If useProxy is true, then proxyhost and proxy port are mandatory + SFSessionProperty.PROXYHOST.GetAttribute().required = true; + SFSessionProperty.PROXYPORT.GetAttribute().required = true; + + // If a username is provided, then a password is required + if (properties.ContainsKey(SFSessionProperty.PROXYUSER)) + { + SFSessionProperty.PROXYPASSWORD.GetAttribute().required = true; + } + } + + checkSessionProperties(properties); + if (password != null) { properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; } - checkSessionProperties(properties); // compose host value if not specified if (!properties.ContainsKey(SFSessionProperty.HOST) || @@ -223,7 +271,7 @@ private static void checkSessionProperties(SFSessionProperties properties) { SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, sessionProperty); - logger.Error("Missing connetion property", e); + logger.Error("Missing connection property", e); throw e; } diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index bd33e748f..95ffac4e2 100755 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -39,19 +39,16 @@ class SFStatement private CancellationTokenSource _timeoutTokenSource; - // Merged cancellation token source for all canellation signal. + // Merged cancellation token source for all cancellation signal. // Cancel callback will be registered under token issued by this source. private CancellationTokenSource _linkedCancellationTokenSouce; - internal SFStatement(SFSession session, IRestRequester rest) + internal SFStatement(SFSession session) { SfSession = session; - _restRequester = rest; + _restRequester = session.restRequester; } - internal SFStatement(SFSession session) : this(session, RestRequester.Instance) - { } - private void AssignQueryRequestId() { lock (_requestIdLock) From 5d50386d0c0ff085fd23a53ff7d38c07d3b4a995 Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Tue, 25 May 2021 11:37:43 -0700 Subject: [PATCH 02/13] Missed file. queryId->QueryId --- Snowflake.Data.Tests/SFDbCommandIT.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data.Tests/SFDbCommandIT.cs b/Snowflake.Data.Tests/SFDbCommandIT.cs index d10830a4e..9f69130d3 100755 --- a/Snowflake.Data.Tests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/SFDbCommandIT.cs @@ -309,7 +309,7 @@ public void TestDataSourceError() catch (SnowflakeDbException e) { Assert.AreEqual(2003, e.ErrorCode); - Assert.AreNotEqual("", e.queryId); + Assert.AreNotEqual("", e.QueryId); } conn.Close(); From cd80a381fb422feb05684517afb76cdcd628b01a Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Tue, 25 May 2021 12:02:06 -0700 Subject: [PATCH 03/13] Remove local path --- Snowflake.Data.Tests/SFBaseTest.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 27eec224b..a020b253b 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -88,7 +88,7 @@ public void SFTestSetup() String cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); - StreamReader reader = new StreamReader("C:\\Users\\natachab\\Snowflake\\fromMasterToWorkOnTicket\\Snowflake.Data.Tests\\parameters.json"); + StreamReader reader = new StreamReader("parameters.json"); var testConfigString = reader.ReadToEnd(); From 40b92bd381a0caa5a1a6988c0b5d7146ad0bea99 Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Wed, 26 May 2021 09:25:17 -0700 Subject: [PATCH 04/13] Update connection string tests,fix secure pwd connection failure --- Snowflake.Data.Tests/SFConnectionIT.cs | 2 +- Snowflake.Data.Tests/SFSessionPropertyTest.cs | 1 + Snowflake.Data/Core/SFSessionProperty.cs | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Snowflake.Data.Tests/SFConnectionIT.cs b/Snowflake.Data.Tests/SFConnectionIT.cs index 766b0bc66..c12b48d95 100755 --- a/Snowflake.Data.Tests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/SFConnectionIT.cs @@ -964,7 +964,7 @@ public void TestInvalidProxySettingWithByPassListFromConnectionString() = ConnectionString + String.Format( ";useProxy=true;proxyHost=Invalid;proxyPort=8080;nonProxyHosts={0}", - "*.foo.com %7C" + testConfig.host + "|localhost"); + "*.foo.com %7C" + testConfig.account+".snowflakecomputing.com|localhost"); conn.Open(); // Because testConfig.host is in the bypass list, the proxy should not be used diff --git a/Snowflake.Data.Tests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/SFSessionPropertyTest.cs index c0d258e62..2fc56da68 100644 --- a/Snowflake.Data.Tests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/SFSessionPropertyTest.cs @@ -46,6 +46,7 @@ public void TestValidConnectionString() { SFSessionProperty.PASSWORD, "123" }, { SFSessionProperty.PORT, "443" }, { SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS, "true" }, + { SFSessionProperty.USEPROXY, "false" }, }, }, }; diff --git a/Snowflake.Data/Core/SFSessionProperty.cs b/Snowflake.Data/Core/SFSessionProperty.cs index 75f77e993..b3ca2596d 100755 --- a/Snowflake.Data/Core/SFSessionProperty.cs +++ b/Snowflake.Data/Core/SFSessionProperty.cs @@ -240,13 +240,13 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin } } - checkSessionProperties(properties); - if (password != null) { properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; } + checkSessionProperties(properties); + // compose host value if not specified if (!properties.ContainsKey(SFSessionProperty.HOST) || (0 == properties[SFSessionProperty.HOST].Length)) From f736a01b195d6a2e18afcd5f588f9562e6103c5e Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Wed, 16 Jun 2021 14:22:07 -0700 Subject: [PATCH 05/13] local test changes --- Snowflake.Data.Tests/SFBaseTest.cs | 2 +- Snowflake.Data.Tests/SFConnectionIT.cs | 42 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 28df8337a..2c03b7d25 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -88,7 +88,7 @@ public void SFTestSetup() String cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); - StreamReader reader = new StreamReader("parameters.json"); + StreamReader reader = new StreamReader("C:\\Users\\natachab\\Snowflake\\fromMasterToWorkOnTicket\\Snowflake.Data.Tests\\parameters.json"); var testConfigString = reader.ReadToEnd(); diff --git a/Snowflake.Data.Tests/SFConnectionIT.cs b/Snowflake.Data.Tests/SFConnectionIT.cs index 415b2f73d..95ea92637 100755 --- a/Snowflake.Data.Tests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/SFConnectionIT.cs @@ -20,6 +20,48 @@ class SFConnectionIT : SFBaseTest { private static SFLogger logger = SFLoggerFactory.GetLogger(); + + [Test] + public void testMulitpleConnectionInParallel() + { + Task[] tasks = new Task[450]; + for (int i = 0; i < 450; i++) + { + tasks[i] = Task.Run(() => + { + using (IDbConnection conn = new SnowflakeDbConnection()) + { + + conn.ConnectionString = ConnectionString + ";CONNECTION_TIMEOUT=30;INSECUREMODE=false"; + Console.WriteLine($"{conn.ConnectionString}"); + + try + { + conn.Open(); + } + catch (Exception e) + { + Console.WriteLine(e); + Console.WriteLine("--------------------------"); + Console.WriteLine(e.InnerException); + } + } + }); + } + try + { + Task.WaitAll(tasks); + } + catch (AggregateException ae) + { + Console.WriteLine("One or more exceptions occurred: "); + foreach (var ex in ae.Flatten().InnerExceptions) + Console.WriteLine(" {0}", ex.Message); + + } + + } + [Test] public void TestBasicConnection() { From 83ba462b159ab29eecc6f8e3a134ffca40eb6492 Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Wed, 16 Jun 2021 14:46:32 -0700 Subject: [PATCH 06/13] Test net46 by default --- Snowflake.Data.Tests/Snowflake.Data.Tests.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj index c1e4fa54f..1e985e643 100755 --- a/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj +++ b/Snowflake.Data.Tests/Snowflake.Data.Tests.csproj @@ -1,6 +1,6 @@  - netcoreapp2.2;net46 + net46;netcoreapp2.2 2.2.8 Snowflake.Data.Tests Snowflake Connector for .NET From 2d24b7b2b7db102fe261f9ff086d31a10dbdfa08 Mon Sep 17 00:00:00 2001 From: David Nawn <77072612+dnawnOlo@users.noreply.github.com> Date: Mon, 14 Jun 2021 16:10:07 -0400 Subject: [PATCH 07/13] Prevent unhandled exception when calling Cancel (#322) * Trap any exception thrown by the Cancel request --- Snowflake.Data/Core/SFStatement.cs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index bd33e748f..455956894 100755 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -149,7 +149,18 @@ private void registerQueryCancellationCallback(int timeout, CancellationToken ex externalCancellationToken); if (!_linkedCancellationTokenSouce.IsCancellationRequested) { - _linkedCancellationTokenSouce.Token.Register(() => Cancel()); + _linkedCancellationTokenSouce.Token.Register(() => + { + try + { + Cancel(); + } + catch (Exception ex) + { + // Prevent an unhandled exception from being thrown + logger.Error("Unable to cancel query.", ex); + } + }); } } From 239a2b5bc2931b817cb9f38a8733adf84bea49c3 Mon Sep 17 00:00:00 2001 From: Danny Guinther Date: Mon, 28 Jun 2021 19:02:37 -0400 Subject: [PATCH 08/13] Better dispose of CancellationTokenSources (#336) * Dispose of CancellationTokenSources to avoid memory leak --- Snowflake.Data/Core/HttpUtil.cs | 9 ++++++-- Snowflake.Data/Core/RestRequester.cs | 34 +++++++++++++++------------- Snowflake.Data/Core/SFStatement.cs | 33 ++++++++++++++++++++++----- 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 07ef2080a..278704892 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -234,12 +234,17 @@ protected override async Task SendAsync(HttpRequestMessage } } + if (childCts != null) + { + childCts.Dispose(); + } + if (response != null) { if (response.IsSuccessStatusCode) { return response; } - else + else { logger.Debug($"Failed Response: {response.ToString()}"); bool isRetryable = isRetryableHTTPCode((int)response.StatusCode); @@ -250,7 +255,7 @@ protected override async Task SendAsync(HttpRequestMessage } } } - else + else { logger.Info("Response returned was null."); } diff --git a/Snowflake.Data/Core/RestRequester.cs b/Snowflake.Data/Core/RestRequester.cs index da7ffecb0..36595020f 100644 --- a/Snowflake.Data/Core/RestRequester.cs +++ b/Snowflake.Data/Core/RestRequester.cs @@ -89,24 +89,26 @@ private async Task SendAsync(HttpMethod method, HttpRequestMessage message = request.ToRequestMessage(method); // merge multiple cancellation token - CancellationTokenSource restRequestTimeout = new CancellationTokenSource(request.GetRestTimeout()); - CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(externalCancellationToken, - restRequestTimeout.Token); - - try - { - var response = await HttpUtil.getHttpClient() - .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token) - .ConfigureAwait(false); - response.EnsureSuccessStatusCode(); - - return response; - } - catch(Exception e) + using (CancellationTokenSource restRequestTimeout = new CancellationTokenSource(request.GetRestTimeout())) { - throw restRequestTimeout.IsCancellationRequested ? new SnowflakeDbException(SFError.REQUEST_TIMEOUT) : e; + using (CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(externalCancellationToken, + restRequestTimeout.Token)) + { + try + { + var response = await HttpUtil.getHttpClient(request.GetInsecureMode()) + .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token) + .ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + + return response; + } + catch(Exception e) + { + throw restRequestTimeout.IsCancellationRequested ? new SnowflakeDbException(SFError.REQUEST_TIMEOUT) : e; + } + } } } } - } diff --git a/Snowflake.Data/Core/SFStatement.cs b/Snowflake.Data/Core/SFStatement.cs index 455956894..c6de07a0d 100755 --- a/Snowflake.Data/Core/SFStatement.cs +++ b/Snowflake.Data/Core/SFStatement.cs @@ -41,7 +41,7 @@ class SFStatement // Merged cancellation token source for all canellation signal. // Cancel callback will be registered under token issued by this source. - private CancellationTokenSource _linkedCancellationTokenSouce; + private CancellationTokenSource _linkedCancellationTokenSource; internal SFStatement(SFSession session, IRestRequester rest) { @@ -119,6 +119,22 @@ private SFRestRequest BuildResultRequest(string resultPath) }; } + private void CleanUpCancellationTokenSources() + { + if (_linkedCancellationTokenSource != null) + { + // This should also take care of cleaning up the cancellation callback that was registered. + // https://github.com/microsoft/referencesource/blob/master/mscorlib/system/threading/CancellationTokenSource.cs#L552 + _linkedCancellationTokenSource.Dispose(); + _linkedCancellationTokenSource = null; + } + if (_timeoutTokenSource != null) + { + _timeoutTokenSource.Dispose(); + _timeoutTokenSource = null; + } + } + private SFBaseResultSet BuildResultSet(QueryExecResponse response, CancellationToken cancellationToken) { if (response.success) @@ -145,11 +161,11 @@ private void SetTimeout(int timeout) private void registerQueryCancellationCallback(int timeout, CancellationToken externalCancellationToken) { SetTimeout(timeout); - _linkedCancellationTokenSouce = CancellationTokenSource.CreateLinkedTokenSource(_timeoutTokenSource.Token, + _linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_timeoutTokenSource.Token, externalCancellationToken); - if (!_linkedCancellationTokenSouce.IsCancellationRequested) + if (!_linkedCancellationTokenSource.IsCancellationRequested) { - _linkedCancellationTokenSouce.Token.Register(() => + _linkedCancellationTokenSource.Token.Register(() => { try { @@ -219,6 +235,7 @@ internal async Task ExecuteAsync(int timeout, string sql, Dicti } finally { + CleanUpCancellationTokenSources(); ClearQueryRequestId(); } } @@ -271,6 +288,7 @@ internal SFBaseResultSet Execute(int timeout, string sql, Dictionary(request); @@ -318,7 +339,7 @@ internal void Cancel() { logger.Warn("Query cancellation failed."); } + CleanUpCancellationTokenSources(); } - } } From efcd40d098bb1ececbe371a9c39e51eb05c4a5c1 Mon Sep 17 00:00:00 2001 From: SimbaGithub <48035983+SimbaGithub@users.noreply.github.com> Date: Fri, 2 Jul 2021 09:30:01 -0700 Subject: [PATCH 09/13] Dispose of the Http response to release resources (#339) Dispose of the Http response to release resources. This is now even more important because we switched to use http streaming on all platforms. --- .../Core/Authenticator/OktaAuthenticator.cs | 12 ++++++---- Snowflake.Data/Core/HttpUtil.cs | 3 +++ Snowflake.Data/Core/RestRequester.cs | 22 ++++++++++++------- Snowflake.Data/Core/SFChunkDownloaderV2.cs | 19 +++++++++------- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index 9ce3a0848..005bb52ba 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -71,8 +71,10 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) logger.Debug("step 4: get SAML reponse from sso"); var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); - var samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false); - samlRawHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); + using (var samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false)) + { + samlRawHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); + } logger.Debug("step 5: verify postback url in SAML reponse"); VerifyPostbackUrl(); @@ -110,8 +112,10 @@ void IAuthenticator.Authenticate() logger.Debug("step 4: get SAML reponse from sso"); var samlRestRequest = BuildSAMLRestRequest(ssoUrl, onetimeToken); - var samlRawResponse = session.restRequester.Get(samlRestRequest); - samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync()).Result; + using (var samlRawResponse = session.restRequester.Get(samlRestRequest)) + { + samlRawHtmlString = Task.Run(async () => await samlRawResponse.Content.ReadAsStringAsync()).Result; + } logger.Debug("step 5: verify postback url in SAML reponse"); VerifyPostbackUrl(); diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 278704892..49e76c7e3 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -260,6 +260,9 @@ protected override async Task SendAsync(HttpRequestMessage logger.Info("Response returned was null."); } + // Disposing of the response if not null now that we don't need it anymore + response?.Dispose(); + requestMessage.RequestUri = updater.Update(); logger.Debug($"Sleep {backOffInSec} seconds and then retry the request"); diff --git a/Snowflake.Data/Core/RestRequester.cs b/Snowflake.Data/Core/RestRequester.cs index 36595020f..ea11737c6 100644 --- a/Snowflake.Data/Core/RestRequester.cs +++ b/Snowflake.Data/Core/RestRequester.cs @@ -53,9 +53,11 @@ public T Post(IRestRequest request) public async Task PostAsync(IRestRequest request, CancellationToken cancellationToken) { - var response = await SendAsync(HttpMethod.Post, request, cancellationToken).ConfigureAwait(false); - var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); - return JsonConvert.DeserializeObject(json); + using (var response = await SendAsync(HttpMethod.Post, request, cancellationToken).ConfigureAwait(false)) + { + var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + return JsonConvert.DeserializeObject(json); + } } public T Get(IRestRequest request) @@ -66,9 +68,11 @@ public T Get(IRestRequest request) public async Task GetAsync(IRestRequest request, CancellationToken cancellationToken) { - HttpResponseMessage response = await GetAsync(request, cancellationToken).ConfigureAwait(false); - var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); - return JsonConvert.DeserializeObject(json); + using (HttpResponseMessage response = await GetAsync(request, cancellationToken).ConfigureAwait(false)) + { + var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + return JsonConvert.DeserializeObject(json); + } } public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) @@ -87,16 +91,16 @@ private async Task SendAsync(HttpMethod method, CancellationToken externalCancellationToken) { HttpRequestMessage message = request.ToRequestMessage(method); - // merge multiple cancellation token using (CancellationTokenSource restRequestTimeout = new CancellationTokenSource(request.GetRestTimeout())) { using (CancellationTokenSource linkedCts = CancellationTokenSource.CreateLinkedTokenSource(externalCancellationToken, restRequestTimeout.Token)) { + HttpResponseMessage response = null; try { - var response = await HttpUtil.getHttpClient(request.GetInsecureMode()) + response = await HttpUtil.getHttpClient(request.GetInsecureMode()) .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token) .ConfigureAwait(false); response.EnsureSuccessStatusCode(); @@ -105,6 +109,8 @@ private async Task SendAsync(HttpMethod method, } catch(Exception e) { + // Disposing of the response if not null now that we don't need it anymore + response?.Dispose(); throw restRequestTimeout.IsCancellationRequested ? new SnowflakeDbException(SFError.REQUEST_TIMEOUT) : e; } } diff --git a/Snowflake.Data/Core/SFChunkDownloaderV2.cs b/Snowflake.Data/Core/SFChunkDownloaderV2.cs index d94528f08..0efea6fee 100755 --- a/Snowflake.Data/Core/SFChunkDownloaderV2.cs +++ b/Snowflake.Data/Core/SFChunkDownloaderV2.cs @@ -130,18 +130,21 @@ private async Task DownloadChunkAsync(DownloadContextV2 downloadCo chunkHeaders = downloadContext.chunkHeaders }; - var httpResponse = await restRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false); - Stream stream = await httpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false); - - if (httpResponse.Content.Headers.TryGetValues("Content-Encoding", out var encoding)) + Stream stream = null; + using (var httpResponse = await restRequester.GetAsync(downloadRequest, downloadContext.cancellationToken).ConfigureAwait(false)) + using (stream = await httpResponse.Content.ReadAsStreamAsync().ConfigureAwait(false)) { - if (string.Equals(encoding.First(), "gzip", StringComparison.OrdinalIgnoreCase)) + + if (httpResponse.Content.Headers.TryGetValues("Content-Encoding", out var encoding)) { - stream = new GZipStream(stream, CompressionMode.Decompress); + if (string.Equals(encoding.First(), "gzip", StringComparison.OrdinalIgnoreCase)) + { + stream = new GZipStream(stream, CompressionMode.Decompress); + } } - } - parseStreamIntoChunk(stream, chunk); + parseStreamIntoChunk(stream, chunk); + } chunk.downloadState = DownloadState.SUCCESS; logger.Info($"Succeed downloading chunk #{downloadContext.chunkIndex+1}"); From cdb48d3182c6e2cdcbd7e5da3939df8182b6b774 Mon Sep 17 00:00:00 2001 From: SimbaGithub <48035983+SimbaGithub@users.noreply.github.com> Date: Mon, 5 Jul 2021 14:33:14 -0700 Subject: [PATCH 10/13] Add SecretDetector and unit tests (#335) * Add SecretDetector and unit tests * Mask messages to be logged with SecretDetector * Add capability to set custom patterns for SecretDetector * Change to array param type when setting custom patterns * Include colon in AWS regex and add test for masking HTTP response * Add test for AWS key with single and double quotes --- .../Mock/MockSecretDetector.cs | 31 ++ Snowflake.Data.Tests/SecretDetectorTest.cs | 473 ++++++++++++++++++ Snowflake.Data/Logger/Log4netImpl.cs | 5 + Snowflake.Data/Logger/SecretDetector.cs | 165 ++++++ 4 files changed, 674 insertions(+) create mode 100644 Snowflake.Data.Tests/Mock/MockSecretDetector.cs create mode 100644 Snowflake.Data.Tests/SecretDetectorTest.cs create mode 100644 Snowflake.Data/Logger/SecretDetector.cs diff --git a/Snowflake.Data.Tests/Mock/MockSecretDetector.cs b/Snowflake.Data.Tests/Mock/MockSecretDetector.cs new file mode 100644 index 000000000..5f4adadde --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockSecretDetector.cs @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Tests.Mock +{ + class MockSecretDetector + { + public static SecretDetector.Mask MaskSecrets(string text) + { + SecretDetector.Mask result = new SecretDetector.Mask(); + try + { + throw new Exception("Test exception"); + } + catch (Exception ex) + { + //We'll assume that the exception was raised during masking + //to be safe consider that the log has sensitive information + //and do not raise an exception. + result.isMasked = true; + result.maskedText = ex.Message; + result.errStr = ex.Message; + } + return result; + } + } +} diff --git a/Snowflake.Data.Tests/SecretDetectorTest.cs b/Snowflake.Data.Tests/SecretDetectorTest.cs new file mode 100644 index 000000000..ff2ffd039 --- /dev/null +++ b/Snowflake.Data.Tests/SecretDetectorTest.cs @@ -0,0 +1,473 @@ +/* + * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Tests +{ + using NUnit.Framework; + using Snowflake.Data.Log; + using Snowflake.Data.Tests.Mock; + using System; + using System.Collections.Generic; + + [TestFixture] + class SecretDetectorTest : SFBaseTest + { + SecretDetector.Mask mask; + + [SetUp] + public void BeforeTest() + { + mask = SecretDetector.MaskSecrets(null); + } + + public void BasicMasking(string text) + { + mask = SecretDetector.MaskSecrets(text); + Assert.IsFalse(mask.isMasked); + Assert.AreEqual(text, mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestNullString() + { + BasicMasking(null); + } + + [Test] + public void TestEmptyString() + { + BasicMasking(""); + } + + [Test] + public void TestNoMasking() + { + BasicMasking("This string is innocuous"); + } + + [Test] + public void TestExceptionInMasking() + { + mask = MockSecretDetector.MaskSecrets("This string will raise an exception"); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual("Test exception", mask.maskedText); + Assert.AreEqual("Test exception", mask.errStr); + } + + public void BasicMasking(string text, string expectedText) + { + mask = SecretDetector.MaskSecrets(text); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(expectedText, mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestAWSKeys() + { + // aws_key_id + BasicMasking(@"aws_key_id='aaaaaaaa'", @"aws_key_id='****'"); + + // aws_secret_key + BasicMasking(@"aws_secret_key='aaaaaaaa'", @"aws_secret_key='****'"); + + // access_key_id + BasicMasking(@"access_key_id='aaaaaaaa'", @"access_key_id='****'"); + + // secret_access_key + BasicMasking(@"secret_access_key='aaaaaaaa'", @"secret_access_key='****'"); + + // aws_key_id with colon + BasicMasking(@"aws_key_id:'aaaaaaaa'", @"aws_key_id:'****'"); + + // aws_key_id with single quote on key + BasicMasking(@"'aws_key_id':'aaaaaaaa'", @"'aws_key_id':'****'"); + + // aws_key_id with double quotes on key + BasicMasking(@"""aws_key_id"":'aaaaaaaa'", @"""aws_key_id"":'****'"); + } + + [Test] + public void TestAWSTokens() + { + // accessToken + BasicMasking(@"accessToken:""aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa""", @"accessToken"":""XXXX"""); + + // tempToken + BasicMasking(@"tempToken:""aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa""", @"tempToken"":""XXXX"""); + + // keySecret + BasicMasking(@"keySecret:""aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa""", @"keySecret"":""XXXX"""); + } + + [Test] + public void TestSASTokens() + { + // sig + BasicMasking(@"sig=?Paaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"sig=****"); + + // signature + BasicMasking(@"signature=?Paaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"signature=****"); + + // AWSAccessKeyId + BasicMasking(@"AWSAccessKeyId=?Paaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"AWSAccessKeyId=****"); + + // password + BasicMasking(@"password=?Paaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"password=****"); + + // passcode + BasicMasking(@"passcode=?Paaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"passcode=****"); + } + + [Test] + public void TestPrivateKey() + { + BasicMasking("-----BEGIN PRIVATE KEY-----\naaaaaaaaaaaaaaaa\naaaaaaaaaaaaaaaa\n-----END PRIVATE KEY-----", + "-----BEGIN PRIVATE KEY-----\\\\nXXXX\\\\n-----END PRIVATE KEY-----"); + } + + [Test] + public void TestPrivateKeyData() + { + BasicMasking(@"""privateKeyData"": ""aaaaaaaaaa""", @"""privateKeyData"": ""XXXX"""); + } + + [Test] + public void TestConnectionTokens() + { + // token + BasicMasking(@"token:aaaaaaaa", @"token:****"); + + // assertion content + BasicMasking(@"assertion content:aaaaaaaa", @"assertion content:****"); + } + + [Test] + public void TestPassword() + { + // password + BasicMasking(@"password:aaaaaaaa", @"password:****"); + + // pwd + BasicMasking(@"pwd:aaaaaaaa", @"pwd:****"); + } + + [Test] + public void TestMaskToken() + { + string longToken = "_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=U" + + "KyJH9DS=nFzzWnfZKGV+C7GopWCGD4Lj" + + "OLLFZKOE26LXHDt3pTi4iI1qwKuSpf/F" + + "mClCMBSissVsU3Ei590FP0lPQQhcSGcD" + + "u69ZL_1X6e9h5z62t/iY7ZkII28n2qU=" + + "nrBJUgPRCIbtJQkVJXIuOHjX4G5yUEKj" + + "ZBAx4w6=_lqtt67bIA=o7D=oUSjfywsR" + + "FoloNIkBPXCwFTv+1RVUHgVA2g8A9Lw5" + + "XdJYuI8vhg=f0bKSq7AhQ2Bh"; + + string tokenStrWithPrefix = "Token =" + longToken; + mask = SecretDetector.MaskSecrets(tokenStrWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"Token =****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string idTokenStrWithPrefix = "idToken : " + longToken; + mask = SecretDetector.MaskSecrets(idTokenStrWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"idToken : ****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string sessionTokenStrWithPrefix = "sessionToken : " + longToken; + mask = SecretDetector.MaskSecrets(sessionTokenStrWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"sessionToken : ****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string masterTokenStrWithPrefix = "masterToken : " + longToken; + mask = SecretDetector.MaskSecrets(masterTokenStrWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"masterToken : ****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string assertionStrWithPrefix = "assertion content: " + longToken; + mask = SecretDetector.MaskSecrets(assertionStrWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"assertion content: ****", mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestTokenFalsePositive() + { + string falsePositiveToken = "2020-04-30 23:06:04,069 - MainThread auth.py:397" + + " - write_temporary_credential() - DEBUG - no ID " + + "token is given when try to store temporary credential"; + + mask = SecretDetector.MaskSecrets(falsePositiveToken); + Assert.IsFalse(mask.isMasked); + Assert.AreEqual(falsePositiveToken, mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestPasswords() + { + string randomPassword = "Fh[+2J~AcqeqW%?"; + + string randomPasswordWithPrefix = "password:" + randomPassword; + mask = SecretDetector.MaskSecrets(randomPasswordWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"password:****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string randomPasswordCaps = "PASSWORD:" + randomPassword; + mask = SecretDetector.MaskSecrets(randomPasswordCaps); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"PASSWORD:****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string randomPasswordMixCase = "PassWorD:" + randomPassword; + mask = SecretDetector.MaskSecrets(randomPasswordMixCase); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"PassWorD:****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string randomPasswordEqualSign = "password = " + randomPassword; + mask = SecretDetector.MaskSecrets(randomPasswordEqualSign); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"password = ****", mask.maskedText); + Assert.IsNull(mask.errStr); + + string randomPwdWithPrefix = "pwd:" + randomPassword; + mask = SecretDetector.MaskSecrets(randomPwdWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(@"pwd:****", mask.maskedText); + Assert.IsNull(mask.errStr); + } + + + [Test] + public void TestTokenPassword() + { + string longToken = "_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=U" + + "KyJH9DS=nFzzWnfZKGV+C7GopWCGD4Lj" + + "OLLFZKOE26LXHDt3pTi4iI1qwKuSpf/F" + + "mClCMBSissVsU3Ei590FP0lPQQhcSGcD" + + "u69ZL_1X6e9h5z62t/iY7ZkII28n2qU=" + + "nrBJUgPRCIbtJQkVJXIuOHjX4G5yUEKj" + + "ZBAx4w6=_lqtt67bIA=o7D=oUSjfywsR" + + "FoloNIkBPXCwFTv+1RVUHgVA2g8A9Lw5" + + "XdJYuI8vhg=f0bKSq7AhQ2Bh"; + + string longToken2 = "ktL57KJemuq4-M+Q0pdRjCIMcf1mzcr" + + "MwKteDS5DRE/Pb+5MzvWjDH7LFPV5b_" + + "/tX/yoLG3b4TuC6Q5qNzsARPPn_zs/j" + + "BbDOEg1-IfPpdsbwX6ETeEnhxkHIL4H" + + "sP-V"; + + string randomPwd = "Fh[+2J~AcqeqW%?"; + string randomPwd2 = randomPwd + "vdkav13"; + + string testStringWithPrefix = "token=" + longToken + + " random giberish " + + "password:" + randomPwd; + mask = SecretDetector.MaskSecrets(testStringWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "token=****" + + " random giberish " + + "password:****", + mask.maskedText); + Assert.IsNull(mask.errStr); + + // order reversed + testStringWithPrefix = "password:" + randomPwd + + " random giberish " + + "token=" + longToken; + mask = SecretDetector.MaskSecrets(testStringWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "password:****" + + " random giberish " + + "token=****", + mask.maskedText); + Assert.IsNull(mask.errStr); + + // multiple tokens and password + testStringWithPrefix = "token=" + longToken + + " random giberish " + + "password:" + randomPwd + + " random giberish " + + "idToken:" + longToken2; + mask = SecretDetector.MaskSecrets(testStringWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "token=****" + + " random giberish " + + "password:****" + + " random giberish " + + "idToken:****", + mask.maskedText); + Assert.IsNull(mask.errStr); + + // two passwords + testStringWithPrefix = "password=" + randomPwd + + " random giberish " + + "pwd:" + randomPwd2; + mask = SecretDetector.MaskSecrets(testStringWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "password=****" + + " random giberish " + + "pwd:****", + mask.maskedText); + Assert.IsNull(mask.errStr); + + // multiple passwords + testStringWithPrefix = "password=" + randomPwd + + " random giberish " + + "password=" + randomPwd2 + + " random giberish " + + "password=" + randomPwd; + mask = SecretDetector.MaskSecrets(testStringWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "password=****" + + " random giberish " + + "password=****" + + " random giberish " + + "password=****", + mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestCustomPattern() + { + string[] regex = new string[2] + { + @"(testCustomPattern\s*:\s*""([a-z]{8,})"")", + @"(testCustomPattern\s*:\s*""([0-9]{8,})"")" + }; + string[] masks = new string[2] + { + "maskCustomPattern1", + "maskCustomPattern2" + }; + + SecretDetector.SetCustomPatterns(regex, masks); + + // Mask custom pattern + string testString = "testCustomPattern: \"abcdefghijklmnop\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(masks[0], mask.maskedText); + Assert.IsNull(mask.errStr); + + testString = "testCustomPattern: \"1234567890\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(masks[1], mask.maskedText); + Assert.IsNull(mask.errStr); + + // Mask password and custom pattern + testString = "password: abcdefghijklmnop testCustomPattern: \"abcdefghijklmnop\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual("password: **** " + masks[0], mask.maskedText); + Assert.IsNull(mask.errStr); + + testString = "password: abcdefghijklmnop testCustomPattern: \"1234567890\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual("password: **** " + masks[1], mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestCustomPatternClear() + { + string[] regex = new string[1] { @"(testCustomPattern\s*:\s*""([a-z]{8,})"")" }; + string[] masks = new string[1] { "maskCustomPattern1" }; + + SecretDetector.SetCustomPatterns(regex, masks); + + // Mask custom pattern + string testString = "testCustomPattern: \"abcdefghijklmnop\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual(masks[0], mask.maskedText); + Assert.IsNull(mask.errStr); + + // Clear custom patterns + SecretDetector.ClearCustomPatterns(); + testString = "testCustomPattern: \"abcdefghijklmnop\""; + mask = SecretDetector.MaskSecrets(testString); + Assert.IsFalse(mask.isMasked); + Assert.AreEqual(testString, mask.maskedText); + Assert.IsNull(mask.errStr); + } + + [Test] + public void TestCustomPatternUnequalCount() + { + string[] regex = new string[0]; + string[] masks = new string[1] { "maskCustomPattern1" }; + + // Masks count is greater than regex + try + { + SecretDetector.SetCustomPatterns(regex, masks); + } + catch (Exception ex) + { + Assert.AreEqual("Regex count and mask count must be equal.", ex.Message); + } + + // Regex count is greater than masks + regex = new string[2] + { + @"(testCustomPattern\s*:\s*""([0-9]{8,})"")", + @"(testCustomPattern\s*:\s*""([0-9]{8,})"")" + }; + try + { + SecretDetector.SetCustomPatterns(regex, masks); + } + catch (Exception ex) + { + Assert.AreEqual("Regex count and mask count must be equal.", ex.Message); + } + } + + [Test] + public void TestHttpResponse() + { + string randomHttpResponse = + "\"data\" : {" + + "\"masterToken\" : \"_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=U" + + "\"token\" : \"_Y1ZNETTn5/qfUWj3Jedby7gipDzQs=U" + + "\"remMeValidityInSeconds\" : 0," + + "\"healthCheckInterval\" : 12," + + "\"newClientForUpgrade\" : null," + + "\"sessionId\" : 1234"; + + string randomHttpResponseWithPrefix = "Post response: " + randomHttpResponse; + mask = SecretDetector.MaskSecrets(randomHttpResponseWithPrefix); + Assert.IsTrue(mask.isMasked); + Assert.AreEqual( + "Post response: " + + "\"data\" : {" + + "\"masterToken\" : \"****" + + "\"token\" : \"****" + + "\"remMeValidityInSeconds\" : 0," + + "\"healthCheckInterval\" : 12," + + "\"newClientForUpgrade\" : null," + + "\"sessionId\" : 1234", + mask.maskedText); + Assert.IsNull(mask.errStr); + } + } +} diff --git a/Snowflake.Data/Logger/Log4netImpl.cs b/Snowflake.Data/Logger/Log4netImpl.cs index 44baa3afc..4a9ec00d6 100755 --- a/Snowflake.Data/Logger/Log4netImpl.cs +++ b/Snowflake.Data/Logger/Log4netImpl.cs @@ -45,27 +45,32 @@ public bool IsFatalEnabled() public void Debug(string msg, Exception ex = null) { + msg = SecretDetector.MaskSecrets(msg).maskedText; logger.Debug(msg, ex); } public void Info(string msg, Exception ex = null) { + msg = SecretDetector.MaskSecrets(msg).maskedText; logger.Info(msg, ex); } public void Warn(string msg, Exception ex = null) { + msg = SecretDetector.MaskSecrets(msg).maskedText; logger.Warn(msg, ex); } public void Error(string msg, Exception ex = null) { + msg = SecretDetector.MaskSecrets(msg).maskedText; logger.Error(msg, ex); } public void Fatal(string msg, Exception ex = null) { + msg = SecretDetector.MaskSecrets(msg).maskedText; logger.Fatal(msg, ex); } } diff --git a/Snowflake.Data/Logger/SecretDetector.cs b/Snowflake.Data/Logger/SecretDetector.cs new file mode 100644 index 000000000..971b400d5 --- /dev/null +++ b/Snowflake.Data/Logger/SecretDetector.cs @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Text.RegularExpressions; + +namespace Snowflake.Data.Log +{ + class SecretDetector + { + public struct Mask + { + public Mask(bool isMasked = false, string maskedText = null, string errStr = null) + { + this.isMasked = isMasked; + this.maskedText = maskedText; + this.errStr = errStr; + } + + public bool isMasked { get; set; } + public string maskedText { get; set; } + public string errStr { get; set; } + } + + private static List CUSTOM_PATTERNS_REGEX = new List(); + private static List CUSTOM_PATTERNS_MASK = new List(); + private static int CUSTOM_PATTERNS_LENGTH; + + public static void SetCustomPatterns(string[] customRegex, string[] customMask) + { + if (CUSTOM_PATTERNS_LENGTH != 0) + { + ClearCustomPatterns(); + } + + if (customRegex.Length == customMask.Length) + { + CUSTOM_PATTERNS_LENGTH = customRegex.Length; + for (int index = 0; index < CUSTOM_PATTERNS_LENGTH; index++) + { + CUSTOM_PATTERNS_REGEX.Add(customRegex[index]); + CUSTOM_PATTERNS_MASK.Add(customMask[index]); + } + } + else + { + throw new ArgumentException("Regex count and mask count must be equal."); + } + } + + public static void ClearCustomPatterns() + { + CUSTOM_PATTERNS_REGEX.Clear(); + CUSTOM_PATTERNS_MASK.Clear(); + CUSTOM_PATTERNS_LENGTH = 0; + } + + private static string MaskCustomPatterns(string text) + { + string result; + for (int index = 0; index < CUSTOM_PATTERNS_LENGTH; index++) + { + result = Regex.Replace(text, CUSTOM_PATTERNS_REGEX[index], CUSTOM_PATTERNS_MASK[index], + RegexOptions.IgnoreCase); + + if (result != text) + { + return result; + } + } + return text; + } + private static readonly string AWS_KEY_PATTERN = @"('|"")?(aws_key_id|aws_secret_key|access_key_id|secret_access_key)('|"")?\s*(=|:)\s*'([^']+)'"; + private static readonly string AWS_TOKEN_PATTERN = @"(accessToken|tempToken|keySecret)\s*:\s*""([a-z0-9/+]{32,}={0,2})"""; + private static readonly string SAS_TOKEN_PATTERN = @"(sig|signature|AWSAccessKeyId|password|passcode)=(\?P[a-z0-9%/+]{16,})"; + private static readonly string PRIVATE_KEY_PATTERN = @"-----BEGIN PRIVATE KEY-----\n([a-z0-9/+=\n]{32,})\n-----END PRIVATE KEY-----"; + private static readonly string PRIVATE_KEY_DATA_PATTERN = @"""privateKeyData"": ""([a - z0 - 9 /+=\\n]{10,})"""; + private static readonly string CONNECTION_TOKEN_PATTERN = @"(token|assertion content)([\'\""\s:=]+)([a-z0-9=/_\-\+]{8,})"; + private static readonly string PASSWORD_PATTERN = @"(password|pwd)([\'\""\s:=]+)([a-z0-9!\""#\$%&\\\'\(\)\*\+\,-\./:;<=>\?\@\[\]\^_`\{\|\}~]{8,})"; + + private static string MaskAWSKeys(string text) + { + return Regex.Replace(text, AWS_KEY_PATTERN, @"$1$2$3$4'****'", + RegexOptions.IgnoreCase); + } + + private static string MaskAWSTokens(string text) + { + return Regex.Replace(text, AWS_TOKEN_PATTERN, @"$1"":""XXXX""", + RegexOptions.IgnoreCase); + } + + private static string MaskSASTokens(string text) + { + return Regex.Replace(text, SAS_TOKEN_PATTERN, @"$1=****", + RegexOptions.IgnoreCase); + } + + private static string MaskPrivateKey(string text) + { + return Regex.Replace(text, PRIVATE_KEY_PATTERN, "-----BEGIN PRIVATE KEY-----\\\\nXXXX\\\\n-----END PRIVATE KEY-----", + RegexOptions.IgnoreCase | RegexOptions.Multiline); + } + + private static string MaskPrivateKeyData(string text) + { + return Regex.Replace(text, PRIVATE_KEY_DATA_PATTERN, @"""privateKeyData"": ""XXXX""", + RegexOptions.IgnoreCase | RegexOptions.Multiline); + } + + private static string MaskConnectionTokens(string text) + { + return Regex.Replace(text, CONNECTION_TOKEN_PATTERN, @"$1$2****", + RegexOptions.IgnoreCase); + } + + private static string MaskPassword(string text) + { + return Regex.Replace(text, PASSWORD_PATTERN, @"$1$2****", + RegexOptions.IgnoreCase); + } + + public static Mask MaskSecrets(string text) + { + Mask result = new Mask(maskedText: text); + + if (String.IsNullOrEmpty(text)) + { + return result; + } + + try + { + result.maskedText = + MaskConnectionTokens( + MaskPassword( + MaskPrivateKeyData( + MaskPrivateKey( + MaskAWSTokens( + MaskSASTokens( + MaskAWSKeys(text))))))); + if (CUSTOM_PATTERNS_LENGTH > 0) + { + result.maskedText = MaskCustomPatterns(result.maskedText); + } + if (result.maskedText != text) + { + result.isMasked = true; + } + } + catch (Exception ex) + { + //We'll assume that the exception was raised during masking + //to be safe consider that the log has sensitive information + //and do not raise an exception. + result.isMasked = true; + result.maskedText = ex.Message; + result.errStr = ex.Message; + } + return result; + } + } +} From e71527faa7e116589c8f990495df4b09ee7311e4 Mon Sep 17 00:00:00 2001 From: SimbaGithub <48035983+SimbaGithub@users.noreply.github.com> Date: Tue, 6 Jul 2021 18:33:06 -0700 Subject: [PATCH 11/13] Re-add Http message logging (#342) --- Snowflake.Data/Core/HttpUtil.cs | 1 + Snowflake.Data/Core/RestRequester.cs | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 49e76c7e3..7f643c927 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -242,6 +242,7 @@ protected override async Task SendAsync(HttpRequestMessage if (response != null) { if (response.IsSuccessStatusCode) { + logger.Debug($"Success Response: {response.ToString()}"); return response; } else diff --git a/Snowflake.Data/Core/RestRequester.cs b/Snowflake.Data/Core/RestRequester.cs index ea11737c6..31baf0fd5 100644 --- a/Snowflake.Data/Core/RestRequester.cs +++ b/Snowflake.Data/Core/RestRequester.cs @@ -56,6 +56,7 @@ public async Task PostAsync(IRestRequest request, CancellationToken cancel using (var response = await SendAsync(HttpMethod.Post, request, cancellationToken).ConfigureAwait(false)) { var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + logger.Debug($"Post response: {json}"); return JsonConvert.DeserializeObject(json); } } @@ -71,17 +72,23 @@ public async Task GetAsync(IRestRequest request, CancellationToken cancell using (HttpResponseMessage response = await GetAsync(request, cancellationToken).ConfigureAwait(false)) { var json = await response.Content.ReadAsStringAsync().ConfigureAwait(false); + logger.Debug($"Get response: {json}"); return JsonConvert.DeserializeObject(json); } } public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) { + HttpRequestMessage message = request.ToRequestMessage(HttpMethod.Get); + logger.Debug($"Http method: {message.ToString()}, http request message: {message.ToString()}"); return SendAsync(HttpMethod.Get, request, cancellationToken); } public HttpResponseMessage Get(IRestRequest request) { + HttpRequestMessage message = request.ToRequestMessage(HttpMethod.Get); + logger.Debug($"Http method: {message.ToString()}, http request message: {message.ToString()}"); + //Run synchronous in a new thread-pool task. return Task.Run(async () => await GetAsync(request, CancellationToken.None)).Result; } @@ -100,7 +107,7 @@ private async Task SendAsync(HttpMethod method, HttpResponseMessage response = null; try { - response = await HttpUtil.getHttpClient(request.GetInsecureMode()) + response = await HttpUtil.getHttpClient() .SendAsync(message, HttpCompletionOption.ResponseHeadersRead, linkedCts.Token) .ConfigureAwait(false); response.EnsureSuccessStatusCode(); From 350ce78731383a633bf5b4f039eae2f5cf0afab0 Mon Sep 17 00:00:00 2001 From: "SIMBA\\natachab" Date: Thu, 22 Jul 2021 10:49:45 -0700 Subject: [PATCH 12/13] Remove local path from tests --- Snowflake.Data.Tests/SFBaseTest.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 2c03b7d25..5d6e86af9 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -88,8 +88,7 @@ public void SFTestSetup() String cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); - StreamReader reader = new StreamReader("C:\\Users\\natachab\\Snowflake\\fromMasterToWorkOnTicket\\Snowflake.Data.Tests\\parameters.json"); - + StreamReader reader = new StreamReader("parameters.json"); var testConfigString = reader.ReadToEnd(); Dictionary testConfigs = JsonConvert.DeserializeObject>(testConfigString); From 1db22caec1c5b0be5950b908ce7f3446d3913bbb Mon Sep 17 00:00:00 2001 From: sfc-gh-abhatnagar Date: Thu, 22 Jul 2021 11:46:39 -0700 Subject: [PATCH 13/13] SNOW-400208 Version Bump from 1.2.4 to 1.2.5 --- Snowflake.Data/Snowflake.Data.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index 2103ab8c7..dc7cd8bf6 100755 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -12,7 +12,7 @@ Snowflake Connector for .NET howryu, tchen Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - 1.2.4 + 1.2.5 Full