diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 430698154..171778d11 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -60,7 +60,7 @@ jobs: run: | cd Snowflake.Data.Tests dotnet restore - dotnet build -f ${{ matrix.dotnet }} + dotnet build -f ${{ matrix.dotnet }} '-p:DefineAdditionalConstants=SF_PUBLIC_ENVIRONMENT' - name: Run Tests run: | cd Snowflake.Data.Tests @@ -119,7 +119,7 @@ jobs: - name: Build Driver run: | dotnet restore - dotnet build + dotnet build '-p:DefineAdditionalConstants=SF_PUBLIC_ENVIRONMENT' - name: Run Tests run: | cd Snowflake.Data.Tests @@ -178,7 +178,7 @@ jobs: - name: Build Driver run: | dotnet restore - dotnet build + dotnet build '-p:DefineAdditionalConstants=SF_PUBLIC_ENVIRONMENT' - name: Run Tests run: | cd Snowflake.Data.Tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 895581a9c..c2af46b46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ repos: - repo: git@github.com:snowflakedb/casec_precommit.git - rev: v1.20 + rev: v1.35.4 hooks: - id: secret-scanner diff --git a/CodingConventions.md b/CodingConventions.md index 19ca8fc75..0242f583e 100644 --- a/CodingConventions.md +++ b/CodingConventions.md @@ -85,6 +85,18 @@ public class ExampleClass } ``` +#### Property + +Use PascalCase, eg. `SomeProperty`. + +```csharp +public ExampleProperty +{ + get; + set; +} +``` + ### Local variables Use camelCase, eg. `someVariable`. diff --git a/README.md b/README.md index 59daa2d88..7b8766853 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The Snowflake .NET connector supports the the following .NET framework and libra - .NET 6.0 - .NET 8.0 -Please refer to the Notice section below for information about safe usage of the .NET Driver +Please refer to the [Notice](#notice) section below for information about safe usage of the .NET Driver # Coding conventions for the project @@ -69,934 +69,46 @@ Alternatively, packages can also be downloaded using Package Manager Console: PM> Install-Package Snowflake.Data ``` -# Testing the Connector +# Testing and Code Coverage -Before running tests, create a parameters.json file under Snowflake.Data.Tests\ directory. In this file, specify username, password and account info that tests will run against. Here is a sample parameters.json file +[Running tests](doc/Testing.md) -``` -{ - "testconnection": { - "SNOWFLAKE_TEST_USER": "snowman", - "SNOWFLAKE_TEST_PASSWORD": "XXXXXXX", - "SNOWFLAKE_TEST_ACCOUNT": "TESTACCOUNT", - "SNOWFLAKE_TEST_WAREHOUSE": "TESTWH", - "SNOWFLAKE_TEST_DATABASE": "TESTDB", - "SNOWFLAKE_TEST_SCHEMA": "TESTSCHEMA", - "SNOWFLAKE_TEST_ROLE": "TESTROLE", - "SNOWFLAKE_TEST_HOST": "testaccount.snowflakecomputing.com" - } -} -``` - -## Command Prompt - -The build solution file builds the connector and tests binaries. Issue the following command from the command line to run the tests. The test binary is located in the Debug directory if you built the solution file in Debug mode. - -```{r, engine='bash', code_block_name} -cd Snowflake.Data.Tests -dotnet test -f net6.0 -l "console;verbosity=normal" -``` - -Tests can also be run under code coverage: - -```{r, engine='bash', code_block_name} -dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_coverage.xml --output-format cobertura --settings coverage.config -``` - -You can run only specific suite of tests (integration or unit). - -Running unit tests: - -```bash -cd Snowflake.Data.Tests -dotnet test -l "console;verbosity=normal" --filter FullyQualifiedName~UnitTests -l console;verbosity=normal -``` - -Running integration tests: - -```bash -cd Snowflake.Data.Tests -dotnet test -l "console;verbosity=normal" --filter FullyQualifiedName~IntegrationTests -``` +[Code coverage](doc/CodeCoverage.md) -## Visual Studio 2017 - -Tests can also be run under Visual Studio 2017. Open the solution file in Visual Studio 2017 and run tests using Test Explorer. +--- # Usage ## Create a Connection -To connect to Snowflake, specify a valid connection string composed of key-value pairs separated by semicolons, -i.e "\=\;\=\...". - -**Note**: If the keyword or value contains an equal sign (=), you must precede the equal sign with another equal sign. For example, if the keyword is "key" and the value is "value_part1=value_part2", use "key=value_part1==value_part2". - -The following table lists all valid connection properties: -
- -| Connection Property | Required | Comment | -|--------------------------------| -------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| ACCOUNT | Yes | Your full account name might include additional segments that identify the region and cloud platform where your account is hosted | -| APPLICATION | No | **_Snowflake partner use only_**: Specifies the name of a partner application to connect through .NET. The name must match the following pattern: ^\[A-Za-z](\[A-Za-z0-9.-]){1,50}$ (one letter followed by 1 to 50 letter, digit, .,- or, \_ characters). | -| DB | No | | -| HOST | No | Specifies the hostname for your account in the following format: \.snowflakecomputing.com.
If no value is specified, the driver uses \.snowflakecomputing.com. | -| PASSWORD | Depends | Required if AUTHENTICATOR is set to `snowflake` (the default value) or the URL for native SSO through Okta. Ignored for all the other authentication types. | -| ROLE | No | | -| SCHEMA | No | | -| USER | Depends | If AUTHENTICATOR is set to `externalbrowser` this is optional. For native SSO through Okta, set this to the login name for your identity provider (IdP). | -| WAREHOUSE | No | | -| CONNECTION_TIMEOUT | No | Total timeout in seconds when connecting to Snowflake. The default is 300 seconds | -| RETRY_TIMEOUT | No | Total timeout in seconds for supported endpoints of retry policy. The default is 300 seconds. The value can only be increased from the default value or set to 0 for infinite timeout | -| MAXHTTPRETRIES | No | Maximum number of times to retry failed HTTP requests (default: 7). You can set `MAXHTTPRETRIES=0` to remove the retry limit, but doing so runs the risk of the .NET driver infinitely retrying failed HTTP calls. | -| CLIENT_SESSION_KEEP_ALIVE | No | Whether to keep the current session active after a period of inactivity, or to force the user to login again. If the value is `true`, Snowflake keeps the session active indefinitely, even if there is no activity from the user. If the value is `false`, the user must log in again after four hours of inactivity. The default is `false`. Setting this value overrides the server session property for the current session. | -| BROWSER_RESPONSE_TIMEOUT | No | Number to seconds to wait for authentication in an external browser (default: 120). | -| DISABLERETRY | No | Set this property to `true` to prevent the driver from reconnecting automatically when the connection fails or drops. The default value is `false`. | -| AUTHENTICATOR | No | The method of authentication. Currently supports the following values:
- snowflake (default): You must also set USER and PASSWORD.
- [the URL for native SSO through Okta](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#native-sso-okta-only): You must also set USER and PASSWORD.
- [externalbrowser](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#browser-based-sso): You must also set USER.
- [snowflake_jwt](https://docs.snowflake.com/en/user-guide/key-pair-auth.html): You must also set PRIVATE_KEY_FILE or PRIVATE_KEY.
- [oauth](https://docs.snowflake.com/en/user-guide/oauth.html): You must also set TOKEN. | -| VALIDATE_DEFAULT_PARAMETERS | No | Whether DB, SCHEMA and WAREHOUSE should be verified when making connection. Default to be true. | -| PRIVATE_KEY_FILE | Depends | The path to the private key file to use for key-pair authentication. Must be used in combination with AUTHENTICATOR=snowflake_jwt | -| PRIVATE_KEY_PWD | No | The passphrase to use for decrypting the private key, if the key is encrypted. | -| PRIVATE_KEY | Depends | The private key to use for key-pair authentication. Must be used in combination with AUTHENTICATOR=snowflake_jwt.
If the private key value includes any equal signs (=), make sure to replace each equal sign with two signs (==) to ensure that the connection string is parsed correctly. | -| TOKEN | Depends | The OAuth token to use for OAuth authentication. Must be used in combination with AUTHENTICATOR=oauth. | -| INSECUREMODE | No | Set to true to disable the certificate revocation list check. Default is false. | -| USEPROXY | No | Set to true if you need to use a proxy server. The default value is false.

This parameter was introduced in v2.0.4. | -| PROXYHOST | Depends | The hostname of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | -| PROXYPORT | Depends | The port number of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | -| PROXYUSER | No | The username for authenticating to the proxy server.

This parameter was introduced in v2.0.4. | -| PROXYPASSWORD | Depends | The password for authenticating to the proxy server.

If USEPROXY is `true` and PROXYUSER is set, you must set this parameter.

This parameter was introduced in v2.0.4. | -| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.
The host target value should fully match with any item from the proxy host list to bypass the proxy server.

This parameter was introduced in v2.0.4. | -| FILE_TRANSFER_MEMORY_THRESHOLD | No | The maximum number of bytes to store in memory used in order to provide a file encryption. If encrypting/decrypting file size exceeds provided value a temporary file will be created and the work will be continued in the temporary file instead of memory.
If no value provided 1MB will be used as a default value (that is 1048576 bytes).
It is possible to configure any integer value bigger than zero representing maximal number of bytes to reside in memory. | -| CLIENT_CONFIG_FILE | No | The location of the client configuration json file. In this file you can configure easy logging feature. | -| ALLOWUNDERSCORESINHOST | No | Specifies whether to allow underscores in account names. This impacts PrivateLink customers whose account names contain underscores. In this situation, you must override the default value by setting allowUnderscoresInHost to true. | -| QUERY_TAG | No | Optional string that can be used to tag queries and other SQL statements executed within a connection. The tags are displayed in the output of the QUERY_HISTORY , QUERY_HISTORY_BY_* functions.
To set QUERY_TAG on the statement level you can use SnowflakeDbCommand.QueryTag. | - -
- -### Password-based Authentication - -The following example demonstrates how to open a connection to Snowflake. This example uses a password for authentication. - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = "account=testaccount;user=testuser;password=XXXXX;db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); -} -``` - - - -Beginning with version 2.0.18, the .NET connector uses Microsoft [DbConnectionStringBuilder](https://learn.microsoft.com/en-us/dotnet/api/system.data.oledb.oledbconnection.connectionstring?view=dotnet-plat-ext-6.0#remarks) to follow the .NET specification for escaping characters in connection strings. - -The following examples show how you can include different types of special characters in a connection string: - -- To include a single quote (') character: - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "user=testuser; " + - "password=test'password;" - ); - ``` - -- To include a double quote (") character: - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "user=testuser; " + - "password=test\"password;" - ); - ``` - -- To include a semicolon (;): - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "user=testuser; " + - "password=\"test;password\";" - ); - ``` - -- To include an equal sign (=): - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "user=testuser; " + - "password=test=password;" - ); - ``` - - Note that previously you needed to use a double equal sign (==) to escape the character. However, beginning with version 2.0.18, you can use a single equal size. - - -Snowflake supports using [double quote identifiers](https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers) for object property values (WAREHOUSE, DATABASE, SCHEMA AND ROLES). The value should be delimited with `\"` in the connection string. The value is case-sensitive and allow to use special characters as part of the value. - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "database=\"testDB\";" - ); - ``` - - To include a `"` character as part of the value should be escaped using `\"\"`. - - ```cs - string connectionString = String.Format( - "account=testaccount; " + - "database=\"\"\"test\"\"user\"\"\";" // DATABASE => ""test"db"" - ); - ``` - - -### Other Authentication Methods - -If you are using a different method for authentication, see the examples below: - -- **Key-pair authentication** - - After setting up [key-pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth.html), you can specify the - private key for authentication in one of the following ways: - - - Specify the file containing an unencrypted private key: - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = "account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key_file={pathToThePrivateKeyFile};db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. - - - Specify the file containing an encrypted private key: - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = "account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key_file={pathToThePrivateKeyFile};private_key_pwd={passwordForDecryptingThePrivateKey};db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. - - `{passwordForDecryptingThePrivateKey}` is the password for decrypting the private key. - - - Specify an unencrypted private key (read from a file): - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - string privateKeyContent = File.ReadAllText({pathToThePrivateKeyFile}); - - conn.ConnectionString = String.Format("account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key={0};db=testdb;schema=testschema", privateKeyContent); - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. - -- **OAuth** - - After setting up [OAuth](https://docs.snowflake.com/en/user-guide/oauth.html), set `AUTHENTICATOR=oauth` and `TOKEN` to the - OAuth token in the connection string. - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = "account=testaccount;user=testuser;authenticator=oauth;token={oauthTokenValue};db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{oauthTokenValue}` is the oauth token to use for authentication. - -- **Browser-based SSO** - - In the connection string, set `AUTHENTICATOR=externalbrowser`. - Optionally, `USER` can be set. In that case only if user authenticated via external browser matches the one from configuration, authentication will complete. - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = "account=testaccount;authenticator=externalbrowser;user={login_name_for_IdP};db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{login_name_for_IdP}` is your login name for your IdP. - - You can override the default timeout after which external browser authentication is marked as failed. - The timeout prevents the infinite hang when the user does not provide the login details, e.g. when closing the browser tab. - To override, you can provide `BROWSER_RESPONSE_TIMEOUT` parameter (in seconds). - -- **Native SSO through Okta** - - In the connection string, set `AUTHENTICATOR` to the - [URL of the endpoint for your Okta account](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#label-native-sso-okta), - and set `USER` to the login name for your IdP. - - ```cs - using (IDbConnection conn = new SnowflakeDbConnection()) - { - conn.ConnectionString = "account=testaccount;authenticator={okta_url_endpoint};user={login_name_for_IdP};db=testdb;schema=testschema"; - - conn.Open(); - - conn.Close(); - } - ``` - - where: - - - `{okta_url_endpoint}` is the URL for the endpoint for your Okta account (e.g. `https://.okta.com`). - - `{login_name_for_IdP}` is your login name for your IdP. - -In v2.0.4 and later releases, you can configure the driver to connect through a proxy server. The following example configures the -driver to connect through the proxy server `myproxyserver` on port `8888`. The driver authenticates to the proxy server as the -user `test` with the password `test`: - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = "account=testaccount;user=testuser;password=XXXXX;db=testdb;schema=testschema;useProxy=true;proxyHost=myproxyserver;proxyPort=8888;proxyUser=test;proxyPassword=test"; - - conn.Open(); - - conn.Close(); -} -``` - -The NONPROXYHOSTS property could be set to specify if the server proxy should be bypassed by an specified host. This should be defined using the full host url or including the url + `*` wilcard symbol. - -Examples: - -- `*` (Bypassed all hosts from the proxy server) -- `*.snowflakecomputing.com` ('Bypass all host that ends with `snowflakecomputing.com`') -- `https:\\testaccount.snowflakecomputing.com` (Bypass proxy server using full host url). -- `*.myserver.com | *testaccount*` (You can specify multiple regex for the property divided by `|`) - - -> Note: The nonproxyhost value should match the full url including the http or https section. The '*' wilcard could be added to bypass the hostname successfully. - -- `myaccount.snowflakecomputing.com` (Not bypassed). -- `*myaccount.snowflakecomputing.com` (Bypassed). - +To create a connection get familiar with: [Connecting and Authentication Methods](doc/Connecting.md) ## Using Connection Pools -Instead of creating a connection each time your client application needs to access Snowflake, you can define a cache of Snowflake connections that can be reused as needed. Connection pooling usually reduces the lag time to make a connection. However, it can slow down client failover to an alternative DNS when a DNS problem occurs. +Connection pooling description: [Multiple Connection Pools](doc/ConnectionPooling.md). -The Snowflake .NET driver provides the following functions for managing connection pools. +Pooling prior to v4.0.0 is described: [Single Connection Pool](doc/ConnectionPoolingDeprecated.md) - `deprecated` -| Function | Description | -| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------- | -| SnowflakeDbConnectionPool.ClearAllPools() | Removes all connections from the connection pool. | -| SnowflakeDbConnection.SetMaxPoolSize(n) | Sets the maximum number of connections for the connection pool, where _n_ is the number of connections. | -| SnowflakeDBConnection.SetTimeout(n) | Sets the number of seconds to keep an unresponsive connection in the connection pool. | -| SnowflakeDbConnectionPool.GetCurrentPoolSize() | Returns the number of connections currently in the connection pool. | -| SnowflakeDbConnectionPool.SetPooling() | Determines whether to enable (`true`) or disable (`false`) connecing pooling. Default: `true`. | +## Data Types and Formats -The following sample demonstrates how to monitor the size of a connection pool as connections are added and dropped from the pool. +Snowflake data types and their .NET types is covered in: [Data Types and Data Formats](doc/DataTypes.md) -```cs -public void TestConnectionPoolClean() -{ - SnowflakeDbConnectionPool.ClearAllPools(); - SnowflakeDbConnectionPool.SetMaxPoolSize(2); - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ConnectionString; - conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); +## Querying Data - var conn2 = new SnowflakeDbConnection(); - conn2.ConnectionString = ConnectionString + " retryCount=1"; - conn2.Open(); - Assert.AreEqual(ConnectionState.Open, conn2.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - conn1.Close(); - conn2.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - var conn3 = new SnowflakeDbConnection(); - conn3.ConnectionString = ConnectionString + " retryCount=2"; - conn3.Open(); - Assert.AreEqual(ConnectionState.Open, conn3.State); +How execute a query, use query bindings, run queries synchronously and asynchronously: +[Running Queries and Reading Results](doc/QueryingData.md) - var conn4 = new SnowflakeDbConnection(); - conn4.ConnectionString = ConnectionString + " retryCount=3"; - conn4.Open(); - Assert.AreEqual(ConnectionState.Open, conn4.State); +## Stage Files - conn3.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - conn4.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(ConnectionState.Closed, conn2.State); - Assert.AreEqual(ConnectionState.Closed, conn3.State); - Assert.AreEqual(ConnectionState.Closed, conn4.State); -} -``` - -## Mapping .NET and Snowflake Data Types - -The .NET driver supports the following mappings from .NET to Snowflake data types. - -| .NET Framekwork Data Type | Data Type in Snowflake | -| ------------------------- | ---------------------- | -| `int`, `long` | `NUMBER(38, 0)` | -| `decimal` | `NUMBER(38, )` | -| `double` | `REAL` | -| `string` | `TEXT` | -| `bool` | `BOOLEAN` | -| `byte` | `BINARY` | -| `datetime` | `DATE` | - -## Arrow data format - -The .NET connector, starting with v2.1.3, supports the [Arrow data format](https://arrow.apache.org/) -as a [preview](https://docs.snowflake.com/en/release-notes/preview-features) feature for data transfers -between Snowflake and a .NET client. The Arrow data format avoids extra -conversions between binary and textual representations of the data. The Arrow -data format can improve performance and reduce memory consumption in clients. - -The data format is controlled by the -DOTNET_QUERY_RESULT_FORMAT parameter. To use Arrow format, execute: - -```snowflake --- at the session level -ALTER SESSION SET DOTNET_QUERY_RESULT_FORMAT = ARROW; --- or at the user level -ALTER USER SET DOTNET_QUERY_RESULT_FORMAT = ARROW; --- or at the account level -ALTER ACCOUNT SET DOTNET_QUERY_RESULT_FORMAT = ARROW; -``` - -The valid values for the parameter are: - -- ARROW -- JSON (default) - -## Run a Query and Read Data - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = connectionString; - conn.Open(); - - IDbCommand cmd = conn.CreateCommand(); - cmd.CommandText = "select * from t"; - IDataReader reader = cmd.ExecuteReader(); - - while(reader.Read()) - { - Console.WriteLine(reader.GetString(0)); - } - - conn.Close(); -} -``` - -Note that for a `TIME` column, the reader returns a `System.DateTime` value. If you need a `System.TimeSpan` column, call the -`getTimeSpan` method in `SnowflakeDbDataReader`. This method was introduced in the v2.0.4 release. - -Note that because this method is not available in the generic `IDataReader` interface, you must cast the object as -`SnowflakeDbDataReader` before calling the method. For example: - -```cs -TimeSpan timeSpanTime = ((SnowflakeDbDataReader)reader).GetTimeSpan(13); -``` - -## Execute a query asynchronously on the server - -You can run the query asynchronously on the server. The server responds immediately with `queryId` and continues to execute the query asynchronously. -Then you can use this `queryId` to check the query status or wait until the query is completed and get the results. -It is fine to start the query in one session and continue to query for the results in another one based on the queryId. - -**Note**: There are 2 levels of asynchronous execution. One is asynchronous execution in terms of C# language (`async await`). -Another is asynchronous execution of the query by the server (you can recognize it by `InAsyncMode` containing method names, e. g. `ExecuteInAsyncMode`, `ExecuteAsyncInAsyncMode`). - -Example of synchronous code starting a query to be executed asynchronously on the server: -```cs -using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) -{ - conn.Open(); - SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand(); - cmd.CommandText = "SELECT ..."; - var queryId = cmd.ExecuteInAsyncMode(); - // ... -} -``` - -Example of asynchronous code starting a query to be executed asynchronously on the server: -```cs -using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) -{ - await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); - SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) - cmd.CommandText = "SELECT ..."; - var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); - // ... -} -``` - -You can check the status of a query executed asynchronously on the server either in synchronous code: -```cs -var queryStatus = cmd.GetQueryStatus(queryId); -Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running -Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error -``` -or the same in an asynchronous code: -```cs -var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); -Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running -Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error -``` - -The following example shows how to get query results. -The operation will repeatedly check the query status until the query is completed or timeout happened or reaching the maximum number of attempts. -The synchronous code example: -```cs -DbDataReader reader = cmd.GetResultsFromQueryId(queryId); -``` -and the asynchronous code example: -```cs -DbDataReader reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); -``` - -**Note**: GET/PUT operations are currently not enabled for asynchronous executions. - -## Executing a Batch of SQL Statements (Multi-Statement Support) - -With version 2.0.18 and later of the .NET connector, you can send -a batch of SQL statements, separated by semicolons, -to be executed in a single request. - -**Note**: Snowflake does not currently support variable binding in multi-statement SQL requests. - ---- - -**Note** - -By default, Snowflake returns an error for queries issued with multiple statements to protect against SQL injection attacks. The multiple statements feature makes your system more vulnerable to SQL injections, and so it should be used carefully. You can reduce the risk by using the MULTI_STATEMENT_COUNT parameter to specify the number of statements to be executed, which makes it more difficult to inject a statement by appending to it. - ---- - -You can execute multiple statements as a batch in the same way you execute queries with single statements, except that the query string contains multiple statements separated by semicolons. Note that multiple statements execute sequentially, not in parallel. - -You can set this parameter at the session level using the following command: - -``` -ALTER SESSION SET MULTI_STATEMENT_COUNT = <0/1>; -``` - -where: - -- **0**: Enables an unspecified number of SQL statements in a query. - - Using this value allows batch queries to contain any number of SQL statements without needing to specify the MULTI_STATEMENT_COUNT statement parameter. However, be aware that using this value reduces the protection against SQL injection attacks. - -- **1**: Allows one SQL statement or a specified number of statement in a query string (default). - - You must include MULTI_STATEMENT_COUNT as a statement parameter to specify the number of statements included when the query string contains more than one statement. If the number of statements sent in the query string does not match the MULTI_STATEMENT_COUNT value, the .NET driver rejects the request. You can, however, omit this parameter if you send a single statement. - -The following example sets the MULTI_STATEMENT_COUNT session parameter to 1. Then for an individual command, it sets MULTI_STATEMENT_COUNT=3 to indicate that the query contains precisely three SQL commands. The query string, `cmd.CommandText` , then contains the three statements to execute. - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = ConnectionString; - conn.Open(); - IDbCommand cmd = conn.CreateCommand(); - cmd.CommandText = "ALTER SESSION SET MULTI_STATEMENT_COUNT = 1;"; - cmd.ExecuteNonQuery(); - conn.Close(); -} - -using (DbCommand cmd = conn.CreateCommand()) -{ - // Set statement count - var stmtCountParam = cmd.CreateParameter(); - stmtCountParam.ParameterName = "MULTI_STATEMENT_COUNT"; - stmtCountParam.DbType = DbType.Int16; - stmtCountParam.Value = 3; - cmd.Parameters.Add(stmtCountParam); - cmd.CommandText = "CREATE OR REPLACE TABLE test(n int); INSERT INTO test values(1), (2); SELECT * FROM test ORDER BY n; - DbDataReader reader = cmd.ExecuteReader(); - do - { - if (reader.HasRow) - { - while (reader.Read()) - { - // read data - } - } - } - while (reader.NextResult()); -} -``` - -## Bind Parameter - -**Note**: Snowflake does not currently support variable binding in multi-statement SQL requests. - -This example shows how bound parameters are converted from C# data types to -Snowflake data types. For example, if the data type of the Snowflake column -is INTEGER, then you can bind C# data types Int32 or Int16. - -This example inserts 3 rows into a table with one column. - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = connectionString; - conn.Open(); - - IDbCommand cmd = conn.CreateCommand(); - cmd.CommandText = "create or replace table T(cola int)"; - int count = cmd.ExecuteNonQuery(); - Assert.AreEqual(0, count); - - IDbCommand cmd = conn.CreateCommand(); - cmd.CommandText = "insert into t values (?), (?), (?)"; - - var p1 = cmd.CreateParameter(); - p1.ParameterName = "1"; - p1.Value = 10; - p1.DbType = DbType.Int32; - cmd.Parameters.Add(p1); - - var p2 = cmd.CreateParameter(); - p2.ParameterName = "2"; - p2.Value = 10000L; - p2.DbType = DbType.Int32; - cmd.Parameters.Add(p2); - - var p3 = cmd.CreateParameter(); - p3.ParameterName = "3"; - p3.Value = (short)1; - p3.DbType = DbType.Int16; - cmd.Parameters.Add(p3); - - var count = cmd.ExecuteNonQuery(); - Assert.AreEqual(3, count); - - cmd.CommandText = "drop table if exists T"; - count = cmd.ExecuteNonQuery(); - Assert.AreEqual(0, count); - - conn.Close(); -} -``` - -## Bind Array Variables - -The sample code creates a table with a single integer column and then uses array binding to populate the table with values 0 to 70000. - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - conn.ConnectionString = ConnectionString; - conn.Open(); - - using (IDbCommand cmd = conn.CreateCommand()) - { - cmd.CommandText = "create or replace table putArrayBind(colA integer)"; - cmd.ExecuteNonQuery(); - - string insertCommand = "insert into putArrayBind values (?)"; - cmd.CommandText = insertCommand; - - int total = 70000; - - List arrint = new List(); - for (int i = 0; i < total; i++) - { - arrint.Add(i); - } - var p1 = cmd.CreateParameter(); - p1.ParameterName = "1"; - p1.DbType = DbType.Int16; - p1.Value = arrint.ToArray(); - cmd.Parameters.Add(p1); - - count = cmd.ExecuteNonQuery(); // count = 70000 - } - - conn.Close(); -} -``` - -## PUT local files to stage - -PUT command can be used to upload files of a local directory or a single local file to the Snowflake stages (named, internal table stage or internal user stage). -Such staging files can be used to load data into a table. -More on this topic: [File staging with PUT](https://docs.snowflake.com/en/sql-reference/sql/put). - -In the driver the command can be executed in a bellow way: - -```cs -using (IDbConnection conn = new SnowflakeDbConnection()) -{ - try - { - conn.ConnectionString = ""; - conn.Open(); - var cmd = (SnowflakeDbCommand)conn.CreateCommand(); // cast allows get QueryId from the command - - cmd.CommandText = "PUT file://some_data.csv @my_schema.my_stage AUTO_COMPRESS=TRUE"; - var reader = cmd.ExecuteReader(); - Assert.IsTrue(reader.read()); - Assert.DoesNotThrow(() => Guid.Parse(cmd.GetQueryId())); - } - catch (SnowflakeDbException e) - { - Assert.DoesNotThrow(() => Guid.Parse(e.QueryId)); // when failed - Assert.That(e.InnerException.GetType(), Is.EqualTo(typeof(FileNotFoundException))); - } -``` - -In case of a failure a SnowflakeDbException exception will be thrown with affected QueryId if possible. -If it was after the query got executed this exception will be a SnowflakeDbException containing affected QueryId. -In case of the initial phase of execution QueryId might not be provided. -Inner exception (if applicable) will provide some details on the failure cause and -it will be for example: FileNotFoundException, DirectoryNotFoundException. - -## GET stage files - -GET command allows to download stage directories or files to a local directory. -It can be used in connection with named stage, table internal stage or user stage. -Detailed information on the command: [Downloading files with GET](https://docs.snowflake.com/en/sql-reference/sql/get). - -To use the command in a driver similar code can be executed in a client app: - -```cs - try - { - conn.ConnectionString = ""; - conn.Open(); - var cmd = (SnowflakeDbCommand)conn.CreateCommand(); // cast allows get QueryId from the command - - cmd.CommandText = "GET @my_schema.my_stage/stage_file.csv file://local_file.csv AUTO_COMPRESS=TRUE"; - var reader = cmd.ExecuteReader(); - Assert.IsTrue(reader.read()); // True on success, False if failure - Assert.DoesNotThrow(() => Guid.Parse(cmd.GetQueryId())); - } - catch (SnowflakeDbException e) - { - Assert.DoesNotThrow(() => Guid.Parse(e.QueryId)); // on failure - } -``` - -In case of a failure a SnowflakeDbException will be thrown with affected QueryId if possible. -When no technical or syntax errors occurred but the DBDataReader has no data to process it returns False -without throwing an exception. - -## Close the Connection - -To close the connection, call the `Close` method of `SnowflakeDbConnection`. - -If you want to avoid blocking threads while the connection is closing, call the `CloseAsync` method instead, passing in a -`CancellationToken`. This method was introduced in the v2.0.4 release. - -Note that because this method is not available in the generic `IDbConnection` interface, you must cast the object as -`SnowflakeDbConnection` before calling the method. For example: - -```cs -CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); -// Close the connection -((SnowflakeDbConnection)conn).CloseAsync(cancellationTokenSource.Token); -``` +Using stage files within PUT/GET commands: +[PUT and GET Files to/from Stage](doc/StageFiles.md) ## Logging -The Snowflake Connector for .NET uses [log4net](http://logging.apache.org/log4net/) as the logging framework. +Logging description and configuration: +[Logging and Easy Logging](doc/Logging.md) -Here is a sample app.config file that uses [log4net](http://logging.apache.org/log4net/) - -```xml - -
- - - - - - - - - - - - - - - - - - - - - -``` - -## Easy logging - -The Easy Logging feature lets you change the log level for all driver classes and add an extra file appender for logs from the driver's classes at runtime. You can specify the log levels and the directory in which to save log files in a configuration file (default: `sf_client_config.json`). - -You typically change log levels only when debugging your application. - -**Note** -This logging configuration file features support only the following log levels: - -- OFF -- ERROR -- WARNING -- INFO -- DEBUG -- TRACE - -This configuration file uses JSON to define the `log_level` and `log_path` logging parameters, as follows: - -```json -{ - "common": { - "log_level": "INFO", - "log_path": "c:\\some-path\\some-directory" - } -} -``` - -where: - -- `log_level` is the desired logging level. -- `log_path` is the location to store the log files. The driver automatically creates a `dotnet` subdirectory in the specified `log_path`. For example, if you set log_path to `c:\logs`, the drivers creates the `c:\logs\dotnet` directory and stores the logs there. - -The driver looks for the location of the configuration file in the following order: - -- `CLIENT_CONFIG_FILE` connection parameter, containing the full path to the configuration file (e.g. `"ACCOUNT=test;USER=test;PASSWORD=test;CLIENT_CONFIG_FILE=C:\\some-path\\client_config.json;"`) -- `SF_CLIENT_CONFIG_FILE` environment variable, containing the full path to the configuration file. -- .NET driver/application directory, where the file must be named `sf_client_config.json`. -- User’s home directory, where the file must be named `sf_client_config.json`. - -**Note** -To enhance security, the driver no longer searches a temporary directory for easy logging configurations. Additionally, the driver now requires the logging configuration file on Unix-style systems to limit file permissions to allow only the file owner to modify the files (such as `chmod 0600` or `chmod 0644`). - -To minimize the number of searches for a configuration file, the driver reads the file only for: - -- The first connection. -- The first connection with `CLIENT_CONFIG_FILE` parameter. - -The extra logs are stored in a `dotnet` subfolder of the specified directory, such as `C:\some-path\some-directory\dotnet`. - -If a client uses the `log4net` library for application logging, enabling easy logging affects the log level in those logs as well. - -## Getting the code coverage - -1. Go to .NET project directory - -2. Clean the directory - -``` -dotnet clean snowflake-connector-net.sln && dotnet nuget locals all --clear -``` - -3. Create parameters.json containing connection info for AWS, AZURE, or GCP account and place inside the Snowflake.Data.Tests folder - -4. Build the project for .NET6 - -``` -dotnet build snowflake-connector-net.sln /p:DebugType=Full -``` - -5. Run dotnet-cover on the .NET6 build - -``` -dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_AWS_coverage.xml --output-format cobertura --settings coverage.config -``` - -6. Build the project for .NET Framework - -``` -msbuild snowflake-connector-net.sln -p:Configuration=Release -``` - -7. Run dotnet-cover on the .NET Framework build - -``` -dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_AWS_coverage.xml --output-format cobertura --settings coverage.config -``` - -
-Repeat steps 3, 5, and 7 for the other cloud providers.
-Note: no need to rebuild the connector again.

- -For Azure:
- -3. Create parameters.json containing connection info for AZURE account and place inside the Snowflake.Data.Tests folder - -4. Run dotnet-cover on the .NET6 build - -``` -dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_AZURE_coverage.xml --output-format cobertura --settings coverage.config -``` - -7. Run dotnet-cover on the .NET Framework build - -``` -dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_AZURE_coverage.xml --output-format cobertura --settings coverage.config -``` - -
-For GCP:
- -3. Create parameters.json containing connection info for GCP account and place inside the Snowflake.Data.Tests folder - -4. Run dotnet-cover on the .NET6 build - -``` -dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_GCP_coverage.xml --output-format cobertura --settings coverage.config -``` - -7. Run dotnet-cover on the .NET Framework build - -``` -dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_GCP_coverage.xml --output-format cobertura --settings coverage.config -``` +--------------- ## Notice @@ -1027,4 +139,14 @@ dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;ve Snowflake has identified an issue where the driver is globally enforcing TLS 1.2 and certificate revocation checks with the .NET Driver v1.2.1 and earlier versions. Starting with v2.0.0, the driver will set these locally. +4. Certificate Revocation List not performed where insecureMode was disabled - + Snowflake has identified vulnerability where the checks against the Certificate Revocation List (CRL) + were not performed where the insecureMode flag was set to false, which is the default setting. + From version v2.1.5 CRL is working back as intended. + Note that the driver is now targeting .NET 6.0. When upgrading, you might also need to run “Update-Package -reinstall” to update the dependencies. + +See more: +* [Security Policy](SECURITY.md) +* [Security Advisories](/security/advisories) + diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs new file mode 100644 index 000000000..aa5d431ed --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsAsyncIT.cs @@ -0,0 +1,197 @@ +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Mock; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionMultiplePoolsAsyncIT: SFBaseTestAsync + { + private readonly PoolConfig _previousPoolConfig = new PoolConfig(); + private readonly SFLogger logger = SFLoggerFactory.GetLogger(); + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + _previousPoolConfig.Reset(); + } + + [Test] + public async Task TestAddToPoolOnOpenAsync() + { + // arrange + var connection = new SnowflakeDbConnection(ConnectionString + "minPoolSize=1"); + + // act + await connection.OpenAsync().ConfigureAwait(false); + + // assert + var pool = SnowflakeDbConnectionPool.GetPool(connection.ConnectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + // cleanup + await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + + [Test] + public async Task TestFailForInvalidConnectionAsync() + { + // arrange + var invalidConnectionString = ";connection_timeout=123"; + var connection = new SnowflakeDbConnection(invalidConnectionString); + + // act + try + { + await connection.OpenAsync().ConfigureAwait(false); + Assert.Fail("OpenAsync should fail for invalid connection string"); + } + catch {} + var thrown = Assert.Throws(() => SnowflakeDbConnectionPool.GetPool(connection.ConnectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain("Required property ACCOUNT is not provided")); + } + + [Test] + public void TestConnectionPoolWithInvalidOpenAsync() + { + // make the connection string unique so it won't pick up connection + // pooled by other test cases. + string connStr = ConnectionString + "minPoolSize=0;maxPoolSize=10;application=conn_pool_test_invalid_openasync2"; + using (var connection = new SnowflakeDbConnection()) + { + connection.ConnectionString = connStr; + // call openAsync but do not wait and destroy it direct + // so the session is initialized with empty token + connection.OpenAsync(); + } + + // use the same connection string to make a new connection + // to ensure the invalid connection made previously is not pooled + using (var connection1 = new SnowflakeDbConnection()) + { + connection1.ConnectionString = connStr; + // this will not open a new session but get the invalid connection from pool + connection1.Open(); + // Now run query with connection1 + var command = connection1.CreateCommand(); + command.CommandText = "select 1, 2, 3"; + + try + { + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + // Process each column as appropriate + reader.GetFieldValue(i); + } + } + } + } + catch (SnowflakeDbException) + { + // fail the test case if anything wrong. + Assert.Fail(); + } + } + } + + [Test] + public async Task TestMinPoolSizeAsync() + { + // arrange + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = ConnectionString + "application=TestMinPoolSizeAsync;minPoolSize=3"; + + // act + await connection.OpenAsync().ConfigureAwait(false); + Thread.Sleep(3000); + + // assert + var pool = SnowflakeDbConnectionPool.GetPool(connection.ConnectionString); + Assert.AreEqual(3, pool.GetCurrentPoolSize()); + + // cleanup + await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); + } + + [Test] + public async Task TestPreventConnectionFromReturningToPool() + { + // arrange + var connectionString = ConnectionString + "minPoolSize=0"; + var connection = new SnowflakeDbConnection(connectionString); + await connection.OpenAsync().ConfigureAwait(false); + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + // act + connection.PreventPooling(); + await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public async Task TestReleaseConnectionWhenRollbackFailsAsync() + { + // arrange + var connectionString = ConnectionString + "minPoolSize=0"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + var commandThrowingExceptionOnlyForRollback = MockHelper.CommandThrowingExceptionOnlyForRollback(); + var mockDbProviderFactory = new Mock(); + mockDbProviderFactory.Setup(p => p.CreateCommand()).Returns(commandThrowingExceptionOnlyForRollback.Object); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + var connection = new TestSnowflakeDbConnection(mockDbProviderFactory.Object); + connection.ConnectionString = connectionString; + await connection.OpenAsync().ConfigureAwait(false); + connection.BeginTransaction(); // not using async version because it is not available on .net framework + Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); + + // act + await connection.CloseAsync(CancellationToken.None).ConfigureAwait(false); + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize(), "Should not return connection to the pool"); + } + + [Test(Description = "test connection pooling with concurrent connection using async calls")] + public void TestConcurrentConnectionPoolingAsync() + { + // add test case name in connection string to make in unique for each test case + // set short expiration timeout to cover the case that connection expired + string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingAsync2;ExpirationTimeout=3"; + ConnectionSinglePoolCacheAsyncIT.ConcurrentPoolingAsyncHelper(connStr, true, 7, 100, 2); + } + + [Test(Description = "test connection pooling with concurrent connection and using async calls no close call for connection. Connection is closed when Dispose() is called by framework.")] + public void TestConcurrentConnectionPoolingDisposeAsync() + { + // add test case name in connection string to make in unique for each test case + // set short expiration timeout to cover the case that connection expired + string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingDisposeAsync2;ExpirationTimeout=3"; + ConnectionSinglePoolCacheAsyncIT.ConcurrentPoolingAsyncHelper(connStr, false, 7, 100, 2); + } + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs new file mode 100644 index 000000000..4b7ec61f0 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionMultiplePoolsIT.cs @@ -0,0 +1,444 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Moq; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Mock; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionMultiplePoolsIT: SFBaseTest + { + private readonly PoolConfig _previousPoolConfig = new PoolConfig(); + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + _previousPoolConfig.Reset(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [Test] + public void TestBasicConnectionPool() + { + var connectionString = ConnectionString + "minPoolSize=0;maxPoolSize=1"; + var conn1 = new SnowflakeDbConnection(connectionString); + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + conn1.Close(); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString, null).GetCurrentPoolSize()); + } + + [Test] + public void TestReuseSessionInConnectionPool() // old name: TestConnectionPool + { + var connectionString = ConnectionString + "minPoolSize=1"; + var conn1 = new SnowflakeDbConnection(connectionString); + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + conn1.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = connectionString; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + + conn2.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + } + + [Test] + public void TestReuseSessionInConnectionPoolReachingMaxConnections() // old name: TestConnectionPoolFull + { + var connectionString = ConnectionString + "maxPoolSize=2;minPoolSize=1"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = connectionString; + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = connectionString; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + conn1.Close(); + conn2.Close(); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + + var conn3 = new SnowflakeDbConnection(); + conn3.ConnectionString = connectionString; + conn3.Open(); + Assert.AreEqual(ConnectionState.Open, conn3.State); + + var conn4 = new SnowflakeDbConnection(); + conn4.ConnectionString = connectionString; + conn4.Open(); + Assert.AreEqual(ConnectionState.Open, conn4.State); + + conn3.Close(); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + conn4.Close(); + Assert.AreEqual(2, pool.GetCurrentPoolSize()); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + Assert.AreEqual(ConnectionState.Closed, conn3.State); + Assert.AreEqual(ConnectionState.Closed, conn4.State); + } + + [Test] + public void TestWaitForTheIdleConnectionWhenExceedingMaxConnectionsLimit() + { + // arrange + var connectionString = ConnectionString + "application=TestWaitForMaxSize1;waitingForIdleSessionTimeout=1s;maxPoolSize=2;minPoolSize=1"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(0, pool.GetCurrentPoolSize(), "expecting pool to be empty"); + var conn1 = OpenConnection(connectionString); + var conn2 = OpenConnection(connectionString); + var watch = new StopWatch(); + + // act + watch.Start(); + var thrown = Assert.Throws(() => OpenConnection(connectionString)); + watch.Stop(); + + // assert + Assert.That(thrown.Message, Does.Contain("Unable to connect. Could not obtain a connection from the pool within a given timeout")); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(1000, 1500)); + Assert.AreEqual(pool.GetCurrentPoolSize(), 2); + + // cleanup + conn1.Close(); + conn2.Close(); + } + + [Test] + public void TestWaitForTheIdleConnectionWhenExceedingMaxConnectionsLimitAsync() + { + // arrange + var connectionString = ConnectionString + "application=TestWaitForMaxSize2;waitingForIdleSessionTimeout=1s;maxPoolSize=2;minPoolSize=1"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(0, pool.GetCurrentPoolSize(), "expecting pool to be empty"); + var conn1 = OpenConnection(connectionString); + var conn2 = OpenConnection(connectionString); + var watch = new StopWatch(); + + // act + watch.Start(); + var thrown = Assert.ThrowsAsync(() => OpenConnectionAsync(connectionString)); + watch.Stop(); + + // assert + Assert.That(thrown.Message, Does.Contain("Unable to connect")); + Assert.IsTrue(thrown.InnerException is AggregateException); + var nestedException = ((AggregateException)thrown.InnerException).InnerException; + Assert.That(nestedException.Message, Does.Contain("Could not obtain a connection from the pool within a given timeout")); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(1000, 1500)); + Assert.AreEqual(pool.GetCurrentPoolSize(), 2); + + // cleanup + conn1.Close(); + conn2.Close(); + } + + [Test] + [Retry(2)] + public void TestWaitInAQueueForAnIdleSession() + { + // arrange + var connectionString = ConnectionString + "application=TestWaitForMaxSize3;waitingForIdleSessionTimeout=3s;maxPoolSize=2;minPoolSize=0"; + var pool = SnowflakeDbConnectionPool.GetPoolInternal(connectionString); + Assert.AreEqual(0, pool.GetCurrentPoolSize(), "the pool is expected to be empty"); + const long ADelay = 0; + const long BDelay = 400; + const long CDelay = 2 * BDelay; + const long DDelay = 3 * BDelay; + const long ABDelayAfterConnect = 2000; + const long ConnectPessimisticEstimate = 1300; + const long StartDelayPessimisticEstimate = 350; + const long AMinConnectionReleaseTime = ADelay + ABDelayAfterConnect; // 2000 + const long AMaxConnectionReleaseTime = ADelay + StartDelayPessimisticEstimate + ConnectPessimisticEstimate + ABDelayAfterConnect; // 3650 + const long BMinConnectionReleaseTime = BDelay + ABDelayAfterConnect; // 2400 + const long BMaxConnectionReleaseTime = BDelay + StartDelayPessimisticEstimate + ConnectPessimisticEstimate + ABDelayAfterConnect; // 4050 + const long CMinConnectDuration = AMinConnectionReleaseTime - CDelay - StartDelayPessimisticEstimate; // 2000 - 800 - 350 = 850 + const long CMaxConnectDuration = AMaxConnectionReleaseTime - CDelay; // 3650 - 800 = 2850 + const long DMinConnectDuration = BMinConnectionReleaseTime - DDelay - StartDelayPessimisticEstimate; // 2400 - 1200 - 350 = 850 + const long DMaxConnectDuration = BMaxConnectionReleaseTime - DDelay; // 3650 - 800 = 2850 + const long MeasurementTolerance = 25; + + var threads = new ConnectingThreads(connectionString) + .NewThread("A", ADelay, ABDelayAfterConnect, true) + .NewThread("B", BDelay, ABDelayAfterConnect, true) + .NewThread("C", CDelay, 0, true) + .NewThread("D", DDelay, 0, true); + pool.SetSessionPoolEventHandler(new SessionPoolThreadEventHandler(threads)); + + // act + threads.StartAll().JoinAll(); + + // assert + var events = threads.Events().ToList(); + Assert.AreEqual(6, events.Count); // A,B - connected; C,D - waiting, connected + var waitingEvents = events.Where(e => e.IsWaitingEvent()).ToList(); + Assert.AreEqual(2, waitingEvents.Count); + CollectionAssert.AreEquivalent(new[] { "C", "D" }, waitingEvents.Select(e => e.ThreadName)); // equivalent = in any order + var connectedEvents = events.Where(e => e.IsConnectedEvent()).ToList(); + Assert.AreEqual(4, connectedEvents.Count); + var firstConnectedEventsGroup = connectedEvents.GetRange(0, 2); + CollectionAssert.AreEquivalent(new[] { "A", "B" }, firstConnectedEventsGroup.Select(e => e.ThreadName)); + var lastConnectingEventsGroup = connectedEvents.GetRange(2, 2); + CollectionAssert.AreEquivalent(new[] { "C", "D" }, lastConnectingEventsGroup.Select(e => e.ThreadName)); + Assert.LessOrEqual(firstConnectedEventsGroup[0].Duration, ConnectPessimisticEstimate); + Assert.LessOrEqual(firstConnectedEventsGroup[1].Duration, ConnectPessimisticEstimate); + // first to wait from C and D should first to connect, because we won't create a new session, we just reuse sessions returned by A and B threads + Assert.AreEqual(waitingEvents[0].ThreadName, lastConnectingEventsGroup[0].ThreadName); + Assert.AreEqual(waitingEvents[1].ThreadName, lastConnectingEventsGroup[1].ThreadName); + Assert.That(lastConnectingEventsGroup[0].Duration, Is.InRange(CMinConnectDuration - MeasurementTolerance, CMaxConnectDuration)); + Assert.That(lastConnectingEventsGroup[1].Duration, Is.InRange(DMinConnectDuration - MeasurementTolerance, DMaxConnectDuration)); + } + + [Test] + public void TestBusyAndIdleConnectionsCountedInPoolSize() + { + // arrange + var connectionString = ConnectionString + "maxPoolSize=2;minPoolSize=1"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = connectionString; + + // act + connection.Open(); + + // assert + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + // act + connection.Close(); + + // assert + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestConnectionPoolNotPossibleToDisableForAllPools() + { + // act + var thrown = Assert.Throws(() => SnowflakeDbConnectionPool.SetPooling(false)); + + // assert + Assert.IsNotNull(thrown); + } + + [Test] + public void TestConnectionPoolDisable() + { + // arrange + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString + ";poolingEnabled=false"); + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = ConnectionString; + + // act + conn1.Open(); + + // assert + Assert.AreEqual(ConnectionState.Open, conn1.State); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + + // act + conn1.Close(); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestNewConnectionPoolClean() + { + var connectionString = ConnectionString + "maxPoolSize=2;minPoolSize=1;"; + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = connectionString; + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = connectionString + "retryCount=1"; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + + var conn3 = new SnowflakeDbConnection(); + conn3.ConnectionString = connectionString + "retryCount=2"; + conn3.Open(); + Assert.AreEqual(ConnectionState.Open, conn3.State); + + conn1.Close(); + conn2.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn2.ConnectionString).GetCurrentPoolSize()); + SnowflakeDbConnectionPool.ClearAllPools(); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn2.ConnectionString).GetCurrentPoolSize()); + conn3.Close(); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(conn3.ConnectionString).GetCurrentPoolSize()); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + Assert.AreEqual(ConnectionState.Closed, conn3.State); + } + + [Test] + public void TestConnectionPoolExpirationWorks() + { + // arrange + const int ExpirationTimeoutInSeconds = 10; + var connectionString = ConnectionString + $"expirationTimeout={ExpirationTimeoutInSeconds};maxPoolSize=4;minPoolSize=2"; + var pool = SnowflakeDbConnectionPool.GetPoolInternal(connectionString); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + + // act + var conn1 = OpenConnection(connectionString); + var conn2 = OpenConnection(connectionString); + var conn3 = OpenConnection(connectionString); + var conn4 = OpenConnection(connectionString); + + // assert + Assert.AreEqual(4, pool.GetCurrentPoolSize()); + + // act + WaitUntilAllSessionsCreatedOrTimeout(pool); + var beforeSleepMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + Thread.Sleep(TimeSpan.FromSeconds(ExpirationTimeoutInSeconds)); + conn1.Close(); + conn2.Close(); + conn3.Close(); + conn4.Close(); + + // assert + Assert.AreEqual(2, pool.GetCurrentPoolSize()); // 2 idle sessions, but expired because close doesn't remove expired sessions + + // act + WaitUntilAllSessionsCreatedOrTimeout(pool); + var conn5 = OpenConnection(connectionString); + WaitUntilAllSessionsCreatedOrTimeout(pool); + + // assert + Assert.AreEqual(2, pool.GetCurrentPoolSize()); // 1 idle session and 1 busy + var sessionStartTimes = pool.GetIdleSessionsStartTimes(); + Assert.AreEqual(1, sessionStartTimes.Count); + Assert.That(sessionStartTimes.First(), Is.GreaterThan(beforeSleepMillis)); + Assert.That(conn5.SfSession.GetStartTime(), Is.GreaterThan(beforeSleepMillis)); + } + + [Test] + public void TestMinPoolSize() + { + // arrange + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = ConnectionString + "application=TestMinPoolSize;minPoolSize=3"; + + // act + connection.Open(); + Thread.Sleep(3000); + + // assert + var pool = SnowflakeDbConnectionPool.GetPool(connection.ConnectionString); + Assert.AreEqual(3, pool.GetCurrentPoolSize()); + + // cleanup + connection.Close(); + } + + [Test] + public void TestPreventConnectionFromReturningToPool() + { + // arrange + var connectionString = ConnectionString + "minPoolSize=0"; + var connection = OpenConnection(connectionString); + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + // act + connection.PreventPooling(); + connection.Close(); + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestReleaseConnectionWhenRollbackFails() + { + // arrange + var connectionString = ConnectionString + "minPoolSize=0"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + var commandThrowingExceptionOnlyForRollback = MockHelper.CommandThrowingExceptionOnlyForRollback(); + var mockDbProviderFactory = new Mock(); + mockDbProviderFactory.Setup(p => p.CreateCommand()).Returns(commandThrowingExceptionOnlyForRollback.Object); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + var connection = new TestSnowflakeDbConnection(mockDbProviderFactory.Object); + connection.ConnectionString = connectionString; + connection.Open(); + connection.BeginTransaction(); + Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); + + // act + connection.Close(); + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize(), "Should not return connection to the pool"); + } + + private void WaitUntilAllSessionsCreatedOrTimeout(SessionPool pool) + { + var expectingToWaitAtMostForSessionCreations = TimeSpan.FromSeconds(15); + Awaiter.WaitUntilConditionOrTimeout(() => pool.OngoingSessionCreationsCount() == 0, expectingToWaitAtMostForSessionCreations); + } + + private SnowflakeDbConnection OpenConnection(string connectionString) + { + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = connectionString; + connection.Open(); + return connection; + } + + private async Task OpenConnectionAsync(string connectionString) + { + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = connectionString; + await connection.OpenAsync().ConfigureAwait(false); + return connection; + } + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs new file mode 100644 index 000000000..801916cb0 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolChangedSessionIT.cs @@ -0,0 +1,216 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionPoolChangedSessionIT : SFBaseTest + { + private readonly QueryExecResponseData _queryExecResponseChangedRole = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = "role change", + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedDatabase = new() + { + finalDatabaseName = "database changed", + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedSchema = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = "schema changed", + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = TestEnvironment.TestConfig.warehouse + }; + + private readonly QueryExecResponseData _queryExecResponseChangedWarehouse = new() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = TestEnvironment.TestConfig.role, + finalWarehouseName = "warehouse changed" + }; + + private static PoolConfig s_previousPoolConfigRestorer; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_previousPoolConfigRestorer = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + } + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_previousPoolConfigRestorer.Reset(); + } + + [Test] + public void TestPoolDestroysConnectionWhenChangedSessionProperties() + { + var connectionString = ConnectionString + "application=Destroy;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedDatabase); + connection.Close(); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenSessionPropertiesUnchanged() + { + var connectionString = ConnectionString + "application=NoSessionChanges;ChangedSession=Destroy;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForOriginalPoolMode() + { + var connectionString = ConnectionString + "application=OriginalPoolMode;ChangedSession=OriginalPool;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedWarehouse); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + connection.Close(); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolingWhenConnectionPropertiesChangedForDefaultPoolMode() + { + var connectionString = ConnectionString + "application=DefaultPoolMode;minPoolSize=0;maxPoolSize=3"; + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedRole); + var sessionId = connection.SfSession.sessionId; + connection.Close(); + + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestPoolDestroysAndRecreatesConnection() + { + var connectionString = ConnectionString + "application=DestroyRecreateSession;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(_queryExecResponseChangedSchema); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestCompareSessionChangesCaseInsensitiveWhenUnquoted() + { + var connectionString = ConnectionString + "application=CompareCaseInsensitive;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var responseData = new QueryExecResponseData() + { + finalDatabaseName = TestEnvironment.TestConfig.database.ToLower(), + finalSchemaName = TestEnvironment.TestConfig.schema.ToUpper(), + finalRoleName = $"{char.ToUpper(TestEnvironment.TestConfig.role[0])}{TestEnvironment.TestConfig.role.Substring(1).ToLower()}", + finalWarehouseName = TestEnvironment.TestConfig.warehouse.ToLower() + }; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(responseData); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + + [Test] + public void TestCompareSessionChangesCaseSensitiveWhenQuoted() + { + var connectionString = ConnectionString + "application=CompareCaseSensitive;ChangedSession=Destroy;minPoolSize=1;maxPoolSize=3"; + + var responseData = new QueryExecResponseData() + { + finalDatabaseName = TestEnvironment.TestConfig.database, + finalSchemaName = TestEnvironment.TestConfig.schema, + finalRoleName = $"\\\"SomeQuotedValue\\\"", + finalWarehouseName = TestEnvironment.TestConfig.warehouse.ToLower() + }; + + var connection = new SnowflakeDbConnection(connectionString); + connection.Open(); + var sessionId = connection.SfSession.sessionId; + connection.SfSession.UpdateSessionProperties(responseData); + connection.Close(); + + var pool = SnowflakeDbConnectionPool.GetPool(connectionString); + Assert.AreEqual(1, pool.GetCurrentPoolSize()); + + var connection2 = new SnowflakeDbConnection(connectionString); + connection2.Open(); + Assert.AreNotEqual(sessionId, connection2.SfSession.sessionId); + connection2.Close(); + } + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs new file mode 100644 index 000000000..6a0745b23 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionPoolCommonIT.cs @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Data; +using System.Threading; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture(ConnectionPoolType.SingleConnectionCache)] + [TestFixture(ConnectionPoolType.MultipleConnectionPool)] + [NonParallelizable] + class ConnectionPoolCommonIT : SFBaseTest + { + private readonly ConnectionPoolType _connectionPoolTypeUnderTest; + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private readonly PoolConfig _previousPoolConfig; + + public ConnectionPoolCommonIT(ConnectionPoolType connectionPoolTypeUnderTest) + { + _connectionPoolTypeUnderTest = connectionPoolTypeUnderTest; + _previousPoolConfig = new PoolConfig(); + } + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(_connectionPoolTypeUnderTest); + SnowflakeDbConnectionPool.ClearAllPools(); + if (_connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache) + { + SnowflakeDbConnectionPool.SetPooling(true); + } + s_logger.Debug($"---------------- BeforeTest ---------------------"); + s_logger.Debug($"Testing Pool Type: {SnowflakeDbConnectionPool.GetConnectionPoolVersion()}"); + } + + [TearDown] + public new void AfterTest() + { + _previousPoolConfig.Reset(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [Test] + public void TestConnectionPoolMultiThreading() + { + Thread t1 = new Thread(() => ThreadProcess1(ConnectionString)); + Thread t2 = new Thread(() => ThreadProcess2(ConnectionString)); + + t1.Start(); + t2.Start(); + + t1.Join(); + t2.Join(); + } + + void ThreadProcess1(string connstr) + { + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = connstr; + conn1.Open(); + Thread.Sleep(1000); + conn1.Close(); + Thread.Sleep(4000); + Assert.AreEqual(ConnectionState.Closed, conn1.State); + } + + void ThreadProcess2(string connstr) + { + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = connstr; + conn1.Open(); + + Thread.Sleep(5000); + SFStatement statement = new SFStatement(conn1.SfSession); + SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); + Assert.AreEqual(true, resultSet.Next()); + Assert.AreEqual("1", resultSet.GetString(0)); + conn1.Close(); + } + + [Test] + public void TestConnectionPoolWithDispose() + { + if (_connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache) + { + SnowflakeDbConnectionPool.SetMaxPoolSize(1); + } + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = "bad connection string"; + Assert.Throws(() => conn1.Open()); + conn1.Close(); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + if (_connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache) + { + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString).GetCurrentPoolSize()); + } + else + { + var thrown = Assert.Throws(() => SnowflakeDbConnectionPool.GetPool(conn1.ConnectionString)); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid")); + } + } + + [Test] + public void TestFailWhenPreventingFromReturningToPoolNotOpenedConnection() + { + // arrange + var connection = new SnowflakeDbConnection(ConnectionString); + + // act + var thrown = Assert.Throws(() => connection.PreventPooling()); + + // assert + Assert.That(thrown.Message, Does.Contain("Session not yet created for this connection. Unable to prevent the session from pooling")); + } + + [Test] + public void TestRollbackTransactionOnPooledWhenExceptionOccurred() + { + var connectionString = SetPoolWithOneElement(); + object firstOpenedSessionId; + using (var connection = new SnowflakeDbConnection(connectionString)) + { + connection.Open(); + firstOpenedSessionId = connection.SfSession.sessionId; + connection.BeginTransaction(); + Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); + Assert.Throws(() => + { + using (var command = connection.CreateCommand()) + { + command.CommandText = "invalid command will throw exception and leave session with an unfinished transaction"; + command.ExecuteNonQuery(); + } + }); + } + + using (var connectionWithSessionReused = new SnowflakeDbConnection(connectionString)) + { + connectionWithSessionReused.Open(); + + Assert.AreEqual(firstOpenedSessionId, connectionWithSessionReused.SfSession.sessionId); + Assert.AreEqual(false, connectionWithSessionReused.HasActiveExplicitTransaction()); + using (var cmd = connectionWithSessionReused.CreateCommand()) + { + cmd.CommandText = "SELECT CURRENT_TRANSACTION()"; + Assert.AreEqual(DBNull.Value, cmd.ExecuteScalar()); + } + } + + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be reused and any pending transaction rolled back before it gets back to the pool"); + } + + [Test] + public void TestTransactionStatusNotTrackedForNonExplicitTransactionCalls() + { + var connectionString = SetPoolWithOneElement(); + using (var connection = new SnowflakeDbConnection(connectionString)) + { + connection.Open(); + using (var command = connection.CreateCommand()) + { + command.CommandText = "BEGIN"; // in general can be put as a part of a multi statement call and mixed with commit as well + command.ExecuteNonQuery(); + Assert.AreEqual(false, connection.HasActiveExplicitTransaction()); + } + } + } + + [Test] + public void TestRollbackTransactionOnPooledWhenConnectionClose() + { + var connectionString = SetPoolWithOneElement(); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); + + string firstOpenedSessionId; + using (var connection1 = new SnowflakeDbConnection(connectionString)) + { + connection1.Open(); + Assert.AreEqual(ExpectedPoolCountAfterOpen(), SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection session is added to the pool after close connection"); + connection1.BeginTransaction(); + Assert.AreEqual(true, connection1.HasActiveExplicitTransaction()); + using (var command = connection1.CreateCommand()) + { + firstOpenedSessionId = connection1.SfSession.sessionId; + command.CommandText = "SELECT CURRENT_TRANSACTION()"; + Assert.AreNotEqual(DBNull.Value, command.ExecuteScalar()); + } + } + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); + + using (var connection2 = new SnowflakeDbConnection(connectionString)) + { + connection2.Open(); + Assert.AreEqual(ExpectedPoolCountAfterOpen(), SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection session should be now removed from the pool"); + Assert.AreEqual(false, connection2.HasActiveExplicitTransaction()); + using (var command = connection2.CreateCommand()) + { + Assert.AreEqual(firstOpenedSessionId, connection2.SfSession.sessionId); + command.CommandText = "SELECT CURRENT_TRANSACTION()"; + Assert.AreEqual(DBNull.Value, command.ExecuteScalar()); + } + } + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); + } + + + private string SetPoolWithOneElement() + { + if (_connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache) + { + SnowflakeDbConnectionPool.SetMaxPoolSize(1); + return ConnectionString; + } + return ConnectionString + "maxPoolSize=1;minPoolSize=0"; + } + + private int ExpectedPoolCountAfterOpen() + { + return _connectionPoolTypeUnderTest == ConnectionPoolType.SingleConnectionCache ? 0 : 1; + } + + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheAsyncIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheAsyncIT.cs new file mode 100644 index 000000000..1b0ac0cf8 --- /dev/null +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheAsyncIT.cs @@ -0,0 +1,217 @@ +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Mock; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + [TestFixture] + [NonParallelizable] + public class ConnectionSinglePoolCacheAsyncIT: SFBaseTestAsync + { + private readonly PoolConfig _previousPoolConfig = new PoolConfig(); + + [SetUp] + public new void BeforeTest() + { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); + SnowflakeDbConnectionPool.ClearAllPools(); + } + + [TearDown] + public new void AfterTest() + { + _previousPoolConfig.Reset(); + } + + + [Test] + public async Task TestPutConnectionToPoolOnCloseAsync() + { + // arrange + using (var conn = new SnowflakeDbConnection(ConnectionString)) + { + Assert.AreEqual(conn.State, ConnectionState.Closed); + CancellationTokenSource connectionCancelToken = new CancellationTokenSource(); + await conn.OpenAsync(connectionCancelToken.Token).ConfigureAwait(false); + + // act + await conn.CloseAsync(connectionCancelToken.Token).ConfigureAwait(false); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn.State); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + } + } + + [Test] + public async Task TestDoNotPutInvalidConnectionToPoolAsync() + { + // arrange + var invalidConnectionString = ";connection_timeout=0"; + using (var conn = new SnowflakeDbConnection(invalidConnectionString)) + { + Assert.AreEqual(conn.State, ConnectionState.Closed); + CancellationTokenSource connectionCancelToken = new CancellationTokenSource(); + try + { + await conn.OpenAsync(connectionCancelToken.Token).ConfigureAwait(false); + Assert.Fail("OpenAsync should throw exception"); + } + catch {} + + // act + await conn.CloseAsync(connectionCancelToken.Token).ConfigureAwait(false); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn.State); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + } + } + + [Test] + public void TestConnectionPoolWithInvalidOpenAsync() + { + SnowflakeDbConnectionPool.SetMaxPoolSize(10); + // make the connection string unique so it won't pick up connection + // pooled by other test cases. + string connStr = ConnectionString + "application=conn_pool_test_invalid_openasync"; + using (var connection = new SnowflakeDbConnection()) + { + connection.ConnectionString = connStr; + // call openAsync but do not wait and destroy it direct + // so the session is initialized with empty token + connection.OpenAsync(); + } + + // use the same connection string to make a new connection + // to ensure the invalid connection made previously is not pooled + using (var connection1 = new SnowflakeDbConnection()) + { + connection1.ConnectionString = connStr; + // this will not open a new session but get the invalid connection from pool + connection1.Open(); + // Now run query with connection1 + var command = connection1.CreateCommand(); + command.CommandText = "select 1, 2, 3"; + + try + { + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + // Process each column as appropriate + reader.GetFieldValue(i); + } + } + } + } + catch (SnowflakeDbException) + { + // fail the test case if anything wrong. + Assert.Fail(); + } + } + } + + [Test(Description = "test connection pooling with concurrent connection using async calls")] + public void TestConcurrentConnectionPoolingAsync() + { + // add test case name in connection string to make in unique for each test case + string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingAsync"; + SnowflakeDbConnectionPool.SetMaxPoolSize(10); + SnowflakeDbConnectionPool.SetTimeout(3); // set short pooling timeout to cover the case that connection expired + ConcurrentPoolingAsyncHelper(connStr, true, 12, 100, 100); + SnowflakeDbConnectionPool.SetTimeout(3600); + } + + [Test(Description = "test connection pooling with concurrent connection and using async calls no close call for connection. Connection is closed when Dispose() is called by framework.")] + public void TestConcurrentConnectionPoolingDisposeAsync() + { + // add test case name in connection string to make in unique for each test case + string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingDisposeAsync"; + SnowflakeDbConnectionPool.SetMaxPoolSize(10); + SnowflakeDbConnectionPool.SetTimeout(3); // set short pooling timeout to cover the case that connection expired + ConcurrentPoolingAsyncHelper(connStr, false, 12, 100, 100); + SnowflakeDbConnectionPool.SetTimeout(3600); + } + + public static void ConcurrentPoolingAsyncHelper(string connectionString, bool closeConnection, int tasksCount, int connectionsInTask, int abandonedConnectionsCount) + { + var tasks = new Task[tasksCount + 1]; + for (int i = 0; i < tasksCount; i++) + { + tasks[i] = QueryExecutionTaskAsync(connectionString, closeConnection, connectionsInTask); + } + // cover the case of invalid sessions to ensure that won't + // break connection pooling + tasks[tasksCount] = InvalidConnectionTaskAsync(connectionString, abandonedConnectionsCount); + Task.WaitAll(tasks); + } + + // task to execute query with new connection in a loop + static async Task QueryExecutionTaskAsync(string connectionString, bool closeConnection, int times) + { + for (int i = 0; i < times; i++) + { + using (var conn = new SnowflakeDbConnection(connectionString)) + { + await conn.OpenAsync().ConfigureAwait(false); + using (DbCommand cmd = conn.CreateCommand()) + { + cmd.CommandText = "select 1, 2, 3"; + try + { + using (DbDataReader reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false)) + { + while (await reader.ReadAsync().ConfigureAwait(false)) + { + for (int j = 0; j < reader.FieldCount; j++) + { + // Process each column as appropriate + await reader.GetFieldValueAsync(j).ConfigureAwait(false); + } + } + } + } + catch (Exception e) + { + Assert.Fail("Caught unexpected exception: " + e); + } + } + if (closeConnection) + { + await conn.CloseAsync(new CancellationTokenSource().Token).ConfigureAwait(false); + } + } + } + } + + // task to generate invalid(not finish open) connections in a loop + static async Task InvalidConnectionTaskAsync(string connectionString, int times) + { + for (int i = 0; i < times; i++) + { + using (var conn = new SnowflakeDbConnection(connectionString)) + { + // intentionally not using await so the connection + // will be disposed with invalid underlying session + conn.OpenAsync(); + }; + // wait 100ms each time so the invalid sessions are generated + // roughly at the same speed as connections for query tasks + await Task.Delay(100).ConfigureAwait(false); + } + } + + } +} diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs similarity index 69% rename from Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs rename to Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs index 4f5020538..956f7f00c 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/ConnectionSinglePoolCacheIT.cs @@ -1,43 +1,35 @@ -/* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. - */ - -using Snowflake.Data.Tests.Util; using System; using System.Data; using System.Data.Common; using System.Threading; using System.Threading.Tasks; -using Snowflake.Data.Core; -using Snowflake.Data.Client; -using Snowflake.Data.Log; using NUnit.Framework; +using Snowflake.Data.Client; using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Mock; +using Snowflake.Data.Tests.Util; +using Moq; namespace Snowflake.Data.Tests.IntegrationTests { - [TestFixture, NonParallelizable] - class SFConnectionPoolIT : SFBaseTest + [TestFixture] + [NonParallelizable] + public class ConnectionSinglePoolCacheIT: SFBaseTest { - private static PoolConfig s_previousPoolConfig; + private readonly PoolConfig _previousPoolConfig = new PoolConfig(); - [OneTimeSetUp] - public static void BeforeAllTests() - { - s_previousPoolConfig = new PoolConfig(); - } - [SetUp] public new void BeforeTest() { - SnowflakeDbConnectionPool.SetPooling(true); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); SnowflakeDbConnectionPool.ClearAllPools(); + SnowflakeDbConnectionPool.SetPooling(true); } [TearDown] public new void AfterTest() { - s_previousPoolConfig.Reset(); + _previousPoolConfig.Reset(); } [OneTimeTearDown] @@ -47,7 +39,20 @@ public static void AfterAllTests() } [Test] - // test connection pooling with concurrent connection + public void TestBasicConnectionPool() + { + SnowflakeDbConnectionPool.SetMaxPoolSize(1); + + var conn1 = new SnowflakeDbConnection(ConnectionString); + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + conn1.Close(); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); + } + + [Test] public void TestConcurrentConnectionPooling() { // add test case name in connection string to make in unique for each test case @@ -75,6 +80,7 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio const int PoolTimeout = 3; // reset to default settings in case it changed by other test cases + Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(connectionString).GetPooling()); // to instantiate pool SnowflakeDbConnectionPool.SetMaxPoolSize(10); SnowflakeDbConnectionPool.SetTimeout(PoolTimeout); @@ -87,8 +93,6 @@ static void ConcurrentPoolingHelper(string connectionString, bool closeConnectio }); } Task.WaitAll(threads); - // set pooling timeout back to default to avoid impact on other test cases - SnowflakeDbConnectionPool.SetTimeout(3600); } // thead to execute query with new connection in a loop @@ -131,45 +135,31 @@ static void QueryExecutionThread(string connectionString, bool closeConnection) } [Test] - public void TestBasicConnectionPool() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - - var conn1 = new SnowflakeDbConnection(ConnectionString); - conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); - conn1.Close(); - - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - } - - [Test] - public void TestConnectionPool() + public void TestPoolContainsClosedConnections() // old name: TestConnectionPool { var conn1 = new SnowflakeDbConnection(ConnectionString); conn1.Open(); Assert.AreEqual(ConnectionState.Open, conn1.State); conn1.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); var conn2 = new SnowflakeDbConnection(); conn2.ConnectionString = ConnectionString; conn2.Open(); Assert.AreEqual(ConnectionState.Open, conn2.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); conn2.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); - SnowflakeDbConnectionPool.ClearAllPools(); } [Test] - public void TestConnectionPoolIsFull() + public void TestPoolContainsAtMostMaxPoolSizeConnections() // old name: TestConnectionPoolFull { SnowflakeDbConnectionPool.SetMaxPoolSize(2); + var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; conn1.Open(); @@ -179,53 +169,77 @@ public void TestConnectionPoolIsFull() conn2.ConnectionString = ConnectionString + " retryCount=1"; conn2.Open(); Assert.AreEqual(ConnectionState.Open, conn2.State); - + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + conn1.Close(); + conn2.Close(); + Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); var conn3 = new SnowflakeDbConnection(); conn3.ConnectionString = ConnectionString + " retryCount=2"; conn3.Open(); Assert.AreEqual(ConnectionState.Open, conn3.State); - SnowflakeDbConnectionPool.ClearAllPools(); - conn1.Close(); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - conn2.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + var conn4 = new SnowflakeDbConnection(); + conn4.ConnectionString = ConnectionString + " retryCount=3"; + conn4.Open(); + Assert.AreEqual(ConnectionState.Open, conn4.State); + conn3.Close(); Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + conn4.Close(); + Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); Assert.AreEqual(ConnectionState.Closed, conn3.State); + Assert.AreEqual(ConnectionState.Closed, conn4.State); SnowflakeDbConnectionPool.ClearAllPools(); } [Test] - public void TestConnectionPoolExpirationWorks() + public void TestConnectionPoolDisableFromPoolManagerLevel() { - SnowflakeDbConnectionPool.SetMaxPoolSize(2); - SnowflakeDbConnectionPool.SetTimeout(10); - + // arrange + SnowflakeDbConnectionPool.SetPooling(false); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; + // act conn1.Open(); - conn1.Close(); - SnowflakeDbConnectionPool.SetTimeout(-1); - var conn2 = new SnowflakeDbConnection(); - conn2.ConnectionString = ConnectionString; - conn2.Open(); - conn2.Close(); - var conn3 = new SnowflakeDbConnection(); - conn3.ConnectionString = ConnectionString; - conn3.Open(); - conn3.Close(); + // assert + Assert.AreEqual(ConnectionState.Open, conn1.State); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - // The pooling timeout should apply to all connections being pooled, - // not just the connections created after the new setting, - // so expected result should be 0 + // act + conn1.Close(); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + } + + [Test] + public void TestConnectionPoolDisable() + { + // arrange + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); SnowflakeDbConnectionPool.SetPooling(false); + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = ConnectionString; + + // act + conn1.Open(); + + // assert + Assert.AreEqual(ConnectionState.Open, conn1.State); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + + // act + conn1.Close(); + + // assert + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); } [Test] @@ -258,143 +272,75 @@ public void TestConnectionPoolClean() Assert.AreEqual(ConnectionState.Closed, conn1.State); Assert.AreEqual(ConnectionState.Closed, conn2.State); Assert.AreEqual(ConnectionState.Closed, conn3.State); - SnowflakeDbConnectionPool.ClearAllPools(); } [Test] - public void TestConnectionPoolFull() + public void TestConnectionPoolExpirationWorks() { SnowflakeDbConnectionPool.SetMaxPoolSize(2); + SnowflakeDbConnectionPool.SetTimeout(10); var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString; + conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); + conn1.Close(); + SnowflakeDbConnectionPool.SetTimeout(0); var conn2 = new SnowflakeDbConnection(); - conn2.ConnectionString = ConnectionString + " retryCount=1"; + conn2.ConnectionString = ConnectionString; conn2.Open(); - Assert.AreEqual(ConnectionState.Open, conn2.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - conn1.Close(); conn2.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + var conn3 = new SnowflakeDbConnection(); - conn3.ConnectionString = ConnectionString + " retryCount=2"; + conn3.ConnectionString = ConnectionString; conn3.Open(); - Assert.AreEqual(ConnectionState.Open, conn3.State); - - var conn4 = new SnowflakeDbConnection(); - conn4.ConnectionString = ConnectionString + " retryCount=3"; - conn4.Open(); - Assert.AreEqual(ConnectionState.Open, conn4.State); - conn3.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - conn4.Close(); - Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(ConnectionState.Closed, conn2.State); - Assert.AreEqual(ConnectionState.Closed, conn3.State); - Assert.AreEqual(ConnectionState.Closed, conn4.State); - SnowflakeDbConnectionPool.ClearAllPools(); - } - - [Test] - public void TestConnectionPoolMultiThreading() - { - Thread t1 = new Thread(() => ThreadProcess1(ConnectionString)); - Thread t2 = new Thread(() => ThreadProcess2(ConnectionString)); - - t1.Start(); - t2.Start(); - - t1.Join(); - t2.Join(); - } - - void ThreadProcess1(string connstr) - { - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = connstr; - conn1.Open(); - Thread.Sleep(1000); - conn1.Close(); - Thread.Sleep(4000); - Assert.AreEqual(ConnectionState.Closed, conn1.State); - } - - void ThreadProcess2(string connstr) - { - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = connstr; - conn1.Open(); - Thread.Sleep(5000); - SFStatement statement = new SFStatement(conn1.SfSession); - SFBaseResultSet resultSet = statement.Execute(0, "select 1", null, false, false); - Assert.AreEqual(true, resultSet.Next()); - Assert.AreEqual("1", resultSet.GetString(0)); - SnowflakeDbConnectionPool.ClearAllPools(); - SnowflakeDbConnectionPool.SetMaxPoolSize(0); - SnowflakeDbConnectionPool.SetPooling(false); + // The pooling timeout should apply to all connections being pooled, + // not just the connections created after the new setting, + // so expected result should be 0 + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetPool(ConnectionString).GetCurrentPoolSize()); } [Test] - public void TestConnectionPoolDisable() + public void TestPreventConnectionFromReturningToPool() { - SnowflakeDbConnectionPool.SetPooling(false); + // arrange + var connection = new SnowflakeDbConnection(ConnectionString); + connection.Open(); + var pool = SnowflakeDbConnectionPool.GetPool(ConnectionString); + Assert.AreEqual(0, pool.GetCurrentPoolSize()); - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ConnectionString; - conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); - conn1.Close(); + // act + connection.PreventPooling(); + connection.Close(); - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize()); } [Test] - public void TestConnectionPoolWithDispose() + public void TestReleaseConnectionWhenRollbackFails() { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ""; - try - { - conn1.Open(); - } - catch (SnowflakeDbException ex) - { - conn1.Close(); - } - - Assert.AreEqual(ConnectionState.Closed, conn1.State); + // arrange + SnowflakeDbConnectionPool.SetMaxPoolSize(10); + var commandThrowingExceptionOnlyForRollback = MockHelper.CommandThrowingExceptionOnlyForRollback(); + var mockDbProviderFactory = new Mock(); + mockDbProviderFactory.Setup(p => p.CreateCommand()).Returns(commandThrowingExceptionOnlyForRollback.Object); Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - } + var connection = new TestSnowflakeDbConnection(mockDbProviderFactory.Object); + connection.ConnectionString = ConnectionString; + connection.Open(); + connection.BeginTransaction(); + Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); + // no Rollback or Commit; during internal Rollback while closing a connection a mocked exception will be thrown - [Test] - public void TestConnectionPoolTurnOff() - { - SnowflakeDbConnectionPool.SetPooling(false); - SnowflakeDbConnectionPool.SetPooling(true); - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - SnowflakeDbConnectionPool.ClearAllPools(); - - var conn1 = new SnowflakeDbConnection(); - conn1.ConnectionString = ConnectionString; - conn1.Open(); - Assert.AreEqual(ConnectionState.Open, conn1.State); - conn1.Close(); - - Assert.AreEqual(ConnectionState.Closed, conn1.State); - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + // act + connection.Close(); - SnowflakeDbConnectionPool.SetPooling(false); - //Put a breakpoint at SFSession close function, after connection pool is off, it will send close session request. + // assert + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Should not return connection to the pool"); } [Test] @@ -411,14 +357,14 @@ public void TestCloseSessionAfterTimeout() Assert.IsTrue(session.IsEstablished()); Thread.Sleep(SessionTimeoutSeconds * 1000); // wait until the session is expired var conn2 = new SnowflakeDbConnection(ConnectionString); - + // act conn2.Open(); // it gets a session from the caching pool firstly closing session of conn1 in background Thread.Sleep(TimeForBackgroundSessionCloseMillis); // wait for closing expired session - + // assert Assert.IsFalse(session.IsEstablished()); - + // cleanup conn2.Close(); } diff --git a/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs b/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs index 595fbb65d..fd2e79409 100644 --- a/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/EasyLoggingIT.cs @@ -24,7 +24,7 @@ public static void BeforeAll() Directory.CreateDirectory(s_workingDirectory); } } - + [OneTimeTearDown] public static void AfterAll() { @@ -36,7 +36,7 @@ public static void AfterEach() { EasyLoggingStarter.Instance.Reset(EasyLoggingLogLevel.Warn); } - + [Test] public void TestEnableEasyLogging() { @@ -48,7 +48,7 @@ public void TestEnableEasyLogging() // act conn.Open(); - + // assert Assert.IsTrue(EasyLoggerManager.HasEasyLoggingAppender()); } @@ -65,13 +65,13 @@ public void TestFailToEnableEasyLoggingForWrongConfiguration() // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } - + [Test] public void TestFailToEnableEasyLoggingWhenConfigHasWrongPermissions() { @@ -79,19 +79,19 @@ public void TestFailToEnableEasyLoggingWhenConfigHasWrongPermissions() { Assert.Ignore("skip test on Windows"); } - + // arrange var configFilePath = CreateConfigTempFile(s_workingDirectory, Config("WARN", s_workingDirectory)); Syscall.chmod(configFilePath, FilePermissions.S_IRUSR | FilePermissions.S_IWUSR | FilePermissions.S_IWGRP); using (IDbConnection conn = new SnowflakeDbConnection()) { conn.ConnectionString = ConnectionString + $"CLIENT_CONFIG_FILE={configFilePath}"; - + // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } @@ -103,22 +103,22 @@ public void TestFailToEnableEasyLoggingWhenLogDirectoryNotAccessible() { Assert.Ignore("skip test on Windows"); } - + // arrange var configFilePath = CreateConfigTempFile(s_workingDirectory, Config("WARN", "/")); using (IDbConnection conn = new SnowflakeDbConnection()) { conn.ConnectionString = ConnectionString + $"CLIENT_CONFIG_FILE={configFilePath}"; - + // act var thrown = Assert.Throws(() => conn.Open()); - + // assert - Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to connect")); + Assert.That(thrown.Message, Does.Contain("Connection string is invalid: Unable to initialize session")); Assert.That(thrown.InnerException.Message, Does.Contain("Failed to create logs directory")); Assert.IsFalse(EasyLoggerManager.HasEasyLoggingAppender()); } } } -} \ No newline at end of file +} diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs index cc4fea738..7c17f9bbd 100644 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFConnectionIT.cs @@ -4,6 +4,7 @@ using System.Data.Common; using System.Net; +using Snowflake.Data.Core.Session; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.IntegrationTests @@ -35,7 +36,7 @@ public void TestBasicConnection() conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); // Data source is empty string for now Assert.AreEqual("", ((SnowflakeDbConnection)conn).DataSource); @@ -121,7 +122,7 @@ public void TestIncorrectUserOrPasswordBasicConnection() { conn.Open(); Assert.Fail(); - + } catch (SnowflakeDbException e) { @@ -165,7 +166,6 @@ public void TestConnectionIsNotMarkedAsOpenWhenWasNotCorrectlyOpenedBefore(bool [Test] public void TestConnectionIsNotMarkedAsOpenWhenWasNotCorrectlyOpenedWithUsingClause() { - SnowflakeDbConnectionPool.SetPooling(true); for (int i = 0; i < 2; ++i) { s_logger.Debug($"Running try #{i}"); @@ -268,7 +268,7 @@ public void TestConnectString() cmd.CommandText = $"insert into {TableName} Values ('test 1', 1);"; cmd.ExecuteNonQuery(); } - + using (var conn1 = new SnowflakeDbConnection()) { conn1.ConnectionString = String.Format("scheme={0};host={1};port={2};" + @@ -297,9 +297,9 @@ public void TestConnectString() } conn1.Close(); - Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn1.State); } - + using (IDbCommand cmd = conn.CreateCommand()) { //cmd.CommandText = "drop database \"dlTest\""; @@ -376,6 +376,7 @@ public void TestConnectViaSecureString() } [Test] + [Retry(2)] public void TestLoginTimeout() { using (IDbConnection conn = new MockSnowflakeDbConnection()) @@ -417,7 +418,7 @@ public void TestLoginWithMaxRetryReached() { using (IDbConnection conn = new MockSnowflakeDbConnection()) { - string maxRetryConnStr = ConnectionString + "maxHttpRetries=5"; + string maxRetryConnStr = ConnectionString + "maxHttpRetries=7"; conn.ConnectionString = maxRetryConnStr; @@ -437,15 +438,16 @@ public void TestLoginWithMaxRetryReached() } stopwatch.Stop(); - // retry 5 times with starting backoff of 1 second - // but should not delay more than the max possible seconds after 5 retries - // and should not take less time than the minimum possible seconds after 5 retries - Assert.Less(stopwatch.ElapsedMilliseconds, 79 * 1000); + // retry 7 times with starting backoff of 1 second + // backoff is chosen randomly it can drop to 0. So the minimal backoff time could be 1 + 0 + 0 + 0 + 0 + 0 + 0 = 1 + // The maximal backoff time could be 1 + 2 + 5 + 10 + 21 + 42 + 85 = 166 + Assert.Less(stopwatch.ElapsedMilliseconds, 166 * 1000); Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, 1 * 1000); } } [Test] + [Retry(2)] public void TestLoginTimeoutWithRetryTimeoutLesserThanConnectionTimeout() { using (IDbConnection conn = new MockSnowflakeDbConnection()) @@ -492,13 +494,13 @@ public void TestDefaultLoginTimeout() conn.ConnectionString = ConnectionString; // Default timeout is 300 sec - Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout, conn.ConnectionTimeout); Assert.AreEqual(conn.State, ConnectionState.Closed); Stopwatch stopwatch = Stopwatch.StartNew(); try - { - conn.Open(); + { + conn.Open(); Assert.Fail(); } catch (AggregateException e) @@ -771,7 +773,7 @@ public void TestConnectionDispose() using (IDbConnection conn = new SnowflakeDbConnection()) { - // Previous connection would be disposed and + // Previous connection would be disposed and // uncommitted txn would rollback at this point conn.ConnectionString = ConnectionString; conn.Open(); @@ -912,6 +914,30 @@ public void TestSSOConnectionWithUser() + ";authenticator=externalbrowser;user=qa@snowflakecomputing.com"; conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); + + // connection pooling is disabled for external browser by default + Assert.AreEqual(false, SnowflakeDbConnectionPool.GetPool(conn.ConnectionString).GetPooling()); + using (IDbCommand command = conn.CreateCommand()) + { + command.CommandText = "SELECT CURRENT_USER()"; + Assert.AreEqual("QA", command.ExecuteScalar().ToString()); + } + } + } + + [Test] + [Ignore("This test requires manual interaction and therefore cannot be run in CI")] + public void TestSSOConnectionWithPoolingEnabled() + { + // Use external browser to log in using proper password for qa@snowflakecomputing.com + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString + = ConnectionStringWithoutAuth + + ";authenticator=externalbrowser;user=qa@snowflakecomputing.com;POOLINGENABLED=TRUE"; + conn.Open(); + Assert.AreEqual(ConnectionState.Open, conn.State); + Assert.AreEqual(true, SnowflakeDbConnectionPool.GetPool(conn.ConnectionString).GetPooling()); using (IDbCommand command = conn.CreateCommand()) { command.CommandText = "SELECT CURRENT_USER()"; @@ -943,7 +969,7 @@ public void TestSSOConnectionWithUserAsync() } } } - + [Test] [Ignore("This test requires manual interaction and therefore cannot be run in CI")] public void TestSSOConnectionWithUserAndDisableConsoleLogin() @@ -992,7 +1018,7 @@ public void TestSSOConnectionWithUserAsyncAndDisableConsoleLogin() [Ignore("This test requires manual interaction and therefore cannot be run in CI")] public void TestSSOConnectionTimeoutAfter10s() { - // Do not log in by external browser - timeout after 10s should happen + // Do not log in by external browser - timeout after 10s should happen int waitSeconds = 10; Stopwatch stopwatch = Stopwatch.StartNew(); Assert.Throws(() => @@ -1016,7 +1042,7 @@ public void TestSSOConnectionTimeoutAfter10s() // timeout after specified number of seconds Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, waitSeconds * 1000); - // and not later than 5s after expected time + // and not later than 5s after expected time Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (waitSeconds + 5) * 1000); } @@ -1148,7 +1174,7 @@ public void TestJwtMissingConnectionSettingConnection() } catch (SnowflakeDbException e) { - // Missing PRIVATE_KEY_FILE connection setting required for + // Missing PRIVATE_KEY_FILE connection setting required for // authenticator =snowflake_jwt Assert.AreEqual(270008, e.ErrorCode); } @@ -1474,7 +1500,7 @@ public void TestMultipleConnectionWithDifferentHttpHandlerSettings() conn8.Open(); } - // Another authenticated proxy with bypasslist, but this will create a new httpclient because + // Another authenticated proxy with bypasslist, but this will create a new httpclient because // InsecureMode=true using (var conn9 = new SnowflakeDbConnection()) { @@ -1558,10 +1584,10 @@ public void TestNonProxyHostShouldBypassProxyServer(string regexHost, string pro var nonProxyHosts = string.Format(regexHost, $"{host}"); conn.ConnectionString = $"{ConnectionString}USEPROXY=true;PROXYHOST={proxyHost};NONPROXYHOSTS={nonProxyHosts};PROXYPORT=3128;"; - + // Act conn.Open(); - + // Assert // The connection would fail to open if the web proxy would be used because the proxy is configured to a non-existent host. Assert.AreEqual(ConnectionState.Open, conn.State); @@ -1583,11 +1609,11 @@ public void TestNonProxyHostShouldNotBypassProxyServer(string regexHost, string var nonProxyHosts = string.Format(regexHost, $"{testConfig.host}"); conn.ConnectionString = $"{ConnectionString}connection_timeout=5;USEPROXY=true;PROXYHOST={proxyHost};NONPROXYHOSTS={nonProxyHosts};PROXYPORT=3128;"; - + // Act/Assert // The connection would fail to open if the web proxy would be used because the proxy is configured to a non-existent host. var exception = Assert.Throws(() => conn.Open()); - + // Assert Assert.AreEqual(270001, exception.ErrorCode); AssertIsConnectionFailure(exception); @@ -1711,12 +1737,11 @@ public void TestEscapeChar() { using (IDbConnection conn = new SnowflakeDbConnection()) { - SnowflakeDbConnectionPool.SetPooling(false); - conn.ConnectionString = ConnectionString + "key1=test\'password;key2=test\"password;key3=test==password"; + conn.ConnectionString = ConnectionString + "poolingEnabled=false;key1=test\'password;key2=test\"password;key3=test==password"; conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); // Data source is empty string for now Assert.AreEqual("", ((SnowflakeDbConnection)conn).DataSource); @@ -1738,12 +1763,11 @@ public void TestEscapeChar1() { using (IDbConnection conn = new SnowflakeDbConnection()) { - SnowflakeDbConnectionPool.SetPooling(false); - conn.ConnectionString = ConnectionString + "key==word=value; key1=\"test;password\"; key2=\"test=password\""; + conn.ConnectionString = ConnectionString + "poolingEnabled=false;key==word=value; key1=\"test;password\"; key2=\"test=password\""; conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); // Data source is empty string for now Assert.AreEqual("", ((SnowflakeDbConnection)conn).DataSource); @@ -1758,14 +1782,13 @@ public void TestEscapeChar1() Assert.AreEqual(ConnectionState.Closed, conn.State); } } - + [Test] [Ignore("Ignore this test. Please run this manually, since it takes 4 hrs to finish.")] public void TestHeartBeat() { - SnowflakeDbConnectionPool.SetPooling(false); var conn = new SnowflakeDbConnection(); - conn.ConnectionString = ConnectionString + ";CLIENT_SESSION_KEEP_ALIVE=true"; + conn.ConnectionString = ConnectionString + "poolingEnabled=false;CLIENT_SESSION_KEEP_ALIVE=true"; conn.Open(); Thread.Sleep(TimeSpan.FromSeconds(14430)); // more than 4 hrs @@ -1784,17 +1807,14 @@ public void TestHeartBeat() public void TestHeartBeatWithConnectionPool() { SnowflakeDbConnectionPool.ClearAllPools(); - SnowflakeDbConnectionPool.SetMaxPoolSize(2); - SnowflakeDbConnectionPool.SetTimeout(14800); - SnowflakeDbConnectionPool.SetPooling(true); var conn = new SnowflakeDbConnection(); - conn.ConnectionString = ConnectionString + ";CLIENT_SESSION_KEEP_ALIVE=true"; + conn.ConnectionString = ConnectionString + "maxPoolSize=2;minPoolSize=0;expirationTimeout=14800;CLIENT_SESSION_KEEP_ALIVE=true"; conn.Open(); conn.Close(); Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - + var conn1 = new SnowflakeDbConnection(); conn1.ConnectionString = ConnectionString + ";CLIENT_SESSION_KEEP_ALIVE=true"; conn1.Open(); @@ -1816,10 +1836,9 @@ public void TestKeepAlive() { // create 100 connections, one per second var connCount = 100; - // pooled connectin expire in 5 seconds so after 5 seconds, + // pooled connection expires in 5 seconds so after 5 seconds, // one connection per second will be closed - SnowflakeDbConnectionPool.SetTimeout(5); - SnowflakeDbConnectionPool.SetMaxPoolSize(20); + var connectionString = ConnectionString + "maxPoolSize=20;ExpirationTimeout=5;CLIENT_SESSION_KEEP_ALIVE=true"; // heart beat interval is validity/4 so send out per 5 seconds HeartBeatBackground.setValidity(20); try @@ -1828,7 +1847,7 @@ public void TestKeepAlive() { using (var conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString + ";CLIENT_SESSION_KEEP_ALIVE=true"; + conn.ConnectionString = connectionString; conn.Open(); } Thread.Sleep(TimeSpan.FromSeconds(1)); @@ -1865,7 +1884,7 @@ public void TestCancelLoginBeforeTimeout() conn.ConnectionString = infiniteLoginTimeOut; Assert.AreEqual(conn.State, ConnectionState.Closed); - // At this point the connection string has not been parsed, it will return the + // At this point the connection string has not been parsed, it will return the // default value //Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); @@ -1873,11 +1892,11 @@ public void TestCancelLoginBeforeTimeout() Task connectTask = conn.OpenAsync(connectionCancelToken.Token); // Sleep for more than the default timeout to make sure there are no false positive) - Thread.Sleep((SFSessionHttpClientProperties.s_retryTimeoutDefault + 10) * 1000); + Thread.Sleep(SFSessionHttpClientProperties.DefaultRetryTimeout.Add(TimeSpan.FromSeconds(10))); Assert.AreEqual(ConnectionState.Connecting, conn.State); - // Cancel the connection because it will never succeed since there is no + // Cancel the connection because it will never succeed since there is no // connection_timeout defined logger.Debug("connectionCancelToken.Cancel "); connectionCancelToken.Cancel(); @@ -1891,7 +1910,7 @@ public void TestCancelLoginBeforeTimeout() Assert.AreEqual( "System.Threading.Tasks.TaskCanceledException", e.InnerException.GetType().ToString()); - + } Assert.AreEqual(ConnectionState.Closed, conn.State); @@ -1935,6 +1954,7 @@ public void TestAsyncLoginTimeout() } [Test] + [Retry(2)] public void TestAsyncLoginTimeoutWithRetryTimeoutLesserThanConnectionTimeout() { using (var conn = new MockSnowflakeDbConnection()) @@ -1999,7 +2019,7 @@ public void TestAsyncDefaultLoginTimeout() Assert.LessOrEqual(stopwatch.ElapsedMilliseconds, (conn.ConnectionTimeout + 1) * 1000); Assert.AreEqual(ConnectionState.Closed, conn.State); - Assert.AreEqual(SFSessionHttpClientProperties.s_retryTimeoutDefault, conn.ConnectionTimeout); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultRetryTimeout.TotalSeconds, conn.ConnectionTimeout); } } @@ -2042,7 +2062,7 @@ public void TestCloseAsyncWithCancellation() { // https://docs.microsoft.com/en-us/dotnet/api/system.data.common.dbconnection.close // https://docs.microsoft.com/en-us/dotnet/api/system.data.common.dbconnection.closeasync - // An application can call Close or CloseAsync more than one time. + // An application can call Close or CloseAsync more than one time. // No exception is generated. using (var conn = new SnowflakeDbConnection()) { @@ -2078,7 +2098,7 @@ public void TestCloseAsync() { // https://docs.microsoft.com/en-us/dotnet/api/system.data.common.dbconnection.close // https://docs.microsoft.com/en-us/dotnet/api/system.data.common.dbconnection.closeasync - // An application can call Close or CloseAsync more than one time. + // An application can call Close or CloseAsync more than one time. // No exception is generated. using (var conn = new SnowflakeDbConnection()) { @@ -2154,7 +2174,7 @@ public void TestExplicitTransactionOperationsTracked() conn.BeginTransaction().Rollback(); Assert.AreEqual(false, conn.HasActiveExplicitTransaction()); - + conn.BeginTransaction().Commit(); Assert.AreEqual(false, conn.HasActiveExplicitTransaction()); } @@ -2207,7 +2227,7 @@ public void TestAsyncOktaConnectionUntilMaxTimeout() } } } - + [Test] [Ignore("This test requires established dev Okta SSO and credentials matching Snowflake user")] public void TestNativeOktaSuccess() @@ -2217,13 +2237,13 @@ public void TestNativeOktaSuccess() var oktaPassword = "***"; using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionStringWithoutAuth + + conn.ConnectionString = ConnectionStringWithoutAuth + $";authenticator={oktaUrl};user={oktaUser};password={oktaPassword};"; conn.Open(); Assert.AreEqual(ConnectionState.Open, conn.State); } } - + [Test] public void TestConnectStringWithQueryTag() { @@ -2231,17 +2251,26 @@ public void TestConnectStringWithQueryTag() { string expectedQueryTag = "Test QUERY_TAG 12345"; conn.ConnectionString = ConnectionString + $";query_tag={expectedQueryTag}"; - + conn.Open(); var command = conn.CreateCommand(); // This query itself will be part of the history and will have the query tag command.CommandText = "SELECT QUERY_TAG FROM table(information_schema.query_history_by_session())"; var queryTag = command.ExecuteScalar(); - + Assert.AreEqual(expectedQueryTag, queryTag); } } - + + [Test] + public void TestUseMultiplePoolsConnectionPoolByDefault() + { + // act + var poolVersion = SnowflakeDbConnectionPool.GetConnectionPoolVersion(); + + // assert + Assert.AreEqual(ConnectionPoolType.MultipleConnectionPool, poolVersion); + } } } diff --git a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolAsyncIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolAsyncIT.cs deleted file mode 100644 index 05f7ed17f..000000000 --- a/Snowflake.Data.Tests/IntegrationTests/SFConnectionPoolAsyncIT.cs +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. - */ - -using Snowflake.Data.Tests.Util; -using System; -using System.Data; -using System.Data.Common; -using System.Threading; -using System.Threading.Tasks; -using Snowflake.Data.Client; -using Snowflake.Data.Core; -using Snowflake.Data.Log; -using Snowflake.Data.Tests.Mock; -using Moq; -using NUnit.Framework; - -namespace Snowflake.Data.Tests.IntegrationTests -{ - [TestFixture, NonParallelizable] - class SFConnectionPoolITAsync : SFBaseTestAsync - { - private static PoolConfig s_previousPoolConfigRestorer; - - [OneTimeSetUp] - public static void BeforeAllTests() - { - s_previousPoolConfigRestorer = new PoolConfig(); - } - - [SetUp] - public new void BeforeTest() - { - SnowflakeDbConnectionPool.SetPooling(true); - SnowflakeDbConnectionPool.ClearAllPools(); - } - - [TearDown] - public new void AfterTest() - { - s_previousPoolConfigRestorer.Reset(); - } - - [OneTimeTearDown] - public static void AfterAllTests() - { - SnowflakeDbConnectionPool.ClearAllPools(); - } - - [Test] - public void TestConnectionPoolWithAsync() - { - using (var conn = new MockSnowflakeDbConnection()) - { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - - int timeoutSec = 0; - string infiniteLoginTimeOut = $";connection_timeout={timeoutSec}"; - - conn.ConnectionString = infiniteLoginTimeOut; - - Assert.AreEqual(conn.State, ConnectionState.Closed); - - CancellationTokenSource connectionCancelToken = new CancellationTokenSource(); - try - { - conn.OpenAsync(connectionCancelToken.Token); - } - catch (SnowflakeDbException ex) - { - conn.CloseAsync(connectionCancelToken.Token); - } - - Thread.Sleep(10 * 1000); - Assert.AreEqual(ConnectionState.Closed, conn.State); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - } - } - - [Test] - public void TestConnectionPoolWithInvalidOpenAsync() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(10); - // make the connection string unique so it won't pick up connection - // pooled by other test cases. - string connStr = ConnectionString + ";application=conn_pool_test_invalid_openasync"; - using (var connection = new SnowflakeDbConnection()) - { - connection.ConnectionString = connStr; - // call openAsync but do not wait and destroy it direct - // so the session is initialized with empty token - connection.OpenAsync(); - } - - // use the same connection string to make a new connection - // to ensure the invalid connection made previously is not pooled - using (var connection1 = new SnowflakeDbConnection()) - { - connection1.ConnectionString = connStr; - // this will not open a new session but get the invalid connection from pool - connection1.Open(); - // Now run query with connection1 - var command = connection1.CreateCommand(); - command.CommandText = "select 1, 2, 3"; - - try - { - using (var reader = command.ExecuteReader()) - { - while (reader.Read()) - { - for (int i = 0; i < reader.FieldCount; i++) - { - // Process each column as appropriate - reader.GetFieldValue(i); - } - } - } - } - catch (SnowflakeDbException) - { - // fail the test case if anything wrong. - Assert.Fail(); - } - } - } - - [Test] - // test connection pooling with concurrent connection using async calls - public void TestConcurrentConnectionPoolingAsync() - { - // add test case name in connection string to make in unique for each test case - string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingAsync"; - ConcurrentPoolingAsyncHelper(connStr, true); - } - - [Test] - public void TestRollbackTransactionOnPooledWhenExceptionOccurred() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - - object firstOpenedSessionId; - using (var connection = new SnowflakeDbConnection()) - { - connection.ConnectionString = ConnectionString; - connection.Open(); - firstOpenedSessionId = connection.SfSession.sessionId; - connection.BeginTransaction(); - Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); - Assert.Throws(() => - { - using (var command = connection.CreateCommand()) - { - command.CommandText = "invalid command will throw exception and leave session with an unfinished transaction"; - command.ExecuteNonQuery(); - } - }); - } - - using (var connectionWithSessionReused = new SnowflakeDbConnection()) - { - connectionWithSessionReused.ConnectionString = ConnectionString; - connectionWithSessionReused.Open(); - - Assert.AreEqual(firstOpenedSessionId, connectionWithSessionReused.SfSession.sessionId); - Assert.AreEqual(false, connectionWithSessionReused.HasActiveExplicitTransaction()); - using (var cmd = connectionWithSessionReused.CreateCommand()) - { - cmd.CommandText = "SELECT CURRENT_TRANSACTION()"; - Assert.AreEqual(DBNull.Value, cmd.ExecuteScalar()); - } - } - - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be reused and any pending transaction rolled back before it gets back to the pool"); - } - - [Test] - public void TestTransactionStatusNotTrackedForNonExplicitTransactionCalls() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - using (var connection = new SnowflakeDbConnection()) - { - connection.ConnectionString = ConnectionString; - connection.Open(); - using (var command = connection.CreateCommand()) - { - command.CommandText = "BEGIN"; // in general can be put as a part of a multi statement call and mixed with commit as well - command.ExecuteNonQuery(); - Assert.AreEqual(false, connection.HasActiveExplicitTransaction()); - } - } - } - - [Test] - public void TestRollbackTransactionOnPooledWhenConnectionClose() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(1); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); - - string firstOpenedSessionId; - using (var connection1 = new SnowflakeDbConnection()) - { - connection1.ConnectionString = ConnectionString; - connection1.Open(); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection session is added to the pool after close connection"); - connection1.BeginTransaction(); - Assert.AreEqual(true, connection1.HasActiveExplicitTransaction()); - using (var command = connection1.CreateCommand()) - { - firstOpenedSessionId = connection1.SfSession.sessionId; - command.CommandText = "SELECT CURRENT_TRANSACTION()"; - Assert.AreNotEqual(DBNull.Value, command.ExecuteScalar()); - } - } - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); - - using (var connection2 = new SnowflakeDbConnection()) - { - connection2.ConnectionString = ConnectionString; - connection2.Open(); - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection session should be now removed from the pool"); - Assert.AreEqual(false, connection2.HasActiveExplicitTransaction()); - using (var command = connection2.CreateCommand()) - { - Assert.AreEqual(firstOpenedSessionId, connection2.SfSession.sessionId); - command.CommandText = "SELECT CURRENT_TRANSACTION()"; - Assert.AreEqual(DBNull.Value, command.ExecuteScalar()); - } - } - Assert.AreEqual(1, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Connection should be returned to the pool"); - } - - [Test] - public void TestFailureOfTransactionRollbackOnConnectionClosePreventsAddingToPool() - { - SnowflakeDbConnectionPool.SetMaxPoolSize(10); - var commandThrowingExceptionOnlyForRollback = new Mock(); - commandThrowingExceptionOnlyForRollback.CallBase = true; - commandThrowingExceptionOnlyForRollback.SetupSet(it => it.CommandText = "ROLLBACK") - .Throws(new SnowflakeDbException(SFError.INTERNAL_ERROR, "Unexpected failure on transaction rollback when connection is returned to the pool with pending transaction")); - var mockDbProviderFactory = new Mock(); - mockDbProviderFactory.Setup(p => p.CreateCommand()).Returns(commandThrowingExceptionOnlyForRollback.Object); - - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); - using (var connection = new TestSnowflakeDbConnection(mockDbProviderFactory.Object)) - { - connection.ConnectionString = ConnectionString; - connection.Open(); - connection.BeginTransaction(); - Assert.AreEqual(true, connection.HasActiveExplicitTransaction()); - // no Rollback or Commit; during internal Rollback while closing a connection a mocked exception will be thrown - } - - Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize(), "Should not return connection to the pool"); - } - - [Test] - // test connection pooling with concurrent connection and using async calls no close - // call for connection. Connection is closed when Dispose() is called - // by framework. - public void TestConcurrentConnectionPoolingDisposeAsync() - { - // add test case name in connection string to make in unique for each test case - string connStr = ConnectionString + ";application=TestConcurrentConnectionPoolingDisposeAsync"; - ConcurrentPoolingAsyncHelper(connStr, false); - } - - static void ConcurrentPoolingAsyncHelper(string connectionString, bool closeConnection) - { - // task number a bit larger than pool size so some connections - // would fail on pooling while some connections could success - const int TaskNum = 12; - // set short pooling timeout to cover the case that connection expired - const int PoolTimeout = 3; - - // reset to default settings in case it changed by other test cases - SnowflakeDbConnectionPool.SetMaxPoolSize(10); - SnowflakeDbConnectionPool.SetTimeout(PoolTimeout); - - var tasks = new Task[TaskNum + 1]; - for (int i = 0; i < TaskNum; i++) - { - tasks[i] = QueryExecutionTaskAsync(connectionString, closeConnection); - } - // cover the case of invalid sessions to ensure that won't - // break connection pooling - tasks[TaskNum] = InvalidConnectionTaskAsync(connectionString); - Task.WaitAll(tasks); - - // set pooling timeout back to default to avoid impact on other test cases - SnowflakeDbConnectionPool.SetTimeout(3600); - } - - // task to execute query with new connection in a loop - static async Task QueryExecutionTaskAsync(string connectionString, bool closeConnection) - { - for (int i = 0; i < 100; i++) - { - using (var conn = new SnowflakeDbConnection(connectionString)) - { - await conn.OpenAsync(); - using (DbCommand cmd = conn.CreateCommand()) - { - cmd.CommandText = "select 1, 2, 3"; - try - { - using (DbDataReader reader = await cmd.ExecuteReaderAsync()) - { - while (await reader.ReadAsync()) - { - for (int j = 0; j < reader.FieldCount; j++) - { - // Process each column as appropriate - await reader.GetFieldValueAsync(j); - } - } - } - } - catch (Exception e) - { - Assert.Fail("Caught unexpected exception: " + e); - } - } - - if (closeConnection) - { - await conn.CloseAsync(new CancellationTokenSource().Token); - } - } - } - } - - // task to generate invalid(not finish open) connections in a loop - static async Task InvalidConnectionTaskAsync(string connectionString) - { - for (int i = 0; i < 100; i++) - { - using (var conn = new SnowflakeDbConnection(connectionString)) - { - // intentionally not using await so the connection - // will be disposed with invalid underlying session - conn.OpenAsync(); - }; - // wait 100ms each time so the invalid sessions are generated - // roughly at the same speed as connections for query tasks - await Task.Delay(100); - } - } - - private class TestSnowflakeDbConnection : SnowflakeDbConnection - { - public TestSnowflakeDbConnection(DbProviderFactory dbProviderFactory) - { - DbProviderFactory = dbProviderFactory; - } - - protected override DbProviderFactory DbProviderFactory { get; } - } - } -} diff --git a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs index 8891e2e2a..cd30779aa 100755 --- a/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs +++ b/Snowflake.Data.Tests/IntegrationTests/SFDbCommandIT.cs @@ -7,6 +7,7 @@ using System; using System.Threading; using System.Threading.Tasks; +using Snowflake.Data.Core; namespace Snowflake.Data.Tests.IntegrationTests { @@ -17,7 +18,6 @@ namespace Snowflake.Data.Tests.IntegrationTests using System.Collections.Generic; using System.Globalization; using Snowflake.Data.Tests.Mock; - using Snowflake.Data.Core; [TestFixture] class SFDbCommandITAsync : SFBaseTestAsync @@ -25,10 +25,9 @@ class SFDbCommandITAsync : SFBaseTestAsync [Test] public void TestExecAsyncAPI() { - SnowflakeDbConnectionPool.ClearAllPools(); using (DbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; Task connectTask = conn.OpenAsync(CancellationToken.None); connectTask.Wait(); @@ -63,10 +62,9 @@ public void TestExecAsyncAPI() [Test] public void TestExecAsyncAPIParallel() { - SnowflakeDbConnectionPool.ClearAllPools(); using (DbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; Task connectTask = conn.OpenAsync(CancellationToken.None); connectTask.Wait(); @@ -112,8 +110,7 @@ public void TestCancelExecuteAsync() using (DbConnection conn = new SnowflakeDbConnection()) { - SnowflakeDbConnectionPool.ClearAllPools(); - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false";; conn.Open(); @@ -147,7 +144,7 @@ public void TestExecuteAsyncWithMaxRetryReached() using (DbConnection conn = new MockSnowflakeDbConnection(mockRestRequester)) { - string maxRetryConnStr = ConnectionString + "maxHttpRetries=5"; + string maxRetryConnStr = ConnectionString + "maxHttpRetries=8;poolingEnabled=false"; conn.ConnectionString = maxRetryConnStr; conn.Open(); @@ -169,10 +166,11 @@ public void TestExecuteAsyncWithMaxRetryReached() } stopwatch.Stop(); - // retry 5 times with backoff 1, 2, 4, 8, 16 seconds + var totalDelaySeconds = 1 + 2 + 4 + 8 + 16 + 16 + 16 + 16; + // retry 8 times with backoff 1, 2, 4, 8, 16, 16, 16, 16 seconds // but should not delay more than another 16 seconds - Assert.Less(stopwatch.ElapsedMilliseconds, 51 * 1000); - Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, 30 * 1000); + Assert.Less(stopwatch.ElapsedMilliseconds, (totalDelaySeconds + 20) * 1000); + Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, totalDelaySeconds * 1000); } } @@ -184,7 +182,7 @@ public async Task TestAsyncExecQueryAsync() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -223,7 +221,7 @@ public async Task TestExecuteNormalQueryWhileAsyncExecQueryIsRunningAsync() SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; for (int i = 0; i < connections.Length; i++) { - connections[i] = new SnowflakeDbConnection(ConnectionString); + connections[i] = new SnowflakeDbConnection(ConnectionString + "poolingEnabled=false"); await connections[i].OpenAsync(CancellationToken.None).ConfigureAwait(false); } @@ -278,7 +276,7 @@ public async Task TestAsyncExecCancelWhileGettingResultsAsync() { using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false";; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -312,7 +310,7 @@ public async Task TestFailedAsyncExecQueryThrowsErrorAsync() { using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -354,7 +352,7 @@ public async Task TestGetStatusOfInvalidQueryIdAsync() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -378,7 +376,7 @@ public async Task TestGetResultsOfInvalidQueryIdAsync() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -402,7 +400,7 @@ public async Task TestGetStatusOfUnknownQueryIdAsync() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -426,7 +424,7 @@ public async Task TestGetResultsOfUnknownQueryIdAsyncWithDefaultRetry() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -452,7 +450,7 @@ public async Task TestGetResultsOfUnknownQueryIdAsyncWithConfiguredRetry() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -481,7 +479,7 @@ public void TestLongRunningQuery() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -500,7 +498,7 @@ public void TestLongRunningQuery() [Ignore("This test case takes too much time so run it manually")] public void TestRowsAffectedOverflowInt() { - using (IDbConnection conn = new SnowflakeDbConnection(ConnectionString)) + using (IDbConnection conn = new SnowflakeDbConnection(ConnectionString + "poolingEnabled=false")) { conn.Open(); @@ -526,7 +524,7 @@ public void TestSimpleCommand() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); IDbCommand cmd = conn.CreateCommand(); @@ -589,7 +587,7 @@ public void TestSimpleLargeResultSet() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -612,6 +610,7 @@ public void TestSimpleLargeResultSet() [Test, NonParallelizable] public void TestUseV1ResultParser() { + var connectionString = ConnectionString + "poolingEnabled=false"; var chunkParserVersion = SFConfiguration.Instance().ChunkParserVersion; int chunkDownloaderVersion = SFConfiguration.Instance().ChunkDownloaderVersion; SFConfiguration.Instance().ChunkParserVersion = 1; @@ -619,7 +618,7 @@ public void TestUseV1ResultParser() using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = connectionString; conn.Open(); @@ -633,7 +632,6 @@ public void TestUseV1ResultParser() // don't test the second column as it has random values just to increase the response size counter++; } - conn.Close(); } SFConfiguration.Instance().ChunkParserVersion = chunkParserVersion; SFConfiguration.Instance().ChunkDownloaderVersion = chunkDownloaderVersion; @@ -642,6 +640,7 @@ public void TestUseV1ResultParser() [Test, NonParallelizable] public void TestUseV2ChunkDownloader() { + var connectionString = ConnectionString + "poolingEnabled=false"; var chunkParserVersion = SFConfiguration.Instance().ChunkParserVersion; int chunkDownloaderVersion = SFConfiguration.Instance().ChunkDownloaderVersion; SFConfiguration.Instance().ChunkParserVersion = 2; @@ -649,7 +648,7 @@ public void TestUseV2ChunkDownloader() using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = connectionString; conn.Open(); @@ -663,7 +662,6 @@ public void TestUseV2ChunkDownloader() // don't test the second column as it has random values just to increase the response size counter++; } - conn.Close(); } SFConfiguration.Instance().ChunkParserVersion = chunkParserVersion; SFConfiguration.Instance().ChunkDownloaderVersion = chunkDownloaderVersion; @@ -673,7 +671,7 @@ public void TestUseV2ChunkDownloader() [Parallelizable(ParallelScope.Children)] public void TestDefaultChunkDownloaderWithPrefetchThreads([Values(1, 2, 4)] int prefetchThreads) { - using (SnowflakeDbConnection conn = new SnowflakeDbConnection(ConnectionString)) + using (SnowflakeDbConnection conn = new SnowflakeDbConnection(ConnectionString + "poolingEnabled=false")) { conn.Open(); @@ -701,7 +699,7 @@ public void TestDataSourceError() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -727,7 +725,7 @@ public void TestCancelQuery() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -782,7 +780,7 @@ public void TestQueryTimeout() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -819,7 +817,7 @@ public void TestTransaction() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -888,7 +886,7 @@ public void TestRowsAffected() using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); @@ -911,7 +909,7 @@ public void TestExecuteScalarNull() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (IDbCommand command = conn.CreateCommand()) @@ -935,7 +933,7 @@ public void TestExecuteWithMaxRetryReached() using (IDbConnection conn = new MockSnowflakeDbConnection(mockRestRequester)) { - string maxRetryConnStr = ConnectionString + "maxHttpRetries=5"; + string maxRetryConnStr = ConnectionString + "maxHttpRetries=8;poolingEnabled=false"; conn.ConnectionString = maxRetryConnStr; conn.Open(); @@ -956,10 +954,11 @@ public void TestExecuteWithMaxRetryReached() } stopwatch.Stop(); - // retry 5 times with backoff 1, 2, 4, 8, 16 seconds + var totalDelaySeconds = 1 + 2 + 4 + 8 + 16 + 16 + 16 + 16; + // retry 8 times with backoff 1, 2, 4, 8, 16, 16, 16, 16 seconds // but should not delay more than another 16 seconds - Assert.Less(stopwatch.ElapsedMilliseconds, 51 * 1000); - Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, 30 * 1000); + Assert.Less(stopwatch.ElapsedMilliseconds, (totalDelaySeconds + 20) * 1000); + Assert.GreaterOrEqual(stopwatch.ElapsedMilliseconds, totalDelaySeconds * 1000); } } @@ -984,7 +983,7 @@ public void TestRowsAffectedUnload() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (IDbCommand command = conn.CreateCommand()) @@ -1016,7 +1015,7 @@ public void TestRowsAffectedUnload() //[Ignore("Ignore flaky unstable test case for now. Will revisit later and sdk issue created (210)")] public void testPutArrayBindAsync() { - ArrayBindTest(ConnectionString, TableName, 15000); + ArrayBindTest(ConnectionString + "poolingEnabled=false", TableName, 15000); } private void ArrayBindTest(string connstr, string tableName, int size) @@ -1156,8 +1155,8 @@ public void TestPutArrayBindAsyncMultiThreading() var t1TableName = TableName + 1; var t2TableName = TableName + 2; - Thread t1 = new Thread(() => ThreadProcess1(ConnectionString, t1TableName)); - Thread t2 = new Thread(() => ThreadProcess2(ConnectionString, t2TableName)); + Thread t1 = new Thread(() => ThreadProcess1(ConnectionString + "poolingEnabled=false", t1TableName)); + Thread t2 = new Thread(() => ThreadProcess2(ConnectionString + "poolingEnabled=false", t2TableName)); //Thread t3 = new Thread(() => ThreadProcess3(ConnectionString)); //Thread t4 = new Thread(() => ThreadProcess4(ConnectionString)); @@ -1195,7 +1194,7 @@ public void testExecuteScalarAsyncSelect() CancellationTokenSource externalCancel = new CancellationTokenSource(); using (DbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); CreateOrReplaceTable(conn, TableName, new []{"cola INTEGER"}); @@ -1235,9 +1234,7 @@ public void testExecuteLargeQueryWithGcsDownscopedToken() { using (IDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString - + String.Format( - ";GCS_USE_DOWNSCOPED_CREDENTIAL=true"); + conn.ConnectionString = ConnectionString + "GCS_USE_DOWNSCOPED_CREDENTIAL=true;poolingEnabled=false"; conn.Open(); int rowCount = 100000; @@ -1256,7 +1253,7 @@ public void TestGetQueryId() { using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); // query id is null when no query executed @@ -1345,7 +1342,7 @@ public void TestAsyncExecQuery() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1383,7 +1380,7 @@ public void TestExecuteNormalQueryWhileAsyncExecQueryIsRunning() SnowflakeDbConnection[] connections = new SnowflakeDbConnection[3]; for (int i = 0; i < connections.Length; i++) { - connections[i] = new SnowflakeDbConnection(ConnectionString); + connections[i] = new SnowflakeDbConnection(ConnectionString + "poolingEnabled=false"); connections[i].Open(); } @@ -1438,7 +1435,7 @@ public void TestFailedAsyncExecQueryThrowsError() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1475,7 +1472,7 @@ public void TestAsyncExecQueryPutGetThrowsNotImplemented() { using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1510,7 +1507,7 @@ public void TestGetStatusOfInvalidQueryId() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1533,7 +1530,7 @@ public void TestGetResultsOfInvalidQueryId() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1556,7 +1553,7 @@ public void TestGetStatusOfUnknownQueryId() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1580,7 +1577,7 @@ public void TestGetResultsOfUnknownQueryIdWithDefaultRetry() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) @@ -1605,7 +1602,7 @@ public void TestGetResultsOfUnknownQueryIdWithConfiguredRetry() using (SnowflakeDbConnection conn = new SnowflakeDbConnection()) { - conn.ConnectionString = ConnectionString; + conn.ConnectionString = ConnectionString + "poolingEnabled=false"; conn.Open(); using (SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) diff --git a/Snowflake.Data.Tests/Mock/MockHelper.cs b/Snowflake.Data.Tests/Mock/MockHelper.cs new file mode 100644 index 000000000..e4cf1d218 --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockHelper.cs @@ -0,0 +1,18 @@ +using Moq; +using Snowflake.Data.Client; +using Snowflake.Data.Core; + +namespace Snowflake.Data.Tests.Mock +{ + public static class MockHelper + { + public static Mock CommandThrowingExceptionOnlyForRollback() + { + var command = new Mock(); + command.CallBase = true; + command.SetupSet(it => it.CommandText = "ROLLBACK") + .Throws(new SnowflakeDbException(SFError.INTERNAL_ERROR, "Unexpected failure on transaction rollback when connection is returned to the pool with pending transaction")); + return command; + } + } +} diff --git a/Snowflake.Data.Tests/Mock/MockLoginStoringRestRequester.cs b/Snowflake.Data.Tests/Mock/MockLoginStoringRestRequester.cs new file mode 100644 index 000000000..a17dfd9e2 --- /dev/null +++ b/Snowflake.Data.Tests/Mock/MockLoginStoringRestRequester.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Core; + +namespace Snowflake.Data.Tests.Mock +{ + class MockLoginStoringRestRequester: IMockRestRequester + { + internal List LoginRequests { get; } = new(); + + public T Get(IRestRequest request) + { + return Task.Run(async () => await (GetAsync(request, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult((T)(object)null); + } + + public Task GetAsync(IRestRequest request, CancellationToken cancellationToken) + { + return Task.FromResult(null); + } + + public HttpResponseMessage Get(IRestRequest request) + { + return null; + } + + public T Post(IRestRequest postRequest) + { + return Task.Run(async () => await (PostAsync(postRequest, CancellationToken.None)).ConfigureAwait(false)).Result; + } + + public Task PostAsync(IRestRequest postRequest, CancellationToken cancellationToken) + { + SFRestRequest sfRequest = (SFRestRequest)postRequest; + if (sfRequest.jsonBody is LoginRequest) + { + LoginRequests.Add((LoginRequest) sfRequest.jsonBody); + LoginResponse authnResponse = new LoginResponse + { + data = new LoginResponseData() + { + token = "session_token", + masterToken = "master_token", + authResponseSessionInfo = new SessionInfo(), + nameValueParameter = new List() + }, + success = true + }; + + // login request return success + return Task.FromResult((T)(object)authnResponse); + } + throw new NotImplementedException(); + } + + public void setHttpClient(HttpClient httpClient) + { + // Nothing to do + } + } +} diff --git a/Snowflake.Data.Tests/Mock/TestSnowflakeDbConnection.cs b/Snowflake.Data.Tests/Mock/TestSnowflakeDbConnection.cs new file mode 100644 index 000000000..621ca5dd9 --- /dev/null +++ b/Snowflake.Data.Tests/Mock/TestSnowflakeDbConnection.cs @@ -0,0 +1,15 @@ +using System.Data.Common; +using Snowflake.Data.Client; + +namespace Snowflake.Data.Tests.Mock +{ + public class TestSnowflakeDbConnection : SnowflakeDbConnection + { + public TestSnowflakeDbConnection(DbProviderFactory dbProviderFactory) + { + DbProviderFactory = dbProviderFactory; + } + + protected override DbProviderFactory DbProviderFactory { get; } + } +} diff --git a/Snowflake.Data.Tests/SFBaseTest.cs b/Snowflake.Data.Tests/SFBaseTest.cs index 6aacb94f9..a6929eed5 100755 --- a/Snowflake.Data.Tests/SFBaseTest.cs +++ b/Snowflake.Data.Tests/SFBaseTest.cs @@ -25,9 +25,9 @@ namespace Snowflake.Data.Tests using Newtonsoft.Json.Serialization; /* - * This is the base class for all tests that call blocking methods in the library - it uses MockSynchronizationContext to verify that + * This is the base class for all tests that call blocking methods in the library - it uses MockSynchronizationContext to verify that * there are no async deadlocks in the library - * + * */ [TestFixture] public class SFBaseTest : SFBaseTestAsync @@ -47,7 +47,7 @@ public static void TearDownContext() /* * This is the base class for all tests that call async methods in the library - it does not use a special SynchronizationContext - * + * */ [TestFixture] [FixtureLifeCycle(LifeCycle.InstancePerTestCase)] @@ -65,12 +65,12 @@ public class SFBaseTestAsync protected virtual string TestName => TestContext.CurrentContext.Test.MethodName; protected string TestNameWithWorker => TestName + TestContext.CurrentContext.WorkerId?.Replace("#", "_"); protected string TableName => TestNameWithWorker; - + private Stopwatch _stopwatch; private List _tablesToRemove; - + [SetUp] public void BeforeTest() { @@ -93,7 +93,7 @@ private void RemoveTables() { if (_tablesToRemove.Count == 0) return; - + using (var conn = new SnowflakeDbConnection(ConnectionString)) { conn.Open(); @@ -148,26 +148,26 @@ public SFBaseTestAsync() string.Format(ConnectionStringSnowflakeAuthFmt, testConfig.user, testConfig.password); - + protected string ConnectionStringWithInvalidUserName => ConnectionStringWithoutAuth + string.Format(ConnectionStringSnowflakeAuthFmt, "unknown", testConfig.password); protected TestConfig testConfig { get; } - + protected string ResolveHost() { return testConfig.host ?? $"{testConfig.account}.snowflakecomputing.com"; } } - + [SetUpFixture] public class TestEnvironment { - private const string ConnectionStringFmt = "scheme={0};host={1};port={2};" + + private const string ConnectionStringFmt = "scheme={0};host={1};port={2};" + "account={3};role={4};db={5};warehouse={6};user={7};password={8};"; - + public static TestConfig TestConfig { get; private set; } private static Dictionary s_testPerformance; @@ -178,7 +178,7 @@ public static void RecordTestPerformance(string name, TimeSpan time) { lock (s_testPerformanceLock) { - s_testPerformance.Add(name, time); + s_testPerformance[name] = time; } } @@ -201,7 +201,7 @@ public void Setup() var testConfigString = reader.ReadToEnd(); - // Local JSON settings to avoid using system wide settings which could be different + // Local JSON settings to avoid using system wide settings which could be different // than the default ones var jsonSettings = new JsonSerializerSettings { @@ -221,16 +221,16 @@ public void Setup() { Assert.Fail("Failed to load test configuration"); } - + ModifySchema(TestConfig.schema, SchemaAction.CREATE); } - + [OneTimeTearDown] public void Cleanup() { ModifySchema(TestConfig.schema, SchemaAction.DROP); } - + [OneTimeSetUp] public void SetupTestPerformance() { @@ -243,12 +243,12 @@ public void CreateTestTimeArtifact() var resultText = "test;time_in_ms\n"; resultText += string.Join("\n", s_testPerformance.Select(test => $"{test.Key};{Math.Round(test.Value.TotalMilliseconds,0)}")); - + var dotnetVersion = Environment.GetEnvironmentVariable("net_version"); var cloudEnv = Environment.GetEnvironmentVariable("snowflake_cloud_env"); var separator = Path.DirectorySeparatorChar; - + // We have to go up 3 times as the working directory path looks as follows: // Snowflake.Data.Tests/bin/debug/{.net_version}/ File.WriteAllText($"..{separator}..{separator}..{separator}{GetOs()}_{dotnetVersion}_{cloudEnv}_performance.csv", resultText); diff --git a/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs b/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs new file mode 100644 index 000000000..4a6a03a33 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/AuthenticationPropertiesValidatorTest.cs @@ -0,0 +1,64 @@ +using System.Net; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Tests.Util; + + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture] + public class AuthenticationPropertiesValidatorTest + { + private const string _necessaryNonAuthProperties = "account=a;"; + + [TestCase("authenticator=snowflake;user=test;password=test", null)] + [TestCase("authenticator=Snowflake;user=test", "test")] + [TestCase("authenticator=ExternalBrowser", null)] + [TestCase("authenticator=snowflake_jwt;user=test;private_key_file=key.file", null)] + [TestCase("authenticator=SNOWFLAKE_JWT;user=test;private_key=key", null)] + [TestCase("authenticator=Snowflake_jwt;user=test;private_key=key;private_key_pwd=test", null)] + [TestCase("authenticator=oauth;token=value", null)] + [TestCase("AUTHENTICATOR=HTTPS://SOMETHING.OKTA.COM;USER=TEST;PASSWORD=TEST", null)] + [TestCase("authenticator=https://something.oktapreview.com;user=test;password=test", null)] + [TestCase("authenticator=https://vanity.url/snowflake/okta;USER=TEST;PASSWORD=TEST", null)] + public void TestAuthPropertiesValid(string connectionString, string password) + { + // Arrange + var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword; + + // Act/Assert + Assert.DoesNotThrow(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword)); + } + + [TestCase("authenticator=snowflake;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")] + [TestCase("authenticator=snowflake;", "test", SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake;user=;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided.")] + [TestCase("authenticator=snowflake;user=;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=snowflake;user=;", "test", SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;private_key_file=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;private_key=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=snowflake_jwt;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property USER is not provided")] + [TestCase("authenticator=oauth;TOKen=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property TOKEN is not provided")] + [TestCase("authenticator=oauth;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property TOKEN is not provided")] + [TestCase("authenticator=okta;user=;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;user=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;password=", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=okta;", null, SFError.MISSING_CONNECTION_PROPERTY, "Error: Required property PASSWORD is not provided")] + [TestCase("authenticator=unknown;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + [TestCase("authenticator=http://unknown.okta.com;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + [TestCase("authenticator=https://unknown;", null, SFError.UNKNOWN_AUTHENTICATOR, "Unknown authenticator")] + public void TestAuthPropertiesInvalid(string connectionString, string password, SFError expectedError, string expectedErrorMessage) + { + // Arrange + var securePassword = string.IsNullOrEmpty(password) ? null : new NetworkCredential(string.Empty, password).SecurePassword; + + // Act + var exception = Assert.Throws(() => SFSessionProperties.ParseConnectionString(_necessaryNonAuthProperties + connectionString, securePassword)); + + // Assert + SnowflakeDbExceptionAssert.HasErrorCode(exception, expectedError); + Assert.That(exception.Message.Contains(expectedErrorMessage), $"Expecting:\n\t{exception.Message}\nto contain:\n\t{expectedErrorMessage}"); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs new file mode 100644 index 000000000..589565ddf --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionCacheManagerTest.cs @@ -0,0 +1,46 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture, NonParallelizable] + public class ConnectionCacheManagerTest + { + private readonly ConnectionCacheManager _connectionCacheManager = new ConnectionCacheManager(); + private const string ConnectionString = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;"; + private static PoolConfig s_poolConfig; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); + SessionPool.SessionFactory = new MockSessionFactory(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [SetUp] + public void BeforeEach() + { + _connectionCacheManager.ClearAllPools(); + } + + [Test] + public void TestEnablePoolingRegardlessOfConnectionStringProperty() + { + // act + var pool = _connectionCacheManager.GetPool(ConnectionString + "poolingEnabled=false"); + + // assert + Assert.IsTrue(pool.GetPooling()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs new file mode 100644 index 000000000..70efa47fb --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/ConnectionPoolManagerTest.cs @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Moq; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests +{ + [TestFixture, NonParallelizable] + class ConnectionPoolManagerTest + { + private readonly ConnectionPoolManager _connectionPoolManager = new ConnectionPoolManager(); + private const string ConnectionString1 = "db=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;minPoolSize=1;"; + private const string ConnectionString2 = "db=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;minPoolSize=1;"; + private const string ConnectionStringWithoutPassword = "db=D3;warehouse=W3;account=A3;user=U3;role=R3;minPoolSize=1;"; + private readonly SecureString _password3 = SecureStringHelper.Encode("P3"); + private static PoolConfig s_poolConfig; + + [OneTimeSetUp] + public static void BeforeAllTests() + { + s_poolConfig = new PoolConfig(); + SnowflakeDbConnectionPool.SetConnectionPoolVersion(ConnectionPoolType.MultipleConnectionPool); + SessionPool.SessionFactory = new MockSessionFactory(); + } + + [OneTimeTearDown] + public static void AfterAllTests() + { + s_poolConfig.Reset(); + SessionPool.SessionFactory = new SessionFactory(); + } + + [SetUp] + public void BeforeEach() + { + _connectionPoolManager.ClearAllPools(); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionString() + { + // Act + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); + + // Assert + Assert.AreEqual(ConnectionString1, sessionPool.ConnectionString); + Assert.AreEqual(null, sessionPool.Password); + } + + [Test] + public void TestPoolManagerReturnsSessionPoolForGivenConnectionStringAndSecurelyProvidedPassword() + { + // Act + var sessionPool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); + + // Assert + Assert.AreEqual(ConnectionStringWithoutPassword, sessionPool.ConnectionString); + Assert.AreEqual(_password3, sessionPool.Password); + } + + [Test] + public void TestPoolManagerThrowsWhenPasswordNotProvided() + { + // Act/Assert + Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, null)); + } + + [Test] + public void TestPoolManagerReturnsSamePoolForGivenConnectionString() + { + // Arrange + var anotherConnectionString = ConnectionString1; + + // Act + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(anotherConnectionString, null); + + // Assert + Assert.AreEqual(sessionPool1, sessionPool2); + } + + [Test] + public void TestDifferentPoolsAreReturnedForDifferentConnectionStrings() + { + // Arrange + Assert.AreNotSame(ConnectionString1, ConnectionString2); + + // Act + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + + // Assert + Assert.AreNotSame(sessionPool1, sessionPool2); + Assert.AreEqual(ConnectionString1, sessionPool1.ConnectionString); + Assert.AreEqual(ConnectionString2, sessionPool2.ConnectionString); + } + + + [Test] + public void TestGetSessionWorksForSpecifiedConnectionString() + { + // Act + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + + // Assert + Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); + Assert.AreEqual(null, sfSession.Password); + } + + [Test] + public async Task TestGetSessionAsyncWorksForSpecifiedConnectionString() + { + // Act + var sfSession = await _connectionPoolManager.GetSessionAsync(ConnectionString1, null, CancellationToken.None); + + // Assert + Assert.AreEqual(ConnectionString1, sfSession.ConnectionString); + Assert.AreEqual(null, sfSession.Password); + } + + [Test] + public void TestCountingOfSessionProvidedByPool() + { + // Act + _connectionPoolManager.GetSession(ConnectionString1, null); + + // Assert + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); + Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); + } + + [Test] + public void TestCountingOfSessionReturnedBackToPool() + { + // Arrange + var sfSession = _connectionPoolManager.GetSession(ConnectionString1, null); + + // Act + _connectionPoolManager.AddSession(sfSession); + + // Assert + var sessionPool = _connectionPoolManager.GetPool(ConnectionString1, null); + Assert.AreEqual(1, sessionPool.GetCurrentPoolSize()); + } + + [Test] + public void TestSetMaxPoolSizeForAllPoolsDisabled() + { + // Arrange + _connectionPoolManager.GetPool(ConnectionString1, null); + + // Act + var thrown = Assert.Throws(() => _connectionPoolManager.SetMaxPoolSize(3)); + + // Assert + Assert.That(thrown.Message, Does.Contain("You cannot change connection pool parameters for all the pools. Instead you can change it on a particular pool")); + } + + [Test] + public void TestSetTimeoutForAllPoolsDisabled() + { + // Arrange + _connectionPoolManager.GetPool(ConnectionString1, null); + + // Act + var thrown = Assert.Throws(() => _connectionPoolManager.SetTimeout(3000)); + + // Assert + Assert.That(thrown.Message, Does.Contain("You cannot change connection pool parameters for all the pools. Instead you can change it on a particular pool")); + } + + [Test] + public void TestSetPoolingForAllPoolsDisabled() + { + // Arrange + _connectionPoolManager.GetPool(ConnectionString1, null); + + // Act + var thrown = Assert.Throws(() => _connectionPoolManager.SetPooling(false)); + + // Assert + Assert.That(thrown.Message, Does.Contain("You cannot change connection pool parameters for all the pools. Instead you can change it on a particular pool")); + } + + [Test] + public void TestGetPoolingOnManagerLevelAlwaysTrue() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + sessionPool1.SetPooling(true); + sessionPool2.SetPooling(false); + + // Act + var pooling = _connectionPoolManager.GetPooling(); + + // Assert + Assert.IsTrue(pooling); + Assert.IsTrue(sessionPool1.GetPooling()); + Assert.IsFalse(sessionPool2.GetPooling()); + } + + [Test] + [TestCase("authenticator=externalbrowser;account=test;user=test;")] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key")] + public void TestDisabledPoolingWhenSecretesProvidedExternally(string connectionString) + { + // act + var pool = _connectionPoolManager.GetPool(connectionString, null); + + // assert + Assert.IsFalse(pool.GetPooling()); + } + + [Test] + public void TestGetTimeoutOnManagerLevelWhenNotAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + sessionPool1.SetTimeout(299); + sessionPool2.SetTimeout(1313); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetTimeout()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Timeout values")); + } + + [Test] + public void TestGetTimeoutOnManagerLevelWhenAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + sessionPool1.SetTimeout(3600); + sessionPool2.SetTimeout(3600); + + // Act/Assert + Assert.AreEqual(3600,_connectionPoolManager.GetTimeout()); + } + + [Test] + public void TestGetMaxPoolSizeOnManagerLevelWhenNotAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + sessionPool1.SetMaxPoolSize(1); + sessionPool2.SetMaxPoolSize(17); + + // Act/Assert + var exception = Assert.Throws(() => _connectionPoolManager.GetMaxPoolSize()); + Assert.IsNotNull(exception); + Assert.AreEqual(SFError.INCONSISTENT_RESULT_ERROR.GetAttribute().errorCode, exception.ErrorCode); + Assert.IsTrue(exception.Message.Contains("Multiple pools have different Max Pool Size values")); + } + + [Test] + public void TestGetMaxPoolSizeOnManagerLevelWhenAllPoolsEqual() + { + // Arrange + var sessionPool1 = _connectionPoolManager.GetPool(ConnectionString1, null); + var sessionPool2 = _connectionPoolManager.GetPool(ConnectionString2, null); + sessionPool1.SetMaxPoolSize(33); + sessionPool2.SetMaxPoolSize(33); + + // Act/Assert + Assert.AreEqual(33,_connectionPoolManager.GetMaxPoolSize()); + } + + [Test] + public void TestGetCurrentPoolSizeReturnsSumOfPoolSizes() + { + // Arrange + EnsurePoolSize(ConnectionString1, null, 2); + EnsurePoolSize(ConnectionString2, null, 3); + + // act + var poolSize = _connectionPoolManager.GetCurrentPoolSize(); + + // assert + Assert.AreEqual(5, poolSize); + } + + [Test] + public void TestReturnPoolForSecurePassword() + { + // arrange + const string AnotherPassword = "anotherPassword"; + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 1); + + // act + var pool = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, SecureStringHelper.Encode(AnotherPassword)); // a new pool has been created because the password is different + + // assert + Assert.AreEqual(0, pool.GetCurrentPoolSize()); + Assert.AreEqual(AnotherPassword, SecureStringHelper.Decode(pool.Password)); + } + + [Test] + public void TestReturnDifferentPoolWhenPasswordProvidedInDifferentWay() + { + // arrange + var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={SecureStringHelper.Decode(_password3)}"; + EnsurePoolSize(ConnectionStringWithoutPassword, _password3, 2); + EnsurePoolSize(connectionStringWithPassword, null, 5); + EnsurePoolSize(connectionStringWithPassword, _password3, 8); + + // act + var pool1 = _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, _password3); + var pool2 = _connectionPoolManager.GetPool(connectionStringWithPassword, null); + var pool3 = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3); + + // assert + Assert.AreEqual(2, pool1.GetCurrentPoolSize()); + Assert.AreEqual(5, pool2.GetCurrentPoolSize()); + Assert.AreEqual(8, pool3.GetCurrentPoolSize()); + } + + [Test] + [TestCase(null)] + [TestCase("")] + public void TestGetPoolFailsWhenNoPasswordProvided(string password) + { + // arrange + var securePassword = password == null ? null : SecureStringHelper.Encode(password); + + // act + var thrown = Assert.Throws(() => _connectionPoolManager.GetPool(ConnectionStringWithoutPassword, securePassword)); + + // assert + Assert.That(thrown.Message, Does.Contain("Required property PASSWORD is not provided")); + } + + [Test] + public void TestPoolDoesNotSerializePassword() + { + // arrange + var password = SecureStringHelper.Decode(_password3); + var connectionStringWithPassword = $"{ConnectionStringWithoutPassword}password={password}"; + var pool = _connectionPoolManager.GetPool(connectionStringWithPassword, _password3); + + // act + var serializedPool = pool.ToString(); + + // assert + Assert.IsFalse(serializedPool.Contains(password)); + } + + private void EnsurePoolSize(string connectionString, SecureString password, int requiredCurrentSize) + { + var sessionPool = _connectionPoolManager.GetPool(connectionString, password); + sessionPool.SetMaxPoolSize(requiredCurrentSize); + for (var i = 0; i < requiredCurrentSize; i++) + { + _connectionPoolManager.GetSession(connectionString, password); + } + Assert.AreEqual(requiredCurrentSize, sessionPool.GetCurrentPoolSize()); + } + } + + class MockSessionFactory : ISessionFactory + { + public SFSession NewSession(string connectionString, SecureString password) + { + var mockSfSession = new Mock(connectionString, password); + mockSfSession.Setup(x => x.Open()).Verifiable(); + mockSfSession.Setup(x => x.OpenAsync(default)).Returns(Task.FromResult(this)); + mockSfSession.Setup(x => x.IsNotOpen()).Returns(false); + mockSfSession.Setup(x => x.IsExpired(It.IsAny(), It.IsAny())).Returns(false); + return mockSfSession.Object; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs b/Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs similarity index 100% rename from Snowflake.Data.Tests/UnitTests/Session/EasyLoggingStarterTest.cs rename to Snowflake.Data.Tests/UnitTests/Logger/EasyLoggingStarterTest.cs diff --git a/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs b/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs index 3157619ae..d7399bd65 100644 --- a/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFAuthenticatorFactoryTest.cs @@ -68,7 +68,7 @@ public void TestGetAuthenticatorOAuth() public void TestGetAuthenticatorOAuthWithMissingToken() { SnowflakeDbException ex = Assert.Throws(() => GetAuthenticator(OAuthAuthenticator.AUTH_NAME)); - Assert.AreEqual(SFError.INVALID_CONNECTION_STRING.GetAttribute().errorCode, ex.ErrorCode); + Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute().errorCode, ex.ErrorCode); } [Test] diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs index 53b130f25..40c7551f8 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionPropertyTest.cs @@ -8,6 +8,7 @@ using NUnit.Framework; using Snowflake.Data.Client; using Snowflake.Data.Core.Authenticator; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Tests.UnitTests { @@ -83,6 +84,26 @@ public void TestThatItFailsIfNoAccountSpecified(string connectionString) Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute().errorCode, exception.ErrorCode); } + [Test] + [TestCase("ACCOUNT=testaccount;USER=testuser;PASSWORD=", null)] + [TestCase("ACCOUNT=testaccount;USER=testuser;", "")] + [TestCase("authenticator=okta;ACCOUNT=testaccount;USER=testuser;PASSWORD=", null)] + [TestCase("authenticator=okta;ACCOUNT=testaccount;USER=testuser;", "")] + public void TestFailWhenNoPasswordProvided(string connectionString, string password) + { + // arrange + var securePassword = password == null ? null : SecureStringHelper.Encode(password); + + // act + var exception = Assert.Throws( + () => SFSessionProperties.ParseConnectionString(connectionString, securePassword) + ); + + // assert + Assert.AreEqual(SFError.MISSING_CONNECTION_PROPERTY.GetAttribute().errorCode, exception.ErrorCode); + Assert.That(exception.Message, Does.Contain("Required property PASSWORD is not provided")); + } + [Test] [TestCase("DB", SFSessionProperty.DB, "\"testdb\"")] [TestCase("SCHEMA", SFSessionProperty.SCHEMA, "\"quotedSchema\"")] @@ -118,20 +139,6 @@ public void TestValidateSupportEscapedQuotesInsideValuesForObjectProperties(stri Assert.AreEqual(expectedValue, properties[sessionProperty]); } - [Test] - public void TestProcessEmptyUserAndPasswordInConnectionString() - { - // arrange - var connectionString = $"ACCOUNT=test;USER=;PASSWORD=;"; - - // act - var properties = SFSessionProperties.ParseConnectionString(connectionString, null); - - // assert - Assert.AreEqual(string.Empty, properties[SFSessionProperty.USER]); - Assert.AreEqual(string.Empty, properties[SFSessionProperty.PASSWORD]); - } - public static IEnumerable ConnectionStringTestCases() { string defAccount = "testaccount"; @@ -181,7 +188,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; @@ -210,7 +223,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseWithProxySettings = new TestCase() @@ -241,7 +260,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};useProxy=true;proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -274,7 +299,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};proxyHost=proxy.com;proxyPort=1234;nonProxyHosts=localhost" @@ -306,7 +337,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseWithIncludeRetryReason = new TestCase() @@ -335,7 +372,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, "false" }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseWithDisableQueryContextCache = new TestCase() @@ -363,7 +406,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, "true" }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLEQUERYCONTEXTCACHE=true" @@ -393,7 +442,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, "false" }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } }, ConnectionString = $"ACCOUNT={defAccount};USER={defUser};PASSWORD={defPassword};DISABLE_CONSOLE_LOGIN=false" @@ -425,7 +480,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseUnderscoredAccountName = new TestCase() @@ -454,7 +515,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, defAllowUnderscoresInHost }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testCaseUnderscoredAccountNameWithEnabledAllowUnderscores = new TestCase() @@ -483,7 +550,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.INCLUDERETRYREASON, defIncludeRetryReason }, { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, - { SFSessionProperty.ALLOWUNDERSCORESINHOST, "true" } + { SFSessionProperty.ALLOWUNDERSCORESINHOST, "true" }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; var testQueryTag = "Test QUERY_TAG 12345"; @@ -514,7 +587,13 @@ public static IEnumerable ConnectionStringTestCases() { SFSessionProperty.DISABLEQUERYCONTEXTCACHE, defDisableQueryContextCache }, { SFSessionProperty.DISABLE_CONSOLE_LOGIN, defDisableConsoleLogin }, { SFSessionProperty.ALLOWUNDERSCORESINHOST, "false" }, - { SFSessionProperty.QUERY_TAG, testQueryTag } + { SFSessionProperty.QUERY_TAG, testQueryTag }, + { SFSessionProperty.MAXPOOLSIZE, DefaultValue(SFSessionProperty.MAXPOOLSIZE) }, + { SFSessionProperty.MINPOOLSIZE, DefaultValue(SFSessionProperty.MINPOOLSIZE) }, + { SFSessionProperty.CHANGEDSESSION, DefaultValue(SFSessionProperty.CHANGEDSESSION) }, + { SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT, DefaultValue(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT) }, + { SFSessionProperty.EXPIRATIONTIMEOUT, DefaultValue(SFSessionProperty.EXPIRATIONTIMEOUT) }, + { SFSessionProperty.POOLINGENABLED, DefaultValue(SFSessionProperty.POOLINGENABLED) } } }; @@ -535,6 +614,9 @@ public static IEnumerable ConnectionStringTestCases() }; } + private static string DefaultValue(SFSessionProperty property) => + property.GetAttribute().defaultValue; + internal class TestCase { public string ConnectionString { get; set; } diff --git a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs index b9530b83b..a1f795026 100644 --- a/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SFSessionTest.cs @@ -1,15 +1,14 @@ /* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ -using Snowflake.Data.Configuration; -using Snowflake.Data.Log; +using Newtonsoft.Json; +using Snowflake.Data.Core; +using NUnit.Framework; +using Snowflake.Data.Tests.Mock; namespace Snowflake.Data.Tests.UnitTests { - using Snowflake.Data.Core; - using NUnit.Framework; - [TestFixture] class SFSessionTest { @@ -20,26 +19,61 @@ public void TestSessionGoneWhenClose() Mock.MockCloseSessionGone restRequester = new Mock.MockCloseSessionGone(); SFSession sfSession = new SFSession("account=test;user=test;password=test", null, restRequester); sfSession.Open(); - sfSession.close(); // no exception is raised. + Assert.DoesNotThrow(() => sfSession.close()); } [Test] - public void TestUpdateDatabaseAndSchema() + public void TestUpdateSessionProperties() { + // arrange string databaseName = "DB_TEST"; string schemaName = "SC_TEST"; - + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + QueryExecResponseData queryExecResponseData = new QueryExecResponseData + { + finalSchemaName = schemaName, + finalDatabaseName = databaseName, + finalRoleName = roleName, + finalWarehouseName = warehouseName + }; + + // act SFSession sfSession = new SFSession("account=test;user=test;password=test", null); - sfSession.UpdateDatabaseAndSchema(databaseName, schemaName); + sfSession.UpdateSessionProperties(queryExecResponseData); + // assert Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); + } + [Test] + public void TestSkipUpdateSessionPropertiesWhenPropertiesMissing() + { + // arrange + string databaseName = "DB_TEST"; + string schemaName = "SC_TEST"; + string warehouseName = "WH_TEST"; + string roleName = "ROLE_TEST"; + SFSession sfSession = new SFSession("account=test;user=test;password=test", null); + sfSession.database = databaseName; + sfSession.warehouse = warehouseName; + sfSession.role = roleName; + sfSession.schema = schemaName; + + // act + QueryExecResponseData queryExecResponseWithoutData = new QueryExecResponseData(); + sfSession.UpdateSessionProperties(queryExecResponseWithoutData); + + // assert // when database or schema name is missing in the response, // the cached value should keep unchanged - sfSession.UpdateDatabaseAndSchema(null, null); Assert.AreEqual(databaseName, sfSession.database); Assert.AreEqual(schemaName, sfSession.schema); + Assert.AreEqual(warehouseName, sfSession.warehouse); + Assert.AreEqual(roleName, sfSession.role); } [Test] @@ -54,12 +88,65 @@ public void TestThatConfiguresEasyLogging(string configPath) var connectionString = configPath == null ? simpleConnectionString : $"{simpleConnectionString}client_config_file={configPath};"; - + // act new SFSession(connectionString, null, easyLoggingStarter.Object); - + // assert easyLoggingStarter.Verify(starter => starter.Init(configPath)); } + + [TestCase(null, "accountDefault", "accountDefault", false)] + [TestCase("initial", "initial", "initial", false)] + [TestCase("initial", null, "initial", false)] + [TestCase("initial", "IniTiaL", "initial", false)] + [TestCase("initial", "final", "final", true)] + [TestCase("initial", "\\\"final\\\"", "\"final\"", true)] + [TestCase("initial", "\\\"Final\\\"", "\"Final\"", true)] + [TestCase("\"Ini\\t\"ial\"", "\\\"Ini\\t\"ial\\\"", "\"Ini\\t\"ial\"", false)] + [TestCase("\"initial\"", "initial", "initial", true)] + [TestCase("\"initial\"", "\\\"initial\\\"", "\"initial\"", false)] + [TestCase("init\"ial", "init\"ial", "init\"ial", false)] + [TestCase("\"init\"ial\"", "\\\"init\"ial\\\"", "\"init\"ial\"", false)] + [TestCase("\"init\"ial\"", "\\\"Init\"ial\\\"", "\"Init\"ial\"", true)] + public void TestSessionPropertyQuotationSafeUpdateOnServerResponse(string sessionInitialValue, string serverResponseFinalSessionValue, string unquotedExpectedFinalValue, bool wasChanged) + { + // Arrange + SFSession sfSession = new SFSession("account=test;user=test;password=test", null); + var changedSessionValue = sessionInitialValue; + + // Act + sfSession.UpdateSessionProperty(ref changedSessionValue, serverResponseFinalSessionValue); + + // Assert + Assert.AreEqual(sfSession.SessionPropertiesChanged, wasChanged); + if (wasChanged || sessionInitialValue is null) + Assert.AreEqual(unquotedExpectedFinalValue, changedSessionValue); + else + Assert.AreEqual(sessionInitialValue, changedSessionValue); + } + + [Test] + public void TestHandlePasswordWithQuotations() + { + // arrange + MockLoginStoringRestRequester restRequester = new MockLoginStoringRestRequester(); + SFSession sfSession = new SFSession("account=test;user=test;password=test\"with'quotations{}", null, restRequester); + + // act + sfSession.Open(); + + // assert + Assert.AreEqual(1, restRequester.LoginRequests.Count); + var loginRequest = restRequester.LoginRequests[0]; + Assert.AreEqual("test\"with'quotations{}", loginRequest.data.password); + + // act + var json = JsonConvert.SerializeObject(loginRequest, JsonUtils.JsonSettings); + var deserializedLoginRequest = (LoginRequest) JsonConvert.DeserializeObject(json, typeof(LoginRequest)); + + // assert + Assert.AreEqual(loginRequest.data.password, deserializedLoginRequest.data.password); + } } } diff --git a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs index 862a7c248..82c59a63c 100644 --- a/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs +++ b/Snowflake.Data.Tests/UnitTests/SecretDetectorTest.cs @@ -1,16 +1,15 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ -using Amazon.S3.Model.Internal.MarshallTransformations; +using NUnit.Framework; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Mock; +using System; +using System.Text; namespace Snowflake.Data.Tests.UnitTests { - using NUnit.Framework; - using Snowflake.Data.Log; - using Snowflake.Data.Tests.Mock; - using System; - using System.Collections.Generic; [TestFixture] class SecretDetectorTest @@ -95,7 +94,7 @@ public void TestAWSKeys() BasicMasking(@"""aws_key_id""='aaaaaaaa'", @"""aws_key_id""='****'"); //aws_key_id|aws_secret_key|access_key_id|secret_access_key)('|"")?(\s*[:|=]\s*)'([^']+)' - // Delimiters before start of value to mask + // Delimiters before start of value to mask BasicMasking(@"aws_key_id:'aaaaaaaa'", @"aws_key_id:'****'"); BasicMasking(@"aws_key_id='aaaaaaaa'", @"aws_key_id='****'"); } @@ -144,7 +143,7 @@ public void TestSASTokens() BasicMasking(@"sig=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"sig=****"); // signature - BasicMasking(@"signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"signature=****"); + BasicMasking(@"signature=aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", @"signature=****"); // AWSAccessKeyId BasicMasking(@"AWSAccessKeyId=ABCDEFGHIJKL01234", @"AWSAccessKeyId=****"); // pragma: allowlist secret @@ -167,6 +166,32 @@ public void TestPrivateKey() "-----BEGIN PRIVATE KEY-----\\\\nXXXX\\\\n-----END PRIVATE KEY-----"); // pragma: allowlist secret } + [Test] + public void TestPrivateKeyProperty() + { + BasicMasking(@"something=anything;private_key=aaaaaa", @"something=anything;private_key=****"); + BasicMasking("something=anything;private_key \r\n =aaaaaa", "something=anything;private_key \r\n =****"); + BasicMasking(@"something=anything;private_key=aaaaaaaaaaaaaaaaaa", @"something=anything;private_key=****"); + BasicMasking(@"something=anything;private_key=a", @"something=anything;private_key=****"); + BasicMasking(@"something=anything;private_key=""a"";someOtherProperty=someValue", @"something=anything;private_key=****"); + BasicMasking(@"something=anything;private_key='a';someOtherProperty=someValue", @"something=anything;private_key=****"); + BasicMasking($"something=anything;private_key ={GetStringWithManyWeirdCharacters()}\r\nxxxxxx\r\nyyyyyy;someOtherProperty=someValue", @"something=anything;private_key =****"); + } + + private string GetStringWithManyWeirdCharacters() + { + var bytes = new byte[256]; + for (var i = 0; i < 256; i++) + { + if (i < 20) + { + bytes[i] = 58; + } + bytes[i] = (byte) i; + } + return Encoding.Default.GetString(bytes); + } + [Test] public void TestPrivateKeyData() { @@ -185,12 +210,12 @@ public void TestConnectionTokens() // assertion content BasicMasking(@"assertion content:aaaaaaaa", @"assertion content:****"); - // Delimiters before start of value to mask + // Delimiters before start of value to mask BasicMasking(@"token""aaaaaaaa", @"token""****"); // " BasicMasking(@"token'aaaaaaaa", @"token'****"); // ' BasicMasking(@"token=aaaaaaaa", @"token=****"); // = BasicMasking(@"token aaaaaaaa", @"token ****"); // {space} - BasicMasking(@"token ="" 'aaaaaaaa", @"token ="" '****"); // Mix + BasicMasking(@"token ="" 'aaaaaaaa", @"token =****"); // Mix // Verify that all allowed characters are correctly supported BasicMasking(@"Token:a=b/c_d-e+F:025", @"Token:****"); @@ -211,17 +236,57 @@ public void TestPassword() // passcode BasicMasking(@"passcode:aaaaaaaa", @"passcode:****"); - // Delimiters before start of value to mask + // Delimiters before start of value to mask BasicMasking(@"password""aaaaaaaa", @"password""****"); // " BasicMasking(@"password'aaaaaaaa", @"password'****"); // ' BasicMasking(@"password=aaaaaaaa", @"password=****"); // = BasicMasking(@"password aaaaaaaa", @"password ****"); // {space} - BasicMasking(@"password ="" 'aaaaaaaa", @"password ="" '****"); // Mix + BasicMasking(@"password ="" 'aaaaaaaa", @"password =****"); // Mix // Verify that all allowed characters are correctly supported BasicMasking(@"password:a!b""c#d$e%f&g'h(i)k*k+l,m;nq?r@s[t]u^v_w`x{y|z}Az0123", @"password:****"); } + [Test] + public void TestPasswordProperty() + { + BasicMasking(@"somethingBefore=cccc;password=aa", @"somethingBefore=cccc;password=****"); + BasicMasking(@"somethingBefore=cccc;password=aa;somethingNext=bbbb", @"somethingBefore=cccc;password=****"); + BasicMasking(@"somethingBefore=cccc;password=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;password=****"); + BasicMasking(@"somethingBefore=cccc;password=;somethingNext=bbbb", @"somethingBefore=cccc;password=****"); + BasicMasking(@"somethingBefore=cccc;password=", @"somethingBefore=cccc;password=****"); + BasicMasking(@"somethingBefore=cccc;password =aa;somethingNext=bbbb", @"somethingBefore=cccc;password =****"); + BasicMasking(@"somethingBefore=cccc;password="" 'aa", @"somethingBefore=cccc;password=****"); + + BasicMasking(@"somethingBefore=cccc;proxypassword=aa", @"somethingBefore=cccc;proxypassword=****"); + BasicMasking(@"somethingBefore=cccc;proxypassword=aa;somethingNext=bbbb", @"somethingBefore=cccc;proxypassword=****"); + BasicMasking(@"somethingBefore=cccc;proxypassword=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;proxypassword=****"); + BasicMasking(@"somethingBefore=cccc;proxypassword=;somethingNext=bbbb", @"somethingBefore=cccc;proxypassword=****"); + BasicMasking(@"somethingBefore=cccc;proxypassword=", @"somethingBefore=cccc;proxypassword=****"); + BasicMasking(@"somethingBefore=cccc;proxypassword =aa;somethingNext=bbbb", @"somethingBefore=cccc;proxypassword =****"); + BasicMasking(@"somethingBefore=cccc;proxypassword="" 'aa", @"somethingBefore=cccc;proxypassword=****"); + + BasicMasking(@"somethingBefore=cccc;private_key_pwd=aa", @"somethingBefore=cccc;private_key_pwd=****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd=aa;somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd=****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd=****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd=;somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd=****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd=", @"somethingBefore=cccc;private_key_pwd=****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd =aa;somethingNext=bbbb", @"somethingBefore=cccc;private_key_pwd =****"); + BasicMasking(@"somethingBefore=cccc;private_key_pwd="" 'aa", @"somethingBefore=cccc;private_key_pwd=****"); + } + + [Test] + [TestCase("2020-04-30 23:06:04,069 - MainThread auth.py:397 - write_temporary_credential() - DEBUG - no ID password was not given")] + [TestCase("2020-04-30 23:06:04,069 - MainThread auth.py:397 - write_temporary_credential() - DEBUG - no ID proxyPassword was not given")] + [TestCase("2020-04-30 23:06:04,069 - MainThread auth.py:397 - write_temporary_credential() - DEBUG - no ID private_key_pwd was not given")] + public void TestPasswordFalsePositive(string falsePositiveMessage) + { + mask = SecretDetector.MaskSecrets(falsePositiveMessage); + Assert.IsFalse(mask.isMasked); + Assert.AreEqual(falsePositiveMessage, mask.maskedText); + Assert.IsNull(mask.errStr); + } + [Test] public void TestMaskToken() { @@ -268,7 +333,7 @@ public void TestMaskToken() string snowFlakeAuthToken = "Authorization: Snowflake Token=\"ver:1-hint:92019676298218-ETMsDgAAAXswwgJhABRBRVMvQ0JDL1BLQ1M1UGFkZGluZwEAABAAEF1tbNM3myWX6A9sNSK6rpIAAACA6StojDJS4q1Vi3ID+dtFEucCEvGMOte0eapK+reb39O6hTHYxLfOgSGsbvbM5grJ4dYdNJjrzDf1r07tID4I2RJJRYjS4/DWBJn98Untd3xeNnXE1/45HgvwKVHlmZQLVwfWAxI7ifl2MVDwJlcXBufLZoVMYhUd4np121d7zFwAFGQzKyzUYQwI3M9Nqja9syHgaotG\""; mask = SecretDetector.MaskSecrets(snowFlakeAuthToken); Assert.IsTrue(mask.isMasked); - Assert.AreEqual(@"Authorization: Snowflake Token=""****""", mask.maskedText); + Assert.AreEqual(@"Authorization: Snowflake Token=****", mask.maskedText); Assert.IsNull(mask.errStr); } @@ -311,7 +376,7 @@ public void TestPasswords() string randomPasswordEqualSign = "password = " + randomPassword; mask = SecretDetector.MaskSecrets(randomPasswordEqualSign); Assert.IsTrue(mask.isMasked); - Assert.AreEqual(@"password = ****", mask.maskedText); + Assert.AreEqual(@"password =****", mask.maskedText); Assert.IsNull(mask.errStr); string randomPwdWithPrefix = "pwd:" + randomPassword; @@ -350,9 +415,7 @@ public void TestTokenPassword() mask = SecretDetector.MaskSecrets(testStringWithPrefix); Assert.IsTrue(mask.isMasked); Assert.AreEqual( - "token=****" + - " random giberish " + - "password:****", + "token=****", mask.maskedText); Assert.IsNull(mask.errStr); @@ -378,11 +441,7 @@ public void TestTokenPassword() mask = SecretDetector.MaskSecrets(testStringWithPrefix); Assert.IsTrue(mask.isMasked); Assert.AreEqual( - "token=****" + - " random giberish " + - "password:****" + - " random giberish " + - "idToken:****", + "token=****", mask.maskedText); Assert.IsNull(mask.errStr); @@ -393,10 +452,7 @@ public void TestTokenPassword() mask = SecretDetector.MaskSecrets(testStringWithPrefix); Assert.IsTrue(mask.isMasked); Assert.AreEqual( - "password=****" + - " random giberish " + - "pwd:****", - mask.maskedText); + "password=****", mask.maskedText); Assert.IsNull(mask.errStr); // multiple passwords @@ -408,15 +464,23 @@ public void TestTokenPassword() mask = SecretDetector.MaskSecrets(testStringWithPrefix); Assert.IsTrue(mask.isMasked); Assert.AreEqual( - "password=****" + - " random giberish " + - "password=****" + - " random giberish " + "password=****", mask.maskedText); Assert.IsNull(mask.errStr); } + [Test] + public void TestTokenProperty() + { + BasicMasking(@"somethingBefore=cccc;token=aa", @"somethingBefore=cccc;token=****"); + BasicMasking(@"somethingBefore=cccc;token=aa;somethingNext=bbbb", @"somethingBefore=cccc;token=****"); + BasicMasking(@"somethingBefore=cccc;token=""aa"";somethingNext=bbbb", @"somethingBefore=cccc;token=****"); + BasicMasking(@"somethingBefore=cccc;token=;somethingNext=bbbb", @"somethingBefore=cccc;token=****"); + BasicMasking(@"somethingBefore=cccc;token=", @"somethingBefore=cccc;token=****"); + BasicMasking(@"somethingBefore=cccc;token =aa;somethingNext=bbbb", @"somethingBefore=cccc;token =****"); + BasicMasking(@"somethingBefore=cccc;token="" 'aa", @"somethingBefore=cccc;token=****"); + } + [Test] public void TestCustomPattern() { diff --git a/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs new file mode 100644 index 000000000..0cc61f28b --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/ConnectionPoolConfigExtractorTest.cs @@ -0,0 +1,339 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class ConnectionPoolConfigExtractorTest + { + [Test] + public void TestExtractDefaultValues() + { + // arrange + var connectionString = "account=test;user=test;password=test;"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(SFSessionHttpClientProperties.DefaultMaxPoolSize, result.MaxPoolSize, "max pool size"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultMinPoolSize, result.MinPoolSize, "min pool size"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultChangedSession, result.ChangedSession, "changed session"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultExpirationTimeout, result.ExpirationTimeout, "expiration timeout"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultWaitingForIdleSessionTimeout, result.WaitingForIdleSessionTimeout, "waiting for idle session timeout"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultConnectionTimeout, result.ConnectionTimeout, "connection timeout"); + Assert.AreEqual(SFSessionHttpClientProperties.DefaultPoolingEnabled, result.PoolingEnabled, "pooling enabled"); + } + + [Test] + public void TestExtractMaxPoolSize() + { + // arrange + var maxPoolSize = 15; + var connectionString = $"account=test;user=test;password=test;maxPoolSize={maxPoolSize}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(maxPoolSize, result.MaxPoolSize); + } + + [Test] + [TestCase("wrong_value")] + [TestCase("0")] + [TestCase("-1")] + public void TestExtractFailsForWrongValueOfMaxPoolSize(string maxPoolSize) + { + // arrange + var connectionString = $"account=test;user=test;password=test;maxPoolSize={maxPoolSize}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.MAXPOOLSIZE.ToString()}")); + } + + [Test] + [TestCase("0", 0)] + [TestCase("7", 7)] + [TestCase("10", 10)] + public void TestExtractMinPoolSize(string propertyValue, int expectedMinPoolSize) + { + // arrange + var connectionString = $"account=test;user=test;password=test;minPoolSize={propertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(expectedMinPoolSize, result.MinPoolSize); + } + + [Test] + [TestCase("wrong_value")] + [TestCase("-1")] + public void TestExtractFailsForWrongValueOfMinPoolSize(string minPoolSize) + { + // arrange + var connectionString = $"account=test;user=test;password=test;minPoolSize={minPoolSize}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.MINPOOLSIZE.ToString()}")); + } + + [Test] + public void TestExtractFailsWhenMinPoolSizeGreaterThanMaxPoolSize() + { + // arrange + var connectionString = $"account=test;user=test;password=test;minPoolSize=10;maxPoolSize=9"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain("MinPoolSize cannot be greater than MaxPoolSize")); + } + + [Test] + [TestCaseSource(nameof(CorrectTimeoutsWithZeroUnchanged))] + public void TestExtractExpirationTimeout(TimeoutTestCase testCase) + { + // arrange + var connectionString = $"account=test;user=test;password=test;expirationTimeout={testCase.PropertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(testCase.ExpectedTimeout, result.ExpirationTimeout); + } + + [Test] + [TestCaseSource(nameof(IncorrectTimeouts))] + public void TestExtractExpirationTimeoutFailsWhenWrongValue(string propertyValue) + { + // arrange + var connectionString = $"account=test;user=test;password=test;expirationTimeout={propertyValue}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.EXPIRATIONTIMEOUT.ToString()}")); + } + + [Test] + [TestCaseSource(nameof(PositiveTimeoutsAndZeroUnchanged))] + public void TestExtractWaitingForIdleSessionTimeout(TimeoutTestCase testCase) + { + // arrange + var connectionString = $"account=test;user=test;password=test;waitingForIdleSessionTimeout={testCase.PropertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(testCase.ExpectedTimeout, result.WaitingForIdleSessionTimeout); + } + + [Test] + public void TestExtractWaitingForIdleSessionTimeoutFailsForInfiniteTimeout() + { + // arrange + var connectionString = $"account=test;user=test;password=test;waitingForIdleSessionTimeout=-1"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain("Waiting for idle session timeout cannot be infinite")); + } + + [Test] + [TestCaseSource(nameof(IncorrectTimeouts))] + public void TestExtractWaitingForIdleSessionTimeoutFailsWhenWrongValue(string propertyValue) + { + // arrange + var connectionString = $"account=test;user=test;password=test;waitingForIdleSessionTimeout={propertyValue}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT.ToString()}")); + } + + [Test] + [TestCaseSource(nameof(CorrectTimeoutsWithZeroAsInfinite))] + public void TestExtractConnectionTimeout(TimeoutTestCase testCase) + { + // arrange + var connectionString = $"account=test;user=test;password=test;CONNECTION_TIMEOUT={testCase.PropertyValue};RETRY_TIMEOUT=60m"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(testCase.ExpectedTimeout, result.ConnectionTimeout); + } + + [Test] + [TestCaseSource(nameof(IncorrectTimeouts))] + public void TestExtractConnectionTimeoutFailsForWrongValue(string propertyValue) + { + // arrange + var connectionString = $"account=test;user=test;password=test;CONNECTION_TIMEOUT={propertyValue}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.CONNECTION_TIMEOUT.ToString()}")); + } + + [Test] + [TestCase("true", true)] + [TestCase("TRUE", true)] + [TestCase("false", false)] + [TestCase("FALSE", false)] + public void TestExtractPoolingEnabled(string propertyValue, bool poolingEnabled) + { + // arrange + var connectionString = $"account=test;user=test;password=test;poolingEnabled={propertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(poolingEnabled, result.PoolingEnabled); + } + + [Test] + [TestCase("account=test;user=test;password=test;", true)] + [TestCase("authenticator=externalbrowser;account=test;user=test;", false)] + [TestCase("authenticator=externalbrowser;account=test;user=test;poolingEnabled=true;", true)] + [TestCase("authenticator=externalbrowser;account=test;user=test;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;poolingEnabled=true;", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey;poolingEnabled=true;", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key=secretKey;poolingEnabled=false;", false)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;private_key_pwd=secretPwd", true)] + [TestCase("authenticator=snowflake_jwt;account=test;user=test;private_key_file=/some/file.key;private_key_pwd=", false)] + public void TestDisablePoolingDefaultWhenSecretsProvidedExternally(string connectionString, bool poolingEnabled) + { + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(poolingEnabled, result.PoolingEnabled); + } + + [Test] + [TestCase("wrong_value")] + [TestCase("15")] + public void TestExtractFailsForWrongValueOfPoolingEnabled(string propertyValue) + { + // arrange + var connectionString = $"account=test;user=test;password=test;poolingEnabled={propertyValue}"; + + // act + var thrown = Assert.Throws(() => ExtractConnectionPoolConfig(connectionString)); + + // assert + Assert.That(thrown.Message, Does.Contain($"Invalid value of parameter {SFSessionProperty.POOLINGENABLED.ToString()}")); + } + + [Test] + [TestCase("OriginalPool", ChangedSessionBehavior.OriginalPool)] + [TestCase("originalpool", ChangedSessionBehavior.OriginalPool)] + [TestCase("ORIGINALPOOL", ChangedSessionBehavior.OriginalPool)] + [TestCase("Destroy", ChangedSessionBehavior.Destroy)] + [TestCase("DESTROY", ChangedSessionBehavior.Destroy)] + public void TestExtractChangedSessionBehaviour(string propertyValue, ChangedSessionBehavior expectedChangedSession) + { + // arrange + var connectionString = $"account=test;user=test;password=test;changedSession={propertyValue}"; + + // act + var result = ExtractConnectionPoolConfig(connectionString); + + // assert + Assert.AreEqual(expectedChangedSession, result.ChangedSession); + } + + private ConnectionPoolConfig ExtractConnectionPoolConfig(string connectionString) => + SessionPool.ExtractConfig(connectionString, null).Item1; + + public class TimeoutTestCase + { + public string PropertyValue { get; } + public TimeSpan ExpectedTimeout { get; } + + public TimeoutTestCase(string propertyValue, TimeSpan expectedTimeout) + { + PropertyValue = propertyValue; + ExpectedTimeout = expectedTimeout; + } + } + + public static IEnumerable CorrectTimeoutsWithZeroUnchanged() => + CorrectTimeoutsWithoutZero().Concat(ZeroUnchangedTimeouts()); + + public static IEnumerable CorrectTimeoutsWithZeroAsInfinite() => + CorrectTimeoutsWithoutZero().Concat(ZeroAsInfiniteTimeouts()); + + public static IEnumerable PositiveTimeoutsAndZeroUnchanged() => + PositiveTimeouts().Concat(ZeroUnchangedTimeouts()); + + private static IEnumerable CorrectTimeoutsWithoutZero() => + NegativeAsInfinityTimeouts().Concat(PositiveTimeouts()); + + private static IEnumerable NegativeAsInfinityTimeouts() + { + yield return new TimeoutTestCase("-1", TimeoutHelper.Infinity()); + } + + private static IEnumerable PositiveTimeouts() + { + yield return new TimeoutTestCase("5", TimeSpan.FromSeconds(5)); + yield return new TimeoutTestCase("6s", TimeSpan.FromSeconds(6)); + yield return new TimeoutTestCase("7S", TimeSpan.FromSeconds(7)); + yield return new TimeoutTestCase("8m", TimeSpan.FromMinutes(8)); + yield return new TimeoutTestCase("9M", TimeSpan.FromMinutes(9)); + yield return new TimeoutTestCase("10ms", TimeSpan.FromMilliseconds(10)); + yield return new TimeoutTestCase("11ms", TimeSpan.FromMilliseconds(11)); + } + + private static IEnumerable ZeroAsInfiniteTimeouts() + { + yield return new TimeoutTestCase("0", TimeoutHelper.Infinity()); + yield return new TimeoutTestCase("0ms", TimeoutHelper.Infinity()); + } + + private static IEnumerable ZeroUnchangedTimeouts() + { + yield return new TimeoutTestCase("0", TimeSpan.Zero); + yield return new TimeoutTestCase("0ms", TimeSpan.Zero); + } + + public static IEnumerable IncorrectTimeouts() + { + yield return "wrong value"; + yield return "1h"; + yield return "1s1s"; + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/FixedZeroCounterTest.cs b/Snowflake.Data.Tests/UnitTests/Session/FixedZeroCounterTest.cs new file mode 100644 index 000000000..fd04be9af --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/FixedZeroCounterTest.cs @@ -0,0 +1,61 @@ +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class FixedZeroCounterTest + { + [Test] + public void TestInitialZero() + { + // arrange + var counter = new FixedZeroCounter(); + + // act + var count = counter.Count(); + + // assert + Assert.AreEqual(0, count); + } + + [Test] + public void TestZeroAfterIncrease() + { + // arrange + var counter = new FixedZeroCounter(); + + // act + counter.Increase(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + + [Test] + public void TestZeroAfterDecrease() + { + // arrange + var counter = new FixedZeroCounter(); + + // act + counter.Decrease(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + + [Test] + public void TestZeroAfterReset() + { + // arrange + var counter = new FixedZeroCounter(); + + // act + counter.Reset(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/NonCountingSessionCreationTokenCounterTest.cs b/Snowflake.Data.Tests/UnitTests/Session/NonCountingSessionCreationTokenCounterTest.cs new file mode 100644 index 000000000..9f73c7c7a --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/NonCountingSessionCreationTokenCounterTest.cs @@ -0,0 +1,52 @@ +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class NonCountingSessionCreationTokenCounterTest + { + [Test] + public void TestGrantSessionCreation() + { + // arrange + var tokens = new NonCountingSessionCreationTokenCounter(); + + // act + tokens.NewToken(); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteSessionCreation() + { + // arrange + var tokens = new NonCountingSessionCreationTokenCounter(); + var token = tokens.NewToken(); + + // act + tokens.RemoveToken(token); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteUnknownTokenDoesNotThrowExceptions() + { + // arrange + var tokens = new NonCountingSessionCreationTokenCounter(); + tokens.NewToken(); + var unknownToken = new SessionCreationToken(SFSessionHttpClientProperties.DefaultConnectionTimeout); + + // act + tokens.RemoveToken(unknownToken); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/NonNegativeCounterTest.cs b/Snowflake.Data.Tests/UnitTests/Session/NonNegativeCounterTest.cs new file mode 100644 index 000000000..638299532 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/NonNegativeCounterTest.cs @@ -0,0 +1,91 @@ +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class NonNegativeCounterTest + { + [Test] + public void TestInitialZero() + { + // arrange + var counter = new NonNegativeCounter(); + + // act + var count = counter.Count(); + + // assert + Assert.AreEqual(0, count); + } + + [Test] + public void TestIncrease() + { + // arrange + var counter = new NonNegativeCounter(); + + // act + counter.Increase(); + + // assert + Assert.AreEqual(1, counter.Count()); + + // act + counter.Increase(); + + // assert + Assert.AreEqual(2, counter.Count()); + } + + + [Test] + public void TestDecrease() + { + // arrange + var counter = new NonNegativeCounter(); + counter.Increase(); + counter.Increase(); + + // act + counter.Decrease(); + + // assert + Assert.AreEqual(1, counter.Count()); + + // act + counter.Decrease(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + + [Test] + public void TestDecreaseDoesNotGoBelowZero() + { + // arrange + var counter = new NonNegativeCounter(); + + // act + counter.Decrease(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + + [Test] + public void TestReset() + { + // arrange + var counter = new NonNegativeCounter(); + counter.Increase(); + counter.Increase(); + + // act + counter.Reset(); + + // assert + Assert.AreEqual(0, counter.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/NonWaitingQueueTest.cs b/Snowflake.Data.Tests/UnitTests/Session/NonWaitingQueueTest.cs new file mode 100644 index 000000000..4d519b3a9 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/NonWaitingQueueTest.cs @@ -0,0 +1,65 @@ +using System.Diagnostics; +using System.Threading; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class NonWaitingQueueTest + { + [Test] + public void TestWaitDoesNotHangAndReturnsFalse() + { + // arrange + var nonWaitingQueue = new NonWaitingQueue(); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = nonWaitingQueue.Wait(10000, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.LessOrEqual(watch.ElapsedMilliseconds, 50); + } + + [Test] + public void TestNoOneIsWaiting() + { + // arrange + var nonWaitingQueue = new NonWaitingQueue(); + nonWaitingQueue.Wait(10000, CancellationToken.None); + + // act + var isAnyoneWaiting = nonWaitingQueue.IsAnyoneWaiting(); + + // assert + Assert.IsFalse(isAnyoneWaiting); + } + + [Test] + public void TestWaitingDisabled() + { + // arrange + var nonWaitingQueue = new NonWaitingQueue(); + + // act + var isWaitingEnabled = nonWaitingQueue.IsWaitingEnabled(); + + // assert + Assert.IsFalse(isWaitingEnabled); + } + + [Test] + public void TestReset() + { + // arrange + var nonWaitingQueue = new NonWaitingQueue(); + + // act/assert + Assert.DoesNotThrow(() => nonWaitingQueue.Reset()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs index 617e3d429..18f1ff7d7 100644 --- a/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs +++ b/Snowflake.Data.Tests/UnitTests/Session/SFHttpClientPropertiesTest.cs @@ -2,15 +2,15 @@ * Copyright (c) 2023 Snowflake Computing Inc. All rights reserved. */ +using System; using System.Collections.Generic; -using Moq; using NUnit.Framework; using Snowflake.Data.Core; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Tests.Util; namespace Snowflake.Data.Tests.UnitTests.Session { - [TestFixture] public class SFHttpClientPropertiesTest { @@ -32,18 +32,18 @@ public void TestConvertToMapOnly2Properties( { validateDefaultParameters = validateDefaultParameters, clientSessionKeepAlive = clientSessionKeepAlive, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, maxHttpRetries = 7, proxyProperties = proxyProperties }; - + // act var parameterMap = properties.ToParameterMap(); - + // assert Assert.AreEqual(2, parameterMap.Count); Assert.AreEqual(validateDefaultParameters, parameterMap[SFSessionParameter.CLIENT_VALIDATE_DEFAULT_PARAMETERS]); @@ -55,7 +55,7 @@ public void TestBuildHttpClientConfig() { // arrange var properties = RandomSFSessionHttpClientProperties(); - + // act var config = properties.BuildHttpClientConfig(); @@ -80,11 +80,11 @@ public void TestCrlCheckEnabledToBeOppositeInsecureMode([Values] bool insecureMo // act var config = properties.BuildHttpClientConfig(); - + // assert Assert.AreEqual(!insecureMode, config.CrlCheckEnabled); } - + private SFSessionHttpClientProperties RandomSFSessionHttpClientProperties() { var proxyProperties = new SFSessionHttpClientProxyProperties() @@ -99,11 +99,11 @@ private SFSessionHttpClientProperties RandomSFSessionHttpClientProperties() { validateDefaultParameters = TestDataGenarator.NextBool(), clientSessionKeepAlive = TestDataGenarator.NextBool(), - timeoutInSec = TestDataGenarator.NextInt(30, 151), + connectionTimeout = TimeSpan.FromSeconds(TestDataGenarator.NextInt(30, 151)), insecureMode = TestDataGenarator.NextBool(), disableRetry = TestDataGenarator.NextBool(), forceRetryOn404 = TestDataGenarator.NextBool(), - retryTimeout = TestDataGenarator.NextInt(300, 600), + retryTimeout = TimeSpan.FromSeconds(TestDataGenarator.NextInt(300, 600)), maxHttpRetries = TestDataGenarator.NextInt(0, 15), proxyProperties = proxyProperties }; @@ -113,31 +113,24 @@ private SFSessionHttpClientProperties RandomSFSessionHttpClientProperties() public void TestExtractProperties(PropertiesTestCase testCase) { // arrange - var proxyExtractorMock = new Moq.Mock(); - var extractor = new SFSessionHttpClientProperties.Extractor(proxyExtractorMock.Object); var properties = SFSessionProperties.ParseConnectionString(testCase.conectionString, null); var proxyProperties = new SFSessionHttpClientProxyProperties(); - proxyExtractorMock - .Setup(e => e.ExtractProperties(properties)) - .Returns(proxyProperties); // act - var extractedProperties = extractor.ExtractProperties(properties); - extractedProperties.CheckPropertiesAreValid(); + var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); // assert Assert.AreEqual(testCase.expectedProperties.validateDefaultParameters, extractedProperties.validateDefaultParameters); Assert.AreEqual(testCase.expectedProperties.clientSessionKeepAlive, extractedProperties.clientSessionKeepAlive); - Assert.AreEqual(testCase.expectedProperties.timeoutInSec, extractedProperties.timeoutInSec); + Assert.AreEqual(testCase.expectedProperties.connectionTimeout, extractedProperties.connectionTimeout); Assert.AreEqual(testCase.expectedProperties.insecureMode, extractedProperties.insecureMode); Assert.AreEqual(testCase.expectedProperties.disableRetry, extractedProperties.disableRetry); Assert.AreEqual(testCase.expectedProperties.forceRetryOn404, extractedProperties.forceRetryOn404); Assert.AreEqual(testCase.expectedProperties.retryTimeout, extractedProperties.retryTimeout); Assert.AreEqual(testCase.expectedProperties.maxHttpRetries, extractedProperties.maxHttpRetries); - Assert.AreEqual(proxyProperties, extractedProperties.proxyProperties); - proxyExtractorMock.Verify(e => e.ExtractProperties(properties), Times.Once); + Assert.NotNull(proxyProperties); } - + public static IEnumerable PropertiesProvider() { var defaultProperties = new PropertiesTestCase() @@ -147,12 +140,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithValidateDefaultParametersChanged = new PropertiesTestCase() @@ -162,12 +155,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = false, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithClientSessionKeepAliveChanged = new PropertiesTestCase() @@ -177,12 +170,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = true, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithTimeoutChanged = new PropertiesTestCase() @@ -192,12 +185,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = 15, + connectionTimeout = TimeSpan.FromSeconds(15), insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithInsecureModeChanged = new PropertiesTestCase() @@ -207,12 +200,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = true, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithDisableRetryChanged = new PropertiesTestCase() @@ -222,12 +215,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = true, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithForceRetryOn404Changed = new PropertiesTestCase() @@ -237,12 +230,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = true, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithRetryTimeoutChangedToAValueAbove300 = new PropertiesTestCase() @@ -252,12 +245,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = 600, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = TimeSpan.FromSeconds(600), + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithRetryTimeoutChangedToAValueBelow300 = new PropertiesTestCase() @@ -267,12 +260,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithRetryTimeoutChangedToZero = new PropertiesTestCase() @@ -282,12 +275,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = 0, + connectionTimeout = SFSessionHttpClientProperties.DefaultConnectionTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = 0, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = TimeoutHelper.Infinity(), + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithMaxHttpRetriesChangedToAValueAbove7 = new PropertiesTestCase() @@ -297,11 +290,11 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, maxHttpRetries = 10 } }; @@ -312,12 +305,12 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, - maxHttpRetries = SFSessionHttpClientProperties.s_maxHttpRetriesDefault + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, + maxHttpRetries = SFSessionHttpClientProperties.DefaultMaxHttpRetries } }; var propertiesWithMaxHttpRetriesChangedToZero = new PropertiesTestCase() @@ -327,11 +320,11 @@ public static IEnumerable PropertiesProvider() { validateDefaultParameters = true, clientSessionKeepAlive = false, - timeoutInSec = SFSessionHttpClientProperties.s_retryTimeoutDefault, + connectionTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, insecureMode = false, disableRetry = false, forceRetryOn404 = false, - retryTimeout = SFSessionHttpClientProperties.s_retryTimeoutDefault, + retryTimeout = SFSessionHttpClientProperties.DefaultRetryTimeout, maxHttpRetries = 0 } }; diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenCounterTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenCounterTest.cs new file mode 100644 index 000000000..ee825b44a --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenCounterTest.cs @@ -0,0 +1,103 @@ +using System; +using System.Threading; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SessionCreationTokenCounterTest + { + private static readonly TimeSpan s_longTime = TimeSpan.FromSeconds(30); + private static readonly TimeSpan s_shortTime = TimeSpan.FromMilliseconds(50); + + [Test] + public void TestGrantSessionCreation() + { + // arrange + var tokens = new SessionCreationTokenCounter(s_longTime); + + // act + tokens.NewToken(); + + // assert + Assert.AreEqual(1, tokens.Count()); + + // act + tokens.NewToken(); + + // assert + Assert.AreEqual(2, tokens.Count()); + } + + [Test] + public void TestCompleteSessionCreation() + { + // arrange + var tokens = new SessionCreationTokenCounter(s_longTime); + var token1 = tokens.NewToken(); + var token2 = tokens.NewToken(); + + // act + tokens.RemoveToken(token1); + + // assert + Assert.AreEqual(1, tokens.Count()); + + // act + tokens.RemoveToken(token2); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestCompleteUnknownTokenDoesNotThrowExceptions() + { + // arrange + var tokens = new SessionCreationTokenCounter(s_longTime); + tokens.NewToken(); + var unknownToken = new SessionCreationToken(SFSessionHttpClientProperties.DefaultConnectionTimeout); + + // act + tokens.RemoveToken(unknownToken); + + // assert + Assert.AreEqual(1, tokens.Count()); + } + + [Test] + public void TestCompleteCleansExpiredTokens() + { + // arrange + var tokens = new SessionCreationTokenCounter(s_shortTime); + var token = tokens.NewToken(); + tokens.NewToken(); // this token will be cleaned because of expiration + Assert.AreEqual(2, tokens.Count()); + const int EpsilonMillis = 5; + Thread.Sleep((int) s_shortTime.TotalMilliseconds + EpsilonMillis); + + // act + tokens.RemoveToken(token); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + + [Test] + public void TestResetTokens() + { + // arrange + var tokens = new SessionCreationTokenCounter(s_longTime); + tokens.NewToken(); + tokens.NewToken(); + + // act + tokens.Reset(); + + // assert + Assert.AreEqual(0, tokens.Count()); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenTest.cs new file mode 100644 index 000000000..13b45b9b1 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionCreationTokenTest.cs @@ -0,0 +1,39 @@ +using System; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SessionCreationTokenTest + { + private static readonly TimeSpan s_timeout30Seconds = TimeSpan.FromSeconds(30); + + [Test] + public void TestTokenIsNotExpired() + { + // arrange + var token = new SessionCreationToken(s_timeout30Seconds); + + // act + var isExpired = token.IsExpired(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()); + + // assert + Assert.IsFalse(isExpired); + } + + [Test] + public void TestTokenIsExpired() + { + // arrange + var token = new SessionCreationToken(s_timeout30Seconds); + var timeout30SecondsAsMillis = (long) s_timeout30Seconds.TotalMilliseconds; + + // act + var isExpired = token.IsExpired(DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() + timeout30SecondsAsMillis + 1); + + // assert + Assert.IsTrue(isExpired); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs new file mode 100644 index 000000000..7d2b1a603 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionOrCreationTokensTest.cs @@ -0,0 +1,62 @@ +using System; +using System.Linq; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SessionOrCreationTokensTest + { + private SFSession _session = new SFSession("account=test;user=test;password=test", null); + + [Test] + public void TestNoBackgroundSessionsToCreateWhenInitialisedWithSession() + { + // arrange + var sessionOrTokens = new SessionOrCreationTokens(_session); + + // act + var backgroundCreationTokens = sessionOrTokens.BackgroundSessionCreationTokens(); + + Assert.AreEqual(0, backgroundCreationTokens.Count); + } + + [Test] + public void TestReturnFirstCreationToken() + { + // arrange + var sessionCreationTokenCounter = new SessionCreationTokenCounter(TimeSpan.FromSeconds(10)); + var tokens = Enumerable.Range(1, 3) + .Select(_ => sessionCreationTokenCounter.NewToken()) + .ToList(); + var sessionOrTokens = new SessionOrCreationTokens(tokens); + + // act + var token = sessionOrTokens.SessionCreationToken(); + + // assert + Assert.AreSame(tokens[0], token); + } + + [Test] + public void TestReturnCreationTokensFromTheSecondOneForBackgroundExecution() + { + // arrange + var sessionCreationTokenCounter = new SessionCreationTokenCounter(TimeSpan.FromSeconds(10)); + var tokens = Enumerable.Range(1, 3) + .Select(_ => sessionCreationTokenCounter.NewToken()) + .ToList(); + var sessionOrTokens = new SessionOrCreationTokens(tokens); + + // act + var backgroundTokens = sessionOrTokens.BackgroundSessionCreationTokens(); + + // assert + Assert.AreEqual(2, backgroundTokens.Count); + Assert.AreSame(tokens[1], backgroundTokens[0]); + Assert.AreSame(tokens[2], backgroundTokens[1]); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs new file mode 100644 index 000000000..fca8f7de1 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs @@ -0,0 +1,170 @@ +using System; +using System.Net; +using System.Text.RegularExpressions; +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SessionPoolTest + { + private const string ConnectionString = "ACCOUNT=testaccount;USER=testuser;PASSWORD=testpassword;"; + + [Test] + public void TestPoolParametersAreNotOverriden() + { + // act + var pool = SessionPool.CreateSessionPool(ConnectionString, null); + + // assert + Assert.IsFalse(pool.IsConfigOverridden()); + } + + [Test] + public void TestOverrideMaxPoolSize() + { + // arrange + var pool = SessionPool.CreateSessionPool(ConnectionString, null); + var newMaxPoolSize = 15; + + // act + pool.SetMaxPoolSize(newMaxPoolSize); + + // assert + Assert.AreEqual(newMaxPoolSize, pool.GetMaxPoolSize()); + Assert.IsTrue(pool.IsConfigOverridden()); + } + + [Test] + public void TestOverrideExpirationTimeout() + { + // arrange + var pool = SessionPool.CreateSessionPool(ConnectionString, null); + var newExpirationTimeoutSeconds = 15; + + // act + pool.SetTimeout(newExpirationTimeoutSeconds); + + // assert + Assert.AreEqual(newExpirationTimeoutSeconds, pool.GetTimeout()); + Assert.IsTrue(pool.IsConfigOverridden()); + } + + [Test] + public void TestOverrideSetPooling() + { + // arrange + var pool = SessionPool.CreateSessionPool(ConnectionString, null); + + // act + pool.SetPooling(false); + + // assert + Assert.IsFalse(pool.GetPooling()); + Assert.IsTrue(pool.IsConfigOverridden()); + } + + [Test] + [TestCase("account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443", "somePassword", " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key=SomePrivateKey;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;token=someToken;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;private_key_pwd=somePrivateKeyPwd;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;proxyPassword=someProxyPassword;port=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=someAccount;DB=someDb;HOST=someHost;PASSWORD=somePassword;USER=SomeUser;PORT=443", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + [TestCase("ACCOUNT=\"someAccount\";DB=\"someDb\";HOST=\"someHost\";PASSWORD=\"somePassword\";USER=\"SomeUser\";PORT=\"443\"", null, " [pool: account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443;]")] + public void TestPoolIdentificationBasedOnConnectionString(string connectionString, string password, string expectedPoolIdentification) + { + // arrange + var securePassword = password == null ? null : new NetworkCredential("", password).SecurePassword; + var pool = SessionPool.CreateSessionPool(connectionString, securePassword); + + // act + var poolIdentification = pool.PoolIdentificationBasedOnConnectionString; + + // assert + Assert.AreEqual(expectedPoolIdentification, poolIdentification); + } + + [Test] + public void TestRetrievePoolFailureForInvalidConnectionString() + { + // arrange + var invalidConnectionString = "account=someAccount;db=someDb;host=someHost;user=SomeUser;port=443"; // invalid because password is not provided + + // act + var exception = Assert.Throws(() => SessionPool.CreateSessionPool(invalidConnectionString, null)); + + // assert + SnowflakeDbExceptionAssert.HasErrorCode(exception, SFError.MISSING_CONNECTION_PROPERTY); + Assert.IsTrue(exception.Message.Contains("Required property PASSWORD is not provided")); + } + + [Test] + public void TestPoolIdentificationBasedOnInternalId() + { + // arrange + var connectionString = "account=someAccount;db=someDb;host=someHost;password=somePassword;user=SomeUser;port=443"; + var pool = SessionPool.CreateSessionPool(connectionString, null); + var poolIdRegex = new Regex(@"^ \[pool: [0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\]$"); + + // act + var poolIdentification = pool.PoolIdentificationBasedOnInternalId; + + // assert + Assert.IsTrue(poolIdRegex.IsMatch(poolIdentification)); + } + + [Test] + public void TestPoolIdentificationForOldPool() + { + // arrange + var pool = SessionPool.CreateSessionCache(); + + // act + var poolIdentification = pool.PoolIdentification(); + + // assert + Assert.AreEqual("", poolIdentification); + } + + [Test] + [TestCase(null)] + [TestCase("")] + [TestCase("anyPassword")] + public void TestValidateValidSecurePassword(string password) + { + // arrange + var securePassword = password == null ? null : SecureStringHelper.Encode(password); + var pool = SessionPool.CreateSessionPool(ConnectionString, securePassword); + + // act + Assert.DoesNotThrow(() => pool.ValidateSecurePassword(securePassword)); + } + + [Test] + [TestCase("somePassword", null)] + [TestCase("somePassword", "")] + [TestCase("somePassword", "anotherPassword")] + [TestCase("", "anotherPassword")] + [TestCase(null, "anotherPassword")] + public void TestFailToValidateNotMatchingSecurePassword(string poolPassword, string notMatchingPassword) + { + // arrange + var poolSecurePassword = poolPassword == null ? null : SecureStringHelper.Encode(poolPassword); + var notMatchingSecurePassword = notMatchingPassword == null ? null : SecureStringHelper.Encode(notMatchingPassword); + var pool = SessionPool.CreateSessionPool(ConnectionString, poolSecurePassword); + + // act + var thrown = Assert.Throws(() => pool.ValidateSecurePassword(notMatchingSecurePassword)); + + // assert + Assert.That(thrown.Message, Does.Contain("Could not get a pool because of password mismatch")); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/SessionPropertiesWithDefaultValuesExtractorTest.cs b/Snowflake.Data.Tests/UnitTests/Session/SessionPropertiesWithDefaultValuesExtractorTest.cs new file mode 100644 index 000000000..3192c083e --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/SessionPropertiesWithDefaultValuesExtractorTest.cs @@ -0,0 +1,213 @@ +using System; +using NUnit.Framework; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class SessionPropertiesWithDefaultValuesExtractorTest + { + [Test] + public void TestReturnExtractedValue() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => true + ); + + // assert + Assert.AreEqual(15, value); + } + + [Test] + public void TestReturnDefaultValueWhenValueIsMissing( + [Values] bool failOnWrongValue) + { + // arrange + var properties = SFSessionProperties.ParseConnectionString($"account=test;user=test;password=test", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => true + ); + + // assert + Assert.AreEqual(defaultValue, value); + } + + [Test] + public void TestReturnDefaultValueWhenPreValidationFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => false, + i => true + ); + + // assert + Assert.AreEqual(defaultValue, value); + } + + [Test] + public void TestFailForPropertyWithInvalidDefaultValue() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + + // act + var thrown = Assert.Throws(() => extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + s => s, + s => true, + s => false)); + + // assert + Assert.That(thrown.Message, Does.Contain("Invalid default value of CONNECTION_TIMEOUT")); + } + + [Test] + public void TestReturnDefaultValueForNullProperty() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;", null); + properties[SFSessionProperty.CONNECTION_TIMEOUT] = null; + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => true); + + // assert + Assert.AreEqual(defaultValue, value); + } + + [Test] + public void TestReturnDefaultValueWhenPostValidationFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => i == defaultValue + ); + + // assert + Assert.AreEqual(defaultValue, value); + } + + [Test] + public void TestReturnDefaultValueWhenExtractFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15X", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, false); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var value = extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => true + ); + + // assert + Assert.AreEqual(defaultValue, value); + } + + [Test] + public void TestFailWhenPreValidationFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, true); + + // act + var thrown = Assert.Throws(() => + extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => false, + i => true + )); + + // assert + Assert.That(thrown.Message, Does.Contain("Invalid value of parameter CONNECTION_TIMEOUT")); + } + + [Test] + public void TestFailWhenPostValidationFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, true); + var defaultValue = GetDefaultIntSessionProperty(SFSessionProperty.CONNECTION_TIMEOUT); + + // act + var thrown = Assert.Throws(() => + extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => i == defaultValue + )); + + // assert + Assert.That(thrown.Message, Does.Contain("Invalid value of parameter CONNECTION_TIMEOUT")); + } + + [Test] + public void TestFailWhenExtractFails() + { + // arrange + var properties = SFSessionProperties.ParseConnectionString("account=test;user=test;password=test;connection_timeout=15X", null); + var extractor = new SessionPropertiesWithDefaultValuesExtractor(properties, true); + + // act + var thrown = Assert.Throws(() => + extractor.ExtractPropertyWithDefaultValue( + SFSessionProperty.CONNECTION_TIMEOUT, + int.Parse, + s => true, + i => true + )); + + // assert + Assert.That(thrown.Message, Does.Contain("Invalid value of parameter CONNECTION_TIMEOUT")); + } + + private int GetDefaultIntSessionProperty(SFSessionProperty property) => + int.Parse(SFSessionProperty.CONNECTION_TIMEOUT.GetAttribute().defaultValue); + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Session/WaitingQueueTest.cs b/Snowflake.Data.Tests/UnitTests/Session/WaitingQueueTest.cs new file mode 100644 index 000000000..530eae133 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Session/WaitingQueueTest.cs @@ -0,0 +1,140 @@ +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests.Session +{ + [TestFixture] + public class WaitingQueueTest + { + [Test] + public void TestWaitForTheResourceUntilTimeout() + { + // arrange + var queue = new WaitingQueue(); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = queue.Wait(50, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(45, 1500)); // sometimes Wait takes a bit smaller amount of time than it should. Thus we expect it to be greater than 45, not just 50. + } + + [Test] + public void TestWaitForTheResourceUntilCancellation() + { + // arrange + var queue = new WaitingQueue(); + var cancellationSource = new CancellationTokenSource(50); + var watch = new Stopwatch(); + + // act + watch.Start(); + var result = queue.Wait(30000, cancellationSource.Token); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(45, 1500)); // sometimes Wait takes a bit smaller amount of time than it should. Thus we expect it to be greater than 45, not just 50. + } + + [Test] + [Retry(2)] + public void TestWaitUntilResourceAvailable() + { + // arrange + var queue = new WaitingQueue(); + var watch = new Stopwatch(); + Task.Run(() => + { + Thread.Sleep(50); + queue.OnResourceIncrease(); + }); + + // act + watch.Start(); + var result = queue.Wait(30000, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsTrue(result); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(50, 1500)); + } + + [Test] + public void TestWaitingEnabled() + { + // arrange + var queue = new WaitingQueue(); + + // act + var isWaitingEnabled = queue.IsWaitingEnabled(); + + // assert + Assert.IsTrue(isWaitingEnabled); + } + + [Test] + public void TestNoOneIsWaiting() + { + // arrange + var queue = new WaitingQueue(); + + // act + var isAnyoneWaiting = queue.IsAnyoneWaiting(); + + // assert + Assert.IsFalse(isAnyoneWaiting); + } + + [Test] + public void TestSomeoneIsWaiting() + { + // arrange + var queue = new WaitingQueue(); + var syncThreadsSemaphore = new SemaphoreSlim(0, 1); + Task.Run(() => + { + syncThreadsSemaphore.Release(); + return queue.Wait(1000, CancellationToken.None); + }); + syncThreadsSemaphore.Wait(10000); // make sure scheduled thread execution has started + Thread.Sleep(50); + + // act + var isAnyoneWaiting = queue.IsAnyoneWaiting(); + + // assert + Assert.IsTrue(isAnyoneWaiting); + } + + [Test] + [Retry(2)] + public void TestReturnUnsuccessfulOnResetWhileWaiting() + { + // arrange + var queue = new WaitingQueue(); + var watch = new Stopwatch(); + Task.Run(() => + { + Thread.Sleep(50); + queue.Reset(); + }); + + // act + watch.Start(); + var result = queue.Wait(30000, CancellationToken.None); + watch.Stop(); + + // assert + Assert.IsFalse(result); + Assert.That(watch.ElapsedMilliseconds, Is.InRange(50, 1500)); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs new file mode 100644 index 000000000..e2863f0b5 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/SnowflakeDbConnectionPoolTest.cs @@ -0,0 +1,25 @@ +using NUnit.Framework; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Tests.UnitTests +{ + public class SnowflakeDbConnectionPoolTest + { + private readonly string _connectionString1 = "database=D1;warehouse=W1;account=A1;user=U1;password=P1;role=R1;"; + private readonly string _connectionString2 = "database=D2;warehouse=W2;account=A2;user=U2;password=P2;role=R2;"; + + [Test] + public void TestRevertPoolToPreviousVersion() + { + // act + SnowflakeDbConnectionPool.SetOldConnectionPoolVersion(); + + // assert + var sessionPool1 = SnowflakeDbConnectionPool.GetPoolInternal(_connectionString1); + var sessionPool2 = SnowflakeDbConnectionPool.GetPoolInternal(_connectionString2); + Assert.AreEqual(ConnectionPoolType.SingleConnectionCache, SnowflakeDbConnectionPool.GetConnectionPoolVersion()); + Assert.AreEqual(sessionPool1, sessionPool2); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs new file mode 100644 index 000000000..52b10ed17 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/SecureStringHelperTest.cs @@ -0,0 +1,23 @@ +using NUnit.Framework; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Tools +{ + [TestFixture] + public class SecureStringHelperTest + { + [Test] + public void TestConvertPassword() + { + // arrange + var passwordText = "testPassword"; + + // act + var securePassword = SecureStringHelper.Encode(passwordText); + var decodedPassword = SecureStringHelper.Decode(securePassword); + + // assert + Assert.AreEqual(passwordText, decodedPassword); + } + } +} diff --git a/Snowflake.Data.Tests/UnitTests/Tools/TimeoutHelperTest.cs b/Snowflake.Data.Tests/UnitTests/Tools/TimeoutHelperTest.cs new file mode 100644 index 000000000..8bda40b73 --- /dev/null +++ b/Snowflake.Data.Tests/UnitTests/Tools/TimeoutHelperTest.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Generic; +using NUnit.Framework; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.UnitTests.Tools +{ + [TestFixture] + public class TimeoutHelperTest + { + [Test] + [TestCaseSource(nameof(InfiniteTimeouts))] + public void TestInfinity(TimeSpan infiniteTimeout) + { + // act + var isInfinite = TimeoutHelper.IsInfinite(infiniteTimeout); + + // assert + Assert.IsTrue(isInfinite); + } + + [Test] + [TestCaseSource(nameof(FiniteTimeouts))] + public void TestFiniteValue(TimeSpan finiteTimeout) + { + // act + var isInfinite = TimeoutHelper.IsInfinite(finiteTimeout); + + // assert + Assert.IsFalse(isInfinite); + } + + [Test] + [TestCaseSource(nameof(ZeroLengthTimeouts))] + public void TestZeroLength(TimeSpan zeroTimeout) + { + // act + var isZeroLength = TimeoutHelper.IsZeroLength(zeroTimeout); + + // assert + Assert.IsTrue(isZeroLength); + } + + [Test] + [TestCaseSource(nameof(NonZeroLengthTimeouts))] + public void TestNonZeroLength(TimeSpan nonZeroTimeout) + { + // act + var isZeroLength = TimeoutHelper.IsZeroLength(nonZeroTimeout); + + // assert + Assert.IsFalse(isZeroLength); + } + + [Test] + [TestCase(1000, 1000)] + [TestCase(1000, 2000)] + public void TestInfiniteTimeoutDoesNotExpire(long startedAtMillis, long nowMillis) + { + // act + var isExpired = TimeoutHelper.IsExpired(startedAtMillis, nowMillis, TimeoutHelper.Infinity()); + + // assert + Assert.IsFalse(isExpired); + } + + [Test] + [TestCase(1000, 1000, 0, true)] + [TestCase(1000, 2000, 0, true)] + [TestCase(1000, 1100, 100, true)] + [TestCase(1000, 1099, 100, false)] + [TestCase(1000, 2000, 100, true)] + public void TestExpiredTimeout(long startedAtMillis, long nowMillis, long timeoutMillis, bool expectedIsExpired) + { + // arrange + var timeout = TimeSpan.FromMilliseconds(timeoutMillis); + + // act + var isExpired = TimeoutHelper.IsExpired(startedAtMillis, nowMillis, timeout); + + // assert + Assert.AreEqual(expectedIsExpired, isExpired); + } + + [Test] + public void TestInfiniteTimeoutNeverExpires() + { + // act + var isExpired = TimeoutHelper.IsExpired(1000, TimeoutHelper.Infinity()); + + // assert + Assert.IsFalse(isExpired); + } + + [Test] + [TestCase(0, 0, true)] + [TestCase(1000, 0, true)] + [TestCase(100, 100, true)] + [TestCase(99, 100, false)] + [TestCase(1000, 100, true)] + public void TestExpiredTimeoutByDuration(long durationMillis, long timeoutMillis, bool expectedIsExpired) + { + // arrange + var timeout = TimeSpan.FromMilliseconds(timeoutMillis); + + // act + var isExpired = TimeoutHelper.IsExpired(durationMillis, timeout); + + // assert + Assert.AreEqual(expectedIsExpired, isExpired); + } + + [Test] + public void TestFiniteTimeoutLeftFailsForInfiniteTimeout() + { + // act + var thrown = Assert.Throws(() => + TimeoutHelper.FiniteTimeoutLeftMillis(1000, 2000, TimeoutHelper.Infinity())); + + // assert + Assert.That(thrown.Message, Does.Contain("Infinite timeout cannot be used to determine milliseconds left")); + } + + + [Test] + [TestCase(1000, 1000, 0, 0)] + [TestCase(1000, 2000, 0, 0)] + [TestCase(1000, 1100, 100, 0)] + [TestCase(1000, 1095, 100, 5)] + public void TestFiniteTimeoutLeft(long startedAtMillis, long nowMillis, long timeoutMillis, long expectedMillisLeft) + { + // arrange + var timeout = TimeSpan.FromMilliseconds(timeoutMillis); + + // act + var millisLeft = TimeoutHelper.FiniteTimeoutLeftMillis(startedAtMillis, nowMillis, timeout); + + // assert + Assert.AreEqual(expectedMillisLeft, millisLeft); + } + + public static IEnumerable InfiniteTimeouts() + { + yield return TimeoutHelper.Infinity(); + yield return TimeSpan.FromMilliseconds(-1); + } + + public static IEnumerable FiniteTimeouts() + { + yield return TimeSpan.Zero; + yield return TimeSpan.FromMilliseconds(1); + yield return TimeSpan.FromSeconds(2); + } + + public static IEnumerable ZeroLengthTimeouts() + { + yield return TimeSpan.Zero; + yield return TimeSpan.FromMilliseconds(0); + yield return TimeSpan.FromSeconds(0); + } + + public static IEnumerable NonZeroLengthTimeouts() + { + yield return TimeoutHelper.Infinity(); + yield return TimeSpan.FromMilliseconds(3); + yield return TimeSpan.FromSeconds(5); + } + } +} diff --git a/Snowflake.Data.Tests/Util/Awaiter.cs b/Snowflake.Data.Tests/Util/Awaiter.cs new file mode 100644 index 000000000..7f48732c7 --- /dev/null +++ b/Snowflake.Data.Tests/Util/Awaiter.cs @@ -0,0 +1,26 @@ +using System; +using System.Threading; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Tests.Util +{ + public static class Awaiter + { + private static readonly TimeSpan s_defaultDelay = TimeSpan.FromMilliseconds(200); + + public static void WaitUntilConditionOrTimeout(Func condition, TimeSpan timeout) + { + WaitUntilConditionOrTimeout(condition, timeout, s_defaultDelay); + } + + public static void WaitUntilConditionOrTimeout(Func condition, TimeSpan timeout, TimeSpan delay) + { + var startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + var breakTime = TimeoutHelper.IsInfinite(timeout) ? long.MaxValue : startTime + timeout.TotalMilliseconds; + while (!condition() && DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() < breakTime) + { + Thread.Sleep(delay); + } + } + } +} diff --git a/Snowflake.Data.Tests/Util/AwaiterTest.cs b/Snowflake.Data.Tests/Util/AwaiterTest.cs new file mode 100644 index 000000000..a446f9c53 --- /dev/null +++ b/Snowflake.Data.Tests/Util/AwaiterTest.cs @@ -0,0 +1,54 @@ +using System; +using NUnit.Framework; + +namespace Snowflake.Data.Tests.Util +{ + [TestFixture] + public class AwaiterTest + { + private readonly TimeSpan _maxDurationRegardedAsImmediately = TimeSpan.FromSeconds(1); + + [Test] + public void TestReturnsImmediatelyWhenConditionIsMet() + { + // act + var millis = MillisecondsOfWaiting(() => true, TimeSpan.FromHours(1)); + + // assert + Assert.LessOrEqual(millis, _maxDurationRegardedAsImmediately.TotalMilliseconds); + } + + [Test] + public void TestReturnsImmediatelyOnZeroTimeout() + { + // act + var millis = MillisecondsOfWaiting(() => false, TimeSpan.FromMilliseconds(0)); + + // assert + Assert.LessOrEqual(millis, _maxDurationRegardedAsImmediately.TotalMilliseconds); + } + + [Test] + public void TestReturnsOnTimeout() + { + // arrange + var timeout = TimeSpan.FromSeconds(2); + + // act + var millis = MillisecondsOfWaiting(() => false, TimeSpan.FromSeconds(2)); + + // assert + Assert.GreaterOrEqual(millis, _maxDurationRegardedAsImmediately.TotalMilliseconds); + Assert.LessOrEqual(millis, timeout.TotalMilliseconds + _maxDurationRegardedAsImmediately.TotalMilliseconds); + } + + private long MillisecondsOfWaiting(Func condition, TimeSpan timeout) + { + var watch = new StopWatch(); + watch.Start(); + Awaiter.WaitUntilConditionOrTimeout(condition, timeout); + watch.Stop(); + return watch.ElapsedMilliseconds; + } + } +} diff --git a/Snowflake.Data.Tests/Util/ConnectingThreads.cs b/Snowflake.Data.Tests/Util/ConnectingThreads.cs new file mode 100644 index 000000000..beba4720c --- /dev/null +++ b/Snowflake.Data.Tests/Util/ConnectingThreads.cs @@ -0,0 +1,220 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Log; +using Snowflake.Data.Tests.Util; + +namespace Snowflake.Data.Tests.IntegrationTests +{ + class ConnectingThreads + { + private string _connectionString; + + private ConcurrentQueue _events = new ConcurrentQueue(); + + private List threads = new List(); + + public ConnectingThreads(string connectionString) + { + _connectionString = connectionString; + } + + public ConnectingThreads NewThread(string name, + long waitBeforeConnectMillis, + long waitAfterConnectMillis, + bool closeOnExit) + { + var thread = new ConnectingThread( + name, + _events, + _connectionString, + waitBeforeConnectMillis, + waitAfterConnectMillis, + closeOnExit).Build(); + threads.Add(thread); + return this; + } + + public ConnectingThreads StartAll() + { + threads.ForEach(thread => thread.Start()); + return this; + } + + public ConnectingThreads JoinAll() + { + threads.ForEach(thread => thread.Join()); + return this; + } + + public IEnumerable Events() => _events.ToArray().OfType(); + + public void Enqueue(ThreadEvent threadEvent) => _events.Enqueue(threadEvent); + + public static SFLogger Logger() => SFLoggerFactory.GetLogger(); // we have to choose a class from Snowflake.Data package otherwise it will be visible in GH build output + } + + class ConnectingThread + { + private static readonly SFLogger s_logger = ConnectingThreads.Logger(); + + private string _name; + + private ConcurrentQueue _events; + + private string _connectionString; + + private long _waitBeforeConnectMillis; + + private long _waitAfterConnectMillis; + + private bool _closeOnExit; + + internal const string NamePrefix = "thread_"; + + public ConnectingThread( + string name, + ConcurrentQueue events, + string connectionString, + long waitBeforeConnectMillis, + long waitAfterConnectMillis, + bool closeOnExit) + { + _name = name; + _events = events; + _connectionString = connectionString; + _waitBeforeConnectMillis = waitBeforeConnectMillis; + _waitAfterConnectMillis = waitAfterConnectMillis; + _closeOnExit = closeOnExit; + } + + public Thread Build() + { + var thread = new Thread(Execute); + thread.Name = NamePrefix + _name; + return thread; + } + + private void Execute() + { + var connection = new SnowflakeDbConnection(); + connection.ConnectionString = _connectionString; + s_logger.Debug($"Execution started, will sleep for {_waitBeforeConnectMillis} ms"); + Sleep(_waitBeforeConnectMillis); + var watch = new StopWatch(); + watch.Start(); + var connected = false; + try + { + s_logger.Debug("Opening the connection"); + connection.Open(); + connected = true; + } + catch (Exception exception) + { + watch.Stop(); + s_logger.Error($"Execution failed because of the error: {exception}"); + _events.Enqueue(ThreadEvent.EventConnectingFailed(_name, exception, watch.ElapsedMilliseconds)); + } + if (connected) + { + watch.Stop(); + _events.Enqueue(ThreadEvent.EventConnected(_name, watch.ElapsedMilliseconds)); + } + Sleep(_waitAfterConnectMillis); + if (_closeOnExit) + { + s_logger.Debug($"Closing the connection"); + connection.Close(); + } + } + + private void Sleep(long millis) + { + if (millis <= 0) + { + return; + } + Thread.Sleep((int) millis); + } + } + + class ThreadEvent + { + public string ThreadName { get; set; } + + public string EventName { get; set; } + + public Exception Error { get; set; } + + public long Timestamp { get; set; } + + public long Duration { get; set; } + + private const string Connected = "CONNECTED"; + private const string WaitingForSession = "WAITING_FOR_SESSION"; + private const string FailedToConnect = "FAILED_TO_CONNECT"; + + public ThreadEvent(string threadName, string eventName, Exception error, long duration) + { + ThreadName = threadName; + EventName = eventName; + Error = error; + Timestamp = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + Duration = duration; + } + + public bool IsConnectedEvent() => EventName.Equals(Connected); + + public bool IsWaitingEvent() => EventName.Equals(WaitingForSession); + + public static ThreadEvent EventConnected(string threadName, long duration) => + new ThreadEvent(threadName, Connected, null, duration); + + public static ThreadEvent EventConnectingFailed(string threadName, Exception exception, long duration) => + new ThreadEvent(threadName, FailedToConnect, exception, duration); + + public static ThreadEvent EventWaitingForSessionStarted(string threadName) => + new ThreadEvent(threadName, WaitingForSession, null, 0); + } + + class SessionPoolThreadEventHandler: SessionPoolEventHandler + { + private static readonly SFLogger s_logger = ConnectingThreads.Logger(); + private readonly ConnectingThreads _connectingThreads; + + public SessionPoolThreadEventHandler(ConnectingThreads connectingThreads) + { + _connectingThreads = connectingThreads; + } + + public override void OnWaitingForSessionStarted(SessionPool sessionPool) + { + var threadName = Thread.CurrentThread.Name; + var realThreadName = threadName.StartsWith(ConnectingThread.NamePrefix) + ? threadName.Substring(ConnectingThread.NamePrefix.Length) : threadName; + s_logger.Warn($"Thread is going to wait for an available session. Current time in milliseconds: {DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()}"); + var waitingStartedEvent = ThreadEvent.EventWaitingForSessionStarted(realThreadName); + _connectingThreads.Enqueue(waitingStartedEvent); + } + + public override void OnWaitingForSessionStarted(SessionPool sessionPool, long millisLeft) + { + s_logger.Warn($"Thread is going to wait with milliseconds timeout of {millisLeft}. Current time in milliseconds: {DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()}"); + } + + public override void OnWaitingForSessionSuccessful(SessionPool sessionPool) + { + s_logger.Warn($"Thread has been woken with a session granted. Current time in milliseconds: {DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()}"); + } + + public override void OnSessionProvided(SessionPool sessionPool) + { + s_logger.Warn($"Thread has got a session. Current time in milliseconds: {DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()}"); + } + } +} diff --git a/Snowflake.Data.Tests/Util/PoolConfig.cs b/Snowflake.Data.Tests/Util/PoolConfig.cs index 4856da243..4291c2f81 100644 --- a/Snowflake.Data.Tests/Util/PoolConfig.cs +++ b/Snowflake.Data.Tests/Util/PoolConfig.cs @@ -3,6 +3,8 @@ */ using Snowflake.Data.Client; +using Snowflake.Data.Core; +using Snowflake.Data.Core.Session; namespace Snowflake.Data.Tests.Util { @@ -11,16 +13,21 @@ class PoolConfig private readonly bool _pooling; private readonly long _timeout; private readonly int _maxPoolSize; + private readonly ConnectionPoolType _connectionPoolType; public PoolConfig() { - _maxPoolSize = SnowflakeDbConnectionPool.GetMaxPoolSize(); - _timeout = SnowflakeDbConnectionPool.GetTimeout(); - _pooling = SnowflakeDbConnectionPool.GetPooling(); + _maxPoolSize = SFSessionHttpClientProperties.DefaultMaxPoolSize; + _timeout = (long) SFSessionHttpClientProperties.DefaultExpirationTimeout.TotalSeconds; + _pooling = SFSessionHttpClientProperties.DefaultPoolingEnabled; + _connectionPoolType = SnowflakeDbConnectionPool.DefaultConnectionPoolType; } public void Reset() { + SnowflakeDbConnectionPool.SetConnectionPoolVersion(_connectionPoolType); + if (_connectionPoolType == ConnectionPoolType.MultipleConnectionPool) + return; // for multiple connection pool setting parameters for all the pools doesn't work by design SnowflakeDbConnectionPool.SetMaxPoolSize(_maxPoolSize); SnowflakeDbConnectionPool.SetTimeout(_timeout); SnowflakeDbConnectionPool.SetPooling(_pooling); diff --git a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs index 63432da31..881cba861 100644 --- a/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs +++ b/Snowflake.Data.Tests/Util/SnowflakeDbExceptionAssert.cs @@ -13,16 +13,16 @@ public static class SnowflakeDbExceptionAssert { public static void HasErrorCode(SnowflakeDbException exception, SFError sfError) { - Assert.AreEqual(exception.ErrorCode, sfError.GetAttribute().errorCode); + Assert.AreEqual(sfError.GetAttribute().errorCode, exception.ErrorCode); } - + public static void HasErrorCode(Exception exception, SFError sfError) { Assert.NotNull(exception); switch (exception) { case SnowflakeDbException snowflakeDbException: - Assert.AreEqual(snowflakeDbException.ErrorCode, sfError.GetAttribute().errorCode); + Assert.AreEqual(sfError.GetAttribute().errorCode, snowflakeDbException.ErrorCode); break; default: Assert.Fail(exception.GetType() + " type is not " + typeof(SnowflakeDbException)); @@ -45,7 +45,7 @@ public static void HasHttpErrorCodeInExceptionChain(Exception exception, HttpSta return he.Message.Contains(((int)expected).ToString()); #else return he.StatusCode == expected; -#endif +#endif default: return false; } diff --git a/Snowflake.Data.Tests/Util/StopWatch.cs b/Snowflake.Data.Tests/Util/StopWatch.cs new file mode 100644 index 000000000..c9ef2e201 --- /dev/null +++ b/Snowflake.Data.Tests/Util/StopWatch.cs @@ -0,0 +1,31 @@ +using System; + +namespace Snowflake.Data.Tests.Util +{ + /** + * StopWatch class is a measure time based on DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(). + * The class System.Diagnostics.Stopwatch uses ticks of processor and calculates the time based on tick frequency. + * If the code which we are testing uses DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() it is better to use this StopWatch class + * because the tests are less flaky in GH builds. + */ + internal class StopWatch + { + private long _startMillis; + private long _stopMillis; + + public long ElapsedMilliseconds + { + get => _stopMillis - _startMillis; + } + + public void Start() + { + _startMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + } + + public void Stop() + { + _stopMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + } + } +} diff --git a/Snowflake.Data/Client/SnowflakeDbConnection.cs b/Snowflake.Data/Client/SnowflakeDbConnection.cs index cce4974fc..fc0ba199d 100755 --- a/Snowflake.Data/Client/SnowflakeDbConnection.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnection.cs @@ -18,12 +18,12 @@ public class SnowflakeDbConnection : DbConnection { private SFLogger logger = SFLoggerFactory.GetLogger(); - internal SFSession SfSession { get; set; } + internal SFSession SfSession { get; set; } internal ConnectionState _connectionState; protected override DbProviderFactory DbProviderFactory => new SnowflakeDbFactory(); - + internal int _connectionTimeout; private bool _disposed = false; @@ -47,7 +47,7 @@ protected enum TransactionRollbackStatus public SnowflakeDbConnection() { _connectionState = ConnectionState.Closed; - _connectionTimeout = + _connectionTimeout = int.Parse(SFSessionProperty.CONNECTION_TIMEOUT.GetAttribute(). defaultValue); _isArrayBindStageCreated = false; @@ -84,12 +84,12 @@ private bool IsNonClosedWithSession() public override int ConnectionTimeout => this._connectionTimeout; /// - /// If the connection to the database is closed, the DataSource returns whatever is contained - /// in the ConnectionString for the DataSource keyword. If the connection is open and the - /// ConnectionString data source keyword's value starts with "|datadirectory|", the property - /// returns whatever is contained in the ConnectionString for the DataSource keyword only. If - /// the connection to the database is open, the property returns what the native provider - /// returns for the DBPROP_INIT_DATASOURCE, and if that is empty, the native provider's + /// If the connection to the database is closed, the DataSource returns whatever is contained + /// in the ConnectionString for the DataSource keyword. If the connection is open and the + /// ConnectionString data source keyword's value starts with "|datadirectory|", the property + /// returns whatever is contained in the ConnectionString for the DataSource keyword only. If + /// the connection to the database is open, the property returns what the native provider + /// returns for the DBPROP_INIT_DATASOURCE, and if that is empty, the native provider's /// DBPROP_DATASOURCENAME is returned. /// Note: not yet implemented /// @@ -105,9 +105,37 @@ public override string DataSource public override ConnectionState State => _connectionState; internal SnowflakeDbTransaction ExplicitTransaction { get; set; } // tracks only explicit transaction operations - + + public void PreventPooling() + { + if (SfSession == null) + { + throw new Exception("Session not yet created for this connection. Unable to prevent the session from pooling"); + } + SfSession.SetPooling(false); + logger.Debug($"Session {SfSession.sessionId} marked not to be pooled any more"); + } + internal bool HasActiveExplicitTransaction() => ExplicitTransaction != null && ExplicitTransaction.IsActive; + private bool TryToReturnSessionToPool() + { + var pooling = SnowflakeDbConnectionPool.GetPooling() && SfSession.GetPooling(); + var transactionRollbackStatus = pooling ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined; + var canReuseSession = CanReuseSession(transactionRollbackStatus); + if (!canReuseSession) + { + SnowflakeDbConnectionPool.ReleaseBusySession(SfSession); + return false; + } + var sessionReturnedToPool = SnowflakeDbConnectionPool.AddSession(SfSession); + if (sessionReturnedToPool) + { + logger.Debug($"Session pooled: {SfSession.sessionId}"); + } + return sessionReturnedToPool; + } + private TransactionRollbackStatus TerminateTransactionForDirtyConnectionReturningToPool() { if (!HasActiveExplicitTransaction()) @@ -122,12 +150,12 @@ private TransactionRollbackStatus TerminateTransactionForDirtyConnectionReturnin // error to indicate a problem within application code that a connection was closed while still having a pending transaction logger.Error("Closing dirty connection: rollback transaction in session " + SfSession.sessionId + " succeeded."); ExplicitTransaction = null; - return TransactionRollbackStatus.Success; + return TransactionRollbackStatus.Success; } } - catch (SnowflakeDbException exception) + catch (Exception exception) { - // error to indicate a problem with rollback of an active transaction and inability to return dirty connection to the pool + // error to indicate a problem with rollback of an active transaction and inability to return dirty connection to the pool logger.Error("Closing dirty connection: rollback transaction in session: " + SfSession.sessionId + " failed, exception: " + exception.Message); return TransactionRollbackStatus.Failure; // connection won't be pooled } @@ -151,19 +179,13 @@ public override void Close() logger.Debug("Close Connection."); if (IsNonClosedWithSession()) { - var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined; - - if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession)) - { - logger.Debug($"Session pooled: {SfSession.sessionId}"); - } - else + var returnedToPool = TryToReturnSessionToPool(); + if (!returnedToPool) { SfSession.close(); } SfSession = null; } - _connectionState = ConnectionState.Closed; } @@ -189,11 +211,9 @@ public virtual async Task CloseAsync(CancellationToken cancellationToken) { if (IsNonClosedWithSession()) { - var transactionRollbackStatus = SnowflakeDbConnectionPool.GetPooling() ? TerminateTransactionForDirtyConnectionReturningToPool() : TransactionRollbackStatus.Undefined; - - if (CanReuseSession(transactionRollbackStatus) && SnowflakeDbConnectionPool.AddSession(SfSession)) + var returnedToPool = TryToReturnSessionToPool(); + if (returnedToPool) { - logger.Debug($"Session pooled: {SfSession.sessionId}"); _connectionState = ConnectionState.Closed; taskCompletionSource.SetResult(null); } @@ -234,10 +254,10 @@ await SfSession.CloseAsync(cancellationToken).ContinueWith( protected virtual bool CanReuseSession(TransactionRollbackStatus transactionRollbackStatus) { - return SnowflakeDbConnectionPool.GetPooling() && + return SnowflakeDbConnectionPool.GetPooling() && transactionRollbackStatus == TransactionRollbackStatus.Success; } - + public override void Open() { logger.Debug("Open Connection."); @@ -381,7 +401,7 @@ protected override void Dispose(bool disposing) SfSession = null; _connectionState = ConnectionState.Closed; } - + _disposed = true; } diff --git a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs index f643fa5c9..617c07ebd 100644 --- a/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs +++ b/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs @@ -1,7 +1,8 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ +using System; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -14,72 +15,145 @@ namespace Snowflake.Data.Client public class SnowflakeDbConnectionPool { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private static readonly IConnectionManager s_connectionManager = new ConnectionCacheManager(); - + private static readonly Object s_connectionManagerInstanceLock = new Object(); + private static IConnectionManager s_connectionManager; + internal const ConnectionPoolType DefaultConnectionPoolType = ConnectionPoolType.MultipleConnectionPool; + + private static IConnectionManager ConnectionManager + { + get + { + if (s_connectionManager != null) + return s_connectionManager; + SetConnectionPoolVersion(DefaultConnectionPoolType); + return s_connectionManager; + } + } + internal static SFSession GetSession(string connectionString, SecureString password) { - s_logger.Debug("SnowflakeDbConnectionPool::GetSession"); - return s_connectionManager.GetSession(connectionString, password); + s_logger.Debug($"SnowflakeDbConnectionPool::GetSession"); + return ConnectionManager.GetSession(connectionString, password); } - + internal static Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) { - s_logger.Debug("SnowflakeDbConnectionPool::GetSessionAsync"); - return s_connectionManager.GetSessionAsync(connectionString, password, cancellationToken); + s_logger.Debug($"SnowflakeDbConnectionPool::GetSessionAsync"); + return ConnectionManager.GetSessionAsync(connectionString, password, cancellationToken); + } + + public static SnowflakeDbSessionPool GetPool(string connectionString, SecureString password) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPool"); + return new SnowflakeDbSessionPool(ConnectionManager.GetPool(connectionString, password)); + } + + public static SnowflakeDbSessionPool GetPool(string connectionString) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPool"); + return new SnowflakeDbSessionPool(ConnectionManager.GetPool(connectionString)); + } + + internal static SessionPool GetPoolInternal(string connectionString) + { + s_logger.Debug($"SnowflakeDbConnectionPool::GetPoolInternal"); + return ConnectionManager.GetPool(connectionString); } - + internal static bool AddSession(SFSession session) { s_logger.Debug("SnowflakeDbConnectionPool::AddSession"); - return s_connectionManager.AddSession(session); + return ConnectionManager.AddSession(session); + } + + internal static void ReleaseBusySession(SFSession session) + { + s_logger.Debug("SnowflakeDbConnectionPool::ReleaseBusySession"); + ConnectionManager.ReleaseBusySession(session); } public static void ClearAllPools() { s_logger.Debug("SnowflakeDbConnectionPool::ClearAllPools"); - s_connectionManager.ClearAllPools(); + ConnectionManager.ClearAllPools(); } public static void SetMaxPoolSize(int maxPoolSize) { s_logger.Debug("SnowflakeDbConnectionPool::SetMaxPoolSize"); - s_connectionManager.SetMaxPoolSize(maxPoolSize); + ConnectionManager.SetMaxPoolSize(maxPoolSize); } public static int GetMaxPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetMaxPoolSize"); - return s_connectionManager.GetMaxPoolSize(); + return ConnectionManager.GetMaxPoolSize(); } public static void SetTimeout(long connectionTimeout) { s_logger.Debug("SnowflakeDbConnectionPool::SetTimeout"); - s_connectionManager.SetTimeout(connectionTimeout); + ConnectionManager.SetTimeout(connectionTimeout); } - + public static long GetTimeout() { s_logger.Debug("SnowflakeDbConnectionPool::GetTimeout"); - return s_connectionManager.GetTimeout(); + return ConnectionManager.GetTimeout(); } public static int GetCurrentPoolSize() { s_logger.Debug("SnowflakeDbConnectionPool::GetCurrentPoolSize"); - return s_connectionManager.GetCurrentPoolSize(); + return ConnectionManager.GetCurrentPoolSize(); } public static bool SetPooling(bool isEnable) { s_logger.Debug("SnowflakeDbConnectionPool::SetPooling"); - return s_connectionManager.SetPooling(isEnable); + return ConnectionManager.SetPooling(isEnable); } public static bool GetPooling() { s_logger.Debug("SnowflakeDbConnectionPool::GetPooling"); - return s_connectionManager.GetPooling(); + return ConnectionManager.GetPooling(); + } + + public static void SetOldConnectionPoolVersion() + { + SetConnectionPoolVersion(ConnectionPoolType.SingleConnectionCache); + } + + internal static void SetConnectionPoolVersion(ConnectionPoolType requestedPoolType) + { + lock (s_connectionManagerInstanceLock) + { + s_connectionManager?.ClearAllPools(); + if (requestedPoolType == ConnectionPoolType.MultipleConnectionPool) + { + s_connectionManager = new ConnectionPoolManager(); + s_logger.Info("SnowflakeDbConnectionPool - multiple connection pools enabled"); + } + if (requestedPoolType == ConnectionPoolType.SingleConnectionCache) + { + s_connectionManager = new ConnectionCacheManager(); + s_logger.Warn("SnowflakeDbConnectionPool - connection cache enabled"); + } + } + } + + internal static ConnectionPoolType GetConnectionPoolVersion() + { + if (ConnectionManager != null) + { + switch (ConnectionManager) + { + case ConnectionCacheManager _: return ConnectionPoolType.SingleConnectionCache; + case ConnectionPoolManager _: return ConnectionPoolType.MultipleConnectionPool; + } + } + return DefaultConnectionPoolType; } } } diff --git a/Snowflake.Data/Client/SnowflakeDbSessionPool.cs b/Snowflake.Data/Client/SnowflakeDbSessionPool.cs new file mode 100644 index 000000000..2bd3caecc --- /dev/null +++ b/Snowflake.Data/Client/SnowflakeDbSessionPool.cs @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using Snowflake.Data.Core.Session; + +namespace Snowflake.Data.Client +{ + public class SnowflakeDbSessionPool + { + private readonly SessionPool _sessionPool; + + internal SnowflakeDbSessionPool(SessionPool sessionPool) + => _sessionPool = sessionPool ?? throw new NullReferenceException("SessionPool not provided!"); + + public bool GetPooling() => _sessionPool.GetPooling(); + + public int GetMinPoolSize() => _sessionPool.GetMinPoolSize(); + + public int GetMaxPoolSize() => _sessionPool.GetMaxPoolSize(); + + public int GetCurrentPoolSize() => _sessionPool.GetCurrentPoolSize(); + + public long GetExpirationTimeout() => _sessionPool.GetTimeout(); + + public long GetConnectionTimeout() => _sessionPool.GetConnectionTimeout(); + + public long GetWaitForIdleSessionTimeout() => _sessionPool.GetWaitForIdleSessionTimeout(); + + public void ClearPool() => _sessionPool.ClearSessions(); + + public ChangedSessionBehavior GetChangedSession() => _sessionPool.GetChangedSession(); + } +} diff --git a/Snowflake.Data/Core/ArrowResultSet.cs b/Snowflake.Data/Core/ArrowResultSet.cs index 31e0eccca..56a636c4e 100755 --- a/Snowflake.Data/Core/ArrowResultSet.cs +++ b/Snowflake.Data/Core/ArrowResultSet.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -398,7 +398,7 @@ internal override string GetString(int ordinal) private void UpdateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); } diff --git a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs index d7a7a29d1..a26d542d3 100644 --- a/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/BasicAuthenticator.cs @@ -10,7 +10,7 @@ namespace Snowflake.Data.Core.Authenticator { class BasicAuthenticator : BaseAuthenticator, IAuthenticator { - public static readonly string AUTH_NAME = "snowflake"; + public const string AUTH_NAME = "snowflake"; private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); internal BasicAuthenticator(SFSession session) : base(session, AUTH_NAME) diff --git a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs index d6ead6818..3b882a05b 100644 --- a/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/ExternalBrowserAuthenticator.cs @@ -21,7 +21,7 @@ namespace Snowflake.Data.Core.Authenticator /// class ExternalBrowserAuthenticator : BaseAuthenticator, IAuthenticator { - public static readonly string AUTH_NAME = "externalbrowser"; + public const string AUTH_NAME = "externalbrowser"; private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); private static readonly string TOKEN_REQUEST_PREFIX = "?token="; private static readonly byte[] SUCCESS_RESPONSE = System.Text.Encoding.UTF8.GetBytes( @@ -87,7 +87,7 @@ await session.restRequester.PostAsync( logger.Warn("Browser response timeout"); throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); } - + httpListener.Stop(); } @@ -134,7 +134,7 @@ void IAuthenticator.Authenticate() logger.Warn("Browser response timeout"); throw new SnowflakeDbException(SFError.BROWSER_RESPONSE_TIMEOUT, timeoutInSec); } - + httpListener.Stop(); } @@ -150,7 +150,7 @@ private void GetContextCallback(IAsyncResult result) { HttpListenerContext context = httpListener.EndGetContext(result); HttpListenerRequest request = context.Request; - + _samlResponseToken = ValidateAndExtractToken(request); HttpListenerResponse response = context.Response; try diff --git a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs index 150551f91..7a41a8335 100644 --- a/Snowflake.Data/Core/Authenticator/IAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/IAuthenticator.cs @@ -17,7 +17,7 @@ namespace Snowflake.Data.Core.Authenticator internal interface IAuthenticator { /// - /// Process the authentication asynchronouly + /// Process the authentication asynchronously /// /// /// @@ -49,19 +49,19 @@ internal abstract class BaseAuthenticator SFLoggerFactory.GetLogger(); // The name of the authenticator. - protected string authName; + private string authName; // The session which created this authenticator. protected SFSession session; // The client environment properties - protected LoginRequestClientEnv ClientEnv = SFEnvironment.ClientEnv; + private LoginRequestClientEnv ClientEnv = SFEnvironment.ClientEnv; /// /// The abstract base for all authenticators. /// /// The session which created the authenticator. - public BaseAuthenticator(SFSession session, string authName) + protected BaseAuthenticator(SFSession session, string authName) { this.session = session; this.authName = authName; @@ -104,7 +104,7 @@ protected void Login() /// /// Builds a simple login request. Each authenticator will fill the Data part with their /// specialized information. The common Data attributes are already filled (clientAppId, - /// ClienAppVersion...). + /// ClientAppVersion...). /// /// A login request to send to the server. private SFRestRequest BuildLoginRequest() @@ -129,10 +129,10 @@ private SFRestRequest BuildLoginRequest() } } - /// - /// Authenticator Factory to build authenticators - /// - internal class AuthenticatorFactory + /// + /// Authenticator Factory to build authenticators + /// + internal class AuthenticatorFactory { private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); /// @@ -155,8 +155,8 @@ internal static IAuthenticator GetAuthenticator(SFSession session) else if (type.Equals(KeyPairAuthenticator.AUTH_NAME, StringComparison.InvariantCultureIgnoreCase)) { // Get private key path or private key from connection settings - if (!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath) && - !session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent)) + if ((!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath) || string.IsNullOrEmpty(pkPath)) && + (!session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent) || string.IsNullOrEmpty(pkContent))) { // There is no PRIVATE_KEY_FILE defined, can't authenticate with key-pair string invalidStringDetail = @@ -192,12 +192,8 @@ internal static IAuthenticator GetAuthenticator(SFSession session) { return new OktaAuthenticator(session, type); } - - var e = new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, type); - - logger.Error("Unknown authenticator", e); - - throw e; + logger.Error($"Unknown authenticator {type}"); + throw new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, type); } } } diff --git a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs index fcfb70695..e0c28d4ef 100644 --- a/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/KeyPairAuthenticator.cs @@ -28,7 +28,7 @@ namespace Snowflake.Data.Core.Authenticator class KeyPairAuthenticator : BaseAuthenticator, IAuthenticator { // The authenticator setting value to use to authenticate using key pair authentication. - public static readonly string AUTH_NAME = "snowflake_jwt"; + public const string AUTH_NAME = "snowflake_jwt"; // The logger. private static readonly SFLogger logger = @@ -85,9 +85,9 @@ private string GenerateJwtToken() { logger.Info("Key-pair Authentication"); - bool hasPkPath = + bool hasPkPath = session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_FILE, out var pkPath); - bool hasPkContent = + bool hasPkContent = session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY, out var pkContent); session.properties.TryGetValue(SFSessionProperty.PRIVATE_KEY_PWD, out var pkPwd); @@ -152,31 +152,31 @@ private string GenerateJwtToken() byte[] sha256Hash = SHA256Encoder.ComputeHash(publicKeyEncoded); publicKeyFingerPrint = "SHA256:" + Convert.ToBase64String(sha256Hash); } - - // Generating the token + + // Generating the token var now = DateTime.UtcNow; System.DateTime dtDateTime = new DateTime(1970, 1, 1, 0, 0, 0, 0, System.DateTimeKind.Utc); long secondsSinceEpoch = (long)((now - dtDateTime).TotalSeconds); - /* + /* * Payload content - * iss : $accountName.$userName.$pulicKeyFingerprint + * iss : $accountName.$userName.$publicKeyFingerprint * sub : $accountName.$userName * iat : $now * exp : $now + LIFETIME - * + * * Note : Lifetime = 120sec for Python impl, 60sec for Jdbc and Odbc */ - String accountUser = - session.properties[SFSessionProperty.ACCOUNT].ToUpper() + - "." + + String accountUser = + session.properties[SFSessionProperty.ACCOUNT].ToUpper() + + "." + session.properties[SFSessionProperty.USER].ToUpper(); String issuer = accountUser + "." + publicKeyFingerPrint; var claims = new[] { new Claim( - JwtRegisteredClaimNames.Iat, - secondsSinceEpoch.ToString(), + JwtRegisteredClaimNames.Iat, + secondsSinceEpoch.ToString(), System.Security.Claims.ClaimValueTypes.Integer64), new Claim(JwtRegisteredClaimNames.Sub, accountUser), }; diff --git a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs index 5e8f4a310..f36d0353e 100644 --- a/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OAuthAuthenticator.cs @@ -14,7 +14,7 @@ namespace Snowflake.Data.Core.Authenticator class OAuthAuthenticator : BaseAuthenticator, IAuthenticator { // The authenticator setting value to use to authenticate using key pair authentication. - public static readonly string AUTH_NAME = "oauth"; + public const string AUTH_NAME = "oauth"; // The logger. private static readonly SFLogger logger = diff --git a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs index cca377512..1780ccffc 100644 --- a/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs +++ b/Snowflake.Data/Core/Authenticator/OktaAuthenticator.cs @@ -13,6 +13,7 @@ using System.Text; using System.Web; using System.Linq; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Core.Authenticator { @@ -21,6 +22,7 @@ namespace Snowflake.Data.Core.Authenticator /// class OktaAuthenticator : BaseAuthenticator, IAuthenticator { + public const string AUTH_NAME = "okta"; private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); internal const string RetryCountHeader = "RetryCount"; @@ -38,7 +40,7 @@ class OktaAuthenticator : BaseAuthenticator, IAuthenticator /// /// /// - internal OktaAuthenticator(SFSession session, string oktaUriString) : + internal OktaAuthenticator(SFSession session, string oktaUriString) : base(session, oktaUriString) { _oktaUrl = new Uri(oktaUriString); @@ -80,11 +82,11 @@ async Task IAuthenticator.AuthenticateAsync(CancellationToken cancellationToken) s_logger.Debug("step 4: Get SAML response from SSO"); var samlRestRequest = BuildSamlRestRequest(ssoUrl, onetimeToken); samlRawResponse = await session.restRequester.GetAsync(samlRestRequest, cancellationToken).ConfigureAwait(false); -#if NETFRAMEWORK +#if NETFRAMEWORK _rawSamlTokenHtmlString = await samlRawResponse.Content.ReadAsStringAsync().ConfigureAwait(false); #else _rawSamlTokenHtmlString = await samlRawResponse.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); -#endif +#endif s_logger.Debug("step 5: Verify postback URL in SAML response"); VerifyPostbackUrl(); @@ -270,10 +272,11 @@ private void VerifyPostbackUrl() throw e; } } - + private bool RetryLimitIsNotReached(int retryCount, int timeoutElapsed) { - return retryCount < session._maxRetryCount && timeoutElapsed < session._maxRetryTimeout; + var elapsedMillis = timeoutElapsed * 1000; + return retryCount < session._maxRetryCount && !TimeoutHelper.IsExpired(elapsedMillis, session._maxRetryTimeout); } private bool IsPostbackUrlNotFound(Exception ex) @@ -293,10 +296,10 @@ private void ThrowRetryLimitException(int retryCount, int timeoutElapsed, Except { errorMessage = $"The retry count has reached its limit of {session._maxRetryCount}"; } - if (timeoutElapsed >= session._maxRetryTimeout) + if (TimeoutHelper.IsExpired(timeoutElapsed * 1000, session._maxRetryTimeout)) { errorMessage += string.IsNullOrEmpty(errorMessage) ? "The" : " and the"; - errorMessage += $" timeout elapsed has reached its limit of {session._maxRetryTimeout}"; + errorMessage += $" timeout elapsed has reached its limit of {session._maxRetryTimeout.TotalSeconds}"; } errorMessage += " while trying to authenticate through Okta"; @@ -307,11 +310,11 @@ private void ThrowRetryLimitException(int retryCount, int timeoutElapsed, Except } internal class IdpTokenRestRequest : BaseRestRequest, IRestRequest - { + { private static readonly MediaTypeWithQualityHeaderValue s_jsonHeader = new MediaTypeWithQualityHeaderValue("application/json"); internal IdpTokenRequest JsonBody { get; set; } - + HttpRequestMessage IRestRequest.ToRequestMessage(HttpMethod method) { HttpRequestMessage message = newMessage(method, Url); diff --git a/Snowflake.Data/Core/ErrorMessages.resx b/Snowflake.Data/Core/ErrorMessages.resx index b7db8c58c..c8e65e465 100755 --- a/Snowflake.Data/Core/ErrorMessages.resx +++ b/Snowflake.Data/Core/ErrorMessages.resx @@ -1,17 +1,17 @@  - @@ -146,7 +146,7 @@ Invalid parameter value {0} for {1} - + Failed to convert data {0} from type {1} to type {2}. @@ -183,10 +183,13 @@ Browser response timed out after {0} seconds. + + Cannot return result set as a scalar value: {0} + IO operation failed. Error: {0} - + Executing command on a non-opened connection. - \ No newline at end of file + diff --git a/Snowflake.Data/Core/SFError.cs b/Snowflake.Data/Core/SFError.cs index ee59e9241..e4e03618f 100755 --- a/Snowflake.Data/Core/SFError.cs +++ b/Snowflake.Data/Core/SFError.cs @@ -16,7 +16,7 @@ public enum SFError [SFErrorAttr(errorCode = 270003)] INVALID_DATA_CONVERSION, - + [SFErrorAttr(errorCode = 270004)] STATEMENT_ALREADY_RUNNING_QUERY, @@ -81,16 +81,19 @@ public enum SFError [SFErrorAttr(errorCode = 270058)] IO_ERROR_ON_GETPUT_COMMAND, - + [SFErrorAttr(errorCode = 270059)] - EXECUTE_COMMAND_ON_CLOSED_CONNECTION + EXECUTE_COMMAND_ON_CLOSED_CONNECTION, + + [SFErrorAttr(errorCode = 270060)] + INCONSISTENT_RESULT_ERROR } class SFErrorAttr : Attribute { public int errorCode { get; set; } } - + public class SqlState { public const string WARNING = "01000"; diff --git a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs index c811deb8b..18eb4f650 100644 --- a/Snowflake.Data/Core/SFMultiStatementsResultSet.cs +++ b/Snowflake.Data/Core/SFMultiStatementsResultSet.cs @@ -112,7 +112,7 @@ internal override bool Rewind() private void updateSessionStatus(QueryExecResponseData responseData) { SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); + session.UpdateSessionProperties(responseData); session.UpdateSessionParameterMap(responseData.parameters); session.UpdateQueryContextCache(responseData.QueryContext); } diff --git a/Snowflake.Data/Core/SFResultSet.cs b/Snowflake.Data/Core/SFResultSet.cs index 03e1794c9..a7586f2c3 100755 --- a/Snowflake.Data/Core/SFResultSet.cs +++ b/Snowflake.Data/Core/SFResultSet.cs @@ -1,376 +1,376 @@ -/* - * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. - */ - -using System; -using System.Threading; -using System.Threading.Tasks; -using Snowflake.Data.Log; -using Snowflake.Data.Client; -using System.Collections.Generic; -using System.Diagnostics; - -namespace Snowflake.Data.Core -{ - class SFResultSet : SFBaseResultSet - { - internal override ResultFormat ResultFormat => ResultFormat.JSON; - - private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - - private readonly int _totalChunkCount; - - private readonly IChunkDownloader _chunkDownloader; - - private BaseResultChunk _currentChunk; - - public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - try - { - columnCount = responseData.rowType?.Count ?? 0; - - this.sfStatement = sfStatement; - UpdateSessionStatus(responseData); - - if (responseData.chunks != null) - { - // counting the first chunk - _totalChunkCount = responseData.chunks.Count; - _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); - } - - _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; - responseData.rowSet = null; - - sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; - - isClosed = false; - - queryId = responseData.queryId; - } - catch(System.Exception ex) - { - s_logger.Error("Result set error queryId="+responseData.queryId, ex); - throw; - } - } - - public enum PutGetResponseRowTypeInfo { - SourceFileName = 0, - DestinationFileName = 1, - SourceFileSize = 2, - DestinationFileSize = 3, - SourceCompressionType = 4, - DestinationCompressionType = 5, - ResultStatus = 6, - ErrorDetails = 7 - } - - public void InitializePutGetRowType(List rowType) - { - foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) - { - rowType.Add(new ExecResponseRowType() - { - name = t.ToString(), - type = "text" - }); - } - } - - public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() - { - responseData.rowType = new List(); - InitializePutGetRowType(responseData.rowType); - - columnCount = responseData.rowType.Count; - - this.sfStatement = sfStatement; - - _currentChunk = new SFResultChunk(responseData.rowSet); - responseData.rowSet = null; - - sfResultSetMetaData = new SFResultSetMetaData(responseData); - - isClosed = false; - - queryId = responseData.queryId; - } - - internal void ResetChunkInfo(BaseResultChunk nextChunk) - { - s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); - _currentChunk.RowSet = null; - _currentChunk = nextChunk; - } - - internal override async Task NextAsync() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - // GetNextChunk could be blocked if download result is not done yet. - // So put this piece of code in a seperate task - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - - return false; - } - - internal override bool Next() - { - ThrowIfClosed(); - - if (_currentChunk.Next()) - return true; - - if (_chunkDownloader != null) - { - s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + - $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + - $" size uncompressed: {_currentChunk.UncompressedSize}"); - BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; - if (nextChunk != null) - { - ResetChunkInfo(nextChunk); - return _currentChunk.Next(); - } - } - return false; - } - - internal override bool NextResult() - { - return false; - } - - internal override async Task NextResultAsync(CancellationToken cancellationToken) - { - return await Task.FromResult(false); - } - - internal override bool HasRows() - { - ThrowIfClosed(); - - return _currentChunk.RowCount > 0 || _totalChunkCount > 0; - } - - /// - /// Move cursor back one row. - /// - /// True if it works, false otherwise. - internal override bool Rewind() - { - ThrowIfClosed(); - - return _currentChunk.Rewind(); - } - - internal UTF8Buffer GetObjectInternal(int ordinal) - { - ThrowIfClosed(); - ThrowIfOutOfBounds(ordinal); - - return _currentChunk.ExtractCell(ordinal); - } - - private void UpdateSessionStatus(QueryExecResponseData responseData) - { - SFSession session = this.sfStatement.SfSession; - session.UpdateDatabaseAndSchema(responseData.finalDatabaseName, responseData.finalSchemaName); - session.UpdateSessionParameterMap(responseData.parameters); - session.UpdateQueryContextCache(responseData.QueryContext); - } - - internal override bool IsDBNull(int ordinal) - { - return (null == GetObjectInternal(ordinal)); - } - - internal override bool GetBoolean(int ordinal) - { - return GetValue(ordinal); - } - - internal override byte GetByte(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override char GetChar(int ordinal) - { - string val = GetString(ordinal); - return val[0]; - } - - internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) - { - return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); - } - - internal override DateTime GetDateTime(int ordinal) - { - return GetValue(ordinal); - } - - internal override TimeSpan GetTimeSpan(int ordinal) - { - return GetValue(ordinal); - } - - internal override decimal GetDecimal(int ordinal) - { - return GetValue(ordinal); - } - - internal override double GetDouble(int ordinal) - { - return GetValue(ordinal); - } - - internal override float GetFloat(int ordinal) - { - return GetValue(ordinal); - } - - internal override Guid GetGuid(int ordinal) - { - return GetValue(ordinal); - } - - internal override short GetInt16(int ordinal) - { - return GetValue(ordinal); - } - - internal override int GetInt32(int ordinal) - { - return GetValue(ordinal); - } - - internal override long GetInt64(int ordinal) - { - return GetValue(ordinal); - } - - internal override string GetString(int ordinal) - { - ThrowIfOutOfBounds(ordinal); - - var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); - switch (type) - { - case SFDataType.DATE: - var val = GetValue(ordinal); - if (val == DBNull.Value) - return null; - return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); - - default: - return GetObjectInternal(ordinal).SafeToString(); - } - } - - internal override object GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); - } - - private T GetValue(int ordinal) - { - UTF8Buffer val = GetObjectInternal(ordinal); - var types = sfResultSetMetaData.GetTypesByIndex(ordinal); - return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); - } - - // - // Summary: - // Reads a subset of data starting at location indicated by dataOffset into the buffer, - // starting at the location indicated by bufferOffset. - // - // Parameters: - // ordinal: - // The zero-based column ordinal. - // - // dataOffset: - // The index within the data from which to begin the read operation. - // - // buffer: - // The buffer into which to copy the data. - // - // bufferOffset: - // The index with the buffer to which the data will be copied. - // - // length: - // The maximum number of elements to read. - // - // Returns: - // The actual number of elements read. - private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct - { - if (dataOffset < 0) - { - throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); - } - - if (bufferOffset < 0) - { - throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); - } - - if ((null != buffer) && (bufferOffset > buffer.Length)) - { - throw new System.ArgumentException("Destination buffer is not long enough. " + - "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); - } - - T[] data = GetValue(ordinal); - - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks - // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. - // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks - // If you pass a buffer that is null, GetChars returns the length of the field in characters. - if (null == buffer) - { - return data.Length; - } - - if (dataOffset > data.Length) - { - throw new System.ArgumentException("Source data is not long enough. " + - "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); - } - else - { - // How much data is available after the offset - long dataLength = data.Length - dataOffset; - // How much data to read - long elementsRead = Math.Min(length, dataLength); - Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); - - return elementsRead; - } - } - } -} +/* + * Copyright (c) 2012-2019 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Log; +using Snowflake.Data.Client; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Snowflake.Data.Core +{ + class SFResultSet : SFBaseResultSet + { + internal override ResultFormat ResultFormat => ResultFormat.JSON; + + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + + private readonly int _totalChunkCount; + + private readonly IChunkDownloader _chunkDownloader; + + private BaseResultChunk _currentChunk; + + public SFResultSet(QueryExecResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + try + { + columnCount = responseData.rowType?.Count ?? 0; + + this.sfStatement = sfStatement; + UpdateSessionStatus(responseData); + + if (responseData.chunks != null) + { + // counting the first chunk + _totalChunkCount = responseData.chunks.Count; + _chunkDownloader = ChunkDownloaderFactory.GetDownloader(responseData, this, cancellationToken); + } + + _currentChunk = responseData.rowSet != null ? new SFResultChunk(responseData.rowSet) : null; + responseData.rowSet = null; + + sfResultSetMetaData = responseData.rowType != null ? new SFResultSetMetaData(responseData, this.sfStatement.SfSession) : null; + + isClosed = false; + + queryId = responseData.queryId; + } + catch(System.Exception ex) + { + s_logger.Error("Result set error queryId="+responseData.queryId, ex); + throw; + } + } + + public enum PutGetResponseRowTypeInfo { + SourceFileName = 0, + DestinationFileName = 1, + SourceFileSize = 2, + DestinationFileSize = 3, + SourceCompressionType = 4, + DestinationCompressionType = 5, + ResultStatus = 6, + ErrorDetails = 7 + } + + public void InitializePutGetRowType(List rowType) + { + foreach (PutGetResponseRowTypeInfo t in System.Enum.GetValues(typeof(PutGetResponseRowTypeInfo))) + { + rowType.Add(new ExecResponseRowType() + { + name = t.ToString(), + type = "text" + }); + } + } + + public SFResultSet(PutGetResponseData responseData, SFStatement sfStatement, CancellationToken cancellationToken) : base() + { + responseData.rowType = new List(); + InitializePutGetRowType(responseData.rowType); + + columnCount = responseData.rowType.Count; + + this.sfStatement = sfStatement; + + _currentChunk = new SFResultChunk(responseData.rowSet); + responseData.rowSet = null; + + sfResultSetMetaData = new SFResultSetMetaData(responseData); + + isClosed = false; + + queryId = responseData.queryId; + } + + internal void ResetChunkInfo(BaseResultChunk nextChunk) + { + s_logger.Debug($"Received chunk #{nextChunk.ChunkIndex + 1} of {_totalChunkCount}"); + _currentChunk.RowSet = null; + _currentChunk = nextChunk; + } + + internal override async Task NextAsync() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + // GetNextChunk could be blocked if download result is not done yet. + // So put this piece of code in a seperate task + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = await _chunkDownloader.GetNextChunkAsync().ConfigureAwait(false); + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + + return false; + } + + internal override bool Next() + { + ThrowIfClosed(); + + if (_currentChunk.Next()) + return true; + + if (_chunkDownloader != null) + { + s_logger.Debug($"Get next chunk from chunk downloader, chunk: {_currentChunk.ChunkIndex + 1}/{_totalChunkCount}" + + $" rows: {_currentChunk.RowCount}, size compressed: {_currentChunk.CompressedSize}," + + $" size uncompressed: {_currentChunk.UncompressedSize}"); + BaseResultChunk nextChunk = Task.Run(async() => await (_chunkDownloader.GetNextChunkAsync()).ConfigureAwait(false)).Result; + if (nextChunk != null) + { + ResetChunkInfo(nextChunk); + return _currentChunk.Next(); + } + } + return false; + } + + internal override bool NextResult() + { + return false; + } + + internal override async Task NextResultAsync(CancellationToken cancellationToken) + { + return await Task.FromResult(false); + } + + internal override bool HasRows() + { + ThrowIfClosed(); + + return _currentChunk.RowCount > 0 || _totalChunkCount > 0; + } + + /// + /// Move cursor back one row. + /// + /// True if it works, false otherwise. + internal override bool Rewind() + { + ThrowIfClosed(); + + return _currentChunk.Rewind(); + } + + internal UTF8Buffer GetObjectInternal(int ordinal) + { + ThrowIfClosed(); + ThrowIfOutOfBounds(ordinal); + + return _currentChunk.ExtractCell(ordinal); + } + + private void UpdateSessionStatus(QueryExecResponseData responseData) + { + SFSession session = this.sfStatement.SfSession; + session.UpdateSessionProperties(responseData); + session.UpdateSessionParameterMap(responseData.parameters); + session.UpdateQueryContextCache(responseData.QueryContext); + } + + internal override bool IsDBNull(int ordinal) + { + return (null == GetObjectInternal(ordinal)); + } + + internal override bool GetBoolean(int ordinal) + { + return GetValue(ordinal); + } + + internal override byte GetByte(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override char GetChar(int ordinal) + { + string val = GetString(ordinal); + return val[0]; + } + + internal override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length) + { + return ReadSubset(ordinal, dataOffset, buffer, bufferOffset, length); + } + + internal override DateTime GetDateTime(int ordinal) + { + return GetValue(ordinal); + } + + internal override TimeSpan GetTimeSpan(int ordinal) + { + return GetValue(ordinal); + } + + internal override decimal GetDecimal(int ordinal) + { + return GetValue(ordinal); + } + + internal override double GetDouble(int ordinal) + { + return GetValue(ordinal); + } + + internal override float GetFloat(int ordinal) + { + return GetValue(ordinal); + } + + internal override Guid GetGuid(int ordinal) + { + return GetValue(ordinal); + } + + internal override short GetInt16(int ordinal) + { + return GetValue(ordinal); + } + + internal override int GetInt32(int ordinal) + { + return GetValue(ordinal); + } + + internal override long GetInt64(int ordinal) + { + return GetValue(ordinal); + } + + internal override string GetString(int ordinal) + { + ThrowIfOutOfBounds(ordinal); + + var type = sfResultSetMetaData.GetColumnTypeByIndex(ordinal); + switch (type) + { + case SFDataType.DATE: + var val = GetValue(ordinal); + if (val == DBNull.Value) + return null; + return SFDataConverter.toDateString((DateTime)val, sfResultSetMetaData.dateOutputFormat); + + default: + return GetObjectInternal(ordinal).SafeToString(); + } + } + + internal override object GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return SFDataConverter.ConvertToCSharpVal(val, types.Item1, types.Item2); + } + + private T GetValue(int ordinal) + { + UTF8Buffer val = GetObjectInternal(ordinal); + var types = sfResultSetMetaData.GetTypesByIndex(ordinal); + return (T)SFDataConverter.ConvertToCSharpVal(val, types.Item1, typeof(T)); + } + + // + // Summary: + // Reads a subset of data starting at location indicated by dataOffset into the buffer, + // starting at the location indicated by bufferOffset. + // + // Parameters: + // ordinal: + // The zero-based column ordinal. + // + // dataOffset: + // The index within the data from which to begin the read operation. + // + // buffer: + // The buffer into which to copy the data. + // + // bufferOffset: + // The index with the buffer to which the data will be copied. + // + // length: + // The maximum number of elements to read. + // + // Returns: + // The actual number of elements read. + private long ReadSubset(int ordinal, long dataOffset, T[] buffer, int bufferOffset, int length) where T : struct + { + if (dataOffset < 0) + { + throw new ArgumentOutOfRangeException("dataOffset", "Non negative number is required."); + } + + if (bufferOffset < 0) + { + throw new ArgumentOutOfRangeException("bufferOffset", "Non negative number is required."); + } + + if ((null != buffer) && (bufferOffset > buffer.Length)) + { + throw new System.ArgumentException("Destination buffer is not long enough. " + + "Check the buffer offset, length, and the buffer's lower bounds.", "buffer"); + } + + T[] data = GetValue(ordinal); + + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=net-5.0#remarks + // If you pass a buffer that is null, GetBytes returns the length of the row in bytes. + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getchars?view=net-5.0#remarks + // If you pass a buffer that is null, GetChars returns the length of the field in characters. + if (null == buffer) + { + return data.Length; + } + + if (dataOffset > data.Length) + { + throw new System.ArgumentException("Source data is not long enough. " + + "Check the data offset, length, and the data's lower bounds." ,"dataOffset"); + } + else + { + // How much data is available after the offset + long dataLength = data.Length - dataOffset; + // How much data to read + long elementsRead = Math.Min(length, dataLength); + Array.Copy(data, dataOffset, buffer, bufferOffset, elementsRead); + + return elementsRead; + } + } + } +} diff --git a/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs new file mode 100644 index 000000000..a6771930c --- /dev/null +++ b/Snowflake.Data/Core/Session/ChangedSessionBehavior.cs @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. + */ + +namespace Snowflake.Data.Core.Session +{ + /** + * ChangedSessionBehavior describes what should happen to a session with a changed state (schema/role/database/warehouse) when it returns to the pool. + */ + public enum ChangedSessionBehavior + { + OriginalPool, + Destroy + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs index e10a984e3..febecbbce 100644 --- a/Snowflake.Data/Core/Session/ConnectionCacheManager.cs +++ b/Snowflake.Data/Core/Session/ConnectionCacheManager.cs @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. + */ + using System.Security; using System.Threading; using System.Threading.Tasks; @@ -6,12 +10,13 @@ namespace Snowflake.Data.Core.Session { internal sealed class ConnectionCacheManager : IConnectionManager { - private readonly SessionPool _sessionPool = new SessionPool(); + private readonly SessionPool _sessionPool = SessionPool.CreateSessionCache(); public SFSession GetSession(string connectionString, SecureString password) => _sessionPool.GetSession(connectionString, password); public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) => _sessionPool.GetSessionAsync(connectionString, password, cancellationToken); - public bool AddSession(SFSession session) => _sessionPool.AddSession(session); - public void ClearAllPools() => _sessionPool.ClearAllPools(); + public bool AddSession(SFSession session) => _sessionPool.AddSession(session, false); + public void ReleaseBusySession(SFSession session) => _sessionPool.ReleaseBusySession(session); + public void ClearAllPools() => _sessionPool.ClearSessions(); public void SetMaxPoolSize(int maxPoolSize) => _sessionPool.SetMaxPoolSize(maxPoolSize); public int GetMaxPoolSize() => _sessionPool.GetMaxPoolSize(); public void SetTimeout(long connectionTimeout) => _sessionPool.SetTimeout(connectionTimeout); @@ -19,5 +24,7 @@ public Task GetSessionAsync(string connectionString, SecureString pas public int GetCurrentPoolSize() => _sessionPool.GetCurrentPoolSize(); public bool SetPooling(bool poolingEnabled) => _sessionPool.SetPooling(poolingEnabled); public bool GetPooling() => _sessionPool.GetPooling(); + public SessionPool GetPool(string connectionString) => _sessionPool; + public SessionPool GetPool(string connectionString, SecureString password) => _sessionPool; } } diff --git a/Snowflake.Data/Core/Session/ConnectionPoolConfig.cs b/Snowflake.Data/Core/Session/ConnectionPoolConfig.cs new file mode 100644 index 000000000..25f1fcd46 --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionPoolConfig.cs @@ -0,0 +1,15 @@ +using System; + +namespace Snowflake.Data.Core.Session +{ + internal class ConnectionPoolConfig + { + public int MaxPoolSize { get; set; } = SFSessionHttpClientProperties.DefaultMaxPoolSize; + public int MinPoolSize { get; set; } = SFSessionHttpClientProperties.DefaultMinPoolSize; + public ChangedSessionBehavior ChangedSession { get; set; } = SFSessionHttpClientProperties.DefaultChangedSession; + public TimeSpan WaitingForIdleSessionTimeout { get; set; } = SFSessionHttpClientProperties.DefaultWaitingForIdleSessionTimeout; + public TimeSpan ExpirationTimeout { get; set; } = SFSessionHttpClientProperties.DefaultExpirationTimeout; + public bool PoolingEnabled { get; set; } = SFSessionHttpClientProperties.DefaultPoolingEnabled; + public TimeSpan ConnectionTimeout { get; set; } = SFSessionHttpClientProperties.DefaultConnectionTimeout; + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionPoolManager.cs b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs new file mode 100644 index 000000000..09bfa5821 --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionPoolManager.cs @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security; +using System.Threading; +using System.Threading.Tasks; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.Session +{ + internal sealed class ConnectionPoolManager : IConnectionManager + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private static readonly Object s_poolsLock = new Object(); + private static readonly Exception s_operationNotAvailable = new Exception("You cannot change connection pool parameters for all the pools. Instead you can change it on a particular pool"); + private readonly Dictionary _pools; + + internal ConnectionPoolManager() + { + lock (s_poolsLock) + { + _pools = new Dictionary(); + } + } + + public SFSession GetSession(string connectionString, SecureString password) + { + s_logger.Debug($"ConnectionPoolManager::GetSession"); + return GetPool(connectionString, password).GetSession(); + } + + public Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken) + { + s_logger.Debug($"ConnectionPoolManager::GetSessionAsync"); + return GetPool(connectionString, password).GetSessionAsync(cancellationToken); + } + + public bool AddSession(SFSession session) + { + s_logger.Debug("ConnectionPoolManager::AddSession"); + return GetPool(session.ConnectionString, session.Password).AddSession(session, true); + } + + public void ReleaseBusySession(SFSession session) + { + s_logger.Debug("ConnectionPoolManager::ReleaseBusySession"); + GetPool(session.ConnectionString, session.Password).ReleaseBusySession(session); + } + + public void ClearAllPools() + { + s_logger.Debug("ConnectionPoolManager::ClearAllPools"); + foreach (var sessionPool in _pools.Values) + { + sessionPool.DestroyPool(); + } + _pools.Clear(); + } + + public void SetMaxPoolSize(int maxPoolSize) + { + throw s_operationNotAvailable; + } + + public int GetMaxPoolSize() + { + s_logger.Debug("ConnectionPoolManager::GetMaxPoolSize"); + var values = _pools.Values.Select(it => it.GetMaxPoolSize()).Distinct().ToList(); + switch (values.Count) + { + case 0: + return SFSessionHttpClientProperties.DefaultMaxPoolSize; + case 1: + return values.First(); + default: + throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Max Pool Size values"); + } + } + + public void SetTimeout(long connectionTimeout) + { + throw s_operationNotAvailable; + } + + public long GetTimeout() + { + s_logger.Debug("ConnectionPoolManager::GetTimeout"); + var values = _pools.Values.Select(it => it.GetTimeout()).Distinct().ToList(); + switch (values.Count) + { + case 0: + return (long) SFSessionHttpClientProperties.DefaultExpirationTimeout.TotalSeconds; + case 1: + return values.First(); + default: + throw new SnowflakeDbException(SFError.INCONSISTENT_RESULT_ERROR, "Multiple pools have different Timeout values"); + } + } + + public int GetCurrentPoolSize() + { + s_logger.Debug("ConnectionPoolManager::GetCurrentPoolSize"); + return _pools.Values.Select(it => it.GetCurrentPoolSize()).Sum(); + } + + public bool SetPooling(bool poolingEnabled) + { + throw s_operationNotAvailable; + } + + public bool GetPooling() + { + s_logger.Debug("ConnectionPoolManager::GetPooling"); + return true; // in new pool pooling is always enabled by default, disabling only by connection string parameter + } + + public SessionPool GetPool(string connectionString, SecureString password) + { + s_logger.Debug("ConnectionPoolManager::GetPool with connection string and secure password"); + var poolKey = GetPoolKey(connectionString, password); + + if (_pools.TryGetValue(poolKey, out var item)) + { + item.ValidateSecurePassword(password); + return item; + } + + lock (s_poolsLock) + { + if (_pools.TryGetValue(poolKey, out var poolCreatedWhileWaitingOnLock)) + { + poolCreatedWhileWaitingOnLock.ValidateSecurePassword(password); + return poolCreatedWhileWaitingOnLock; + } + s_logger.Info($"Creating new pool"); + var pool = SessionPool.CreateSessionPool(connectionString, password); + _pools.Add(poolKey, pool); + return pool; + } + } + + public SessionPool GetPool(string connectionString) + { + s_logger.Debug("ConnectionPoolManager::GetPool with connection string"); + return GetPool(connectionString, null); + } + + private string GetPoolKey(string connectionString, SecureString password) => + password != null && password.Length > 0 + ? connectionString + ";password=" + SecureStringHelper.Decode(password) + ";" + : connectionString + ";password=;"; + } +} diff --git a/Snowflake.Data/Core/Session/ConnectionPoolType.cs b/Snowflake.Data/Core/Session/ConnectionPoolType.cs new file mode 100644 index 000000000..5844878fc --- /dev/null +++ b/Snowflake.Data/Core/Session/ConnectionPoolType.cs @@ -0,0 +1,8 @@ +namespace Snowflake.Data.Core.Session +{ + internal enum ConnectionPoolType + { + SingleConnectionCache, + MultipleConnectionPool + } +} diff --git a/Snowflake.Data/Core/Session/FixedZeroCounter.cs b/Snowflake.Data/Core/Session/FixedZeroCounter.cs new file mode 100644 index 000000000..f1d8467be --- /dev/null +++ b/Snowflake.Data/Core/Session/FixedZeroCounter.cs @@ -0,0 +1,19 @@ +namespace Snowflake.Data.Core.Session +{ + internal class FixedZeroCounter: ICounter + { + public int Count() => 0; + + public void Increase() + { + } + + public void Decrease() + { + } + + public void Reset() + { + } + } +} diff --git a/Snowflake.Data/Core/Session/IConnectionManager.cs b/Snowflake.Data/Core/Session/IConnectionManager.cs index e72ade2e7..01cfa3e8c 100644 --- a/Snowflake.Data/Core/Session/IConnectionManager.cs +++ b/Snowflake.Data/Core/Session/IConnectionManager.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System.Security; @@ -13,6 +13,7 @@ internal interface IConnectionManager SFSession GetSession(string connectionString, SecureString password); Task GetSessionAsync(string connectionString, SecureString password, CancellationToken cancellationToken); bool AddSession(SFSession session); + void ReleaseBusySession(SFSession session); void ClearAllPools(); void SetMaxPoolSize(int maxPoolSize); int GetMaxPoolSize(); @@ -21,5 +22,7 @@ internal interface IConnectionManager int GetCurrentPoolSize(); bool SetPooling(bool poolingEnabled); bool GetPooling(); + SessionPool GetPool(string connectionString); + SessionPool GetPool(string connectionString, SecureString password); } } diff --git a/Snowflake.Data/Core/Session/ICounter.cs b/Snowflake.Data/Core/Session/ICounter.cs new file mode 100644 index 000000000..5a38878e7 --- /dev/null +++ b/Snowflake.Data/Core/Session/ICounter.cs @@ -0,0 +1,13 @@ +namespace Snowflake.Data.Core.Session +{ + internal interface ICounter + { + int Count(); + + void Increase(); + + void Decrease(); + + void Reset(); + } +} diff --git a/Snowflake.Data/Core/Session/ISessionCreationTokenCounter.cs b/Snowflake.Data/Core/Session/ISessionCreationTokenCounter.cs new file mode 100644 index 000000000..9b98c01e5 --- /dev/null +++ b/Snowflake.Data/Core/Session/ISessionCreationTokenCounter.cs @@ -0,0 +1,13 @@ +namespace Snowflake.Data.Core.Session +{ + internal interface ISessionCreationTokenCounter + { + SessionCreationToken NewToken(); + + void RemoveToken(SessionCreationToken creationToken); + + int Count(); + + void Reset(); + } +} diff --git a/Snowflake.Data/Core/Session/ISessionFactory.cs b/Snowflake.Data/Core/Session/ISessionFactory.cs new file mode 100644 index 000000000..f9416de8d --- /dev/null +++ b/Snowflake.Data/Core/Session/ISessionFactory.cs @@ -0,0 +1,9 @@ +using System.Security; + +namespace Snowflake.Data.Core.Session +{ + internal interface ISessionFactory + { + SFSession NewSession(string connectionString, SecureString password); + } +} diff --git a/Snowflake.Data/Core/Session/ISessionPoolEventHandler.cs b/Snowflake.Data/Core/Session/ISessionPoolEventHandler.cs new file mode 100644 index 000000000..2b16959a2 --- /dev/null +++ b/Snowflake.Data/Core/Session/ISessionPoolEventHandler.cs @@ -0,0 +1,15 @@ +namespace Snowflake.Data.Core.Session +{ + internal interface ISessionPoolEventHandler + { + void OnNewSessionCreated(SessionPool sessionPool); + + void OnWaitingForSessionStarted(SessionPool sessionPool); + + void OnWaitingForSessionStarted(SessionPool sessionPool, long millisLeft); + + void OnWaitingForSessionSuccessful(SessionPool sessionPool); + + void OnSessionProvided(SessionPool sessionPool); + } +} diff --git a/Snowflake.Data/Core/Session/IWaitingQueue.cs b/Snowflake.Data/Core/Session/IWaitingQueue.cs new file mode 100644 index 000000000..26bc45d0d --- /dev/null +++ b/Snowflake.Data/Core/Session/IWaitingQueue.cs @@ -0,0 +1,19 @@ +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + internal interface IWaitingQueue + { + bool Wait(int millisecondsTimeout, CancellationToken cancellationToken); + + void OnResourceIncrease(); + + bool IsAnyoneWaiting(); + + int WaitingCount(); + + bool IsWaitingEnabled(); + + void Reset(); + } +} diff --git a/Snowflake.Data/Core/Session/NonCountingSessionCreationTokenCounter.cs b/Snowflake.Data/Core/Session/NonCountingSessionCreationTokenCounter.cs new file mode 100644 index 000000000..44292d755 --- /dev/null +++ b/Snowflake.Data/Core/Session/NonCountingSessionCreationTokenCounter.cs @@ -0,0 +1,21 @@ +using System; + +namespace Snowflake.Data.Core.Session +{ + internal class NonCountingSessionCreationTokenCounter: ISessionCreationTokenCounter + { + private static readonly TimeSpan s_irrelevantCreateSessionTimeout = SFSessionHttpClientProperties.DefaultConnectionTimeout; // in case of old caching pool or pooling disabled we do not remove expired ones nor even store them + + public SessionCreationToken NewToken() => new SessionCreationToken(s_irrelevantCreateSessionTimeout); + + public void RemoveToken(SessionCreationToken creationToken) + { + } + + public int Count() => 0; + + public void Reset() + { + } + } +} diff --git a/Snowflake.Data/Core/Session/NonNegativeCounter.cs b/Snowflake.Data/Core/Session/NonNegativeCounter.cs new file mode 100644 index 000000000..5f1fa5959 --- /dev/null +++ b/Snowflake.Data/Core/Session/NonNegativeCounter.cs @@ -0,0 +1,20 @@ +using System; + +namespace Snowflake.Data.Core.Session +{ + internal class NonNegativeCounter : ICounter + { + private int _value; + + public int Count() => _value; + + public void Increase() => _value++; + + public void Decrease() + { + _value = Math.Max(_value - 1, 0); + } + + public void Reset() => _value = 0; + } +} diff --git a/Snowflake.Data/Core/Session/NonWaitingQueue.cs b/Snowflake.Data/Core/Session/NonWaitingQueue.cs new file mode 100644 index 000000000..46ec84677 --- /dev/null +++ b/Snowflake.Data/Core/Session/NonWaitingQueue.cs @@ -0,0 +1,35 @@ +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + internal class NonWaitingQueue: IWaitingQueue + { + public bool Wait(int millisecondsTimeout, CancellationToken cancellationToken) + { + return false; + } + + public void OnResourceIncrease() + { + } + + public bool IsAnyoneWaiting() + { + return false; + } + + public int WaitingCount() + { + return 0; + } + + public bool IsWaitingEnabled() + { + return false; + } + + public void Reset() + { + } + } +} diff --git a/Snowflake.Data/Core/Session/SFSession.cs b/Snowflake.Data/Core/Session/SFSession.cs index 3b0c80f8d..af8b1e55b 100755 --- a/Snowflake.Data/Core/Session/SFSession.cs +++ b/Snowflake.Data/Core/Session/SFSession.cs @@ -1,10 +1,9 @@ /* - * Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. */ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Security; using System.Web; @@ -15,7 +14,8 @@ using System.Threading.Tasks; using System.Net.Http; using System.Text.RegularExpressions; -using Snowflake.Data.Configuration; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Core { @@ -46,12 +46,16 @@ public class SFSession internal SFSessionProperties properties; internal string database; - internal string schema; + internal string role; + internal string warehouse; + internal bool sessionPropertiesChanged = false; internal string serverVersion; - internal TimeSpan connectionTimeout; + private readonly ConnectionPoolConfig _poolConfig; + + internal TimeSpan connectionTimeout => _poolConfig.ConnectionTimeout; internal bool InsecureMode; @@ -62,14 +66,12 @@ public class SFSession private string arrayBindStage = null; private int arrayBindStageThreshold = 0; internal int masterValidityInSeconds = 0; - - internal static readonly SFSessionHttpClientProperties.Extractor propertiesExtractor = new SFSessionHttpClientProperties.Extractor( - new SFSessionHttpClientProxyProperties.Extractor()); private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance; private long _startTime = 0; - internal string connStr = null; + internal string ConnectionString { get; } + internal SecureString Password { get; } private QueryContextCache _queryContextCache = new QueryContextCache(_defaultQueryContextCacheSize); @@ -81,7 +83,16 @@ public class SFSession internal int _maxRetryCount; - internal int _maxRetryTimeout; + internal TimeSpan _maxRetryTimeout; + + private string _user; + + public bool GetPooling() => _poolConfig.PoolingEnabled; + + public void SetPooling(bool isEnabled) + { + _poolConfig.PoolingEnabled = isEnabled; + } internal String _queryTag; @@ -94,6 +105,8 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) masterToken = authnResponse.data.masterToken; database = authnResponse.data.authResponseSessionInfo.databaseName; schema = authnResponse.data.authResponseSessionInfo.schemaName; + role = authnResponse.data.authResponseSessionInfo.roleName; + warehouse = authnResponse.data.authResponseSessionInfo.warehouseName; serverVersion = authnResponse.data.serverVersion; masterValidityInSeconds = authnResponse.data.masterValidityInSeconds; UpdateSessionParameterMap(authnResponse.data.nameValueParameter); @@ -102,7 +115,7 @@ internal void ProcessLoginResponse(LoginResponse authnResponse) logger.Debug("Query context cache disabled."); } logger.Debug($"Session opened: {sessionId}"); - _startTime = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + _startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); } else { @@ -138,7 +151,7 @@ internal Uri BuildLoginUrl() } /// - /// Constructor + /// Constructor /// /// A string in the form of "key1=value1;key2=value2" internal SFSession( @@ -153,37 +166,43 @@ internal SFSession( EasyLoggingStarter easyLoggingStarter) { _easyLoggingStarter = easyLoggingStarter; - connStr = connectionString; - properties = SFSessionProperties.ParseConnectionString(connectionString, password); + ConnectionString = connectionString; + Password = password; + properties = SFSessionProperties.ParseConnectionString(ConnectionString, Password); _disableQueryContextCache = bool.Parse(properties[SFSessionProperty.DISABLEQUERYCONTEXTCACHE]); _disableConsoleLogin = bool.Parse(properties[SFSessionProperty.DISABLE_CONSOLE_LOGIN]); + properties.TryGetValue(SFSessionProperty.USER, out _user); ValidateApplicationName(properties); try { - var extractedProperties = propertiesExtractor.ExtractProperties(properties); + var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); var httpClientConfig = extractedProperties.BuildHttpClientConfig(); ParameterMap = extractedProperties.ToParameterMap(); InsecureMode = extractedProperties.insecureMode; _HttpClient = HttpUtil.Instance.GetHttpClient(httpClientConfig); restRequester = new RestRequester(_HttpClient); - extractedProperties.CheckPropertiesAreValid(); - connectionTimeout = extractedProperties.TimeoutDuration(); + _poolConfig = extractedProperties.BuildConnectionPoolConfig(); properties.TryGetValue(SFSessionProperty.CLIENT_CONFIG_FILE, out var easyLoggingConfigFile); _easyLoggingStarter.Init(easyLoggingConfigFile); properties.TryGetValue(SFSessionProperty.QUERY_TAG, out _queryTag); _maxRetryCount = extractedProperties.maxHttpRetries; _maxRetryTimeout = extractedProperties.retryTimeout; } + catch (SnowflakeDbException e) + { + logger.Error("Unable to initialize session ", e); + throw; + } catch (Exception e) { - logger.Error("Unable to connect", e); + logger.Error("Unable to initialize session ", e); throw new SnowflakeDbException(e, SnowflakeDbException.CONNECTION_FAILURE_SSTATE, SFError.INVALID_CONNECTION_STRING, - "Unable to connect"); + "Unable to initialize session "); } } - + private void ValidateApplicationName(SFSessionProperties properties) { // If there is an "application" setting, verify that it matches the expect pattern @@ -227,7 +246,7 @@ internal Uri BuildUri(string path, Dictionary queryParams = null return uriBuilder.Uri; } - internal void Open() + internal virtual void Open() { logger.Debug("Open Session"); @@ -239,7 +258,7 @@ internal void Open() authenticator.Authenticate(); } - internal async Task OpenAsync(CancellationToken cancellationToken) + internal virtual async Task OpenAsync(CancellationToken cancellationToken) { logger.Debug("Open Session Async"); @@ -255,7 +274,7 @@ internal void close() { // Nothing to do if the session is not open if (!IsEstablished()) return; - + logger.Debug($"Closing session with id: {sessionId}, user: {_user}, database: {database}, schema: {schema}, role: {role}, warehouse: {warehouse}, connection start timestamp: {_startTime}"); stopHeartBeatForThisSession(); // Send a close session request @@ -287,7 +306,7 @@ internal async Task CloseAsync(CancellationToken cancellationToken) { // Nothing to do if the session is not open if (!IsEstablished()) return; - + logger.Debug($"Closing session with id: {sessionId}, user: {_user}, database: {database}, schema: {schema}, role: {role}, warehouse: {warehouse}, connection start timestamp: {_startTime}"); stopHeartBeatForThisSession(); // Send a close session request @@ -453,20 +472,48 @@ internal RequestQueryContext GetQueryContextRequest() return _queryContextCache.GetQueryContextRequest(); } - internal void UpdateDatabaseAndSchema(string databaseName, string schemaName) + internal void UpdateSessionProperties(QueryExecResponseData responseData) { - // with HTAP session metadata removal database/schema - // might be not returened in query result - if (!String.IsNullOrEmpty(databaseName)) - { - this.database = databaseName; - } - if (!String.IsNullOrEmpty(schemaName)) + // with HTAP session metadata removal database/schema might be not returned in query result + UpdateSessionProperty(ref database, responseData.finalDatabaseName); + UpdateSessionProperty(ref schema, responseData.finalSchemaName); + UpdateSessionProperty(ref role, responseData.finalRoleName); + UpdateSessionProperty(ref warehouse, responseData.finalWarehouseName); + } + + internal void UpdateSessionProperty(ref string initialSessionValue, string finalSessionValue) + { + // with HTAP session metadata removal database/schema might be not returned in query result + if (!string.IsNullOrEmpty(finalSessionValue)) { - this.schema = schemaName; + bool quoted = false; + string unquotedFinalValue = UnquoteJson(finalSessionValue, ref quoted); + if (!string.IsNullOrEmpty(initialSessionValue)) + { + quoted |= initialSessionValue.StartsWith("\""); + if (!string.Equals(initialSessionValue, unquotedFinalValue, quoted ? StringComparison.Ordinal : StringComparison.OrdinalIgnoreCase)) + { + sessionPropertiesChanged = true; + initialSessionValue = unquotedFinalValue; + } + } + else // null session value gets populated and is not treated as a session property change + { + initialSessionValue = unquotedFinalValue; + } } } - + + private static string UnquoteJson(string value, ref bool unquoted) + { + if (value is null) + return value; + unquoted = value.Length >= 4 && value.StartsWith("\\\"") && value.EndsWith("\\\""); + return unquoted ? value.Replace("\\\"", "\"") : value; + } + + internal bool SessionPropertiesChanged => sessionPropertiesChanged; + internal void startHeartBeatForThisSession() { if (!this.isHeartBeatEnabled) @@ -571,15 +618,17 @@ internal void heartbeat() } } - internal bool IsNotOpen() + internal virtual bool IsNotOpen() { return _startTime == 0; } - internal bool IsExpired(long timeoutInSeconds, long utcTimeInSeconds) + internal virtual bool IsExpired(TimeSpan timeout, long utcTimeInMillis) { - return _startTime + timeoutInSeconds <= utcTimeInSeconds; + var hasEverBeenOpened = !IsNotOpen(); + return hasEverBeenOpened && TimeoutHelper.IsExpired(_startTime, utcTimeInMillis, timeout); } + + internal long GetStartTime() => _startTime; } } - diff --git a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs index f129de25a..2ba4709e7 100644 --- a/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs +++ b/Snowflake.Data/Core/Session/SFSessionHttpClientProperties.cs @@ -1,6 +1,9 @@ using System; using System.Collections.Generic; -using System.Threading; +using Snowflake.Data.Client; +using Snowflake.Data.Core.Authenticator; +using Snowflake.Data.Core.Session; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; namespace Snowflake.Data.Core @@ -8,71 +11,170 @@ namespace Snowflake.Data.Core internal class SFSessionHttpClientProperties { - internal static readonly int s_maxHttpRetriesDefault = 7; - internal static readonly int s_retryTimeoutDefault = 300; - private static readonly SFLogger logger = SFLoggerFactory.GetLogger(); + private static readonly Extractor s_propertiesExtractor = new Extractor(new SFSessionHttpClientProxyProperties.Extractor()); + public const int DefaultMaxPoolSize = 10; + public const int DefaultMinPoolSize = 2; + public const ChangedSessionBehavior DefaultChangedSession = ChangedSessionBehavior.Destroy; + public static readonly TimeSpan DefaultWaitingForIdleSessionTimeout = TimeSpan.FromSeconds(30); + public static readonly TimeSpan DefaultConnectionTimeout = TimeSpan.FromMinutes(5); + public static readonly TimeSpan DefaultExpirationTimeout = TimeSpan.FromHours(1); + public const bool DefaultPoolingEnabled = true; + public const int DefaultMaxHttpRetries = 7; + public static readonly TimeSpan DefaultRetryTimeout = TimeSpan.FromSeconds(300); + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); internal bool validateDefaultParameters; internal bool clientSessionKeepAlive; - internal int timeoutInSec; + internal TimeSpan connectionTimeout; internal bool insecureMode; internal bool disableRetry; internal bool forceRetryOn404; - internal int retryTimeout; + internal TimeSpan retryTimeout; internal int maxHttpRetries; internal bool includeRetryReason; internal SFSessionHttpClientProxyProperties proxyProperties; + private int _maxPoolSize; + private int _minPoolSize; + private ChangedSessionBehavior _changedSession; + private TimeSpan _waitingForSessionIdleTimeout; + private TimeSpan _expirationTimeout; + private bool _poolingEnabled; - internal void CheckPropertiesAreValid() + public static SFSessionHttpClientProperties ExtractAndValidate(SFSessionProperties properties) { - if (timeoutInSec < s_retryTimeoutDefault) + var extractedProperties = s_propertiesExtractor.ExtractProperties(properties); + extractedProperties.CheckPropertiesAreValid(); + return extractedProperties; + } + + public void DisablePoolingDefaultIfSecretsProvidedExternally(SFSessionProperties properties) + { + var authenticator = properties[SFSessionProperty.AUTHENTICATOR].ToLower(); + if (ExternalBrowserAuthenticator.AUTH_NAME.Equals(authenticator)) { - logger.Warn($"Connection timeout provided is less than recommended minimum value of" + - $" {s_retryTimeoutDefault}"); + DisablePoolingIfNotExplicitlyEnabled(properties, "external browser"); + + } else if (KeyPairAuthenticator.AUTH_NAME.Equals(authenticator) + && properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_FILE) + && !properties.IsNonEmptyValueProvided(SFSessionProperty.PRIVATE_KEY_PWD)) + { + DisablePoolingIfNotExplicitlyEnabled(properties, "key pair with private key in a file"); } + } - if (timeoutInSec < 0) + private void DisablePoolingIfNotExplicitlyEnabled(SFSessionProperties properties, string authenticationDescription) + { + if (!properties.IsPoolingEnabledValueProvided && _poolingEnabled) + { + _poolingEnabled = false; + s_logger.Info($"Disabling connection pooling for {authenticationDescription} authentication"); + } + else if (properties.IsPoolingEnabledValueProvided && _poolingEnabled) { - logger.Warn($"Connection timeout provided is negative. Timeout will be infinite."); + s_logger.Warn($"Connection pooling is enabled for {authenticationDescription} authentication which is not recommended"); } + } - if (retryTimeout > 0 && retryTimeout < s_retryTimeoutDefault) + private void CheckPropertiesAreValid() + { + try + { + ValidateConnectionTimeout(); + ValidateRetryTimeout(); + ShortenConnectionTimeoutByRetryTimeout(); + ValidateHttpRetries(); + ValidateMinMaxPoolSize(); + ValidateWaitingForSessionIdleTimeout(); + } + catch (SnowflakeDbException) + { + throw; + } + catch (Exception exception) + { + throw new SnowflakeDbException(SFError.INVALID_CONNECTION_STRING, exception); + } + } + + private void ValidateConnectionTimeout() + { + if (TimeoutHelper.IsZeroLength(connectionTimeout)) + { + s_logger.Warn("Connection timeout provided is 0. Timeout will be infinite"); + connectionTimeout = TimeoutHelper.Infinity(); + } + else if (TimeoutHelper.IsInfinite(connectionTimeout)) + { + s_logger.Warn("Connection timeout provided is negative. Timeout will be infinite."); + } + if (!TimeoutHelper.IsInfinite(connectionTimeout) && connectionTimeout < DefaultRetryTimeout) { - logger.Warn($"Max retry timeout provided is less than the allowed minimum value of" + - $" {s_retryTimeoutDefault}"); + s_logger.Warn($"Connection timeout provided is less than recommended minimum value of {DefaultRetryTimeout}"); + } + } - retryTimeout = s_retryTimeoutDefault; + private void ValidateRetryTimeout() + { + if (retryTimeout.TotalMilliseconds > 0 && retryTimeout < DefaultRetryTimeout) + { + s_logger.Warn($"Max retry timeout provided is less than the allowed minimum value of {DefaultRetryTimeout}"); + retryTimeout = DefaultRetryTimeout; + } + else if (TimeoutHelper.IsZeroLength(retryTimeout)) + { + s_logger.Warn($"Max retry timeout provided is 0. Timeout will be infinite"); + retryTimeout = TimeoutHelper.Infinity(); } - else if (retryTimeout == 0) + else if (TimeoutHelper.IsInfinite(retryTimeout)) { - logger.Warn($"Max retry timeout provided is 0. Timeout will be infinite"); + s_logger.Warn($"Max retry timeout provided is negative. Timeout will be infinite"); } + } - // Use the shorter timeout between CONNECTION_TIMEOUT and RETRY_TIMEOUT - if (retryTimeout < timeoutInSec) + private void ShortenConnectionTimeoutByRetryTimeout() + { + if (!TimeoutHelper.IsInfinite(retryTimeout) && retryTimeout < connectionTimeout) { - timeoutInSec = retryTimeout; + s_logger.Warn($"Connection timeout greater than retry timeout. Setting connection time same as retry timeout"); + connectionTimeout = retryTimeout; } + } - if (maxHttpRetries > 0 && maxHttpRetries < s_maxHttpRetriesDefault) + private void ValidateHttpRetries() + { + if (maxHttpRetries > 0 && maxHttpRetries < DefaultMaxHttpRetries) { - logger.Warn($"Max retry count provided is less than the allowed minimum value of" + - $" {s_maxHttpRetriesDefault}"); + s_logger.Warn($"Max retry count provided is less than the allowed minimum value of {DefaultMaxHttpRetries}"); - maxHttpRetries = s_maxHttpRetriesDefault; + maxHttpRetries = DefaultMaxHttpRetries; } else if (maxHttpRetries == 0) { - logger.Warn($"Max retry count provided is 0. Retry count will be infinite"); + s_logger.Warn($"Max retry count provided is 0. Retry count will be infinite"); } } - internal TimeSpan TimeoutDuration() + private void ValidateMinMaxPoolSize() { - return timeoutInSec > 0 ? TimeSpan.FromSeconds(timeoutInSec) : Timeout.InfiniteTimeSpan; + if (_minPoolSize > _maxPoolSize) + { + throw new Exception("MinPoolSize cannot be greater than MaxPoolSize"); + } } - internal HttpClientConfig BuildHttpClientConfig() + private void ValidateWaitingForSessionIdleTimeout() + { + if (TimeoutHelper.IsInfinite(_waitingForSessionIdleTimeout)) + { + throw new Exception("Waiting for idle session timeout cannot be infinite"); + } + if (TimeoutHelper.IsZeroLength(_waitingForSessionIdleTimeout)) + { + s_logger.Warn("Waiting for idle session timeout is 0. There will be no waiting for idle session"); + } + } + + public HttpClientConfig BuildHttpClientConfig() { return new HttpClientConfig( !insecureMode, @@ -87,6 +189,18 @@ internal HttpClientConfig BuildHttpClientConfig() includeRetryReason); } + public ConnectionPoolConfig BuildConnectionPoolConfig() => + new ConnectionPoolConfig() + { + MaxPoolSize = _maxPoolSize, + MinPoolSize = _minPoolSize, + ChangedSession = _changedSession, + WaitingForIdleSessionTimeout = _waitingForSessionIdleTimeout, + ExpirationTimeout = _expirationTimeout, + PoolingEnabled = _poolingEnabled, + ConnectionTimeout = connectionTimeout + }; + internal Dictionary ToParameterMap() { var parameterMap = new Dictionary(); @@ -111,20 +225,37 @@ public Extractor(SFSessionHttpClientProxyProperties.IExtractor proxyPropertiesEx public SFSessionHttpClientProperties ExtractProperties(SFSessionProperties propertiesDictionary) { + var extractor = new SessionPropertiesWithDefaultValuesExtractor(propertiesDictionary, true); return new SFSessionHttpClientProperties() { validateDefaultParameters = Boolean.Parse(propertiesDictionary[SFSessionProperty.VALIDATE_DEFAULT_PARAMETERS]), clientSessionKeepAlive = Boolean.Parse(propertiesDictionary[SFSessionProperty.CLIENT_SESSION_KEEP_ALIVE]), - timeoutInSec = int.Parse(propertiesDictionary[SFSessionProperty.CONNECTION_TIMEOUT]), + connectionTimeout = extractor.ExtractTimeout(SFSessionProperty.CONNECTION_TIMEOUT), insecureMode = Boolean.Parse(propertiesDictionary[SFSessionProperty.INSECUREMODE]), disableRetry = Boolean.Parse(propertiesDictionary[SFSessionProperty.DISABLERETRY]), forceRetryOn404 = Boolean.Parse(propertiesDictionary[SFSessionProperty.FORCERETRYON404]), - retryTimeout = int.Parse(propertiesDictionary[SFSessionProperty.RETRY_TIMEOUT]), + retryTimeout = extractor.ExtractTimeout(SFSessionProperty.RETRY_TIMEOUT), maxHttpRetries = int.Parse(propertiesDictionary[SFSessionProperty.MAXHTTPRETRIES]), includeRetryReason = Boolean.Parse(propertiesDictionary[SFSessionProperty.INCLUDERETRYREASON]), - proxyProperties = proxyPropertiesExtractor.ExtractProperties(propertiesDictionary) + proxyProperties = proxyPropertiesExtractor.ExtractProperties(propertiesDictionary), + _maxPoolSize = extractor.ExtractPositiveIntegerWithDefaultValue(SFSessionProperty.MAXPOOLSIZE), + _minPoolSize = extractor.ExtractNonNegativeIntegerWithDefaultValue(SFSessionProperty.MINPOOLSIZE), + _changedSession = ExtractChangedSession(extractor, SFSessionProperty.CHANGEDSESSION), + _waitingForSessionIdleTimeout = extractor.ExtractTimeout(SFSessionProperty.WAITINGFORIDLESESSIONTIMEOUT), + _expirationTimeout = extractor.ExtractTimeout(SFSessionProperty.EXPIRATIONTIMEOUT), + _poolingEnabled = extractor.ExtractBooleanWithDefaultValue(SFSessionProperty.POOLINGENABLED) }; } + + private ChangedSessionBehavior ExtractChangedSession( + SessionPropertiesWithDefaultValuesExtractor extractor, + SFSessionProperty property) => + extractor.ExtractPropertyWithDefaultValue( + property, + i => (ChangedSessionBehavior)Enum.Parse(typeof(ChangedSessionBehavior), i, true), + s => !string.IsNullOrEmpty(s), + b => true + ); } } -} \ No newline at end of file +} diff --git a/Snowflake.Data/Core/Session/SFSessionProperty.cs b/Snowflake.Data/Core/Session/SFSessionProperty.cs index 08f3dcdee..12b650ce6 100644 --- a/Snowflake.Data/Core/Session/SFSessionProperty.cs +++ b/Snowflake.Data/Core/Session/SFSessionProperty.cs @@ -11,7 +11,9 @@ using Snowflake.Data.Core.Authenticator; using System.Data.Common; using System.Linq; +using System.Text; using System.Text.RegularExpressions; +using Snowflake.Data.Core.Tools; namespace Snowflake.Data.Core { @@ -23,7 +25,7 @@ internal enum SFSessionProperty DB, [SFSessionPropertyAttr(required = false)] HOST, - [SFSessionPropertyAttr(required = true)] + [SFSessionPropertyAttr(required = true, IsSecret = true)] PASSWORD, [SFSessionPropertyAttr(required = false, defaultValue = "443")] PORT, @@ -45,11 +47,11 @@ internal enum SFSessionProperty VALIDATE_DEFAULT_PARAMETERS, [SFSessionPropertyAttr(required = false)] PRIVATE_KEY_FILE, - [SFSessionPropertyAttr(required = false)] + [SFSessionPropertyAttr(required = false, IsSecret = true)] PRIVATE_KEY_PWD, - [SFSessionPropertyAttr(required = false)] + [SFSessionPropertyAttr(required = false, IsSecret = true)] PRIVATE_KEY, - [SFSessionPropertyAttr(required = false)] + [SFSessionPropertyAttr(required = false, IsSecret = true)] TOKEN, [SFSessionPropertyAttr(required = false, defaultValue = "false")] INSECUREMODE, @@ -61,7 +63,7 @@ internal enum SFSessionProperty PROXYPORT, [SFSessionPropertyAttr(required = false)] PROXYUSER, - [SFSessionPropertyAttr(required = false)] + [SFSessionPropertyAttr(required = false, IsSecret = true)] PROXYPASSWORD, [SFSessionPropertyAttr(required = false)] NONPROXYHOSTS, @@ -96,7 +98,19 @@ internal enum SFSessionProperty [SFSessionPropertyAttr(required = false, defaultValue = "false")] ALLOWUNDERSCORESINHOST, [SFSessionPropertyAttr(required = false)] - QUERY_TAG + QUERY_TAG, + [SFSessionPropertyAttr(required = false, defaultValue = "10")] + MAXPOOLSIZE, + [SFSessionPropertyAttr(required = false, defaultValue = "2")] + MINPOOLSIZE, + [SFSessionPropertyAttr(required = false, defaultValue = "Destroy")] + CHANGEDSESSION, + [SFSessionPropertyAttr(required = false, defaultValue = "30s")] + WAITINGFORIDLESESSIONTIMEOUT, + [SFSessionPropertyAttr(required = false, defaultValue = "60m")] + EXPIRATIONTIMEOUT, + [SFSessionPropertyAttr(required = false, defaultValue = "true")] + POOLINGENABLED } class SFSessionPropertyAttr : Attribute @@ -104,21 +118,24 @@ class SFSessionPropertyAttr : Attribute public bool required { get; set; } public string defaultValue { get; set; } + + public bool IsSecret { get; set; } = false; } class SFSessionProperties : Dictionary { private static SFLogger logger = SFLoggerFactory.GetLogger(); + internal string ConnectionStringWithoutSecrets { get; set; } + + internal bool IsPoolingEnabledValueProvided { get; set; } + // Connection string properties to obfuscate in the log - private static List secretProps = - new List{ - SFSessionProperty.PASSWORD, - SFSessionProperty.PRIVATE_KEY, - SFSessionProperty.TOKEN, - SFSessionProperty.PRIVATE_KEY_PWD, - SFSessionProperty.PROXYPASSWORD, - }; + private static readonly List s_secretProps = Enum.GetValues(typeof(SFSessionProperty)) + .Cast() + .Where(p => p.GetAttribute().IsSecret) + .Select(p => p.ToString()) + .ToList(); private static readonly List s_accountRegexStrings = new List { @@ -184,6 +201,8 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin builder.Keys.CopyTo(keys, 0); builder.Values.CopyTo(values,0); + properties.ConnectionStringWithoutSecrets = BuildConnectionStringWithoutSecrets(ref keys, ref values); + for(var i=0; i 0) { - properties[SFSessionProperty.PASSWORD] = new NetworkCredential(string.Empty, password).Password; + properties[SFSessionProperty.PASSWORD] = SecureStringHelper.Decode(password); } - checkSessionProperties(properties); + ValidateAuthenticator(properties); + properties.IsPoolingEnabledValueProvided = properties.IsNonEmptyValueProvided(SFSessionProperty.POOLINGENABLED); + CheckSessionProperties(properties); ValidateFileTransferMaxBytesInMemoryProperty(properties); ValidateAccountDomain(properties); @@ -267,6 +288,53 @@ internal static SFSessionProperties ParseConnectionString(string connectionStrin return properties; } + private static void ValidateAuthenticator(SFSessionProperties properties) + { + var knownAuthenticators = new[] { + BasicAuthenticator.AUTH_NAME, + OktaAuthenticator.AUTH_NAME, + OAuthAuthenticator.AUTH_NAME, + KeyPairAuthenticator.AUTH_NAME, + ExternalBrowserAuthenticator.AUTH_NAME + }; + + if (properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator)) + { + authenticator = authenticator.ToLower(); + if (!knownAuthenticators.Contains(authenticator) && !(authenticator.Contains(OktaAuthenticator.AUTH_NAME) && authenticator.StartsWith("https://"))) + { + var error = $"Unknown authenticator: {authenticator}"; + logger.Error(error); + throw new SnowflakeDbException(SFError.UNKNOWN_AUTHENTICATOR, authenticator); + } + } + } + + internal bool IsNonEmptyValueProvided(SFSessionProperty property) => + TryGetValue(property, out var propertyValueStr) && !string.IsNullOrEmpty(propertyValueStr); + + private static string BuildConnectionStringWithoutSecrets(ref string[] keys, ref string[] values) + { + var count = keys.Length; + var result = new StringBuilder(); + for (var i = 0; i < count; i++ ) + { + if (!IsSecretProperty(keys[i])) + { + result.Append(keys[i]); + result.Append("="); + result.Append(values[i]); + result.Append(";"); + } + } + return result.ToString(); + } + + private static bool IsSecretProperty(string propertyName) + { + return s_secretProps.Contains(propertyName, StringComparer.OrdinalIgnoreCase); + } + private static void UpdatePropertiesForSpecialCases(SFSessionProperties properties, string connectionString) { var propertyEntry = connectionString.Split(';'); @@ -342,7 +410,7 @@ private static bool IsAccountRegexMatched(string account) => .Select(regex => Regex.Match(account, regex, RegexOptions.IgnoreCase)) .All(match => match.Success); - private static void checkSessionProperties(SFSessionProperties properties) + private static void CheckSessionProperties(SFSessionProperties properties) { foreach (SFSessionProperty sessionProperty in Enum.GetValues(typeof(SFSessionProperty))) { @@ -350,17 +418,23 @@ private static void checkSessionProperties(SFSessionProperties properties) if (IsRequired(sessionProperty, properties) && !properties.ContainsKey(sessionProperty)) { - SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, - sessionProperty); + SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, sessionProperty); logger.Error("Missing connection property", e); throw e; } + if (IsRequired(sessionProperty, properties) && string.IsNullOrEmpty(properties[sessionProperty])) + { + SnowflakeDbException e = new SnowflakeDbException(SFError.MISSING_CONNECTION_PROPERTY, sessionProperty); + logger.Error("Empty connection property", e); + throw e; + } + // add default value to the map string defaultVal = sessionProperty.GetAttribute().defaultValue; if (defaultVal != null && !properties.ContainsKey(sessionProperty)) { - logger.Debug($"Sesssion property {sessionProperty} set to default value: {defaultVal}"); + logger.Debug($"Session property {sessionProperty} set to default value: {defaultVal}"); properties.Add(sessionProperty, defaultVal); } } @@ -424,6 +498,12 @@ private static bool IsRequired(SFSessionProperty sessionProperty, SFSessionPrope return !authenticatorDefined || !authenticatorsWithoutUsername .Any(auth => auth.Equals(authenticator, StringComparison.OrdinalIgnoreCase)); } + else if (sessionProperty.Equals(SFSessionProperty.TOKEN)) + { + var authenticatorDefined = properties.TryGetValue(SFSessionProperty.AUTHENTICATOR, out var authenticator); + + return !authenticatorDefined || authenticator.Equals(OAuthAuthenticator.AUTH_NAME); + } else { return sessionProperty.GetAttribute().required; diff --git a/Snowflake.Data/Core/Session/SessionCreationToken.cs b/Snowflake.Data/Core/Session/SessionCreationToken.cs new file mode 100644 index 000000000..8d26a3261 --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionCreationToken.cs @@ -0,0 +1,22 @@ +using System; +using Snowflake.Data.Core.Tools; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionCreationToken + { + public Guid Id { get; } + private readonly long _grantedAtAsEpochMillis; + private readonly TimeSpan _timeout; + + public SessionCreationToken(TimeSpan timeout) + { + Id = Guid.NewGuid(); + _grantedAtAsEpochMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _timeout = timeout; + } + + public bool IsExpired(long nowMillis) => + TimeoutHelper.IsExpired(_grantedAtAsEpochMillis, nowMillis, _timeout); + } +} diff --git a/Snowflake.Data/Core/Session/SessionCreationTokenCounter.cs b/Snowflake.Data/Core/Session/SessionCreationTokenCounter.cs new file mode 100644 index 000000000..32ba7e55b --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionCreationTokenCounter.cs @@ -0,0 +1,75 @@ +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionCreationTokenCounter: ISessionCreationTokenCounter + { + private readonly TimeSpan _timeout; + private readonly ReaderWriterLockSlim _tokenLock = new ReaderWriterLockSlim(); + private readonly List _tokens = new List(); + + public SessionCreationTokenCounter(TimeSpan timeout) + { + _timeout = timeout; + } + + public SessionCreationToken NewToken() + { + _tokenLock.EnterWriteLock(); + try + { + var token = new SessionCreationToken(_timeout); + _tokens.Add(token); + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _tokens.RemoveAll(t => t.IsExpired(now)); + return token; + } + finally + { + _tokenLock.ExitWriteLock(); + } + } + + public void RemoveToken(SessionCreationToken creationToken) + { + _tokenLock.EnterWriteLock(); + try + { + var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + _tokens.RemoveAll(t => creationToken.Id == t.Id || t.IsExpired(now)); + } + finally + { + _tokenLock.ExitWriteLock(); + } + } + + public int Count() + { + _tokenLock.EnterReadLock(); + try + { + return _tokens.Count; + } + finally + { + _tokenLock.ExitReadLock(); + } + } + + public void Reset() + { + _tokenLock.EnterWriteLock(); + try + { + _tokens.Clear(); + } + finally + { + _tokenLock.ExitWriteLock(); + } + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionFactory.cs b/Snowflake.Data/Core/Session/SessionFactory.cs new file mode 100644 index 000000000..2eb0ba6df --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionFactory.cs @@ -0,0 +1,12 @@ +using System.Security; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionFactory : ISessionFactory + { + public SFSession NewSession(string connectionString, SecureString password) + { + return new SFSession(connectionString, password); + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionOrCreationTokens.cs b/Snowflake.Data/Core/Session/SessionOrCreationTokens.cs new file mode 100644 index 000000000..2185c3f51 --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionOrCreationTokens.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionOrCreationTokens + { + internal static readonly List s_emptySessionCreationTokenList = new List(); // used as a memory optimization not to create many short living empty list + + public SFSession Session { get; } + public List SessionCreationTokens { get; } + + public SessionOrCreationTokens(SFSession session) + { + Session = session ?? throw new Exception("Internal error: missing session"); + SessionCreationTokens = s_emptySessionCreationTokenList; + } + + public SessionOrCreationTokens(List sessionCreationTokens) + { + Session = null; + if (sessionCreationTokens == null || sessionCreationTokens.Count == 0) + { + throw new Exception("Internal error: missing session creation token"); + } + SessionCreationTokens = sessionCreationTokens; + } + + public List BackgroundSessionCreationTokens() + { + if (Session == null) + { + return SessionCreationTokens.Skip(1).ToList(); + } + return SessionCreationTokens; + } + + public SessionCreationToken SessionCreationToken() => SessionCreationTokens.First(); + } +} diff --git a/Snowflake.Data/Core/Session/SessionPool.cs b/Snowflake.Data/Core/Session/SessionPool.cs index 14ad70848..de66c2240 100644 --- a/Snowflake.Data/Core/Session/SessionPool.cs +++ b/Snowflake.Data/Core/Session/SessionPool.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Snowflake.Data.Client; +using Snowflake.Data.Core.Tools; using Snowflake.Data.Log; namespace Snowflake.Data.Core.Session @@ -16,112 +17,350 @@ namespace Snowflake.Data.Core.Session sealed class SessionPool : IDisposable { private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); - private static readonly object s_sessionPoolLock = new object(); - private readonly List _sessions; - private int _maxPoolSize; - private long _timeout; - private const int MaxPoolSize = 10; - private const long Timeout = 3600; - private bool _pooling = true; + private readonly object _sessionPoolLock = new object(); + private static ISessionFactory s_sessionFactory = new SessionFactory(); - internal SessionPool() + private readonly Guid _id = Guid.NewGuid(); + private readonly List _idleSessions; + private readonly IWaitingQueue _waitingForIdleSessionQueue; + private readonly ISessionCreationTokenCounter _sessionCreationTokenCounter; + private readonly ISessionCreationTokenCounter _noPoolingSessionCreationTokenCounter = new NonCountingSessionCreationTokenCounter(); + internal string ConnectionString { get; } + internal SecureString Password { get; } + private readonly string _connectionStringWithoutSecrets; + private readonly ICounter _busySessionsCounter; + private ISessionPoolEventHandler _sessionPoolEventHandler = new SessionPoolEventHandler(); // a way to inject some additional behaviour after certain events. Can be used for example to measure time of given steps. + private readonly ConnectionPoolConfig _poolConfig; + private bool _configOverriden = false; + private bool _underDestruction = false; + + private static readonly InvalidOperationException s_notSupportedInCachePoolException = new InvalidOperationException("Feature not supported in a Connection Cache"); + + private SessionPool() { - lock (s_sessionPoolLock) - { - _sessions = new List(); - _maxPoolSize = MaxPoolSize; - _timeout = Timeout; - } + // acquiring a lock not needed because one is already acquired in SnowflakeDbConnectionPool + _idleSessions = new List(); + _busySessionsCounter = new FixedZeroCounter(); + _waitingForIdleSessionQueue = new NonWaitingQueue(); + _sessionCreationTokenCounter = new NonCountingSessionCreationTokenCounter(); + _poolConfig = new ConnectionPoolConfig(); + } + + private SessionPool(string connectionString, SecureString password, ConnectionPoolConfig poolConfig, string connectionStringWithoutSecrets) + { + // acquiring a lock not needed because one is already acquired in ConnectionPoolManager + _idleSessions = new List(); + _busySessionsCounter = new NonNegativeCounter(); + ConnectionString = connectionString; + Password = password; + _connectionStringWithoutSecrets = connectionStringWithoutSecrets; + _waitingForIdleSessionQueue = new WaitingQueue(); + _poolConfig = poolConfig; + _sessionCreationTokenCounter = new SessionCreationTokenCounter(_poolConfig.ConnectionTimeout); + } + + internal static SessionPool CreateSessionCache() => new SessionPool(); + + internal static SessionPool CreateSessionPool(string connectionString, SecureString password) + { + s_logger.Debug("Creating a connection pool"); + var extracted = ExtractConfig(connectionString, password); + s_logger.Debug("Creating a connection pool identified by: " + extracted.Item2); + return new SessionPool(connectionString, password, extracted.Item1, extracted.Item2); } - + ~SessionPool() { // Use async for the finalizer due to possible deadlock // when waiting for the CloseResponse task while closing the session - ClearAllPoolsAsync(); + DestroyPoolAsync(); } public void Dispose() { - ClearAllPools(); + DestroyPool(); + } + + internal static ISessionFactory SessionFactory + { + set => s_sessionFactory = value; } private void CleanExpiredSessions() { - s_logger.Debug("SessionPool::CleanExpiredSessions"); - lock (s_sessionPoolLock) + s_logger.Debug("SessionPool::CleanExpiredSessions" + PoolIdentification()); + lock (_sessionPoolLock) { - long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + var timeNow = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); - foreach (var item in _sessions.ToList()) + foreach (var item in _idleSessions.ToList()) { - if (item.IsExpired(_timeout, timeNow)) + if (item.IsExpired(_poolConfig.ExpirationTimeout, timeNow)) { Task.Run(() => item.close()); - _sessions.Remove(item); + _idleSessions.Remove(item); } } } } + internal static Tuple ExtractConfig(string connectionString, SecureString password) + { + try + { + var properties = SFSessionProperties.ParseConnectionString(connectionString, password); + var extractedProperties = SFSessionHttpClientProperties.ExtractAndValidate(properties); + extractedProperties.DisablePoolingDefaultIfSecretsProvidedExternally(properties); + return Tuple.Create(extractedProperties.BuildConnectionPoolConfig(), properties.ConnectionStringWithoutSecrets); + } + catch (Exception exception) + { + s_logger.Error("Failed to extract pool configuration", exception); + throw; + } + } + + internal void ValidateSecurePassword(SecureString password) + { + if (!ExtractPassword(Password).Equals(ExtractPassword(password))) + { + var errorMessage = "Could not get a pool because of password mismatch"; + s_logger.Error(errorMessage + PoolIdentification()); + throw new Exception(errorMessage); + } + } + + private string ExtractPassword(SecureString password) => + password == null ? string.Empty : SecureStringHelper.Decode(password); + internal SFSession GetSession(string connStr, SecureString password) { - s_logger.Debug("SessionPool::GetSession"); - if (!_pooling) - return NewSession(connStr, password); - SFSession session = GetIdleSession(connStr); - return session ?? NewSession(connStr, password); + s_logger.Debug("SessionPool::GetSession" + PoolIdentification()); + if (!GetPooling()) + return NewNonPoolingSession(connStr, password); + var sessionOrCreateTokens = GetIdleSession(connStr); + if (sessionOrCreateTokens.Session != null) + { + _sessionPoolEventHandler.OnSessionProvided(this); + } + ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); + WarnAboutOverridenConfig(); + return sessionOrCreateTokens.Session ?? NewSession(connStr, password, sessionOrCreateTokens.SessionCreationToken()); + } + + internal async Task GetSessionAsync(string connStr, SecureString password, CancellationToken cancellationToken) + { + s_logger.Debug("SessionPool::GetSessionAsync" + PoolIdentification()); + if (!GetPooling()) + return await NewNonPoolingSessionAsync(connStr, password, cancellationToken).ConfigureAwait(false); + var sessionOrCreateTokens = GetIdleSession(connStr); + if (sessionOrCreateTokens.Session != null) + { + _sessionPoolEventHandler.OnSessionProvided(this); + } + ScheduleNewIdleSessions(connStr, password, sessionOrCreateTokens.BackgroundSessionCreationTokens()); + WarnAboutOverridenConfig(); + return sessionOrCreateTokens.Session ?? await NewSessionAsync(connStr, password, sessionOrCreateTokens.SessionCreationToken(), cancellationToken).ConfigureAwait(false); } - - internal Task GetSessionAsync(string connStr, SecureString password, CancellationToken cancellationToken) + + private void ScheduleNewIdleSessions(string connStr, SecureString password, List tokens) { - s_logger.Debug("SessionPool::GetSessionAsync"); - if (!_pooling) - return NewSessionAsync(connStr, password, cancellationToken); - SFSession session = GetIdleSession(connStr); - return session != null ? Task.FromResult(session) : NewSessionAsync(connStr, password, cancellationToken); + tokens.ForEach(token => ScheduleNewIdleSession(connStr, password, token)); } - private SFSession GetIdleSession(string connStr) + private void ScheduleNewIdleSession(string connStr, SecureString password, SessionCreationToken token) { - s_logger.Debug("SessionPool::GetIdleSession"); - lock (s_sessionPoolLock) + Task.Run(() => { - for (int i = 0; i < _sessions.Count; i++) + var session = NewSession(connStr, password, token); + AddSession(session, false); // we don't want to ensure min pool size here because we could get into infinite recursion if expirationTimeout would be very low + }); + } + + private void WarnAboutOverridenConfig() + { + if (IsConfigOverridden() && GetPooling() && IsMultiplePoolsVersion()) + { + s_logger.Warn("Providing a connection from a pool for which technical configuration has been overriden by the user"); + } + } + + internal bool IsConfigOverridden() => _configOverriden; + + internal SFSession GetSession() => GetSession(ConnectionString, Password); + + internal Task GetSessionAsync(CancellationToken cancellationToken) => + GetSessionAsync(ConnectionString, Password, cancellationToken); + + internal void SetSessionPoolEventHandler(ISessionPoolEventHandler sessionPoolEventHandler) + { + _sessionPoolEventHandler = sessionPoolEventHandler; + } + + private SessionOrCreationTokens GetIdleSession(string connStr) + { + s_logger.Debug("SessionPool::GetIdleSession" + PoolIdentification()); + lock (_sessionPoolLock) + { + if (_waitingForIdleSessionQueue.IsAnyoneWaiting()) { - if (_sessions[i].connStr.Equals(connStr)) + s_logger.Debug("SessionPool::GetIdleSession - someone is already waiting for a session, request is going to be queued" + PoolIdentification()); + } + else + { + var session = ExtractIdleSession(connStr); + if (session != null) { - SFSession session = _sessions[i]; - _sessions.RemoveAt(i); - long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - if (session.IsExpired(_timeout, timeNow)) - { - Task.Run(() => session.close()); - i--; - } - else + s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, an idle session was retrieved from the pool" + PoolIdentification()); + return new SessionOrCreationTokens(session); + } + s_logger.Debug("SessionPool::GetIdleSession - no thread was waiting for a session, but could not find any idle session available in the pool" + PoolIdentification()); + var sessionsCount = AllowedNumberOfNewSessionCreations(1); + if (sessionsCount > 0) + { + // there is no need to wait for a session since we can create new ones + return new SessionOrCreationTokens(RegisterSessionCreations(sessionsCount)); + } + } + } + return new SessionOrCreationTokens(WaitForSession(connStr)); + } + + private List RegisterSessionCreationsWhenReturningSessionToPool() + { + var count = AllowedNumberOfNewSessionCreations(0); + return RegisterSessionCreations(count); + } + + private List RegisterSessionCreations(int sessionsCount) => + Enumerable.Range(1, sessionsCount) + .Select(_ => _sessionCreationTokenCounter.NewToken()) + .ToList(); + + private int AllowedNumberOfNewSessionCreations(int atLeastCount) + { + // we are expecting to create atLeast 1 session in case of opening a connection (atLeastCount = 1) + // but we have no expectations when closing a connection (atLeastCount = 0) + if (!IsMultiplePoolsVersion()) + { + if (atLeastCount > 0) + s_logger.Debug($"SessionPool - creating of new sessions is not limited"); + return atLeastCount; // we are either in old pool or there is no pooling + } + var currentSize = GetCurrentPoolSize(); + if (currentSize < _poolConfig.MaxPoolSize) + { + var maxSessionsToCreate = _poolConfig.MaxPoolSize - currentSize; + var sessionsNeeded = Math.Max(_poolConfig.MinPoolSize - currentSize, atLeastCount); + var sessionsToCreate = Math.Min(sessionsNeeded, maxSessionsToCreate); + s_logger.Debug($"SessionPool - allowed to create {sessionsToCreate} sessions, current pool size is {currentSize} out of {_poolConfig.MaxPoolSize}" + PoolIdentification()); + return sessionsToCreate; + } + s_logger.Debug($"SessionPool - not allowed to create a session, current pool size is {currentSize} out of {_poolConfig.MaxPoolSize}" + PoolIdentification()); + return 0; + } + + private bool IsMultiplePoolsVersion() => _waitingForIdleSessionQueue.IsWaitingEnabled(); + + private SFSession WaitForSession(string connStr) + { + if (TimeoutHelper.IsInfinite(_poolConfig.WaitingForIdleSessionTimeout)) + throw new Exception("WaitingForIdleSessionTimeout cannot be infinite"); + s_logger.Info($"SessionPool::WaitForSession for {(long) _poolConfig.WaitingForIdleSessionTimeout.TotalMilliseconds} ms timeout" + PoolIdentification()); + _sessionPoolEventHandler.OnWaitingForSessionStarted(this); + var beforeWaitingTimeMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + long nowTimeMillis = beforeWaitingTimeMillis; + while (GetPooling() && !_underDestruction && !TimeoutHelper.IsExpired(beforeWaitingTimeMillis, nowTimeMillis, _poolConfig.WaitingForIdleSessionTimeout)) // we loop to handle the case if someone overtook us after being woken or session which we were promised has just expired + { + var timeoutLeftMillis = TimeoutHelper.FiniteTimeoutLeftMillis(beforeWaitingTimeMillis, nowTimeMillis, _poolConfig.WaitingForIdleSessionTimeout); + _sessionPoolEventHandler.OnWaitingForSessionStarted(this, timeoutLeftMillis); + var successful = _waitingForIdleSessionQueue.Wait((int) timeoutLeftMillis, CancellationToken.None); + if (successful) + { + s_logger.Debug($"SessionPool::WaitForSession - woken with a session granted" + PoolIdentification()); + _sessionPoolEventHandler.OnWaitingForSessionSuccessful(this); + lock (_sessionPoolLock) + { + var session = ExtractIdleSession(connStr); + if (session != null) { - s_logger.Debug($"reuse pooled session with sid {session.sessionId}"); + s_logger.Debug("SessionPool::WaitForSession - provided an idle session" + PoolIdentification()); return session; } } } + else + { + s_logger.Debug("SessionPool::WaitForSession - woken without a session granted" + PoolIdentification()); + } + nowTimeMillis = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + } + s_logger.Info("SessionPool::WaitForSession - could not find any idle session available withing a given timeout" + PoolIdentification()); + throw WaitingFailedException(); + } + + private static Exception WaitingFailedException() => new Exception("Could not obtain a connection from the pool within a given timeout"); + + private SFSession ExtractIdleSession(string connStr) + { + for (int i = 0; i < _idleSessions.Count; i++) + { + if (_idleSessions[i].ConnectionString.Equals(connStr)) + { + SFSession session = _idleSessions[i]; + _idleSessions.RemoveAt(i); + var timeNow = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + if (session.IsExpired(_poolConfig.ExpirationTimeout, timeNow)) + { + Task.Run(() => session.close()); + i--; + } + else + { + s_logger.Debug($"reuse pooled session with sid {session.sessionId}" + PoolIdentification()); + _busySessionsCounter.Increase(); + return session; + } + } } return null; } - private SFSession NewSession(String connectionString, SecureString password) + private SFSession NewNonPoolingSession(String connectionString, SecureString password) => + NewSession(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken()); + + private SFSession NewSession(String connectionString, SecureString password, SessionCreationToken sessionCreationToken) { - s_logger.Debug("SessionPool::NewSession"); + s_logger.Debug("SessionPool::NewSession" + PoolIdentification()); try { - var session = new SFSession(connectionString, password); + var session = s_sessionFactory.NewSession(connectionString, password); session.Open(); + s_logger.Debug("SessionPool::NewSession - opened" + PoolIdentification()); + if (GetPooling() && !_underDestruction) + { + lock (_sessionPoolLock) + { + _sessionCreationTokenCounter.RemoveToken(sessionCreationToken); + _busySessionsCounter.Increase(); + s_logger.Debug($"Pool state after creating a session {GetCurrentState()}" + PoolIdentification()); + } + } + _sessionPoolEventHandler.OnNewSessionCreated(this); + _sessionPoolEventHandler.OnSessionProvided(this); return session; } catch (Exception e) { // Otherwise when Dispose() is called, the close request would timeout. + _sessionCreationTokenCounter.RemoveToken(sessionCreationToken); + if (GetPooling()) + { + lock (_sessionPoolLock) + { + s_logger.Debug($"Failed to create a new session {GetCurrentState()}" + PoolIdentification()); + } + } if (e is SnowflakeDbException) throw; throw new SnowflakeDbException( @@ -132,14 +371,32 @@ private SFSession NewSession(String connectionString, SecureString password) } } - private Task NewSessionAsync(String connectionString, SecureString password, CancellationToken cancellationToken) + private Task NewNonPoolingSessionAsync( + String connectionString, + SecureString password, + CancellationToken cancellationToken) => + NewSessionAsync(connectionString, password, _noPoolingSessionCreationTokenCounter.NewToken(), cancellationToken); + + private Task NewSessionAsync(String connectionString, SecureString password, SessionCreationToken sessionCreationToken, CancellationToken cancellationToken) { - s_logger.Debug("SessionPool::NewSessionAsync"); - var session = new SFSession(connectionString, password); + s_logger.Debug("SessionPool::NewSessionAsync" + PoolIdentification()); + var session = s_sessionFactory.NewSession(connectionString, password); return session .OpenAsync(cancellationToken) .ContinueWith(previousTask => { + if (previousTask.IsFaulted || previousTask.IsCanceled) + { + _sessionCreationTokenCounter.RemoveToken(sessionCreationToken); + if (GetPooling()) + { + lock (_sessionPoolLock) + { + s_logger.Debug($"Failed to create a new session {GetCurrentState()}" + PoolIdentification()); + } + } + } + if (previousTask.IsFaulted && previousTask.Exception != null) throw previousTask.Exception; @@ -149,101 +406,287 @@ private Task NewSessionAsync(String connectionString, SecureString pa SFError.INTERNAL_ERROR, "Failure while opening session async"); + if (!previousTask.IsCanceled) + { + if (GetPooling() && !_underDestruction) + { + lock (_sessionPoolLock) + { + _sessionCreationTokenCounter.RemoveToken(sessionCreationToken); + _busySessionsCounter.Increase(); + s_logger.Debug($"Pool state after creating a session {GetCurrentState()}" + PoolIdentification()); + } + } + + _sessionPoolEventHandler.OnNewSessionCreated(this); + _sessionPoolEventHandler.OnSessionProvided(this); + } return session; }, TaskContinuationOptions.NotOnCanceled); } - internal bool AddSession(SFSession session) + internal void ReleaseBusySession(SFSession session) { - s_logger.Debug("SessionPool::AddSession"); - if (!_pooling) + s_logger.Debug("SessionPool::ReleaseBusySession" + PoolIdentification()); + SessionPoolState poolState; + lock (_sessionPoolLock) + { + _busySessionsCounter.Decrease(); + poolState = GetCurrentState(); + } + s_logger.Debug($"After releasing a busy session from the pool {poolState}" + PoolIdentification()); + } + + internal bool AddSession(SFSession session, bool ensureMinPoolSize) + { + s_logger.Debug("SessionPool::AddSession" + PoolIdentification()); + + if (!GetPooling() || _underDestruction) return false; - long timeNow = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); - if (session.IsNotOpen() || session.IsExpired(_timeout, timeNow)) + + if (IsMultiplePoolsVersion() && + session.SessionPropertiesChanged && + _poolConfig.ChangedSession == ChangedSessionBehavior.Destroy) + { + s_logger.Debug($"Session returning to pool was changed. Destroying the session: {session.sessionId}."); + session.SetPooling(false); + } + + if (!session.GetPooling()) + { + ReleaseBusySession(session); + if (ensureMinPoolSize) + { + ScheduleNewIdleSessions(ConnectionString, Password, RegisterSessionCreationsWhenReturningSessionToPool()); + } return false; + } + + var result = ReturnSessionToPool(session, ensureMinPoolSize); + var wasSessionReturnedToPool = result.Item1; + var sessionCreationTokens = result.Item2; + ScheduleNewIdleSessions(ConnectionString, Password, sessionCreationTokens); + return wasSessionReturnedToPool; + } + + private Tuple> ReturnSessionToPool(SFSession session, bool ensureMinPoolSize) + { + long timeNow = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds(); + if (session.IsNotOpen() || session.IsExpired(_poolConfig.ExpirationTimeout, timeNow)) + { + lock (_sessionPoolLock) + { + _busySessionsCounter.Decrease(); + var sessionCreationTokens = ensureMinPoolSize + ? RegisterSessionCreationsWhenReturningSessionToPool() + : SessionOrCreationTokens.s_emptySessionCreationTokenList; + var poolState = GetCurrentState(); + s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); + return Tuple.Create(false, sessionCreationTokens); + } + } - lock (s_sessionPoolLock) + lock (_sessionPoolLock) { - if (_sessions.Count >= _maxPoolSize) + _busySessionsCounter.Decrease(); + CleanExpiredSessions(); + if (session.IsExpired(_poolConfig.ExpirationTimeout, DateTimeOffset.UtcNow.ToUnixTimeMilliseconds())) // checking again because we could have spent some time waiting for a lock { - CleanExpiredSessions(); + var sessionCreationTokens = ensureMinPoolSize + ? RegisterSessionCreationsWhenReturningSessionToPool() + : SessionOrCreationTokens.s_emptySessionCreationTokenList; + var poolState = GetCurrentState(); + s_logger.Debug($"Could not return session to pool {poolState}" + PoolIdentification()); + return Tuple.Create(false, sessionCreationTokens); } - if (_sessions.Count >= _maxPoolSize) + var poolStateBeforeReturningToPool = GetCurrentState(); + if (poolStateBeforeReturningToPool.Count() >= _poolConfig.MaxPoolSize) { - // pool is full - return false; + s_logger.Warn($"Pool is full - unable to add session with sid {session.sessionId} {poolStateBeforeReturningToPool}"); + return Tuple.Create(false, SessionOrCreationTokens.s_emptySessionCreationTokenList); } + _idleSessions.Add(session); + _waitingForIdleSessionQueue.OnResourceIncrease(); + var sessionCreationTokensAfterReturningToPool = ensureMinPoolSize + ? RegisterSessionCreationsWhenReturningSessionToPool() + : SessionOrCreationTokens.s_emptySessionCreationTokenList; + var poolStateAfterReturningToPool = GetCurrentState(); + s_logger.Debug($"returned session with sid {session.sessionId} to pool {poolStateAfterReturningToPool}" + PoolIdentification()); + return Tuple.Create(true, sessionCreationTokensAfterReturningToPool); + } + } + + internal void DestroyPool() + { + s_logger.Debug("SessionPool::DestroyPool" + PoolIdentification()); + lock (_sessionPoolLock) + { + _underDestruction = true; + ClearIdleSessions(); + _busySessionsCounter.Reset(); + _waitingForIdleSessionQueue.Reset(); + _sessionCreationTokenCounter.Reset(); + } + } - s_logger.Debug($"pool connection with sid {session.sessionId}"); - _sessions.Add(session); - return true; + internal void DestroyPoolAsync() + { + s_logger.Debug("SessionPool::DestroyPoolAsync" + PoolIdentification()); + lock (_sessionPoolLock) + { + _underDestruction = true; + ClearIdleSessionsAsync(); + _busySessionsCounter.Reset(); + _waitingForIdleSessionQueue.Reset(); + _sessionCreationTokenCounter.Reset(); } } - internal void ClearAllPools() + internal void ClearSessions() { - s_logger.Debug("SessionPool::ClearAllPools"); - lock (s_sessionPoolLock) + s_logger.Debug($"SessionPool::ClearSessions" + PoolIdentification()); + lock (_sessionPoolLock) { - foreach (SFSession session in _sessions) + _busySessionsCounter.Reset(); + ClearIdleSessions(); + _waitingForIdleSessionQueue.Reset(); + } + } + + internal void ClearIdleSessions() + { + s_logger.Debug("SessionPool::ClearIdleSessions" + PoolIdentification()); + lock (_sessionPoolLock) + { + foreach (SFSession session in _idleSessions) { session.close(); // it is left synchronously here because too much async tasks slows down testing } - _sessions.Clear(); + _idleSessions.Clear(); } } - internal async void ClearAllPoolsAsync() + internal async void ClearIdleSessionsAsync() { - s_logger.Debug("SessionPool::ClearAllPoolsAsync"); - foreach (SFSession session in _sessions) + s_logger.Debug("SessionPool::ClearIdleSessionsAsync" + PoolIdentification()); + IEnumerable idleSessionsCopy; + lock (_sessionPoolLock) + { + idleSessionsCopy = _idleSessions.Select(session => session); + _idleSessions.Clear(); + } + foreach (SFSession session in idleSessionsCopy) { await session.CloseAsync(CancellationToken.None).ConfigureAwait(false); } - _sessions.Clear(); } public void SetMaxPoolSize(int size) { - _maxPoolSize = size; + s_logger.Debug($"SessionPool::SetMaxPoolSize({size})" + PoolIdentification()); + _poolConfig.MaxPoolSize = size; + _configOverriden = true; } - public int GetMaxPoolSize() + public int GetMaxPoolSize() => _poolConfig.MaxPoolSize; + + public int GetMinPoolSize() { - return _maxPoolSize; + return IsMultiplePoolsVersion() + ? _poolConfig.MinPoolSize + : throw s_notSupportedInCachePoolException; } - public void SetTimeout(long time) + public ChangedSessionBehavior GetChangedSession() => + IsMultiplePoolsVersion() + ? _poolConfig.ChangedSession + : throw s_notSupportedInCachePoolException; + + public long GetWaitForIdleSessionTimeout() => + IsMultiplePoolsVersion() + ? (long)_poolConfig.WaitingForIdleSessionTimeout.TotalSeconds + : throw s_notSupportedInCachePoolException; + + public long GetConnectionTimeout() + { + return TimeoutHelper.IsInfinite(_poolConfig.ConnectionTimeout) ? -1 : (long)_poolConfig.ConnectionTimeout.TotalSeconds; + } + + public void SetTimeout(long seconds) { - _timeout = time; + s_logger.Debug($"SessionPool::SetTimeout({seconds})" + PoolIdentification()); + var timeout = seconds < 0 ? TimeoutHelper.Infinity() : TimeSpan.FromSeconds(seconds); + _poolConfig.ExpirationTimeout = timeout; + _configOverriden = true; } public long GetTimeout() { - return _timeout; + return TimeoutHelper.IsInfinite(_poolConfig.ExpirationTimeout) ? -1 : (long)_poolConfig.ExpirationTimeout.TotalSeconds; } public int GetCurrentPoolSize() { - return _sessions.Count; + return _idleSessions.Count + _busySessionsCounter.Count() + _sessionCreationTokenCounter.Count(); + } + + public SessionPoolState GetCurrentState() + { + return new SessionPoolState( + _idleSessions.Count, + _busySessionsCounter.Count(), + _sessionCreationTokenCounter.Count(), + _waitingForIdleSessionQueue.WaitingCount(), + IsMultiplePoolsVersion() + ); } public bool SetPooling(bool isEnable) { - s_logger.Info($"SessionPool::SetPooling({isEnable})"); - if (_pooling == isEnable) + s_logger.Info($"SessionPool::SetPooling({isEnable})" + PoolIdentification()); + if (_poolConfig.PoolingEnabled == isEnable) return false; - _pooling = isEnable; - if (!_pooling) + _poolConfig.PoolingEnabled = isEnable; + if (!_poolConfig.PoolingEnabled) { - ClearAllPools(); + ClearSessions(); } + _configOverriden = true; return true; } - public bool GetPooling() + public bool GetPooling() => _poolConfig.PoolingEnabled; + + internal int OngoingSessionCreationsCount() { - return _pooling; + lock (_sessionPoolLock) + { + return _sessionCreationTokenCounter.Count(); + } } + + internal List GetIdleSessionsStartTimes() + { + lock (_sessionPoolLock) + { + return _idleSessions.Select(s => s.GetStartTime()).ToList(); + } + } + + internal string PoolIdentification() + { + if (!IsMultiplePoolsVersion()) + return ""; + return +#if SF_PUBLIC_ENVIRONMENT + PoolIdentificationBasedOnInternalId; +#else + PoolIdentificationBasedOnConnectionString; +#endif + } + + internal string PoolIdentificationBasedOnConnectionString => " [pool: " + _connectionStringWithoutSecrets + "]"; + + internal string PoolIdentificationBasedOnInternalId => " [pool: " + _id + "]"; } } diff --git a/Snowflake.Data/Core/Session/SessionPoolEventHandler.cs b/Snowflake.Data/Core/Session/SessionPoolEventHandler.cs new file mode 100644 index 000000000..b1a067d6b --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionPoolEventHandler.cs @@ -0,0 +1,25 @@ +namespace Snowflake.Data.Core.Session +{ + internal class SessionPoolEventHandler: ISessionPoolEventHandler + { + public virtual void OnNewSessionCreated(SessionPool sessionPool) + { + } + + public virtual void OnWaitingForSessionStarted(SessionPool sessionPool) + { + } + + public virtual void OnWaitingForSessionStarted(SessionPool sessionPool, long millisLeft) + { + } + + public virtual void OnWaitingForSessionSuccessful(SessionPool sessionPool) + { + } + + public virtual void OnSessionProvided(SessionPool sessionPool) + { + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionPoolState.cs b/Snowflake.Data/Core/Session/SessionPoolState.cs new file mode 100644 index 000000000..e548c186e --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionPoolState.cs @@ -0,0 +1,29 @@ +namespace Snowflake.Data.Core.Session +{ + internal class SessionPoolState + { + private readonly int _idleSessionsCount; + private readonly int _busySessionsCount; + private readonly int _sessionCreationsCount; + private readonly int _waitingCount; + private readonly bool _extensiveFormat; + + public SessionPoolState(int idleSessionsCount, int busySessionsCount, int sessionCreationsCount, int waitingCount, bool extensiveFormat) + { + _idleSessionsCount = idleSessionsCount; + _busySessionsCount = busySessionsCount; + _sessionCreationsCount = sessionCreationsCount; + _waitingCount = waitingCount; + _extensiveFormat = extensiveFormat; + } + + public int Count() => _idleSessionsCount + _busySessionsCount + _sessionCreationsCount; + + public override string ToString() + { + return _extensiveFormat + ? $"[pool size: {Count()} (idle sessions: {_idleSessionsCount}, busy sessions: {_busySessionsCount}, sessions under creation: {_sessionCreationsCount}), waiting sessions: {_waitingCount}]" + : $"[pool size: {Count()}]"; + } + } +} diff --git a/Snowflake.Data/Core/Session/SessionPropertiesWithDefaultValuesExtractor.cs b/Snowflake.Data/Core/Session/SessionPropertiesWithDefaultValuesExtractor.cs new file mode 100644 index 000000000..f9092c7f9 --- /dev/null +++ b/Snowflake.Data/Core/Session/SessionPropertiesWithDefaultValuesExtractor.cs @@ -0,0 +1,143 @@ +using System; +using System.Linq; +using System.Text.RegularExpressions; +using Snowflake.Data.Core.Tools; +using Snowflake.Data.Log; + +namespace Snowflake.Data.Core.Session +{ + internal class SessionPropertiesWithDefaultValuesExtractor + { + private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger(); + private static readonly Regex s_timeoutFormatRegex = new Regex(@"^(-)?[0-9]{1,10}[mM]?[sS]?$"); + + private readonly SFSessionProperties _propertiesDictionary; + private readonly bool _failOnWrongValue; + + public SessionPropertiesWithDefaultValuesExtractor(SFSessionProperties propertiesDictionary, bool failOnWrongValue) + { + _propertiesDictionary = propertiesDictionary; + _failOnWrongValue = failOnWrongValue; + } + + public bool ExtractBooleanWithDefaultValue(SFSessionProperty property) => + ExtractPropertyWithDefaultValue( + property, + Boolean.Parse, + s => true, + b => true + ); + + public int ExtractPositiveIntegerWithDefaultValue( + SFSessionProperty property) => + ExtractPropertyWithDefaultValue( + property, + int.Parse, + s => true, + i => i > 0 + ); + + public int ExtractNonNegativeIntegerWithDefaultValue( + SFSessionProperty property) => + ExtractPropertyWithDefaultValue( + property, + int.Parse, + s => true, + i => i >= 0 + ); + + public TimeSpan ExtractTimeout( + SFSessionProperty property) => + ExtractPropertyWithDefaultValue( + property, + ExtractTimeout, + ValidateTimeoutFormat, + t => true + ); + + public T ExtractPropertyWithDefaultValue( + SFSessionProperty property, + Func extractor, + Func preExtractValidation, + Func postExtractValidation) + { + var propertyAttribute = property.GetAttribute(); + var defaultValueString = propertyAttribute.defaultValue; + var defaultValue = extractor(defaultValueString); + if (!postExtractValidation(defaultValue)) + { + throw new Exception($"Invalid default value of {property}"); + } + var valueString = _propertiesDictionary[property]; + if (string.IsNullOrEmpty(valueString)) + { + s_logger.Warn($"Parameter {property} not defined. Using a default value: {defaultValue}"); + return defaultValue; + } + if (!preExtractValidation(valueString)) + { + return handleFailedValidation(defaultValue, valueString, property); + } + T value; + try + { + value = extractor(valueString); + } + catch (Exception e) + { + if (_failOnWrongValue) + { + s_logger.Error($"Invalid value of parameter {property}. Error: {e}"); + throw new Exception($"Invalid value of parameter {property}", e); + } + s_logger.Warn($"Invalid value of parameter {property}. Using a default a default value: {defaultValue}"); + return defaultValue; + } + if (!postExtractValidation(value)) + { + return handleFailedValidation(defaultValue, value, property); + } + return value; + } + + private TResult handleFailedValidation( + TResult defaultValue, + TValue value, + SFSessionProperty property) + { + if (_failOnWrongValue) + { + s_logger.Error($"Invalid value of parameter {property}: {value}"); + throw new Exception($"Invalid value of parameter {property}"); + } + s_logger.Warn($"Invalid value of parameter {property}. Using a default value: {defaultValue}"); + return defaultValue; + } + + private static bool ValidateTimeoutFormat(string value) => + !string.IsNullOrEmpty(value) && s_timeoutFormatRegex.IsMatch(value); + + private static TimeSpan ExtractTimeout(string value) + { + var numericValueString = string.Concat(value.Where(IsNumberOrMinus)); + var unitValue = value.Substring(numericValueString.Length).ToLower(); + var numericValue = int.Parse(numericValueString); + if (numericValue < 0) + return TimeoutHelper.Infinity(); + switch (unitValue) + { + case "": + case "s": + return TimeSpan.FromSeconds(numericValue); + case "ms": + return TimeSpan.FromMilliseconds(numericValue); + case "m": + return TimeSpan.FromMinutes(numericValue); + default: + throw new Exception($"unknown timeout unit value: {unitValue}"); + } + } + + private static bool IsNumberOrMinus(char value) => char.IsNumber(value) || value.Equals('-'); + } +} diff --git a/Snowflake.Data/Core/Session/WaitingQueue.cs b/Snowflake.Data/Core/Session/WaitingQueue.cs new file mode 100644 index 000000000..8eeab2282 --- /dev/null +++ b/Snowflake.Data/Core/Session/WaitingQueue.cs @@ -0,0 +1,131 @@ +using System; +using System.Collections.Generic; +using System.Threading; + +namespace Snowflake.Data.Core.Session +{ + internal class WaitingQueue: IWaitingQueue + { + private readonly ReaderWriterLockSlim _lock = new ReaderWriterLockSlim(); + private readonly List _queue = new List(); + private readonly HashSet _notSuccessfulCollection = new HashSet(); + + public bool Wait(int millisecondsTimeout, CancellationToken cancellationToken) + { + var semaphore = new SemaphoreSlim(0, 1); + _lock.EnterWriteLock(); + try + { + _queue.Add(semaphore); + } + finally + { + _lock.ExitWriteLock(); + } + try + { + var waitingResult = semaphore.Wait(millisecondsTimeout, cancellationToken); + bool shouldFail; + _lock.EnterReadLock(); + try + { + shouldFail = _notSuccessfulCollection.Contains(semaphore); + } + finally + { + _lock.ExitReadLock(); + } + if (shouldFail) + { + _lock.EnterWriteLock(); + try + { + _notSuccessfulCollection.Remove(semaphore); + } + finally + { + _lock.ExitWriteLock(); + } + return false; + } + return waitingResult; + } + catch (OperationCanceledException) + { + return false; + } + finally + { + bool removed; + _lock.EnterWriteLock(); + try + { + removed = _queue.Remove(semaphore); + } + finally + { + _lock.ExitWriteLock(); + } + if (!removed && semaphore.CurrentCount > 0) // that means that it was removed by OnResourceIncrease() and not consumed by this waiting because of timeout + { + OnResourceIncrease(); + } + } + } + + public void OnResourceIncrease() + { + SemaphoreSlim semaphore = null; + _lock.EnterWriteLock(); + try + { + if (_queue.Count > 0) + { + semaphore = _queue[0]; + _queue.RemoveAt(0); + } + } + finally + { + _lock.ExitWriteLock(); + } + semaphore?.Release(); + } + + public bool IsAnyoneWaiting() => WaitingCount() > 0; + + public int WaitingCount() + { + _lock.EnterReadLock(); + try + { + return _queue.Count; + } + finally + { + _lock.ExitReadLock(); + } + } + + public bool IsWaitingEnabled() => true; + + public void Reset() + { + _lock.EnterWriteLock(); + try + { + while (_queue.Count > 0) + { + var semaphore = _queue[0]; + _queue.RemoveAt(0); + _notSuccessfulCollection.Add(semaphore); + semaphore?.Release(); + } + } + finally + { + _lock.ExitWriteLock(); + } + } + } +} diff --git a/Snowflake.Data/Core/Tools/SecureStringHelper.cs b/Snowflake.Data/Core/Tools/SecureStringHelper.cs new file mode 100644 index 000000000..5d7b685c1 --- /dev/null +++ b/Snowflake.Data/Core/Tools/SecureStringHelper.cs @@ -0,0 +1,12 @@ +using System.Net; +using System.Security; + +namespace Snowflake.Data.Core.Tools +{ + internal static class SecureStringHelper + { + public static string Decode(SecureString password) => new NetworkCredential(string.Empty, password).Password; + + public static SecureString Encode(string password) => new NetworkCredential(string.Empty, password).SecurePassword; + } +} diff --git a/Snowflake.Data/Core/Tools/TimeoutHelper.cs b/Snowflake.Data/Core/Tools/TimeoutHelper.cs new file mode 100644 index 000000000..ae29de795 --- /dev/null +++ b/Snowflake.Data/Core/Tools/TimeoutHelper.cs @@ -0,0 +1,44 @@ +using System; +using System.Threading; + +namespace Snowflake.Data.Core.Tools +{ + internal class TimeoutHelper + { + public static bool IsExpired(long startedAtMillis, long nowMillis, TimeSpan timeout) + { + if (IsInfinite(timeout)) + return false; + var timeoutInMillis = (long) timeout.TotalMilliseconds; + return startedAtMillis + timeoutInMillis <= nowMillis; + } + + public static bool IsExpired(long elapsedMillis, TimeSpan timeout) + { + if (IsInfinite(timeout)) + return false; + return elapsedMillis >= timeout.TotalMilliseconds; + } + + public static bool IsInfinite(TimeSpan timeout) => timeout == Timeout.InfiniteTimeSpan; + + public static bool IsZeroLength(TimeSpan timeout) + { + if (IsInfinite(timeout)) + return false; + return TimeSpan.Zero == timeout; + } + + public static TimeSpan Infinity() => Timeout.InfiniteTimeSpan; + + public static long FiniteTimeoutLeftMillis(long startedAtMillis, long nowMillis, TimeSpan timeout) + { + if (IsInfinite(timeout)) + { + throw new Exception("Infinite timeout cannot be used to determine milliseconds left"); + } + var passedMillis = nowMillis - startedAtMillis; + return Math.Max((long) timeout.TotalMilliseconds - passedMillis, 0); + } + } +} diff --git a/Snowflake.Data/Logger/SecretDetector.cs b/Snowflake.Data/Logger/SecretDetector.cs index 2c0524fb1..59cd810d6 100644 --- a/Snowflake.Data/Logger/SecretDetector.cs +++ b/Snowflake.Data/Logger/SecretDetector.cs @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Snowflake Computing Inc. All rights reserved. + * Copyright (c) 2021-2024 Snowflake Computing Inc. All rights reserved. */ using System; @@ -76,69 +76,107 @@ private static string MaskCustomPatterns(string text) /* * https://docs.microsoft.com/en-us/dotnet/standard/base-types/character-escapes-in-regular-expressions * . $ ^ { [ ( | ) * + ? \ - * The characters are special regular expression language elements. + * The characters are special regular expression language elements. * To match them in a regular expression, they must be escaped or included in a positive character group. * [ ] \ - ^ * The characters are special character group element. * To match them in a character group, they must be escaped. */ - private static readonly string AWS_KEY_PATTERN = @"(aws_key_id|aws_secret_key|access_key_id|secret_access_key)('|"")?(\s*[:=]\s*)'([^']+)'"; - private static readonly string AWS_TOKEN_PATTERN = @"(accessToken|tempToken|keySecret)\""\s*:\s*\""([a-z0-9/+]{32,}={0,2})\"""; - private static readonly string AWS_SERVER_SIDE_PATTERN = @"((x-amz-server-side-encryption)([a-z0-9\-])*)\s*(:|=)\s*([a-z0-9/_\-+:=])+"; - private static readonly string SAS_TOKEN_PATTERN = @"(sig|signature|AWSAccessKeyId|password|passcode)=([a-z0-9%/+]{16,})"; - private static readonly string PRIVATE_KEY_PATTERN = @"-----BEGIN PRIVATE KEY-----\n([a-z0-9/+=\n]{32,})\n-----END PRIVATE KEY-----"; - private static readonly string PRIVATE_KEY_DATA_PATTERN = @"""privateKeyData"": ""([a-z0-9/+=\n]{10,})"""; - private static readonly string CONNECTION_TOKEN_PATTERN = @"(token|assertion content)(['""\s:=]+)([a-z0-9=/_\-+:]{8,})"; - private static readonly string PASSWORD_PATTERN = @"(password|passcode|pwd|proxypassword)(['""\s:=]+)([a-z0-9!""#$%&'\()*+,-./:;<=>?@\[\]\^_`{|}~]{6,})"; + private const string AwsKeyPattern = @"(aws_key_id|aws_secret_key|access_key_id|secret_access_key)('|"")?(\s*[:=]\s*)'([^']+)'"; + private const string AwsTokenPattern = @"(accessToken|tempToken|keySecret)\""\s*:\s*\""([a-z0-9/+]{32,}={0,2})\"""; + private const string AwsServerSidePattern = @"((x-amz-server-side-encryption)([a-z0-9\-])*)\s*(:|=)\s*([a-z0-9/_\-+:=])+"; + private const string SasTokenPattern = @"(sig|signature|AWSAccessKeyId|password|passcode)=([a-z0-9%/+]{16,})"; + private const string PrivateKeyPattern = @"-----BEGIN PRIVATE KEY-----\n([a-z0-9/+=\n]{32,})\n-----END PRIVATE KEY-----"; // pragma: allowlist secret + private const string PrivateKeyDataPattern = @"""privateKeyData"": ""([a-z0-9/+=\n]{10,})"""; + private const string PrivateKeyPropertyPrefixPattern = @"(private_key\s*=)"; + private const string ConnectionTokenPattern = @"(token|assertion content)(['""\s:=]+)([a-z0-9=/_\-+:]{8,})"; + private const string TokenPropertyPattern = @"(token)(\s*=)(.*)"; + private const string PasswordPattern = @"(password|passcode|pwd|proxypassword|private_key_pwd)(['""\s:=]+)([a-z0-9!""#$%&'\()*+,-./:;<=>?@\[\]\^_`{|}~]{6,})"; + private const string PasswordPropertyPattern = @"(password|proxypassword|private_key_pwd)(\s*=)(.*)"; + + private static readonly Func[] s_maskFunctions = { + MaskAWSServerSide, + MaskAWSKeys, + MaskSASTokens, + MaskAWSTokens, + MaskPrivateKey, + MaskPrivateKeyData, + MaskPrivateKeyProperty, + MaskPassword, + MaskPasswordProperty, + MaskConnectionTokens, + MaskTokenProperty + }; private static string MaskAWSKeys(string text) { - return Regex.Replace(text, AWS_KEY_PATTERN, @"$1$2$3'****'", + return Regex.Replace(text, AwsKeyPattern, @"$1$2$3'****'", RegexOptions.IgnoreCase); } private static string MaskAWSTokens(string text) { - return Regex.Replace(text, AWS_TOKEN_PATTERN, @"$1"":""XXXX""", + return Regex.Replace(text, AwsTokenPattern, @"$1"":""XXXX""", RegexOptions.IgnoreCase); } private static string MaskAWSServerSide(string text) { - return Regex.Replace(text, AWS_SERVER_SIDE_PATTERN, @"$1:....", + return Regex.Replace(text, AwsServerSidePattern, @"$1:....", RegexOptions.IgnoreCase); } private static string MaskSASTokens(string text) { - return Regex.Replace(text, SAS_TOKEN_PATTERN, @"$1=****", + return Regex.Replace(text, SasTokenPattern, @"$1=****", RegexOptions.IgnoreCase); } private static string MaskPrivateKey(string text) { - return Regex.Replace(text, PRIVATE_KEY_PATTERN, "-----BEGIN PRIVATE KEY-----\\\\nXXXX\\\\n-----END PRIVATE KEY-----", + return Regex.Replace(text, PrivateKeyPattern, "-----BEGIN PRIVATE KEY-----\\\\nXXXX\\\\n-----END PRIVATE KEY-----", // pragma: allowlist secret RegexOptions.IgnoreCase | RegexOptions.Multiline); } + private static string MaskPrivateKeyProperty(string text) + { + var match = Regex.Match(text, PrivateKeyPropertyPrefixPattern, RegexOptions.IgnoreCase); + if (match.Success) + { + int length = match.Index + match.Value.Length; + return text.Substring(0, length) + "****"; + } + return text; + } + private static string MaskPrivateKeyData(string text) { - return Regex.Replace(text, PRIVATE_KEY_DATA_PATTERN, @"""privateKeyData"": ""XXXX""", + return Regex.Replace(text, PrivateKeyDataPattern, @"""privateKeyData"": ""XXXX""", RegexOptions.IgnoreCase | RegexOptions.Multiline); } private static string MaskConnectionTokens(string text) { - return Regex.Replace(text, CONNECTION_TOKEN_PATTERN, @"$1$2****", + return Regex.Replace(text, ConnectionTokenPattern, @"$1$2****", RegexOptions.IgnoreCase); } private static string MaskPassword(string text) { - return Regex.Replace(text, PASSWORD_PATTERN, @"$1$2****", + return Regex.Replace(text, PasswordPattern, @"$1$2****", RegexOptions.IgnoreCase); } + private static string MaskPasswordProperty(string text) + { + return Regex.Replace(text, PasswordPropertyPattern, @"$1$2****", RegexOptions.IgnoreCase); + } + + private static string MaskTokenProperty(string text) + { + return Regex.Replace(text, TokenPropertyPattern, @"$1$2****", RegexOptions.IgnoreCase); + } + public static Mask MaskSecrets(string text) { Mask result = new Mask(maskedText: text); @@ -150,19 +188,7 @@ public static Mask MaskSecrets(string text) try { - result.maskedText = - MaskConnectionTokens( - MaskPassword( - MaskPrivateKeyData( - MaskPrivateKey( - MaskAWSTokens( - MaskSASTokens( - MaskAWSKeys( - MaskAWSServerSide(text)))))))); - if (CUSTOM_PATTERNS_LENGTH > 0) - { - result.maskedText = MaskCustomPatterns(result.maskedText); - } + result.maskedText = MaskAllPatterns(text); if (result.maskedText != text) { result.isMasked = true; @@ -179,5 +205,19 @@ public static Mask MaskSecrets(string text) } return result; } + + private static string MaskAllPatterns(string text) + { + string result = text; + foreach (var maskFunction in s_maskFunctions) + { + result = maskFunction.Invoke(result); + } + if (CUSTOM_PATTERNS_LENGTH > 0) + { + result = MaskCustomPatterns(result); + } + return result; + } } } diff --git a/Snowflake.Data/Snowflake.Data.csproj b/Snowflake.Data/Snowflake.Data.csproj index 6670c5bce..d2c2065f2 100644 --- a/Snowflake.Data/Snowflake.Data.csproj +++ b/Snowflake.Data/Snowflake.Data.csproj @@ -67,6 +67,10 @@ $(Version) + + $(DefineConstants);$(DefineAdditionalConstants) + + diff --git a/doc/CodeCoverage.md b/doc/CodeCoverage.md new file mode 100644 index 000000000..497219494 --- /dev/null +++ b/doc/CodeCoverage.md @@ -0,0 +1,72 @@ +## Getting the code coverage + +1. Go to .NET project directory + +2. Clean the directory + +``` +dotnet clean snowflake-connector-net.sln && dotnet nuget locals all --clear +``` + +3. Create parameters.json containing connection info for AWS, AZURE, or GCP account and place inside the Snowflake.Data.Tests folder + +4. Build the project for .NET6 + +``` +dotnet build snowflake-connector-net.sln /p:DebugType=Full +``` + +5. Run dotnet-cover on the .NET6 build + +``` +dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_AWS_coverage.xml --output-format cobertura --settings coverage.config +``` + +6. Build the project for .NET Framework + +``` +msbuild snowflake-connector-net.sln -p:Configuration=Release +``` + +7. Run dotnet-cover on the .NET Framework build + +``` +dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_AWS_coverage.xml --output-format cobertura --settings coverage.config +``` + +
+Repeat steps 3, 5, and 7 for the other cloud providers.
+Note: no need to rebuild the connector again.

+ +For Azure:
+ +3. Create parameters.json containing connection info for AZURE account and place inside the Snowflake.Data.Tests folder + +4. Run dotnet-cover on the .NET6 build + +``` +dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_AZURE_coverage.xml --output-format cobertura --settings coverage.config +``` + +7. Run dotnet-cover on the .NET Framework build + +``` +dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_AZURE_coverage.xml --output-format cobertura --settings coverage.config +``` + +
+For GCP:
+ +3. Create parameters.json containing connection info for GCP account and place inside the Snowflake.Data.Tests folder + +4. Run dotnet-cover on the .NET6 build + +``` +dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_GCP_coverage.xml --output-format cobertura --settings coverage.config +``` + +7. Run dotnet-cover on the .NET Framework build + +``` +dotnet-coverage collect "dotnet test --framework net472 --no-build -l console;verbosity=normal" --output net472_GCP_coverage.xml --output-format cobertura --settings coverage.config +``` diff --git a/doc/Connecting.md b/doc/Connecting.md new file mode 100644 index 000000000..b88281388 --- /dev/null +++ b/doc/Connecting.md @@ -0,0 +1,296 @@ +## Connecting + +To connect to Snowflake, specify a valid connection string composed of key-value pairs separated by semicolons, +i.e "\=\;\=\...". + +**Note**: If the value specified in the connection string contains any signs like semicolon (`;`) or equal sign (`=`) or any phrases which can interfere with parsing the connection string, +please surround the value with double quotation marks (`""`). For example `password="=;;;=dummy==password;;"`. + +The following table lists all valid connection properties: +
+ +| Connection Property | Required | Comment | +|--------------------------------| -------- |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| ACCOUNT | Yes | Your full account name might include additional segments that identify the region and cloud platform where your account is hosted | +| APPLICATION | No | **_Snowflake partner use only_**: Specifies the name of a partner application to connect through .NET. The name must match the following pattern: ^\[A-Za-z](\[A-Za-z0-9.-]){1,50}$ (one letter followed by 1 to 50 letter, digit, .,- or, \_ characters). | +| DB | No | | +| HOST | No | Specifies the hostname for your account in the following format: \.snowflakecomputing.com.
If no value is specified, the driver uses \.snowflakecomputing.com. | +| PASSWORD | Depends | Required if AUTHENTICATOR is set to `snowflake` (the default value) or the URL for native SSO through Okta. Ignored for all the other authentication types. | +| ROLE | No | | +| SCHEMA | No | | +| USER | Depends | If AUTHENTICATOR is set to `externalbrowser` this is optional. For native SSO through Okta, set this to the login name for your identity provider (IdP). | +| WAREHOUSE | No | | +| CONNECTION_TIMEOUT | No | Total timeout in seconds when connecting to Snowflake. The default is 300 seconds | +| RETRY_TIMEOUT | No | Total timeout in seconds for supported endpoints of retry policy. The default is 300 seconds. The value can only be increased from the default value or set to 0 for infinite timeout | +| MAXHTTPRETRIES | No | Maximum number of times to retry failed HTTP requests (default: 7). You can set `MAXHTTPRETRIES=0` to remove the retry limit, but doing so runs the risk of the .NET driver infinitely retrying failed HTTP calls. | +| CLIENT_SESSION_KEEP_ALIVE | No | Whether to keep the current session active after a period of inactivity, or to force the user to login again. If the value is `true`, Snowflake keeps the session active indefinitely, even if there is no activity from the user. If the value is `false`, the user must log in again after four hours of inactivity. The default is `false`. Setting this value overrides the server session property for the current session. | +| BROWSER_RESPONSE_TIMEOUT | No | Number to seconds to wait for authentication in an external browser (default: 120). | +| DISABLERETRY | No | Set this property to `true` to prevent the driver from reconnecting automatically when the connection fails or drops. The default value is `false`. | +| AUTHENTICATOR | No | The method of authentication. Currently supports the following values:
- snowflake (default): You must also set USER and PASSWORD.
- [the URL for native SSO through Okta](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#native-sso-okta-only): You must also set USER and PASSWORD.
- [externalbrowser](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#browser-based-sso): You must also set USER.
- [snowflake_jwt](https://docs.snowflake.com/en/user-guide/key-pair-auth.html): You must also set PRIVATE_KEY_FILE or PRIVATE_KEY.
- [oauth](https://docs.snowflake.com/en/user-guide/oauth.html): You must also set TOKEN. | +| VALIDATE_DEFAULT_PARAMETERS | No | Whether DB, SCHEMA and WAREHOUSE should be verified when making connection. Default to be true. | +| PRIVATE_KEY_FILE | Depends | The path to the private key file to use for key-pair authentication. Must be used in combination with AUTHENTICATOR=snowflake_jwt | +| PRIVATE_KEY_PWD | No | The passphrase to use for decrypting the private key, if the key is encrypted. | +| PRIVATE_KEY | Depends | The private key to use for key-pair authentication. Must be used in combination with AUTHENTICATOR=snowflake_jwt.
If the private key value includes any equal signs (=), make sure to replace each equal sign with two signs (==) to ensure that the connection string is parsed correctly. | +| TOKEN | Depends | The OAuth token to use for OAuth authentication. Must be used in combination with AUTHENTICATOR=oauth. | +| INSECUREMODE | No | Set to true to disable the certificate revocation list check. Default is false. | +| USEPROXY | No | Set to true if you need to use a proxy server. The default value is false.

This parameter was introduced in v2.0.4. | +| PROXYHOST | Depends | The hostname of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | +| PROXYPORT | Depends | The port number of the proxy server.

If USEPROXY is set to `true`, you must set this parameter.

This parameter was introduced in v2.0.4. | +| PROXYUSER | No | The username for authenticating to the proxy server.

This parameter was introduced in v2.0.4. | +| PROXYPASSWORD | Depends | The password for authenticating to the proxy server.

If USEPROXY is `true` and PROXYUSER is set, you must set this parameter.

This parameter was introduced in v2.0.4. | +| NONPROXYHOSTS | No | The list of hosts that the driver should connect to directly, bypassing the proxy server. Separate the hostnames with a pipe symbol (\|). You can also use an asterisk (`*`) as a wildcard.
The host target value should fully match with any item from the proxy host list to bypass the proxy server.

This parameter was introduced in v2.0.4. | +| FILE_TRANSFER_MEMORY_THRESHOLD | No | The maximum number of bytes to store in memory used in order to provide a file encryption. If encrypting/decrypting file size exceeds provided value a temporary file will be created and the work will be continued in the temporary file instead of memory.
If no value provided 1MB will be used as a default value (that is 1048576 bytes).
It is possible to configure any integer value bigger than zero representing maximal number of bytes to reside in memory. | +| CLIENT_CONFIG_FILE | No | The location of the client configuration json file. In this file you can configure easy logging feature. | +| ALLOWUNDERSCORESINHOST | No | Specifies whether to allow underscores in account names. This impacts PrivateLink customers whose account names contain underscores. In this situation, you must override the default value by setting allowUnderscoresInHost to true. | +| QUERY_TAG | No | Optional string that can be used to tag queries and other SQL statements executed within a connection. The tags are displayed in the output of the QUERY_HISTORY , QUERY_HISTORY_BY_* functions.
To set QUERY_TAG on the statement level you can use SnowflakeDbCommand.QueryTag. | +| MAXPOOLSIZE | No | Maximum number of connections in a pool. Default value is 10. `maxPoolSize` value cannot be lower than `minPoolSize` value. | +| MINPOOLSIZE | No | Expected minimum number of connections in pool. When you get a connection from the pool, more connections might be initialised in background to increase the pool size to `minPoolSize`. If you specify 0 or 1 there will be no attempts to create extra initialisations in background. The default value is 2. `maxPoolSize` value cannot be lower than `minPoolSize` value. The parameter is used only in a new version of connection pool. | +| CHANGEDSESSION | No | Specifies what should happen with a closed connection when some of its session variables are altered (e. g. you used `ALTER SESSION SET SCHEMA` to change the databese schema). The default behaviour is `OriginalPool` which means the session stays in the original pool. Currently no other option is possible. Parameter used only in a new version of connection pool. | +| WAITINGFORIDLESESSIONTIMEOUT | No | Timeout for waiting for an idle session when pool is full. It happens when there is no idle session and we cannot create a new one because of reaching `maxPoolSize`. The default value is 30 seconds. Usage of units possible and allowed are: e. g. `1000ms` (milliseconds), `15s` (seconds), `2m` (minutes) where seconds are default for a skipped postfix. Special values: `0` - immediate fail for new connection to open when session is full. You cannot specify infinite value. | +| EXPIRATIONTIMEOUT | No | Timeout for using each connection. Connections which last more than specified timeout are considered to be expired and are being removed from the pool. The default is 1 hour. Usage of units possible and allowed are: e. g. `360000ms` (milliseconds), `3600s` (seconds), `60m` (minutes) where seconds are default for a skipped postfix. Special values: `0` - immediate expiration of the connection just after its creation. Expiration timeout cannot be set to infinity. | +| POOLINGENABLED | No | Boolean flag indicating if the connection should be a part of a pool. The default value is `true`. | + +
+ +**Note**: Connections should not be shared across multiple threads. + +### Password-based Authentication + +The following example demonstrates how to open a connection to Snowflake. This example uses a password for authentication. + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = "account=testaccount;user=testuser;password=XXXXX;db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); +} +``` + + + +Beginning with version 2.0.18, the .NET connector uses Microsoft [DbConnectionStringBuilder](https://learn.microsoft.com/en-us/dotnet/api/system.data.oledb.oledbconnection.connectionstring?view=dotnet-plat-ext-6.0#remarks) to follow the .NET specification for escaping characters in connection strings. + +The following examples show how you can include different types of special characters in a connection string: + +- To include a single quote (') character: + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "user=testuser; " + + "password=test'password;" + ); + ``` + +- To include a double quote (") character: + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "user=testuser; " + + "password=test\"password;" + ); + ``` + +- To include a semicolon (;): + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "user=testuser; " + + "password=\"test;password\";" + ); + ``` + +- To include an equal sign (=): + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "user=testuser; " + + "password=test=password;" + ); + ``` + + Note that previously you needed to use a double equal sign (==) to escape the character. However, beginning with version 2.0.18, you can use a single equal size. + + +Snowflake supports using [double quote identifiers](https://docs.snowflake.com/en/sql-reference/identifiers-syntax#double-quoted-identifiers) for object property values (WAREHOUSE, DATABASE, SCHEMA AND ROLES). The value should be delimited with `\"` in the connection string. The value is case-sensitive and allow to use special characters as part of the value. + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "database=\"testDB\";" + ); + ``` +- To include a `"` character as part of the value should be escaped using `\"\"`. + + ```cs + string connectionString = String.Format( + "account=testaccount; " + + "database=\"\"\"test\"\"user\"\"\";" // DATABASE => ""test"db"" + ); + ``` + +### Other Authentication Methods + +If you are using a different method for authentication, see the examples below: + +- **Key-pair authentication** + + After setting up [key-pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth.html), you can specify the + private key for authentication in one of the following ways: + + - Specify the file containing an unencrypted private key: + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = "account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key_file={pathToThePrivateKeyFile};db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. + + - Specify the file containing an encrypted private key: + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = "account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key_file={pathToThePrivateKeyFile};private_key_pwd={passwordForDecryptingThePrivateKey};db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. + - `{passwordForDecryptingThePrivateKey}` is the password for decrypting the private key. + + - Specify an unencrypted private key (read from a file): + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + string privateKeyContent = File.ReadAllText({pathToThePrivateKeyFile}); + + conn.ConnectionString = String.Format("account=testaccount;authenticator=snowflake_jwt;user=testuser;private_key={0};db=testdb;schema=testschema", privateKeyContent); + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{pathToThePrivateKeyFile}` is the path to the file containing the unencrypted private key. + +- **OAuth** + + After setting up [OAuth](https://docs.snowflake.com/en/user-guide/oauth.html), set `AUTHENTICATOR=oauth` and `TOKEN` to the + OAuth token in the connection string. + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = "account=testaccount;user=testuser;authenticator=oauth;token={oauthTokenValue};db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{oauthTokenValue}` is the oauth token to use for authentication. + +- **Browser-based SSO** + + In the connection string, set `AUTHENTICATOR=externalbrowser`. + Optionally, `USER` can be set. In that case only if user authenticated via external browser matches the one from configuration, authentication will complete. + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = "account=testaccount;authenticator=externalbrowser;user={login_name_for_IdP};db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{login_name_for_IdP}` is your login name for your IdP. + + You can override the default timeout after which external browser authentication is marked as failed. + The timeout prevents the infinite hang when the user does not provide the login details, e.g. when closing the browser tab. + To override, you can provide `BROWSER_RESPONSE_TIMEOUT` parameter (in seconds). + +- **Native SSO through Okta** + + In the connection string, set `AUTHENTICATOR` to the + [URL of the endpoint for your Okta account](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use.html#label-native-sso-okta), + and set `USER` to the login name for your IdP. + + ```cs + using (IDbConnection conn = new SnowflakeDbConnection()) + { + conn.ConnectionString = "account=testaccount;authenticator={okta_url_endpoint};user={login_name_for_IdP};db=testdb;schema=testschema"; + + conn.Open(); + + conn.Close(); + } + ``` + + where: + + - `{okta_url_endpoint}` is the URL for the endpoint for your Okta account (e.g. `https://.okta.com`). + - `{login_name_for_IdP}` is your login name for your IdP. + +In v2.0.4 and later releases, you can configure the driver to connect through a proxy server. The following example configures the +driver to connect through the proxy server `myproxyserver` on port `8888`. The driver authenticates to the proxy server as the +user `test` with the password `test`: + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = "account=testaccount;user=testuser;password=XXXXX;db=testdb;schema=testschema;useProxy=true;proxyHost=myproxyserver;proxyPort=8888;proxyUser=test;proxyPassword=test"; + + conn.Open(); + + conn.Close(); +} +``` + +The NONPROXYHOSTS property could be set to specify if the server proxy should be bypassed by an specified host. This should be defined using the full host url or including the url + `*` wilcard symbol. + +Examples: + +- `*` (Bypassed all hosts from the proxy server) +- `*.snowflakecomputing.com` ('Bypass all host that ends with `snowflakecomputing.com`') +- `https:\\testaccount.snowflakecomputing.com` (Bypass proxy server using full host url). +- `*.myserver.com | *testaccount*` (You can specify multiple regex for the property divided by `|`) + + +> Note: The nonproxyhost value should match the full url including the http or https section. The '*' wilcard could be added to bypass the hostname successfully. + +- `myaccount.snowflakecomputing.com` (Not bypassed). +- `*myaccount.snowflakecomputing.com` (Bypassed). + diff --git a/doc/ConnectionPooling.md b/doc/ConnectionPooling.md new file mode 100644 index 000000000..f9b8dad86 --- /dev/null +++ b/doc/ConnectionPooling.md @@ -0,0 +1,405 @@ +## Using Connection Pools + +### Multiple Connection Pools + +Snowflake .NET Driver v4.0.0 provides multiple pools with couple of additional features in comparison to the previous implementation. + +Each pool is identified by the entire connection string. Order of connection string parameters is significant and the same connection parameters +ordered differently lead to two different pools being used. + +All the pool parameters can be controlled from the connection string. + +Pool interface is also maintained by the [SnowflakeDbConnectionPool.cs](https://github.com/snowflakedb/snowflake-connector-net/blob/master/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs). +However, some operations (eg. setting pool parameters from this SnowflakeDbConnectionPool class) are not possible having in mind multiple pools and possibly their different setup. +For that a [SnowflakeDbSessionPool.cs](/Snowflake.Data/Client/SnowflakeDbSessionPool.cs) is provided by +- [SnowflakeDbSessionPool.GetPool(connectionString)](https://github.com/snowflakedb/snowflake-connector-net/blob/master/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs#L45) +- [SnowflakeDbSessionPool.GetPool(connectionString, securePassword)](https://github.com/snowflakedb/snowflake-connector-net/blob/master/Snowflake.Data/Client/SnowflakeDbConnectionPool.cs#L51). +to control pool settings from the code. Changed pool settings are not reflected by their connection string therefore recommended way is to control the pool from the connection string. + +### Pool Lifecycle + +Single pool is instantiated each time an application creates and opens a connection for the first time using particular connection string. +Pool can be also initialized when accessing for the first time from SnowflakeDbConnectionPool.GetPool. + +From that moment the pool tracks and maintains connections matching exactly this connection string. +Pool is responsible for destroying and recreating connections which are old enough (see [Expiration Timeout](#expiration-timeout)). +Number of connections is maintained within [Minimum pool size](#minimum-pool-size) and [Maximum pool size](#maximum-pool-size). +Connections in all their statuses are tracked: +- opening phase +- busy phase +- closed and returned to the pool (idle) +User can clean up the pool using methods: [Clear Pool](#clear-pool). + +### Connection Lifecycle + +#### Opening + +When an application request to open a connection from the pool there are couple of possibilities: + +1) Pool has idle connections already opened and they are provided immediately to the application +2) Pool has no idle connections but [Maximum pool size](#maximum-pool-size) is not reached in which case pool will open connection. +The slot for the new connection is reserved in the pool from the very beginning. +Even though opening a connection may take a while other threads are not blocked from accessing the pool. +3) When [Maximum pool size](#maximum-pool-size) is reached connection is waiting to be opened within period of +time controlled with [Pool Size Exceeded Timeout](#pool-size-exceeded-timeout). +When the timeout is exceeded then an exception will be thrown. + +#### Busy + +`Busy` connection is provided by the pool and it is counted to the pool size. It is returned to be reused during Close operation. +When application does not close connections it may hit the limit of [Maximum pool size](#maximum-pool-size). + +#### Closing + +When application closes the connection couple of things happen: +- Pending transactions will be rolled back (if any) +- Connection can be pooled when its properties are not changed +- Connection with changed: database, schema, warehouse or role can be: + - pooled when OriginalPool mode enabled, see more: [Changed Session Behavior](#changed-session-behavior) + - destroyed when Destroy mode is set + +#### Evicting Connection + +In order to prevent connection pooling the easiest way is to disable pooling. More on this here: [Pooling Enabled](#pooling-enabled). + +However, in special cases an application may need to mark a single, opened connection to evict without turning off the pool. +When such a connection is closed it will not be pooled. Pool will create a new connection to maintain [Minimum pool size](#minimum-pool-size) if needed. + +```cs +using (var connection = new SnowflakeDbConnection(ConnectionString)) +{ + connection.Open(); + connection.PreventPooling(); +} +``` + +### Pool Interfaces + +| Connection Pool Feature | Connection String Parameter | Default | Method | Info | +|-----------------------------------------------------------|------------------------------|---------|---------------------------------|----------------------------------------------------------------------------------------------------------------------------| +| [Multiple pools](#multiple-pools) | | | | | +| [Minimum pool size](#minimum-pool-size) | MinPoolSize | 2 | | | +| [Maximum pool size](#maximum-pool-size) | MaxPoolSize | 10 | | | +| [Changed Session Behavior](#changed-session-behavior) | ChangedSession | Destroy | | Destroy or OriginalPool | +| [Pool Size Exceeded Timeout](#pool-size-exceeded-timeout) | WaitingForIdleSessionTimeout | 30s | | Values can be provided with postfix [ms], [s], [m] | +| [Expiration Timeout](#expiration-timeout) | ExpirationTimeout | 60m | | | +| [Pooling Enabled](#connection-timeout) | PoolingEnabled | true | | Pooling connections authenticated with External Browser or Key-Pair Authentication without password is disabled by default | +| [Connection Timeout](#pooling-enabled) | | 300s | | | +| [Current Pool Size](#current-pool-size) | | | GetCurrentPoolSize() | | +| [Clear Pool](#clear-pool) | | | ClearPool() or ClearAllPools() | | + +#### Multiple pools + +When a first connection is opened, a connection pool is created based on an exact matching algorithm that associates the pool with the connection string of the connection. Each connection pool is associated with a distinct connection string. When a new connection is opened, if the connection string is not an exact match to an existing pool, a new pool is created. + +Different pools can have separate settings from the above settings for instance: minimum pool size or changed session behavior. + +```cs +using (var connection = new SnowflakeDbConnection(ConnectionString + ";application=App1")) +{ + connection.Open(); + // Pool 1 is created +} + +using (var connection = new SnowflakeDbConnection(ConnectionString + ";application=App2")) +{ + connection.Open(); + // Pool 2 is created +} +``` + +#### Minimum pool size + +Ensures minimum specified size of the connections in a pool. Additional connections are created in the background during connection opening request. +When connections are being closed Connection Timeout is analysed for all the connections in a pool and the expired ones are being closed. +After that some connections will get recreated to ensure minimum size of the pool. + +```cs +var connectionString = ConnectionString + ";MinPoolSize=10"; +using (var connection = new SnowflakeDbConnection(connectionString)) +{ + connection.Open(); + // Pool of size 10 is created +} +var poolSize = SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize(); +Assert.AreEqual(10, poolSize); +``` + +#### Maximum pool size + +Latest pool version ensures maximum size of the pool. +What counts for that are: +- idle connections +- busy connections (provided by the pool to the application) +- connections during opening phase + +When a maximum pool size is reached any request to provide (open) another connection is waiting for any idle session to be returned to the pool. +When an Idle Session Timeout is reached and an idle session is not returned within that period an exception will get thrown. + +```cs +var connectionString = ConnectionString + ";MaxPoolSize=2"; + +Task[] tasks = new Task[8]; +for (int i = 0; i < tasks.Length; i++) +{ + var taskName = $"Task {i}"; + tasks[i] = Task.Run(() => + { + using (var connection = new SnowflakeDbConnection(connectionString)) + { + StopWatch sw = new StopWatch(); + + // register opening time + sw.Start(); + connection.Open(); + sw.Stop(); + + // output + Console.WriteLine($"{taskName} waited {Math.Round((double)sw.ElapsedMilliseconds / 1000)} seconds"); + + // wait 2s before closing the connection + Thread.Sleep(2000); + } + }); +} +Task.WaitAll(tasks); + +// check current pool size +var poolSize = SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize(); +Assert.AreEqual(2, poolSize); + +// output: +// Task 1 waited 0 seconds +// Task 4 waited 0 seconds +// Task 7 waited 2 seconds +// Task 0 waited 2 seconds +// Task 6 waited 4 seconds +// Task 3 waited 4 seconds +// Task 2 waited 6 seconds +// Task 5 waited 6 seconds +``` + +#### Changed Session Behavior + +When an application does a change to the connection using one of SQL commands, for instance: +* `use schema`, `create schema` +* `use database`, `create database` +* `use warehouse`, `create warehouse` +* `use role`, `create role` +* `drop` +then such an affected connection is marked internally as no longer matching with the pool it originated from (it becomes a "dirty" connection). +Keep in mind that create commands automatically set active the created object within current connection +(eg. [create database](https://docs.snowflake.com/en/sql-reference/sql/create-database#general-usage-notes)). + +Pool has two different approaches to connections altered with above way: +* Destroy connection +* Pool it back to the Original Pool + +1) Destroy Connection Mode + +To enable this pool mode parameter ChangedSession should be set to `Destroy` or entirely skipped (Destroy is the default pool behavior). +In this mode application may safely alter connection properties: schema, database, warehouse or role. Such a dirty connection no longer matching +with the connection string will not get pooled any more. The pool marks it internally as `dirty` and ensures it gets removed +when no longer used (closed) by the application. + +Since such connections do not return to the pool, it will recreate necessary number of connections to satisfy the Minimum Pool Size requirement. + +```cs +var connectionString = ConnectionString + ";ChangedSession=Destroy"; +var connection = new SnowflakeDbConnection(connectionString); + +connection.Open(); +var randomSchemaName = Guid.NewGuid(); +connection.CreateCommand($"create schema \"{randomSchemaName}\").ExecuteNonQuery(); // schema gets changed +// application is running commands on a schema with random name +connection.Close(); // connection does not return to the original pool and gets destroyed; pool will reconstruct the pool + // with new connections accordingly to the MinPoolSize + +var connection2 = new SnowflakeDbConnection(connectionString); +connection2.Open(); +// operations here will be performed against schema indicated in the ConnectionString +``` + +2) Pooling Changed Session to the Original Pool + +When parameter ChangedSession is set to `OriginalPool` it allows the connection to be pooled back to the original pool from which it came from. + +Disclaimer for OriginalPool Mode + +When application reuses connections affected by the above commands (use/create) it might get to a point when using a connection +provided by the pool it gets SQL syntax errors since tables, procedures, stages and other database objects do not exists because the operations +are executed using changed database, schema, user or role no longer matching connection string. +Reusing connection from a pool requires attention from the code perspective and ensuring that each retrieved connection uses appropriate database, schema, warehouse or role. +This mode is purely for backward compatibility but is not recommended to be used. It is also not a default. + +```cs +var connectionString = ConnectionString + ";ChangedSession=OriginalPool;MinPoolSize=1;MaxPoolSize=1"; +var connection = new SnowflakeDbConnection(connectionString); + +connection.Open(); +var randomSchemaName = Guid.NewGuid(); +connection.CreateCommand($"create schema \"{randomSchemaName}\").ExecuteNonQuery(); // schema gets changed +// application is running commands on a schema with random name +connection.Close(); // connection returns to the original pool but it's schema will no longer match with initial value + +var connection2 = new SnowflakeDbConnection(connectionString); +connection2.Open(); +// operations here will be performed against schema: randomSchemaName +``` + +#### Pool Size Exceeded Timeout + +The timeout for providing a connection when Max Pool Size is reached. +* When timeout to provide new connection is exceeded and there are no idle connections in the pool an exception will be thrown +* When specified as 0, an exception will be thrown immediately if there are no idle connections in the pool + +```cs +var connectionString = ConnectionString + ";MaxPoolSize=2;WaitingForIdleSessionTimeout=3"; + +Task[] tasks = new Task[8]; +for (int i = 0; i < tasks.Length; i++) +{ + var taskName = $"Task {i}"; + tasks[i] = Task.Run(() => + { + try + { + using (var connection = new SnowflakeDbConnection(connectionString)) + { + StopWatch sw = new StopWatch(); + + // register opening time + sw.Start(); + connection.Open(); + sw.Stop(); + + // output + Console.WriteLine($"{taskName} waited {Math.Round((double)sw.ElapsedMilliseconds / 1000)} seconds"); + + // wait 2s before closing the connection + Thread.Sleep(2000); + } + } + catch (SnowflakeDbException ex) + { + Console.WriteLine($"{taskName} - {ex.Message}"); + } + }); +} +Task.WaitAll(tasks); + +// check current pool size +var poolSize = SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize(); +Assert.AreEqual(2, poolSize); + +// output: +// Task 3 waited 0 seconds +// Task 0 waited 0 seconds +// Task 5 waited 2 seconds +// Task 6 waited 2 seconds +// Task 4 - Error: Snowflake Internal Error: Unable to connect. Could not obtain a connection from the pool within a given timeout SqlState: 08006, VendorCode: 270001, QueryId: +// Task 7 - Error: Snowflake Internal Error: Unable to connect. Could not obtain a connection from the pool within a given timeout SqlState: 08006, VendorCode: 270001, QueryId: +// Task 1 - Error: Snowflake Internal Error: Unable to connect. Could not obtain a connection from the pool within a given timeout SqlState: 08006, VendorCode: 270001, QueryId: +// Task 2 - Error: Snowflake Internal Error: Unable to connect. Could not obtain a connection from the pool within a given timeout SqlState: 08006, VendorCode: 270001, QueryId: +``` + +#### Expiration Timeout + +Overall timeout for entire connection lifetime +* When reached connection is always removed +* After pruning, Min Pool Size is checked to achieve expected number of connections in the pool + +```cs +var connectionString = ConnectionString + ";MinPoolSize=1;ExpirationTimeout=2"; +var connection1 = new SnowflakeDbConnection(connectionString); +var connection2 = new SnowflakeDbConnection(connectionString); +var connection3 = new SnowflakeDbConnection(connectionString); + +connection1.Open(); +connection2.Open(); +connection1.Close(); +connection2.Close(); + +// 2 connections are in the pool +Assert.AreEqual(2, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); + +Thread.Sleep(2000); + +connection3.Open(); +connection3.Close(); + +// both previous connections have expired +Assert.AreEqual(1, SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize()); +``` + +#### Connection Timeout + +Total timeout in seconds when connecting to Snowflake. +Equivalent of https://learn.microsoft.com/en-us/dotnet/api/system.data.idbconnection.connectiontimeout?view=net-6.0 + +```cs +var connectionString = ConnectionString + ";connection_timeout=160"; +using (var connection = new SnowflakeDbConnection(connectionString)) +{ + connection.Open(); +} +``` + +#### Pooling Enabled + +Enables or disables connection pooling for the pool identified by a given connection string. + +For security reasons pooling is disabled by default for External Browser or Key-Pair Authentication (unless password for key is provided). + +It can be enabled with a connection string parameter if needed. +However, be warned that using: +- token key file accessible by others and used to authorize connection +- shared environment with an external browser authenticated connections +leads to vulnerabilities and is not recommended. + +```cs +var connectionString = ConnectionString + ";PoolingEnabled=false"; +using (var connection = new SnowflakeDbConnection(connectionString)) +{ + connection.Open(); +} + +// no connection in the pool +var poolSize = SnowflakeDbConnectionPool.GetPool(connectionString).GetCurrentPoolSize(); +Assert.AreEqual(0, poolSize); +``` + +#### Current Pool Size + +Allows to check size of the given pool programatically. It is total number of all the connections: idle, busy and during initialization. + +```cs +var pool = SnowflakeDbConnectionPool.GetPool(connectionString); +var poolSize = pool.GetCurrentPoolSize(); +// default pool size is 2 +Assert.AreEqual(2, poolSize); +``` + +At the SnowflakeDbConnectionPool there is also a way to get sum of connections from all the pools. + +```cs +var pool1 = SnowflakeDbConnectionPool.GetPool(connectionString + ";MinPoolSize=2"); +var pool2 = SnowflakeDbConnectionPool.GetPool(connectionString + ";MinPoolSize=3"); +var poolsSize = SnowflakeDbConnectionPool.GetCurrentPoolSize(); +Assert.AreEqual(5, poolSize); +``` + +#### Clear Pool + +Interface allows to clear a particular pool or all the pools initiated by an application. +Please keep in mind that a default of min pool size will be maintained. + +```cs +var pool = SnowflakeDbConnectionPool.GetPool(connectionString); +``` + +There is also a way to clear all the pools initiated by an application. + +```cs +SnowflakeDbConnectionPool.ClearAllPools(); +``` diff --git a/doc/ConnectionPoolingDeprecated.md b/doc/ConnectionPoolingDeprecated.md new file mode 100644 index 000000000..f5f6005e9 --- /dev/null +++ b/doc/ConnectionPoolingDeprecated.md @@ -0,0 +1,70 @@ +## Using Connection Pools + +### Single Connection Pool (DEPRECATED) + +DEPRECATED VERSION + +Instead of creating a connection each time your client application needs to access Snowflake, you can define a cache of Snowflake connections that can be reused as needed. +Connection pooling usually reduces the lag time to make a connection. However, it can slow down client failover to an alternative DNS when a DNS problem occurs. + +The Snowflake .NET driver provides the following functions for managing connection pools. + +| Function | Description | +|-------------------------------------------------|---------------------------------------------------------------------------------------------------------| +| SnowflakeDbConnectionPool.ClearAllPools() | Removes all connections from the connection pool. | +| SnowflakeDbConnection.SetMaxPoolSize(n) | Sets the maximum number of connections for the connection pool, where _n_ is the number of connections. | +| SnowflakeDBConnection.SetTimeout(n) | Sets the number of seconds to keep an unresponsive connection in the connection pool. | +| SnowflakeDbConnectionPool.GetCurrentPoolSize() | Returns the number of connections currently in the connection pool. | +| SnowflakeDbConnectionPool.SetPooling() | Determines whether to enable (`true`) or disable (`false`) connection pooling. Default: `true`. | + +The following sample demonstrates how to monitor the size of a connection pool as connections are added and dropped from the pool. + +```cs +public void TestConnectionPoolClean() +{ + SnowflakeDbConnectionPool.ClearAllPools(); + SnowflakeDbConnectionPool.SetMaxPoolSize(2); + var conn1 = new SnowflakeDbConnection(); + conn1.ConnectionString = ConnectionString; + conn1.Open(); + Assert.AreEqual(ConnectionState.Open, conn1.State); + + var conn2 = new SnowflakeDbConnection(); + conn2.ConnectionString = ConnectionString + " retryCount=1"; + conn2.Open(); + Assert.AreEqual(ConnectionState.Open, conn2.State); + Assert.AreEqual(0, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + conn1.Close(); + conn2.Close(); + Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + var conn3 = new SnowflakeDbConnection(); + conn3.ConnectionString = ConnectionString + " retryCount=2"; + conn3.Open(); + Assert.AreEqual(ConnectionState.Open, conn3.State); + + var conn4 = new SnowflakeDbConnection(); + conn4.ConnectionString = ConnectionString + " retryCount=3"; + conn4.Open(); + Assert.AreEqual(ConnectionState.Open, conn4.State); + + conn3.Close(); + Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + conn4.Close(); + Assert.AreEqual(2, SnowflakeDbConnectionPool.GetCurrentPoolSize()); + + Assert.AreEqual(ConnectionState.Closed, conn1.State); + Assert.AreEqual(ConnectionState.Closed, conn2.State); + Assert.AreEqual(ConnectionState.Closed, conn3.State); + Assert.AreEqual(ConnectionState.Closed, conn4.State); +} +``` + +Note +Some of the features and configurations available for [Multiple Connection Pools](ConnectionPooling.md) are not available for the old pool. +Following configurations/settings have no effect on [Single Connection Pool](ConnectionPoolingDeprecated.md): +- `poolingEnabled` setting, feature not configurable by connection string, instead you could use `SnowflakeDbConnectionPool.SetPooling(false)` +- `changedSession` setting, only `OriginalPool` behavior available +- `maxPoolSize` setting, feature not configurable by connection string, instead you could use `SnowflakeDbConnectionPool.SetMaxPoolSize()` +- `minPoolSize` setting, feature not available +- `waitingForIdleSessionTimeout` setting, feature not available +- `expirationTimeout` setting, feature not configurable by connection string, instead you could use `SnowflakeDbConnectionPool.SetTimeout()`. diff --git a/doc/DataTypes.md b/doc/DataTypes.md new file mode 100644 index 000000000..7db51cdb7 --- /dev/null +++ b/doc/DataTypes.md @@ -0,0 +1,41 @@ +## Data Types and Formats + +## Mapping .NET and Snowflake Data Types + +The .NET driver supports the following mappings from .NET to Snowflake data types. + +| .NET Framekwork Data Type | Data Type in Snowflake | +| ------------------------- | ---------------------- | +| `int`, `long` | `NUMBER(38, 0)` | +| `decimal` | `NUMBER(38, )` | +| `double` | `REAL` | +| `string` | `TEXT` | +| `bool` | `BOOLEAN` | +| `byte` | `BINARY` | +| `datetime` | `DATE` | + +## Arrow data format + +The .NET connector, starting with v2.1.3, supports the [Arrow data format](https://arrow.apache.org/) +as a [preview](https://docs.snowflake.com/en/release-notes/preview-features) feature for data transfers +between Snowflake and a .NET client. The Arrow data format avoids extra +conversions between binary and textual representations of the data. The Arrow +data format can improve performance and reduce memory consumption in clients. + +The data format is controlled by the +DOTNET_QUERY_RESULT_FORMAT parameter. To use Arrow format, execute: + +```snowflake +-- at the session level +ALTER SESSION SET DOTNET_QUERY_RESULT_FORMAT = ARROW; +-- or at the user level +ALTER USER SET DOTNET_QUERY_RESULT_FORMAT = ARROW; +-- or at the account level +ALTER ACCOUNT SET DOTNET_QUERY_RESULT_FORMAT = ARROW; +``` + +The valid values for the parameter are: + +- ARROW +- JSON (default) + diff --git a/doc/Disconnecting.md b/doc/Disconnecting.md new file mode 100644 index 000000000..b6fd5b18a --- /dev/null +++ b/doc/Disconnecting.md @@ -0,0 +1,21 @@ +## Close the Connection + +To close the connection, call the `Close` method of `SnowflakeDbConnection`. + +If you want to avoid blocking threads while the connection is closing, call the `CloseAsync` method instead, passing in a +`CancellationToken`. This method was introduced in the v2.0.4 release. + +Note that because this method is not available in the generic `IDbConnection` interface, you must cast the object as +`SnowflakeDbConnection` before calling the method. For example: + +```cs +CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(); +// Close the connection +((SnowflakeDbConnection)conn).CloseAsync(cancellationTokenSource.Token); +``` + +## Evict the Connection + +For the open connection, call the `PreventPooling()` to mark the connection to be removed on close instead being still pooled. +The busy sessions counter will be decreased when the connection is closed. + diff --git a/doc/Logging.md b/doc/Logging.md new file mode 100644 index 000000000..18c235e7e --- /dev/null +++ b/doc/Logging.md @@ -0,0 +1,82 @@ +## Logging + +The Snowflake Connector for .NET uses [log4net](http://logging.apache.org/log4net/) as the logging framework. + +Here is a sample app.config file that uses [log4net](http://logging.apache.org/log4net/) + +```xml + +
+ + + + + + + + + + + + + + + + + + + + + +``` + +## Easy logging + +The Easy Logging feature lets you change the log level for all driver classes and add an extra file appender for logs from the driver's classes at runtime. You can specify the log levels and the directory in which to save log files in a configuration file (default: `sf_client_config.json`). + +You typically change log levels only when debugging your application. + +**Note** +This logging configuration file features support only the following log levels: + +- OFF +- ERROR +- WARNING +- INFO +- DEBUG +- TRACE + +This configuration file uses JSON to define the `log_level` and `log_path` logging parameters, as follows: + +```json +{ + "common": { + "log_level": "INFO", + "log_path": "c:\\some-path\\some-directory" + } +} +``` + +where: + +- `log_level` is the desired logging level. +- `log_path` is the location to store the log files. The driver automatically creates a `dotnet` subdirectory in the specified `log_path`. For example, if you set log_path to `c:\logs`, the drivers creates the `c:\logs\dotnet` directory and stores the logs there. + +The driver looks for the location of the configuration file in the following order: + +- `CLIENT_CONFIG_FILE` connection parameter, containing the full path to the configuration file (e.g. `"ACCOUNT=test;USER=test;PASSWORD=test;CLIENT_CONFIG_FILE=C:\\some-path\\client_config.json;"`) +- `SF_CLIENT_CONFIG_FILE` environment variable, containing the full path to the configuration file. +- .NET driver/application directory, where the file must be named `sf_client_config.json`. +- User’s home directory, where the file must be named `sf_client_config.json`. + +**Note** +To enhance security, the driver no longer searches a temporary directory for easy logging configurations. Additionally, the driver now requires the logging configuration file on Unix-style systems to limit file permissions to allow only the file owner to modify the files (such as `chmod 0600` or `chmod 0644`). + +To minimize the number of searches for a configuration file, the driver reads the file only for: + +- The first connection. +- The first connection with `CLIENT_CONFIG_FILE` parameter. + +The extra logs are stored in a `dotnet` subfolder of the specified directory, such as `C:\some-path\some-directory\dotnet`. + +If a client uses the `log4net` library for application logging, enabling easy logging affects the log level in those logs as well. diff --git a/doc/QueryingData.md b/doc/QueryingData.md new file mode 100644 index 000000000..cec2323bb --- /dev/null +++ b/doc/QueryingData.md @@ -0,0 +1,252 @@ +## Run a Query and Read Data + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = connectionString; + conn.Open(); + + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = "select * from t"; + IDataReader reader = cmd.ExecuteReader(); + + while(reader.Read()) + { + Console.WriteLine(reader.GetString(0)); + } + + conn.Close(); +} +``` + +Note that for a `TIME` column, the reader returns a `System.DateTime` value. If you need a `System.TimeSpan` column, call the +`getTimeSpan` method in `SnowflakeDbDataReader`. This method was introduced in the v2.0.4 release. + +Note that because this method is not available in the generic `IDataReader` interface, you must cast the object as +`SnowflakeDbDataReader` before calling the method. For example: + +```cs +TimeSpan timeSpanTime = ((SnowflakeDbDataReader)reader).GetTimeSpan(13); +``` + +## Execute a query asynchronously on the server + +You can run the query asynchronously on the server. The server responds immediately with `queryId` and continues to execute the query asynchronously. +Then you can use this `queryId` to check the query status or wait until the query is completed and get the results. +It is fine to start the query in one session and continue to query for the results in another one based on the queryId. + +**Note**: There are 2 levels of asynchronous execution. One is asynchronous execution in terms of C# language (`async await`). +Another is asynchronous execution of the query by the server (you can recognize it by `InAsyncMode` containing method names, e. g. `ExecuteInAsyncMode`, `ExecuteAsyncInAsyncMode`). + +Example of synchronous code starting a query to be executed asynchronously on the server: +```cs +using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) +{ + conn.Open(); + SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand(); + cmd.CommandText = "SELECT ..."; + var queryId = cmd.ExecuteInAsyncMode(); + // ... +} +``` + +Example of asynchronous code starting a query to be executed asynchronously on the server: +```cs +using (SnowflakeDbConnection conn = new SnowflakeDbConnection("account=testaccount;username=testusername;password=testpassword")) +{ + await conn.OpenAsync(CancellationToken.None).ConfigureAwait(false); + SnowflakeDbCommand cmd = (SnowflakeDbCommand)conn.CreateCommand()) + cmd.CommandText = "SELECT ..."; + var queryId = await cmd.ExecuteAsyncInAsyncMode(CancellationToken.None).ConfigureAwait(false); + // ... +} +``` + +You can check the status of a query executed asynchronously on the server either in synchronous code: +```cs +var queryStatus = cmd.GetQueryStatus(queryId); +Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running +Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error +``` +or the same in an asynchronous code: +```cs +var queryStatus = await cmd.GetQueryStatusAsync(queryId, CancellationToken.None).ConfigureAwait(false); +Assert.IsTrue(conn.IsStillRunning(queryStatus)); // assuming that the query is still running +Assert.IsFalse(conn.IsAnError(queryStatus)); // assuming that the query has not finished with error +``` + +The following example shows how to get query results. +The operation will repeatedly check the query status until the query is completed or timeout happened or reaching the maximum number of attempts. +The synchronous code example: +```cs +DbDataReader reader = cmd.GetResultsFromQueryId(queryId); +``` +and the asynchronous code example: +```cs +DbDataReader reader = await cmd.GetResultsFromQueryIdAsync(queryId, CancellationToken.None).ConfigureAwait(false); +``` + +**Note**: GET/PUT operations are currently not enabled for asynchronous executions. + +## Executing a Batch of SQL Statements (Multi-Statement Support) + +With version 2.0.18 and later of the .NET connector, you can send +a batch of SQL statements, separated by semicolons, +to be executed in a single request. + +**Note**: Snowflake does not currently support variable binding in multi-statement SQL requests. + +--- + +**Note** + +By default, Snowflake returns an error for queries issued with multiple statements to protect against SQL injection attacks. The multiple statements feature makes your system more vulnerable to SQL injections, and so it should be used carefully. You can reduce the risk by using the MULTI_STATEMENT_COUNT parameter to specify the number of statements to be executed, which makes it more difficult to inject a statement by appending to it. + +--- + +You can execute multiple statements as a batch in the same way you execute queries with single statements, except that the query string contains multiple statements separated by semicolons. Note that multiple statements execute sequentially, not in parallel. + +You can set this parameter at the session level using the following command: + +``` +ALTER SESSION SET MULTI_STATEMENT_COUNT = <0/1>; +``` + +where: + +- **0**: Enables an unspecified number of SQL statements in a query. + + Using this value allows batch queries to contain any number of SQL statements without needing to specify the MULTI_STATEMENT_COUNT statement parameter. However, be aware that using this value reduces the protection against SQL injection attacks. + +- **1**: Allows one SQL statement or a specified number of statement in a query string (default). + + You must include MULTI_STATEMENT_COUNT as a statement parameter to specify the number of statements included when the query string contains more than one statement. If the number of statements sent in the query string does not match the MULTI_STATEMENT_COUNT value, the .NET driver rejects the request. You can, however, omit this parameter if you send a single statement. + +The following example sets the MULTI_STATEMENT_COUNT session parameter to 1. Then for an individual command, it sets MULTI_STATEMENT_COUNT=3 to indicate that the query contains precisely three SQL commands. The query string, `cmd.CommandText` , then contains the three statements to execute. + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = ConnectionString; + conn.Open(); + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = "ALTER SESSION SET MULTI_STATEMENT_COUNT = 1;"; + cmd.ExecuteNonQuery(); + conn.Close(); +} + +using (DbCommand cmd = conn.CreateCommand()) +{ + // Set statement count + var stmtCountParam = cmd.CreateParameter(); + stmtCountParam.ParameterName = "MULTI_STATEMENT_COUNT"; + stmtCountParam.DbType = DbType.Int16; + stmtCountParam.Value = 3; + cmd.Parameters.Add(stmtCountParam); + cmd.CommandText = "CREATE OR REPLACE TABLE test(n int); INSERT INTO test values(1), (2); SELECT * FROM test ORDER BY n; + DbDataReader reader = cmd.ExecuteReader(); + do + { + if (reader.HasRow) + { + while (reader.Read()) + { + // read data + } + } + } + while (reader.NextResult()); +} +``` + +## Bind Parameter + +**Note**: Snowflake does not currently support variable binding in multi-statement SQL requests. + +This example shows how bound parameters are converted from C# data types to +Snowflake data types. For example, if the data type of the Snowflake column +is INTEGER, then you can bind C# data types Int32 or Int16. + +This example inserts 3 rows into a table with one column. + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = connectionString; + conn.Open(); + + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = "create or replace table T(cola int)"; + int count = cmd.ExecuteNonQuery(); + Assert.AreEqual(0, count); + + IDbCommand cmd = conn.CreateCommand(); + cmd.CommandText = "insert into t values (?), (?), (?)"; + + var p1 = cmd.CreateParameter(); + p1.ParameterName = "1"; + p1.Value = 10; + p1.DbType = DbType.Int32; + cmd.Parameters.Add(p1); + + var p2 = cmd.CreateParameter(); + p2.ParameterName = "2"; + p2.Value = 10000L; + p2.DbType = DbType.Int32; + cmd.Parameters.Add(p2); + + var p3 = cmd.CreateParameter(); + p3.ParameterName = "3"; + p3.Value = (short)1; + p3.DbType = DbType.Int16; + cmd.Parameters.Add(p3); + + var count = cmd.ExecuteNonQuery(); + Assert.AreEqual(3, count); + + cmd.CommandText = "drop table if exists T"; + count = cmd.ExecuteNonQuery(); + Assert.AreEqual(0, count); + + conn.Close(); +} +``` + +## Bind Array Variables + +The sample code creates a table with a single integer column and then uses array binding to populate the table with values 0 to 70000. + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + conn.ConnectionString = ConnectionString; + conn.Open(); + + using (IDbCommand cmd = conn.CreateCommand()) + { + cmd.CommandText = "create or replace table putArrayBind(colA integer)"; + cmd.ExecuteNonQuery(); + + string insertCommand = "insert into putArrayBind values (?)"; + cmd.CommandText = insertCommand; + + int total = 70000; + + List arrint = new List(); + for (int i = 0; i < total; i++) + { + arrint.Add(i); + } + var p1 = cmd.CreateParameter(); + p1.ParameterName = "1"; + p1.DbType = DbType.Int16; + p1.Value = arrint.ToArray(); + cmd.Parameters.Add(p1); + + count = cmd.ExecuteNonQuery(); // count = 70000 + } + + conn.Close(); +} +``` + diff --git a/doc/StageFiles.md b/doc/StageFiles.md new file mode 100644 index 000000000..aa59b82b9 --- /dev/null +++ b/doc/StageFiles.md @@ -0,0 +1,64 @@ +## PUT local files to stage + +PUT command can be used to upload files of a local directory or a single local file to the Snowflake stages (named, internal table stage or internal user stage). +Such staging files can be used to load data into a table. +More on this topic: [File staging with PUT](https://docs.snowflake.com/en/sql-reference/sql/put). + +In the driver the command can be executed in a bellow way: + +```cs +using (IDbConnection conn = new SnowflakeDbConnection()) +{ + try + { + conn.ConnectionString = ""; + conn.Open(); + var cmd = (SnowflakeDbCommand)conn.CreateCommand(); // cast allows get QueryId from the command + + cmd.CommandText = "PUT file://some_data.csv @my_schema.my_stage AUTO_COMPRESS=TRUE"; + var reader = cmd.ExecuteReader(); + Assert.IsTrue(reader.read()); + Assert.DoesNotThrow(() => Guid.Parse(cmd.GetQueryId())); + } + catch (SnowflakeDbException e) + { + Assert.DoesNotThrow(() => Guid.Parse(e.QueryId)); // when failed + Assert.That(e.InnerException.GetType(), Is.EqualTo(typeof(FileNotFoundException))); + } +``` + +In case of a failure a SnowflakeDbException exception will be thrown with affected QueryId if possible. +If it was after the query got executed this exception will be a SnowflakeDbException containing affected QueryId. +In case of the initial phase of execution QueryId might not be provided. +Inner exception (if applicable) will provide some details on the failure cause and +it will be for example: FileNotFoundException, DirectoryNotFoundException. + +## GET stage files + +GET command allows to download stage directories or files to a local directory. +It can be used in connection with named stage, table internal stage or user stage. +Detailed information on the command: [Downloading files with GET](https://docs.snowflake.com/en/sql-reference/sql/get). + +To use the command in a driver similar code can be executed in a client app: + +```cs + try + { + conn.ConnectionString = ""; + conn.Open(); + var cmd = (SnowflakeDbCommand)conn.CreateCommand(); // cast allows get QueryId from the command + + cmd.CommandText = "GET @my_schema.my_stage/stage_file.csv file://local_file.csv AUTO_COMPRESS=TRUE"; + var reader = cmd.ExecuteReader(); + Assert.IsTrue(reader.read()); // True on success, False if failure + Assert.DoesNotThrow(() => Guid.Parse(cmd.GetQueryId())); + } + catch (SnowflakeDbException e) + { + Assert.DoesNotThrow(() => Guid.Parse(e.QueryId)); // on failure + } +``` + +In case of a failure a SnowflakeDbException will be thrown with affected QueryId if possible. +When no technical or syntax errors occurred but the DBDataReader has no data to process it returns False +without throwing an exception. diff --git a/doc/Testing.md b/doc/Testing.md new file mode 100644 index 000000000..70ee63f28 --- /dev/null +++ b/doc/Testing.md @@ -0,0 +1,54 @@ +# Testing the Connector + +Before running tests, create a parameters.json file under Snowflake.Data.Tests\ directory. In this file, specify username, password and account info that tests will run against. Here is a sample parameters.json file + +``` +{ + "testconnection": { + "SNOWFLAKE_TEST_USER": "snowman", + "SNOWFLAKE_TEST_PASSWORD": "XXXXXXX", + "SNOWFLAKE_TEST_ACCOUNT": "TESTACCOUNT", + "SNOWFLAKE_TEST_WAREHOUSE": "TESTWH", + "SNOWFLAKE_TEST_DATABASE": "TESTDB", + "SNOWFLAKE_TEST_SCHEMA": "TESTSCHEMA", + "SNOWFLAKE_TEST_ROLE": "TESTROLE", + "SNOWFLAKE_TEST_HOST": "testaccount.snowflakecomputing.com" + } +} +``` + +## Command Prompt + +The build solution file builds the connector and tests binaries. Issue the following command from the command line to run the tests. The test binary is located in the Debug directory if you built the solution file in Debug mode. + +```{r, engine='bash', code_block_name} +cd Snowflake.Data.Tests +dotnet test -f net6.0 -l "console;verbosity=normal" +``` + +Tests can also be run under code coverage: + +```{r, engine='bash', code_block_name} +dotnet-coverage collect "dotnet test --framework net6.0 --no-build -l console;verbosity=normal" --output net6.0_coverage.xml --output-format cobertura --settings coverage.config +``` + +You can run only specific suite of tests (integration or unit). + +Running unit tests: + +```bash +cd Snowflake.Data.Tests +dotnet test -l "console;verbosity=normal" --filter FullyQualifiedName~UnitTests -l console;verbosity=normal +``` + +Running integration tests: + +```bash +cd Snowflake.Data.Tests +dotnet test -l "console;verbosity=normal" --filter FullyQualifiedName~IntegrationTests +``` + +## Visual Studio 2017 + +Tests can also be run under Visual Studio 2017. Open the solution file in Visual Studio 2017 and run tests using Test Explorer. +