Skip to content

Commit

Permalink
Applying PR suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jmartinezramirez committed Aug 28, 2024
1 parent 99a1d4c commit 4c36d9c
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private static void MockHomeDirectoryReturnsNull()
private static void MockFileFromEnvironmentalVariable()
{
t_environmentOperations
.Setup(e => e.GetEnvironmentVariable(EasyLoggingConfigFinder.ClientConfigEnvironmentName, null))
.Setup(e => e.GetEnvironmentVariable(EasyLoggingConfigFinder.ClientConfigEnvironmentName))
.Returns(EnvironmentalConfigFilePath);
}

Expand Down
16 changes: 8 additions & 8 deletions Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionTest.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@


using System;
using System.IO;
using Mono.Unix;

namespace Snowflake.Data.Tests.UnitTests
{
using Core;
Expand All @@ -16,21 +20,19 @@ public void TestFillConnectionStringFromTomlConfig()
// Arrange
var mockFileOperations = new Mock<FileOperations>();
var mockEnvironmentOperations = new Mock<EnvironmentOperations>();
mockEnvironmentOperations.Setup(e => e.GetEnvironmentVariable(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string v, string d) => d);
mockEnvironmentOperations.Setup(e => e.GetFolderPath(Environment.SpecialFolder.UserProfile))
.Returns($"{Path.DirectorySeparatorChar}home");
mockFileOperations.Setup(f => f.Exists(It.IsAny<string>())).Returns(true);
mockFileOperations.Setup(f => f.ReadAllText(It.IsAny<string>()))
mockFileOperations.Setup(f => f.ReadAllText(It.IsAny<string>(), It.IsAny<Action<UnixStream>>()))
.Returns("[default]\naccount=\"testaccount\"\nuser=\"testuser\"\npassword=\"testpassword\"\n");
var tomlConnectionBuilder = new SnowflakeTomlConnectionBuilder(mockFileOperations.Object, mockEnvironmentOperations.Object);

// Act
using (var conn = new SnowflakeDbConnection(tomlConnectionBuilder))
{
conn.ConnectionString = "account=user1account;user=user1;password=user1password;";
conn.FillConnectionStringFromTomlConfigIfNotSet();
// Assert
Assert.AreNotEqual("account=testaccount;user=testuser;password=testpassword;", conn.ConnectionString);
Assert.AreNotEqual("account=testaccount;user=testuser;password=testpassword;", conn.ConnectionString);
Assert.AreEqual("account=testaccount;user=testuser;password=testpassword;", conn.ConnectionString);
}
}

Expand All @@ -41,8 +43,6 @@ public void TestFillConnectionStringFromTomlConfigShouldNotBeExecutedIfAlreadySe
var connectionTest = "account=user1account;user=user1;password=user1password;";
var mockFileOperations = new Mock<FileOperations>();
var mockEnvironmentOperations = new Mock<EnvironmentOperations>();
mockEnvironmentOperations.Setup(e => e.GetEnvironmentVariable(It.IsAny<string>(), It.IsAny<string>()))
.Returns((string v, string d) => d);
mockFileOperations.Setup(f => f.Exists(It.IsAny<string>())).Returns(true);
mockFileOperations.Setup(f => f.ReadAllText(It.IsAny<string>()))
.Returns("[default]\naccount=\"testaccount\"\nuser=\"testuser\"\npassword=\"testpassword\"\n");
Expand Down
113 changes: 35 additions & 78 deletions Snowflake.Data.Tests/UnitTests/SnowflakeTomlConnectionBuilderTest.cs

Large diffs are not rendered by default.

24 changes: 20 additions & 4 deletions Snowflake.Data.Tests/UnitTests/Tools/FileOperationsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*/


using System;

namespace Snowflake.Data.Tests.Tools
{
using System.IO;
Expand Down Expand Up @@ -49,7 +51,7 @@ public void TestReadAllTextOnWindows()
var filePath = CreateConfigTempFile(s_workingDirectory, content);

// act
var result = s_fileOperations.ReadAllText(filePath);
var result = s_fileOperations.ReadAllText(filePath, GetTestFileValidation());

// assert
Assert.AreEqual(content, result);
Expand All @@ -69,14 +71,14 @@ public void TestReadAllTextCheckingPermissions()
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act
var result = s_fileOperations.ReadAllText(filePath);
var result = s_fileOperations.ReadAllText(filePath, GetTestFileValidation());

// assert
Assert.AreEqual(content, result);
}

