diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index dd31f1c16..cd72bffce 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -14,7 +14,7 @@ namespace Snowflake.Data.Tests.UnitTests class SFSessionPropertyTest { - [Test, TestCaseSource("ConnectionStringTestCases")] + [Test, TestCaseSource(nameof(ConnectionStringTestCases))] public void TestThatPropertiesAreParsed(TestCase testcase) { // act @@ -29,7 +29,10 @@ public void TestThatPropertiesAreParsed(TestCase testcase) [Test] [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;FILE_TRANSFER_MEMORY_THRESHOLD=0;", "Error: Invalid parameter value 0 for FILE_TRANSFER_MEMORY_THRESHOLD")] [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;FILE_TRANSFER_MEMORY_THRESHOLD=xyz;", "Error: Invalid parameter value xyz for FILE_TRANSFER_MEMORY_THRESHOLD")] - public void TestThatItFailsForWrongFileTransferMaxBytesInMemoryParameter(string connectionString, string expectedErrorMessagePart) + [TestCase("ACCOUNT=testaccount?;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value testaccount? for ACCOUNT")] + [TestCase("ACCOUNT=complicated.long.testaccount?;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value complicated.long.testaccount? for ACCOUNT")] + [TestCase("ACCOUNT=?testaccount;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value ?testaccount for ACCOUNT")] + public void TestThatItFailsForWrongConnectionParameter(string connectionString, string expectedErrorMessagePart) { // act var exception = Assert.Throws( @@ -40,6 +43,20 @@ public void TestThatItFailsForWrongFileTransferMaxBytesInMemoryParameter(string Assert.AreEqual(SFError.INVALID_CONNECTION_PARAMETER_VALUE.GetAttribute().errorCode, exception.ErrorCode); Assert.IsTrue(exception.Message.Contains(expectedErrorMessagePart)); } + + [Test] + [TestCase("ACCOUNT=;USER=testuser;PASSWORD=testpassword")] + [TestCase("USER=testuser;PASSWORD=testpassword")] + public void TestThatItFailsIfNoAccountSpecified(string connectionString) + { + // act + var exception = Assert.Throws( + () => SFSessionProperties.parseConnectionString(connectionString, null) + ); + + // assert + Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute().errorCode, exception.ErrorCode); + } public static IEnumerable ConnectionStringTestCases() { @@ -260,6 +277,34 @@ public static IEnumerable ConnectionStringTestCases() ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true" }; + var complicatedAccount = $"{defAccount}.region-name.host-name"; + var testCaseComplicatedAccountName = new TestCase() + { + ConnectionString = $"ACCOUNT={complicatedAccount};USER={defUser};PASSWORD={defPassword};", + ExpectedProperties = new SFSessionProperties() + { + { SFSessionProperty.ACCOUNT, defAccount }, + { SFSessionProperty.USER, defUser }, + { SFSessionProperty.HOST, $"{complicatedAccount}.snowflakecomputing.com" }, + { SFSessionProperty.AUTHENTICATOR, defAuthenticator }, + { SFSessionProperty.SCHEME, defScheme }, + { SFSessionProperty.CONNECTION_TIMEOUT, defConnectionTimeout }, + { SFSessionProperty.PASSWORD, defPassword }, + { SFSessionProperty.PORT, defPort }, + { SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS, "true" }, + { SFSessionProperty.USEPROXY, "false" }, + { SFSessionProperty.INSECUREMODE, "false" }, + { SFSessionProperty.DISABLERETRY, "false" }, + { SFSessionProperty.FORCERETRYON404, "false" }, + { SFSessionProperty.CLIENT_SESSION_KEEP_ALIVE, "false" }, + { SFSessionProperty.FORCEPARSEERROR, "false" }, + { SFSessionProperty.BROWSER_RESPONSE_TIMEOUT, defBrowserResponseTime }, + { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout }, + { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries }, + { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, + { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache } + } + }; return new TestCase[] { simpleTestCase, @@ -268,7 +313,8 @@ public static IEnumerable ConnectionStringTestCases() testCaseThatDefaultForUseProxyIsFalse, testCaseWithFileTransferMaxBytesInMemory, testCaseWithIncludeRetryReason, - testCaseWithDisableQueryContextCache + testCaseWithDisableQueryContextCache, + testCaseComplicatedAccountName }; } diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 28ef2fa7e..b07015a88 100755 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -11,6 +11,7 @@ using Snowflake.Data.Core.Authenticator; using System.Data.Common; using System.Linq; +using System.Text.RegularExpressions; namespace Snowflake.Data.Core { @@ -101,10 +102,10 @@ class SFSessionPropertyAttr : Attribute class SFSessionProperties : Dictionary { - static private SFLogger logger = SFLoggerFactory.GetLogger(); + private static SFLogger logger = SFLoggerFactory.GetLogger(); // Connection string properties to obfuscate in the log - static private List secretProps = + private static List secretProps = new List{ SFSessionProperty.PASSWORD, SFSessionProperty.PRIVATE_KEY, @@ -112,6 +113,8 @@ class SFSessionProperties : Dictionary SFSessionProperty.PRIVATE_KEY_PWD, SFSessionProperty.PROXYPASSWORD, }; + + private const string AccountRegexString = "^\\w[\\w.-]+\\w$"; public override bool Equals(object obj) { @@ -251,7 +254,8 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin checkSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); - + ValidateAccountDomain(properties); + // compose host value if not specified if (!properties.ContainsKey(SFSessionProperty.HOST) || (0 == properties[SFSessionProperty.HOST].Length)) @@ -271,6 +275,22 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin return properties; } + private static void ValidateAccountDomain(SFSessionProperties properties) + { + var account = properties[SFSessionProperty.ACCOUNT]; + if (string.IsNullOrEmpty(account)) + return; + var match = Regex.Match(account, AccountRegexString, RegexOptions.IgnoreCase); + if (match.Success) + return; + logger.Error($"Invalid account {account}"); + throw new SnowflakeDbException( + new Exception("Invalid account"), + SFError.INVALID_CONNECTION_PARAMETER_VALUE, + account, + SFSessionProperty.ACCOUNT); + } + private static void checkSessionProperties(SFSessionProperties properties) { foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty)))