From 017c8c8d6bab99d0c9080f50b58fed482c6a3d5d Mon Sep 17 00:00:00 2001
From: sfc-gh-ext-simba-lf <lfarol@magnitude.com>
Date: Tue, 16 Jan 2024 10:31:51 -0800
Subject: [PATCH] SNOW-955536: Add multiple SAML integration

---
 .../IntegrationTests/SFConnectionIT.cs        | 44 +++++++++++
 .../UnitTests/SFSessionPropertyTest.cs        | 55 ++++++++++++--
 .../ExternalBrowserAuthenticator.cs           | 73 ++++++++++++++-----
 Snowflake.Data/Core/RestParams.cs             |  2 +
 Snowflake.Data/Core/Session/SFSession.cs      |  3 +
 .../Core/Session/SFSessionProperty.cs         |  4 +-
 6 files changed, 155 insertions(+), 26 deletions(-)

diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
index c13cd684b..29867c46b 100644
--- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
+++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs
@@ -894,6 +894,50 @@ public void TestSSOConnectionWithUserAsync()
             }
         }
         
+        [Test]
+        [Ignore("This test requires manual interaction and therefore cannot be run in CI")]
+        public void TestSSOConnectionWithUserAndDisableConsoleLogin()
+        {
+            // Use external browser to log in using proper password for qa@snowflakecomputing.com
+            using (IDbConnection conn = new SnowflakeDbConnection())
+            {
+                conn.ConnectionString
+                    = ConnectionStringWithoutAuth
+                    + ";authenticator=externalbrowser;user=qa@snowflakecomputing.com;disable_console_login=false;";
+                conn.Open();
+                Assert.AreEqual(ConnectionState.Open, conn.State);
+                using (IDbCommand command = conn.CreateCommand())
+                {
+                    command.CommandText = "SELECT CURRENT_USER()";
+                    Assert.AreEqual("QA", command.ExecuteScalar().ToString());
+                }
+            }
+        }
+
+        [Test]
+        [Ignore("This test requires manual interaction and therefore cannot be run in CI")]
+        public void TestSSOConnectionWithUserAsyncAndDisableConsoleLogin()
+        {
+            // Use external browser to log in using proper password for qa@snowflakecomputing.com
+            using (SnowflakeDbConnection conn = new SnowflakeDbConnection())
+            {
+                conn.ConnectionString
+                    = ConnectionStringWithoutAuth
+                      + ";authenticator=externalbrowser;user=qa@snowflakecomputing.com;disable_console_login=false;";
+
+                Task connectTask = conn.OpenAsync(CancellationToken.None);
+                connectTask.Wait();
+                Assert.AreEqual(ConnectionState.Open, conn.State);
+                using (DbCommand command = conn.CreateCommand())
+                {
+                    command.CommandText = "SELECT CURRENT_USER()";
+                    Task<object> task = command.ExecuteScalarAsync(CancellationToken.None);
+                    task.Wait(CancellationToken.None);
+                    Assert.AreEqual("QA", task.Result);
+                }
+            }
+        }
+
         [Test]
         [Ignore("This test requires manual interaction and therefore cannot be run in CI")]
         public void TestSSOConnectionTimeoutAfter10s()
diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs
index cd72bffce..f8dc08a94 100644
--- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs
+++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs
@@ -78,6 +78,7 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
             string defMaxHttpRetries = "7";
             string defIncludeRetryReason = "true";
             string defDisableQueryContextCache = "false";
+            string defDisableConsoleLogin = "true";
 
             var simpleTestCase = new TestCase()
             {
@@ -103,7 +104,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 }
             };
             var testCaseWithBrowserResponseTimeout = new TestCase()
@@ -129,7 +131,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 }
             };
             var testCaseWithProxySettings = new TestCase()
@@ -158,7 +161,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 },
                 ConnectionString =
                     $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};useProxy=true;proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost"
@@ -189,7 +193,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 },
                 ConnectionString =
                     $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost"
@@ -219,7 +224,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.FILE_TRANSFER_MEMORY_THRESHOLD, "25" },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 }
             };
             var testCaseWithIncludeRetryReason = new TestCase()
@@ -246,7 +252,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, "false" },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 }
             };
             var testCaseWithDisableQueryContextCache = new TestCase()
@@ -272,11 +279,41 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, "true" }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, "true" },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 },
                 ConnectionString =
                     $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true"
             };