[Test]
public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadAllText()
public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadConfigurationFile()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
Expand All @@ -89,8 +91,22 @@ public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadAllText()
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act and assert
Assert.Throws<SecurityException>(() => s_fileOperations.ReadAllText(filePath),
Assert.Throws<SecurityException>(() => s_fileOperations.ReadAllText(filePath, GetTestFileValidation()),
"Attempting to read a file with too broad permissions assigned");
}

private Action<UnixStream> GetTestFileValidation()
{
return stream =>
{
const FileAccessPermissions forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute | FileAccessPermissions.GroupReadWriteExecute;
if (stream.OwnerUser.UserId != Syscall.geteuid())
throw new SecurityException("Attempting to read a file not owned by the effective user of the current process");
if (stream.OwnerGroup.GroupId != Syscall.getegid())
throw new SecurityException("Attempting to read a file not owned by the effective group of the current process");
if ((stream.FileAccessPermissions & forbiddenPermissions) != 0)
throw new SecurityException("Attempting to read a file with too broad permissions assigned");
};
}
}
}
8 changes: 4 additions & 4 deletions Snowflake.Data.Tests/UnitTests/Tools/UnixOperationsTest.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Security;
using Mono.Unix;
using Mono.Unix.Native;
using NUnit.Framework;
using Snowflake.Data.Core;
using Snowflake.Data.Core.Tools;
using static Snowflake.Data.Tests.UnitTests.Configuration.EasyLoggingConfigGenerator;

