diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index cc4fea738..d80a6cbae 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 6aacb94f9..43e4d92a7 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -1,447 +1,447 @@ -/* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Collections.Generic; -using System.Data; -using System.Diagnostics; -using System.Reflection; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using NUnit.Framework; -using Snowflake.Data.Client; -using Snowflake.Data.Log; -using Snowflake.Data.Tests.Util; - -[assembly:LevelOfParallelism(10)] - -namespace Snowflake.Data.Tests -{ - using NUnit.Framework; - using NUnit.Framework.Interfaces; - using Newtonsoft.Json; - using Newtonsoft.Json.Serialization; - - /* - * This is the base class for all tests that call blocking methods in the library - it uses MockSynchronizationContext to verify that - * there are no async deadlocks in the library - * - */ - [TestFixture] - public class SFBaseTest : SFBaseTestAsync - { - [SetUp] - public static void SetUpContext() - { - MockSynchronizationContext.SetupContext(); - } - - [TearDown] - public static void TearDownContext() - { - MockSynchronizationContext.Verify(); - } - } - - /* - * This is the base class for all tests that call async methods in the library - it does not use a special SynchronizationContext - * - */ - [TestFixture] - [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] - [SetCulture("en-US")] - #if !SEQUENTIAL_TEST_RUN - [Parallelizable(ParallelScope.All)] - #endif - public class SFBaseTestAsync - { - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - - private const string ConnectionStringWithoutAuthFmt = "scheme={0};host={1};port={2};" + - "account={3};role={4};db={5};schema={6};warehouse={7}"; - private const string ConnectionStringSnowflakeAuthFmt = ";user={0};password={1};"; - protected virtual string TestName => TestContext.CurrentContext.Test.MethodName; - protected string TestNameWithWorker => TestName + TestContext.CurrentContext.WorkerId?.Replace("#", "_"); - protected string TableName => TestNameWithWorker; - - - private Stopwatch _stopwatch; - - private List _tablesToRemove; - - [SetUp] - public void BeforeTest() - { - _stopwatch = new Stopwatch(); - _stopwatch.Start(); - _tablesToRemove = new List(); - } - - [TearDown] - public void AfterTest() - { - _stopwatch.Stop(); - var testName = $"{TestContext.CurrentContext.Test.FullName}"; - - TestEnvironment.RecordTestPerformance(testName, _stopwatch.Elapsed); - RemoveTables(); - } - - private void RemoveTables() - { - if (_tablesToRemove.Count == 0) - return; - - using (var conn = new SnowflakeDbConnection(ConnectionString)) - { - conn.Open(); - - var cmd = conn.CreateCommand(); - - foreach (var table in _tablesToRemove) - { - cmd.CommandText = $"DROP TABLE IF EXISTS {table}"; - cmd.ExecuteNonQuery(); - } - } - } - - protected void CreateOrReplaceTable(IDbConnection conn, string tableName, IEnumerable columns, string additionalQueryStr = null) - { - CreateOrReplaceTable(conn, tableName, "", columns, additionalQueryStr); - } - - protected void CreateOrReplaceTable(IDbConnection conn, string tableName, string tableType, IEnumerable columns, string additionalQueryStr = null) - { - var columnsStr = string.Join(", ", columns); - var cmd = conn.CreateCommand(); - cmd.CommandText = $"CREATE OR REPLACE {tableType} TABLE {tableName}({columnsStr}) {additionalQueryStr}"; - s_logger.Debug(cmd.CommandText); - cmd.ExecuteNonQuery(); - - _tablesToRemove.Add(tableName); - } - - protected void AddTableToRemoveList(string tableName) - { - _tablesToRemove.Add(tableName); - } - - public SFBaseTestAsync() - { - testConfig = TestEnvironment.TestConfig; - } - - protected string ConnectionStringWithoutAuth => string.Format(ConnectionStringWithoutAuthFmt, - testConfig.protocol, - testConfig.host, - testConfig.port, - testConfig.account, - testConfig.role, - testConfig.database, - testConfig.schema, - testConfig.warehouse); - - protected string ConnectionString => ConnectionStringWithoutAuth + - string.Format(ConnectionStringSnowflakeAuthFmt, - testConfig.user, - testConfig.password); - - protected string ConnectionStringWithInvalidUserName => ConnectionStringWithoutAuth + - string.Format(ConnectionStringSnowflakeAuthFmt, - "unknown", - testConfig.password); - - protected TestConfig testConfig { get; } - - protected string ResolveHost() - { - return testConfig.host ?? $"{testConfig.account}.snowflakecomputing.com"; - } - } - - [SetUpFixture] - public class TestEnvironment - { - private const string ConnectionStringFmt = "scheme={0};host={1};port={2};" + - "account={3};role={4};db={5};warehouse={6};user={7};password={8};"; - - public static TestConfig TestConfig { get; private set; } - - private static Dictionary s_testPerformance; - - private static readonly object s_testPerformanceLock = new object(); - - public static void RecordTestPerformance(string name, TimeSpan time) - { - lock (s_testPerformanceLock) - { - s_testPerformance.Add(name, time); - } - } - - [OneTimeSetUp] - public void Setup() - { -#if NETFRAMEWORK - log4net.GlobalContext.Properties["framework"] = "net471"; - log4net.Config.XmlConfigurator.Configure(); - -#else - log4net.GlobalContext.Properties["framework"] = "net6.0"; - var logRepository = log4net.LogManager.GetRepository(Assembly.GetEntryAssembly()); - log4net.Config.XmlConfigurator.Configure(logRepository, new FileInfo("App.config")); -#endif - var cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); - Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); - - var reader = new StreamReader("parameters.json"); - - var testConfigString = reader.ReadToEnd(); - - // Local JSON settings to avoid using system wide settings which could be different - // than the default ones - var jsonSettings = new JsonSerializerSettings - { - ContractResolver = new DefaultContractResolver - { - NamingStrategy = new DefaultNamingStrategy() - } - }; - var testConfigs = JsonConvert.DeserializeObject>(testConfigString, jsonSettings); - - if (testConfigs.TryGetValue("testconnection", out var testConnectionConfig)) - { - TestConfig = testConnectionConfig; - TestConfig.schema = TestConfig.schema + "_" + Guid.NewGuid().ToString().Replace("-", "_"); - } - else - { - Assert.Fail("Failed to load test configuration"); - } - - ModifySchema(TestConfig.schema, SchemaAction.CREATE); - } - - [OneTimeTearDown] - public void Cleanup() - { - ModifySchema(TestConfig.schema, SchemaAction.DROP); - } - - [OneTimeSetUp] - public void SetupTestPerformance() - { - s_testPerformance = new Dictionary(); - } - - [OneTimeTearDown] - public void CreateTestTimeArtifact() - { - var resultText = "test;time_in_ms\n"; - resultText += string.Join("\n", - s_testPerformance.Select(test => $"{test.Key};{Math.Round(test.Value.TotalMilliseconds,0)}")); - - var dotnetVersion = Environment.GetEnvironmentVariable("net_version"); - var cloudEnv = Environment.GetEnvironmentVariable("snowflake_cloud_env"); - - var separator = Path.DirectorySeparatorChar; - - // We have to go up 3 times as the working directory path looks as follows: - // Snowflake.Data.Tests/bin/debug/{.net_version}/ - File.WriteAllText($"..{separator}..{separator}..{separator}{GetOs()}_{dotnetVersion}_{cloudEnv}_performance.csv", resultText); - } - - private static string s_connectionString => string.Format(ConnectionStringFmt, - TestConfig.protocol, - TestConfig.host, - TestConfig.port, - TestConfig.account, - TestConfig.role, - TestConfig.database, - TestConfig.warehouse, - TestConfig.user, - TestConfig.password); - - private enum SchemaAction - { - CREATE, - DROP - } - - private static void ModifySchema(string schemaName, SchemaAction schemaAction) - { - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = s_connectionString; - conn.Open(); - var dbCommand = conn.CreateCommand(); - switch (schemaAction) - { - case SchemaAction.CREATE: - dbCommand.CommandText = $"CREATE OR REPLACE SCHEMA {schemaName}"; - break; - case SchemaAction.DROP: - dbCommand.CommandText = $"DROP SCHEMA IF EXISTS {schemaName}"; - break; - default: - Assert.Fail($"Not supported action on schema: {schemaAction}"); - break; - } - try - { - dbCommand.ExecuteNonQuery(); - } - catch (InvalidOperationException e) - { - Assert.Fail($"Unable to {schemaAction.ToString().ToLower()} schema {schemaName}:\n{e.StackTrace}"); - } - } - } - - private static string GetOs() - { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - return "windows"; - } - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - return "linux"; - } - if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) - { - return "macos"; - } - - return "unknown"; - } - } - - public class TestConfig - { - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_USER", NullValueHandling = NullValueHandling.Ignore)] - internal string user { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PASSWORD", NullValueHandling = NullValueHandling.Ignore)] - internal string password { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_ACCOUNT", NullValueHandling = NullValueHandling.Ignore)] - internal string account { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_HOST", NullValueHandling = NullValueHandling.Ignore)] - internal string host { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PORT", NullValueHandling = NullValueHandling.Ignore)] - internal string port { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_WAREHOUSE", NullValueHandling = NullValueHandling.Ignore)] - internal string warehouse { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_DATABASE", NullValueHandling = NullValueHandling.Ignore)] - internal string database { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_SCHEMA", NullValueHandling = NullValueHandling.Ignore)] - internal string schema { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_ROLE", NullValueHandling = NullValueHandling.Ignore)] - internal string role { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PROTOCOL", NullValueHandling = NullValueHandling.Ignore)] - internal string protocol { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_USER", NullValueHandling = NullValueHandling.Ignore)] - internal string oktaUser { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_PASSWORD", NullValueHandling = NullValueHandling.Ignore)] - internal string oktaPassword { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_URL", NullValueHandling = NullValueHandling.Ignore)] - internal string oktaUrl { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_JWT_USER", NullValueHandling = NullValueHandling.Ignore)] - internal string jwtAuthUser { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PEM_FILE", NullValueHandling = NullValueHandling.Ignore)] - internal string pemFilePath { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_P8_FILE", NullValueHandling = NullValueHandling.Ignore)] - internal string p8FilePath { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PWD_PROTECTED_PK_FILE", NullValueHandling = NullValueHandling.Ignore)] - internal string pwdProtectedPrivateKeyFilePath { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PK_CONTENT", NullValueHandling = NullValueHandling.Ignore)] - internal string privateKey { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PROTECTED_PK_CONTENT", NullValueHandling = NullValueHandling.Ignore)] - internal string pwdProtectedPrivateKey { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PK_PWD", NullValueHandling = NullValueHandling.Ignore)] - internal string privateKeyFilePwd { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OAUTH_TOKEN", NullValueHandling = NullValueHandling.Ignore)] - internal string oauthToken { get; set; } - - [JsonProperty(PropertyName = "SNOWFLAKE_TEST_EXP_OAUTH_TOKEN", NullValueHandling = NullValueHandling.Ignore)] - internal string expOauthToken { get; set; } - - [JsonProperty(PropertyName = "PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] - internal string proxyHost { get; set; } - - [JsonProperty(PropertyName = "PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] - internal string proxyPort { get; set; } - - [JsonProperty(PropertyName = "AUTH_PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] - internal string authProxyHost { get; set; } - - [JsonProperty(PropertyName = "AUTH_PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] - internal string authProxyPort { get; set; } - - [JsonProperty(PropertyName = "AUTH_PROXY_USER", NullValueHandling = NullValueHandling.Ignore)] - internal string authProxyUser { get; set; } - - [JsonProperty(PropertyName = "AUTH_PROXY_PWD", NullValueHandling = NullValueHandling.Ignore)] - internal string authProxyPwd { get; set; } - - [JsonProperty(PropertyName = "NON_PROXY_HOSTS", NullValueHandling = NullValueHandling.Ignore)] - internal string nonProxyHosts { get; set; } - - public TestConfig() - { - protocol = "https"; - port = "443"; - } - } - - public class IgnoreOnEnvIsAttribute : Attribute, ITestAction - { - private readonly string _key; - - private readonly string[] _values; - public IgnoreOnEnvIsAttribute(string key, string[] values) - { - _key = key; - _values = values; - } - - public void BeforeTest(ITest test) - { - foreach (var value in _values) - { - if (Environment.GetEnvironmentVariable(_key) == value) - { - Assert.Ignore("Test is ignored when environment variable {0} is {1} ", _key, value); - } - } - } - - public void AfterTest(ITest test) - { - } - - public ActionTargets Targets => ActionTargets.Test | ActionTargets.Suite; - } -} +/* + * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Data; +using System.Diagnostics; +using System.Reflection; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Util; + +[assembly:LevelOfParallelism(10)] + +namespace Snowflake.Data.Tests +{ + using NUnit.Framework; + using NUnit.Framework.Interfaces; + using Newtonsoft.Json; + using Newtonsoft.Json.Serialization; + + /* + * This is the base class for all tests that call blocking methods in the library - it uses MockSynchronizationContext to verify that + * there are no async deadlocks in the library + * + */ + [TestFixture] + public class SFBaseTest : SFBaseTestAsync + { + [SetUp] + public static void SetUpContext() + { + MockSynchronizationContext.SetupContext(); + } + + [TearDown] + public static void TearDownContext() + { + MockSynchronizationContext.Verify(); + } + } + + /* + * This is the base class for all tests that call async methods in the library - it does not use a special SynchronizationContext + * + */ + [TestFixture] + [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] + [SetCulture("en-US")] + #if !SEQUENTIAL_TEST_RUN + [Parallelizable(ParallelScope.All)] + #endif + public class SFBaseTestAsync + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private const string ConnectionStringWithoutAuthFmt = "scheme={0};host={1};port={2};" + + "account={3};role={4};db={5};schema={6};warehouse={7}"; + private const string ConnectionStringSnowflakeAuthFmt = ";user={0};password={1};"; + protected virtual string TestName => TestContext.CurrentContext.Test.MethodName; + protected string TestNameWithWorker => TestName + TestContext.CurrentContext.WorkerId?.Replace("#", "_"); + protected string TableName => TestNameWithWorker; + + + private Stopwatch _stopwatch; + + private List _tablesToRemove; + + [SetUp] + public void BeforeTest() + { + _stopwatch = new Stopwatch(); + _stopwatch.Start(); + _tablesToRemove = new List(); + } + + [TearDown] + public void AfterTest() + { + _stopwatch.Stop(); + var testName = $"{TestContext.CurrentContext.Test.FullName}"; + + TestEnvironment.RecordTestPerformance(testName, _stopwatch.Elapsed); + RemoveTables(); + } + + private void RemoveTables() + { + if (_tablesToRemove.Count == 0) + return; + + using (var conn = new SnowflakeDbConnection(ConnectionString)) + { + conn.Open(); + + var cmd = conn.CreateCommand(); + + foreach (var table in _tablesToRemove) + { + cmd.CommandText = $"DROP TABLE IF EXISTS {table}"; + cmd.ExecuteNonQuery(); + } + } + } + + protected void CreateOrReplaceTable(IDbConnection conn, string tableName, IEnumerable columns, string additionalQueryStr = null) + { + CreateOrReplaceTable(conn, tableName, "", columns, additionalQueryStr); + } + + protected void CreateOrReplaceTable(IDbConnection conn, string tableName, string tableType, IEnumerable columns, string additionalQueryStr = null) + { + var columnsStr = string.Join(", ", columns); + var cmd = conn.CreateCommand(); + cmd.CommandText = $"CREATE OR REPLACE {tableType} TABLE {tableName}({columnsStr}) {additionalQueryStr}"; + s_logger.Debug(cmd.CommandText); + cmd.ExecuteNonQuery(); + + _tablesToRemove.Add(tableName); + } + + protected void AddTableToRemoveList(string tableName) + { + _tablesToRemove.Add(tableName); + } + + public SFBaseTestAsync() + { + testConfig = TestEnvironment.TestConfig; + } + + protected string ConnectionStringWithoutAuth => string.Format(ConnectionStringWithoutAuthFmt, + testConfig.protocol, + testConfig.host, + testConfig.port, + testConfig.account, + testConfig.role, + testConfig.database, + testConfig.schema, + testConfig.warehouse); + + protected string ConnectionString => ConnectionStringWithoutAuth + + string.Format(ConnectionStringSnowflakeAuthFmt, + testConfig.user, + testConfig.password); + + protected string ConnectionStringWithInvalidUserName => ConnectionStringWithoutAuth + + string.Format(ConnectionStringSnowflakeAuthFmt, + "unknown", + testConfig.password); + + protected TestConfig testConfig { get; } + + protected string ResolveHost() + { + return testConfig.host ?? $"{testConfig.account}.snowflakecomputing.com"; + } + } + + [SetUpFixture] + public class TestEnvironment + { + private const string ConnectionStringFmt = "scheme={0};host={1};port={2};" + + "account={3};role={4};db={5};warehouse={6};user={7};password={8};"; + + public static TestConfig TestConfig { get; private set; } + + private static Dictionary s_testPerformance; + + private static readonly object s_testPerformanceLock = new object(); + + public static void RecordTestPerformance(string name, TimeSpan time) + { + lock (s_testPerformanceLock) + { + s_testPerformance.Add(name, time); + } + } + + [OneTimeSetUp] + public void Setup() + { +#if NETFRAMEWORK + log4net.GlobalContext.Properties["framework"] = "net471"; + log4net.Config.XmlConfigurator.Configure(); + +#else + log4net.GlobalContext.Properties["framework"] = "net6.0"; + var logRepository = log4net.LogManager.GetRepository(Assembly.GetEntryAssembly()); + log4net.Config.XmlConfigurator.Configure(logRepository, new FileInfo("App.config")); +#endif + var cloud = Environment.GetEnvironmentVariable("snowflake_cloud_env"); + Assert.IsTrue(cloud == null || cloud == "AWS" || cloud == "AZURE" || cloud == "GCP", "{0} is not supported. Specify AWS, AZURE or GCP as cloud environment", cloud); + + var reader = new StreamReader("parameters.json"); + + var testConfigString = reader.ReadToEnd(); + + // Local JSON settings to avoid using system wide settings which could be different + // than the default ones + var jsonSettings = new JsonSerializerSettings + { + ContractResolver = new DefaultContractResolver + { + NamingStrategy = new DefaultNamingStrategy() + } + }; + var testConfigs = JsonConvert.DeserializeObject>(testConfigString, jsonSettings); + + if (testConfigs.TryGetValue("testconnection", out var testConnectionConfig)) + { + TestConfig = testConnectionConfig; + TestConfig.schema = TestConfig.schema + "_" + Guid.NewGuid().ToString().Replace("-", "_"); + } + else + { + Assert.Fail("Failed to load test configuration"); + } + + ModifySchema(TestConfig.schema, SchemaAction.CREATE); + } + + [OneTimeTearDown] + public void Cleanup() + { + ModifySchema(TestConfig.schema, SchemaAction.DROP); + } + + [OneTimeSetUp] + public void SetupTestPerformance() + { + s_testPerformance = new Dictionary(); + } + + [OneTimeTearDown] + public void CreateTestTimeArtifact() + { + var resultText = "test;time_in_ms\n"; + resultText += string.Join("\n", + s_testPerformance.Select(test => $"{test.Key};{Math.Round(test.Value.TotalMilliseconds,0)}")); + + var dotnetVersion = Environment.GetEnvironmentVariable("net_version"); + var cloudEnv = Environment.GetEnvironmentVariable("snowflake_cloud_env"); + + var separator = Path.DirectorySeparatorChar; + + // We have to go up 3 times as the working directory path looks as follows: + // Snowflake.Data.Tests/bin/debug/{.net_version}/ + File.WriteAllText($"..{separator}..{separator}..{separator}{GetOs()}_{dotnetVersion}_{cloudEnv}_performance.csv", resultText); + } + + private static string s_connectionString => string.Format(ConnectionStringFmt, + TestConfig.protocol, + TestConfig.host, + TestConfig.port, + TestConfig.account, + TestConfig.role, + TestConfig.database, + TestConfig.warehouse, + TestConfig.user, + TestConfig.password); + + private enum SchemaAction + { + CREATE, + DROP + } + + private static void ModifySchema(string schemaName, SchemaAction schemaAction) + { + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = s_connectionString; + conn.Open(); + var dbCommand = conn.CreateCommand(); + switch (schemaAction) + { + case SchemaAction.CREATE: + dbCommand.CommandText = $"CREATE OR REPLACE SCHEMA {schemaName}"; + break; + case SchemaAction.DROP: + dbCommand.CommandText = $"DROP SCHEMA IF EXISTS {schemaName}"; + break; + default: + Assert.Fail($"Not supported action on schema: {schemaAction}"); + break; + } + try + { + dbCommand.ExecuteNonQuery(); + } + catch (InvalidOperationException e) + { + Assert.Fail($"Unable to {schemaAction.ToString().ToLower()} schema {schemaName}:\n{e.StackTrace}"); + } + } + } + + private static string GetOs() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return "windows"; + } + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + return "linux"; + } + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return "macos"; + } + + return "unknown"; + } + } + + public class TestConfig + { + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_USER", NullValueHandling = NullValueHandling.Ignore)] + internal string user { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PASSWORD", NullValueHandling = NullValueHandling.Ignore)] + internal string password { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_ACCOUNT", NullValueHandling = NullValueHandling.Ignore)] + internal string account { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_HOST", NullValueHandling = NullValueHandling.Ignore)] + internal string host { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PORT", NullValueHandling = NullValueHandling.Ignore)] + internal string port { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_WAREHOUSE", NullValueHandling = NullValueHandling.Ignore)] + internal string warehouse { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_DATABASE", NullValueHandling = NullValueHandling.Ignore)] + internal string database { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_SCHEMA", NullValueHandling = NullValueHandling.Ignore)] + internal string schema { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_ROLE", NullValueHandling = NullValueHandling.Ignore)] + internal string role { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PROTOCOL", NullValueHandling = NullValueHandling.Ignore)] + internal string protocol { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_USER", NullValueHandling = NullValueHandling.Ignore)] + internal string oktaUser { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_PASSWORD", NullValueHandling = NullValueHandling.Ignore)] + internal string oktaPassword { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OKTA_URL", NullValueHandling = NullValueHandling.Ignore)] + internal string oktaUrl { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_JWT_USER", NullValueHandling = NullValueHandling.Ignore)] + internal string jwtAuthUser { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PEM_FILE", NullValueHandling = NullValueHandling.Ignore)] + internal string pemFilePath { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_P8_FILE", NullValueHandling = NullValueHandling.Ignore)] + internal string p8FilePath { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PWD_PROTECTED_PK_FILE", NullValueHandling = NullValueHandling.Ignore)] + internal string pwdProtectedPrivateKeyFilePath { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PK_CONTENT", NullValueHandling = NullValueHandling.Ignore)] + internal string privateKey { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PROTECTED_PK_CONTENT", NullValueHandling = NullValueHandling.Ignore)] + internal string pwdProtectedPrivateKey { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_PK_PWD", NullValueHandling = NullValueHandling.Ignore)] + internal string privateKeyFilePwd { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_OAUTH_TOKEN", NullValueHandling = NullValueHandling.Ignore)] + internal string oauthToken { get; set; } + + [JsonProperty(PropertyName = "SNOWFLAKE_TEST_EXP_OAUTH_TOKEN", NullValueHandling = NullValueHandling.Ignore)] + internal string expOauthToken { get; set; } + + [JsonProperty(PropertyName = "PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] + internal string proxyHost { get; set; } + + [JsonProperty(PropertyName = "PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] + internal string proxyPort { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_HOST", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyHost { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_PORT", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyPort { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_USER", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyUser { get; set; } + + [JsonProperty(PropertyName = "AUTH_PROXY_PWD", NullValueHandling = NullValueHandling.Ignore)] + internal string authProxyPwd { get; set; } + + [JsonProperty(PropertyName = "NON_PROXY_HOSTS", NullValueHandling = NullValueHandling.Ignore)] + internal string nonProxyHosts { get; set; } + + public TestConfig() + { + protocol = "https"; + port = "443"; + } + } + + public class IgnoreOnEnvIsAttribute : Attribute, ITestAction + { + private readonly string _key; + + private readonly string[] _values; + public IgnoreOnEnvIsAttribute(string key, string[] values) + { + _key = key; + _values = values; + } + + public void BeforeTest(ITest test) + { + foreach (var value in _values) + { + if (Environment.GetEnvironmentVariable(_key) == value) + { + Assert.Ignore("Test is ignored when environment variable {0} is {1} ", _key, value); + } + } + } + + public void AfterTest(ITest test) + { + } + + public ActionTargets Targets => ActionTargets.Test | ActionTargets.Suite; + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs index da3baf531..733c707a0 100644 --- a/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFS3ClientTest.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. */ diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index 309570ca6..702ddcf81 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright (c) 2019 Snowflake Computing Inc. All rights reserved. */ diff --git a/Snowflake.Data/Core/HttpUtil.cs b/Snowflake.Data/Core/HttpUtil.cs index 531e76fd7..a26ae164c 100755 --- a/Snowflake.Data/Core/HttpUtil.cs +++ b/Snowflake.Data/Core/HttpUtil.cs @@ -1,563 +1,563 @@ -/* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. - */ - -using System.Threading.Tasks; -using System.Net.Http; -using System.Net; -using System; -using System.Threading; -using System.Collections.Generic; -using Snowflake.Data.Log; -using System.Collections.Specialized; -using System.Web; -using System.Security.Authentication; -using System.Linq; -using Snowflake.Data.Core.Authenticator; - -namespace Snowflake.Data.Core -{ - public class HttpClientConfig - { - public HttpClientConfig( - bool crlCheckEnabled, - string proxyHost, - string proxyPort, - string proxyUser, - string proxyPassword, - string noProxyList, - bool disableRetry, - bool forceRetryOn404, - int maxHttpRetries, - bool includeRetryReason = true) - { - CrlCheckEnabled = crlCheckEnabled; - ProxyHost = proxyHost; - ProxyPort = proxyPort; - ProxyUser = proxyUser; - ProxyPassword = proxyPassword; - NoProxyList = noProxyList; - DisableRetry = disableRetry; - ForceRetryOn404 = forceRetryOn404; - MaxHttpRetries = maxHttpRetries; - IncludeRetryReason = includeRetryReason; - - ConfKey = string.Join(";", - new string[] { - crlCheckEnabled.ToString(), - proxyHost, - proxyPort, - proxyUser, - proxyPassword, - noProxyList, - disableRetry.ToString(), - forceRetryOn404.ToString(), - maxHttpRetries.ToString(), - includeRetryReason.ToString()}); - } - - public readonly bool CrlCheckEnabled; - public readonly string ProxyHost; - public readonly string ProxyPort; - public readonly string ProxyUser; - public readonly string ProxyPassword; - public readonly string NoProxyList; - public readonly bool DisableRetry; - public readonly bool ForceRetryOn404; - public readonly int MaxHttpRetries; - public readonly bool IncludeRetryReason; - - // Key used to identify the HttpClient with the configuration matching the settings - public readonly string ConfKey; - } - - public sealed class HttpUtil - { - static internal readonly int MAX_BACKOFF = 16; - private static readonly int s_baseBackOffTime = 1; - private static readonly int s_exponentialFactor = 2; - private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); - - private static readonly List s_supportedEndpointsForRetryPolicy = new List - { - RestPath.SF_LOGIN_PATH, - RestPath.SF_AUTHENTICATOR_REQUEST_PATH, - RestPath.SF_TOKEN_REQUEST_PATH - }; - - private HttpUtil() - { - // This value is used by AWS SDK and can cause deadlock, - // so we need to increase the default value of 2 - // See: https://github.com/aws/aws-sdk-net/issues/152 - ServicePointManager.DefaultConnectionLimit = 50; - } - - internal static HttpUtil Instance { get; } = new HttpUtil(); - - private readonly object httpClientProviderLock = new object(); - - private Dictionary _HttpClients = new Dictionary(); - - internal HttpClient GetHttpClient(HttpClientConfig config) - { - lock (httpClientProviderLock) - { - return RegisterNewHttpClientIfNecessary(config); - } - } - - - private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config) - { - string name = config.ConfKey; - if (!_HttpClients.ContainsKey(name)) - { - logger.Debug("Http client not registered. Adding."); - - var httpClient = new HttpClient( - new RetryHandler(SetupCustomHttpHandler(config), config.DisableRetry, config.ForceRetryOn404, config.MaxHttpRetries, config.IncludeRetryReason)) - { - Timeout = Timeout.InfiniteTimeSpan - }; - - // Add the new client key to the list - _HttpClients.Add(name, httpClient); - } - - return _HttpClients[name]; - } - - internal HttpMessageHandler SetupCustomHttpHandler(HttpClientConfig config) - { - HttpMessageHandler httpHandler; - try - { - httpHandler = new HttpClientHandler() - { - // Verify no certificates have been revoked - CheckCertificateRevocationList = config.CrlCheckEnabled, - // Enforce tls v1.2 - SslProtocols = SslProtocols.Tls12, - AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, - UseCookies = false, // Disable cookies - UseProxy = false - }; - } - // special logic for .NET framework 4.7.1 that - // CheckCertificateRevocationList and SslProtocols are not supported - catch (PlatformNotSupportedException) - { - httpHandler = new HttpClientHandler() - { - AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, - UseCookies = false, // Disable cookies - UseProxy = false - }; - } - - // Add a proxy if necessary - if (null != config.ProxyHost) - { - // Proxy needed - WebProxy proxy = new WebProxy(config.ProxyHost, int.Parse(config.ProxyPort)); - - // Add credential if provided - if (!String.IsNullOrEmpty(config.ProxyUser)) - { - ICredentials credentials = new NetworkCredential(config.ProxyUser, config.ProxyPassword); - proxy.Credentials = credentials; - } - - // Add bypasslist if provided - if (!String.IsNullOrEmpty(config.NoProxyList)) - { - string[] bypassList = config.NoProxyList.Split( - new char[] { '|' }, - StringSplitOptions.RemoveEmptyEntries); - // Convert simplified syntax to standard regular expression syntax - string entry = null; - for (int i = 0; i < bypassList.Length; i++) - { - // Get the original entry - entry = bypassList[i].Trim(); - // . -> [.] because . means any char - entry = entry.Replace(".", "[.]"); - // * -> .* because * is a quantifier and need a char or group to apply to - entry = entry.Replace("*", ".*"); - - entry = entry.StartsWith("^") ? entry : $"^{entry}"; - - entry = entry.EndsWith("$") ? entry : $"{entry}$"; - - // Replace with the valid entry syntax - bypassList[i] = entry; - - } - proxy.BypassList = bypassList; - } - - HttpClientHandler httpHandlerWithProxy = (HttpClientHandler)httpHandler; - httpHandlerWithProxy.UseProxy = true; - httpHandlerWithProxy.Proxy = proxy; - return httpHandlerWithProxy; - } - return httpHandler; - } - - /// - /// UriUpdater would update the uri in each retry. During construction, it would take in an uri that would later - /// be updated in each retry and figure out the rules to apply when updating. - /// - internal class UriUpdater - { - /// - /// IRule defines how the queryParams of a uri should be updated in each retry - /// - interface IRule - { - void apply(NameValueCollection queryParams); - } - - /// - /// RetryCountRule would update the retryCount parameter - /// - class RetryCountRule : IRule - { - int retryCount; - - internal RetryCountRule() - { - retryCount = 1; - } - - void IRule.apply(NameValueCollection queryParams) - { - if (retryCount == 1) - { - queryParams.Add(RestParams.SF_QUERY_RETRY_COUNT, retryCount.ToString()); - } - else - { - queryParams.Set(RestParams.SF_QUERY_RETRY_COUNT, retryCount.ToString()); - } - retryCount++; - } - } - - /// - /// RequestUUIDRule would update the request_guid query with a new RequestGUID - /// - class RequestUUIDRule : IRule - { - void IRule.apply(NameValueCollection queryParams) - { - queryParams.Set(RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString()); - } - } - - /// - /// RetryReasonRule would update the retryReason parameter - /// - class RetryReasonRule : IRule - { - int retryReason; - - internal RetryReasonRule() - { - retryReason = 0; - } - - public void SetRetryReason(int reason) - { - retryReason = reason; - } - - void IRule.apply(NameValueCollection queryParams) - { - queryParams.Set(RestParams.SF_QUERY_RETRY_REASON, retryReason.ToString()); - } - } - - UriBuilder uriBuilder; - List rules; - internal UriUpdater(Uri uri, bool includeRetryReason = true) - { - uriBuilder = new UriBuilder(uri); - rules = new List(); - - if (uri.AbsolutePath.StartsWith(RestPath.SF_QUERY_PATH)) - { - rules.Add(new RetryCountRule()); - if (includeRetryReason) - { - rules.Add(new RetryReasonRule()); - } - } - - if (uri.Query != null && uri.Query.Contains(RestParams.SF_QUERY_REQUEST_GUID)) - { - rules.Add(new RequestUUIDRule()); - } - } - - internal Uri Update(int retryReason = 0) - { - // Optimization to bypass parsing if there is no rules at all. - if (rules.Count == 0) - { - return uriBuilder.Uri; - } - - var queryParams = HttpUtility.ParseQueryString(uriBuilder.Query); - - foreach (IRule rule in rules) - { - if (rule is RetryReasonRule) - { - ((RetryReasonRule)rule).SetRetryReason(retryReason); - } - rule.apply(queryParams); - } - - uriBuilder.Query = queryParams.ToString(); - - return uriBuilder.Uri; - } - } - private class RetryHandler : DelegatingHandler - { - static private SFLogger logger = SFLoggerFactory.GetLogger(); - - private bool disableRetry; - private bool forceRetryOn404; - private int maxRetryCount; - private bool includeRetryReason; - - internal RetryHandler(HttpMessageHandler innerHandler, bool disableRetry, bool forceRetryOn404, int maxRetryCount, bool includeRetryReason) : base(innerHandler) - { - this.disableRetry = disableRetry; - this.forceRetryOn404 = forceRetryOn404; - this.maxRetryCount = maxRetryCount; - this.includeRetryReason = includeRetryReason; - } - - protected override async Task SendAsync(HttpRequestMessage requestMessage, - CancellationToken cancellationToken) - { - HttpResponseMessage response = null; - string absolutePath = requestMessage.RequestUri.AbsolutePath; - bool isLoginRequest = IsLoginEndpoint(absolutePath); - bool isOktaSSORequest = IsOktaSSORequest(requestMessage.RequestUri.Host, absolutePath); - int backOffInSec = s_baseBackOffTime; - int totalRetryTime = 0; - - ServicePoint p = ServicePointManager.FindServicePoint(requestMessage.RequestUri); - p.Expect100Continue = false; // Saves about 100 ms per request - p.UseNagleAlgorithm = false; // Saves about 200 ms per request - p.ConnectionLimit = 20; // Default value is 2, we need more connections for performing multiple parallel queries - - TimeSpan httpTimeout = (TimeSpan)requestMessage.Properties[SFRestRequest.HTTP_REQUEST_TIMEOUT_KEY]; - TimeSpan restTimeout = (TimeSpan)requestMessage.Properties[SFRestRequest.REST_REQUEST_TIMEOUT_KEY]; - - if (logger.IsDebugEnabled()) - { - logger.Debug("Http request timeout : " + httpTimeout); - logger.Debug("Rest request timeout : " + restTimeout); - } - - CancellationTokenSource childCts = null; - - UriUpdater updater = new UriUpdater(requestMessage.RequestUri, includeRetryReason); - int retryCount = 0; - - while (true) - { - - try - { - childCts = null; - - if (!httpTimeout.Equals(Timeout.InfiniteTimeSpan)) - { - childCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - if (httpTimeout.Ticks == 0) - childCts.Cancel(); - else - childCts.CancelAfter(httpTimeout); - } - response = await base.SendAsync(requestMessage, childCts == null ? - cancellationToken : childCts.Token).ConfigureAwait(false); - } - catch (Exception e) - { - if (cancellationToken.IsCancellationRequested) - { - logger.Debug("SF rest request timeout or explicit cancel called."); - cancellationToken.ThrowIfCancellationRequested(); - } - else if (childCts != null && childCts.Token.IsCancellationRequested) - { - logger.Warn("Http request timeout. Retry the request"); - totalRetryTime += (int)httpTimeout.TotalSeconds; - } - else - { - //TODO: Should probably check to see if the error is recoverable or transient. - logger.Warn("Error occurred during request, retrying...", e); - } - } - - if (childCts != null) - { - childCts.Dispose(); - } - - int errorReason = 0; - - if (response != null) - { - if (isOktaSSORequest) - { - response.Content.Headers.Add(OktaAuthenticator.RetryCountHeader, retryCount.ToString()); - response.Content.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, totalRetryTime.ToString()); - } - - if (response.IsSuccessStatusCode) - { - logger.Debug($"Success Response: StatusCode: {(int)response.StatusCode}, ReasonPhrase: '{response.ReasonPhrase}'"); - return response; - } - else - { - logger.Debug($"Failed Response: StatusCode: {(int)response.StatusCode}, ReasonPhrase: '{response.ReasonPhrase}'"); - bool isRetryable = isRetryableHTTPCode((int)response.StatusCode, forceRetryOn404); - - if (!isRetryable || disableRetry) - { - // No need to keep retrying, stop here - return response; - } - } - errorReason = (int)response.StatusCode; - } - else - { - logger.Info("Response returned was null."); - } - - retryCount++; - if ((maxRetryCount > 0) && (retryCount > maxRetryCount)) - { - logger.Debug($"stop retry as maxHttpRetries {maxRetryCount} reached"); - if (response != null) - { - return response; - } - throw new OperationCanceledException($"http request failed and max retry {maxRetryCount} reached"); - } - - // Disposing of the response if not null now that we don't need it anymore - response?.Dispose(); - - requestMessage.RequestUri = updater.Update(errorReason); - - logger.Debug($"Sleep {backOffInSec} seconds and then retry the request, retryCount: {retryCount}"); - - await Task.Delay(TimeSpan.FromSeconds(backOffInSec), cancellationToken).ConfigureAwait(false); - totalRetryTime += backOffInSec; - - var jitter = GetJitter(backOffInSec); - - // Set backoff time - if (isLoginRequest) - { - // Choose between previous sleep time and new base sleep time for login requests - backOffInSec = (int)ChooseRandom( - backOffInSec + jitter, - Math.Pow(s_exponentialFactor, retryCount) + jitter); - } - else if (backOffInSec < MAX_BACKOFF) - { - // Multiply sleep by 2 for non-login requests - backOffInSec *= 2; - } - - if ((restTimeout.TotalSeconds > 0) && (totalRetryTime + backOffInSec > restTimeout.TotalSeconds)) - { - // No need to wait more than necessary if it can be avoided. - // If the rest timeout will be reached before the next back-off, - // then use the remaining connection timeout - backOffInSec = Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime); - } - } - } - } - - /// - /// Check whether or not the error is retryable or not. - /// - /// The http status code. - /// True if the request should be retried, false otherwise. - static public bool isRetryableHTTPCode(int statusCode, bool forceRetryOn404) - { - if (forceRetryOn404 && statusCode == 404) - return true; - return (500 <= statusCode) && (statusCode < 600) || - // Forbidden - (statusCode == 403) || - // Request timeout - (statusCode == 408) || - // Too many requests - (statusCode == 429); - } - - /// - /// Get the jitter amount based on current wait time. - /// - /// The current retry backoff time. - /// The new jitter amount. - static internal double GetJitter(double curWaitTime) - { - double multiplicationFactor = ChooseRandom(-1, 1); - double jitterAmount = 0.5 * curWaitTime * multiplicationFactor; - return jitterAmount; - } - - /// - /// Randomly generates a number between a given range. - /// - /// The min range (inclusive). - /// The max range (inclusive). - /// The random number. - static double ChooseRandom(double min, double max) - { - var next = new Random().NextDouble(); - - return min + (next * (max - min)); - } - - /// - /// Checks if the endpoint is a login request. - /// - /// The endpoint to check. - /// True if the endpoint is a login request, false otherwise. - static internal bool IsLoginEndpoint(string endpoint) - { - return null != s_supportedEndpointsForRetryPolicy.FirstOrDefault(ep => endpoint.Equals(ep)); - } - - /// - /// Checks if request is for Okta and an SSO SAML endpoint. - /// - /// The host url to check. - /// The endpoint to check. - /// True if the endpoint is an okta sso saml request, false otherwise. - static internal bool IsOktaSSORequest(string host, string endpoint) - { - return host.Contains(OktaUrl.DOMAIN) && endpoint.Contains(OktaUrl.SSO_SAML_PATH); - } - } -} - - +/* + * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + */ + +using System.Threading.Tasks; +using System.Net.Http; +using System.Net; +using System; +using System.Threading; +using System.Collections.Generic; +using Snowflake.Data.Log; +using System.Collections.Specialized; +using System.Web; +using System.Security.Authentication; +using System.Linq; +using Snowflake.Data.Core.Authenticator; + +namespace Snowflake.Data.Core +{ + public class HttpClientConfig + { + public HttpClientConfig( + bool crlCheckEnabled, + string proxyHost, + string proxyPort, + string proxyUser, + string proxyPassword, + string noProxyList, + bool disableRetry, + bool forceRetryOn404, + int maxHttpRetries, + bool includeRetryReason = true) + { + CrlCheckEnabled = crlCheckEnabled; + ProxyHost = proxyHost; + ProxyPort = proxyPort; + ProxyUser = proxyUser; + ProxyPassword = proxyPassword; + NoProxyList = noProxyList; + DisableRetry = disableRetry; + ForceRetryOn404 = forceRetryOn404; + MaxHttpRetries = maxHttpRetries; + IncludeRetryReason = includeRetryReason; + + ConfKey = string.Join(";", + new string[] { + crlCheckEnabled.ToString(), + proxyHost, + proxyPort, + proxyUser, + proxyPassword, + noProxyList, + disableRetry.ToString(), + forceRetryOn404.ToString(), + maxHttpRetries.ToString(), + includeRetryReason.ToString()}); + } + + public readonly bool CrlCheckEnabled; + public readonly string ProxyHost; + public readonly string ProxyPort; + public readonly string ProxyUser; + public readonly string ProxyPassword; + public readonly string NoProxyList; + public readonly bool DisableRetry; + public readonly bool ForceRetryOn404; + public readonly int MaxHttpRetries; + public readonly bool IncludeRetryReason; + + // Key used to identify the HttpClient with the configuration matching the settings + public readonly string ConfKey; + } + + public sealed class HttpUtil + { + static internal readonly int MAX_BACKOFF = 16; + private static readonly int s_baseBackOffTime = 1; + private static readonly int s_exponentialFactor = 2; + private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); + + private static readonly List s_supportedEndpointsForRetryPolicy = new List + { + RestPath.SF_LOGIN_PATH, + RestPath.SF_AUTHENTICATOR_REQUEST_PATH, + RestPath.SF_TOKEN_REQUEST_PATH + }; + + private HttpUtil() + { + // This value is used by AWS SDK and can cause deadlock, + // so we need to increase the default value of 2 + // See: https://github.com/aws/aws-sdk-net/issues/152 + ServicePointManager.DefaultConnectionLimit = 50; + } + + internal static HttpUtil Instance { get; } = new HttpUtil(); + + private readonly object httpClientProviderLock = new object(); + + private Dictionary _HttpClients = new Dictionary(); + + internal HttpClient GetHttpClient(HttpClientConfig config) + { + lock (httpClientProviderLock) + { + return RegisterNewHttpClientIfNecessary(config); + } + } + + + private HttpClient RegisterNewHttpClientIfNecessary(HttpClientConfig config) + { + string name = config.ConfKey; + if (!_HttpClients.ContainsKey(name)) + { + logger.Debug("Http client not registered. Adding."); + + var httpClient = new HttpClient( + new RetryHandler(SetupCustomHttpHandler(config), config.DisableRetry, config.ForceRetryOn404, config.MaxHttpRetries, config.IncludeRetryReason)) + { + Timeout = Timeout.InfiniteTimeSpan + }; + + // Add the new client key to the list + _HttpClients.Add(name, httpClient); + } + + return _HttpClients[name]; + } + + internal HttpMessageHandler SetupCustomHttpHandler(HttpClientConfig config) + { + HttpMessageHandler httpHandler; + try + { + httpHandler = new HttpClientHandler() + { + // Verify no certificates have been revoked + CheckCertificateRevocationList = config.CrlCheckEnabled, + // Enforce tls v1.2 + SslProtocols = SslProtocols.Tls12, + AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, + UseCookies = false, // Disable cookies + UseProxy = false + }; + } + // special logic for .NET framework 4.7.1 that + // CheckCertificateRevocationList and SslProtocols are not supported + catch (PlatformNotSupportedException) + { + httpHandler = new HttpClientHandler() + { + AutomaticDecompression = DecompressionMethods.GZip | DecompressionMethods.Deflate, + UseCookies = false, // Disable cookies + UseProxy = false + }; + } + + // Add a proxy if necessary + if (null != config.ProxyHost) + { + // Proxy needed + WebProxy proxy = new WebProxy(config.ProxyHost, int.Parse(config.ProxyPort)); + + // Add credential if provided + if (!String.IsNullOrEmpty(config.ProxyUser)) + { + ICredentials credentials = new NetworkCredential(config.ProxyUser, config.ProxyPassword); + proxy.Credentials = credentials; + } + + // Add bypasslist if provided + if (!String.IsNullOrEmpty(config.NoProxyList)) + { + string[] bypassList = config.NoProxyList.Split( + new char[] { '|' }, + StringSplitOptions.RemoveEmptyEntries); + // Convert simplified syntax to standard regular expression syntax + string entry = null; + for (int i = 0; i < bypassList.Length; i++) + { + // Get the original entry + entry = bypassList[i].Trim(); + // . -> [.] because . means any char + entry = entry.Replace(".", "[.]"); + // * -> .* because * is a quantifier and need a char or group to apply to + entry = entry.Replace("*", ".*"); + + entry = entry.StartsWith("^") ? entry : $"^{entry}"; + + entry = entry.EndsWith("$") ? entry : $"{entry}$"; + + // Replace with the valid entry syntax + bypassList[i] = entry; + + } + proxy.BypassList = bypassList; + } + + HttpClientHandler httpHandlerWithProxy = (HttpClientHandler)httpHandler; + httpHandlerWithProxy.UseProxy = true; + httpHandlerWithProxy.Proxy = proxy; + return httpHandlerWithProxy; + } + return httpHandler; + } + + /// + /// UriUpdater would update the uri in each retry. During construction, it would take in an uri that would later + /// be updated in each retry and figure out the rules to apply when updating. + /// + internal class UriUpdater + { + /// + /// IRule defines how the queryParams of a uri should be updated in each retry + /// + interface IRule + { + void apply(NameValueCollection queryParams); + } + + /// + /// RetryCountRule would update the retryCount parameter + /// + class RetryCountRule : IRule + { + int retryCount; + + internal RetryCountRule() + { + retryCount = 1; + } + + void IRule.apply(NameValueCollection queryParams) + { + if (retryCount == 1) + { + queryParams.Add(RestParams.SF_QUERY_RETRY_COUNT, retryCount.ToString()); + } + else + { + queryParams.Set(RestParams.SF_QUERY_RETRY_COUNT, retryCount.ToString()); + } + retryCount++; + } + } + + /// + /// RequestUUIDRule would update the request_guid query with a new RequestGUID + /// + class RequestUUIDRule : IRule + { + void IRule.apply(NameValueCollection queryParams) + { + queryParams.Set(RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString()); + } + } + + /// + /// RetryReasonRule would update the retryReason parameter + /// + class RetryReasonRule : IRule + { + int retryReason; + + internal RetryReasonRule() + { + retryReason = 0; + } + + public void SetRetryReason(int reason) + { + retryReason = reason; + } + + void IRule.apply(NameValueCollection queryParams) + { + queryParams.Set(RestParams.SF_QUERY_RETRY_REASON, retryReason.ToString()); + } + } + + UriBuilder uriBuilder; + List rules; + internal UriUpdater(Uri uri, bool includeRetryReason = true) + { + uriBuilder = new UriBuilder(uri); + rules = new List(); + + if (uri.AbsolutePath.StartsWith(RestPath.SF_QUERY_PATH)) + { + rules.Add(new RetryCountRule()); + if (includeRetryReason) + { + rules.Add(new RetryReasonRule()); + } + } + + if (uri.Query != null && uri.Query.Contains(RestParams.SF_QUERY_REQUEST_GUID)) + { + rules.Add(new RequestUUIDRule()); + } + } + + internal Uri Update(int retryReason = 0) + { + // Optimization to bypass parsing if there is no rules at all. + if (rules.Count == 0) + { + return uriBuilder.Uri; + } + + var queryParams = HttpUtility.ParseQueryString(uriBuilder.Query); + + foreach (IRule rule in rules) + { + if (rule is RetryReasonRule) + { + ((RetryReasonRule)rule).SetRetryReason(retryReason); + } + rule.apply(queryParams); + } + + uriBuilder.Query = queryParams.ToString(); + + return uriBuilder.Uri; + } + } + private class RetryHandler : DelegatingHandler + { + static private SFLogger logger = SFLoggerFactory.GetLogger(); + + private bool disableRetry; + private bool forceRetryOn404; + private int maxRetryCount; + private bool includeRetryReason; + + internal RetryHandler(HttpMessageHandler innerHandler, bool disableRetry, bool forceRetryOn404, int maxRetryCount, bool includeRetryReason) : base(innerHandler) + { + this.disableRetry = disableRetry; + this.forceRetryOn404 = forceRetryOn404; + this.maxRetryCount = maxRetryCount; + this.includeRetryReason = includeRetryReason; + } + + protected override async Task SendAsync(HttpRequestMessage requestMessage, + CancellationToken cancellationToken) + { + HttpResponseMessage response = null; + string absolutePath = requestMessage.RequestUri.AbsolutePath; + bool isLoginRequest = IsLoginEndpoint(absolutePath); + bool isOktaSSORequest = IsOktaSSORequest(requestMessage.RequestUri.Host, absolutePath); + int backOffInSec = s_baseBackOffTime; + int totalRetryTime = 0; + + ServicePoint p = ServicePointManager.FindServicePoint(requestMessage.RequestUri); + p.Expect100Continue = false; // Saves about 100 ms per request + p.UseNagleAlgorithm = false; // Saves about 200 ms per request + p.ConnectionLimit = 20; // Default value is 2, we need more connections for performing multiple parallel queries + + TimeSpan httpTimeout = (TimeSpan)requestMessage.Properties[SFRestRequest.HTTP_REQUEST_TIMEOUT_KEY]; + TimeSpan restTimeout = (TimeSpan)requestMessage.Properties[SFRestRequest.REST_REQUEST_TIMEOUT_KEY]; + + if (logger.IsDebugEnabled()) + { + logger.Debug("Http request timeout : " + httpTimeout); + logger.Debug("Rest request timeout : " + restTimeout); + } + + CancellationTokenSource childCts = null; + + UriUpdater updater = new UriUpdater(requestMessage.RequestUri, includeRetryReason); + int retryCount = 0; + + while (true) + { + + try + { + childCts = null; + + if (!httpTimeout.Equals(Timeout.InfiniteTimeSpan)) + { + childCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + if (httpTimeout.Ticks == 0) + childCts.Cancel(); + else + childCts.CancelAfter(httpTimeout); + } + response = await base.SendAsync(requestMessage, childCts == null ? + cancellationToken : childCts.Token).ConfigureAwait(false); + } + catch (Exception e) + { + if (cancellationToken.IsCancellationRequested) + { + logger.Debug("SF rest request timeout or explicit cancel called."); + cancellationToken.ThrowIfCancellationRequested(); + } + else if (childCts != null && childCts.Token.IsCancellationRequested) + { + logger.Warn("Http request timeout. Retry the request"); + totalRetryTime += (int)httpTimeout.TotalSeconds; + } + else + { + //TODO: Should probably check to see if the error is recoverable or transient. + logger.Warn("Error occurred during request, retrying...", e); + } + } + + if (childCts != null) + { + childCts.Dispose(); + } + + int errorReason = 0; + + if (response != null) + { + if (isOktaSSORequest) + { + response.Content.Headers.Add(OktaAuthenticator.RetryCountHeader, retryCount.ToString()); + response.Content.Headers.Add(OktaAuthenticator.TimeoutElapsedHeader, totalRetryTime.ToString()); + } + + if (response.IsSuccessStatusCode) + { + logger.Debug($"Success Response: StatusCode: {(int)response.StatusCode}, ReasonPhrase: '{response.ReasonPhrase}'"); + return response; + } + else + { + logger.Debug($"Failed Response: StatusCode: {(int)response.StatusCode}, ReasonPhrase: '{response.ReasonPhrase}'"); + bool isRetryable = isRetryableHTTPCode((int)response.StatusCode, forceRetryOn404); + + if (!isRetryable || disableRetry) + { + // No need to keep retrying, stop here + return response; + } + } + errorReason = (int)response.StatusCode; + } + else + { + logger.Info("Response returned was null."); + } + + retryCount++; + if ((maxRetryCount > 0) && (retryCount > maxRetryCount)) + { + logger.Debug($"stop retry as maxHttpRetries {maxRetryCount} reached"); + if (response != null) + { + return response; + } + throw new OperationCanceledException($"http request failed and max retry {maxRetryCount} reached"); + } + + // Disposing of the response if not null now that we don't need it anymore + response?.Dispose(); + + requestMessage.RequestUri = updater.Update(errorReason); + + logger.Debug($"Sleep {backOffInSec} seconds and then retry the request, retryCount: {retryCount}"); + + await Task.Delay(TimeSpan.FromSeconds(backOffInSec), cancellationToken).ConfigureAwait(false); + totalRetryTime += backOffInSec; + + var jitter = GetJitter(backOffInSec); + + // Set backoff time + if (isLoginRequest) + { + // Choose between previous sleep time and new base sleep time for login requests + backOffInSec = (int)ChooseRandom( + backOffInSec + jitter, + Math.Pow(s_exponentialFactor, retryCount) + jitter); + } + else if (backOffInSec < MAX_BACKOFF) + { + // Multiply sleep by 2 for non-login requests + backOffInSec *= 2; + } + + if ((restTimeout.TotalSeconds > 0) && (totalRetryTime + backOffInSec > restTimeout.TotalSeconds)) + { + // No need to wait more than necessary if it can be avoided. + // If the rest timeout will be reached before the next back-off, + // then use the remaining connection timeout + backOffInSec = Math.Min(backOffInSec, (int)restTimeout.TotalSeconds - totalRetryTime); + } + } + } + } + + /// + /// Check whether or not the error is retryable or not. + /// + /// The http status code. + /// True if the request should be retried, false otherwise. + static public bool isRetryableHTTPCode(int statusCode, bool forceRetryOn404) + { + if (forceRetryOn404 && statusCode == 404) + return true; + return (500 <= statusCode) && (statusCode < 600) || + // Forbidden + (statusCode == 403) || + // Request timeout + (statusCode == 408) || + // Too many requests + (statusCode == 429); + } + + /// + /// Get the jitter amount based on current wait time. + /// + /// The current retry backoff time. + /// The new jitter amount. + static internal double GetJitter(double curWaitTime) + { + double multiplicationFactor = ChooseRandom(-1, 1); + double jitterAmount = 0.5 * curWaitTime * multiplicationFactor; + return jitterAmount; + } + + /// + /// Randomly generates a number between a given range. + /// + /// The min range (inclusive). + /// The max range (inclusive). + /// The random number. + static double ChooseRandom(double min, double max) + { + var next = new Random().NextDouble(); + + return min + (next * (max - min)); + } + + /// + /// Checks if the endpoint is a login request. + /// + /// The endpoint to check. + /// True if the endpoint is a login request, false otherwise. + static internal bool IsLoginEndpoint(string endpoint) + { + return null != s_supportedEndpointsForRetryPolicy.FirstOrDefault(ep => endpoint.Equals(ep)); + } + + /// + /// Checks if request is for Okta and an SSO SAML endpoint. + /// + /// The host url to check. + /// The endpoint to check. + /// True if the endpoint is an okta sso saml request, false otherwise. + static internal bool IsOktaSSORequest(string host, string endpoint) + { + return host.Contains(OktaUrl.DOMAIN) && endpoint.Contains(OktaUrl.SSO_SAML_PATH); + } + } +} + + diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 3b0c80f8d..2dd594f54 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,585 +1,585 @@ -/* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Security; -using System.Web; -using Snowflake.Data.Log; -using Snowflake.Data.Client; -using Snowflake.Data.Core.Authenticator; -using System.Threading; -using System.Threading.Tasks; -using System.Net.Http; -using System.Text.RegularExpressions; -using Snowflake.Data.Configuration; - -namespace Snowflake.Data.Core -{ - public class SFSession - { - public const int SF_SESSION_EXPIRED_CODE = 390112; - - private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); - - private static readonly Regex APPLICATION_REGEX = new Regex(@"^[A-Za-z]([A-Za-z0-9.\-_]){1,50}$"); - - private const string SF_AUTHORIZATION_BASIC = "Basic"; - - private const string SF_AUTHORIZATION_SNOWFLAKE_FMT = "Snowflake Token=\"{0}\""; - - private const int _defaultQueryContextCacheSize = 5; - - internal string sessionId; - - internal string sessionToken; - - internal string masterToken; - - internal IRestRequester restRequester { get; private set; } - - private IAuthenticator authenticator; - - internal SFSessionProperties properties; - - internal string database; - - internal string schema; - - internal string serverVersion; - - internal TimeSpan connectionTimeout; - - internal bool InsecureMode; - - internal bool isHeartBeatEnabled; - - private HttpClient _HttpClient; - - private string arrayBindStage = null; - private int arrayBindStageThreshold = 0; - internal int masterValidityInSeconds = 0; - - internal static readonly SFSessionHttpClientProperties.Extractor propertiesExtractor = new SFSessionHttpClientProperties.Extractor( - new SFSessionHttpClientProxyProperties.Extractor()); - - private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; - - private long _startTime = 0; - internal string connStr = null; - - private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); - - private int _queryContextCacheSize = _defaultQueryContextCacheSize; - - private bool _disableQueryContextCache = false; - - internal bool _disableConsoleLogin; - - internal int _maxRetryCount; - - internal int _maxRetryTimeout; - - internal String _queryTag; - - internal void ProcessLoginResponse(LoginResponse authnResponse) - { - if (authnResponse.success) - { - sessionId = authnResponse.data.sessionId; - sessionToken = authnResponse.data.token; - masterToken = authnResponse.data.masterToken; - database = authnResponse.data.authResponseSessionInfo.databaseName; - schema = authnResponse.data.authResponseSessionInfo.schemaName; - serverVersion = authnResponse.data.serverVersion; - masterValidityInSeconds = authnResponse.data.masterValidityInSeconds; - UpdateSessionParameterMap(authnResponse.data.nameValueParameter); - if (_disableQueryContextCache) - { - logger.Debug("Query context cache disabled."); - } - logger.Debug($"Session opened: {sessionId}"); - _startTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - } - else - { - SnowflakeDbException e = new SnowflakeDbException - (SnowflakeDbException.CONNECTION_FAILURE_SSTATE, - authnResponse.code, - authnResponse.message, - ""); - - logger.Error("Authentication failed", e); - throw e; - } - } - - internal readonly Dictionary ParameterMap; - - internal Uri BuildLoginUrl() - { - var queryParams = new Dictionary(); - string warehouseValue; - string dbValue; - string schemaValue; - string roleName; - queryParams[RestParams.SF_QUERY_WAREHOUSE] = properties.TryGetValue(SFSessionProperty.WAREHOUSE, out warehouseValue) ? warehouseValue : ""; - queryParams[RestParams.SF_QUERY_DB] = properties.TryGetValue(SFSessionProperty.DB, out dbValue) ? dbValue : ""; - queryParams[RestParams.SF_QUERY_SCHEMA] = properties.TryGetValue(SFSessionProperty.SCHEMA, out schemaValue) ? schemaValue : ""; - queryParams[RestParams.SF_QUERY_ROLE] = properties.TryGetValue(SFSessionProperty.ROLE, out roleName) ? roleName : ""; - queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); - queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); - - var loginUrl = BuildUri(RestPath.SF_LOGIN_PATH, queryParams); - return loginUrl; - } - - /// - /// Constructor - /// - /// A string in the form of "key1=value1;key2=value2" - internal SFSession( - String connectionString, - SecureString password) : this(connectionString, password, EasyLoggingStarter.Instance) - { - } - - internal SFSession( - String connectionString, - SecureString password, - EasyLoggingStarter easyLoggingStarter) - { - _easyLoggingStarter = easyLoggingStarter; - connStr = connectionString; - properties = SFSessionProperties.ParseConnectionString(connectionString, password); - _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); - _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]); - ValidateApplicationName(properties); - try - { - var extractedProperties = propertiesExtractor.ExtractProperties(properties); - var httpClientConfig = extractedProperties.BuildHttpClientConfig(); - ParameterMap = extractedProperties.ToParameterMap(); - InsecureMode = extractedProperties.insecureMode; - _HttpClient = HttpUtil.Instance.GetHttpClient(httpClientConfig); - restRequester = new RestRequester(_HttpClient); - extractedProperties.CheckPropertiesAreValid(); - connectionTimeout = extractedProperties.TimeoutDuration(); - properties.TryGetValue(SFSessionProperty.CLIENT_CONFIG_FILE, out var easyLoggingConfigFile); - _easyLoggingStarter.Init(easyLoggingConfigFile); - properties.TryGetValue(SFSessionProperty.QUERY_TAG, out _queryTag); - _maxRetryCount = extractedProperties.maxHttpRetries; - _maxRetryTimeout = extractedProperties.retryTimeout; - } - catch (Exception e) - { - logger.Error("Unable to connect", e); - throw new SnowflakeDbException(e, - SnowflakeDbException.CONNECTION_FAILURE_SSTATE, - SFError.INVALID_CONNECTION_STRING, - "Unable to connect"); - } - } - - private void ValidateApplicationName(SFSessionProperties properties) - { - // If there is an "application" setting, verify that it matches the expect pattern - properties.TryGetValue(SFSessionProperty.APPLICATION, out string applicationNameSetting); - if (!String.IsNullOrEmpty(applicationNameSetting) && !APPLICATION_REGEX.IsMatch(applicationNameSetting)) - { - throw new SnowflakeDbException( - SnowflakeDbException.CONNECTION_FAILURE_SSTATE, - SFError.INVALID_CONNECTION_PARAMETER_VALUE, - applicationNameSetting, - SFSessionProperty.APPLICATION.ToString() - ); - } - } - - internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password) - { - // Inject the HttpClient to use with the Mock requester - restRequester.setHttpClient(_HttpClient); - // Override the Rest requester with the mock for testing - this.restRequester = restRequester; - } - - internal Uri BuildUri(string path, Dictionary queryParams = null) - { - UriBuilder uriBuilder = new UriBuilder(); - uriBuilder.Scheme = properties[SFSessionProperty.SCHEME]; - uriBuilder.Host = properties[SFSessionProperty.HOST]; - uriBuilder.Port = int.Parse(properties[SFSessionProperty.PORT]); - uriBuilder.Path = path; - - if (queryParams != null && queryParams.Any()) - { - var queryString = HttpUtility.ParseQueryString(string.Empty); - foreach (var kvp in queryParams) - queryString[kvp.Key] = kvp.Value; - - uriBuilder.Query = queryString.ToString(); - } - - return uriBuilder.Uri; - } - - internal void Open() - { - logger.Debug("Open Session"); - - if (authenticator == null) - { - authenticator = AuthenticatorFactory.GetAuthenticator(this); - } - - authenticator.Authenticate(); - } - - internal async Task OpenAsync(CancellationToken cancellationToken) - { - logger.Debug("Open Session Async"); - - if (authenticator == null) - { - authenticator = AuthenticatorFactory.GetAuthenticator(this); - } - - await authenticator.AuthenticateAsync(cancellationToken).ConfigureAwait(false); - } - - internal void close() - { - // Nothing to do if the session is not open - if (!IsEstablished()) return; - - stopHeartBeatForThisSession(); - - // Send a close session request - var queryParams = new Dictionary(); - queryParams[RestParams.SF_QUERY_SESSION_DELETE] = "true"; - queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); - queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); - - SFRestRequest closeSessionRequest = new SFRestRequest - { - Url = BuildUri(RestPath.SF_SESSION_PATH, queryParams), - authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), - sid = sessionId - }; - - logger.Debug($"Send closeSessionRequest"); - var response = restRequester.Post(closeSessionRequest); - if (!response.success) - { - logger.Debug($"Failed to delete session: {sessionId}, error ignored. Code: {response.code} Message: {response.message}"); - } - - logger.Debug($"Session closed: {sessionId}"); - // Just in case the session won't be closed twice - sessionToken = null; - } - - internal async Task CloseAsync(CancellationToken cancellationToken) - { - // Nothing to do if the session is not open - if (!IsEstablished()) return; - - stopHeartBeatForThisSession(); - - // Send a close session request - var queryParams = new Dictionary(); - queryParams[RestParams.SF_QUERY_SESSION_DELETE] = "true"; - queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); - queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); - - SFRestRequest closeSessionRequest = new SFRestRequest() - { - Url = BuildUri(RestPath.SF_SESSION_PATH, queryParams), - authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), - sid = sessionId - }; - - logger.Debug($"Send async closeSessionRequest"); - var response = await restRequester.PostAsync(closeSessionRequest, cancellationToken).ConfigureAwait(false); - if (!response.success) - { - logger.Debug($"Failed to delete session {sessionId}, error ignored. Code: {response.code} Message: {response.message}"); - } - - logger.Debug($"Session closed: {sessionId}"); - // Just in case the session won't be closed twice - sessionToken = null; - } - - internal bool IsEstablished() => sessionToken != null; - - internal void renewSession() - { - logger.Info("Renew the session."); - var response = restRequester.Post(getRenewSessionRequest()); - if (!response.success) - { - SnowflakeDbException e = new SnowflakeDbException("", - response.code, response.message, sessionId); - logger.Error($"Renew session (ID: {sessionId}) failed", e); - throw e; - } - else - { - sessionToken = response.data.sessionToken; - masterToken = response.data.masterToken; - } - } - - internal async Task renewSessionAsync(CancellationToken cancellationToken) - { - logger.Info("Renew the session."); - var response = - await restRequester.PostAsync( - getRenewSessionRequest(), - cancellationToken - ).ConfigureAwait(false); - if (!response.success) - { - SnowflakeDbException e = new SnowflakeDbException("", - response.code, response.message, sessionId); - logger.Error($"Renew session (ID: {sessionId}) failed", e); - throw e; - } - else - { - sessionToken = response.data.sessionToken; - masterToken = response.data.masterToken; - } - } - - internal SFRestRequest getRenewSessionRequest() - { - RenewSessionRequest postBody = new RenewSessionRequest() - { - oldSessionToken = this.sessionToken, - requestType = "RENEW" - }; - - var parameters = new Dictionary - { - { RestParams.SF_QUERY_REQUEST_ID, Guid.NewGuid().ToString() }, - { RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString() }, - }; - - return new SFRestRequest - { - jsonBody = postBody, - Url = BuildUri(RestPath.SF_TOKEN_REQUEST_PATH, parameters), - authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, masterToken), - RestTimeout = Timeout.InfiniteTimeSpan, - _isLogin = true - }; - } - - internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body) - { - return new SFRestRequest() - { - jsonBody = body, - Url = uri, - authorizationToken = SF_AUTHORIZATION_BASIC, - RestTimeout = connectionTimeout, - _isLogin = true - }; - } - - internal void UpdateSessionParameterMap(List parameterList) - { - logger.Debug("Update parameter map"); - // with HTAP parameter removal parameters might not returned - // query response - if (parameterList is null) - { - return; - } - - foreach (NameValueParameter parameter in parameterList) - { - if (Enum.TryParse(parameter.name, out SFSessionParameter parameterName)) - { - ParameterMap[parameterName] = parameter.value; - } - } - if (ParameterMap.ContainsKey(SFSessionParameter.CLIENT_STAGE_ARRAY_BINDING_THRESHOLD)) - { - string val = ParameterMap[SFSessionParameter.CLIENT_STAGE_ARRAY_BINDING_THRESHOLD].ToString(); - this.arrayBindStageThreshold = Int32.Parse(val); - } - if (ParameterMap.ContainsKey(SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE)) - { - bool keepAlive = Boolean.Parse(ParameterMap[SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE].ToString()); - if(keepAlive) - { - startHeartBeatForThisSession(); - } - else - { - stopHeartBeatForThisSession(); - } - } - if ((!_disableQueryContextCache) && - (ParameterMap.ContainsKey(SFSessionParameter.QUERY_CONTEXT_CACHE_SIZE))) - { - string val = ParameterMap[SFSessionParameter.QUERY_CONTEXT_CACHE_SIZE].ToString(); - _queryContextCacheSize = Int32.Parse(val); - _queryContextCache.SetCapacity(_queryContextCacheSize); - } - } - - internal void UpdateQueryContextCache(ResponseQueryContext queryContext) - { - if (!_disableQueryContextCache) - { - _queryContextCache.Update(queryContext); - } - } - - internal RequestQueryContext GetQueryContextRequest() - { - if (_disableQueryContextCache) - { - return null; - } - return _queryContextCache.GetQueryContextRequest(); - } - - internal void UpdateDatabaseAndSchema(string databaseName, string schemaName) - { - // with HTAP session metadata removal database/schema - // might be not returened in query result - if (!String.IsNullOrEmpty(databaseName)) - { - this.database = databaseName; - } - if (!String.IsNullOrEmpty(schemaName)) - { - this.schema = schemaName; - } - } - - internal void startHeartBeatForThisSession() - { - if (!this.isHeartBeatEnabled) - { - HeartBeatBackground heartBeatBg = HeartBeatBackground.Instance; - if (this.masterValidityInSeconds == 0) - { - //In case server doesnot provide the default timeout - var DEFAULT_TIMEOUT_IN_SECOND = 14400; - this.masterValidityInSeconds = DEFAULT_TIMEOUT_IN_SECOND; - } - heartBeatBg.addConnection(this, this.masterValidityInSeconds); - this.isHeartBeatEnabled = true; - } - } - internal void stopHeartBeatForThisSession() - { - if (this.isHeartBeatEnabled) - { - HeartBeatBackground heartBeatBg = HeartBeatBackground.Instance; - heartBeatBg.removeConnection(this); - this.isHeartBeatEnabled = false; - } - - } - - public string GetArrayBindStage() - { - return arrayBindStage; - } - - public void SetArrayBindStage(string arrayBindStage) - { - this.arrayBindStage = string.Format("{0}.{1}.{2}", this.database, this.schema, arrayBindStage); - } - - public int GetArrayBindStageThreshold() - { - return this.arrayBindStageThreshold; - } - - public void SetArrayBindStageThreshold(int arrayBindStageThreshold) - { - this.arrayBindStageThreshold = arrayBindStageThreshold; - } - - internal void heartbeat() - { - logger.Debug("heartbeat"); - - bool retry = false; - if (IsEstablished()) - { - do - { - var parameters = new Dictionary - { - { RestParams.SF_QUERY_REQUEST_ID, Guid.NewGuid().ToString() }, - { RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString() }, - }; - - SFRestRequest heartBeatSessionRequest = new SFRestRequest - { - Url = BuildUri(RestPath.SF_SESSION_HEARTBEAT_PATH, parameters), - authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), - RestTimeout = Timeout.InfiniteTimeSpan - }; - var response = restRequester.Post(heartBeatSessionRequest); - - logger.Debug("heartbeat response=" + response); - if (response.success) - { - logger.Debug("SFSession::heartbeat success, session token did not expire."); - } - else - { - if (response.code == SF_SESSION_EXPIRED_CODE) - { - logger.Debug($"SFSession ::heartbeat Session ID: {sessionId} session token expired and retry heartbeat"); - try - { - renewSession(); - retry = true; - continue; - } - catch (Exception ex) - { - // Since we don't lock the heart beat queue when sending - // the heart beat, it's possible that the session get - // closed when sending renew request and caused exception - // thrown from renewSession(), simply ignore that - logger.Error($"renew session (ID: {sessionId}) failed.", ex); - } - } - else - { - logger.Error($"heartbeat failed for session ID: {sessionId}."); - } - } - retry = false; - } while (retry); - } - } - - internal bool IsNotOpen() - { - return _startTime == 0; - } - - internal bool IsExpired(long timeoutInSeconds, long utcTimeInSeconds) - { - return _startTime + timeoutInSeconds <= utcTimeInSeconds; - } - } -} - +/* + * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Security; +using System.Web; +using Snowflake.Data.Log; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; +using System.Threading; +using System.Threading.Tasks; +using System.Net.Http; +using System.Text.RegularExpressions; +using Snowflake.Data.Configuration; + +namespace Snowflake.Data.Core +{ + public class SFSession + { + public const int SF_SESSION_EXPIRED_CODE = 390112; + + private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); + + private static readonly Regex APPLICATION_REGEX = new Regex(@"^[A-Za-z]([A-Za-z0-9.\-_]){1,50}$"); + + private const string SF_AUTHORIZATION_BASIC = "Basic"; + + private const string SF_AUTHORIZATION_SNOWFLAKE_FMT = "Snowflake Token=\"{0}\""; + + private const int _defaultQueryContextCacheSize = 5; + + internal string sessionId; + + internal string sessionToken; + + internal string masterToken; + + internal IRestRequester restRequester { get; private set; } + + private IAuthenticator authenticator; + + internal SFSessionProperties properties; + + internal string database; + + internal string schema; + + internal string serverVersion; + + internal TimeSpan connectionTimeout; + + internal bool InsecureMode; + + internal bool isHeartBeatEnabled; + + private HttpClient _HttpClient; + + private string arrayBindStage = null; + private int arrayBindStageThreshold = 0; + internal int masterValidityInSeconds = 0; + + internal static readonly SFSessionHttpClientProperties.Extractor propertiesExtractor = new SFSessionHttpClientProperties.Extractor( + new SFSessionHttpClientProxyProperties.Extractor()); + + private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; + + private long _startTime = 0; + internal string connStr = null; + + private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); + + private int _queryContextCacheSize = _defaultQueryContextCacheSize; + + private bool _disableQueryContextCache = false; + + internal bool _disableConsoleLogin; + + internal int _maxRetryCount; + + internal int _maxRetryTimeout; + + internal String _queryTag; + + internal void ProcessLoginResponse(LoginResponse authnResponse) + { + if (authnResponse.success) + { + sessionId = authnResponse.data.sessionId; + sessionToken = authnResponse.data.token; + masterToken = authnResponse.data.masterToken; + database = authnResponse.data.authResponseSessionInfo.databaseName; + schema = authnResponse.data.authResponseSessionInfo.schemaName; + serverVersion = authnResponse.data.serverVersion; + masterValidityInSeconds = authnResponse.data.masterValidityInSeconds; + UpdateSessionParameterMap(authnResponse.data.nameValueParameter); + if (_disableQueryContextCache) + { + logger.Debug("Query context cache disabled."); + } + logger.Debug($"Session opened: {sessionId}"); + _startTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + } + else + { + SnowflakeDbException e = new SnowflakeDbException + (SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + authnResponse.code, + authnResponse.message, + ""); + + logger.Error("Authentication failed", e); + throw e; + } + } + + internal readonly Dictionary ParameterMap; + + internal Uri BuildLoginUrl() + { + var queryParams = new Dictionary(); + string warehouseValue; + string dbValue; + string schemaValue; + string roleName; + queryParams[RestParams.SF_QUERY_WAREHOUSE] = properties.TryGetValue(SFSessionProperty.WAREHOUSE, out warehouseValue) ? warehouseValue : ""; + queryParams[RestParams.SF_QUERY_DB] = properties.TryGetValue(SFSessionProperty.DB, out dbValue) ? dbValue : ""; + queryParams[RestParams.SF_QUERY_SCHEMA] = properties.TryGetValue(SFSessionProperty.SCHEMA, out schemaValue) ? schemaValue : ""; + queryParams[RestParams.SF_QUERY_ROLE] = properties.TryGetValue(SFSessionProperty.ROLE, out roleName) ? roleName : ""; + queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); + queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); + + var loginUrl = BuildUri(RestPath.SF_LOGIN_PATH, queryParams); + return loginUrl; + } + + /// + /// Constructor + /// + /// A string in the form of "key1=value1;key2=value2" + internal SFSession( + String connectionString, + SecureString password) : this(connectionString, password, EasyLoggingStarter.Instance) + { + } + + internal SFSession( + String connectionString, + SecureString password, + EasyLoggingStarter easyLoggingStarter) + { + _easyLoggingStarter = easyLoggingStarter; + connStr = connectionString; + properties = SFSessionProperties.ParseConnectionString(connectionString, password); + _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); + _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]); + ValidateApplicationName(properties); + try + { + var extractedProperties = propertiesExtractor.ExtractProperties(properties); + var httpClientConfig = extractedProperties.BuildHttpClientConfig(); + ParameterMap = extractedProperties.ToParameterMap(); + InsecureMode = extractedProperties.insecureMode; + _HttpClient = HttpUtil.Instance.GetHttpClient(httpClientConfig); + restRequester = new RestRequester(_HttpClient); + extractedProperties.CheckPropertiesAreValid(); + connectionTimeout = extractedProperties.TimeoutDuration(); + properties.TryGetValue(SFSessionProperty.CLIENT_CONFIG_FILE, out var easyLoggingConfigFile); + _easyLoggingStarter.Init(easyLoggingConfigFile); + properties.TryGetValue(SFSessionProperty.QUERY_TAG, out _queryTag); + _maxRetryCount = extractedProperties.maxHttpRetries; + _maxRetryTimeout = extractedProperties.retryTimeout; + } + catch (Exception e) + { + logger.Error("Unable to connect", e); + throw new SnowflakeDbException(e, + SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + SFError.INVALID_CONNECTION_STRING, + "Unable to connect"); + } + } + + private void ValidateApplicationName(SFSessionProperties properties) + { + // If there is an "application" setting, verify that it matches the expect pattern + properties.TryGetValue(SFSessionProperty.APPLICATION, out string applicationNameSetting); + if (!String.IsNullOrEmpty(applicationNameSetting) && !APPLICATION_REGEX.IsMatch(applicationNameSetting)) + { + throw new SnowflakeDbException( + SnowflakeDbException.CONNECTION_FAILURE_SSTATE, + SFError.INVALID_CONNECTION_PARAMETER_VALUE, + applicationNameSetting, + SFSessionProperty.APPLICATION.ToString() + ); + } + } + + internal SFSession(String connectionString, SecureString password, IMockRestRequester restRequester) : this(connectionString, password) + { + // Inject the HttpClient to use with the Mock requester + restRequester.setHttpClient(_HttpClient); + // Override the Rest requester with the mock for testing + this.restRequester = restRequester; + } + + internal Uri BuildUri(string path, Dictionary queryParams = null) + { + UriBuilder uriBuilder = new UriBuilder(); + uriBuilder.Scheme = properties[SFSessionProperty.SCHEME]; + uriBuilder.Host = properties[SFSessionProperty.HOST]; + uriBuilder.Port = int.Parse(properties[SFSessionProperty.PORT]); + uriBuilder.Path = path; + + if (queryParams != null && queryParams.Any()) + { + var queryString = HttpUtility.ParseQueryString(string.Empty); + foreach (var kvp in queryParams) + queryString[kvp.Key] = kvp.Value; + + uriBuilder.Query = queryString.ToString(); + } + + return uriBuilder.Uri; + } + + internal void Open() + { + logger.Debug("Open Session"); + + if (authenticator == null) + { + authenticator = AuthenticatorFactory.GetAuthenticator(this); + } + + authenticator.Authenticate(); + } + + internal async Task OpenAsync(CancellationToken cancellationToken) + { + logger.Debug("Open Session Async"); + + if (authenticator == null) + { + authenticator = AuthenticatorFactory.GetAuthenticator(this); + } + + await authenticator.AuthenticateAsync(cancellationToken).ConfigureAwait(false); + } + + internal void close() + { + // Nothing to do if the session is not open + if (!IsEstablished()) return; + + stopHeartBeatForThisSession(); + + // Send a close session request + var queryParams = new Dictionary(); + queryParams[RestParams.SF_QUERY_SESSION_DELETE] = "true"; + queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); + queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); + + SFRestRequest closeSessionRequest = new SFRestRequest + { + Url = BuildUri(RestPath.SF_SESSION_PATH, queryParams), + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), + sid = sessionId + }; + + logger.Debug($"Send closeSessionRequest"); + var response = restRequester.Post(closeSessionRequest); + if (!response.success) + { + logger.Debug($"Failed to delete session: {sessionId}, error ignored. Code: {response.code} Message: {response.message}"); + } + + logger.Debug($"Session closed: {sessionId}"); + // Just in case the session won't be closed twice + sessionToken = null; + } + + internal async Task CloseAsync(CancellationToken cancellationToken) + { + // Nothing to do if the session is not open + if (!IsEstablished()) return; + + stopHeartBeatForThisSession(); + + // Send a close session request + var queryParams = new Dictionary(); + queryParams[RestParams.SF_QUERY_SESSION_DELETE] = "true"; + queryParams[RestParams.SF_QUERY_REQUEST_ID] = Guid.NewGuid().ToString(); + queryParams[RestParams.SF_QUERY_REQUEST_GUID] = Guid.NewGuid().ToString(); + + SFRestRequest closeSessionRequest = new SFRestRequest() + { + Url = BuildUri(RestPath.SF_SESSION_PATH, queryParams), + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), + sid = sessionId + }; + + logger.Debug($"Send async closeSessionRequest"); + var response = await restRequester.PostAsync(closeSessionRequest, cancellationToken).ConfigureAwait(false); + if (!response.success) + { + logger.Debug($"Failed to delete session {sessionId}, error ignored. Code: {response.code} Message: {response.message}"); + } + + logger.Debug($"Session closed: {sessionId}"); + // Just in case the session won't be closed twice + sessionToken = null; + } + + internal bool IsEstablished() => sessionToken != null; + + internal void renewSession() + { + logger.Info("Renew the session."); + var response = restRequester.Post(getRenewSessionRequest()); + if (!response.success) + { + SnowflakeDbException e = new SnowflakeDbException("", + response.code, response.message, sessionId); + logger.Error($"Renew session (ID: {sessionId}) failed", e); + throw e; + } + else + { + sessionToken = response.data.sessionToken; + masterToken = response.data.masterToken; + } + } + + internal async Task renewSessionAsync(CancellationToken cancellationToken) + { + logger.Info("Renew the session."); + var response = + await restRequester.PostAsync( + getRenewSessionRequest(), + cancellationToken + ).ConfigureAwait(false); + if (!response.success) + { + SnowflakeDbException e = new SnowflakeDbException("", + response.code, response.message, sessionId); + logger.Error($"Renew session (ID: {sessionId}) failed", e); + throw e; + } + else + { + sessionToken = response.data.sessionToken; + masterToken = response.data.masterToken; + } + } + + internal SFRestRequest getRenewSessionRequest() + { + RenewSessionRequest postBody = new RenewSessionRequest() + { + oldSessionToken = this.sessionToken, + requestType = "RENEW" + }; + + var parameters = new Dictionary + { + { RestParams.SF_QUERY_REQUEST_ID, Guid.NewGuid().ToString() }, + { RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString() }, + }; + + return new SFRestRequest + { + jsonBody = postBody, + Url = BuildUri(RestPath.SF_TOKEN_REQUEST_PATH, parameters), + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, masterToken), + RestTimeout = Timeout.InfiniteTimeSpan, + _isLogin = true + }; + } + + internal SFRestRequest BuildTimeoutRestRequest(Uri uri, Object body) + { + return new SFRestRequest() + { + jsonBody = body, + Url = uri, + authorizationToken = SF_AUTHORIZATION_BASIC, + RestTimeout = connectionTimeout, + _isLogin = true + }; + } + + internal void UpdateSessionParameterMap(List parameterList) + { + logger.Debug("Update parameter map"); + // with HTAP parameter removal parameters might not returned + // query response + if (parameterList is null) + { + return; + } + + foreach (NameValueParameter parameter in parameterList) + { + if (Enum.TryParse(parameter.name, out SFSessionParameter parameterName)) + { + ParameterMap[parameterName] = parameter.value; + } + } + if (ParameterMap.ContainsKey(SFSessionParameter.CLIENT_STAGE_ARRAY_BINDING_THRESHOLD)) + { + string val = ParameterMap[SFSessionParameter.CLIENT_STAGE_ARRAY_BINDING_THRESHOLD].ToString(); + this.arrayBindStageThreshold = Int32.Parse(val); + } + if (ParameterMap.ContainsKey(SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE)) + { + bool keepAlive = Boolean.Parse(ParameterMap[SFSessionParameter.CLIENT_SESSION_KEEP_ALIVE].ToString()); + if(keepAlive) + { + startHeartBeatForThisSession(); + } + else + { + stopHeartBeatForThisSession(); + } + } + if ((!_disableQueryContextCache) && + (ParameterMap.ContainsKey(SFSessionParameter.QUERY_CONTEXT_CACHE_SIZE))) + { + string val = ParameterMap[SFSessionParameter.QUERY_CONTEXT_CACHE_SIZE].ToString(); + _queryContextCacheSize = Int32.Parse(val); + _queryContextCache.SetCapacity(_queryContextCacheSize); + } + } + + internal void UpdateQueryContextCache(ResponseQueryContext queryContext) + { + if (!_disableQueryContextCache) + { + _queryContextCache.Update(queryContext); + } + } + + internal RequestQueryContext GetQueryContextRequest() + { + if (_disableQueryContextCache) + { + return null; + } + return _queryContextCache.GetQueryContextRequest(); + } + + internal void UpdateDatabaseAndSchema(string databaseName, string schemaName) + { + // with HTAP session metadata removal database/schema + // might be not returened in query result + if (!String.IsNullOrEmpty(databaseName)) + { + this.database = databaseName; + } + if (!String.IsNullOrEmpty(schemaName)) + { + this.schema = schemaName; + } + } + + internal void startHeartBeatForThisSession() + { + if (!this.isHeartBeatEnabled) + { + HeartBeatBackground heartBeatBg = HeartBeatBackground.Instance; + if (this.masterValidityInSeconds == 0) + { + //In case server doesnot provide the default timeout + var DEFAULT_TIMEOUT_IN_SECOND = 14400; + this.masterValidityInSeconds = DEFAULT_TIMEOUT_IN_SECOND; + } + heartBeatBg.addConnection(this, this.masterValidityInSeconds); + this.isHeartBeatEnabled = true; + } + } + internal void stopHeartBeatForThisSession() + { + if (this.isHeartBeatEnabled) + { + HeartBeatBackground heartBeatBg = HeartBeatBackground.Instance; + heartBeatBg.removeConnection(this); + this.isHeartBeatEnabled = false; + } + + } + + public string GetArrayBindStage() + { + return arrayBindStage; + } + + public void SetArrayBindStage(string arrayBindStage) + { + this.arrayBindStage = string.Format("{0}.{1}.{2}", this.database, this.schema, arrayBindStage); + } + + public int GetArrayBindStageThreshold() + { + return this.arrayBindStageThreshold; + } + + public void SetArrayBindStageThreshold(int arrayBindStageThreshold) + { + this.arrayBindStageThreshold = arrayBindStageThreshold; + } + + internal void heartbeat() + { + logger.Debug("heartbeat"); + + bool retry = false; + if (IsEstablished()) + { + do + { + var parameters = new Dictionary + { + { RestParams.SF_QUERY_REQUEST_ID, Guid.NewGuid().ToString() }, + { RestParams.SF_QUERY_REQUEST_GUID, Guid.NewGuid().ToString() }, + }; + + SFRestRequest heartBeatSessionRequest = new SFRestRequest + { + Url = BuildUri(RestPath.SF_SESSION_HEARTBEAT_PATH, parameters), + authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, sessionToken), + RestTimeout = Timeout.InfiniteTimeSpan + }; + var response = restRequester.Post(heartBeatSessionRequest); + + logger.Debug("heartbeat response=" + response); + if (response.success) + { + logger.Debug("SFSession::heartbeat success, session token did not expire."); + } + else + { + if (response.code == SF_SESSION_EXPIRED_CODE) + { + logger.Debug($"SFSession ::heartbeat Session ID: {sessionId} session token expired and retry heartbeat"); + try + { + renewSession(); + retry = true; + continue; + } + catch (Exception ex) + { + // Since we don't lock the heart beat queue when sending + // the heart beat, it's possible that the session get + // closed when sending renew request and caused exception + // thrown from renewSession(), simply ignore that + logger.Error($"renew session (ID: {sessionId}) failed.", ex); + } + } + else + { + logger.Error($"heartbeat failed for session ID: {sessionId}."); + } + } + retry = false; + } while (retry); + } + } + + internal bool IsNotOpen() + { + return _startTime == 0; + } + + internal bool IsExpired(long timeoutInSeconds, long utcTimeInSeconds) + { + return _startTime + timeoutInSeconds <= utcTimeInSeconds; + } + } +} + diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 7ce6b4731..3b3b86646 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -1,450 +1,450 @@ -/* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Collections.Generic; -using System.Net; -using System.Security; -using Snowflake.Data.Log; -using Snowflake.Data.Client; -using Snowflake.Data.Core.Authenticator; -using System.Data.Common; -using System.Linq; -using System.Text.RegularExpressions; - -namespace Snowflake.Data.Core -{ - internal enum SFSessionProperty - { - [SFSessionPropertyAttr(required = true)] - ACCOUNT, - [SFSessionPropertyAttr(required = false)] - DB, - [SFSessionPropertyAttr(required = false)] - HOST, - [SFSessionPropertyAttr(required = true)] - PASSWORD, - [SFSessionPropertyAttr(required = false, defaultValue = "443")] - PORT, - [SFSessionPropertyAttr(required = false)] - ROLE, - [SFSessionPropertyAttr(required = false)] - SCHEMA, - [SFSessionPropertyAttr(required = false, defaultValue = "https")] - SCHEME, - [SFSessionPropertyAttr(required = true, defaultValue = "")] - USER, - [SFSessionPropertyAttr(required = false)] - WAREHOUSE, - [SFSessionPropertyAttr(required = false, defaultValue = "300")] - CONNECTION_TIMEOUT, - [SFSessionPropertyAttr(required = false, defaultValue = "snowflake")] - AUTHENTICATOR, - [SFSessionPropertyAttr(required = false, defaultValue = "true")] - VALIDATE_DEFAULT_PARAMETERS, - [SFSessionPropertyAttr(required = false)] - PRIVATE_KEY_FILE, - [SFSessionPropertyAttr(required = false)] - PRIVATE_KEY_PWD, - [SFSessionPropertyAttr(required = false)] - PRIVATE_KEY, - [SFSessionPropertyAttr(required = false)] - TOKEN, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - INSECUREMODE, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - USEPROXY, - [SFSessionPropertyAttr(required = false)] - PROXYHOST, - [SFSessionPropertyAttr(required = false)] - PROXYPORT, - [SFSessionPropertyAttr(required = false)] - PROXYUSER, - [SFSessionPropertyAttr(required = false)] - PROXYPASSWORD, - [SFSessionPropertyAttr(required = false)] - NONPROXYHOSTS, - [SFSessionPropertyAttr(required = false)] - APPLICATION, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - DISABLERETRY, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - FORCERETRYON404, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - CLIENT_SESSION_KEEP_ALIVE, - [SFSessionPropertyAttr(required = false)] - GCS_USE_DOWNSCOPED_CREDENTIAL, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - FORCEPARSEERROR, - [SFSessionPropertyAttr(required = false, defaultValue = "120")] - BROWSER_RESPONSE_TIMEOUT, - [SFSessionPropertyAttr(required = false, defaultValue = "300")] - RETRY_TIMEOUT, - [SFSessionPropertyAttr(required = false, defaultValue = "7")] - MAXHTTPRETRIES, - [SFSessionPropertyAttr(required = false)] - FILE_TRANSFER_MEMORY_THRESHOLD, - [SFSessionPropertyAttr(required = false, defaultValue = "true")] - INCLUDERETRYREASON, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - DISABLEQUERYCONTEXTCACHE, - [SFSessionPropertyAttr(required = false)] - CLIENT_CONFIG_FILE, - [SFSessionPropertyAttr(required = false, defaultValue = "true")] - DISABLE_CONSOLE_LOGIN, - [SFSessionPropertyAttr(required = false, defaultValue = "false")] - ALLOWUNDERSCORESINHOST, - [SFSessionPropertyAttr(required = false)] - QUERY_TAG - } - - class SFSessionPropertyAttr : Attribute - { - public bool required { get; set; } - - public string defaultValue { get; set; } - } - - class SFSessionProperties : Dictionary - { - private static SFLogger logger = SFLoggerFactory.GetLogger(); - - // Connection string properties to obfuscate in the log - private static List secretProps = - new List{ - SFSessionProperty.PASSWORD, - SFSessionProperty.PRIVATE_KEY, - SFSessionProperty.TOKEN, - SFSessionProperty.PRIVATE_KEY_PWD, - SFSessionProperty.PROXYPASSWORD, - }; - - private static readonly List s_accountRegexStrings = new List - { - "^\\w", - "\\w$", - "^[\\w.-]+$" - }; - - public override bool Equals(object obj) - { - if (obj == null) return false; - try - { - SFSessionProperties prop = (SFSessionProperties)obj; - foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) - { - if (this.ContainsKey(sessionProperty) ^ prop.ContainsKey(sessionProperty)) - { - return false; - } - if (!this.ContainsKey(sessionProperty)) - { - continue; - } - if (!this[sessionProperty].Equals(prop[sessionProperty])) - { - return false; - } - } - return true; - } - catch (InvalidCastException) - { - logger.Warn("Invalid casting to SFSessionProperties"); - return false; - } - } - - public override int GetHashCode() - { - return base.GetHashCode(); - } - - internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password) - { - logger.Info("Start parsing connection string."); - var builder = new DbConnectionStringBuilder(); - try - { - builder.ConnectionString = connectionString; - } - catch (ArgumentException e) - { - logger.Warn("Invalid connectionString", e); - throw new SnowflakeDbException(e, - SFError.INVALID_CONNECTION_STRING, - e.Message); - } - var properties = new SFSessionProperties(); - - var keys = new string[builder.Keys.Count]; - var values = new string[builder.Values.Count]; - builder.Keys.CopyTo(keys, 0); - builder.Values.CopyTo(values,0); - - for(var i=0; i().required = true; - SFSessionProperty.PROXYPORT.GetAttribute().required = true; - - // If a username is provided, then a password is required - if (properties.ContainsKey(SFSessionProperty.PROXYUSER)) - { - SFSessionProperty.PROXYPASSWORD.GetAttribute().required = true; - } - } - - if (password != null) - { - properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; - } - - checkSessionProperties(properties); - ValidateFileTransferMaxBytesInMemoryProperty(properties); - ValidateAccountDomain(properties); - - var allowUnderscoresInHost = ParseAllowUnderscoresInHost(properties); - - // compose host value if not specified - if (!properties.ContainsKey(SFSessionProperty.HOST) || - (0 == properties[SFSessionProperty.HOST].Length)) - { - var compliantAccountName = properties[SFSessionProperty.ACCOUNT]; - if (!allowUnderscoresInHost && compliantAccountName.Contains('_')) - { - compliantAccountName = compliantAccountName.Replace('_', '-'); - logger.Info($"Replacing _ with - in the account name. Old: {properties[SFSessionProperty.ACCOUNT]}, new: {compliantAccountName}."); - } - var hostName = $"{compliantAccountName}.snowflakecomputing.com"; - // Remove in case it's here but empty - properties.Remove(SFSessionProperty.HOST); - properties.Add(SFSessionProperty.HOST, hostName); - logger.Info($"Compose host name: {hostName}"); - } - - // Trim the account name to remove the region and cloud platform if any were provided - // because the login request data does not expect region and cloud information to be - // passed on for account_name - properties[SFSessionProperty.ACCOUNT] = properties[SFSessionProperty.ACCOUNT].Split('.')[0]; - - return properties; - } - - private static void UpdatePropertiesForSpecialCases(SFSessionProperties properties, string connectionString) - { - var propertyEntry = connectionString.Split(';'); - foreach(var keyVal in propertyEntry) - { - if(keyVal.Length > 0) - { - var tokens = keyVal.Split(new string[] { "=" }, StringSplitOptions.None); - var propertyName = tokens[0].ToUpper(); - switch (propertyName) - { - case "DB": - case "SCHEMA": - case "WAREHOUSE": - case "ROLE": - { - if (tokens.Length == 2) - { - var sessionProperty = (SFSessionProperty)Enum.Parse( - typeof(SFSessionProperty), propertyName); - properties[sessionProperty]= tokens[1]; - } - - break; - } - case "USER": - case "PASSWORD": - { - - var sessionProperty = (SFSessionProperty)Enum.Parse( - typeof(SFSessionProperty), propertyName); - if (!properties.ContainsKey(sessionProperty)) - { - properties.Add(sessionProperty, ""); - } - - break; - } - } - } - } - } - - private static void ValidateAccountDomain(SFSessionProperties properties) - { - var account = properties[SFSessionProperty.ACCOUNT]; - if (string.IsNullOrEmpty(account)) - return; - if (IsAccountRegexMatched(account)) - return; - logger.Error($"Invalid account {account}"); - throw new SnowflakeDbException( - new Exception("Invalid account"), - SFError.INVALID_CONNECTION_PARAMETER_VALUE, - account, - SFSessionProperty.ACCOUNT); - } - - private static bool IsAccountRegexMatched(string account) => - s_accountRegexStrings - .Select(regex => Regex.Match(account, regex, RegexOptions.IgnoreCase)) - .All(match => match.Success); - - private static void checkSessionProperties(SFSessionProperties properties) - { - foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) - { - // if required property, check if exists in the dictionary - if (IsRequired(sessionProperty, properties) && - !properties.ContainsKey(sessionProperty)) - { - SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, - sessionProperty); - logger.Error("Missing connection property", e); - throw e; - } - - // add default value to the map - string defaultVal = sessionProperty.GetAttribute().defaultValue; - if (defaultVal != null && !properties.ContainsKey(sessionProperty)) - { - logger.Debug($"Sesssion property {sessionProperty} set to default value: {defaultVal}"); - properties.Add(sessionProperty, defaultVal); - } - } - } - - private static void ValidateFileTransferMaxBytesInMemoryProperty(SFSessionProperties properties) - { - if (!properties.TryGetValue(SFSessionProperty.FILE_TRANSFER_MEMORY_THRESHOLD, out var maxBytesInMemoryString)) - { - return; - } - - var propertyName = SFSessionProperty.FILE_TRANSFER_MEMORY_THRESHOLD.ToString(); - int maxBytesInMemory; - try - { - maxBytesInMemory = int.Parse(maxBytesInMemoryString); - } - catch (Exception e) - { - logger.Error($"Value for parameter {propertyName} could not be parsed"); - throw new SnowflakeDbException(e, SFError.INVALID_CONNECTION_PARAMETER_VALUE, maxBytesInMemoryString, propertyName); - } - - if (maxBytesInMemory <= 0) - { - logger.Error($"Value for parameter {propertyName} should be greater than 0"); - throw new SnowflakeDbException( - new Exception($"Value for parameter {propertyName} should be greater than 0"), - SFError.INVALID_CONNECTION_PARAMETER_VALUE, maxBytesInMemoryString, propertyName); - } - } - - private static bool IsRequired(SFSessionProperty sessionProperty, SFSessionProperties properties) - { - if (sessionProperty.Equals(SFSessionProperty.PASSWORD)) - { - var authenticatorDefined = - properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); - - var authenticatorsWithoutPassword = new List() - { - ExternalBrowserAuthenticator.AUTH_NAME, - KeyPairAuthenticator.AUTH_NAME, - OAuthAuthenticator.AUTH_NAME - }; - // External browser, jwt and oauth don't require a password for authenticating - return !authenticatorDefined || !authenticatorsWithoutPassword - .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); - } - else if (sessionProperty.Equals(SFSessionProperty.USER)) - { - var authenticatorDefined = - properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); - - var authenticatorsWithoutUsername = new List() - { - OAuthAuthenticator.AUTH_NAME, - ExternalBrowserAuthenticator.AUTH_NAME - }; - return !authenticatorDefined || !authenticatorsWithoutUsername - .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); - } - else - { - return sessionProperty.GetAttribute().required; - } - } - - private static bool ParseAllowUnderscoresInHost(SFSessionProperties properties) - { - var allowUnderscoresInHost = bool.Parse(SFSessionProperty.ALLOWUNDERSCORESINHOST.GetAttribute().defaultValue); - if (!properties.TryGetValue(SFSessionProperty.ALLOWUNDERSCORESINHOST, out var property)) - return allowUnderscoresInHost; - try - { - allowUnderscoresInHost = bool.Parse(property); - } - catch (Exception e) - { - logger.Warn("Unable to parse property 'allowUnderscoresInHost'", e); - } - - return allowUnderscoresInHost; - } - } - - public static class EnumExtensions - { - public static TAttribute GetAttribute(this Enum value) - where TAttribute : Attribute - { - var type = value.GetType(); - var memInfo = type.GetMember(value.ToString()); - var attributes = memInfo[0].GetCustomAttributes(typeof(TAttribute), false); - return (attributes.Length > 0) ? (TAttribute)attributes[0] : null; - } - } -} +/* + * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Net; +using System.Security; +using Snowflake.Data.Log; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; +using System.Data.Common; +using System.Linq; +using System.Text.RegularExpressions; + +namespace Snowflake.Data.Core +{ + internal enum SFSessionProperty + { + [SFSessionPropertyAttr(required = true)] + ACCOUNT, + [SFSessionPropertyAttr(required = false)] + DB, + [SFSessionPropertyAttr(required = false)] + HOST, + [SFSessionPropertyAttr(required = true)] + PASSWORD, + [SFSessionPropertyAttr(required = false, defaultValue = "443")] + PORT, + [SFSessionPropertyAttr(required = false)] + ROLE, + [SFSessionPropertyAttr(required = false)] + SCHEMA, + [SFSessionPropertyAttr(required = false, defaultValue = "https")] + SCHEME, + [SFSessionPropertyAttr(required = true, defaultValue = "")] + USER, + [SFSessionPropertyAttr(required = false)] + WAREHOUSE, + [SFSessionPropertyAttr(required = false, defaultValue = "300")] + CONNECTION_TIMEOUT, + [SFSessionPropertyAttr(required = false, defaultValue = "snowflake")] + AUTHENTICATOR, + [SFSessionPropertyAttr(required = false, defaultValue = "true")] + VALIDATE_DEFAULT_PARAMETERS, + [SFSessionPropertyAttr(required = false)] + PRIVATE_KEY_FILE, + [SFSessionPropertyAttr(required = false)] + PRIVATE_KEY_PWD, + [SFSessionPropertyAttr(required = false)] + PRIVATE_KEY, + [SFSessionPropertyAttr(required = false)] + TOKEN, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + INSECUREMODE, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + USEPROXY, + [SFSessionPropertyAttr(required = false)] + PROXYHOST, + [SFSessionPropertyAttr(required = false)] + PROXYPORT, + [SFSessionPropertyAttr(required = false)] + PROXYUSER, + [SFSessionPropertyAttr(required = false)] + PROXYPASSWORD, + [SFSessionPropertyAttr(required = false)] + NONPROXYHOSTS, + [SFSessionPropertyAttr(required = false)] + APPLICATION, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + DISABLERETRY, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + FORCERETRYON404, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + CLIENT_SESSION_KEEP_ALIVE, + [SFSessionPropertyAttr(required = false)] + GCS_USE_DOWNSCOPED_CREDENTIAL, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + FORCEPARSEERROR, + [SFSessionPropertyAttr(required = false, defaultValue = "120")] + BROWSER_RESPONSE_TIMEOUT, + [SFSessionPropertyAttr(required = false, defaultValue = "300")] + RETRY_TIMEOUT, + [SFSessionPropertyAttr(required = false, defaultValue = "7")] + MAXHTTPRETRIES, + [SFSessionPropertyAttr(required = false)] + FILE_TRANSFER_MEMORY_THRESHOLD, + [SFSessionPropertyAttr(required = false, defaultValue = "true")] + INCLUDERETRYREASON, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + DISABLEQUERYCONTEXTCACHE, + [SFSessionPropertyAttr(required = false)] + CLIENT_CONFIG_FILE, + [SFSessionPropertyAttr(required = false, defaultValue = "true")] + DISABLE_CONSOLE_LOGIN, + [SFSessionPropertyAttr(required = false, defaultValue = "false")] + ALLOWUNDERSCORESINHOST, + [SFSessionPropertyAttr(required = false)] + QUERY_TAG + } + + class SFSessionPropertyAttr : Attribute + { + public bool required { get; set; } + + public string defaultValue { get; set; } + } + + class SFSessionProperties : Dictionary + { + private static SFLogger logger = SFLoggerFactory.GetLogger(); + + // Connection string properties to obfuscate in the log + private static List secretProps = + new List{ + SFSessionProperty.PASSWORD, + SFSessionProperty.PRIVATE_KEY, + SFSessionProperty.TOKEN, + SFSessionProperty.PRIVATE_KEY_PWD, + SFSessionProperty.PROXYPASSWORD, + }; + + private static readonly List s_accountRegexStrings = new List + { + "^\\w", + "\\w$", + "^[\\w.-]+$" + }; + + public override bool Equals(object obj) + { + if (obj == null) return false; + try + { + SFSessionProperties prop = (SFSessionProperties)obj; + foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) + { + if (this.ContainsKey(sessionProperty) ^ prop.ContainsKey(sessionProperty)) + { + return false; + } + if (!this.ContainsKey(sessionProperty)) + { + continue; + } + if (!this[sessionProperty].Equals(prop[sessionProperty])) + { + return false; + } + } + return true; + } + catch (InvalidCastException) + { + logger.Warn("Invalid casting to SFSessionProperties"); + return false; + } + } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + + internal static SFSessionProperties ParseConnectionString(string connectionString, SecureString password) + { + logger.Info("Start parsing connection string."); + var builder = new DbConnectionStringBuilder(); + try + { + builder.ConnectionString = connectionString; + } + catch (ArgumentException e) + { + logger.Warn("Invalid connectionString", e); + throw new SnowflakeDbException(e, + SFError.INVALID_CONNECTION_STRING, + e.Message); + } + var properties = new SFSessionProperties(); + + var keys = new string[builder.Keys.Count]; + var values = new string[builder.Values.Count]; + builder.Keys.CopyTo(keys, 0); + builder.Values.CopyTo(values,0); + + for(var i=0; i().required = true; + SFSessionProperty.PROXYPORT.GetAttribute().required = true; + + // If a username is provided, then a password is required + if (properties.ContainsKey(SFSessionProperty.PROXYUSER)) + { + SFSessionProperty.PROXYPASSWORD.GetAttribute().required = true; + } + } + + if (password != null) + { + properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; + } + + checkSessionProperties(properties); + ValidateFileTransferMaxBytesInMemoryProperty(properties); + ValidateAccountDomain(properties); + + var allowUnderscoresInHost = ParseAllowUnderscoresInHost(properties); + + // compose host value if not specified + if (!properties.ContainsKey(SFSessionProperty.HOST) || + (0 == properties[SFSessionProperty.HOST].Length)) + { + var compliantAccountName = properties[SFSessionProperty.ACCOUNT]; + if (!allowUnderscoresInHost && compliantAccountName.Contains('_')) + { + compliantAccountName = compliantAccountName.Replace('_', '-'); + logger.Info($"Replacing _ with - in the account name. Old: {properties[SFSessionProperty.ACCOUNT]}, new: {compliantAccountName}."); + } + var hostName = $"{compliantAccountName}.snowflakecomputing.com"; + // Remove in case it's here but empty + properties.Remove(SFSessionProperty.HOST); + properties.Add(SFSessionProperty.HOST, hostName); + logger.Info($"Compose host name: {hostName}"); + } + + // Trim the account name to remove the region and cloud platform if any were provided + // because the login request data does not expect region and cloud information to be + // passed on for account_name + properties[SFSessionProperty.ACCOUNT] = properties[SFSessionProperty.ACCOUNT].Split('.')[0]; + + return properties; + } + + private static void UpdatePropertiesForSpecialCases(SFSessionProperties properties, string connectionString) + { + var propertyEntry = connectionString.Split(';'); + foreach(var keyVal in propertyEntry) + { + if(keyVal.Length > 0) + { + var tokens = keyVal.Split(new string[] { "=" }, StringSplitOptions.None); + var propertyName = tokens[0].ToUpper(); + switch (propertyName) + { + case "DB": + case "SCHEMA": + case "WAREHOUSE": + case "ROLE": + { + if (tokens.Length == 2) + { + var sessionProperty = (SFSessionProperty)Enum.Parse( + typeof(SFSessionProperty), propertyName); + properties[sessionProperty]= tokens[1]; + } + + break; + } + case "USER": + case "PASSWORD": + { + + var sessionProperty = (SFSessionProperty)Enum.Parse( + typeof(SFSessionProperty), propertyName); + if (!properties.ContainsKey(sessionProperty)) + { + properties.Add(sessionProperty, ""); + } + + break; + } + } + } + } + } + + private static void ValidateAccountDomain(SFSessionProperties properties) + { + var account = properties[SFSessionProperty.ACCOUNT]; + if (string.IsNullOrEmpty(account)) + return; + if (IsAccountRegexMatched(account)) + return; + logger.Error($"Invalid account {account}"); + throw new SnowflakeDbException( + new Exception("Invalid account"), + SFError.INVALID_CONNECTION_PARAMETER_VALUE, + account, + SFSessionProperty.ACCOUNT); + } + + private static bool IsAccountRegexMatched(string account) => + s_accountRegexStrings + .Select(regex => Regex.Match(account, regex, RegexOptions.IgnoreCase)) + .All(match => match.Success); + + private static void checkSessionProperties(SFSessionProperties properties) + { + foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) + { + // if required property, check if exists in the dictionary + if (IsRequired(sessionProperty, properties) && + !properties.ContainsKey(sessionProperty)) + { + SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, + sessionProperty); + logger.Error("Missing connection property", e); + throw e; + } + + // add default value to the map + string defaultVal = sessionProperty.GetAttribute().defaultValue; + if (defaultVal != null && !properties.ContainsKey(sessionProperty)) + { + logger.Debug($"Sesssion property {sessionProperty} set to default value: {defaultVal}"); + properties.Add(sessionProperty, defaultVal); + } + } + } + + private static void ValidateFileTransferMaxBytesInMemoryProperty(SFSessionProperties properties) + { + if (!properties.TryGetValue(SFSessionProperty.FILE_TRANSFER_MEMORY_THRESHOLD, out var maxBytesInMemoryString)) + { + return; + } + + var propertyName = SFSessionProperty.FILE_TRANSFER_MEMORY_THRESHOLD.ToString(); + int maxBytesInMemory; + try + { + maxBytesInMemory = int.Parse(maxBytesInMemoryString); + } + catch (Exception e) + { + logger.Error($"Value for parameter {propertyName} could not be parsed"); + throw new SnowflakeDbException(e, SFError.INVALID_CONNECTION_PARAMETER_VALUE, maxBytesInMemoryString, propertyName); + } + + if (maxBytesInMemory <= 0) + { + logger.Error($"Value for parameter {propertyName} should be greater than 0"); + throw new SnowflakeDbException( + new Exception($"Value for parameter {propertyName} should be greater than 0"), + SFError.INVALID_CONNECTION_PARAMETER_VALUE, maxBytesInMemoryString, propertyName); + } + } + + private static bool IsRequired(SFSessionProperty sessionProperty, SFSessionProperties properties) + { + if (sessionProperty.Equals(SFSessionProperty.PASSWORD)) + { + var authenticatorDefined = + properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); + + var authenticatorsWithoutPassword = new List() + { + ExternalBrowserAuthenticator.AUTH_NAME, + KeyPairAuthenticator.AUTH_NAME, + OAuthAuthenticator.AUTH_NAME + }; + // External browser, jwt and oauth don't require a password for authenticating + return !authenticatorDefined || !authenticatorsWithoutPassword + .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); + } + else if (sessionProperty.Equals(SFSessionProperty.USER)) + { + var authenticatorDefined = + properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); + + var authenticatorsWithoutUsername = new List() + { + OAuthAuthenticator.AUTH_NAME, + ExternalBrowserAuthenticator.AUTH_NAME + }; + return !authenticatorDefined || !authenticatorsWithoutUsername + .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); + } + else + { + return sessionProperty.GetAttribute().required; + } + } + + private static bool ParseAllowUnderscoresInHost(SFSessionProperties properties) + { + var allowUnderscoresInHost = bool.Parse(SFSessionProperty.ALLOWUNDERSCORESINHOST.GetAttribute().defaultValue); + if (!properties.TryGetValue(SFSessionProperty.ALLOWUNDERSCORESINHOST, out var property)) + return allowUnderscoresInHost; + try + { + allowUnderscoresInHost = bool.Parse(property); + } + catch (Exception e) + { + logger.Warn("Unable to parse property 'allowUnderscoresInHost'", e); + } + + return allowUnderscoresInHost; + } + } + + public static class EnumExtensions + { + public static TAttribute GetAttribute(this Enum value) + where TAttribute : Attribute + { + var type = value.GetType(); + var memInfo = type.GetMember(value.ToString()); + var attributes = memInfo[0].GetCustomAttributes(typeof(TAttribute), false); + return (attributes.Length > 0) ? (TAttribute)attributes[0] : null; + } + } +}