+            var testCaseWithDisableConsoleLogin = new TestCase()
+            {
+                ExpectedProperties = new SFSessionProperties()
+                {
+                    { SFSessionProperty.ACCOUNT, defAccount },
+                    { SFSessionProperty.USER, defUser },
+                    { SFSessionProperty.HOST, defHost },
+                    { SFSessionProperty.AUTHENTICATOR, defAuthenticator },
+                    { SFSessionProperty.SCHEME, defScheme },
+                    { SFSessionProperty.CONNECTION_TIMEOUT, defConnectionTimeout },
+                    { SFSessionProperty.PASSWORD, defPassword },
+                    { SFSessionProperty.PORT, defPort },
+                    { SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS, "true" },
+                    { SFSessionProperty.USEPROXY, "false" },
+                    { SFSessionProperty.INSECUREMODE, "false" },
+                    { SFSessionProperty.DISABLERETRY, "false" },
+                    { SFSessionProperty.FORCERETRYON404, "false" },
+                    { SFSessionProperty.CLIENT_SESSION_KEEP_ALIVE, "false" },
+                    { SFSessionProperty.FORCEPARSEERROR, "false" },
+                    { SFSessionProperty.BROWSER_RESPONSE_TIMEOUT, defBrowserResponseTime },
+                    { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
+                    { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
+                    { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, "false" }
+                },
+                ConnectionString =
+                    $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLE_CONSOLE_LOGIN=false"
+            };
             var complicatedAccount = $"{defAccount}.region-name.host-name";
             var testCaseComplicatedAccountName = new TestCase()
             {
@@ -302,7 +339,8 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                     { SFSessionProperty.RETRY_TIMEOUT, defRetryTimeout },
                     { SFSessionProperty.MAXHTTPRETRIES, defMaxHttpRetries },
                     { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason },
-                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }
+                    { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache },
+                    { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }
                 }
             };
             return new TestCase[]
@@ -314,6 +352,7 @@ public static IEnumerable<TestCase> ConnectionStringTestCases()
                 testCaseWithFileTransferMaxBytesInMemory,
                 testCaseWithIncludeRetryReason,
                 testCaseWithDisableQueryContextCache,
+                testCaseWithDisableConsoleLogin,
                 testCaseComplicatedAccountName
             };
         }
diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs
index 3d7b19110..d6ead6818 100644
--- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs
+++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs
@@ -12,6 +12,7 @@
 using Snowflake.Data.Log;
 using Snowflake.Data.Client;
 using System.Text.RegularExpressions;
+using System.Collections.Generic;
 
 namespace Snowflake.Data.Core.Authenticator
 {
@@ -54,19 +55,28 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken)
                 httpListener.Start();
 
                 logger.Debug("Get IdpUrl and ProofKey");
-                var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort);
-                var authenticatorRestResponse =
-                    await session.restRequester.PostAsync<AuthenticatorResponse>(
-                        authenticatorRestRequest,
-                        cancellationToken
-                    ).ConfigureAwait(false);
-                authenticatorRestResponse.FilterFailedResponse();
+                string loginUrl;
+                if (session._disableConsoleLogin)
+                {
+                    var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort);
+                    var authenticatorRestResponse =
+                        await session.restRequester.PostAsync<AuthenticatorResponse>(
+                            authenticatorRestRequest,
+                            cancellationToken
+                        ).ConfigureAwait(false);
+                    authenticatorRestResponse.FilterFailedResponse();
 
-                var idpUrl = authenticatorRestResponse.data.ssoUrl;
-                _proofKey = authenticatorRestResponse.data.proofKey;
+                    loginUrl = authenticatorRestResponse.data.ssoUrl;
+                    _proofKey = authenticatorRestResponse.data.proofKey;
+                }
+                else
+                {
+                    _proofKey = GenerateProofKey();
+                    loginUrl = GetLoginUrl(_proofKey, localPort);
+                }
 
                 logger.Debug("Open browser");
-                StartBrowser(idpUrl);
+                StartBrowser(loginUrl);
 
                 logger.Debug("Get the redirect SAML request");
                 _successEvent = new ManualResetEvent(false);
@@ -96,15 +106,24 @@ void IAuthenticator.Authenticate()
                 httpListener.Start();
 
                 logger.Debug("Get IdpUrl and ProofKey");
