Skip to content

Commit

Permalink
SNOW-944715 Validate account domain (#820)
Browse files Browse the repository at this point in the history
### Description
Validate account domain to prevent SSRF.

### Checklist
- [x] Code compiles correctly
- [x] Code is formatted according to [Coding
Conventions](../CodingConventions.md)
- [x] Created tests which fail without the change (if possible)
- [x] All tests passing (`dotnet test`)
- [x] Extended the README / documentation, if necessary
- [x] Provide JIRA issue id (if possible) or GitHub issue id in PR name
  • Loading branch information
sfc-gh-knozderko authored Dec 1, 2023
1 parent 2dce725 commit ecdefc6
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
52 changes: 49 additions & 3 deletions Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Snowflake.Data.Tests.UnitTests
class SFSessionPropertyTest
{

[Test, TestCaseSource("ConnectionStringTestCases")]
[Test, TestCaseSource(nameof(ConnectionStringTestCases))]
public void TestThatPropertiesAreParsed(TestCase testcase)
{
// act
Expand All @@ -29,7 +29,10 @@ public void TestThatPropertiesAreParsed(TestCase testcase)
[Test]
[TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;FILE_TRANSFER_MEMORY_THRESHOLD=0;", "Error: Invalid parameter value 0 for FILE_TRANSFER_MEMORY_THRESHOLD")]
[TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;FILE_TRANSFER_MEMORY_THRESHOLD=xyz;", "Error: Invalid parameter value xyz for FILE_TRANSFER_MEMORY_THRESHOLD")]
public void TestThatItFailsForWrongFileTransferMaxBytesInMemoryParameter(string connectionString, string expectedErrorMessagePart)
[TestCase("ACCOUNT=testaccount?;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value testaccount? for ACCOUNT")]
[TestCase("ACCOUNT=complicated.long.testaccount?;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value complicated.long.testaccount? for ACCOUNT")]
[TestCase("ACCOUNT=?testaccount;USER=testuser;PASSWORD=testpassword", "Error: Invalid parameter value ?testaccount for ACCOUNT")]
public void TestThatItFailsForWrongConnectionParameter(string connectionString, string expectedErrorMessagePart)
{
// act
var exception = Assert.Throws<SnowflakeDbException>(
Expand All @@ -40,6 +43,20 @@ public void TestThatItFailsForWrongFileTransferMaxBytesInMemoryParameter(string
Assert.AreEqual(SFError.INVALID_CONNECTION_PARAMETER_VALUE.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
Assert.IsTrue(exception.Message.Contains(expectedErrorMessagePart));
}

[Test]
[TestCase("ACCOUNT=;USER=testuser;PASSWORD=testpassword")]
[TestCase("USER=testuser;PASSWORD=testpassword")]
public void TestThatItFailsIfNoAccountSpecified(string connectionString)
{
// act
var exception = Assert.Throws<SnowflakeDbException>(
() => SFSessionProperties.parseConnectionString(connectionString, null)
);

// assert
Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute<SFErrorAttr>().errorCode, exception.ErrorCode);
}

public static IEnumerable<TestCase> ConnectionStringTestCases()
{
Expand Down Expand Up @@ -260,6 +277,34 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
ConnectionString =
$"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true"
};
var complicatedAccount = $"{defAccount}.region-name.host-name";
var testCaseComplicatedAccountName = new TestCase()
{
ConnectionString = $"ACCOUNT={complicatedAccount};USER={defUser};PASSWORD={defPassword};",
ExpectedProperties = new SFSessionProperties()
{
{ SFSessionProperty.ACCOUNT, defAccount },
{ SFSessionProperty.USER, defUser },
{ SFSessionProperty.HOST, $"{complicatedAccount}.snowflakecomputing.com" },
{ SFSessionProperty.AUTHENTICATOR, defAuthenticator },
{ SFSessionProperty.SCHEME, defScheme },
{ SFSessionProperty.CONNECTION_TIMEOUT, defConnectionTimeout },
{ SFSessionProperty.PASSWORD, defPassword },
{ SFSessionProperty.PORT, defPort },
{ SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS, "true" },
{ SFSessionProperty.USEPROXY, "false" },
{ SFSessionProperty.INSECUREMODE, "false" },
{ SFSessionProperty.DISABLERETRY, "false" },
{ SFSessionProperty.FORCERETRYON404, "false" },
{ SFSessionProperty.CLIENT_SESSION_KEEP_ALIVE, "false" },
{ SFSessionProperty.FORCEPARSEERROR, "false" },
{ SFSessionProperty.BROWSER_RESPONSE_TIMEOUT, defBrowserResponseTime },
{ SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
{ SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
{ SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
{ SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
}
};
return new TestCase[]
{
simpleTestCase,
Expand All @@ -268,7 +313,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
testCaseThatDefaultForUseProxyIsFalse,
testCaseWithFileTransferMaxBytesInMemory,
testCaseWithIncludeRetryReason,
testCaseWithDisableQueryContextCache
testCaseWithDisableQueryContextCache,
testCaseComplicatedAccountName
};
}

Expand Down
26 changes: 23 additions & 3 deletions Snowflake.Data/Core/Session/SFSessionProperty.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Snowflake.Data.Core.Authenticator;
using System.Data.Common;
using System.Linq;
using System.Text.RegularExpressions;

namespace Snowflake.Data.Core
{
Expand Down Expand Up @@ -101,17 +102,19 @@ class SFSessionPropertyAttr : Attribute

class SFSessionProperties : Dictionary<SFSessionProperty, String>
{
static private SFLogger logger = SFLoggerFactory.GetLogger<SFSessionProperties>();
private static SFLogger logger = SFLoggerFactory.GetLogger<SFSessionProperties>();

// Connection string properties to obfuscate in the log
static private List<SFSessionProperty> secretProps =
private static List<SFSessionProperty> secretProps =
new List<SFSessionProperty>{
SFSessionProperty.PASSWORD,
SFSessionProperty.PRIVATE_KEY,
SFSessionProperty.TOKEN,
SFSessionProperty.PRIVATE_KEY_PWD,
SFSessionProperty.PROXYPASSWORD,
};

private const string AccountRegexString = "^\\w[\\w.-]+\\w$";

public override bool Equals(object obj)
{
Expand Down Expand Up @@ -251,7 +254,8 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin

checkSessionProperties(properties);
ValidateFileTransferMaxBytesInMemoryProperty(properties);

ValidateAccountDomain(properties);

// compose host value if not specified
if (!properties.ContainsKey(SFSessionProperty.HOST) ||
(0 == properties[SFSessionProperty.HOST].Length))
Expand All @@ -271,6 +275,22 @@ internal static SFSessionProperties parseConnectionString(String connectionStrin
return properties;
}

private static void ValidateAccountDomain(SFSessionProperties properties)
{
var account = properties[SFSessionProperty.ACCOUNT];
if (string.IsNullOrEmpty(account))
return;
var match = Regex.Match(account, AccountRegexString, RegexOptions.IgnoreCase);
if (match.Success)
return;
logger.Error($"Invalid account {account}");
throw new SnowflakeDbException(
new Exception("Invalid account"),
SFError.INVALID_CONNECTION_PARAMETER_VALUE,
account,
SFSessionProperty.ACCOUNT);
}

private static void checkSessionProperties(SFSessionProperties properties)
{
foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty)))
Expand Down

0 comments on commit ecdefc6

Please sign in to comment.