namespace Snowflake.Data.Tests.Tools
{
using System.Security;

[TestFixture, NonParallelizable]
public class UnixOperationsTest
{
Expand Down Expand Up @@ -96,7 +96,7 @@ public void TestReadAllTextCheckingPermissions()
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act
var result = s_unixOperations.ReadAllText(filePath);
var result = s_unixOperations.ReadAllText(filePath, SnowflakeTomlConnectionBuilder.GetFileValidations());

// assert
Assert.AreEqual(content, result);
Expand All @@ -115,7 +115,7 @@ public void TestShouldThrowExceptionIfOtherPermissionsIsSetWhenReadAllText()
Syscall.chmod(filePath, (FilePermissions)filePermissions);

// act and assert
Assert.Throws<SecurityException>(() => s_unixOperations.ReadAllText(filePath), "Attempting to read a file with too broad permissions assigned");
Assert.Throws<SecurityException>(() => s_unixOperations.ReadAllText(filePath, SnowflakeTomlConnectionBuilder.GetFileValidations()), "Attempting to read a file with too broad permissions assigned");
}

public static IEnumerable<FilePermissions> UserPermissions()
Expand Down
20 changes: 10 additions & 10 deletions Snowflake.Data/Client/SnowflakeDbConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
* Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
*/

using System;
using System.Data;
using System.Data.Common;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
using Snowflake.Data.Core;
using Snowflake.Data.Log;

namespace Snowflake.Data.Client
{
using System;
using System.Data.Common;
using Snowflake.Data.Core;
using System.Security;
using System.Threading.Tasks;
using System.Data;
using System.Threading;
using Snowflake.Data.Log;

[System.ComponentModel.DesignerCategory("Code")]
public class SnowflakeDbConnection : DbConnection
{
Expand Down Expand Up @@ -46,7 +46,7 @@ protected enum TransactionRollbackStatus
Failure
}

public SnowflakeDbConnection() : this(new SnowflakeTomlConnectionBuilder())
public SnowflakeDbConnection() : this(SnowflakeTomlConnectionBuilder.Instance)
{
}

Expand Down
8 changes: 3 additions & 5 deletions Snowflake.Data/Core/EnvironmentVariables.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
// <copyright file="EnvironmentVariables.cs" company="Snowflake Inc">
// Copyright (c) 2019-2023 Snowflake Inc. All rights reserved.
// </copyright>
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

namespace Snowflake.Data.Core
{
public static class EnvironmentVariables
{
public static string SnowflakeDefaultConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME";
public static string SnowflakeHome = "SNOWFLAKE_HOME";
}
}
81 changes: 56 additions & 25 deletions Snowflake.Data/Core/SnowflakeTomlConnectionBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
// <copyright file="SnowflakeTomlConnectionBuilder.cs" company="Snowflake Inc">
// Copyright (c) 2024 Snowflake Inc. All rights reserved.
// </copyright>
/*
* Copyright (c) 2024 Snowflake Computing Inc. All rights reserved.
*/

using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
using Mono.Unix;
using Mono.Unix.Native;
using Snowflake.Data.Client;
using Snowflake.Data.Core.Tools;
using Snowflake.Data.Log;
using Tomlyn;
using Tomlyn.Model;

namespace Snowflake.Data.Core
{
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using Client;
using Log;
using Tomlyn;
using Tomlyn.Model;
using Tools;

public class SnowflakeTomlConnectionBuilder
internal class SnowflakeTomlConnectionBuilder
{
private const string DefaultConnectionName = "default";
private const string DefaultSnowflakeFolder = ".snowflake";
private const string DefaultTokenPath = "/snowflake/session/token";

internal const string SnowflakeDefaultConnectionName = "SNOWFLAKE_DEFAULT_CONNECTION_NAME";
internal const string SnowflakeHome = "SNOWFLAKE_HOME";

private readonly SFLogger _logger = SFLoggerFactory.GetLogger<SnowflakeDbConnection>();

private readonly Dictionary<string, string> _tomlToNetPropertiesMapper = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase)
Expand All @@ -30,7 +37,9 @@ public class SnowflakeTomlConnectionBuilder
private readonly FileOperations _fileOperations;
private readonly EnvironmentOperations _environmentOperations;

public SnowflakeTomlConnectionBuilder() : this(FileOperations.Instance, EnvironmentOperations.Instance)
internal static readonly SnowflakeTomlConnectionBuilder Instance = new SnowflakeTomlConnectionBuilder();

internal SnowflakeTomlConnectionBuilder() : this(FileOperations.Instance, EnvironmentOperations.Instance)
{
}

Expand Down Expand Up @@ -59,17 +68,28 @@ private string GetConnectionStringFromTomlTable(TomlTable connectionToml)
var isOauth = connectionToml.TryGetValue("authenticator", out var authenticator) && authenticator.ToString().Equals("oauth");
foreach (var property in connectionToml.Keys)
{
var propertyValue = (string)connectionToml[property];
if (isOauth && property.Equals("token_file_path", StringComparison.InvariantCultureIgnoreCase))
{
tokenFilePathValue = (string)connectionToml[property];
tokenFilePathValue = propertyValue;
continue;
}
var mappedProperty = _tomlToNetPropertiesMapper.TryGetValue(property, out var mapped) ? mapped : property;
connectionStringBuilder.Append($"{mappedProperty}={(string)connectionToml[property]};");
connectionStringBuilder.Append($"{mappedProperty}={propertyValue};");
}

AppendTokenFromFileIfNotGivenExplicitly(connectionToml, isOauth, connectionStringBuilder, tokenFilePathValue);

return connectionStringBuilder.ToString();
}

private void AppendTokenFromFileIfNotGivenExplicitly(TomlTable connectionToml, bool isOauth,
StringBuilder connectionStringBuilder, string tokenFilePathValue)
{
if (!isOauth || connectionToml.ContainsKey("token"))
return connectionStringBuilder.ToString();
{
return;
}

var token = LoadTokenFromFile(tokenFilePathValue);
if (!string.IsNullOrEmpty(token))
Expand All @@ -80,16 +100,13 @@ private string GetConnectionStringFromTomlTable(TomlTable connectionToml)
{
_logger.Warn("The token has empty value");
}


return connectionStringBuilder.ToString();
}

private string LoadTokenFromFile(string tokenFilePathValue)
{
var tokenFile = !string.IsNullOrEmpty(tokenFilePathValue) && _fileOperations.Exists(tokenFilePathValue) ? tokenFilePathValue : DefaultTokenPath;
_logger.Debug($"Read token from file path: {tokenFile}");
return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile) : null;
return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile, GetFileValidations()) : null;
}