-                var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort);
-                var authenticatorRestResponse = session.restRequester.Post<AuthenticatorResponse>(authenticatorRestRequest);
-                authenticatorRestResponse.FilterFailedResponse();
+                string loginUrl;
+                if (session._disableConsoleLogin)
+                {
+                    var authenticatorRestRequest = BuildAuthenticatorRestRequest(localPort);
+                    var authenticatorRestResponse = session.restRequester.Post<AuthenticatorResponse>(authenticatorRestRequest);
+                    authenticatorRestResponse.FilterFailedResponse();
 
-                var idpUrl = authenticatorRestResponse.data.ssoUrl;
-                _proofKey = authenticatorRestResponse.data.proofKey;
+                    loginUrl = authenticatorRestResponse.data.ssoUrl;
+                    _proofKey = authenticatorRestResponse.data.proofKey;
+                }
+                else
+                {
+                    _proofKey = GenerateProofKey();
+                    loginUrl = GetLoginUrl(_proofKey, localPort);
+                }
 
                 logger.Debug("Open browser");
-                StartBrowser(idpUrl);
+                StartBrowser(loginUrl);
 
                 logger.Debug("Get the redirect SAML request");
                 _successEvent = new ManualResetEvent(false);
@@ -187,7 +206,7 @@ private static void StartBrowser(string url)
             // The following code is learnt from https://brockallen.com/2016/09/24/process-start-for-urls-on-net-core/
 #if NETFRAMEWORK
             // .net standard would pass here
-            Process.Start(url);
+            Process.Start(new ProcessStartInfo(url) { UseShellExecute = true });
 #else
             // hack because of this: https://github.com/dotnet/corefx/issues/10361
             if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
@@ -247,5 +266,25 @@ protected override void SetSpecializedAuthenticatorData(ref LoginRequestData dat
             data.Token = _samlResponseToken;
             data.ProofKey = _proofKey;
         }
+
+        private string GetLoginUrl(string proofKey, int localPort)
+        {
+            Dictionary<string, string> parameters = new Dictionary<string, string>()
+            {
+                { "login_name", session.properties[SFSessionProperty.USER]},
+                { "proof_key", proofKey },
+                { "browser_mode_redirect_port", localPort.ToString() }
+            };
+            Uri loginUrl = session.BuildUri(RestPath.SF_CONSOLE_LOGIN, parameters);
+            return loginUrl.ToString();
+        }
+
+        private string GenerateProofKey()
+        {
+            Random rnd = new Random();
+            Byte[] randomness = new Byte[32];
+            rnd.NextBytes(randomness);
+            return Convert.ToBase64String(randomness);
+        }
     }
 }
diff --git a/Snowflake.Data/Core/RestParams.cs b/Snowflake.Data/Core/RestParams.cs
index 11bca113f..1188affb0 100644
--- a/Snowflake.Data/Core/RestParams.cs
+++ b/Snowflake.Data/Core/RestParams.cs
@@ -40,6 +40,8 @@ internal static class RestPath
         internal const string SF_QUERY_PATH = "/queries/v1/query-request";
 
         internal const string SF_SESSION_HEARTBEAT_PATH = SF_SESSION_PATH + "/heartbeat";
+
+        internal const string SF_CONSOLE_LOGIN = "/console/login";
     }
 
     internal class SFEnvironment
diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs
index cb9e4f0fe..2ad440407 100755
--- a/Snowflake.Data/Core/Session/SFSession.cs
+++ b/Snowflake.Data/Core/Session/SFSession.cs
@@ -77,6 +77,8 @@ public class SFSession
 
         private bool _disableQueryContextCache = false;
 
+        internal bool _disableConsoleLogin;
+
         internal void ProcessLoginResponse(LoginResponse authnResponse)
         {
             if (authnResponse.success)
@@ -148,6 +150,7 @@ internal SFSession(
             connStr = connectionString;
             properties = SFSessionProperties.parseConnectionString(connectionString, password);
             _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]);
+            _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]);
             ValidateApplicationName(properties);
             try
             {
diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs
index b07015a88..c2e8e9d55 100755
--- a/Snowflake.Data/Core/Session/SFSessionProperty.cs
+++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs
@@ -90,7 +90,9 @@ internal enum SFSessionProperty
         [SFSessionPropertyAttr(required = false, defaultValue = "false")]
         DISABLEQUERYCONTEXTCACHE,
         [SFSessionPropertyAttr(required = false)]
-        CLIENT_CONFIG_FILE
+        CLIENT_CONFIG_FILE,
+        [SFSessionPropertyAttr(required = false, defaultValue = "true")]
+        DISABLE_CONSOLE_LOGIN
     }
 
     class SFSessionPropertyAttr : Attribute