diff --git a/Snowflake.Data/Core/SnowflakeTomlConnectionBuilder.cs b/Snowflake.Data/Core/SnowflakeTomlConnectionBuilder.cs index 9d702be47..020c4f0a3 100644 --- a/Snowflake.Data/Core/SnowflakeTomlConnectionBuilder.cs +++ b/Snowflake.Data/Core/SnowflakeTomlConnectionBuilder.cs @@ -7,6 +7,7 @@ namespace Snowflake.Data.Core using System; using System.Collections.Generic; using System.IO; + using System.Linq; using System.Text; using Tomlyn; using Tomlyn.Model; @@ -16,12 +17,15 @@ public class SnowflakeTomlConnectionBuilder { private const string DefaultConnectionName = "default"; private const string DefaultSnowflakeFolder = ".snowflake"; + private const string DefaultTokenPath = "/snowflake/session/token"; private Dictionary TomlToNetPropertiesMapper = new Dictionary(StringComparer.InvariantCultureIgnoreCase) { { "DATABASE", "DB" } }; + + private readonly FileOperations _fileOperations; private readonly EnvironmentOperations _environmentOperations; @@ -50,15 +54,42 @@ public string GetConnectionStringFromToml(string connectionName = null) private string GetConnectionStringFromTomlTable(TomlTable connectionToml) { var connectionStringBuilder = new StringBuilder(); + var tokenFilePathValue = string.Empty; + var isOauth = connectionToml.TryGetValue("authenticator", out var authenticator) && authenticator.ToString().Equals("oauth"); foreach (var property in connectionToml.Keys) { + if (isOauth && property.Equals("token_file_path", StringComparison.InvariantCultureIgnoreCase)) + { + tokenFilePathValue = (string)connectionToml[property]; + continue; + } var mappedProperty = TomlToNetPropertiesMapper.TryGetValue(property, out var mapped) ? mapped : property; connectionStringBuilder.Append($"{mappedProperty}={(string)connectionToml[property]};"); } + if (!isOauth || connectionToml.ContainsKey("token")) + return connectionStringBuilder.ToString(); + + var token = LoadTokenFromFile(tokenFilePathValue); + if (!string.IsNullOrEmpty(token)) + { + connectionStringBuilder.Append($"token={token};"); + } + else + { + // log warning TODO + } + + return connectionStringBuilder.ToString(); } + private string LoadTokenFromFile(string tokenFilePathValue) + { + var tokenFile = _fileOperations.Exists(tokenFilePathValue) ? tokenFilePathValue : DefaultTokenPath; + return _fileOperations.Exists(tokenFile) ? _fileOperations.ReadAllText(tokenFile) : null; + } + private TomlTable GetTomlTableFromConfig(string tomlPath, string connectionName) { TomlTable result = null; diff --git a/Snowflake.Data/Core/Tools/FileOperations.cs b/Snowflake.Data/Core/Tools/FileOperations.cs index 8b8311629..8953a8d6e 100644 --- a/Snowflake.Data/Core/Tools/FileOperations.cs +++ b/Snowflake.Data/Core/Tools/FileOperations.cs @@ -7,6 +7,7 @@ namespace Snowflake.Data.Core.Tools { using System.Runtime.InteropServices; + using Mono.Unix; internal class FileOperations { @@ -20,7 +21,12 @@ public virtual bool Exists(string path) public virtual string ReadAllText(string path) { - var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(path) : _unixOperations.ReadAllText(path); + return ReadAllText(path, FileAccessPermissions.OtherReadWriteExecute); + } + + public virtual string ReadAllText(string path, FileAccessPermissions? forbiddenPermissions) + { + var contentFile = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? File.ReadAllText(path) : _unixOperations.ReadAllText(path, forbiddenPermissions); return contentFile; } } diff --git a/Snowflake.Data/Core/Tools/UnixOperations.cs b/Snowflake.Data/Core/Tools/UnixOperations.cs index 7c32da863..dab753067 100644 --- a/Snowflake.Data/Core/Tools/UnixOperations.cs +++ b/Snowflake.Data/Core/Tools/UnixOperations.cs @@ -40,7 +40,7 @@ public virtual bool CheckFileHasAnyOfPermissions(string path, FileAccessPermissi /// The content of the file as a string. /// Thrown if the file is not owned by the effective user or group, or if it has forbidden permissions. - public string ReadAllText(string path, FileAccessPermissions forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute) + public string ReadAllText(string path, FileAccessPermissions? forbiddenPermissions = FileAccessPermissions.OtherReadWriteExecute) { var fileInfo = new UnixFileInfo(path: path); @@ -50,7 +50,7 @@ public string ReadAllText(string path, FileAccessPermissions forbiddenPermission throw new SecurityException("Attempting to read a file not owned by the effective user of the current process"); if (handle.OwnerGroup.GroupId != Syscall.getegid()) throw new SecurityException("Attempting to read a file not owned by the effective group of the current process"); - if ((handle.FileAccessPermissions & forbiddenPermissions) != 0) + if (forbiddenPermissions.HasValue && (handle.FileAccessPermissions & forbiddenPermissions.Value) != 0) throw new SecurityException("Attempting to read a file with too broad permissions assigned"); using (var streamReader = new StreamReader(handle, Encoding.Default)) {