private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName)
Expand All @@ -99,11 +116,11 @@ private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName)
return null;
}

var tomlContent = _fileOperations.ReadAllText(tomlPath) ?? string.Empty;
var tomlContent = _fileOperations.ReadAllText(tomlPath, GetFileValidations()) ?? string.Empty;
var toml = Toml.ToModel(tomlContent);
if (string.IsNullOrEmpty(connectionName))
{
connectionName = _environmentOperations.GetEnvironmentVariable(EnvironmentVariables.SnowflakeDefaultConnectionName, DefaultConnectionName);
connectionName = _environmentOperations.GetEnvironmentVariable(SnowflakeDefaultConnectionName) ?? DefaultConnectionName;
}

var connectionExists = toml.TryGetValue(connectionName, out var connection);
Expand All @@ -119,10 +136,24 @@ private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName)
private string ResolveConnectionTomlFile()
{
var defaultDirectory = Path.Combine(HomeDirectoryProvider.HomeDirectory(_environmentOperations), DefaultSnowflakeFolder);
var tomlFolder = _environmentOperations.GetEnvironmentVariable(EnvironmentVariables.SnowflakeHome, defaultDirectory);
var tomlFolder = _environmentOperations.GetEnvironmentVariable(SnowflakeHome) ?? defaultDirectory;
var tomlPath = Path.Combine(tomlFolder, "connections.toml");
tomlPath = Path.GetFullPath(tomlPath);
return tomlPath;
}

internal static Action<UnixStream> GetFileValidations()
{
return stream =>
{
const FileAccessPermissions forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute | FileAccessPermissions.GroupReadWriteExecute;
if (stream.OwnerUser.UserId != Syscall.geteuid())
throw new SecurityException("Attempting to read a file not owned by the effective user of the current process");
if (stream.OwnerGroup.GroupId != Syscall.getegid())
throw new SecurityException("Attempting to read a file not owned by the effective group of the current process");
if ((stream.FileAccessPermissions & forbiddenPermissions) != 0)
throw new SecurityException("Attempting to read a file with too broad permissions assigned");
};
}
}
}
4 changes: 2 additions & 2 deletions Snowflake.Data/Core/Tools/EnvironmentOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ internal class EnvironmentOperations
public static readonly EnvironmentOperations Instance = new EnvironmentOperations();
private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<EnvironmentOperations>();

public virtual string GetEnvironmentVariable(string variable, string defaultValue = null)
public virtual string GetEnvironmentVariable(string variable)
{
return Environment.GetEnvironmentVariable(variable) ?? defaultValue;
return Environment.GetEnvironmentVariable(variable);
}

public virtual string GetFolderPath(Environment.SpecialFolder folder)
Expand Down
8 changes: 5 additions & 3 deletions Snowflake.Data/Core/Tools/FileOperations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
* Copyright (c) 2023 Snowflake Computing Inc. All rights reserved.
*/

using System;
using System.IO;
using System.Linq;

namespace Snowflake.Data.Core.Tools
{
Expand All @@ -21,12 +23,12 @@ public virtual bool Exists(string path)

public virtual string ReadAllText(string path)
{
return ReadAllText(path, FileAccessPermissions.OtherReadWriteExecute);
return ReadAllText(path, null);
}

public virtual string ReadAllText(string path, FileAccessPermissions? forbiddenPermissions)
public virtual string ReadAllText(string path, Action<UnixStream> validation)
{
var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(path) : _unixOperations.ReadAllText(path, forbiddenPermissions);
var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) || validation == null ? File.ReadAllText(path) : _unixOperations.ReadAllText(path, validation);
return contentFile;
}
}
Expand Down
Loading

0 comments on commit 4c36d9c

Please sign in to comment.