diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ef1e7a9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +* text=auto eol=lf +*.java text diff=java diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..f15fa11 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,35 @@ +name: Build and test + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'temurin' + cache: maven + - name: Get hyperd version + id: evaluate-property + run: | + echo "HYPER_VERSION=$(mvn help:evaluate -Dexpression=hyperapi.version -q -DforceStdout)" >> $GITHUB_ENV + - name: Cache hyperd + uses: actions/cache@v3 + with: + path: | + target/.cache + key: ${{ runner.os }}-hyper-${{ env.HYPER_VERSION }} + restore-keys: | + ${{ runner.os }}-hyper-${{ env.HYPER_VERSION }} + - name: Maven package + run: mvn --batch-mode --no-transfer-progress clean package --file pom.xml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..68e9e88 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,42 @@ +name: Release to staging + +on: + release: + types: [ "created" ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up JDK 11 + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'temurin' + server-id: ossrh + server-username: 'MAVEN_USERNAME' + server-password: 'MAVEN_PASSWORD' + gpg-private-key: ${{ secrets.GPG_SIGNING_KEY }} + gpg-passphrase: 'MAVEN_GPG_PASSPHRASE' + - name: Get hyperd version + id: evaluate-property + run: | + echo "HYPER_VERSION=$(mvn help:evaluate -Dexpression=hyperapi.version -q -DforceStdout)" >> $GITHUB_ENV + - name: Cache hyperd + uses: actions/cache@v3 + with: + path: | + target/.cache + key: ${{ runner.os }}-hyper-${{ env.HYPER_VERSION }} + restore-keys: | + ${{ runner.os }}-hyper-${{ env.HYPER_VERSION }} + - name: Set version + run: mvn versions:set --no-transfer-progress -DnewVersion=${{ github.event.release.tag_name }} + - name: Build with Maven + run: mvn --batch-mode --no-transfer-progress clean deploy -P release --file pom.xml + env: + MAVEN_USERNAME: ${{ secrets.CENTRAL_TOKEN_USERNAME }} + MAVEN_PASSWORD: ${{ secrets.CENTRAL_TOKEN_PASSWORD }} + MAVEN_GPG_PASSPHRASE: ${{ secrets.GPG_SIGNING_KEY_PASSWORD }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d467b07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +.idea +!.idea/externalDependencies.xml +!.idea/palantir-java-format.xml + +.DS_Store + +target/ +pom.xml.tag +pom.xml.releaseBackup +pom.xml.versionsBackup +pom.xml.next +release.properties +dependency-reduced-pom.xml +buildNumber.properties +.mvn/timing.properties +.mvn/wrapper/maven-wrapper.jar +.project +.classpath +src/main/resources/config/config.properties + +*.iml +pom.xml.bak diff --git a/.hooks/pre-commit b/.hooks/pre-commit new file mode 100755 index 0000000..8c878af --- /dev/null +++ b/.hooks/pre-commit @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +set -e + +echo '[git pre-commit] mvn spotless:apply sortpom:sort' +MAVEN_OPTS='-Dorg.slf4j.simpleLogger.defaultLogLevel=error' mvn spotless:apply sortpom:sort +git add --update \ No newline at end of file diff --git a/.idea/externalDependencies.xml b/.idea/externalDependencies.xml new file mode 100644 index 0000000..faf3cba --- /dev/null +++ b/.idea/externalDependencies.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/palantir-java-format.xml b/.idea/palantir-java-format.xml new file mode 100644 index 0000000..3815718 --- /dev/null +++ b/.idea/palantir-java-format.xml @@ -0,0 +1,6 @@ + + + + + \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS index 6010d8c..0393ef9 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,4 @@ # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. -#ECCN:Open Source -#GUSINFO:Open Source,Open Source Workflow \ No newline at end of file +#ECCN: 5D002.c.1 +#GUSINFO:Open Source,Open Source Workflow +* datacloud-query-connector-owners@salesforce.com diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9bdfbf2..2d2ecb9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,46 +1,27 @@ -*This is a suggested `CONTRIBUTING.md` file template for use by open sourced Salesforce projects. The main goal of this file is to make clear the intents and expectations that end-users may have regarding this project and how/if to engage with it. Adjust as needed (especially look for `{project_slug}` which refers to the org and repo name of your project) and remove this paragraph before committing to your repo.* +# Contributing Guide For Data Cloud JDBC Driver -# Contributing Guide For {NAME OF PROJECT} - -This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to {PROJECT}. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. +This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to Data Cloud JDBC Driver. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. # Governance Model -> Pick the most appropriate one - -## Community Based - -The intent and goal of open sourcing this project is to increase the contributor and user base. The governance model is one where new project leads (`admins`) will be added to the project based on their contributions and efforts, a so-called "do-acracy" or "meritocracy" similar to that used by all Apache Software Foundation projects. - -> or ## Salesforce Sponsored The intent and goal of open sourcing this project is to increase the contributor and user base. However, only Salesforce employees will be given `admin` rights and will be the final arbitrars of what contributions are accepted or not. -> or - -## Published but not supported - -The intent and goal of open sourcing this project is because it may contain useful or interesting code/concepts that we wish to share with the larger open source community. Although occasional work may be done on it, we will not be looking for or soliciting contributions. - -# Getting started - -Please join the community on {Here list Slack channels, Email lists, Glitter, Discord, etc... links}. Also please make sure to take a look at the project [roadmap](ROADMAP.md) to see where are headed. - # Issues, requests & ideas Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. ### Bug Reports and Fixes -- If you find a bug, please search for it in the [Issues](https://github.com/{project_slug}/issues), and if it isn't already tracked, - [create a new issue](https://github.com/{project_slug}/issues/new). Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still +- If you find a bug, please search for it in the [Issues](https://github.com/forcedotcom/datacloud-jdbc/issues), and if it isn't already tracked, + [create a new issue](https://github.com/forcedotcom/datacloud-jdbc/issues/new). Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still be reviewed. - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`. - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number. - Include tests that isolate the bug and verifies that it was fixed. ### New Features -- If you'd like to add new functionality to this project, describe the problem you want to solve in a [new Issue](https://github.com/{project_slug}/issues/new). +- If you'd like to add new functionality to this project, describe the problem you want to solve in a [new Issue](https://github.com/forcedotcom/datacloud-jdbc/issues/new). - Issues that have been identified as a feature request will be labelled `enhancement`. - If you'd like to implement the new feature, please wait for feedback from the project maintainers before spending too much time writing the code. In some cases, `enhancement`s may @@ -51,7 +32,7 @@ Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. alternative implementation of something that may have advantages over the way its currently done, or you have any other change, we would be happy to hear about it! - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind. - - If not, [open an Issue](https://github.com/{project_slug}/issues/new) to discuss the idea first. + - If not, [open an Issue](https://github.com/forcedotcom/datacloud-jdbc/issues/new) to discuss the idea first. If you're new to our project and looking for some way to make your first contribution, look for Issues labelled `good first contribution`. diff --git a/LICENSE.txt b/LICENSE.txt index ae7332a..c2516fc 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -191,7 +191,7 @@ Apache License same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright {yyyy} {name of copyright owner} + Copyright 2024 Salesforce Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index b2beac8..4164c31 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,178 @@ -# README +# Salesforce DataCloud JDBC Driver -A repo containing all the basic file templates and general guidelines for any open source project at Salesforce. +With the Salesforce Data Cloud JDBC driver you can efficiently query millions of rows of data with low latency, and perform bulk data extractions. +This driver is read-only and forward-only. +It requires Java 11 or greater. + + +## Getting started + +To add the driver to your project, add the following Maven dependency: + +```xml + + com.salesforce.datacloud + jdbc + ${jdbc.version} + +``` + +The class name for this driver is: + +``` +com.salesforce.datacloud.jdbc.DataCloudJDBCDriver +``` + +## Building the driver: + +Use the following command to build and test the driver: + +```shell +mvn clean install +``` ## Usage -It's required that all files must be placed at the top level of your repository. +### Connection string + +Use `jdbc:salesforce-datacloud://login.salesforce.com` + +### JDBC Driver class + +Use `com.salesforce.datacloud.jdbc.DataCloudJDBCDriver` as the driver class name for the JDBC application. + +### Authentication + +We support three of the [OAuth authorization flows][oauth authorization flows] provided by Salesforce. +All of these flows require a connected app be configured for the driver to authenticate as, see the documentation here: [connected app overview][connected app overview]. +Set the following properties appropriately to establish a connection with your chosen OAuth authorization flow: + +| Parameter | Description | +|--------------|----------------------------------------------------------------------------------------------------------------------| +| user | The login name of the user. | +| password | The password of the user. | +| clientId | The consumer key of the connected app. | +| clientSecret | The consumer secret of the connected app. | +| privateKey | The private key of the connected app. | +| coreToken | OAuth token that a connected app uses to request access to a protected resource on behalf of the client application. | +| refreshToken | Token obtained from the web server, user-agent, or hybrid app token flow. | + + +#### username and password authentication: + +The documentation for username and password authentication can be found [here][username flow]. + +To configure username and password, set properties like so: + +```java +Properties properties = new Properties(); +properties.put("user", "${userName}"); +properties.put("password", "${password}"); +properties.put("clientId", "${clientId}"); +properties.put("clientSecret", "${clientSecret}"); +``` + +#### jwt authentication: + +The documentation for jwt authentication can be found [here][jwt flow]. + +Instuctions to generate a private key can be found [here](#generating-a-private-key-for-jwt-authentication) + +```java +Properties properties = new Properties(); +properties.put("privateKey", "${privateKey}"); +properties.put("clientId", "${clientId}"); +properties.put("clientSecret", "${clientSecret}"); +``` + +#### refresh token authentication: + +The documentation for refresh token authentication can be found [here][refresh token flow]. + +```java +Properties properties = new Properties(); +properties.put("coreToken", "${coreToken}"); +properties.put("refreshToken", "${refreshToken}"); +properties.put("clientId", "${clientId}"); +properties.put("clientSecret", "${clientSecret}"); +``` + +### Connection settings + +See this page on available [connection settings][connection settings]. +These settings can be configured in properties by using the prefix `serverSetting.` + +For example, to control locale set the following property: + +```java +properties.put("serverSetting.lc_time", "en_US"); +``` + +--- + +### Generating a private key for jwt authentication + +To authenticate using key-pair authentication you'll need to generate a certificate and register it with your connected app. + +```shell +# create a key pair: +openssl genrsa -out keypair.key 2048 +# create a digital certificate, follow the prompts: +openssl req -new -x509 -nodes -sha256 -days 365 -key keypair.key -out certificate.crt +# create a private key from the key pair: +openssl pkcs8 -topk8 -nocrypt -in keypair.key -out private.key +``` + +### Optional configuration + +- `dataspace`: The data space to query, defaults to "default" +- `User-Agent`: The User-Agent string identifies the JDBC driver and, optionally, the client application making the database connection.
+ By default, the User-Agent string will end with "salesforce-datacloud-jdbc/{version}" and we will prepend any User-Agent provided by the client application.
+ For example: "User-Agent: ClientApp/1.2.3 salesforce-datacloud-jdbc/1.0" + + +### Usage sample code + +```java +public static void executeQuery() throws ClassNotFoundException, SQLException { + Class.forName("com.salesforce.datacloud.jdbc.DataCloudJDBCDriver"); + + Properties properties = new Properties(); + properties.put("user", "${userName}"); + properties.put("password", "${password}"); + properties.put("clientId", "${clientId}"); + properties.put("clientSecret", "${clientSecret}"); + + try (var connection = DriverManager.getConnection("jdbc:salesforce-datacloud://login.salesforce.com", properties); + var statement = connection.createStatement()) { + var resultSet = statement.executeQuery("${query}"); + + while (resultSet.next()) { + // Iterate over the result set + } + } +} +``` + +## Generated assertions + +Some of our classes are tested using assertions generated with [the assertj assertions generator][assertion generator]. +Due to some transient test-compile issues we experienced, we checked in generated assertions for some of our classes. +If you make changes to any of these classes, you will need to re-run the assertion generator to have the appropriate assertions available for that class. + +To find examples of these generated assertions, look for files with the path `**/test/**/*Assert.java`. + +To re-generate these assertions execute the following command: + +```shell +mvn assertj:generate-assertions +``` -> **NOTE** Your README should contain detailed, useful information about the project! +[oauth authorization flows]: https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_flows.htm&type=5 +[username flow]: https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_username_password_flow.htm&type=5 +[jwt flow]: https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_jwt_flow.htm&type=5 +[refresh token flow]: https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_refresh_token_flow.htm&type=5 +[connection settings]: https://tableau.github.io/hyper-db/docs/hyper-api/connection#connection-settings +[assertion generator]: https://joel-costigliola.github.io/assertj/assertj-assertions-generator-maven-plugin.html#configuration +[connected app overview]: https://help.salesforce.com/s/articleView?id=sf.connected_app_overview.htm&type=5 \ No newline at end of file diff --git a/license-header.txt b/license-header.txt new file mode 100644 index 0000000..2ff21cf --- /dev/null +++ b/license-header.txt @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ \ No newline at end of file diff --git a/license_info.md b/license_info.md deleted file mode 100644 index b8da254..0000000 --- a/license_info.md +++ /dev/null @@ -1,224 +0,0 @@ -License Info ------------- - -Most projects we open source should use the [Apache License v2](https://opensource.org/license/apache-2-0/) license. Samples, demos, and blog / doc code examples should instead use [CC-0](https://creativecommons.org/publicdomain/zero/1.0/). If you strongly feel your project should perhaps use a different license clause, please engage with legal team. - -For the ALv2 license, create a `LICENSE.txt` file (or use the one in this template repo) in the root of your repo containing: -``` -Apache License Version 2.0 - -Copyright (c) 2023 Salesforce, Inc. -All rights reserved. - -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS -``` - -The shorter version of license text should be added as a comment to all Salesforce-authored source code and configuration files that support comments. This include file formats like HTML, CSS, JavaScript, XML, etc. which aren't directly code, but are still critical to your project code. Like: -``` -/* - * Copyright (c) 2023, Salesforce, Inc. - * SPDX-License-Identifier: Apache-2 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - ``` - -Note that there are many tools that exist to do this sort of thing in an automated fashion, without having to manually edit every single file in your project. It is highly recommended that you research some of these tools for your particular language / build system. - -For sample, demo, and example code, we recommend the [Unlicense](https://opensource.org/license/unlicense/) license. Create a `LICENSE.txt` file containing: -``` -This is free and unencumbered software released into the public domain. - -Anyone is free to copy, modify, publish, use, compile, sell, or distribute this software, either in source code form or as a compiled binary, for any purpose, commercial or non-commercial, and by any means. - -In jurisdictions that recognize copyright laws, the author or authors of this software dedicate any and all copyright interest in the software to the public domain. We make this dedication for the benefit of the public at large and to the detriment of our heirs and successors. We intend this dedication to be an overt act of relinquishment in perpetuity of all present and future rights to this software under copyright law. - -THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -``` - -No license header is required for samples, demos, and example code. diff --git a/lombok.config b/lombok.config new file mode 100644 index 0000000..be5091c --- /dev/null +++ b/lombok.config @@ -0,0 +1,3 @@ +config.stopBubbling = true +lombok.nonNull.exceptionType = IllegalArgumentException +lombok.addLombokGeneratedAnnotation = true \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..dab03f8 --- /dev/null +++ b/pom.xml @@ -0,0 +1,819 @@ + + + 4.0.0 + com.salesforce.datacloud + jdbc + 0.20.0-SNAPSHOT + jar + Salesforce Data Cloud JDBC Driver + Salesforce Data Cloud JDBC Driver + + 18.0.0 + 3.26.3 + 1.25.0 + 1.17.1 + 3.17.0 + ${project.build.directory}/.cache + 3.5.0 + 0.13.0 + + 0.0.20746.reac9bd2d + ${project.build.directory}/hyper + 2.18.0 + 0.8.12 + 11 + 0.12.6 + 5.11.3 + 1.18.34 + {java.version} + {java.version} + 5.14.1 + 4.12.0 + UTF-8 + UTF-8 + 3.25.5 + com.salesforce.datacloud.jdbc.internal.shaded + 1.7.32 + 2.43.0 + + + + + org.apache.arrow + arrow-bom + ${arrow.version} + pom + import + + + org.junit + junit-bom + ${junit-bom.version} + pom + import + + + org.mockito + mockito-bom + ${mockito-bom.version} + pom + import + + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + + + com.google.protobuf + protobuf-java + ${protobuf-java.version} + + + com.squareup.okhttp3 + okhttp + ${okhttp.version} + + + commons-cli + commons-cli + 1.6.0 + + + commons-codec + commons-codec + ${commons-codec.version} + + + io.grpc + grpc-netty-shaded + 1.63.0 + + + io.grpc + grpc-protobuf + 1.63.0 + + + io.grpc + grpc-stub + 1.63.0 + + + io.jsonwebtoken + jjwt-api + ${jjwt.version} + + + javax.annotation + javax.annotation-api + 1.3.2 + + + net.jodah + failsafe + 2.4.4 + + + org.apache.arrow + arrow-vector + + + org.apache.calcite.avatica + avatica + ${avatica.version} + + + org.apache.commons + commons-collections4 + 4.4 + + + org.apache.commons + commons-lang3 + ${commons-lang3.version} + + + org.slf4j + slf4j-simple + ${slf4j.version} + + + org.projectlombok + lombok + ${lombok.version} + provided + + + io.jsonwebtoken + jjwt-gson + ${jjwt.version} + runtime + + + io.jsonwebtoken + jjwt-impl + ${jjwt.version} + runtime + + + org.apache.arrow + arrow-memory-netty + runtime + + + com.squareup.okhttp3 + mockwebserver + ${okhttp.version} + test + + + org.assertj + assertj-core + ${assertj.version} + test + + + org.grpcmock + grpcmock-junit5 + ${grpcmock.junit5.version} + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.junit.jupiter + junit-jupiter-params + test + + + org.junit.platform + junit-platform-launcher + test + + + org.mockito + mockito-core + test + + + org.mockito + mockito-junit-jupiter + test + + + + + + true + src/main/resources + + + true + src/test/resources + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:3.19.6:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:1.63.0:exe:${os.detected.classifier} + false + + + + + compile + compile-custom + + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.6.0 + + true + shaded + jdbc-shaded + false + + + org.apache + ${shadeBase}.apache + + + io.netty + ${shadeBase}.io.netty + + + io.grpc + ${shadeBase}.io.grpc + + + com.fasterxml.jackson + ${shadeBase}.com.fasterxml.jackson + + + io.jsonwebtoken + ${shadeBase}.io.jsonwebtoken + + + io.vavr + ${shadeBase}.io.vavr + + + com.squareup + ${shadeBase}.com.squareup + + + com.google + ${shadeBase}.com.google + + + net.jodah + ${shadeBase}.net.jodah + + + org.projectlombok + ${shadeBase}.org.projectlombok + + + javax.annotation + ${shadeBase}.javax.annotation + + + com.google.protobuf + ${shadeBase}.com.google.protobuf + + + commons-cli + ${shadeBase}.commons-cli + + + commons-codec + ${shadeBase}.commons-codec + + + org.slf4j + ${shadeBase}.org.slf4j + + + + + *:* + + META-INF/LICENSE* + META-INF/NOTICE* + META-INF/DEPENDENCIES + META-INF/maven/** + META-INF/services/com.fasterxml.* + META-INF/*.xml + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + .netbeans_automatic_build + git.properties + google-http-client.properties + storage.v1.json + + pipes-fork-server-default-log4j2.xml + dependencies.properties + pipes-fork-server-default-log4j2.xml + + + + org.apache.arrow:arrow-vector + + + codegen/** + + + + org.slf4j:slf4j-simple + + + org/slf4j/** + + + + + + + + + + + + shade + + package + + + + + maven-compiler-plugin + 3.13.0 + + ${java.version} + ${java.version} + + + + org.apache.maven.plugins + maven-assembly-plugin + 3.7.1 + + true + true + + jar-with-dependencies + + + + + make-assembly + + single + + package + + + + + org.projectlombok + lombok-maven-plugin + 1.18.20.0 + + + delombok + + delombok + + + false + ${project.basedir}/src/main/java + ${project.build.directory}/delombok + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 3.11.1 + + ${project.build.directory}/delombok;${project.build.directory}/generated-sources/protobuf + ${java.version} + true + com.salesforce.hyperdb.grpc + ${project.build.directory}/apidocs + + + + + jar + + package + + + + + org.apache.maven.plugins + maven-source-plugin + 3.3.1 + + + + jar + + package + + + + + maven-surefire-plugin + 3.3.0 + + + org.junit.jupiter + junit-jupiter-engine + ${junit-bom.version} + + + + + maven-failsafe-plugin + 3.3.0 + + + org.junit.jupiter + junit-jupiter-engine + ${junit-bom.version} + + + + + + integration-test + verify + + + + + + org.jacoco + jacoco-maven-plugin + ${jacoco.maven.plugin.version} + + + + prepare-agent + + + + report + + report + + test + + + + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + origin/main + ${release.profile.active} + + + + .gitattributes + .gitignore + + + + + true + 4 + + + + + + + src/main/java/**/*.java + src/test/java/**/*.java + + + 2.39.0 + + true + + + + + + ${project.basedir}/license-header.txt + + + + + + + check + apply + + + + + + com.rudikershaw.gitbuildhook + git-build-hook-maven-plugin + ${git-build-hook-maven-plugin.version} + + + .hooks/pre-commit + + + + + + install + + + + + + org.assertj + assertj-assertions-generator-maven-plugin + 2.2.0 + + + com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor + + false + src/test/java + false + true + false + false + true + + + + org.apache.maven.plugins + maven-resources-plugin + 3.2.0 + + + ${*} + + + + + com.github.ekryd.sortpom + sortpom-maven-plugin + 3.4.1 + + ${project.build.sourceEncoding} + custom_1 + 4 + \n + scope,groupId,artifactId + true + true + false + stop + + + + + sort + + generate-sources + + + verify-sorted-pom + + verify + + validate + + + + + com.googlecode.maven-download-plugin + download-maven-plugin + 1.11.3 + + + download-hyper-cpp + + wget + + process-test-resources + + ${hyper-download-url}.${hyperapi.version}.zip + true + ${download.cache.directory}/hyper-unzipped + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 3.1.0 + + + flatten-hyperd + + run + + process-test-resources + + + + + + + + + + + + + + + + + chmod-hyperd + + run + + process-test-resources + + + + + + + + + + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + https://github.com/forcedotcom/datacloud-jdbc + + + Apache License Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + + + + + Data Cloud Query Developer Team + datacloud-query-connector-owners@salesforce.com + Salesforce Data Cloud + https://www.salesforce.com/data/ + + + + scm:git:https://github.com/forcedotcom/datacloud-jdbc.git + scm:git:git@github.com:forcedotcom/datacloud-jdbc.git + https://github.com/forcedotcom/datacloud-jdbc + + + GitHub Issues + https://github.com/forcedotcom/datacloud-jdbc/issues + + + GitHub Actions + https://github.com/forcedotcom/datacloud-jdbc/actions + + + + ossrh + https://oss.sonatype.org/content/repositories/snapshots + + + + + windows + + + windows + + + + + https://downloads.tableau.com/tssoftware/tableauhyperapi-cxx-windows-x86_64-release-main + + + + mac-apple-silicon + + + mac + aarch64 + + + + + https://downloads.tableau.com/tssoftware/tableauhyperapi-cxx-macos-arm64-release-main + + + + mac-x86_64 + + + mac + x86_64 + + + + + https://downloads.tableau.com/tssoftware//tableauhyperapi-cxx-macos-x86_64-release-main + + + + linux + + + !mac os x + unix + + + + + https://downloads.tableau.com/tssoftware/tableauhyperapi-cxx-linux-x86_64-release-main + + + + release + + true + + + + + org.sonatype.plugins + nexus-staging-maven-plugin + 1.7.0 + true + + ossrh + https://oss.sonatype.org/ + true + + + + org.apache.maven.plugins + maven-gpg-plugin + 3.2.7 + + + --pinentry-mode + loopback + + + + + sign-artifacts + + sign + + verify + + + + + + + + diff --git a/src/main/java/com/salesforce/datacloud/jdbc/DataCloudDatasource.java b/src/main/java/com/salesforce/datacloud/jdbc/DataCloudDatasource.java new file mode 100644 index 0000000..1a083c4 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/DataCloudDatasource.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; +import java.util.logging.Logger; +import javax.sql.DataSource; +import lombok.SneakyThrows; + +public class DataCloudDatasource implements DataSource { + private static final String USERNAME_PROPERTY = "userName"; + private static final String PASSWORD_PROPERTY = "password"; + private static final String PRIVATE_KEY_PROPERTY = "privateKey"; + private static final String REFRESH_TOKEN_PROPERTY = "refreshToken"; + private static final String CORE_TOKEN_PROPERTY = "coreToken"; + private static final String CLIENT_ID_PROPERTY = "clientId"; + private static final String CLIENT_SECRET_PROPERTY = "clientSecret"; + private static final String INTERNAL_ENDPOINT_PROPERTY = "internalEndpoint"; + private static final String PORT_PROPERTY = "port"; + private static final String TENANT_ID_PROPERTY = "tenantId"; + private static final String DATASPACE_PROPERTY = "dataspace"; + private static final String CORE_TENANT_ID_PROPERTY = "coreTenantId"; + + protected static final String NOT_SUPPORTED_IN_DATACLOUD_QUERY = + "Datasource method is not supported in Data Cloud query"; + + private String connectionUrl; + private final Properties properties = new Properties(); + + @Override + public Connection getConnection() throws SQLException { + try { + return DriverManager.getConnection(getConnectionUrl(), properties); + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + setUserName(username); + setPassword(password); + return getConnection(); + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public int getLoginTimeout() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @SneakyThrows + @Override + public Logger getParentLogger() { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public T unwrap(Class iface) throws SQLException { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + + private String getConnectionUrl() { + return connectionUrl; + } + + public void setConnectionUrl(String connectionUrl) { + this.connectionUrl = connectionUrl; + } + + public void setUserName(String userName) { + this.properties.setProperty(USERNAME_PROPERTY, userName); + } + + public void setPassword(String password) { + this.properties.setProperty(PASSWORD_PROPERTY, password); + } + + public void setPrivateKey(String privateKey) { + this.properties.setProperty(PRIVATE_KEY_PROPERTY, privateKey); + } + + public void setRefreshToken(String refreshToken) { + this.properties.setProperty(REFRESH_TOKEN_PROPERTY, refreshToken); + } + + public void setCoreToken(String coreToken) { + this.properties.setProperty(CORE_TOKEN_PROPERTY, coreToken); + } + + public void setInternalEndpoint(String internalEndpoint) { + this.properties.setProperty(INTERNAL_ENDPOINT_PROPERTY, internalEndpoint); + } + + public void setPort(String port) { + this.properties.setProperty(PORT_PROPERTY, port); + } + + public void setTenantId(String tenantId) { + this.properties.setProperty(TENANT_ID_PROPERTY, tenantId); + } + + public void setDataspace(String dataspace) { + this.properties.setProperty(DATASPACE_PROPERTY, dataspace); + } + + public void setCoreTenantId(String coreTenantId) { + this.properties.setProperty(CORE_TENANT_ID_PROPERTY, coreTenantId); + } + + public void setClientId(String clientId) { + this.properties.setProperty(CLIENT_ID_PROPERTY, clientId); + } + + public void setClientSecret(String clientSecret) { + this.properties.setProperty(CLIENT_SECRET_PROPERTY, clientSecret); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriver.java b/src/main/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriver.java new file mode 100644 index 0000000..eb001eb --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriver.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +import com.salesforce.datacloud.jdbc.config.DriverVersion; +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.DriverPropertyInfo; +import java.sql.SQLException; +import java.util.Properties; +import java.util.logging.Logger; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class DataCloudJDBCDriver implements Driver { + private static Driver registeredDriver; + + static { + try { + register(); + log.info("DataCloud JDBC driver registered"); + } catch (SQLException e) { + log.error("Error occurred while registering DataCloud JDBC driver. {}", e.getMessage()); + throw new ExceptionInInitializerError(e); + } + } + + private static void register() throws SQLException { + if (isRegistered()) { + throw new IllegalStateException("Driver is already registered. It can only be registered once."); + } + registeredDriver = new DataCloudJDBCDriver(); + DriverManager.registerDriver(registeredDriver); + } + + public static boolean isRegistered() { + return registeredDriver != null; + } + + @Override + public Connection connect(String url, Properties info) throws SQLException { + if (url == null) { + throw new SQLException("Error occurred while registering JDBC driver. URL is null."); + } + + if (!this.acceptsURL(url)) { + return null; + } + return DataCloudConnection.of(url, info); + } + + @Override + public boolean acceptsURL(String url) { + return DataCloudConnection.acceptsUrl(url); + } + + @Override + public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) { + return new DriverPropertyInfo[0]; + } + + @Override + public int getMajorVersion() { + return DriverVersion.getMajorVersion(); + } + + @Override + public int getMinorVersion() { + return DriverVersion.getMinorVersion(); + } + + @Override + public boolean jdbcCompliant() { + return false; + } + + @Override + public Logger getParentLogger() { + return null; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettings.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettings.java new file mode 100644 index 0000000..6d70aac --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettings.java @@ -0,0 +1,207 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.copy; +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.required; + +import com.salesforce.datacloud.jdbc.config.DriverVersion; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.PropertiesExtensions; +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.SQLException; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.Getter; +import lombok.NonNull; +import lombok.experimental.UtilityClass; +import lombok.val; + +@Getter +public abstract class AuthenticationSettings { + public static AuthenticationSettings of(@NonNull Properties properties) throws SQLException { + checkNotEmpty(properties); + checkHasAllRequired(properties); + + if (hasPrivateKey(properties)) { + return new PrivateKeyAuthenticationSettings(properties); + } else if (hasPassword(properties)) { + return new PasswordAuthenticationSettings(properties); + } else if (hasRefreshToken(properties)) { + return new RefreshTokenAuthenticationSettings(properties); + } else { + throw new DataCloudJDBCException(Messages.PROPERTIES_MISSING, "28000"); + } + } + + public static boolean hasAll(Properties properties, Set keys) { + return keys.stream().allMatch(k -> optional(properties, k).isPresent()); + } + + public static boolean hasAny(Properties properties) { + return hasPrivateKey(properties) || hasPassword(properties) || hasRefreshToken(properties); + } + + private static boolean hasPrivateKey(Properties properties) { + return hasAll(properties, Keys.PRIVATE_KEY_KEYS); + } + + private static boolean hasPassword(Properties properties) { + return hasAll(properties, Keys.PASSWORD_KEYS); + } + + private static boolean hasRefreshToken(Properties properties) { + return hasAll(properties, Keys.REFRESH_TOKEN_KEYS); + } + + private static void checkNotEmpty(@NonNull Properties properties) throws SQLException { + if (properties.isEmpty()) { + throw new DataCloudJDBCException( + Messages.PROPERTIES_EMPTY, "28000", new IllegalArgumentException(Messages.PROPERTIES_EMPTY)); + } + } + + private static void checkHasAllRequired(Properties properties) throws SQLException { + if (hasAll(properties, Keys.REQUIRED_KEYS)) { + return; + } + + val missing = Keys.REQUIRED_KEYS.stream() + .filter(k -> optional(properties, k).isEmpty()) + .collect(Collectors.joining(", ", Messages.PROPERTIES_REQUIRED, "")); + + throw new DataCloudJDBCException(missing, "28000", new IllegalArgumentException(missing)); + } + + final URI getLoginUri() throws SQLException { + try { + return new URI(loginUrl); + } catch (URISyntaxException ex) { + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + protected AuthenticationSettings(@NonNull Properties properties) throws SQLException { + checkNotEmpty(properties); + + this.relevantProperties = copy(properties, Keys.ALL); + + this.loginUrl = required(relevantProperties, Keys.LOGIN_URL); + this.clientId = required(relevantProperties, Keys.CLIENT_ID); + this.clientSecret = required(relevantProperties, Keys.CLIENT_SECRET); + + this.dataspace = optional(relevantProperties, Keys.DATASPACE).orElse(Defaults.DATASPACE); + this.userAgent = optional(relevantProperties, Keys.USER_AGENT).orElse(Defaults.USER_AGENT); + this.maxRetries = optional(relevantProperties, Keys.MAX_RETRIES) + .map(PropertiesExtensions::toIntegerOrNull) + .orElse(Defaults.MAX_RETRIES); + } + + private final Properties relevantProperties; + + private final String loginUrl; + private final String clientId; + private final String clientSecret; + private final String dataspace; + private final String userAgent; + private final int maxRetries; + + @UtilityClass + protected static class Keys { + static final String LOGIN_URL = "loginURL"; + static final String USER_NAME = "userName"; + static final String PASSWORD = "password"; + static final String PRIVATE_KEY = "privateKey"; + static final String CLIENT_SECRET = "clientSecret"; + static final String CLIENT_ID = "clientId"; + static final String DATASPACE = "dataspace"; + static final String MAX_RETRIES = "maxRetries"; + static final String USER_AGENT = "User-Agent"; + static final String REFRESH_TOKEN = "refreshToken"; + + static final Set REQUIRED_KEYS = Set.of(LOGIN_URL, CLIENT_ID, CLIENT_SECRET); + + static final Set OPTIONAL_KEYS = Set.of(DATASPACE, USER_AGENT, MAX_RETRIES); + + static final Set PASSWORD_KEYS = Set.of(USER_NAME, PASSWORD); + + static final Set PRIVATE_KEY_KEYS = Set.of(USER_NAME, PRIVATE_KEY); + + static final Set REFRESH_TOKEN_KEYS = Set.of(REFRESH_TOKEN); + + static final Set ALL = Stream.of( + REQUIRED_KEYS, OPTIONAL_KEYS, PASSWORD_KEYS, PRIVATE_KEY_KEYS, REFRESH_TOKEN_KEYS) + .flatMap(Set::stream) + .collect(Collectors.toSet()); + } + + @UtilityClass + protected static class Defaults { + static final int MAX_RETRIES = 3; + static final String DATASPACE = null; + static final String USER_AGENT = DriverVersion.formatDriverInfo(); + } + + @UtilityClass + protected static class Messages { + static final String PROPERTIES_NULL = "properties is marked non-null but is null"; + static final String PROPERTIES_EMPTY = "Properties cannot be empty when creating AuthenticationSettings."; + static final String PROPERTIES_MISSING = + "Properties did not contain valid settings for known authentication strategies: password, privateKey, or refreshToken with coreToken"; + static final String PROPERTIES_REQUIRED = "Properties did not contain the following required settings: "; + } +} + +@Getter +class PasswordAuthenticationSettings extends AuthenticationSettings { + protected PasswordAuthenticationSettings(@NonNull Properties properties) throws SQLException { + super(properties); + + this.password = required(this.getRelevantProperties(), Keys.PASSWORD); + this.userName = required(this.getRelevantProperties(), Keys.USER_NAME); + } + + private final String password; + private final String userName; +} + +@Getter +class PrivateKeyAuthenticationSettings extends AuthenticationSettings { + protected PrivateKeyAuthenticationSettings(@NonNull Properties properties) throws SQLException { + super(properties); + + this.privateKey = required(this.getRelevantProperties(), Keys.PRIVATE_KEY); + this.userName = required(this.getRelevantProperties(), Keys.USER_NAME); + } + + private final String privateKey; + private final String userName; +} + +@Getter +class RefreshTokenAuthenticationSettings extends AuthenticationSettings { + protected RefreshTokenAuthenticationSettings(@NonNull Properties properties) throws SQLException { + super(properties); + + this.refreshToken = required(this.getRelevantProperties(), Keys.REFRESH_TOKEN); + } + + private final String refreshToken; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategy.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategy.java new file mode 100644 index 0000000..0a79194 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategy.java @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.FormCommand; +import java.net.URI; +import java.sql.SQLException; +import java.util.Properties; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; +import lombok.experimental.UtilityClass; +import lombok.val; +import org.apache.commons.lang3.StringUtils; + +interface AuthenticationStrategy { + static AuthenticationStrategy of(@NonNull Properties properties) throws SQLException { + val settings = AuthenticationSettings.of(properties); + return of(settings); + } + + static AuthenticationStrategy of(@NonNull AuthenticationSettings settings) throws SQLException { + if (settings instanceof PasswordAuthenticationSettings) { + return new PasswordAuthenticationStrategy((PasswordAuthenticationSettings) settings); + } else if (settings instanceof PrivateKeyAuthenticationSettings) { + return new PrivateKeyAuthenticationStrategy((PrivateKeyAuthenticationSettings) settings); + } else if (settings instanceof RefreshTokenAuthenticationSettings) { + return new RefreshTokenAuthenticationStrategy((RefreshTokenAuthenticationSettings) settings); + } else { + val rootCauseException = new IllegalArgumentException(Messages.UNKNOWN_SETTINGS_TYPE); + throw new DataCloudJDBCException(Messages.UNKNOWN_SETTINGS_TYPE, "28000", rootCauseException); + } + } + + @UtilityClass + class Messages { + static final String UNKNOWN_SETTINGS_TYPE = "Resolved settings were an unknown type of AuthenticationSettings"; + } + + @UtilityClass + class Keys { + static final String GRANT_TYPE = "grant_type"; + static final String CLIENT_ID = "client_id"; + static final String CLIENT_SECRET = "client_secret"; + static final String USER_AGENT = "User-Agent"; + } + + FormCommand buildAuthenticate() throws SQLException; + + AuthenticationSettings getSettings(); +} + +abstract class SharedAuthenticationStrategy implements AuthenticationStrategy { + protected final FormCommand.Builder builder(HttpCommandPath path) throws SQLException { + val settings = getSettings(); + val builder = FormCommand.builder(); + + builder.url(settings.getLoginUri()); + builder.suffix(path.getSuffix()); + + builder.header(Keys.USER_AGENT, settings.getUserAgent()); + + return builder; + } +} + +@Getter +@RequiredArgsConstructor +class PasswordAuthenticationStrategy extends SharedAuthenticationStrategy { + private static final String GRANT_TYPE = "password"; + private static final String USERNAME = "username"; + private static final String PASSWORD = "password"; + + private final PasswordAuthenticationSettings settings; + + /** + * username + * password flow docs + */ + @Override + public FormCommand buildAuthenticate() throws SQLException { + val builder = super.builder(HttpCommandPath.AUTHENTICATE); + + builder.bodyEntry(Keys.GRANT_TYPE, GRANT_TYPE); + builder.bodyEntry(USERNAME, settings.getUserName()); + builder.bodyEntry(PASSWORD, settings.getPassword()); + builder.bodyEntry(Keys.CLIENT_ID, settings.getClientId()); + builder.bodyEntry(Keys.CLIENT_SECRET, settings.getClientSecret()); + + return builder.build(); + } +} + +@Getter +@RequiredArgsConstructor +class RefreshTokenAuthenticationStrategy extends SharedAuthenticationStrategy { + private static final String GRANT_TYPE = "refresh_token"; + private static final String REFRESH_TOKEN = "refresh_token"; + + private final RefreshTokenAuthenticationSettings settings; + + /** + * refresh + * token flow docs + */ + @Override + public FormCommand buildAuthenticate() throws SQLException { + val builder = super.builder(HttpCommandPath.AUTHENTICATE); + + builder.bodyEntry(Keys.GRANT_TYPE, GRANT_TYPE); + builder.bodyEntry(REFRESH_TOKEN, settings.getRefreshToken()); + builder.bodyEntry(Keys.CLIENT_ID, settings.getClientId()); + builder.bodyEntry(Keys.CLIENT_SECRET, settings.getClientSecret()); + + return builder.build(); + } +} + +@Getter +@RequiredArgsConstructor +class PrivateKeyAuthenticationStrategy extends SharedAuthenticationStrategy { + private static final String GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + private static final String ASSERTION = "assertion"; + + private final PrivateKeyAuthenticationSettings settings; + + /** + * private key flow + * docs + */ + @Override + public FormCommand buildAuthenticate() throws SQLException { + val builder = super.builder(HttpCommandPath.AUTHENTICATE); + + builder.bodyEntry(Keys.GRANT_TYPE, GRANT_TYPE); + builder.bodyEntry(ASSERTION, JwtParts.buildJwt(settings)); + + return builder.build(); + } +} + +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +class ExchangeTokenAuthenticationStrategy { + private static final String GRANT_TYPE = "urn:salesforce:grant-type:external:cdp"; + private static final String ACCESS_TOKEN = "urn:ietf:params:oauth:token-type:access_token"; + static final String SUBJECT_TOKEN_TYPE = "subject_token_type"; + static final String SUBJECT_TOKEN_KEY = "subject_token"; + static final String DATASPACE = "dataspace"; + + static ExchangeTokenAuthenticationStrategy of(@NonNull AuthenticationSettings settings, @NonNull OAuthToken token) { + return new ExchangeTokenAuthenticationStrategy(settings, token); + } + + @Getter + private final AuthenticationSettings settings; + + private final OAuthToken token; + + /** + * exchange + * token flow docs + */ + public FormCommand toCommand() { + val builder = FormCommand.builder(); + + builder.url(token.getInstanceUrl()); + builder.suffix(HttpCommandPath.EXCHANGE.getSuffix()); + + builder.header(AuthenticationStrategy.Keys.USER_AGENT, settings.getUserAgent()); + + builder.bodyEntry(AuthenticationStrategy.Keys.GRANT_TYPE, GRANT_TYPE); + builder.bodyEntry(SUBJECT_TOKEN_TYPE, ACCESS_TOKEN); + builder.bodyEntry(SUBJECT_TOKEN_KEY, token.getToken()); + + if (StringUtils.isNotBlank(settings.getDataspace())) { + builder.bodyEntry(DATASPACE, settings.getDataspace()); + } + + return builder.build(); + } +} + +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +class RevokeTokenAuthenticationStrategy { + static final String REVOKE_TOKEN_KEY = "token"; + + static RevokeTokenAuthenticationStrategy of(@NonNull AuthenticationSettings settings, @NonNull OAuthToken token) { + return new RevokeTokenAuthenticationStrategy(settings, token); + } + + @Getter + private final AuthenticationSettings settings; + + private final OAuthToken token; + + public FormCommand toCommand() { + val builder = FormCommand.builder(); + + builder.url(token.getInstanceUrl()); + builder.suffix(HttpCommandPath.REVOKE.getSuffix()); + + builder.header(AuthenticationStrategy.Keys.USER_AGENT, settings.getUserAgent()); + + builder.bodyEntry(REVOKE_TOKEN_KEY, token.getToken()); + + return builder.build(); + } +} + +@Getter +enum HttpCommandPath { + AUTHENTICATE("services/oauth2/token"), + EXCHANGE("services/a360/token"), + REVOKE("services/oauth2/revoke"); + + private final URI suffix; + + @SneakyThrows + HttpCommandPath(String suffix) { + this.suffix = new URI(suffix); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudToken.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudToken.java new file mode 100644 index 0000000..f4729f9 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudToken.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.util.Require.requireNotNullOrBlank; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Messages; +import java.io.IOException; +import java.net.URI; +import java.sql.SQLException; +import java.util.Base64; +import java.util.Calendar; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.commons.lang3.StringUtils; + +@Slf4j +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class DataCloudToken { + private static final int JWT_PAYLOAD_INDEX = 1; + private static final String JWT_DELIMITER = "\\."; + private static final String AUDIENCE_TENANT_ID = "audienceTenantId"; + + private final String type; + private final String token; + private final URI tenant; + private final Calendar expiresIn; + + private static final String TENANT_IO_ERROR_RESPONSE = "Error while decoding tenantId."; + + public static DataCloudToken of(DataCloudTokenResponse model) throws SQLException { + val type = model.getTokenType(); + val token = model.getToken(); + val tenantUrl = model.getInstanceUrl(); + + requireNotNullOrBlank(type, "token_type"); + requireNotNullOrBlank(token, "access_token"); + requireNotNullOrBlank(tenantUrl, "instance_url"); + + val expiresIn = Calendar.getInstance(); + expiresIn.add(Calendar.SECOND, model.getExpiresIn()); + + try { + val tenant = URI.create(tenantUrl); + + return new DataCloudToken(type, token, tenant, expiresIn); + } catch (IllegalArgumentException ex) { + val rootCauseException = new IllegalArgumentException( + "Failed to parse the provided tenantUrl: '" + tenantUrl + "'. " + ex.getMessage(), ex.getCause()); + throw new DataCloudJDBCException(Messages.FAILED_LOGIN, "28000", rootCauseException); + } + } + + public boolean isAlive() { + val now = Calendar.getInstance(); + return now.compareTo(expiresIn) <= 0; + } + + public String getTenantUrl() { + return this.tenant.toString(); + } + + public String getTenantId() throws SQLException { + return getTenantId(this.token); + } + + public String getAccessToken() { + return this.type + StringUtils.SPACE + this.token; + } + + private static String getTenantId(String token) throws SQLException { + String[] chunks = token.split(JWT_DELIMITER, -1); + Base64.Decoder decoder = Base64.getUrlDecoder(); + try { + val chunk = chunks[JWT_PAYLOAD_INDEX]; + val decodedChunk = decoder.decode(chunk); + + return new ObjectMapper() + .readTree(decodedChunk) + .get(AUDIENCE_TENANT_ID) + .asText(); + } catch (IOException e) { + throw new DataCloudJDBCException(TENANT_IO_ERROR_RESPONSE, "58030", e); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessor.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessor.java new file mode 100644 index 0000000..a07bcbf --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessor.java @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.getIntegerOrDefault; +import static org.apache.commons.lang3.StringUtils.isBlank; +import static org.apache.commons.lang3.StringUtils.isNotBlank; + +import com.salesforce.datacloud.jdbc.auth.errors.AuthorizationException; +import com.salesforce.datacloud.jdbc.auth.model.AuthenticationResponseWithError; +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import com.salesforce.datacloud.jdbc.auth.model.OAuthTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.ClientBuilder; +import com.salesforce.datacloud.jdbc.http.FormCommand; +import java.sql.SQLException; +import java.time.temporal.ChronoUnit; +import java.util.Properties; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.FailsafeException; +import net.jodah.failsafe.RetryPolicy; +import net.jodah.failsafe.function.CheckedSupplier; +import okhttp3.OkHttpClient; + +@Slf4j +@Builder(access = AccessLevel.PRIVATE) +public class DataCloudTokenProcessor implements TokenProcessor { + static final String MAX_RETRIES_KEY = "maxRetries"; + static final int DEFAULT_MAX_RETRIES = 3; + + private static final String CORE_ERROR_RESPONSE = "Received an error when acquiring oauth access token"; + + private static final String OFF_CORE_ERROR_RESPONSE = + "Received an error when exchanging oauth access token for data cloud token"; + + @Getter + private AuthenticationSettings settings; + + private AuthenticationStrategy strategy; + private OkHttpClient client; + private TokenCache cache; + private RetryPolicy policy; + private RetryPolicy exponentialBackOffPolicy; + + private AuthenticationResponseWithError getTokenWithRetry(CheckedSupplier response) + throws SQLException { + try { + return Failsafe.with(this.policy).get(response); + } catch (FailsafeException ex) { + if (ex.getCause() != null) { + throw new DataCloudJDBCException(ex.getCause().getMessage(), "28000", ex); + } + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + private AuthenticationResponseWithError getDataCloudTokenWithRetry( + CheckedSupplier response) throws SQLException { + try { + return Failsafe.with(this.exponentialBackOffPolicy).get(response); + } catch (FailsafeException ex) { + if (ex.getCause() != null) { + throw new DataCloudJDBCException(ex.getCause().getMessage(), "28000", ex); + } + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + private OAuthToken fetchOAuthToken() throws SQLException { + val command = strategy.buildAuthenticate(); + val model = (OAuthTokenResponse) getTokenWithRetry(() -> { + val response = FormCommand.post(this.client, command, OAuthTokenResponse.class); + return throwExceptionOnError(response, CORE_ERROR_RESPONSE); + }); + return OAuthToken.of(model); + } + + private DataCloudToken fetchDataCloudToken() throws SQLException { + val model = (DataCloudTokenResponse) getDataCloudTokenWithRetry(() -> { + val oauthToken = getOAuthToken(); + val command = + ExchangeTokenAuthenticationStrategy.of(settings, oauthToken).toCommand(); + val response = FormCommand.post(this.client, command, DataCloudTokenResponse.class); + return throwExceptionOnError(response, OFF_CORE_ERROR_RESPONSE); + }); + return DataCloudToken.of(model); + } + + @Override + public OAuthToken getOAuthToken() throws SQLException { + try { + return fetchOAuthToken(); + } catch (Exception ex) { + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + @Override + public DataCloudToken getDataCloudToken() throws SQLException { + val cachedDataCloudToken = cache.getDataCloudToken(); + if (cachedDataCloudToken != null && cachedDataCloudToken.isAlive()) { + return cachedDataCloudToken; + } + + try { + return retrieveAndCacheDataCloudToken(); + } catch (Exception ex) { + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + private DataCloudToken retrieveAndCacheDataCloudToken() throws SQLException { + try { + val dataCloudToken = fetchDataCloudToken(); + cache.setDataCloudToken(dataCloudToken); + return dataCloudToken; + } catch (Exception ex) { + cache.clearDataCloudToken(); + throw new DataCloudJDBCException(ex.getMessage(), "28000", ex); + } + } + + private static AuthenticationResponseWithError throwExceptionOnError( + AuthenticationResponseWithError response, String message) throws SQLException { + val token = response.getToken(); + val code = response.getErrorCode(); + val description = response.getErrorDescription(); + + if (isNotBlank(token) && isNotBlank(code) && isNotBlank(description)) { + log.warn("{} but got error code {} : {}", message, code, description); + } else if (isNotBlank(code) || isNotBlank(description)) { + val authorizationException = AuthorizationException.builder() + .message(message + ". " + code + ": " + description) + .errorCode(code) + .errorDescription(description) + .build(); + throw new DataCloudJDBCException(authorizationException.getMessage(), "28000", authorizationException); + } else if (isBlank(token)) { + throw new DataCloudJDBCException(message + ", no token in response.", "28000"); + } + + return response; + } + + public static DataCloudTokenProcessor of(Properties properties) throws SQLException { + val settings = AuthenticationSettings.of(properties); + val strategy = AuthenticationStrategy.of(settings); + val client = ClientBuilder.buildOkHttpClient(properties); + val policy = buildRetryPolicy(properties); + val exponentialBackOffPolicy = buildExponentialBackoffRetryPolicy(properties); + val cache = new TokenCacheImpl(); + + return DataCloudTokenProcessor.builder() + .client(client) + .policy(policy) + .exponentialBackOffPolicy(exponentialBackOffPolicy) + .cache(cache) + .strategy(strategy) + .settings(settings) + .build(); + } + + static RetryPolicy buildRetryPolicy(Properties properties) { + val maxRetries = getIntegerOrDefault(properties, MAX_RETRIES_KEY, DEFAULT_MAX_RETRIES); + return new RetryPolicy() + .withMaxRetries(maxRetries) + .handleIf(e -> !(e instanceof AuthorizationException)); + } + + static RetryPolicy buildExponentialBackoffRetryPolicy(Properties properties) { + val maxRetries = getIntegerOrDefault(properties, MAX_RETRIES_KEY, DEFAULT_MAX_RETRIES); + return new RetryPolicy() + .withMaxRetries(maxRetries) + .withBackoff(1, 30, ChronoUnit.SECONDS) + .handleIf(e -> !(e instanceof AuthorizationException)); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/OAuthToken.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/OAuthToken.java new file mode 100644 index 0000000..ed98e09 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/OAuthToken.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import com.salesforce.datacloud.jdbc.auth.model.OAuthTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Messages; +import java.net.URI; +import java.sql.SQLException; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Value; +import lombok.val; +import org.apache.commons.lang3.StringUtils; + +@Value +@Builder(access = AccessLevel.PRIVATE) +public class OAuthToken { + private static final String BEARER_PREFIX = "Bearer "; + + String token; + URI instanceUrl; + + public static OAuthToken of(OAuthTokenResponse response) throws SQLException { + val accessToken = response.getToken(); + + if (StringUtils.isBlank(accessToken)) { + throw new DataCloudJDBCException(Messages.FAILED_LOGIN, "28000"); + } + + try { + val instanceUrl = new URI(response.getInstanceUrl()); + + return OAuthToken.builder() + .token(accessToken) + .instanceUrl(instanceUrl) + .build(); + } catch (Exception ex) { + throw new DataCloudJDBCException(Messages.FAILED_LOGIN, "28000", ex); + } + } + + public String getBearerToken() { + return BEARER_PREFIX + getToken(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpers.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpers.java new file mode 100644 index 0000000..74ab319 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpers.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; +import java.security.KeyFactory; +import java.security.interfaces.RSAPrivateKey; +import java.security.spec.PKCS8EncodedKeySpec; +import java.sql.SQLException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; +import java.util.Date; +import lombok.Getter; +import lombok.experimental.UtilityClass; +import lombok.val; + +@Getter +enum Audience { + DEV("login.test1.pc-rnd.salesforce.com"), + PROD("login.salesforce.com"); + + public final String url; + + Audience(String audience) { + this.url = audience; + } + + public static Audience of(String url) throws SQLException { + if (url.contains(TEST_SUFFIX)) { + return Audience.DEV; + } else if (url.endsWith(PROD_SUFFIX)) { + return Audience.PROD; + } else { + val errorMessage = "The specified url: '" + url + "' didn't match any known environments"; + val rootCauseException = new IllegalArgumentException(errorMessage); + throw new DataCloudJDBCException(errorMessage, "28000", rootCauseException); + } + } + + private static final String PROD_SUFFIX = ".salesforce.com"; + private static final String TEST_SUFFIX = ".test1.pc-rnd.salesforce.com"; +} + +@UtilityClass +class JwtParts { + public static String buildJwt(PrivateKeyAuthenticationSettings settings) throws SQLException { + try { + Instant now = Instant.now(); + Audience audience = Audience.of(settings.getLoginUrl()); + RSAPrivateKey privateKey = asPrivateKey(settings.getPrivateKey()); + return Jwts.builder() + .setIssuer(settings.getClientId()) + .setSubject(settings.getUserName()) + .setAudience(audience.url) + .setIssuedAt(Date.from(now)) + .setExpiration(Date.from(now.plus(2L, ChronoUnit.MINUTES))) + .signWith(privateKey, SignatureAlgorithm.RS256) + .compact(); + } catch (Exception ex) { + throw new DataCloudJDBCException(JWT_CREATION_FAILURE, "28000", ex); + } + } + + private static RSAPrivateKey asPrivateKey(String privateKey) throws SQLException { + String rsaPrivateKey = privateKey + .replaceFirst(BEGIN_PRIVATE_KEY, "") + .replaceFirst(END_PRIVATE_KEY, "") + .replaceAll("\\s", ""); + + val bytes = decodeBase64(rsaPrivateKey); + + try { + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(bytes); + val factory = KeyFactory.getInstance("RSA"); + return (RSAPrivateKey) factory.generatePrivate(keySpec); + + } catch (Exception ex) { + throw new DataCloudJDBCException(JWT_CREATION_FAILURE, "28000", ex); + } + } + + private byte[] decodeBase64(String input) { + return Base64.getDecoder().decode(input); + } + + private static final String JWT_CREATION_FAILURE = + "JWT assertion creation failed. Please check Username, Client Id, Private key and try again."; + private static final String BEGIN_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----"; + private static final String END_PRIVATE_KEY = "-----END PRIVATE KEY-----"; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenCache.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenCache.java new file mode 100644 index 0000000..068b294 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenCache.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +interface TokenCache { + void setDataCloudToken(DataCloudToken dataCloudToken); + + void clearDataCloudToken(); + + DataCloudToken getDataCloudToken(); +} + +class TokenCacheImpl implements TokenCache { + private DataCloudToken dataCloudToken; + + @Override + public void setDataCloudToken(DataCloudToken dataCloudToken) { + this.dataCloudToken = dataCloudToken; + } + + @Override + public void clearDataCloudToken() { + this.dataCloudToken = null; + } + + @Override + public DataCloudToken getDataCloudToken() { + return this.dataCloudToken; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenProcessor.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenProcessor.java new file mode 100644 index 0000000..19ea97d --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/TokenProcessor.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import java.sql.SQLException; + +public interface TokenProcessor { + AuthenticationSettings getSettings(); + + OAuthToken getOAuthToken() throws SQLException; + + DataCloudToken getDataCloudToken() throws SQLException; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/errors/AuthorizationException.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/errors/AuthorizationException.java new file mode 100644 index 0000000..9109939 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/errors/AuthorizationException.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth.errors; + +import lombok.Builder; +import lombok.Getter; + +@Getter +@Builder +public class AuthorizationException extends Exception { + private final String message; + private final String errorCode; + private final String errorDescription; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/model/AuthenticationResponseWithError.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/AuthenticationResponseWithError.java new file mode 100644 index 0000000..cf9c927 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/AuthenticationResponseWithError.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth.model; + +/** + * * Check out the error code docs + */ +public interface AuthenticationResponseWithError { + String getToken(); + + String getErrorCode(); + + String getErrorDescription(); +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/model/DataCloudTokenResponse.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/DataCloudTokenResponse.java new file mode 100644 index 0000000..e43e4d0 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/DataCloudTokenResponse.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth.model; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +/** + * The shape of this response can be found here + * under the heading "Exchanging Access Token for Data Cloud Token" + */ +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class DataCloudTokenResponse implements AuthenticationResponseWithError { + @JsonProperty("access_token") + private String token; + + @JsonProperty("instance_url") + private String instanceUrl; + + @JsonProperty("token_type") + private String tokenType; + + @JsonProperty("expires_in") + private int expiresIn; + + @JsonProperty("error") + private String errorCode; + + @JsonProperty("error_description") + private String errorDescription; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/auth/model/OAuthTokenResponse.java b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/OAuthTokenResponse.java new file mode 100644 index 0000000..3f18c13 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/auth/model/OAuthTokenResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth.model; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +/** + * The shape of this response can be found here under the + * heading "Salesforce Grants a New Access Token" + */ +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class OAuthTokenResponse implements AuthenticationResponseWithError { + private String scope; + + @JsonProperty("access_token") + private String token; + + @JsonProperty("instance_url") + private String instanceUrl; + + @JsonProperty("token_type") + private String tokenType; + + @JsonProperty("issued_at") + private String issuedAt; + + @JsonProperty("error") + private String errorCode; + + @JsonProperty("error_description") + private String errorDescription; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/config/DriverVersion.java b/src/main/java/com/salesforce/datacloud/jdbc/config/DriverVersion.java new file mode 100644 index 0000000..0a93c95 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/config/DriverVersion.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import java.util.Properties; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.Value; +import lombok.experimental.UtilityClass; +import lombok.val; + +@UtilityClass +public class DriverVersion { + private static final String DRIVER_NAME = "salesforce-datacloud-jdbc"; + private static final String DATABASE_PRODUCT_NAME = "salesforce-datacloud-queryservice"; + private static final String DATABASE_PRODUCT_VERSION = "1.0"; + + @Getter(lazy = true) + private static final DriverVersionInfo driverVersionInfo = DriverVersionInfo.of(); + + public static int getMajorVersion() { + return getDriverVersionInfo().getMajor(); + } + + public static int getMinorVersion() { + return getDriverVersionInfo().getMinor(); + } + + public static String getDriverName() { + return DRIVER_NAME; + } + + public static String getProductName() { + return DATABASE_PRODUCT_NAME; + } + + public static String getProductVersion() { + return DATABASE_PRODUCT_VERSION; + } + + public static String formatDriverInfo() { + return String.format("%s/%s", getDriverName(), getDriverVersionInfo()); + } + + public static String getDriverVersion() { + return getDriverVersionInfo().toString(); + } +} + +@Value +@Builder(access = AccessLevel.PRIVATE) +class DriverVersionInfo { + @Builder.Default + int major = 0; + + @Builder.Default + int minor = 0; + + static String getVersion(Properties properties) { + String version = properties.getProperty("version"); + if (version == null || version.isEmpty()) { + return "0.0"; + } + return version.replaceAll("-.*$", ""); + } + + static DriverVersionInfo of(Properties properties) { + val builder = DriverVersionInfo.builder(); + String version = getVersion(properties); + if (!version.isEmpty()) { + val chunks = version.split("\\.", -1); + builder.major(Integer.parseInt(chunks[0])); + builder.minor(Integer.parseInt(chunks[1])); + return builder.build(); + } + return builder.build(); + } + + @Override + public String toString() { + return String.format("%d.%d", major, minor); + } + + static DriverVersionInfo of() { + val properties = ResourceReader.readResourceAsProperties("/version.properties"); + return of(properties); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/config/KeywordResources.java b/src/main/java/com/salesforce/datacloud/jdbc/config/KeywordResources.java new file mode 100644 index 0000000..2218c3c --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/config/KeywordResources.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Getter; +import lombok.experimental.UtilityClass; +import lombok.val; + +@UtilityClass +public class KeywordResources { + + // spotless:off + public static final Set SQL_2003_KEYWORDS = Set.of("ADD","ALL","ALLOCATE","ALTER","AND","ANY","ARE","ARRAY", + "AS","ASENSITIVE","ASYMMETRIC","AT","ATOMIC","AUTHORIZATION","BEGIN","BETWEEN","BIGINT","BINARY","BLOB", + "BOOLEAN","BOTH","BY","CALL","CALLED","CASCADED","CASE","CAST","CHAR","CHARACTER","CHECK","CLOB","CLOSE", + "COLLATE","COLUMN","COMMIT","CONDITION","CONNECT","CONSTRAINT","CONTINUE","CORRESPONDING","CREATE","CROSS", + "CUBE","CURRENT","CURRENT_DATE","CURRENT_DEFAULT_TRANSFORM_GROUP","CURRENT_PATH","CURRENT_ROLE", + "CURRENT_TIME","CURRENT_TIMESTAMP","CURRENT_TRANSFORM_GROUP_FOR_TYPE","CURRENT_USER","CURSOR","CYCLE", + "DATE","DAY","DEALLOCATE","DEC","DECIMAL","DECLARE","DEFAULT","DELETE","DEREF","DESCRIBE","DETERMINISTIC", + "DISCONNECT","DISTINCT","DO","DOUBLE","DROP","DYNAMIC","EACH","ELEMENT","ELSE","ELSEIF","END","ESCAPE", + "EXCEPT","EXEC","EXECUTE","EXISTS","EXIT","EXTERNAL","FALSE","FETCH","FILTER","FLOAT","FOR","FOREIGN", + "FREE","FROM","FULL","FUNCTION","GET","GLOBAL","GRANT","GROUP","GROUPING","HANDLER","HAVING","HOLD","HOUR", + "IDENTITY","IF","IMMEDIATE","IN","INDICATOR","INNER","INOUT","INPUT","INSENSITIVE","INSERT","INT","INTEGER", + "INTERSECT","INTERVAL","INTO","IS","ITERATE","JOIN","LANGUAGE","LARGE","LATERAL","LEADING","LEAVE","LEFT", + "LIKE","LOCAL","LOCALTIME","LOCALTIMESTAMP","LOOP","MATCH","MEMBER","MERGE","METHOD","MINUTE","MODIFIES", + "MODULE","MONTH","MULTISET","NATIONAL","NATURAL","NCHAR","NCLOB","NEW","NO","NONE","NOT","NULL","NUMERIC", + "OF","OLD","ON","ONLY","OPEN","OR","ORDER","OUT","OUTER","OUTPUT","OVER","OVERLAPS","PARAMETER","PARTITION", + "PRECISION","PREPARE","PROCEDURE","RANGE","READS","REAL","RECURSIVE","REF","REFERENCES","REFERENCING", + "RELEASE","REPEAT","RESIGNAL","RESULT","RETURN","RETURNS","REVOKE","RIGHT","ROLLBACK","ROLLUP","ROW","ROWS", + "SAVEPOINT","SCOPE","SCROLL","SEARCH","SECOND","SELECT","SENSITIVE","SESSION_USER","SET","SIGNAL","SIMILAR", + "SMALLINT","SOME","SPECIFIC","SPECIFICTYPE","SQL","SQLEXCEPTION","SQLSTATE","SQLWARNING","START","STATIC", + "SUBMULTISET","SYMMETRIC","SYSTEM","SYSTEM_USER","TABLE","TABLESAMPLE","THEN","TIME","TIMESTAMP", + "TIMEZONE_HOUR","TIMEZONE_MINUTE","TO","TRAILING","TRANSLATION","TREAT","TRIGGER","TRUE","UNDO","UNION", + "UNIQUE","UNKNOWN","UNNEST","UNTIL","UPDATE","USER","USING","VALUE","VALUES","VARCHAR","VARYING","WHEN", + "WHENEVER","WHERE","WHILE","WINDOW","WITH","WITHIN","WITHOUT","YEAR"); + // spotless:on + + @Getter(lazy = true) + private final String sqlKeywords = loadSqlKeywords(); + + private static String loadSqlKeywords() { + val hyperSqlLexerKeywords = ResourceReader.readResourceAsStringList("/keywords/hyper_sql_lexer_keywords.txt"); + val difference = hyperSqlLexerKeywords.stream() + .map(String::toUpperCase) + .distinct() + .filter(keyword -> !SQL_2003_KEYWORDS.contains(keyword)) + .sorted() + .collect(Collectors.toList()); + return String.join(",", difference); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/config/QueryResources.java b/src/main/java/com/salesforce/datacloud/jdbc/config/QueryResources.java new file mode 100644 index 0000000..f9aea78 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/config/QueryResources.java @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import lombok.Getter; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class QueryResources { + @Getter(lazy = true) + private final String columnsQuery = loadQuery("get_columns_query"); + + @Getter(lazy = true) + private final String schemasQuery = loadQuery("get_schemas_query"); + + @Getter(lazy = true) + private final String tablesQuery = loadQuery("get_tables_query"); + + private static String loadQuery(String name) { + return ResourceReader.readResourceAsString("/sql/" + name + ".sql"); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/config/ResourceReader.java b/src/main/java/com/salesforce/datacloud/jdbc/config/ResourceReader.java new file mode 100644 index 0000000..0974a78 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/config/ResourceReader.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import lombok.NonNull; +import lombok.SneakyThrows; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@UtilityClass +public class ResourceReader { + public static String readResourceAsString(@NonNull String path) { + val result = new AtomicReference(); + withResourceAsStream(path, in -> result.set(new String(in.readAllBytes(), StandardCharsets.UTF_8))); + return result.get(); + } + + @SneakyThrows + public static Properties readResourceAsProperties(@NonNull String path) { + val result = new Properties(); + withResourceAsStream(path, result::load); + return result; + } + + @SneakyThrows + static void withResourceAsStream(String path, @NonNull IOExceptionThrowingConsumer consumer) { + try (val in = ResourceReader.class.getResourceAsStream(path)) { + if (in == null) { + val message = String.format("%s. path=%s", NOT_FOUND_MESSAGE, path); + throw new DataCloudJDBCException(message, SqlErrorCodes.UNDEFINED_FILE); + } + + consumer.accept(in); + } catch (IOException e) { + val message = String.format("%s. path=%s", IO_EXCEPTION_MESSAGE, path); + log.error(message, e); + throw new DataCloudJDBCException(message, SqlErrorCodes.UNDEFINED_FILE, e); + } + } + + public static List readResourceAsStringList(String path) { + return Arrays.stream(readResourceAsString(path).split("\n")) + .map(String::trim) + .collect(Collectors.toList()); + } + + private static final String NOT_FOUND_MESSAGE = "Resource file not found"; + private static final String IO_EXCEPTION_MESSAGE = "Error while loading resource file"; + + @FunctionalInterface + public interface IOExceptionThrowingConsumer { + void accept(T t) throws IOException; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursor.java new file mode 100644 index 0000000..576d877 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursor.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.util.ThrowingFunction.rethrowFunction; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.io.IOException; +import java.sql.SQLException; +import java.util.Calendar; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.util.AbstractCursor; +import org.apache.calcite.avatica.util.ArrayImpl; + +@AllArgsConstructor +@Slf4j +class ArrowStreamReaderCursor extends AbstractCursor { + + private static final int INIT_ROW_NUMBER = -1; + + private final ArrowStreamReader reader; + + private final AtomicInteger currentRow = new AtomicInteger(INIT_ROW_NUMBER); + + private void wasNullConsumer(boolean wasNull) { + this.wasNull[0] = wasNull; + } + + @SneakyThrows + private VectorSchemaRoot getSchemaRoot() { + return reader.getVectorSchemaRoot(); + } + + @Override + @SneakyThrows + public List createAccessors( + List types, Calendar localCalendar, ArrayImpl.Factory factory) { + return getSchemaRoot().getFieldVectors().stream() + .map(rethrowFunction(this::createAccessor)) + .collect(Collectors.toList()); + } + + private Accessor createAccessor(FieldVector vector) throws SQLException { + return QueryJDBCAccessorFactory.createAccessor(vector, currentRow::get, this::wasNullConsumer); + } + + private boolean loadNextBatch() throws SQLException { + try { + if (reader.loadNextBatch()) { + currentRow.set(0); + return true; + } + } catch (IOException e) { + throw new DataCloudJDBCException(e); + } + return false; + } + + @SneakyThrows + @Override + public boolean next() { + val current = currentRow.incrementAndGet(); + val total = getSchemaRoot().getRowCount(); + + try { + return current < total || loadNextBatch(); + } catch (Exception e) { + throw new DataCloudJDBCException("Failed to load next batch", e); + } + } + + @Override + protected Getter createGetter(int i) { + throw new UnsupportedOperationException(); + } + + @SneakyThrows + @Override + public void close() { + reader.close(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java new file mode 100644 index 0000000..d08b553 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudConnection.java @@ -0,0 +1,468 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.util.Constants.CONNECTION_PROTOCOL; +import static com.salesforce.datacloud.jdbc.util.Constants.LOGIN_URL; +import static com.salesforce.datacloud.jdbc.util.Constants.USER; +import static com.salesforce.datacloud.jdbc.util.Constants.USER_NAME; + +import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; +import com.salesforce.datacloud.jdbc.auth.DataCloudTokenProcessor; +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.ClientBuilder; +import com.salesforce.datacloud.jdbc.interceptor.AuthorizationHeaderInterceptor; +import com.salesforce.datacloud.jdbc.interceptor.DataspaceHeaderInterceptor; +import com.salesforce.datacloud.jdbc.interceptor.HyperDefaultsHeaderInterceptor; +import com.salesforce.datacloud.jdbc.interceptor.TracingHeadersInterceptor; +import com.salesforce.datacloud.jdbc.interceptor.UserAgentHeaderInterceptor; +import com.salesforce.datacloud.jdbc.util.Messages; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannelBuilder; +import java.net.URI; +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.Statement; +import java.sql.Struct; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.commons.lang3.StringUtils; + +@Slf4j +@Getter +@Builder(access = AccessLevel.PACKAGE) +public class DataCloudConnection implements Connection, AutoCloseable { + private static final int DEFAULT_PORT = 443; + + private final AtomicBoolean closed = new AtomicBoolean(false); + + private final TokenProcessor tokenProcessor; + + @NonNull @Builder.Default + private final Properties properties = new Properties(); + + @Setter + @Builder.Default + private List interceptors = new ArrayList<>(); + + @NonNull private final HyperGrpcClientExecutor executor; + + public static DataCloudConnection fromChannel(@NonNull ManagedChannelBuilder builder, Properties properties) + throws SQLException { + val interceptors = getClientInterceptors(null, properties); + val executor = HyperGrpcClientExecutor.of(builder.intercept(interceptors), properties); + + return DataCloudConnection.builder() + .executor(executor) + .properties(properties) + .build(); + } + + /** This flow is not supported by the JDBC Driver Manager, only use it if you know what you're doing. */ + public static DataCloudConnection fromTokenSupplier( + AuthorizationHeaderInterceptor authInterceptor, @NonNull String host, int port, Properties properties) + throws SQLException { + val channel = ManagedChannelBuilder.forAddress(host, port); + return fromTokenSupplier(authInterceptor, channel, properties); + } + + /** This flow is not supported by the JDBC Driver Manager, only use it if you know what you're doing. */ + public static DataCloudConnection fromTokenSupplier( + AuthorizationHeaderInterceptor authInterceptor, + @NonNull ManagedChannelBuilder builder, + Properties properties) + throws SQLException { + val interceptors = getClientInterceptors(authInterceptor, properties); + val executor = HyperGrpcClientExecutor.of(builder.intercept(interceptors), properties); + + return DataCloudConnection.builder() + .executor(executor) + .properties(properties) + .build(); + } + + static List getClientInterceptors( + AuthorizationHeaderInterceptor authInterceptor, Properties properties) { + return Stream.of( + authInterceptor, + new HyperDefaultsHeaderInterceptor(), + TracingHeadersInterceptor.of(), + UserAgentHeaderInterceptor.of(properties), + DataspaceHeaderInterceptor.of(properties)) + .filter(Objects::nonNull) + .peek(t -> log.info("Registering interceptor. interceptor={}", t)) + .collect(Collectors.toList()); + } + + public static DataCloudConnection of(String url, Properties properties) throws SQLException { + var serviceRootUrl = getServiceRootUrl(url); + properties.put(LOGIN_URL, serviceRootUrl); + addClientUsernameIfRequired(properties); + + if (!AuthenticationSettings.hasAny(properties)) { + throw new DataCloudJDBCException("No authentication settings provided"); + } + + val tokenProcessor = DataCloudTokenProcessor.of(properties); + + val host = tokenProcessor.getDataCloudToken().getTenantUrl(); + + val builder = ManagedChannelBuilder.forAddress(host, DEFAULT_PORT); + val authInterceptor = AuthorizationHeaderInterceptor.of(tokenProcessor); + + val interceptors = getClientInterceptors(authInterceptor, properties); + val executor = HyperGrpcClientExecutor.of(builder.intercept(interceptors), properties); + + return DataCloudConnection.builder() + .tokenProcessor(tokenProcessor) + .executor(executor) + .properties(properties) + .build(); + } + + @Override + public Statement createStatement() { + return createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + } + + @Override + public PreparedStatement prepareStatement(String sql) { + return getQueryPreparedStatement(sql); + } + + private DataCloudPreparedStatement getQueryPreparedStatement(String sql) { + return new DataCloudPreparedStatement(this, sql, new DefaultParameterManager()); + } + + @Override + public CallableStatement prepareCall(String sql) { + return null; + } + + @Override + public String nativeSQL(String sql) { + return sql; + } + + @Override + public void setAutoCommit(boolean autoCommit) {} + + @Override + public boolean getAutoCommit() { + return false; + } + + @Override + public void commit() {} + + @Override + public void rollback() {} + + @Override + public void close() { + try { + if (closed.compareAndSet(false, true)) { + executor.close(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean isClosed() { + return closed.get(); + } + + @Override + public DatabaseMetaData getMetaData() { + val client = ClientBuilder.buildOkHttpClient(properties); + val userName = this.properties.getProperty("userName"); + val loginUrl = this.properties.getProperty("loginURL"); + return new DataCloudDatabaseMetadata( + getQueryStatement(), Optional.ofNullable(tokenProcessor), client, loginUrl, userName); + } + + private @NonNull DataCloudStatement getQueryStatement() { + return new DataCloudStatement(this); + } + + @Override + public void setReadOnly(boolean readOnly) {} + + @Override + public boolean isReadOnly() { + return true; + } + + @Override + public void setCatalog(String catalog) {} + + @Override + public String getCatalog() { + return ""; + } + + @Override + public void setTransactionIsolation(int level) {} + + @Override + public int getTransactionIsolation() { + return Connection.TRANSACTION_NONE; + } + + @Override + public SQLWarning getWarnings() { + return null; + } + + @Override + public void clearWarnings() {} + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency) { + return getQueryStatement(); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) { + return getQueryPreparedStatement(sql); + } + + @Override + public CallableStatement prepareCall(String sql, int resultSetType, int resultSetConcurrency) { + return null; + } + + @Override + public Map> getTypeMap() { + return null; + } + + @Override + public void setTypeMap(Map> map) {} + + @Override + public void setHoldability(int holdability) {} + + @Override + public int getHoldability() { + return 0; + } + + @Override + public Savepoint setSavepoint() { + return null; + } + + @Override + public Savepoint setSavepoint(String name) { + return null; + } + + @Override + public void rollback(Savepoint savepoint) {} + + @Override + public void releaseSavepoint(Savepoint savepoint) {} + + @Override + public Statement createStatement(int resultSetType, int resultSetConcurrency, int resultSetHoldability) { + return null; + } + + @Override + public PreparedStatement prepareStatement( + String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) { + return getQueryPreparedStatement(sql); + } + + @Override + public CallableStatement prepareCall( + String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) { + return null; + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) { + return null; + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) { + return null; + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) { + return null; + } + + @Override + public Clob createClob() { + return null; + } + + @Override + public Blob createBlob() { + return null; + } + + @Override + public NClob createNClob() { + return null; + } + + @Override + public SQLXML createSQLXML() { + return null; + } + + @Override + public boolean isValid(int timeout) throws SQLException { + if (timeout < 0) { + throw new DataCloudJDBCException(String.format("Invalid timeout value: %d", timeout)); + } + return !isClosed(); + } + + @Override + public void setClientInfo(String name, String value) {} + + @Override + public void setClientInfo(Properties properties) {} + + @Override + public String getClientInfo(String name) { + return ""; + } + + @Override + public Properties getClientInfo() { + return properties; + } + + @Override + public Array createArrayOf(String typeName, Object[] elements) { + return null; + } + + @Override + public Struct createStruct(String typeName, Object[] attributes) { + return null; + } + + @Override + public void setSchema(String schema) {} + + @Override + public String getSchema() { + return ""; + } + + @Override + public void abort(Executor executor) {} + + @Override + public void setNetworkTimeout(Executor executor, int milliseconds) {} + + @Override + public int getNetworkTimeout() { + return 0; + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (!iface.isInstance(this)) { + throw new DataCloudJDBCException(this.getClass().getName() + " not unwrappable from " + iface.getName()); + } + return (T) this; + } + + @Override + public boolean isWrapperFor(Class iface) { + return iface.isInstance(this); + } + + public static boolean acceptsUrl(String url) { + return url != null && url.startsWith(CONNECTION_PROTOCOL) && urlDoesNotContainScheme(url); + } + + private static boolean urlDoesNotContainScheme(String url) { + val suffix = url.substring(CONNECTION_PROTOCOL.length()); + return !suffix.startsWith("http://") && !suffix.startsWith("https://"); + } + + /** + * Returns the extracted service url from given jdbc endpoint + * + * @param url of the form jdbc:salesforce-datacloud://login.salesforce.com + * @return service url + * @throws SQLException when given url doesn't belong with required datasource + */ + static String getServiceRootUrl(String url) throws SQLException { + if (!acceptsUrl(url)) { + throw new DataCloudJDBCException(Messages.ILLEGAL_CONNECTION_PROTOCOL); + } + + val serviceRootUrl = url.substring(CONNECTION_PROTOCOL.length()); + val noTrailingSlash = StringUtils.removeEnd(serviceRootUrl, "/"); + val host = StringUtils.removeStart(noTrailingSlash, "//"); + + return host.isBlank() ? host : createURI(host).toString(); + } + + private static URI createURI(String host) throws SQLException { + try { + return URI.create("https://" + host); + } catch (IllegalArgumentException e) { + throw new DataCloudJDBCException(Messages.ILLEGAL_CONNECTION_PROTOCOL, e); + } + } + + static void addClientUsernameIfRequired(Properties properties) { + if (properties.containsKey(USER)) { + properties.computeIfAbsent(USER_NAME, p -> properties.get(USER)); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadata.java b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadata.java new file mode 100644 index 0000000..4d7cedc --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadata.java @@ -0,0 +1,967 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.config.KeywordResources; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Constants; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.RowIdLifetime; +import java.sql.SQLException; +import java.util.List; +import java.util.Optional; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; + +@Slf4j +public class DataCloudDatabaseMetadata implements DatabaseMetaData { + private final DataCloudStatement dataCloudStatement; + private final Optional tokenProcessor; + private final OkHttpClient client; + private final String loginURL; + private final String userName; + + public DataCloudDatabaseMetadata( + DataCloudStatement dataCloudStatement, + Optional tokenProcessor, + OkHttpClient client, + String loginURL, + String userName) { + this.dataCloudStatement = dataCloudStatement; + this.tokenProcessor = tokenProcessor; + this.client = client; + this.loginURL = loginURL; + this.userName = userName; + } + + @Override + public boolean allProceduresAreCallable() { + return false; + } + + @Override + public boolean allTablesAreSelectable() { + return true; + } + + @Override + public String getURL() { + return loginURL; + } + + @Override + public String getUserName() { + return userName; + } + + @Override + public boolean isReadOnly() { + return true; + } + + @Override + public boolean nullsAreSortedHigh() { + return false; + } + + @Override + public boolean nullsAreSortedLow() { + return true; + } + + @Override + public boolean nullsAreSortedAtStart() { + return false; + } + + @Override + public boolean nullsAreSortedAtEnd() { + return false; + } + + @Override + public String getDatabaseProductName() { + return Constants.DATABASE_PRODUCT_NAME; + } + + @Override + public String getDatabaseProductVersion() { + return Constants.DATABASE_PRODUCT_VERSION; + } + + @Override + public String getDriverName() { + return Constants.DRIVER_NAME; + } + + @Override + public String getDriverVersion() { + return Constants.DRIVER_VERSION; + } + + @Override + public int getDriverMajorVersion() { + return 1; + } + + @Override + public int getDriverMinorVersion() { + return 0; + } + + @Override + public boolean usesLocalFiles() { + return false; + } + + @Override + public boolean usesLocalFilePerTable() { + return false; + } + + @Override + public boolean supportsMixedCaseIdentifiers() { + return false; + } + + @Override + public boolean storesUpperCaseIdentifiers() { + return false; + } + + @Override + public boolean storesLowerCaseIdentifiers() { + return true; + } + + @Override + public boolean storesMixedCaseIdentifiers() { + return false; + } + + @Override + public boolean supportsMixedCaseQuotedIdentifiers() { + return true; + } + + @Override + public boolean storesUpperCaseQuotedIdentifiers() { + return false; + } + + @Override + public boolean storesLowerCaseQuotedIdentifiers() { + return false; + } + + @Override + public boolean storesMixedCaseQuotedIdentifiers() { + return false; + } + + @Override + public String getIdentifierQuoteString() { + return "\""; + } + + @Override + public String getSQLKeywords() { + return KeywordResources.getSqlKeywords(); + } + + @Override + public String getNumericFunctions() { + return null; + } + + @Override + public String getStringFunctions() { + return null; + } + + @Override + public String getSystemFunctions() { + return null; + } + + @Override + public String getTimeDateFunctions() { + return null; + } + + @Override + public String getSearchStringEscape() { + return "\\"; + } + + @Override + public String getExtraNameCharacters() { + return null; + } + + @Override + public boolean supportsAlterTableWithAddColumn() { + return false; + } + + @Override + public boolean supportsAlterTableWithDropColumn() { + return false; + } + + @Override + public boolean supportsColumnAliasing() { + return true; + } + + @Override + public boolean nullPlusNonNullIsNull() { + return false; + } + + @Override + public boolean supportsConvert() { + return true; + } + + @Override + public boolean supportsConvert(int fromType, int toType) { + return true; + } + + @Override + public boolean supportsTableCorrelationNames() { + return true; + } + + @Override + public boolean supportsDifferentTableCorrelationNames() { + return false; + } + + @Override + public boolean supportsExpressionsInOrderBy() { + return true; + } + + @Override + public boolean supportsOrderByUnrelated() { + return true; + } + + @Override + public boolean supportsGroupBy() { + return true; + } + + @Override + public boolean supportsGroupByUnrelated() { + return true; + } + + @Override + public boolean supportsGroupByBeyondSelect() { + return true; + } + + @Override + public boolean supportsLikeEscapeClause() { + return true; + } + + @Override + public boolean supportsMultipleResultSets() { + return false; + } + + @Override + public boolean supportsMultipleTransactions() { + return false; + } + + @Override + public boolean supportsNonNullableColumns() { + return true; + } + + /** + * {@inheritDoc} + * + *

This grammar is defined at: + * http://www.microsoft.com/msdn/sdk/platforms/doc/odbc/src/intropr.htm + * + *

In Appendix C. From this description, we seem to support the ODBC minimal (Level 0) grammar. + * + * @return true + */ + @Override + public boolean supportsMinimumSQLGrammar() { + return true; + } + + @Override + public boolean supportsCoreSQLGrammar() { + return false; + } + + @Override + public boolean supportsExtendedSQLGrammar() { + return false; + } + + @Override + public boolean supportsANSI92EntryLevelSQL() { + return true; + } + + @Override + public boolean supportsANSI92IntermediateSQL() { + return true; + } + + @Override + public boolean supportsANSI92FullSQL() { + return true; + } + + @Override + public boolean supportsIntegrityEnhancementFacility() { + return false; + } + + @Override + public boolean supportsOuterJoins() { + return true; + } + + @Override + public boolean supportsFullOuterJoins() { + return true; + } + + @Override + public boolean supportsLimitedOuterJoins() { + return true; + } + + @Override + public String getSchemaTerm() { + return "schema"; + } + + @Override + public String getProcedureTerm() { + return "procedure"; + } + + @Override + public String getCatalogTerm() { + return "database"; + } + + @Override + public boolean isCatalogAtStart() { + return true; + } + + @Override + public String getCatalogSeparator() { + return "."; + } + + @Override + public boolean supportsSchemasInDataManipulation() { + return false; + } + + @Override + public boolean supportsSchemasInProcedureCalls() { + return false; + } + + @Override + public boolean supportsSchemasInTableDefinitions() { + return false; + } + + @Override + public boolean supportsSchemasInIndexDefinitions() { + return false; + } + + @Override + public boolean supportsSchemasInPrivilegeDefinitions() { + return false; + } + + @Override + public boolean supportsCatalogsInDataManipulation() { + return false; + } + + @Override + public boolean supportsCatalogsInProcedureCalls() { + return false; + } + + @Override + public boolean supportsCatalogsInTableDefinitions() { + return false; + } + + @Override + public boolean supportsCatalogsInIndexDefinitions() { + return false; + } + + @Override + public boolean supportsCatalogsInPrivilegeDefinitions() { + return false; + } + + @Override + public boolean supportsPositionedDelete() { + return false; + } + + @Override + public boolean supportsPositionedUpdate() { + return false; + } + + @Override + public boolean supportsSelectForUpdate() { + return false; + } + + @Override + public boolean supportsStoredProcedures() { + return false; + } + + @Override + public boolean supportsSubqueriesInComparisons() { + return true; + } + + @Override + public boolean supportsSubqueriesInExists() { + return true; + } + + @Override + public boolean supportsSubqueriesInIns() { + return true; + } + + @Override + public boolean supportsSubqueriesInQuantifieds() { + return true; + } + + @Override + public boolean supportsCorrelatedSubqueries() { + return true; + } + + @Override + public boolean supportsUnion() { + return true; + } + + @Override + public boolean supportsUnionAll() { + return true; + } + + @Override + public boolean supportsOpenCursorsAcrossCommit() { + return false; + } + + @Override + public boolean supportsOpenCursorsAcrossRollback() { + return false; + } + + @Override + public boolean supportsOpenStatementsAcrossCommit() { + return false; + } + + @Override + public boolean supportsOpenStatementsAcrossRollback() { + return false; + } + + @Override + public int getMaxBinaryLiteralLength() { + return 0; + } + + @Override + public int getMaxCharLiteralLength() { + return 0; + } + + @Override + public int getMaxColumnNameLength() { + return 0; + } + + @Override + public int getMaxColumnsInGroupBy() { + return 0; + } + + @Override + public int getMaxColumnsInIndex() { + return 0; + } + + @Override + public int getMaxColumnsInOrderBy() { + return 0; + } + + @Override + public int getMaxColumnsInSelect() { + return 0; + } + + @Override + public int getMaxColumnsInTable() { + return 0; + } + + @Override + public int getMaxConnections() { + return 0; + } + + @Override + public int getMaxCursorNameLength() { + return 0; + } + + @Override + public int getMaxIndexLength() { + return 0; + } + + @Override + public int getMaxSchemaNameLength() { + return 0; + } + + @Override + public int getMaxProcedureNameLength() { + return 0; + } + + @Override + public int getMaxCatalogNameLength() { + return 0; + } + + @Override + public int getMaxRowSize() { + return 0; + } + + @Override + public boolean doesMaxRowSizeIncludeBlobs() { + return false; + } + + @Override + public int getMaxStatementLength() { + return 0; + } + + @Override + public int getMaxStatements() { + return 0; + } + + @Override + public int getMaxTableNameLength() { + return 0; + } + + @Override + public int getMaxTablesInSelect() { + return 0; + } + + @Override + public int getMaxUserNameLength() { + return 0; + } + + @Override + public int getDefaultTransactionIsolation() { + return Connection.TRANSACTION_SERIALIZABLE; + } + + @Override + public boolean supportsTransactions() { + return false; + } + + @Override + public boolean supportsTransactionIsolationLevel(int level) { + return false; + } + + @Override + public boolean supportsDataDefinitionAndDataManipulationTransactions() { + return false; + } + + @Override + public boolean supportsDataManipulationTransactionsOnly() { + return false; + } + + @Override + public boolean dataDefinitionCausesTransactionCommit() { + return false; + } + + @Override + public boolean dataDefinitionIgnoredInTransactions() { + return false; + } + + @Override + public ResultSet getProcedures(String catalog, String schemaPattern, String procedureNamePattern) { + return null; + } + + @Override + public ResultSet getProcedureColumns( + String catalog, String schemaPattern, String procedureNamePattern, String columnNamePattern) { + return null; + } + + @Override + public ResultSet getTables(String catalog, String schemaPattern, String tableNamePattern, String[] types) + throws SQLException { + return QueryMetadataUtil.createTableResultSet(schemaPattern, tableNamePattern, types, dataCloudStatement); + } + + @Override + public ResultSet getSchemas() throws SQLException { + return QueryMetadataUtil.createSchemaResultSet(null, dataCloudStatement); + } + + @Override + public ResultSet getCatalogs() throws SQLException { + return QueryMetadataUtil.createCatalogsResultSet(tokenProcessor, client); + } + + @Override + public ResultSet getTableTypes() throws SQLException { + return QueryMetadataUtil.createTableTypesResultSet(); + } + + @Override + public ResultSet getColumns(String catalog, String schemaPattern, String tableNamePattern, String columnNamePattern) + throws SQLException { + return QueryMetadataUtil.createColumnResultSet( + schemaPattern, tableNamePattern, columnNamePattern, dataCloudStatement); + } + + @Override + public ResultSet getColumnPrivileges(String catalog, String schema, String table, String columnNamePattern) + throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getTablePrivileges(String catalog, String schemaPattern, String tableNamePattern) + throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getBestRowIdentifier(String catalog, String schema, String table, int scope, boolean nullable) + throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getVersionColumns(String catalog, String schema, String table) throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getPrimaryKeys(String catalog, String schema, String table) throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getImportedKeys(String catalog, String schema, String table) throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getExportedKeys(String catalog, String schema, String table) throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getCrossReference( + String parentCatalog, + String parentSchema, + String parentTable, + String foreignCatalog, + String foreignSchema, + String foreignTable) + throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getTypeInfo() throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public ResultSet getIndexInfo(String catalog, String schema, String table, boolean unique, boolean approximate) + throws SQLException { + return MetadataResultSet.of(); + } + + @Override + public boolean supportsResultSetType(int type) { + return ResultSet.TYPE_FORWARD_ONLY == type; + } + + @Override + public boolean supportsResultSetConcurrency(int type, int concurrency) { + return ResultSet.TYPE_FORWARD_ONLY == type && ResultSet.CONCUR_READ_ONLY == concurrency; + } + + @Override + public boolean ownUpdatesAreVisible(int type) { + return false; + } + + @Override + public boolean ownDeletesAreVisible(int type) { + return false; + } + + @Override + public boolean ownInsertsAreVisible(int type) { + return false; + } + + @Override + public boolean othersUpdatesAreVisible(int type) { + return false; + } + + @Override + public boolean othersDeletesAreVisible(int type) { + return false; + } + + @Override + public boolean othersInsertsAreVisible(int type) { + return false; + } + + @Override + public boolean updatesAreDetected(int type) { + return false; + } + + @Override + public boolean deletesAreDetected(int type) { + return false; + } + + @Override + public boolean insertsAreDetected(int type) { + return false; + } + + @Override + public boolean supportsBatchUpdates() { + return false; + } + + @Override + public ResultSet getUDTs(String catalog, String schemaPattern, String typeNamePattern, int[] types) { + return null; + } + + @Override + public Connection getConnection() { + return dataCloudStatement.getConnection(); + } + + @Override + public boolean supportsSavepoints() { + return false; + } + + @Override + public boolean supportsNamedParameters() { + return false; + } + + @Override + public boolean supportsMultipleOpenResults() { + return false; + } + + @Override + public boolean supportsGetGeneratedKeys() { + return false; + } + + @Override + public ResultSet getSuperTypes(String catalog, String schemaPattern, String typeNamePattern) { + return null; + } + + @Override + public ResultSet getSuperTables(String catalog, String schemaPattern, String tableNamePattern) { + return null; + } + + @Override + public ResultSet getAttributes( + String catalog, String schemaPattern, String typeNamePattern, String attributeNamePattern) { + return null; + } + + @Override + public boolean supportsResultSetHoldability(int holdability) { + return false; + } + + @Override + public int getResultSetHoldability() { + return 0; + } + + @Override + public int getDatabaseMajorVersion() { + return 1; + } + + @Override + public int getDatabaseMinorVersion() { + return 0; + } + + @Override + public int getJDBCMajorVersion() { + return 1; + } + + @Override + public int getJDBCMinorVersion() { + return 0; + } + + @Override + public int getSQLStateType() { + return 0; + } + + @Override + public boolean locatorsUpdateCopy() { + return false; + } + + @Override + public boolean supportsStatementPooling() { + return false; + } + + @Override + public RowIdLifetime getRowIdLifetime() { + return null; + } + + @Override + public ResultSet getSchemas(String catalog, String schemaPattern) throws SQLException { + return QueryMetadataUtil.createSchemaResultSet(schemaPattern, dataCloudStatement); + } + + @Override + public boolean supportsStoredFunctionsUsingCallSyntax() { + return false; + } + + @Override + public boolean autoCommitFailureClosesAllResultSets() { + return false; + } + + @Override + public ResultSet getClientInfoProperties() { + return null; + } + + @Override + public ResultSet getFunctions(String catalog, String schemaPattern, String functionNamePattern) { + return null; + } + + @Override + public ResultSet getFunctionColumns( + String catalog, String schemaPattern, String functionNamePattern, String columnNamePattern) { + return null; + } + + @Override + public ResultSet getPseudoColumns( + String catalog, String schemaPattern, String tableNamePattern, String columnNamePattern) { + return null; + } + + @Override + public boolean generatedKeyAlwaysReturned() { + return false; + } + + public List getDataspaces() throws SQLException { + return QueryMetadataUtil.createDataspacesResponse(tokenProcessor, client); + } + + @Override + public T unwrap(Class iface) throws SQLException { + if (!iface.isInstance(this)) { + throw new DataCloudJDBCException(this.getClass().getName() + " not unwrappable from " + iface.getName()); + } + return (T) this; + } + + @Override + public boolean isWrapperFor(Class iface) { + return iface.isInstance(this); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java new file mode 100644 index 0000000..ae3349a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatement.java @@ -0,0 +1,460 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCDateFromDateAndCalendar; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimeFromTimeAndCalendar; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimestampFromTimestampAndCalendar; +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; + +import com.google.protobuf.ByteString; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.ArrowUtils; +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryParameterArrow; +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Duration; +import java.util.Calendar; +import java.util.Map; +import java.util.TimeZone; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +public class DataCloudPreparedStatement extends DataCloudStatement implements PreparedStatement { + private String sql; + private final ParameterManager parameterManager; + private final Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + + DataCloudPreparedStatement(DataCloudConnection connection, ParameterManager parameterManager) { + super(connection); + this.parameterManager = parameterManager; + } + + DataCloudPreparedStatement(DataCloudConnection connection, String sql, ParameterManager parameterManager) { + super(connection); + this.sql = sql; + this.parameterManager = parameterManager; + } + + private void setParameter(int parameterIndex, int sqlType, T value) throws SQLException { + try { + parameterManager.setParameter(parameterIndex, sqlType, value); + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + } + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + this.sql = sql; + return executeQuery(); + } + + @Override + public boolean execute(String sql) throws SQLException { + resultSet = executeQuery(sql); + return true; + } + + @Override + public ResultSet executeQuery() throws SQLException { + final byte[] encodedRow; + try { + encodedRow = ArrowUtils.toArrowByteArray(parameterManager.getParameters(), calendar); + } catch (IOException e) { + throw new DataCloudJDBCException("Failed to encode parameters on prepared statement", e); + } + + val queryParamBuilder = QueryParam.newBuilder() + .setParamStyle(QueryParam.ParameterStyle.QUESTION_MARK) + .setArrowParameters(QueryParameterArrow.newBuilder() + .setData(ByteString.copyFrom(encodedRow)) + .build()) + .build(); + + val client = getQueryExecutor(queryParamBuilder); + val timeout = Duration.ofSeconds(getQueryTimeout()); + + val useSync = optional(this.dataCloudConnection.getProperties(), Constants.FORCE_SYNC) + .map(Boolean::parseBoolean) + .orElse(false); + resultSet = useSync ? executeSyncQuery(sql, client) : executeAdaptiveQuery(sql, client, timeout); + return resultSet; + } + + @Override + public int executeUpdate() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + setParameter(parameterIndex, sqlType, null); + } + + @Override + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + setParameter(parameterIndex, Types.BOOLEAN, x); + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + setParameter(parameterIndex, Types.TINYINT, x); + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + setParameter(parameterIndex, Types.SMALLINT, x); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + setParameter(parameterIndex, Types.INTEGER, x); + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + setParameter(parameterIndex, Types.BIGINT, x); + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + setParameter(parameterIndex, Types.FLOAT, x); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + setParameter(parameterIndex, Types.DOUBLE, x); + } + + @Override + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + setParameter(parameterIndex, Types.DECIMAL, x); + } + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + setParameter(parameterIndex, Types.VARCHAR, x); + } + + @Override + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setDate(int parameterIndex, Date x) throws SQLException { + setParameter(parameterIndex, Types.DATE, x); + } + + @Override + public void setTime(int parameterIndex, Time x) throws SQLException { + setParameter(parameterIndex, Types.TIME, x); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + setParameter(parameterIndex, Types.TIMESTAMP, x); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void clearParameters() { + parameterManager.clearParameters(); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + if (x == null) { + setNull(parameterIndex, Types.NULL); + return; + } + setParameter(parameterIndex, targetSqlType, x); + } + + @Override + public void setObject(int parameterIndex, Object x) throws SQLException { + if (x == null) { + setNull(parameterIndex, Types.NULL); + return; + } + + TypeHandler handler = TypeHandlers.typeHandlerMap.get(x.getClass()); + if (handler != null) { + try { + handler.setParameter(this, parameterIndex, x); + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + } else { + String message = "Object type not supported for: " + x.getClass().getSimpleName() + " (value: " + x + ")"; + throw new DataCloudJDBCException(new SQLFeatureNotSupportedException(message)); + } + } + + @Override + public boolean execute() throws SQLException { + resultSet = executeQuery(); + return true; + } + + @Override + public void addBatch() throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, int length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setRef(int parameterIndex, Ref x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, Blob x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Clob x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setArray(int parameterIndex, Array x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + val utcDate = getUTCDateFromDateAndCalendar(x, cal); + setParameter(parameterIndex, Types.DATE, utcDate); + } + + @Override + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + val utcTime = getUTCTimeFromTimeAndCalendar(x, cal); + setParameter(parameterIndex, Types.TIME, utcTime); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + val utcTimestamp = getUTCTimestampFromTimestampAndCalendar(x, cal); + setParameter(parameterIndex, Types.TIMESTAMP, utcTimestamp); + } + + @Override + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setURL(int parameterIndex, URL x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public ParameterMetaData getParameterMetaData() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setRowId(int parameterIndex, RowId x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNString(int parameterIndex, String value) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, NClob value) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType, int scaleOrLength) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, long length) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setClob(int parameterIndex, Reader reader) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public T unwrap(Class iFace) throws SQLException { + if (iFace.isInstance(this)) { + return iFace.cast(this); + } + throw new DataCloudJDBCException("Cannot unwrap to " + iFace.getName()); + } + + @Override + public boolean isWrapperFor(Class iFace) { + return iFace.isInstance(this); + } +} + +@FunctionalInterface +interface TypeHandler { + void setParameter(PreparedStatement ps, int parameterIndex, Object value) throws SQLException; +} + +@UtilityClass +final class TypeHandlers { + public static final TypeHandler STRING_HANDLER = (ps, idx, value) -> ps.setString(idx, (String) value); + public static final TypeHandler BIGDECIMAL_HANDLER = (ps, idx, value) -> ps.setBigDecimal(idx, (BigDecimal) value); + public static final TypeHandler SHORT_HANDLER = (ps, idx, value) -> ps.setShort(idx, (Short) value); + public static final TypeHandler INTEGER_HANDLER = (ps, idx, value) -> ps.setInt(idx, (Integer) value); + public static final TypeHandler LONG_HANDLER = (ps, idx, value) -> ps.setLong(idx, (Long) value); + public static final TypeHandler FLOAT_HANDLER = (ps, idx, value) -> ps.setFloat(idx, (Float) value); + public static final TypeHandler DOUBLE_HANDLER = (ps, idx, value) -> ps.setDouble(idx, (Double) value); + public static final TypeHandler DATE_HANDLER = (ps, idx, value) -> ps.setDate(idx, (Date) value); + public static final TypeHandler TIME_HANDLER = (ps, idx, value) -> ps.setTime(idx, (Time) value); + public static final TypeHandler TIMESTAMP_HANDLER = (ps, idx, value) -> ps.setTimestamp(idx, (Timestamp) value); + public static final TypeHandler BOOLEAN_HANDLER = (ps, idx, value) -> ps.setBoolean(idx, (Boolean) value); + static final Map, TypeHandler> typeHandlerMap = Map.ofEntries( + Map.entry(String.class, STRING_HANDLER), + Map.entry(BigDecimal.class, BIGDECIMAL_HANDLER), + Map.entry(Short.class, SHORT_HANDLER), + Map.entry(Integer.class, INTEGER_HANDLER), + Map.entry(Long.class, LONG_HANDLER), + Map.entry(Float.class, FLOAT_HANDLER), + Map.entry(Double.class, DOUBLE_HANDLER), + Map.entry(Date.class, DATE_HANDLER), + Map.entry(Time.class, TIME_HANDLER), + Map.entry(Timestamp.class, TIMESTAMP_HANDLER), + Map.entry(Boolean.class, BOOLEAN_HANDLER)); +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java new file mode 100644 index 0000000..98e62b2 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudResultSet.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import java.sql.ResultSet; + +public interface DataCloudResultSet extends ResultSet { + String getQueryId(); + + String getStatus(); + + boolean isReady(); +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java new file mode 100644 index 0000000..41ceeed --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/DataCloudStatement.java @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.getIntegerOrDefault; +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; + +import com.salesforce.datacloud.jdbc.core.listener.AdaptiveQueryStatusListener; +import com.salesforce.datacloud.jdbc.core.listener.AsyncQueryStatusListener; +import com.salesforce.datacloud.jdbc.core.listener.QueryStatusListener; +import com.salesforce.datacloud.jdbc.core.listener.SyncQueryStatusListener; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import com.salesforce.hyperdb.grpc.QueryParam; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.Statement; +import java.time.Duration; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +public class DataCloudStatement implements Statement { + protected ResultSet resultSet; + + protected static final String NOT_SUPPORTED_IN_DATACLOUD_QUERY = "Write is not supported in Data Cloud query"; + protected static final String BATCH_EXECUTION_IS_NOT_SUPPORTED = + "Batch execution is not supported in Data Cloud query"; + private static final String QUERY_TIMEOUT = "queryTimeout"; + public static final int DEFAULT_QUERY_TIMEOUT = 3 * 60 * 60; + + protected final DataCloudConnection dataCloudConnection; + + private int queryTimeout; + + public DataCloudStatement(@NonNull DataCloudConnection connection) { + this.dataCloudConnection = connection; + this.queryTimeout = getIntegerOrDefault(connection.getProperties(), QUERY_TIMEOUT, DEFAULT_QUERY_TIMEOUT); + } + + protected QueryStatusListener listener; + + protected HyperGrpcClientExecutor getQueryExecutor() { + return getQueryExecutor(null); + } + + protected HyperGrpcClientExecutor getQueryExecutor(QueryParam additionalQueryParams) { + val clientBuilder = dataCloudConnection.getExecutor().toBuilder(); + + clientBuilder.interceptors(dataCloudConnection.getInterceptors()); + + if (additionalQueryParams != null) { + clientBuilder.additionalQueryParams(additionalQueryParams); + } + + return clientBuilder.queryTimeout(getQueryTimeout()).build(); + } + + private void assertQueryReady() throws SQLException { + if (listener == null) { + throw new DataCloudJDBCException("a query was not executed before attempting to access results"); + } + + if (!listener.isReady()) { + throw new DataCloudJDBCException("query results were not ready"); + } + } + + public boolean isReady() { + return listener.isReady(); + } + + @Override + public boolean execute(String sql) throws SQLException { + log.debug("Entering execute"); + this.resultSet = executeQuery(sql); + return true; + } + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + log.debug("Entering executeQuery"); + + val useSync = optional(this.dataCloudConnection.getProperties(), Constants.FORCE_SYNC) + .map(Boolean::parseBoolean) + .orElse(false); + resultSet = useSync ? executeSyncQuery(sql) : executeAdaptiveQuery(sql); + + return resultSet; + } + + public DataCloudResultSet executeSyncQuery(String sql) throws SQLException { + log.debug("Entering executeSyncQuery"); + val client = getQueryExecutor(); + return executeSyncQuery(sql, client); + } + + protected DataCloudResultSet executeSyncQuery(String sql, HyperGrpcClientExecutor client) throws SQLException { + listener = SyncQueryStatusListener.of(sql, client); + resultSet = listener.generateResultSet(); + log.info("executeSyncQuery completed. queryId={}", listener.getQueryId()); + return (DataCloudResultSet) resultSet; + } + + public DataCloudResultSet executeAdaptiveQuery(String sql) throws SQLException { + log.debug("Entering executeAdaptiveQuery"); + val client = getQueryExecutor(); + val timeout = Duration.ofSeconds(getQueryTimeout()); + return executeAdaptiveQuery(sql, client, timeout); + } + + protected DataCloudResultSet executeAdaptiveQuery(String sql, HyperGrpcClientExecutor client, Duration timeout) + throws SQLException { + listener = AdaptiveQueryStatusListener.of(sql, client, timeout); + resultSet = listener.generateResultSet(); + log.info("executeAdaptiveQuery completed. queryId={}", listener.getQueryId()); + return (DataCloudResultSet) resultSet; + } + + public DataCloudStatement executeAsyncQuery(String sql) throws SQLException { + log.debug("Entering executeAsyncQuery"); + val client = getQueryExecutor(); + return executeAsyncQuery(sql, client); + } + + protected DataCloudStatement executeAsyncQuery(String sql, HyperGrpcClientExecutor client) throws SQLException { + listener = AsyncQueryStatusListener.of(sql, client); + log.info("executeAsyncQuery completed. queryId={}", listener.getQueryId()); + return this; + } + + @Override + public int executeUpdate(String sql) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void close() throws SQLException { + log.debug("Entering close"); + if (resultSet != null) { + try { + resultSet.close(); + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + } + log.debug("Exiting close"); + } + + @Override + public int getMaxFieldSize() { + return 0; + } + + @Override + public void setMaxFieldSize(int max) {} + + @Override + public int getMaxRows() { + return 0; + } + + @Override + public void setMaxRows(int max) {} + + @Override + public void setEscapeProcessing(boolean enable) {} + + @Override + public int getQueryTimeout() { + return queryTimeout; + } + + @Override + public void setQueryTimeout(int seconds) { + if (seconds < 0) { + this.queryTimeout = DEFAULT_QUERY_TIMEOUT; + } else { + this.queryTimeout = seconds; + } + } + + @Override + public void cancel() {} + + @Override + public SQLWarning getWarnings() { + return null; + } + + @Override + public void clearWarnings() {} + + @Override + public void setCursorName(String name) {} + + @Override + public ResultSet getResultSet() throws SQLException { + log.debug("Entering getResultSet"); + assertQueryReady(); + + if (resultSet == null) { + resultSet = listener.generateResultSet(); + } + log.info("getResultSet completed. queryId={}", listener.getQueryId()); + return resultSet; + } + + @Override + public int getUpdateCount() { + return 0; + } + + @Override + public boolean getMoreResults() { + return false; + } + + @Override + public void setFetchDirection(int direction) {} + + @Override + public int getFetchDirection() { + return ResultSet.FETCH_FORWARD; + } + + @Override + public void setFetchSize(int rows) {} + + @Override + public int getFetchSize() { + return 0; + } + + @Override + public int getResultSetConcurrency() { + return 0; + } + + @Override + public int getResultSetType() { + return ResultSet.TYPE_FORWARD_ONLY; + } + + @Override + public void addBatch(String sql) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void clearBatch() throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public int[] executeBatch() throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public Connection getConnection() { + return dataCloudConnection; + } + + @Override + public boolean getMoreResults(int current) { + return false; + } + + @Override + public ResultSet getGeneratedKeys() { + return null; + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + throw new DataCloudJDBCException(BATCH_EXECUTION_IS_NOT_SUPPORTED, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public int getResultSetHoldability() { + return 0; + } + + @Override + public boolean isClosed() { + return false; + } + + @Override + public void setPoolable(boolean poolable) {} + + @Override + public boolean isPoolable() { + return false; + } + + @Override + public void closeOnCompletion() {} + + @Override + public boolean isCloseOnCompletion() { + return false; + } + + @Override + public T unwrap(Class iFace) throws SQLException { + if (iFace.isInstance(this)) { + return iFace.cast(this); + } + throw new DataCloudJDBCException("Cannot unwrap to " + iFace.getName()); + } + + @Override + public boolean isWrapperFor(Class iFace) { + return iFace.isInstance(this); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannel.java b/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannel.java new file mode 100644 index 0000000..fd60087 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannel.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.google.protobuf.ByteString; +import com.salesforce.datacloud.jdbc.util.ConsumingPeekingIterator; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryResultPartBinary; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.util.Iterator; +import java.util.Optional; +import java.util.stream.Stream; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +public class ExecuteQueryResponseChannel implements ReadableByteChannel { + private static final ByteBuffer empty = ByteBuffer.allocateDirect(0); + private final Iterator iterator; + + public static ExecuteQueryResponseChannel of(Stream stream) { + return new ExecuteQueryResponseChannel(stream.map(ExecuteQueryResponseChannel::fromQueryResult)); + } + + private ExecuteQueryResponseChannel(Stream stream) { + this.iterator = ConsumingPeekingIterator.of(stream, ExecuteQueryResponseChannel::isNotEmpty); + } + + static ByteBuffer fromQueryResult(QueryResult queryResult) { + return Optional.ofNullable(queryResult) + .map(QueryResult::getBinaryPart) + .map(QueryResultPartBinary::getData) + .map(ByteString::toByteArray) + .map(ByteBuffer::wrap) + .orElse(empty); + } + + @Override + public int read(ByteBuffer destination) { + if (this.iterator.hasNext()) { + return transferToDestination(iterator.next(), destination); + } else { + return -1; + } + } + + @Override + public boolean isOpen() { + return iterator.hasNext(); + } + + @Override + public void close() throws IOException {} + + static int transferToDestination(ByteBuffer source, ByteBuffer destination) { + if (source == null) { + return 0; + } + + val transfer = Math.min(destination.remaining(), source.remaining()); + if (transfer > 0) { + destination.put(source.array(), source.arrayOffset() + source.position(), transfer); + source.position(source.position() + transfer); + } + return transfer; + } + + static boolean isNotEmpty(ByteBuffer buffer) { + return buffer.hasRemaining(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettings.java b/src/main/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettings.java new file mode 100644 index 0000000..e25ecf4 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettings.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.val; + +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class HyperConnectionSettings { + private static final String HYPER_SETTING = "serverSetting."; + private final Map settings; + + public static HyperConnectionSettings of(Properties properties) { + val result = properties.entrySet().stream() + .filter(e -> e.getKey().toString().startsWith(HYPER_SETTING)) + .collect( + Collectors.toMap(e -> e.getKey().toString().substring(HYPER_SETTING.length()), e -> e.getValue() + .toString())); + return new HyperConnectionSettings(result); + } + + public Map getSettings() { + return Collections.unmodifiableMap(settings); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java new file mode 100644 index 0000000..1b9c97c --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientExecutor.java @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.interceptor.QueryIdHeaderInterceptor; +import com.salesforce.datacloud.jdbc.util.PropertiesExtensions; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.OutputFormat; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryInfoParam; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryResultParam; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.sql.SQLException; +import java.time.Duration; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@Builder(toBuilder = true) +public class HyperGrpcClientExecutor implements AutoCloseable { + private static final int GRPC_INBOUND_MESSAGE_MAX_SIZE = 128 * 1024 * 1024; + + @NonNull private final ManagedChannel channel; + + @Getter + private final QueryParam additionalQueryParams; + + private final QueryParam settingsQueryParams; + + @Builder.Default + private int queryTimeout = -1; + + private final List interceptors; + + public static HyperGrpcClientExecutor of(@NonNull ManagedChannelBuilder builder, @NonNull Properties properties) + throws SQLException { + val client = HyperGrpcClientExecutor.builder(); + + val settings = HyperConnectionSettings.of(properties).getSettings(); + if (!settings.isEmpty()) { + client.settingsQueryParams( + QueryParam.newBuilder().putAllSettings(settings).build()); + } + + if (PropertiesExtensions.getBooleanOrDefault(properties, "grpc.enableRetries", Boolean.TRUE)) { + int maxRetryAttempts = + PropertiesExtensions.getIntegerOrDefault(properties, "grpc.retryPolicy.maxAttempts", 5); + builder.enableRetry() + .maxRetryAttempts(maxRetryAttempts) + .defaultServiceConfig(retryPolicy(maxRetryAttempts)); + } + + val channel = + builder.maxInboundMessageSize(GRPC_INBOUND_MESSAGE_MAX_SIZE).build(); + return client.channel(channel).build(); + } + + private static Map retryPolicy(int maxRetryAttempts) { + return Map.of( + "methodConfig", + List.of(Map.of( + "name", List.of(Collections.EMPTY_MAP), + "retryPolicy", + Map.of( + "maxAttempts", + String.valueOf(maxRetryAttempts), + "initialBackoff", + "0.5s", + "maxBackoff", + "30s", + "backoffMultiplier", + 2.0, + "retryableStatusCodes", + List.of("UNAVAILABLE"))))); + } + + public static HyperGrpcClientExecutor of( + HyperGrpcClientExecutorBuilder builder, QueryParam additionalQueryParams, int queryTimeout) { + return builder.additionalQueryParams(additionalQueryParams) + .queryTimeout(queryTimeout) + .build(); + } + + public Iterator executeAdaptiveQuery(String sql) throws SQLException { + return execute(sql, QueryParam.TransferMode.ADAPTIVE); + } + + public Iterator executeAsyncQuery(String sql) throws SQLException { + return execute(sql, QueryParam.TransferMode.ASYNC); + } + + public Iterator executeQuery(String sql) throws SQLException { + return execute(sql, QueryParam.TransferMode.SYNC); + } + + public Iterator getQueryInfo(String queryId) { + val param = getQueryInfoParam(queryId); + return getStub(queryId).getQueryInfo(param); + } + + public Iterator getQueryInfoStreaming(String queryId) { + val param = getQueryInfoParamStreaming(queryId); + return getStub(queryId).getQueryInfo(param); + } + + public Iterator getQueryResult(String queryId, long chunkId, boolean omitSchema) { + val param = getQueryResultParam(queryId, chunkId, omitSchema); + return getStub(queryId).getQueryResult(param); + } + + private QueryParam getQueryParams(String sql, QueryParam.TransferMode transferMode) { + val builder = QueryParam.newBuilder() + .setQuery(sql) + .setTransferMode(transferMode) + .setOutputFormat(OutputFormat.ARROW_V3); + + if (additionalQueryParams != null) { + builder.mergeFrom(additionalQueryParams); + } + + if (settingsQueryParams != null) { + builder.mergeFrom(settingsQueryParams); + } + + return builder.build(); + } + + private QueryResultParam getQueryResultParam(String queryId, long chunkId, boolean omitSchema) { + val builder = QueryResultParam.newBuilder() + .setQueryId(queryId) + .setChunkId(chunkId) + .setOutputFormat(OutputFormat.ARROW_V3); + + if (omitSchema) { + builder.setOmitSchema(true); + } + + return builder.build(); + } + + private QueryInfoParam getQueryInfoParam(String queryId) { + return QueryInfoParam.newBuilder().setQueryId(queryId).build(); + } + + private QueryInfoParam getQueryInfoParamStreaming(String queryId) { + return QueryInfoParam.newBuilder() + .setQueryId(queryId) + .setStreaming(true) + .build(); + } + + private Iterator execute(String sql, QueryParam.TransferMode mode) throws SQLException { + val request = getQueryParams(sql, mode); + return getStub().executeQuery(request); + } + + @Getter(lazy = true, value = AccessLevel.PRIVATE) + private final HyperServiceGrpc.HyperServiceBlockingStub stub = lazyStub(); + + private HyperServiceGrpc.HyperServiceBlockingStub lazyStub() { + var result = HyperServiceGrpc.newBlockingStub(channel); + + log.info("Stub will execute query. deadline={}", queryTimeout > 0 ? Duration.ofSeconds(queryTimeout) : "none"); + + if (interceptors != null && !interceptors.isEmpty()) { + log.info("Registering additional interceptors. count={}", interceptors.size()); + result = result.withInterceptors(interceptors.toArray(ClientInterceptor[]::new)); + } + + if (queryTimeout > 0) { + return result.withDeadlineAfter(queryTimeout, TimeUnit.SECONDS); + } else { + return result; + } + } + + private HyperServiceGrpc.HyperServiceBlockingStub getStub(String queryId) { + val queryIdHeaderInterceptor = new QueryIdHeaderInterceptor(queryId); + return getStub().withInterceptors(queryIdHeaderInterceptor); + } + + @Override + public void close() throws Exception { + if (channel.isShutdown() || channel.isTerminated()) { + return; + } + + channel.shutdown(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataCursor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataCursor.java new file mode 100644 index 0000000..3c1b86b --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataCursor.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import lombok.NonNull; +import org.apache.calcite.avatica.util.AbstractCursor; + +public class MetadataCursor extends AbstractCursor { + private final int rowCount; + private int currentRow = -1; + private List data; + private final AtomicBoolean closed = new AtomicBoolean(); + + public MetadataCursor(@NonNull List data) { + this.data = data; + this.rowCount = data.size(); + } + + protected class ListGetter extends AbstractGetter { + protected final int index; + + public ListGetter(int index) { + this.index = index; + } + + @Override + public Object getObject() throws SQLException { + Object o; + try { + o = ((List) data.get(currentRow)).get(index); + } catch (RuntimeException e) { + throw new DataCloudJDBCException(e); + } + wasNull[0] = o == null; + return o; + } + } + + @Override + protected Getter createGetter(int i) { + return new ListGetter(i); + } + + @Override + public boolean next() { + currentRow++; + return currentRow < rowCount; + } + + @Override + public void close() { + try { + closed.set(true); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataResultSet.java b/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataResultSet.java new file mode 100644 index 0000000..ed5b966 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/MetadataResultSet.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import lombok.experimental.UtilityClass; +import lombok.val; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; + +@UtilityClass +public class MetadataResultSet { + public static AvaticaResultSet of() throws SQLException { + val signature = new Meta.Signature(List.of(), null, List.of(), Map.of(), null, Meta.StatementType.SELECT); + return of( + null, + new QueryState(), + signature, + new AvaticaResultSetMetaData(null, null, signature), + TimeZone.getDefault(), + null, + List.of()); + } + + public static AvaticaResultSet of( + AvaticaStatement statement, + QueryState state, + Meta.Signature signature, + ResultSetMetaData resultSetMetaData, + TimeZone timeZone, + Meta.Frame firstFrame, + List data) + throws SQLException { + AvaticaResultSet result; + try { + result = new AvaticaResultSet(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + result.execute2(new MetadataCursor(data), signature.columns); + return result; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/ParameterManager.java b/src/main/java/com/salesforce/datacloud/jdbc/core/ParameterManager.java new file mode 100644 index 0000000..3f2dc83 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/ParameterManager.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.core.model.ParameterBinding; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import lombok.Getter; + +interface ParameterManager { + void setParameter(int index, int sqlType, Object value) throws SQLException; + + void clearParameters(); + + List getParameters(); +} + +@Getter +class DefaultParameterManager implements ParameterManager { + private final List parameters = new ArrayList<>(); + protected final String PARAMETER_INDEX_ERROR = "Parameter index must be greater than 0"; + + @Override + public void setParameter(int index, int sqlType, Object value) throws SQLException { + if (index <= 0) { + throw new DataCloudJDBCException(PARAMETER_INDEX_ERROR); + } + + while (parameters.size() < index) { + parameters.add(null); + } + parameters.set(index - 1, new ParameterBinding(sqlType, value)); + } + + @Override + public void clearParameters() { + parameters.clear(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadata.java b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadata.java new file mode 100644 index 0000000..bc137df --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadata.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.util.Constants; +import java.sql.Types; +import java.util.List; + +public enum QueryDBMetadata { + GET_TABLE_TYPES(List.of("TABLE_TYPE"), List.of(Constants.TEXT), List.of(Types.VARCHAR)), + GET_CATALOGS(List.of("TABLE_CAT"), List.of(Constants.TEXT), List.of(Types.VARCHAR)), + GET_SCHEMAS( + List.of("TABLE_SCHEM", "TABLE_CATALOG"), + List.of(Constants.TEXT, Constants.TEXT), + List.of(Types.VARCHAR, Types.VARCHAR)), + GET_TABLES( + List.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "TABLE_TYPE", + "REMARKS", + "TYPE_CAT", + "TYPE_SCHEM", + "TYPE_NAME", + "SELF_REFERENCING_COL_NAME", + "REF_GENERATION"), + List.of( + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT), + List.of( + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR)), + GET_COLUMNS( + List.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "TYPE_NAME", + "COLUMN_SIZE", + "BUFFER_LENGTH", + "DECIMAL_DIGITS", + "NUM_PREC_RADIX", + "NULLABLE", + "REMARKS", + "COLUMN_DEF", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", + "IS_NULLABLE", + "SCOPE_CATALOG", + "SCOPE_SCHEMA", + "SCOPE_TABLE", + "SOURCE_DATA_TYPE", + "IS_AUTOINCREMENT", + "IS_GENERATEDCOLUMN"), + List.of( + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.INTEGER, + Constants.TEXT, + Constants.INTEGER, + Constants.INTEGER, + Constants.INTEGER, + Constants.INTEGER, + Constants.INTEGER, + Constants.TEXT, + Constants.TEXT, + Constants.INTEGER, + Constants.INTEGER, + Constants.INTEGER, + Constants.INTEGER, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.TEXT, + Constants.SHORT, + Constants.TEXT, + Constants.TEXT), + List.of( + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.INTEGER, + Types.VARCHAR, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.VARCHAR, + Types.VARCHAR, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.SMALLINT, + Types.VARCHAR, + Types.VARCHAR)); + + private final List columnNames; + private final List columnTypes; + private final List columnTypeIds; + + QueryDBMetadata(List columnNames, List columnTypes, List columnTypeIds) { + this.columnNames = columnNames; + this.columnTypes = columnTypes; + this.columnTypeIds = columnTypeIds; + } + + public List getColumnNames() { + return columnNames; + } + + public List getColumnTypes() { + return columnTypes; + } + + public List getColumnTypeIds() { + return columnTypeIds; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursor.java new file mode 100644 index 0000000..8cee2c5 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursor.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.util.ThrowingFunction.rethrowFunction; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.sql.SQLException; +import java.util.Calendar; +import java.util.List; +import java.util.stream.Collectors; +import lombok.SneakyThrows; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.util.AbstractCursor; +import org.apache.calcite.avatica.util.ArrayImpl; + +public class QueryJDBCCursor extends AbstractCursor { + private VectorSchemaRoot root; + private final int rowCount; + private int currentRow = -1; + + public QueryJDBCCursor(VectorSchemaRoot schemaRoot) { + this.root = schemaRoot; + this.rowCount = root.getRowCount(); + } + + @SneakyThrows + @Override + public List createAccessors( + List types, Calendar localCalendar, ArrayImpl.Factory factory) { + return root.getFieldVectors().stream() + .map(rethrowFunction(this::createAccessor)) + .collect(Collectors.toList()); + } + + private Accessor createAccessor(FieldVector vector) throws SQLException { + return QueryJDBCAccessorFactory.createAccessor( + vector, this::getCurrentRow, (boolean wasNull) -> this.wasNull[0] = wasNull); + } + + @Override + protected Getter createGetter(int i) { + throw new UnsupportedOperationException("Not allowed."); + } + + @Override + public boolean next() { + currentRow++; + return currentRow < rowCount; + } + + @Override + public void close() { + try { + AutoCloseables.close(root); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private int getCurrentRow() { + return currentRow; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/QueryMetadataUtil.java b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryMetadataUtil.java new file mode 100644 index 0000000..df67b9a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryMetadataUtil.java @@ -0,0 +1,577 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static java.util.Map.entry; + +import com.salesforce.datacloud.jdbc.auth.OAuthToken; +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.config.QueryResources; +import com.salesforce.datacloud.jdbc.core.model.DataspaceResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.FormCommand; +import com.salesforce.datacloud.jdbc.util.ArrowUtils; +import com.salesforce.datacloud.jdbc.util.Constants; +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TimeZone; +import java.util.stream.Collectors; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import okhttp3.OkHttpClient; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.SqlType; +import org.apache.commons.lang3.StringUtils; + +@Slf4j +@UtilityClass +class QueryMetadataUtil { + static final int NUM_TABLE_METADATA_COLUMNS = 10; + static final int NUM_COLUMN_METADATA_COLUMNS = 24; + static final int NUM_SCHEMA_METADATA_COLUMNS = 2; + static final int NUM_TABLE_TYPES_METADATA_COLUMNS = 1; + static final int NUM_CATALOG_METADATA_COLUMNS = 1; + private static final String SOQL_ENDPOINT_SUFFIX = "services/data/v61.0/query/"; + private static final String SOQL_QUERY_PARAM_KEY = "q"; + + private static final int TABLE_CATALOG_INDEX = 0; + private static final int TABLE_SCHEMA_INDEX = 1; + private static final int TABLE_NAME_INDEX = 2; + private static final int COLUMN_NAME_INDEX = 3; + private static final int DATA_TYPE_INDEX = 4; + private static final int TYPE_NAME_INDEX = 5; + private static final int COLUMN_SIZE_INDEX = 6; + private static final int BUFFER_LENGTH_INDEX = 7; + private static final int DECIMAL_DIGITS_INDEX = 8; + private static final int NUM_PREC_RADIX_INDEX = 9; + private static final int NULLABLE_INDEX = 10; + private static final int DESCRIPTION_INDEX = 11; + private static final int COLUMN_DEFAULT_INDEX = 12; + private static final int SQL_DATA_TYPE_INDEX = 13; + private static final int SQL_DATE_TIME_SUB_INDEX = 14; + private static final int CHAR_OCTET_LENGTH_INDEX = 15; + private static final int ORDINAL_POSITION_INDEX = 16; + private static final int IS_NULLABLE_INDEX = 17; + private static final int SCOPE_CATALOG_INDEX = 18; + private static final int SCOPE_SCHEMA_INDEX = 19; + private static final int SCOPE_TABLE_INDEX = 20; + private static final int SOURCE_DATA_TYPE_INDEX = 21; + private static final int AUTO_INCREMENT_INDEX = 22; + private static final int GENERATED_COLUMN_INDEX = 23; + + private static final Map dbTypeToSql = Map.ofEntries( + Map.entry("int2", SqlType.SMALLINT.toString()), + Map.entry("int4", SqlType.INTEGER.toString()), + Map.entry("oid", SqlType.BIGINT.toString()), + Map.entry("int8", SqlType.BIGINT.toString()), + Map.entry("float", SqlType.DOUBLE.toString()), + Map.entry("float4", SqlType.REAL.toString()), + Map.entry("float8", SqlType.DOUBLE.toString()), + Map.entry("bool", SqlType.BOOLEAN.toString()), + Map.entry("char", SqlType.CHAR.toString()), + Map.entry("text", SqlType.VARCHAR.toString()), + Map.entry("date", SqlType.DATE.toString()), + Map.entry("time", SqlType.TIME.toString()), + Map.entry("timetz", SqlType.TIME.toString()), + Map.entry("timestamp", SqlType.TIMESTAMP.toString()), + Map.entry("timestamptz", SqlType.TIMESTAMP.toString()), + Map.entry("array", SqlType.ARRAY.toString())); + + public static ResultSet createTableResultSet( + String schemaPattern, String tableNamePattern, String[] types, DataCloudStatement dataCloudStatement) + throws SQLException { + + String tablesQuery = getTablesQuery(schemaPattern, tableNamePattern, types); + ResultSet resultSet = dataCloudStatement.executeQuery(tablesQuery); + List data = constructTableData(resultSet); + QueryDBMetadata queryDbMetadata = QueryDBMetadata.GET_TABLES; + + return getMetadataResultSet(queryDbMetadata, NUM_TABLE_METADATA_COLUMNS, data); + } + + static AvaticaResultSet getMetadataResultSet(QueryDBMetadata queryDbMetadata, int columnsCount, List data) + throws SQLException { + QueryResultSetMetadata queryResultSetMetadata = new QueryResultSetMetadata(queryDbMetadata); + List columnMetaData = + ArrowUtils.convertJDBCMetadataToAvaticaColumns(queryResultSetMetadata, columnsCount); + Meta.Signature signature = new Meta.Signature( + columnMetaData, null, Collections.emptyList(), Collections.emptyMap(), null, Meta.StatementType.SELECT); + return MetadataResultSet.of( + null, new QueryState(), signature, queryResultSetMetadata, TimeZone.getDefault(), null, data); + } + + private static List constructTableData(ResultSet resultSet) throws SQLException { + List data = new ArrayList<>(); + try { + while (resultSet.next()) { + List rowData = Arrays.asList( + resultSet.getString("TABLE_CAT"), + resultSet.getString("TABLE_SCHEM"), + resultSet.getString("TABLE_NAME"), + "TABLE", + resultSet.getString("REMARKS"), + resultSet.getString("TYPE_CAT"), + resultSet.getString("TYPE_SCHEM"), + resultSet.getString("TYPE_NAME"), + resultSet.getString("SELF_REFERENCING_COL_NAME"), + resultSet.getString("REF_GENERATION")); + data.add(rowData); + } + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + return data; + } + + private static String getTablesQuery(String schemaPattern, String tableNamePattern, String[] types) { + String tablesQuery = QueryResources.getTablesQuery(); + + if (schemaPattern != null && !schemaPattern.isEmpty()) { + tablesQuery += " AND n.nspname LIKE " + quoteStringLiteral(schemaPattern); + } + + if (tableNamePattern != null && !tableNamePattern.isEmpty()) { + tablesQuery += " AND c.relname LIKE " + quoteStringLiteral(tableNamePattern); + } + if (types != null && types.length > 0) { + tablesQuery += " AND (false "; + StringBuilder orclause = new StringBuilder(); + for (String type : types) { + Map clauses = tableTypeClauses.get(type); + if (clauses != null) { + String clause = clauses.get("SCHEMAS"); + orclause.append(" OR ( ").append(clause).append(" ) "); + } + } + tablesQuery += orclause.toString() + ") "; + } + + tablesQuery += " ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME "; + + return tablesQuery; + } + + public static ResultSet createColumnResultSet( + String schemaPattern, + String tableNamePattern, + String columnNamePattern, + DataCloudStatement dataCloudStatement) + throws SQLException { + + String getColumnsQuery = getColumnsQuery(schemaPattern, tableNamePattern, columnNamePattern); + ResultSet resultSet = dataCloudStatement.executeQuery(getColumnsQuery); + List data = constructColumnData(resultSet); + QueryDBMetadata queryDbMetadata = QueryDBMetadata.GET_COLUMNS; + + return getMetadataResultSet(queryDbMetadata, NUM_COLUMN_METADATA_COLUMNS, data); + } + + private static String getColumnsQuery(String schemaPattern, String tableNamePattern, String columnNamePattern) { + String getColumnsQuery = QueryResources.getColumnsQuery(); + + if (schemaPattern != null && !schemaPattern.isEmpty()) { + getColumnsQuery += " AND n.nspname LIKE " + quoteStringLiteral(schemaPattern); + } + if (tableNamePattern != null && !tableNamePattern.isEmpty()) { + getColumnsQuery += " AND c.relname LIKE " + quoteStringLiteral(tableNamePattern); + } + if (columnNamePattern != null && !columnNamePattern.isEmpty()) { + getColumnsQuery += " AND attname LIKE " + quoteStringLiteral(columnNamePattern); + } + getColumnsQuery += " ORDER BY nspname, c.relname, attnum "; + + return getColumnsQuery; + } + + private static List constructColumnData(ResultSet resultSet) throws SQLException { + List data = new ArrayList<>(); + try { + while (resultSet.next()) { + Object[] rowData = new Object[24]; + + String tableCatalog = null; + rowData[TABLE_CATALOG_INDEX] = tableCatalog; + + String tableSchema = resultSet.getString("nspname"); + rowData[TABLE_SCHEMA_INDEX] = tableSchema; + + String tableName = resultSet.getString("relname"); + rowData[TABLE_NAME_INDEX] = tableName; + + String columnName = resultSet.getString("attname"); + rowData[COLUMN_NAME_INDEX] = columnName; + + int dataType = (int) resultSet.getLong("atttypid"); + rowData[DATA_TYPE_INDEX] = dataType; + + String typeName = resultSet.getString("datatype"); + typeName = typeName == null ? StringUtils.EMPTY : typeName; + if (typeName.toLowerCase().contains("numeric")) { + rowData[TYPE_NAME_INDEX] = SqlType.NUMERIC.toString(); + rowData[DATA_TYPE_INDEX] = SqlType.valueOf(SqlType.NUMERIC.toString()).id; + } else { + rowData[TYPE_NAME_INDEX] = dbTypeToSql.getOrDefault(typeName.toLowerCase(), typeName); + dataType = dbTypeToSql.containsKey(typeName.toLowerCase()) + ? SqlType.valueOf(dbTypeToSql.get(typeName.toLowerCase())).id + : (int) resultSet.getLong("atttypid"); + rowData[DATA_TYPE_INDEX] = dataType; + } + int columnSize = 255; + rowData[COLUMN_SIZE_INDEX] = columnSize; + + int decimalDigits = 2; + rowData[DECIMAL_DIGITS_INDEX] = decimalDigits; + + int numPrecRadix = 10; + rowData[NUM_PREC_RADIX_INDEX] = numPrecRadix; + + int nullable = resultSet.getBoolean("attnotnull") + ? DatabaseMetaData.columnNoNulls + : DatabaseMetaData.columnNullable; + rowData[NULLABLE_INDEX] = nullable; + + String description = resultSet.getString("description"); + rowData[DESCRIPTION_INDEX] = description; + + String columnDefault = resultSet.getString("adsrc"); + rowData[COLUMN_DEFAULT_INDEX] = columnDefault; + + rowData[SQL_DATA_TYPE_INDEX] = null; + + String sqlDateTimeSub = StringUtils.EMPTY; + rowData[SQL_DATE_TIME_SUB_INDEX] = sqlDateTimeSub; + + int charOctetLength = 2; + rowData[CHAR_OCTET_LENGTH_INDEX] = charOctetLength; + + int ordinalPosition = resultSet.getInt("attnum"); + rowData[ORDINAL_POSITION_INDEX] = ordinalPosition; + + String isNullable = resultSet.getBoolean("attnotnull") ? "NO" : "YES"; + rowData[IS_NULLABLE_INDEX] = isNullable; + + rowData[SCOPE_CATALOG_INDEX] = null; + rowData[SCOPE_SCHEMA_INDEX] = null; + rowData[SCOPE_TABLE_INDEX] = null; + rowData[SOURCE_DATA_TYPE_INDEX] = null; + + String identity = resultSet.getString("attidentity"); + String defval = resultSet.getString("adsrc"); + String autoIncrement = "NO"; + if ((defval != null && defval.contains("nextval(")) || identity != null) { + autoIncrement = "YES"; + } + rowData[AUTO_INCREMENT_INDEX] = autoIncrement; + + String generated = resultSet.getString("attgenerated"); + String generatedColumn = generated != null ? "YES" : "NO"; + rowData[GENERATED_COLUMN_INDEX] = generatedColumn; + + data.add(Arrays.asList(rowData)); + } + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + return data; + } + + public static ResultSet createSchemaResultSet(String schemaPattern, DataCloudStatement dataCloudStatement) + throws SQLException { + + String schemasQuery = getSchemasQuery(schemaPattern); + ResultSet resultSet = dataCloudStatement.executeQuery(schemasQuery); + List data = constructSchemaData(resultSet); + QueryDBMetadata queryDbMetadata = QueryDBMetadata.GET_SCHEMAS; + + return getMetadataResultSet(queryDbMetadata, NUM_SCHEMA_METADATA_COLUMNS, data); + } + + private static String getSchemasQuery(String schemaPattern) { + String schemasQuery = QueryResources.getSchemasQuery(); + if (StringUtils.isNotEmpty(schemaPattern)) { + schemasQuery += " AND nspname LIKE " + quoteStringLiteral(schemaPattern); + } + return schemasQuery; + } + + private static List constructSchemaData(ResultSet resultSet) throws SQLException { + List data = new ArrayList<>(); + + try { + while (resultSet.next()) { + List rowData = + Arrays.asList(resultSet.getString("TABLE_SCHEM"), resultSet.getString("TABLE_CATALOG")); + data.add(rowData); + } + } catch (SQLException e) { + throw new DataCloudJDBCException(e); + } + return data; + } + + public static ResultSet createTableTypesResultSet() throws SQLException { + + List data = constructTableTypesData(); + QueryDBMetadata queryDbMetadata = QueryDBMetadata.GET_TABLE_TYPES; + + return getMetadataResultSet(queryDbMetadata, NUM_TABLE_TYPES_METADATA_COLUMNS, data); + } + + private static List constructTableTypesData() { + List data = new ArrayList<>(); + + for (Map.Entry> entry : tableTypeClauses.entrySet()) { + List rowData = Arrays.asList(entry.getValue()); + data.add(rowData); + } + return data; + } + + public static ResultSet createCatalogsResultSet(Optional tokenProcessor, OkHttpClient client) + throws SQLException { + val tenantId = tokenProcessor.get().getDataCloudToken().getTenantId(); + val dataspaceName = tokenProcessor.get().getSettings().getDataspace(); + List data = List.of(List.of("lakehouse:" + tenantId + ";" + dataspaceName)); + + QueryDBMetadata queryDbMetadata = QueryDBMetadata.GET_CATALOGS; + + return getMetadataResultSet(queryDbMetadata, NUM_CATALOG_METADATA_COLUMNS, data); + } + + public static List createDataspacesResponse(Optional tokenProcessor, OkHttpClient client) + throws SQLException { + + try { + + val dataspaceResponse = getDataSpaceResponse(tokenProcessor, client, false); + return dataspaceResponse.getRecords().stream() + .map(DataspaceResponse.DataSpaceAttributes::getName) + .collect(Collectors.toList()); + } catch (Exception e) { + throw new DataCloudJDBCException(e); + } + } + + private FormCommand buildGetDataspaceFormCommand(OAuthToken oAuthToken) throws URISyntaxException { + val builder = FormCommand.builder(); + builder.url(oAuthToken.getInstanceUrl()); + builder.suffix(new URI(SOQL_ENDPOINT_SUFFIX)); + builder.queryParameters(Map.of(SOQL_QUERY_PARAM_KEY, "SELECT+name+from+Dataspace")); + builder.header(Constants.AUTHORIZATION, oAuthToken.getBearerToken()); + builder.header(FormCommand.CONTENT_TYPE_HEADER_NAME, Constants.CONTENT_TYPE_JSON); + builder.header("User-Agent", "cdp/jdbc"); + builder.header("enable-stream-flow", "false"); + return builder.build(); + } + + private static DataspaceResponse getDataSpaceResponse( + Optional tokenProcessor, OkHttpClient client, boolean isGetCatalog) throws SQLException { + String errorMessage = isGetCatalog + ? "Token processor is empty. getCatalogs() cannot be executed" + : "Token processor is empty. getDataspaces() cannot be executed"; + if (tokenProcessor.isEmpty()) { + throw new DataCloudJDBCException(errorMessage); + } + try { + val oAuthToken = tokenProcessor.get().getOAuthToken(); + FormCommand httpFormCommand = buildGetDataspaceFormCommand(oAuthToken); + + return FormCommand.get(client, httpFormCommand, DataspaceResponse.class); + + } catch (Exception e) { + throw new DataCloudJDBCException(e); + } + } + + private static final Map> tableTypeClauses = Map.ofEntries( + entry( + "TABLE", + Map.of( + "SCHEMAS", + "c.relkind = 'r' AND n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'r' AND c.relname !~ '^pg_'")), + entry( + "PARTITIONED TABLE", + Map.of( + "SCHEMAS", + "c.relkind = 'p' AND n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'p' AND c.relname !~ '^pg_'")), + entry( + "VIEW", + Map.of( + "SCHEMAS", + "c.relkind = 'v' AND n.nspname <> 'pg_catalog' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'v' AND c.relname !~ '^pg_'")), + entry( + "INDEX", + Map.of( + "SCHEMAS", + "c.relkind = 'i' AND n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'i' AND c.relname !~ '^pg_'")), + entry( + "PARTITIONED INDEX", + Map.of( + "SCHEMAS", + "c.relkind = 'I' AND n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'I' AND c.relname !~ '^pg_'")), + entry("SEQUENCE", Map.of("SCHEMAS", "c.relkind = 'S'", "NOSCHEMAS", "c.relkind = 'S'")), + entry( + "TYPE", + Map.of( + "SCHEMAS", + "c.relkind = 'c' AND n.nspname !~ '^pg_' AND n.nspname <> 'information_schema'", + "NOSCHEMAS", + "c.relkind = 'c' AND c.relname !~ '^pg_'")), + entry( + "SYSTEM TABLE", + Map.of( + "SCHEMAS", + "c.relkind = 'r' AND (n.nspname = 'pg_catalog' OR n.nspname = 'information_schema')", + "NOSCHEMAS", + "c.relkind = 'r' AND c.relname ~ '^pg_' AND c.relname !~ '^pg_toast_' AND c.relname !~ '^pg_temp_'")), + entry( + "SYSTEM TOAST TABLE", + Map.of( + "SCHEMAS", + "c.relkind = 'r' AND n.nspname = 'pg_toast'", + "NOSCHEMAS", + "c.relkind = 'r' AND c.relname ~ '^pg_toast_'")), + entry( + "SYSTEM TOAST INDEX", + Map.of( + "SCHEMAS", + "c.relkind = 'i' AND n.nspname = 'pg_toast'", + "NOSCHEMAS", + "c.relkind = 'i' AND c.relname ~ '^pg_toast_'")), + entry( + "SYSTEM VIEW", + Map.of( + "SCHEMAS", + "c.relkind = 'v' AND (n.nspname = 'pg_catalog' OR n.nspname = 'information_schema') ", + "NOSCHEMAS", + "c.relkind = 'v' AND c.relname ~ '^pg_'")), + entry( + "SYSTEM INDEX", + Map.of( + "SCHEMAS", + "c.relkind = 'i' AND (n.nspname = 'pg_catalog' OR n.nspname = 'information_schema') ", + "NOSCHEMAS", + "c.relkind = 'v' AND c.relname ~ '^pg_' AND c.relname !~ '^pg_toast_' AND c.relname !~ '^pg_temp_'")), + entry( + "TEMPORARY TABLE", + Map.of( + "SCHEMAS", + "c.relkind IN ('r','p') AND n.nspname ~ '^pg_temp_' ", + "NOSCHEMAS", + "c.relkind IN ('r','p') AND c.relname ~ '^pg_temp_' ")), + entry( + "TEMPORARY INDEX", + Map.of( + "SCHEMAS", + "c.relkind = 'i' AND n.nspname ~ '^pg_temp_' ", + "NOSCHEMAS", + "c.relkind = 'i' AND c.relname ~ '^pg_temp_' ")), + entry( + "TEMPORARY VIEW", + Map.of( + "SCHEMAS", + "c.relkind = 'v' AND n.nspname ~ '^pg_temp_' ", + "NOSCHEMAS", + "c.relkind = 'v' AND c.relname ~ '^pg_temp_' ")), + entry( + "TEMPORARY SEQUENCE", + Map.of( + "SCHEMAS", + "c.relkind = 'S' AND n.nspname ~ '^pg_temp_' ", + "NOSCHEMAS", + "c.relkind = 'S' AND c.relname ~ '^pg_temp_' ")), + entry("FOREIGN TABLE", Map.of("SCHEMAS", "c.relkind = 'f'", "NOSCHEMAS", "c.relkind = 'f'")), + entry("MATERIALIZED VIEW", Map.of("SCHEMAS", "c.relkind = 'm'", "NOSCHEMAS", "c.relkind = 'm'"))); + + public static String quoteStringLiteral(String v) { + StringBuilder result = new StringBuilder(); + + result.ensureCapacity(v.length() + 8); + + result.append("E'"); + + boolean escaped = false; + + for (int i = 0; i < v.length(); i++) { + char ch = v.charAt(i); + switch (ch) { + case '\'': + result.append("''"); + break; + case '\\': + result.append("\\\\"); + escaped = true; + break; + case '\n': + result.append("\\n"); + escaped = true; + break; + case '\r': + result.append("\\r"); + escaped = true; + break; + case '\t': + result.append("\\t"); + escaped = true; + break; + case '\b': + result.append("\\b"); + escaped = true; + break; + case '\f': + result.append("\\f"); + escaped = true; + break; + default: + if (ch < ' ') { + result.append('\\').append(String.format("%03o", (int) ch)); + escaped = true; + } else { + result.append(ch); + } + } + } + + if (!escaped) { + result.deleteCharAt(0); + } + + return result.append('\'').toString(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadata.java b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadata.java new file mode 100644 index 0000000..083dc5f --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadata.java @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import java.sql.ResultSetMetaData; +import java.sql.Types; +import java.util.List; +import org.apache.commons.lang3.StringUtils; + +public class QueryResultSetMetadata implements ResultSetMetaData { + private final List columnNames; + private final List columnTypes; + private final List columnTypeIds; + + public QueryResultSetMetadata(List columnNames, List columnTypes, List columnTypeIds) { + this.columnNames = columnNames; + this.columnTypes = columnTypes; + this.columnTypeIds = columnTypeIds; + } + + public QueryResultSetMetadata(QueryDBMetadata metadata) { + this.columnNames = metadata.getColumnNames(); + this.columnTypes = metadata.getColumnTypes(); + this.columnTypeIds = metadata.getColumnTypeIds(); + } + + @Override + public int getColumnCount() { + return columnNames.size(); + } + + @Override + public boolean isAutoIncrement(int column) { + return false; + } + + @Override + public boolean isCaseSensitive(int column) { + return true; + } + + @Override + public boolean isSearchable(int column) { + return false; + } + + @Override + public boolean isCurrency(int column) { + return false; + } + + @Override + public int isNullable(int column) { + return columnNullable; + } + + @Override + public boolean isSigned(int column) { + return false; + } + + @Override + public int getColumnDisplaySize(int column) { + int columnType = getColumnType(column); + switch (columnType) { + case Types.CHAR: + case Types.VARCHAR: + case Types.BINARY: + return getColumnName(column).length(); + case Types.INTEGER: + case Types.BIGINT: + case Types.SMALLINT: + case Types.TINYINT: + return getPrecision(column) + 1; + case Types.DECIMAL: + return getPrecision(column) + 1 + 1; + case Types.DOUBLE: + return 24; + case Types.BOOLEAN: + return 5; + default: + return 25; + } + } + + @Override + public String getColumnLabel(int column) { + if (columnNames == null || column > columnNames.size()) { + return "C" + (column - 1); + } else { + return columnNames.get(column - 1); + } + } + + @Override + public String getColumnName(int column) { + return columnNames.get(column - 1); + } + + @Override + public String getSchemaName(int column) { + return StringUtils.EMPTY; + } + + @Override + public int getPrecision(int column) { + int columnType = getColumnType(column); + switch (columnType) { + case Types.CHAR: + case Types.VARCHAR: + case Types.LONGVARCHAR: + case Types.BINARY: + case Types.BIT: + case Types.VARBINARY: + case Types.LONGVARBINARY: + return getColumnName(column).length(); + case Types.INTEGER: + case Types.DECIMAL: + case Types.BIGINT: + case Types.SMALLINT: + case Types.FLOAT: + case Types.REAL: + case Types.DOUBLE: + case Types.NUMERIC: + case Types.TINYINT: + return 38; + default: + return 0; + } + } + + @Override + public int getScale(int column) { + int columnType = getColumnType(column); + switch (columnType) { + case Types.INTEGER: + case Types.DECIMAL: + case Types.BIGINT: + case Types.SMALLINT: + case Types.FLOAT: + case Types.REAL: + case Types.DOUBLE: + case Types.NUMERIC: + case Types.TINYINT: + return 18; + default: + return 0; + } + } + + @Override + public String getTableName(int column) { + return StringUtils.EMPTY; + } + + @Override + public String getCatalogName(int column) { + return StringUtils.EMPTY; + } + + @Override + public int getColumnType(int column) { + return columnTypeIds.get(column - 1); + } + + @Override + public String getColumnTypeName(int column) { + return columnTypes.get(column - 1); + } + + @Override + public boolean isReadOnly(int column) { + return true; + } + + @Override + public boolean isWritable(int column) { + return false; + } + + @Override + public boolean isDefinitelyWritable(int column) { + return false; + } + + @Override + public String getColumnClassName(int column) { + return null; + } + + @Override + public T unwrap(Class iface) { + return null; + } + + @Override + public boolean isWrapperFor(Class iface) { + return false; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java b/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java new file mode 100644 index 0000000..72da31c --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/StreamingResultSet.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import com.salesforce.datacloud.jdbc.core.listener.QueryStatusListener; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.datacloud.jdbc.util.ArrowUtils; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Collections; +import java.util.TimeZone; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; + +@Getter +@Slf4j +public class StreamingResultSet extends AvaticaResultSet implements DataCloudResultSet { + private static final int ROOT_ALLOCATOR_MB_FROM_V2 = 100 * 1024 * 1024; + private final QueryStatusListener listener; + + private StreamingResultSet( + QueryStatusListener listener, + AvaticaStatement statement, + QueryState state, + Meta.Signature signature, + ResultSetMetaData resultSetMetaData, + TimeZone timeZone, + Meta.Frame firstFrame) + throws SQLException { + super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + this.listener = listener; + } + + @SneakyThrows + public static StreamingResultSet of(String sql, QueryStatusListener listener) { + try { + val channel = ExecuteQueryResponseChannel.of(listener.stream()); + val reader = new ArrowStreamReader(channel, new RootAllocator(ROOT_ALLOCATOR_MB_FROM_V2)); + val schemaRoot = reader.getVectorSchemaRoot(); + val columns = ArrowUtils.toColumnMetaData(schemaRoot.getSchema().getFields()); + val timezone = TimeZone.getDefault(); + val state = new QueryState(); + val signature = new Meta.Signature( + columns, sql, Collections.emptyList(), Collections.emptyMap(), null, Meta.StatementType.SELECT); + val metadata = new AvaticaResultSetMetaData(null, null, signature); + val result = new StreamingResultSet(listener, null, state, signature, metadata, timezone, null); + val cursor = new ArrowStreamReaderCursor(reader); + result.execute2(cursor, columns); + + return result; + } catch (Exception ex) { + throw QueryExceptionHandler.createException(QUERY_FAILURE + sql, ex); + } + } + + @Override + public String getQueryId() { + return listener.getQueryId(); + } + + @Override + public String getStatus() { + return listener.getStatus(); + } + + @Override + public boolean isReady() { + return listener.isReady(); + } + + private static final String QUERY_FAILURE = "Failed to execute query: "; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessor.java new file mode 100644 index 0000000..f490f83 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessor.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLXML; +import java.sql.Struct; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import java.util.function.IntSupplier; +import org.apache.calcite.avatica.util.Cursor.Accessor; + +public abstract class QueryJDBCAccessor implements Accessor { + private final IntSupplier currentRowSupplier; + protected boolean wasNull; + protected QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer; + + protected QueryJDBCAccessor( + IntSupplier currentRowSupplier, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this.currentRowSupplier = currentRowSupplier; + this.wasNullConsumer = wasNullConsumer; + } + + protected int getCurrentRow() { + return currentRowSupplier.getAsInt(); + } + + public abstract Class getObjectClass(); + + @Override + public boolean wasNull() { + return wasNull; + } + + @Override + public String getString() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public boolean getBoolean() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public byte getByte() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public short getShort() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public int getInt() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public long getLong() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public float getFloat() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public double getDouble() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public BigDecimal getBigDecimal(int i) throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public byte[] getBytes() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public InputStream getAsciiStream() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public InputStream getUnicodeStream() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public InputStream getBinaryStream() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Object getObject() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Reader getCharacterStream() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Object getObject(Map> map) throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Ref getRef() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Blob getBlob() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Clob getClob() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Array getArray() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Struct getStruct() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Date getDate(Calendar calendar) throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Time getTime(Calendar calendar) throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public URL getURL() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public NClob getNClob() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public SQLXML getSQLXML() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public String getNString() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public Reader getNCharacterStream() throws SQLException { + throw new DataCloudJDBCException(getOperationNotSupported(this.getClass())); + } + + @Override + public T getObject(Class aClass) { + return null; + } + + private static SQLException getOperationNotSupported(final Class type) { + return new SQLFeatureNotSupportedException( + String.format("Operation not supported for type: %s.", type.getName())); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorFactory.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorFactory.java new file mode 100644 index 0000000..bf09b05 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorFactory.java @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor; + +import com.salesforce.datacloud.jdbc.core.accessor.impl.BaseIntVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.BinaryVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.BooleanVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DateVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DecimalVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DoubleVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.LargeListVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.ListVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.TimeStampVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.TimeVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.VarCharVectorAccessor; +import java.sql.SQLException; +import java.util.function.IntSupplier; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.types.Types; + +public class QueryJDBCAccessorFactory { + @FunctionalInterface + public interface WasNullConsumer { + void setWasNull(boolean wasNull); + } + + public static QueryJDBCAccessor createAccessor( + ValueVector vector, IntSupplier getCurrentRow, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) + throws SQLException { + Types.MinorType arrowType = + Types.getMinorTypeForArrowType(vector.getField().getType()); + if (arrowType.equals(Types.MinorType.VARCHAR)) { + return new VarCharVectorAccessor((VarCharVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.LARGEVARCHAR)) { + return new VarCharVectorAccessor((LargeVarCharVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.DECIMAL)) { + return new DecimalVectorAccessor((DecimalVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.BIT)) { + return new BooleanVectorAccessor((BitVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.FLOAT8)) { + return new DoubleVectorAccessor((Float8Vector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TINYINT)) { + return new BaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.SMALLINT)) { + return new BaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.INT)) { + return new BaseIntVectorAccessor((IntVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.BIGINT)) { + return new BaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.UINT4)) { + return new BaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.VARBINARY)) { + return new BinaryVectorAccessor((VarBinaryVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.LARGEVARBINARY)) { + return new BinaryVectorAccessor((LargeVarBinaryVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.FIXEDSIZEBINARY)) { + return new BinaryVectorAccessor((FixedSizeBinaryVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.DATEDAY)) { + return new DateVectorAccessor((DateDayVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.DATEMILLI)) { + return new DateVectorAccessor((DateMilliVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMENANO)) { + return new TimeVectorAccessor((TimeNanoVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMEMICRO)) { + return new TimeVectorAccessor((TimeMicroVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMEMILLI)) { + return new TimeVectorAccessor((TimeMilliVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESEC)) { + return new TimeVectorAccessor((TimeSecVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPSECTZ)) { + return new TimeStampVectorAccessor((TimeStampSecTZVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPSEC)) { + return new TimeStampVectorAccessor((TimeStampSecVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPMILLITZ)) { + return new TimeStampVectorAccessor((TimeStampMilliTZVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPMILLI)) { + return new TimeStampVectorAccessor((TimeStampMilliVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPMICROTZ)) { + return new TimeStampVectorAccessor((TimeStampMicroTZVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPMICRO)) { + return new TimeStampVectorAccessor((TimeStampMicroVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPNANOTZ)) { + return new TimeStampVectorAccessor((TimeStampNanoTZVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.TIMESTAMPNANO)) { + return new TimeStampVectorAccessor((TimeStampNanoVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.LIST)) { + return new ListVectorAccessor((ListVector) vector, getCurrentRow, wasNullConsumer); + } else if (arrowType.equals(Types.MinorType.LARGELIST)) { + return new LargeListVectorAccessor((LargeListVector) vector, getCurrentRow, wasNullConsumer); + } + + throw new UnsupportedOperationException( + "Unsupported vector type: " + vector.getClass().getName()); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessor.java new file mode 100644 index 0000000..84e5adc --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessor.java @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.core.accessor.impl.NumericGetter.createGetter; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.types.Types.MinorType; + +public class BaseIntVectorAccessor extends QueryJDBCAccessor { + + private final MinorType type; + private final boolean isUnsigned; + private final NumericGetter.Getter getter; + private final NumericGetter.NumericHolder holder; + + private static final String INVALID_TYPE_ERROR_RESPONSE = "Invalid Minor Type provided"; + + public BaseIntVectorAccessor( + TinyIntVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + this(vector, currentRowSupplier, false, setCursorWasNull); + } + + public BaseIntVectorAccessor( + SmallIntVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + this(vector, currentRowSupplier, false, setCursorWasNull); + } + + public BaseIntVectorAccessor( + IntVector vector, IntSupplier currentRowSupplier, QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + this(vector, currentRowSupplier, false, setCursorWasNull); + } + + public BaseIntVectorAccessor( + BigIntVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + this(vector, currentRowSupplier, false, setCursorWasNull); + } + + private BaseIntVectorAccessor( + BaseIntVector vector, + IntSupplier currentRowSupplier, + boolean isUnsigned, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.type = vector.getMinorType(); + this.holder = new NumericGetter.NumericHolder(); + this.getter = createGetter(vector); + this.isUnsigned = isUnsigned; + } + + public BaseIntVectorAccessor( + UInt4Vector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + this(vector, currentRowSupplier, false, setCursorWasNull); + } + + @Override + public long getLong() { + getter.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Class getObjectClass() { + return Long.class; + } + + @Override + public String getString() { + final long number = getLong(); + + if (this.wasNull) { + return null; + } else { + return isUnsigned ? Long.toUnsignedString(number) : Long.toString(number); + } + } + + @Override + public byte getByte() { + return (byte) getLong(); + } + + @Override + public short getShort() { + return (short) getLong(); + } + + @Override + public int getInt() { + return (int) getLong(); + } + + @Override + public float getFloat() { + return (float) getLong(); + } + + @Override + public double getDouble() { + return (double) getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = BigDecimal.valueOf(getLong()); + return this.wasNull ? null : value; + } + + @Override + public BigDecimal getBigDecimal(int scale) { + final BigDecimal value = BigDecimal.valueOf(this.getDouble()).setScale(scale, RoundingMode.HALF_UP); + return this.wasNull ? null : value; + } + + @Override + public Number getObject() throws SQLException { + final Number number; + switch (type) { + case TINYINT: + number = getByte(); + break; + case SMALLINT: + number = getShort(); + break; + case INT: + case UINT4: + number = getInt(); + break; + case BIGINT: + number = getLong(); + break; + default: + val rootCauseException = new UnsupportedOperationException(INVALID_TYPE_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_TYPE_ERROR_RESPONSE, "2200G", rootCauseException); + } + return wasNull ? null : number; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseListVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseListVectorAccessor.java new file mode 100644 index 0000000..31f1602 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseListVectorAccessor.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.sql.Array; +import java.sql.SQLException; +import java.util.List; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.FieldVector; + +public abstract class BaseListVectorAccessor extends QueryJDBCAccessor { + + protected abstract long getStartOffset(int index); + + protected abstract long getEndOffset(int index); + + protected abstract FieldVector getDataVector(); + + protected abstract boolean isNull(int index); + + protected BaseListVectorAccessor( + IntSupplier currentRowSupplier, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(currentRowSupplier, wasNullConsumer); + } + + @Override + public Class getObjectClass() { + return List.class; + } + + protected List getListObject(VectorProvider vectorProvider) throws SQLException { + List object = vectorProvider.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return object; + } + + protected interface VectorProvider { + List getObject(int row) throws SQLException; + } + + @Override + public Array getArray() { + val index = getCurrentRow(); + val dataVector = getDataVector(); + + this.wasNull = isNull(index); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + val startOffset = getStartOffset(index); + val endOffset = getEndOffset(index); + + val valuesCount = endOffset - startOffset; + return new DataCloudArray(dataVector, startOffset, valuesCount); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessor.java new file mode 100644 index 0000000..7139ea2 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessor.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.nio.charset.StandardCharsets; +import java.util.function.IntSupplier; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.VarBinaryVector; + +public class BinaryVectorAccessor extends QueryJDBCAccessor { + + private interface ByteArrayGetter { + byte[] get(int index); + } + + private final ByteArrayGetter getter; + + public BinaryVectorAccessor( + FixedSizeBinaryVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this(vector::get, currentRowSupplier, wasNullConsumer); + } + + public BinaryVectorAccessor( + VarBinaryVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this(vector::get, currentRowSupplier, wasNullConsumer); + } + + public BinaryVectorAccessor( + LargeVarBinaryVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this(vector::get, currentRowSupplier, wasNullConsumer); + } + + private BinaryVectorAccessor( + ByteArrayGetter getter, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(currentRowSupplier, wasNullConsumer); + this.getter = getter; + } + + @Override + public byte[] getBytes() { + byte[] bytes = getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return bytes; + } + + @Override + public Object getObject() { + return this.getBytes(); + } + + @Override + public Class getObjectClass() { + return byte[].class; + } + + @Override + public String getString() { + byte[] bytes = this.getBytes(); + if (bytes == null) { + return null; + } + + return new String(bytes, StandardCharsets.UTF_8); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessor.java new file mode 100644 index 0000000..ae40322 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessor.java @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.math.BigDecimal; +import java.util.function.IntSupplier; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.holders.NullableBitHolder; + +public class BooleanVectorAccessor extends QueryJDBCAccessor { + + private final BitVector vector; + private final NullableBitHolder holder; + + public BooleanVectorAccessor( + BitVector vector, IntSupplier getCurrentRow, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(getCurrentRow, wasNullConsumer); + this.vector = vector; + this.holder = new NullableBitHolder(); + } + + @Override + public Class getObjectClass() { + return Boolean.class; + } + + @Override + public Object getObject() { + final boolean value = this.getBoolean(); + return this.wasNull ? null : value; + } + + @Override + public byte getByte() { + return (byte) this.getLong(); + } + + @Override + public short getShort() { + return (short) this.getLong(); + } + + @Override + public int getInt() { + return (int) this.getLong(); + } + + @Override + public float getFloat() { + return this.getLong(); + } + + @Override + public double getDouble() { + return this.getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final long value = this.getLong(); + + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public String getString() { + final boolean value = getBoolean(); + return wasNull ? null : Boolean.toString(value); + } + + @Override + public long getLong() { + vector.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public boolean getBoolean() { + return this.getLong() != 0; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArray.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArray.java new file mode 100644 index 0000000..fba3cda --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArray.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.ArrowUtils.getSQLTypeFromArrowType; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import java.sql.Array; +import java.sql.JDBCType; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Map; +import lombok.val; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +public class DataCloudArray implements Array { + + private final FieldVector dataVector; + private final long startOffset; + private final long valuesCount; + protected static final String NOT_SUPPORTED_IN_DATACLOUD_QUERY = + "Array method is not supported in Data Cloud query"; + + public DataCloudArray(FieldVector dataVector, long startOffset, long valuesCount) { + this.dataVector = dataVector; + this.startOffset = startOffset; + this.valuesCount = valuesCount; + } + + @Override + public String getBaseTypeName() { + val arrowType = this.dataVector.getField().getType(); + val baseType = getSQLTypeFromArrowType(arrowType); + return JDBCType.valueOf(baseType).getName(); + } + + @Override + public int getBaseType() { + val arrowType = this.dataVector.getField().getType(); + return getSQLTypeFromArrowType(arrowType); + } + + @Override + public Object getArray() throws SQLException { + return getArray(null); + } + + @Override + public Object getArray(Map> map) throws SQLException { + if (map != null) { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + return getArrayNoBoundCheck(this.dataVector, this.startOffset, this.valuesCount); + } + + @Override + public Object getArray(long index, int count) throws SQLException { + return getArray(index, count, null); + } + + @Override + public Object getArray(long index, int count, Map> map) throws SQLException { + if (map != null) { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + checkBoundaries(index, count); + val start = LargeMemoryUtil.checkedCastToInt(this.startOffset + index); + return getArrayNoBoundCheck(this.dataVector, start, count); + } + + @Override + public ResultSet getResultSet() throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public ResultSet getResultSet(Map> map) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public ResultSet getResultSet(long index, int count) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public ResultSet getResultSet(long index, int count, Map> map) throws SQLException { + throw new DataCloudJDBCException(NOT_SUPPORTED_IN_DATACLOUD_QUERY, SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Override + public void free() { + // no-op + } + + private static Object getArrayNoBoundCheck(ValueVector dataVector, long start, long count) { + Object[] result = new Object[LargeMemoryUtil.checkedCastToInt(count)]; + for (int i = 0; i < count; i++) { + result[i] = dataVector.getObject(LargeMemoryUtil.checkedCastToInt(start + i)); + } + return result; + } + + private void checkBoundaries(long index, int count) { + if (index < 0 || index + count > this.startOffset + this.valuesCount) { + throw new ArrayIndexOutOfBoundsException(); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessor.java new file mode 100644 index 0000000..61c33ec --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessor.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.core.accessor.impl.DateVectorGetter.createGetter; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCDateFromMilliseconds; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DateVectorGetter.Getter; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DateVectorGetter.Holder; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.ValueVector; + +public class DateVectorAccessor extends QueryJDBCAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + private static final String INVALID_VECTOR_ERROR_RESPONSE = "Invalid Arrow vector provided"; + + public DateVectorAccessor( + DateDayVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + public DateVectorAccessor( + DateMilliVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Date.class; + } + + @Override + public Object getObject() { + return this.getDate(null); + } + + /** + * @param calendar Calendar passed in. Ignores the calendar + * @return Timestamp of Date at current row, at midnight UTC + */ + @Override + public Timestamp getTimestamp(Calendar calendar) { + Date date = getDate(calendar); + if (date == null) { + return null; + } + return new Timestamp(date.getTime()); + } + + /** + * @param calendar Calendar passed in. Ignores the calendar + * @return Date of current row in UTC + */ + @Override + public Date getDate(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + return getUTCDateFromMilliseconds(milliseconds); + } + + @Override + public String getString() { + val date = getDate(null); + if (date == null) { + return null; + } + + return date.toLocalDate().toString(); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) throws SQLException { + if (vector instanceof DateDayVector) { + return TimeUnit.DAYS; + } else if (vector instanceof DateMilliVector) { + return TimeUnit.MILLISECONDS; + } + + val rootCauseException = new IllegalArgumentException(INVALID_VECTOR_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VECTOR_ERROR_RESPONSE, "22007", rootCauseException); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorGetter.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorGetter.java new file mode 100644 index 0000000..308994d --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorGetter.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import lombok.experimental.UtilityClass; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.holders.NullableDateDayHolder; +import org.apache.arrow.vector.holders.NullableDateMilliHolder; + +@UtilityClass +final class DateVectorGetter { + + static class Holder { + int isSet; + long value; + } + + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(DateDayVector vector) { + NullableDateDayHolder auxHolder = new NullableDateDayHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(DateMilliVector vector) { + NullableDateMilliHolder auxHolder = new NullableDateMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessor.java new file mode 100644 index 0000000..f874dde --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessor.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.math.BigDecimal; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.DecimalVector; + +public class DecimalVectorAccessor extends QueryJDBCAccessor { + private final DecimalVector vector; + + public DecimalVectorAccessor( + DecimalVector vector, IntSupplier getCurrentRow, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(getCurrentRow, wasNullConsumer); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return BigDecimal.class; + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = vector.getObject(getCurrentRow()); + this.wasNull = value == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return value; + } + + @Override + public Object getObject() { + return getBigDecimal(); + } + + @Override + public String getString() { + val value = this.getBigDecimal(); + if (value == null) { + return null; + } + return value.toString(); + } + + @Override + public int getInt() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.intValue(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessor.java new file mode 100644 index 0000000..5c39b23 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessor.java @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.holders.NullableFloat8Holder; + +public class DoubleVectorAccessor extends QueryJDBCAccessor { + + private final Float8Vector vector; + private final NullableFloat8Holder holder; + + private static final String INVALID_VALUE_ERROR_RESPONSE = "BigDecimal doesn't support Infinite/NaN"; + + public DoubleVectorAccessor( + Float8Vector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new NullableFloat8Holder(); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Double.class; + } + + @Override + public double getDouble() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Object getObject() { + final double value = this.getDouble(); + + return this.wasNull ? null : value; + } + + @Override + public String getString() { + final double value = this.getDouble(); + return this.wasNull ? null : Double.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getDouble() != 0.0; + } + + @Override + public byte getByte() { + return (byte) this.getDouble(); + } + + @Override + public short getShort() { + return (short) this.getDouble(); + } + + @Override + public int getInt() { + return (int) this.getDouble(); + } + + @Override + public long getLong() { + return (long) this.getDouble(); + } + + @Override + public float getFloat() { + return (float) this.getDouble(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + val rootCauseException = new UnsupportedOperationException(INVALID_VALUE_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VALUE_ERROR_RESPONSE, "2200G", rootCauseException); + } + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public BigDecimal getBigDecimal(int scale) throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + val rootCauseException = new UnsupportedOperationException(INVALID_VALUE_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VALUE_ERROR_RESPONSE, "2200G", rootCauseException); + } + return this.wasNull ? null : BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/LargeListVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/LargeListVectorAccessor.java new file mode 100644 index 0000000..2405c4d --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/LargeListVectorAccessor.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.sql.SQLException; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.LargeListVector; + +public class LargeListVectorAccessor extends BaseListVectorAccessor { + + private final LargeListVector vector; + + public LargeListVectorAccessor( + LargeListVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(currentRowSupplier, wasNullConsumer); + this.vector = vector; + } + + @Override + public Object getObject() throws SQLException { + return getListObject(vector::getObject); + } + + @Override + protected long getStartOffset(int index) { + val offsetBuffer = vector.getOffsetBuffer(); + return offsetBuffer.getInt((long) index * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + val offsetBuffer = vector.getOffsetBuffer(); + return offsetBuffer.getInt((long) (index + 1) * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessor.java new file mode 100644 index 0000000..e16065d --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessor.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.sql.SQLException; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.ListVector; + +public class ListVectorAccessor extends BaseListVectorAccessor { + + private final ListVector vector; + + public ListVectorAccessor( + ListVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(currentRowSupplier, wasNullConsumer); + this.vector = vector; + } + + @Override + public Object getObject() throws SQLException { + return getListObject(vector::getObject); + } + + @Override + protected long getStartOffset(int index) { + val offsetBuffer = vector.getOffsetBuffer(); + return offsetBuffer.getInt((long) index * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + val offsetBuffer = vector.getOffsetBuffer(); + return offsetBuffer.getInt((long) (index + 1) * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/NumericGetter.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/NumericGetter.java new file mode 100644 index 0000000..0ac98e9 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/NumericGetter.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import lombok.experimental.UtilityClass; +import lombok.val; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableSmallIntHolder; +import org.apache.arrow.vector.holders.NullableTinyIntHolder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; + +@UtilityClass +class NumericGetter { + + private static final String INVALID_VECTOR_ERROR_RESPONSE = "Invalid Integer Vector provided"; + + static class NumericHolder { + int isSet; + long value; + } + + @FunctionalInterface + interface Getter { + void get(int index, NumericHolder holder); + } + + static Getter createGetter(BaseIntVector vector) throws SQLException { + if (vector instanceof TinyIntVector) { + return createGetter((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + return createGetter((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + return createGetter((IntVector) vector); + } else if (vector instanceof BigIntVector) { + return createGetter((BigIntVector) vector); + } else if (vector instanceof UInt4Vector) { + return createGetter((UInt4Vector) vector); + } + val rootCauseException = new UnsupportedOperationException(INVALID_VECTOR_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VECTOR_ERROR_RESPONSE, "2200G", rootCauseException); + } + + private static Getter createGetter(TinyIntVector vector) { + NullableTinyIntHolder nullableTinyIntHolder = new NullableTinyIntHolder(); + return (index, holder) -> { + vector.get(index, nullableTinyIntHolder); + + holder.isSet = nullableTinyIntHolder.isSet; + holder.value = nullableTinyIntHolder.value; + }; + } + + private static Getter createGetter(SmallIntVector vector) { + NullableSmallIntHolder nullableSmallIntHolder = new NullableSmallIntHolder(); + return (index, holder) -> { + vector.get(index, nullableSmallIntHolder); + + holder.isSet = nullableSmallIntHolder.isSet; + holder.value = nullableSmallIntHolder.value; + }; + } + + private static Getter createGetter(IntVector vector) { + NullableIntHolder nullableIntHolder = new NullableIntHolder(); + return (index, holder) -> { + vector.get(index, nullableIntHolder); + + holder.isSet = nullableIntHolder.isSet; + holder.value = nullableIntHolder.value; + }; + } + + private static Getter createGetter(BigIntVector vector) { + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + return (index, holder) -> { + vector.get(index, nullableBigIntHolder); + + holder.isSet = nullableBigIntHolder.isSet; + holder.value = nullableBigIntHolder.value; + }; + } + + private static Getter createGetter(UInt4Vector vector) { + NullableUInt4Holder nullableUInt4Holder = new NullableUInt4Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt4Holder); + + holder.isSet = nullableUInt4Holder.isSet; + holder.value = nullableUInt4Holder.value; + }; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessor.java new file mode 100644 index 0000000..3c33e1a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessor.java @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.core.accessor.impl.TimeStampVectorGetter.createGetter; +import static com.salesforce.datacloud.jdbc.util.Constants.ISO_DATE_TIME_FORMAT; +import static com.salesforce.datacloud.jdbc.util.Constants.ISO_DATE_TIME_SEC_FORMAT; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.time.temporal.ChronoUnit; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.DateUtility; + +public class TimeStampVectorAccessor extends QueryJDBCAccessor { + + private static final String INVALID_UNIT_ERROR_RESPONSE = "Invalid Arrow time unit"; + + @FunctionalInterface + interface LongToLocalDateTime { + LocalDateTime fromLong(long value); + } + + private final TimeZone timeZone; + private final TimeUnit timeUnit; + private final LongToLocalDateTime longToLocalDateTime; + private final TimeStampVectorGetter.Holder holder; + private final TimeStampVectorGetter.Getter getter; + + public TimeStampVectorAccessor( + TimeStampVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) + throws SQLException { + super(currentRowSupplier, wasNullConsumer); + this.timeZone = getTimeZoneForVector(vector); + this.timeUnit = getTimeUnitForVector(vector); + this.longToLocalDateTime = getLongToLocalDateTimeForVector(vector, this.timeZone); + this.holder = new TimeStampVectorGetter.Holder(); + this.getter = createGetter(vector); + } + + @Override + public Date getDate(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Date(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Time getTime(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Time(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return Timestamp.valueOf(localDateTime); + } + + @Override + public Class getObjectClass() { + return Timestamp.class; + } + + @Override + public Object getObject() { + return this.getTimestamp(null); + } + + @Override + public String getString() { + LocalDateTime localDateTime = getLocalDateTime(null); + if (localDateTime == null) { + return null; + } + + if (this.timeUnit == TimeUnit.SECONDS) { + return localDateTime.format(DateTimeFormatter.ofPattern(ISO_DATE_TIME_SEC_FORMAT)); + } + + return localDateTime.format(DateTimeFormatter.ofPattern(ISO_DATE_TIME_FORMAT)); + } + + private LocalDateTime getLocalDateTime(Calendar calendar) { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + long value = holder.value; + LocalDateTime localDateTime = this.longToLocalDateTime.fromLong(value); + + if (calendar != null) { + TimeZone timeZone = calendar.getTimeZone(); + long millis = this.timeUnit.toMillis(value); + localDateTime = localDateTime.minus( + (long) timeZone.getOffset(millis) - (long) this.timeZone.getOffset(millis), ChronoUnit.MILLIS); + } + return localDateTime; + } + + private static LongToLocalDateTime getLongToLocalDateTimeForVector(TimeStampVector vector, TimeZone timeZone) + throws SQLException { + String timeZoneID = timeZone.getID(); + + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return nanoseconds -> DateUtility.getLocalDateTimeFromEpochNano(nanoseconds, timeZoneID); + case MICROSECOND: + return microseconds -> DateUtility.getLocalDateTimeFromEpochMicro(microseconds, timeZoneID); + case MILLISECOND: + return milliseconds -> DateUtility.getLocalDateTimeFromEpochMilli(milliseconds, timeZoneID); + case SECOND: + return seconds -> + DateUtility.getLocalDateTimeFromEpochMilli(TimeUnit.SECONDS.toMillis(seconds), timeZoneID); + default: + val rootCauseException = new UnsupportedOperationException(INVALID_UNIT_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_UNIT_ERROR_RESPONSE, "22007", rootCauseException); + } + } + + protected static TimeZone getTimeZoneForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + String timezoneName = arrowType.getTimezone(); + + return timezoneName == null ? TimeZone.getTimeZone("UTC") : TimeZone.getTimeZone(timezoneName); + } + + protected static TimeUnit getTimeUnitForVector(TimeStampVector vector) throws SQLException { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return TimeUnit.NANOSECONDS; + case MICROSECOND: + return TimeUnit.MICROSECONDS; + case MILLISECOND: + return TimeUnit.MILLISECONDS; + case SECOND: + return TimeUnit.SECONDS; + default: + val rootCauseException = new UnsupportedOperationException(INVALID_UNIT_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_UNIT_ERROR_RESPONSE, "22007", rootCauseException); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorGetter.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorGetter.java new file mode 100644 index 0000000..af75cbc --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorGetter.java @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import lombok.experimental.UtilityClass; +import lombok.val; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.holders.NullableTimeStampMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMicroTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecTZHolder; + +@UtilityClass +final class TimeStampVectorGetter { + + private static final String INVALID_VECTOR_ERROR_RESPONSE = "Unsupported Timestamp vector provided"; + + static class Holder { + int isSet; + long value; + } + + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeStampVector vector) throws SQLException { + if (vector instanceof TimeStampNanoVector) { + return createGetter((TimeStampNanoVector) vector); + } else if (vector instanceof TimeStampNanoTZVector) { + return createGetter((TimeStampNanoTZVector) vector); + } else if (vector instanceof TimeStampMicroVector) { + return createGetter((TimeStampMicroVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + return createGetter((TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMilliVector) { + return createGetter((TimeStampMilliVector) vector); + } else if (vector instanceof TimeStampMilliTZVector) { + return createGetter((TimeStampMilliTZVector) vector); + } else if (vector instanceof TimeStampSecVector) { + return createGetter((TimeStampSecVector) vector); + } else if (vector instanceof TimeStampSecTZVector) { + return createGetter((TimeStampSecTZVector) vector); + } + + val rootCauseException = new UnsupportedOperationException(INVALID_VECTOR_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VECTOR_ERROR_RESPONSE, "22007", rootCauseException); + } + + private static Getter createGetter(TimeStampNanoVector vector) { + NullableTimeStampNanoHolder auxHolder = new NullableTimeStampNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampNanoTZVector vector) { + NullableTimeStampNanoTZHolder auxHolder = new NullableTimeStampNanoTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroVector vector) { + NullableTimeStampMicroHolder auxHolder = new NullableTimeStampMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroTZVector vector) { + NullableTimeStampMicroTZHolder auxHolder = new NullableTimeStampMicroTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliVector vector) { + NullableTimeStampMilliHolder auxHolder = new NullableTimeStampMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliTZVector vector) { + NullableTimeStampMilliTZHolder auxHolder = new NullableTimeStampMilliTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecVector vector) { + NullableTimeStampSecHolder auxHolder = new NullableTimeStampSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecTZVector vector) { + NullableTimeStampSecTZHolder auxHolder = new NullableTimeStampSecTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessor.java new file mode 100644 index 0000000..442778e --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessor.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.core.accessor.impl.TimeVectorGetter.Getter; +import static com.salesforce.datacloud.jdbc.core.accessor.impl.TimeVectorGetter.Holder; +import static com.salesforce.datacloud.jdbc.core.accessor.impl.TimeVectorGetter.createGetter; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimeFromMilliseconds; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.format.DateTimeFormatter; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; +import lombok.val; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.ValueVector; + +public class TimeVectorAccessor extends QueryJDBCAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + private static final String INVALID_VECTOR_ERROR_RESPONSE = "Unsupported Timestamp vector type provided"; + + public TimeVectorAccessor( + TimeNanoVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new TimeVectorGetter.Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + public TimeVectorAccessor( + TimeMicroVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + public TimeVectorAccessor( + TimeMilliVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + public TimeVectorAccessor( + TimeSecVector vector, + IntSupplier currentRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer setCursorWasNull) + throws SQLException { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Time.class; + } + + @Override + public Object getObject() { + return this.getTime(null); + } + + /** + * @param calendar Calendar passed in. Ignores the calendar + * @return the Time relative to 00:00:00 assuming timezone is UTC + */ + @Override + public Time getTime(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + return getUTCTimeFromMilliseconds(milliseconds); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + /** + * @param calendar Calendar passed in. Ignores the calendar + * @return the Timestamp relative to 00:00:00 assuming timezone is UTC + */ + @Override + public Timestamp getTimestamp(Calendar calendar) { + Time time = getTime(calendar); + if (time == null) { + return null; + } + return new Timestamp(time.getTime()); + } + + @Override + public String getString() { + Time time = getTime(null); + if (time == null) { + return null; + } + + return time.toLocalTime().format(DateTimeFormatter.ISO_TIME); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) throws SQLException { + if (vector instanceof TimeNanoVector) { + return TimeUnit.NANOSECONDS; + } else if (vector instanceof TimeMicroVector) { + return TimeUnit.MICROSECONDS; + } else if (vector instanceof TimeMilliVector) { + return TimeUnit.MILLISECONDS; + } else if (vector instanceof TimeSecVector) { + return TimeUnit.SECONDS; + } + + val rootCauseException = new UnsupportedOperationException(INVALID_VECTOR_ERROR_RESPONSE); + throw new DataCloudJDBCException(INVALID_VECTOR_ERROR_RESPONSE, "22007", rootCauseException); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorGetter.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorGetter.java new file mode 100644 index 0000000..e0da278 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorGetter.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import lombok.experimental.UtilityClass; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.holders.NullableTimeMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeSecHolder; + +@UtilityClass +public class TimeVectorGetter { + + static class Holder { + int isSet; + long value; + } + + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeNanoVector vector) { + NullableTimeNanoHolder auxHolder = new NullableTimeNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMicroVector vector) { + NullableTimeMicroHolder auxHolder = new NullableTimeMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMilliVector vector) { + NullableTimeMilliHolder auxHolder = new NullableTimeMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeSecVector vector) { + NullableTimeSecHolder auxHolder = new NullableTimeSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessor.java b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessor.java new file mode 100644 index 0000000..4e2d09a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessor.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import java.nio.charset.StandardCharsets; +import java.util.function.IntSupplier; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.VarCharVector; + +public class VarCharVectorAccessor extends QueryJDBCAccessor { + + @FunctionalInterface + interface Getter { + byte[] get(int index); + } + + private final Getter getter; + + public VarCharVectorAccessor( + VarCharVector vector, + IntSupplier currenRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this(vector::get, currenRowSupplier, wasNullConsumer); + } + + public VarCharVectorAccessor( + LargeVarCharVector vector, + IntSupplier currenRowSupplier, + QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + this(vector::get, currenRowSupplier, wasNullConsumer); + } + + VarCharVectorAccessor( + Getter getter, IntSupplier currentRowSupplier, QueryJDBCAccessorFactory.WasNullConsumer wasNullConsumer) { + super(currentRowSupplier, wasNullConsumer); + this.getter = getter; + } + + @Override + public Class getObjectClass() { + return String.class; + } + + @Override + public byte[] getBytes() { + final byte[] bytes = this.getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return this.getter.get(getCurrentRow()); + } + + @Override + public String getString() { + return getObject(); + } + + @Override + public String getObject() { + final byte[] bytes = getBytes(); + return bytes == null ? null : new String(bytes, StandardCharsets.UTF_8); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java new file mode 100644 index 0000000..e705e7b --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListener.java @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static com.salesforce.datacloud.jdbc.util.ThrowingSupplier.rethrowLongSupplier; +import static com.salesforce.datacloud.jdbc.util.ThrowingSupplier.rethrowSupplier; + +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.core.StreamingResultSet; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.time.Duration; +import java.time.Instant; +import java.util.Iterator; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.function.UnaryOperator; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class AdaptiveQueryStatusListener implements QueryStatusListener { + @Getter + private final String queryId; + + @Getter + private final String query; + + private final HyperGrpcClientExecutor client; + + private final Duration timeout; + + private final Iterator response; + + private final AdaptiveQueryStatusPoller headPoller; + + private final AsyncQueryStatusPoller tailPoller; + + public static AdaptiveQueryStatusListener of(String query, HyperGrpcClientExecutor client, Duration timeout) + throws SQLException { + try { + val response = client.executeAdaptiveQuery(query); + val queryId = response.next().getQueryInfo().getQueryStatus().getQueryId(); + + return new AdaptiveQueryStatusListener( + queryId, + query, + client, + timeout, + response, + new AdaptiveQueryStatusPoller(queryId, client), + new AsyncQueryStatusPoller(queryId, client)); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createException("Failed to execute query: " + query, ex); + } + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public String getStatus() { + val poller = headPoller.pollChunkCount() > 1 ? tailPoller : headPoller; + return Optional.of(poller) + .map(QueryStatusPoller::pollQueryStatus) + .map(QueryStatus::getCompletionStatus) + .map(Enum::name) + .orElse(QueryStatus.CompletionStatus.RUNNING.name()); + } + + @Override + public DataCloudResultSet generateResultSet() { + return StreamingResultSet.of(query, this); + } + + @Override + public Stream stream() throws SQLException { + return Stream.>>of(this::head, rethrowSupplier(this::tail)) + .flatMap(Supplier::get); + } + + private Stream head() { + return StreamUtilities.toStream(response) + .map(headPoller::map) + .filter(Optional::isPresent) + .map(Optional::get); + } + + private Stream tail() throws SQLException { + return StreamUtilities.lazyLimitedStream(this::infiniteChunks, rethrowLongSupplier(this::getChunkLimit)) + .flatMap(UnaryOperator.identity()); + } + + private Stream> infiniteChunks() { + return LongStream.iterate(1, n -> n + 1).mapToObj(this::tryGetQueryResult); + } + + private long getChunkLimit() throws SQLException { + if (headPoller.pollChunkCount() > 1) { + blockUntilReady(tailPoller, timeout); + return tailPoller.pollChunkCount() - 1; + } + + return 0; + } + + private Stream tryGetQueryResult(long chunkId) { + return StreamUtilities.tryTimes( + 3, + () -> client.getQueryResult(queryId, chunkId, true), + throwable -> log.warn( + "Error when getting chunk for query. queryId={}, chunkId={}", + queryId, + chunkId, + throwable)) + .map(StreamUtilities::toStream) + .orElse(Stream.empty()); + } + + @SneakyThrows + private void blockUntilReady(QueryStatusPoller poller, Duration timeout) { + val end = Instant.now().plus(timeout); + var millis = 1000; + while (!poller.pollIsReady() && Instant.now().isBefore(end)) { + log.info( + "Waiting for additional query results. queryId={}, timeout={}, sleep={}", + queryId, + timeout, + Duration.ofSeconds(millis)); + + Thread.sleep(millis); + millis *= 2; + } + + if (!tailPoller.pollIsReady()) { + throw new DataCloudJDBCException(BEFORE_READY + ". queryId=" + queryId + ", timeout=" + timeout); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java new file mode 100644 index 0000000..d61bdc6 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPoller.java @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.util.Iterator; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@AllArgsConstructor(access = AccessLevel.PACKAGE) +public class AdaptiveQueryStatusPoller implements QueryStatusPoller { + private final AtomicLong chunks = new AtomicLong(1); + private final AtomicReference lastStatus = new AtomicReference<>(); + private final String queryId; + private final HyperGrpcClientExecutor client; + + @SneakyThrows + private Iterator getQueryInfoStreaming() { + try { + return client.getQueryInfoStreaming(queryId); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createException("Failed when getting query status", ex); + } + } + + public Optional map(ExecuteQueryResponse item) { + getQueryStatus(item).ifPresent(this::handleQueryStatus); + return getQueryResult(item); + } + + private void handleQueryStatus(QueryStatus status) { + lastStatus.set(status); + + if (status.getChunkCount() > 1) { + this.chunks.set(status.getChunkCount()); + } + } + + private Optional getQueryStatus(ExecuteQueryResponse item) { + if (item != null && item.hasQueryInfo()) { + val info = item.getQueryInfo(); + if (info.hasQueryStatus()) { + return Optional.of(info.getQueryStatus()); + } + } + + return Optional.empty(); + } + + private Optional getQueryResult(ExecuteQueryResponse item) { + if (item != null && item.hasQueryResult()) { + return Optional.of(item.getQueryResult()); + } + + return Optional.empty(); + } + + @Override + public QueryStatus pollQueryStatus() { + return lastStatus.get(); + } + + @Override + public long pollChunkCount() { + val status = Optional.ofNullable(this.lastStatus.get()); + val finalized = status.map(QueryStatus::getCompletionStatus) + .map(t -> t == QueryStatus.CompletionStatus.FINISHED + || t == QueryStatus.CompletionStatus.RESULTS_PRODUCED) + .orElse(false); + + if (finalized) { + return chunks.get(); + } + + val queryInfos = getQueryInfoStreaming(); + val result = StreamUtilities.toStream(queryInfos) + .map(Optional::ofNullable) + .filter(Optional::isPresent) + .map(Optional::get) + .map(QueryInfo::getQueryStatus) + .filter(AdaptiveQueryStatusPoller::isChunksCompleted) + .findFirst(); + + result.ifPresent(it -> { + val completion = it.getCompletionStatus(); + val chunkCount = it.getChunkCount(); + log.info("Polling chunk count. queryId={}, status={}, count={}", this.queryId, completion, chunkCount); + chunks.set(chunkCount); + }); + + return chunks.get(); + } + + private static boolean isChunksCompleted(QueryStatus s) { + val completion = s.getCompletionStatus(); + return completion == QueryStatus.CompletionStatus.RESULTS_PRODUCED + || completion == QueryStatus.CompletionStatus.FINISHED; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java new file mode 100644 index 0000000..14a3936 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListener.java @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.core.StreamingResultSet; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.util.Optional; +import java.util.function.UnaryOperator; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@Builder(access = AccessLevel.PRIVATE) +public class AsyncQueryStatusListener implements QueryStatusListener { + @Getter + private final String queryId; + + @Getter + private final String query; + + private final HyperGrpcClientExecutor client; + + @Getter(value = AccessLevel.PRIVATE, lazy = true) + private final AsyncQueryStatusPoller poller = new AsyncQueryStatusPoller(queryId, client); + + public static AsyncQueryStatusListener of(String query, HyperGrpcClientExecutor client) throws SQLException { + try { + val result = client.executeAsyncQuery(query).next(); + val id = result.getQueryInfo().getQueryStatus().getQueryId(); + + return AsyncQueryStatusListener.builder() + .queryId(id) + .query(query) + .client(client) + .build(); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createException("Failed to execute query: " + query, ex); + } + } + + @Override + public boolean isReady() { + return getPoller().pollIsReady(); + } + + @Override + public String getStatus() { + return Optional.of(getPoller()) + .map(AsyncQueryStatusPoller::pollQueryStatus) + .map(QueryStatus::getCompletionStatus) + .map(Enum::name) + .orElse(null); + } + + @Override + public DataCloudResultSet generateResultSet() { + return StreamingResultSet.of(query, this); + } + + @Override + public Stream stream() throws SQLException { + return StreamUtilities.lazyLimitedStream(this::infiniteChunks, this::getChunkLimit) + .flatMap(UnaryOperator.identity()); + } + + private Stream> infiniteChunks() { + return LongStream.iterate(0, n -> n + 1).mapToObj(this::tryGetQueryResult); + } + + @SneakyThrows + private long getChunkLimit() { + if (!isReady()) { + throw new DataCloudJDBCException(BEFORE_READY); + } + + return getPoller().pollChunkCount(); + } + + private Stream tryGetQueryResult(long chunkId) { + return StreamUtilities.tryTimes( + 3, + () -> client.getQueryResult(queryId, chunkId, chunkId > 0), + throwable -> log.warn( + "Error when getting chunk for query. queryId={}, chunkId={}", + queryId, + chunkId, + throwable)) + .map(StreamUtilities::toStream) + .orElse(Stream.empty()); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java new file mode 100644 index 0000000..57e4737 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPoller.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static com.salesforce.datacloud.jdbc.core.listener.QueryStatusListener.BEFORE_READY; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.util.Iterator; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.val; + +@AllArgsConstructor(access = AccessLevel.PACKAGE) +class AsyncQueryStatusPoller implements QueryStatusPoller { + private final String queryId; + private final HyperGrpcClientExecutor client; + + private final AtomicReference lastStatus = new AtomicReference<>(); + + @SneakyThrows + private Optional getQueryInfo() { + try { + return Optional.ofNullable(client.getQueryInfo(queryId)).map(Iterator::next); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createException("Failed when getting query status", ex); + } + } + + private Optional fetchQueryStatus() { + val status = getQueryInfo().map(QueryInfo::getQueryStatus); + status.ifPresent(this.lastStatus::set); + return status; + } + + @Override + public QueryStatus pollQueryStatus() { + val status = Optional.ofNullable(this.lastStatus.get()); + val finished = status.map(QueryStatus::getCompletionStatus) + .map(t -> t == QueryStatus.CompletionStatus.FINISHED) + .orElse(false); + + return finished ? status.get() : fetchQueryStatus().orElse(null); + } + + @Override + public long pollChunkCount() throws SQLException { + if (!pollIsReady()) { + throw new DataCloudJDBCException(BEFORE_READY); + } + return lastStatus.get().getChunkCount(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java new file mode 100644 index 0000000..6783272 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListener.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.hyperdb.grpc.QueryResult; +import java.sql.SQLException; +import java.util.stream.Stream; + +public interface QueryStatusListener { + String BEFORE_READY = "Results were requested before ready"; + + String getQuery(); + + boolean isReady(); + + String getStatus(); + + String getQueryId(); + + DataCloudResultSet generateResultSet(); + + Stream stream() throws SQLException; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java new file mode 100644 index 0000000..5ec7fac --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusPoller.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.hyperdb.grpc.QueryStatus; +import java.sql.SQLException; +import java.util.Optional; + +public interface QueryStatusPoller { + QueryStatus pollQueryStatus(); + + long pollChunkCount() throws SQLException; + + default boolean pollIsReady() { + return Optional.ofNullable(pollQueryStatus()) + .map(QueryStatus::getCompletionStatus) + .map(QueryStatusPoller::isReady) + .orElse(false); + } + + static boolean isReady(QueryStatus.CompletionStatus completionStatus) { + return completionStatus == QueryStatus.CompletionStatus.RESULTS_PRODUCED + || completionStatus == QueryStatus.CompletionStatus.FINISHED; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListener.java b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListener.java new file mode 100644 index 0000000..c2be449 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListener.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.core.StreamingResultSet; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.exception.QueryExceptionHandler; +import com.salesforce.datacloud.jdbc.util.StreamUtilities; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.util.Iterator; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@Builder(access = AccessLevel.PRIVATE) +public class SyncQueryStatusListener implements QueryStatusListener { + @Getter + private final String queryId; + + @Getter + private final String query; + + private final AtomicReference status = new AtomicReference<>(); + + private final Iterator initial; + + public static SyncQueryStatusListener of(String query, HyperGrpcClientExecutor client) throws SQLException { + val result = client.executeQuery(query); + + try { + val id = getQueryId(result.next(), query); + return SyncQueryStatusListener.builder() + .query(query) + .queryId(id) + .initial(result) + .build(); + } catch (StatusRuntimeException ex) { + throw QueryExceptionHandler.createException("Failed to execute query: " + query, ex); + } + } + + @Override + public boolean isReady() { + return true; + } + + @Override + public String getStatus() { + return Optional.ofNullable(status.get()) + .map(QueryStatus::getCompletionStatus) + .map(Enum::name) + .orElse(null); + } + + @Override + public DataCloudResultSet generateResultSet() { + return StreamingResultSet.of(query, this); + } + + @Override + public Stream stream() throws SQLException { + return StreamUtilities.toStream(this.initial) + .peek(this::peekQueryStatus) + .map(SyncQueryStatusListener::extractQueryResult) + .filter(Optional::isPresent) + .map(Optional::get); + } + + private static Optional extractQueryResult(ExecuteQueryResponse response) { + return Optional.ofNullable(response).map(ExecuteQueryResponse::getQueryResult); + } + + private void peekQueryStatus(ExecuteQueryResponse response) { + Optional.ofNullable(response) + .map(ExecuteQueryResponse::getQueryInfo) + .map(QueryInfo::getQueryStatus) + .ifPresent(status::set); + } + + @SneakyThrows + private static String getQueryId(ExecuteQueryResponse response, String query) { + val rootErrorMessage = "The server did not supply an ID for the query: " + query; + return Optional.ofNullable(response) + .map(ExecuteQueryResponse::getQueryInfo) + .map(QueryInfo::getQueryStatus) + .map(QueryStatus::getQueryId) + .orElseThrow(() -> + new DataCloudJDBCException(rootErrorMessage, new IllegalStateException(rootErrorMessage))); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/model/DataspaceResponse.java b/src/main/java/com/salesforce/datacloud/jdbc/core/model/DataspaceResponse.java new file mode 100644 index 0000000..87af013 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/model/DataspaceResponse.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.model; + +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.util.List; +import java.util.Map; +import lombok.Data; + +@Data +@JsonIgnoreProperties(ignoreUnknown = true) +public class DataspaceResponse { + List records; + Integer totalSize; + Boolean done; + + @Data + public static class DataSpaceAttributes { + Map attributes; + + @JsonAlias({"Name"}) + String name; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/core/model/ParameterBinding.java b/src/main/java/com/salesforce/datacloud/jdbc/core/model/ParameterBinding.java new file mode 100644 index 0000000..ddd9199 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/core/model/ParameterBinding.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.model; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +@Getter +public class ParameterBinding { + private final int sqlType; + private final Object value; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/exception/DataCloudJDBCException.java b/src/main/java/com/salesforce/datacloud/jdbc/exception/DataCloudJDBCException.java new file mode 100644 index 0000000..055dfb0 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/exception/DataCloudJDBCException.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.exception; + +import java.sql.SQLException; +import lombok.Getter; + +@Getter +public class DataCloudJDBCException extends SQLException { + private String customerHint; + + private String customerDetail; + + public DataCloudJDBCException() { + super(); + } + + public DataCloudJDBCException(String reason) { + super(reason); + } + + public DataCloudJDBCException(String reason, String SQLState) { + super(reason, SQLState); + } + + public DataCloudJDBCException(String reason, String SQLState, int vendorCode) { + super(reason, SQLState, vendorCode); + } + + public DataCloudJDBCException(Throwable cause) { + super(cause); + } + + public DataCloudJDBCException(String reason, Throwable cause) { + super(reason, cause); + } + + public DataCloudJDBCException(String reason, String SQLState, Throwable cause) { + super(reason, SQLState, cause); + } + + public DataCloudJDBCException(String reason, String SQLState, int vendorCode, Throwable cause) { + super(reason, SQLState, vendorCode, cause); + } + + public DataCloudJDBCException( + String reason, String SQLState, String customerHint, String customerDetail, Throwable cause) { + super(reason, SQLState, 0, cause); + + this.customerHint = customerHint; + this.customerDetail = customerDetail; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandler.java b/src/main/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandler.java new file mode 100644 index 0000000..12ce86f --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandler.java @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.exception; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.salesforce.hyperdb.grpc.ErrorInfo; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import java.sql.SQLException; +import java.util.List; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@UtilityClass +public class QueryExceptionHandler { + + public static DataCloudJDBCException createException(String message, Exception e) { + if (e instanceof StatusRuntimeException) { + StatusRuntimeException ex = (StatusRuntimeException) e; + com.google.rpc.Status status = StatusProto.fromThrowable(ex); + + if (status != null) { + List detailsList = status.getDetailsList(); + Any firstError = detailsList.stream() + .filter(any -> any.is(ErrorInfo.class)) + .findFirst() + .orElse(null); + if (firstError != null) { + ErrorInfo errorInfo; + try { + errorInfo = firstError.unpack(ErrorInfo.class); + } catch (InvalidProtocolBufferException exc) { + return new DataCloudJDBCException("Invalid error info", e); + } + + String sqlState = errorInfo.getSqlstate(); + String customerHint = errorInfo.getCustomerHint(); + String customerDetail = errorInfo.getCustomerDetail(); + String primaryMessage = String.format( + "%s: %s%nDETAIL:%n%s%nHINT:%n%s", + sqlState, errorInfo.getPrimaryMessage(), customerDetail, customerHint); + return new DataCloudJDBCException(primaryMessage, sqlState, customerHint, customerDetail, ex); + } + } + } + return new DataCloudJDBCException(message, e); + } + + public static SQLException createException(String message, String sqlState, Exception e) { + return new SQLException(message, sqlState, e.getCause()); + } + + public static SQLException createException(String message, String sqlState) { + return new SQLException(message, sqlState); + } + + public static SQLException createException(String message) { + return new DataCloudJDBCException(message); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/http/ClientBuilder.java b/src/main/java/com/salesforce/datacloud/jdbc/http/ClientBuilder.java new file mode 100644 index 0000000..ac75dd3 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/http/ClientBuilder.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.getIntegerOrDefault; +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; + +import com.salesforce.datacloud.jdbc.util.internal.SFDefaultSocketFactoryWrapper; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import okhttp3.OkHttpClient; + +@Slf4j +@UtilityClass +public class ClientBuilder { + + static final String READ_TIME_OUT_SECONDS_KEY = "readTimeOutSeconds"; + static final int DEFAULT_READ_TIME_OUT_SECONDS = 600; + + static final String CONNECT_TIME_OUT_SECONDS_KEY = "connectTimeOutSeconds"; + static final int DEFAULT_CONNECT_TIME_OUT_SECONDS = 600; + + static final String CALL_TIME_OUT_SECONDS_KEY = "callTimeOutSeconds"; + static final int DEFAULT_CALL_TIME_OUT_SECONDS = 600; + + static final String DISABLE_SOCKS_PROXY_KEY = "disableSocksProxy"; + static final Boolean DISABLE_SOCKS_PROXY_DEFAULT = false; + + public static OkHttpClient buildOkHttpClient(Properties properties) { + val disableSocksProxy = optional(properties, DISABLE_SOCKS_PROXY_KEY) + .map(Boolean::valueOf) + .orElse(DISABLE_SOCKS_PROXY_DEFAULT); + + val readTimeout = getIntegerOrDefault(properties, READ_TIME_OUT_SECONDS_KEY, DEFAULT_READ_TIME_OUT_SECONDS); + val connectTimeout = + getIntegerOrDefault(properties, CONNECT_TIME_OUT_SECONDS_KEY, DEFAULT_CONNECT_TIME_OUT_SECONDS); + val callTimeout = getIntegerOrDefault(properties, CALL_TIME_OUT_SECONDS_KEY, DEFAULT_CALL_TIME_OUT_SECONDS); + + return new OkHttpClient.Builder() + .socketFactory(new SFDefaultSocketFactoryWrapper(disableSocksProxy)) + .callTimeout(callTimeout, TimeUnit.SECONDS) + .connectTimeout(connectTimeout, TimeUnit.SECONDS) + .readTimeout(readTimeout, TimeUnit.SECONDS) + .addInterceptor(new MetadataCacheInterceptor()) + .build(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/http/FormCommand.java b/src/main/java/com/salesforce/datacloud/jdbc/http/FormCommand.java new file mode 100644 index 0000000..9558888 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/http/FormCommand.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Constants; +import java.io.IOException; +import java.net.URI; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import lombok.Builder; +import lombok.NonNull; +import lombok.Singular; +import lombok.Value; +import lombok.val; +import okhttp3.FormBody; +import okhttp3.Headers; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.apache.commons.lang3.StringUtils; + +@Value +@Builder(builderClassName = "Builder") +public class FormCommand { + private static final String ACCEPT_HEADER_NAME = "Accept"; + public static final String CONTENT_TYPE_HEADER_NAME = "Content-Type"; + private static final String URL_ENCODED_CONTENT = "application/x-www-form-urlencoded"; + private static ObjectMapper mapper = new ObjectMapper(); + + @NonNull URI url; + + @NonNull URI suffix; + + @Singular + Map headers; + + @Singular + Map bodyEntries; + + @Singular + Map queryParameters; + + public static T get(@NonNull OkHttpClient client, @NonNull FormCommand command, Class type) + throws SQLException { + val url = getUrl(command); + val headers = asHeaders(command); + val request = new Request.Builder().url(url).headers(headers).get().build(); + + return executeRequest(client, request, type); + } + + public static T post(@NonNull OkHttpClient client, @NonNull FormCommand command, Class type) + throws SQLException { + val url = getUrl(command); + val headers = asHeaders(command); + val payload = asFormBody(command); + val request = + new Request.Builder().url(url).headers(headers).post(payload).build(); + + return executeRequest(client, request, type); + } + + private static String getUrl(FormCommand command) { + HttpUrl.Builder builder = Objects.requireNonNull( + HttpUrl.parse(command.getUrl().toString())) + .newBuilder(); + builder.addPathSegments(command.suffix.toString()); + command.queryParameters.forEach(builder::addEncodedQueryParameter); + return builder.build().toString(); + } + + private static T executeRequest(@NonNull OkHttpClient client, Request request, Class type) + throws SQLException { + try (Response response = client.newCall(request).execute()) { + if (!response.isSuccessful()) throw new IOException("Unexpected code " + response); + val body = response.body(); + if (body == null || StringUtils.isEmpty(body.toString())) { + throw new IOException("Response Body was null " + response); + } + val json = body.string(); + return mapper.readValue(json, type); + } catch (IOException e) { + throw new DataCloudJDBCException(e); + } + } + + private static FormBody asFormBody(FormCommand command) { + val body = new FormBody.Builder(); + command.getBodyEntries().forEach(body::add); + return body.build(); + } + + private static Headers asHeaders(FormCommand command) { + val headers = new HashMap<>(command.getHeaders()); + + headers.putIfAbsent(ACCEPT_HEADER_NAME, Constants.CONTENT_TYPE_JSON); + headers.putIfAbsent(CONTENT_TYPE_HEADER_NAME, URL_ENCODED_CONTENT); + + return Headers.of(headers); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptor.java new file mode 100644 index 0000000..415e78a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptor.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.MetadataCacheUtil; +import java.io.IOException; +import java.net.HttpURLConnection; +import lombok.extern.slf4j.Slf4j; +import okhttp3.*; +import org.jetbrains.annotations.NotNull; + +@Slf4j +public class MetadataCacheInterceptor implements Interceptor { + + @NotNull @Override + public Response intercept(@NotNull Chain chain) throws IOException { + Request request = chain.request(); + Response response; + String responseString = MetadataCacheUtil.getMetadata(request.url().toString()); + if (responseString != null) { + log.trace("Getting the metadata response from local cache"); + response = new Response.Builder() + .code(HttpURLConnection.HTTP_OK) + .request(request) + .protocol(Protocol.HTTP_1_1) + .message("OK") + .addHeader("from-local-cache", "true") + .body(ResponseBody.create(responseString, MediaType.parse(Constants.CONTENT_TYPE_JSON))) + .build(); + } else { + log.trace("Cache miss for metadata response. Getting from server"); + response = chain.proceed(request); + } + return response; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptor.java new file mode 100644 index 0000000..d46c9a6 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptor.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import io.grpc.Metadata; +import java.sql.SQLException; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@ToString +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class AuthorizationHeaderInterceptor implements HeaderMutatingClientInterceptor { + + public interface TokenSupplier { + String getToken() throws SQLException; + + default String getAudience() { + return null; + } + } + + public static AuthorizationHeaderInterceptor of(TokenProcessor tokenProcessor) { + val supplier = new TokenProcessorSupplier(tokenProcessor); + return new AuthorizationHeaderInterceptor(supplier, "oauth"); + } + + public static AuthorizationHeaderInterceptor of(TokenSupplier supplier) { + return new AuthorizationHeaderInterceptor(supplier, "custom"); + } + + private static final String AUTH = "Authorization"; + private static final String AUD = "audience"; + + private static final Metadata.Key AUTH_KEY = Metadata.Key.of(AUTH, ASCII_STRING_MARSHALLER); + private static final Metadata.Key AUD_KEY = Metadata.Key.of(AUD, ASCII_STRING_MARSHALLER); + + @ToString.Exclude + private final TokenSupplier tokenSupplier; + + private final String name; + + @SneakyThrows + @Override + public void mutate(final Metadata headers) { + val token = tokenSupplier.getToken(); + headers.put(AUTH_KEY, token); + + val audience = tokenSupplier.getAudience(); + if (audience != null) { + headers.put(AUD_KEY, audience); + } + } + + @AllArgsConstructor + static class TokenProcessorSupplier implements TokenSupplier { + private final TokenProcessor tokenProcessor; + + @SneakyThrows + @Override + public String getToken() { + val token = tokenProcessor.getDataCloudToken(); + return token.getAccessToken(); + } + + @SneakyThrows + @Override + public String getAudience() { + val token = tokenProcessor.getDataCloudToken(); + return token.getTenantId(); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptor.java new file mode 100644 index 0000000..953a8fd --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptor.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import io.grpc.Metadata; +import java.util.Properties; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class DataspaceHeaderInterceptor implements HeaderMutatingClientInterceptor { + public static DataspaceHeaderInterceptor of(Properties properties) { + return optional(properties, DATASPACE) + .map(DataspaceHeaderInterceptor::new) + .orElse(null); + } + + @NonNull private final String dataspace; + + static final String DATASPACE = "dataspace"; + + private static final Metadata.Key DATASPACE_KEY = Metadata.Key.of(DATASPACE, ASCII_STRING_MARSHALLER); + + @Override + public void mutate(final Metadata headers) { + headers.put(DATASPACE_KEY, dataspace); + } + + @Override + public String toString() { + return ("DataspaceHeaderInterceptor(dataspace=" + dataspace + ")"); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptor.java new file mode 100644 index 0000000..c885a84 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptor.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import lombok.SneakyThrows; + +public interface HeaderMutatingClientInterceptor extends ClientInterceptor { + void mutate(final Metadata headers); + + @Override + default ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) { + @SneakyThrows + @Override + public void start(final Listener responseListener, final Metadata headers) { + try { + mutate(headers); + } catch (Exception ex) { + throw new DataCloudJDBCException( + "Caught exception when mutating headers in client interceptor", ex); + } + + super.start(responseListener, headers); + } + }; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptor.java new file mode 100644 index 0000000..9cc28ea --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptor.java @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import io.grpc.Metadata; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; + +@Getter +@ToString +@NoArgsConstructor +public class HyperDefaultsHeaderInterceptor implements HeaderMutatingClientInterceptor { + private static final String GRPC_MAX_METADATA_SIZE = String.valueOf(1024 * 1024); // 1mb + private static final String WORKLOAD_VALUE = "jdbcv3"; + private static final String WORKLOAD_KEY_STR = "x-hyperdb-workload"; + private static final String MAX_METADATA_SIZE = "grpc.max_metadata_size"; + + private static final Metadata.Key WORKLOAD_KEY = Metadata.Key.of(WORKLOAD_KEY_STR, ASCII_STRING_MARSHALLER); + private static final Metadata.Key SIZE_KEY = Metadata.Key.of(MAX_METADATA_SIZE, ASCII_STRING_MARSHALLER); + + @Override + public void mutate(final Metadata headers) { + headers.put(WORKLOAD_KEY, WORKLOAD_VALUE); + headers.put(SIZE_KEY, GRPC_MAX_METADATA_SIZE); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptor.java new file mode 100644 index 0000000..56d2578 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptor.java @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import io.grpc.Metadata; +import lombok.RequiredArgsConstructor; +import lombok.ToString; + +@ToString +@RequiredArgsConstructor +public class QueryIdHeaderInterceptor implements HeaderMutatingClientInterceptor { + static final String HYPER_QUERY_ID = "x-hyperdb-query-id"; + + public static final Metadata.Key HYPER_QUERY_ID_KEY = + Metadata.Key.of(HYPER_QUERY_ID, ASCII_STRING_MARSHALLER); + + private final String queryId; + + @Override + public void mutate(final Metadata headers) { + headers.put(HYPER_QUERY_ID_KEY, queryId); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptor.java new file mode 100644 index 0000000..2b1569d --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptor.java @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import com.salesforce.datacloud.jdbc.internal.Tracer; +import io.grpc.Metadata; +import java.util.function.Supplier; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@ToString +@Builder(access = AccessLevel.PRIVATE) +public class TracingHeadersInterceptor implements HeaderMutatingClientInterceptor { + public static TracingHeadersInterceptor of() { + val tracer = Tracer.get(); + val traceId = tracer.nextId(); + log.info("new tracing interceptor created. traceId={}", traceId); + return TracingHeadersInterceptor.builder() + .getTraceId(() -> traceId) + .getSpanId(tracer::nextSpanId) + .build(); + } + + private static final String TRACE_ID = "x-b3-traceid"; + private static final String SPAN_ID = "x-b3-spanid"; + + private static final Metadata.Key TRACE_ID_KEY = Metadata.Key.of(TRACE_ID, ASCII_STRING_MARSHALLER); + private static final Metadata.Key SPAN_ID_KEY = Metadata.Key.of(SPAN_ID, ASCII_STRING_MARSHALLER); + + private final Supplier getTraceId; + private final Supplier getSpanId; + + @Override + public void mutate(Metadata headers) { + headers.put(TRACE_ID_KEY, getTraceId.get()); + headers.put(SPAN_ID_KEY, getSpanId.get()); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptor.java b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptor.java new file mode 100644 index 0000000..10bdebb --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptor.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static com.salesforce.datacloud.jdbc.util.PropertiesExtensions.optional; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + +import com.salesforce.datacloud.jdbc.config.DriverVersion; +import io.grpc.Metadata; +import java.util.Properties; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.ToString; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@ToString +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class UserAgentHeaderInterceptor implements HeaderMutatingClientInterceptor { + public static UserAgentHeaderInterceptor of(Properties properties) { + val provided = optional(properties, USER_AGENT); + val userAgent = getCombinedUserAgent(provided.orElse(null)); + return new UserAgentHeaderInterceptor(userAgent); + } + + private final String userAgent; + + @Override + public void mutate(final Metadata headers) { + headers.put(USER_AGENT_KEY, userAgent); + } + + private static final String USER_AGENT = "User-Agent"; + + private static final Metadata.Key USER_AGENT_KEY = Metadata.Key.of(USER_AGENT, ASCII_STRING_MARSHALLER); + + private static String getCombinedUserAgent(String clientProvidedUserAgent) { + String driverInfo = DriverVersion.formatDriverInfo(); + if (clientProvidedUserAgent == null || clientProvidedUserAgent.isEmpty()) { + return driverInfo; + } + + if (clientProvidedUserAgent.equals(driverInfo)) { + return driverInfo; + } + + return String.format("%s %s", clientProvidedUserAgent, driverInfo); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/internal/EncodingUtils.java b/src/main/java/com/salesforce/datacloud/jdbc/internal/EncodingUtils.java new file mode 100644 index 0000000..11b85cd --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/internal/EncodingUtils.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.internal; + +import javax.annotation.concurrent.Immutable; +import lombok.experimental.UtilityClass; + +@Immutable +@UtilityClass +public final class EncodingUtils { + static final int BYTE_BASE16 = 2; + private static final String ALPHABET = "0123456789abcdef"; + private static final char[] ENCODING = buildEncodingArray(); + private static final boolean[] VALID_HEX = buildValidHexArray(); + + private static char[] buildEncodingArray() { + char[] encoding = new char[512]; + for (int i = 0; i < 256; ++i) { + encoding[i] = ALPHABET.charAt(i >>> 4); + encoding[i | 0x100] = ALPHABET.charAt(i & 0xF); + } + return encoding; + } + + private static boolean[] buildValidHexArray() { + boolean[] validHex = new boolean[Character.MAX_VALUE]; + for (int i = 0; i < Character.MAX_VALUE; i++) { + validHex[i] = (48 <= i && i <= 57) || (97 <= i && i <= 102); + } + return validHex; + } + + /** + * Appends the base16 encoding of the specified {@code value} to the {@code dest}. + * + * @param value the value to be converted. + * @param dest the destination char array. + * @param destOffset the starting offset in the destination char array. + */ + public static void longToBase16String(long value, char[] dest, int destOffset) { + byteToBase16((byte) (value >> 56 & 0xFFL), dest, destOffset); + byteToBase16((byte) (value >> 48 & 0xFFL), dest, destOffset + BYTE_BASE16); + byteToBase16((byte) (value >> 40 & 0xFFL), dest, destOffset + 2 * BYTE_BASE16); + byteToBase16((byte) (value >> 32 & 0xFFL), dest, destOffset + 3 * BYTE_BASE16); + byteToBase16((byte) (value >> 24 & 0xFFL), dest, destOffset + 4 * BYTE_BASE16); + byteToBase16((byte) (value >> 16 & 0xFFL), dest, destOffset + 5 * BYTE_BASE16); + byteToBase16((byte) (value >> 8 & 0xFFL), dest, destOffset + 6 * BYTE_BASE16); + byteToBase16((byte) (value & 0xFFL), dest, destOffset + 7 * BYTE_BASE16); + } + + /** + * Encodes the specified byte, and returns the encoded {@code String}. + * + * @param value the value to be converted. + * @param dest the destination char array. + * @param destOffset the starting offset in the destination char array. + */ + public static void byteToBase16(byte value, char[] dest, int destOffset) { + int b = value & 0xFF; + dest[destOffset] = ENCODING[b]; + dest[destOffset + 1] = ENCODING[b | 0x100]; + } + + /** Returns whether the {@link CharSequence} is a valid hex string. */ + public static boolean isValidBase16String(CharSequence value) { + int len = value.length(); + for (int i = 0; i < len; i++) { + char b = value.charAt(i); + if (!isValidBase16Character(b)) { + return false; + } + } + return true; + } + + /** Returns whether the given {@code char} is a valid hex character. */ + public static boolean isValidBase16Character(char b) { + return VALID_HEX[b]; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/internal/TemporaryBuffers.java b/src/main/java/com/salesforce/datacloud/jdbc/internal/TemporaryBuffers.java new file mode 100644 index 0000000..c2fc261 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/internal/TemporaryBuffers.java @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.internal; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class TemporaryBuffers { + + private static final ThreadLocal CHAR_ARRAY = new ThreadLocal<>(); + + /** + * A {@link ThreadLocal} {@code char[]} of size {@code len}. Take care when using a large value of {@code len} as + * this buffer will remain for the lifetime of the thread. The returned buffer will not be zeroed and may be larger + * than the requested size, you must make sure to fill the entire content to the desired value and set the length + * explicitly when converting to a {@link String}. + */ + public static char[] chars(int len) { + char[] buffer = CHAR_ARRAY.get(); + if (buffer == null || buffer.length < len) { + buffer = new char[len]; + CHAR_ARRAY.set(buffer); + } + return buffer; + } + + public static void clearChars() { + CHAR_ARRAY.remove(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/internal/Tracer.java b/src/main/java/com/salesforce/datacloud/jdbc/internal/Tracer.java new file mode 100644 index 0000000..6940866 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/internal/Tracer.java @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.internal; + +public class Tracer { + private static final int TRACE_ID_BYTES_LENGTH = 16; + private static final int TRACE_ID_HEX_LENGTH = 2 * TRACE_ID_BYTES_LENGTH; + private static final int SPAN_ID_BYTES_LENGTH = 8; + private static final int SPAN_ID_HEX_LENGTH = 2 * SPAN_ID_BYTES_LENGTH; + private static final long INVALID_ID = 0; + + private static final String INVALID = "0000000000000000"; + + private static volatile Tracer instance; + + public static synchronized Tracer get() { + if (instance == null) { + synchronized (Tracer.class) { + if (instance == null) { + instance = new Tracer(); + } + } + } + return instance; + } + + public String nextSpanId() { + long id; + do { + id = randomLong(); + } while (id == INVALID_ID); + return fromLong(id); + } + + public boolean isValidSpanId(CharSequence spanId) { + return spanId != null + && spanId.length() == SPAN_ID_HEX_LENGTH + && !INVALID.contentEquals(spanId) + && EncodingUtils.isValidBase16String(spanId); + } + + private String fromLong(long id) { + if (id == 0) { + return INVALID; + } + char[] result = TemporaryBuffers.chars(SPAN_ID_HEX_LENGTH); + EncodingUtils.longToBase16String(id, result, 0); + return new String(result, 0, SPAN_ID_HEX_LENGTH); + } + + private long randomLong() { + return java.util.concurrent.ThreadLocalRandom.current().nextLong(); + } + + public String nextId() { + long idHi = randomLong(); + long idLo; + do { + idLo = randomLong(); + } while (idLo == 0); + return fromLongs(idHi, idLo); + } + + public boolean isValidTraceId(CharSequence traceId) { + return traceId != null + && traceId.length() == TRACE_ID_HEX_LENGTH + && !INVALID.contentEquals(traceId) + && EncodingUtils.isValidBase16String(traceId); + } + + private String fromLongs(long traceIdLongHighPart, long traceIdLongLowPart) { + if (traceIdLongHighPart == 0 && traceIdLongLowPart == 0) { + return INVALID; + } + char[] chars = TemporaryBuffers.chars(TRACE_ID_HEX_LENGTH); + EncodingUtils.longToBase16String(traceIdLongHighPart, chars, 0); + EncodingUtils.longToBase16String(traceIdLongLowPart, chars, 16); + return new String(chars, 0, TRACE_ID_HEX_LENGTH); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/ArrowUtils.java b/src/main/java/com/salesforce/datacloud/jdbc/util/ArrowUtils.java new file mode 100644 index 0000000..4567af7 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/ArrowUtils.java @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.salesforce.datacloud.jdbc.core.model.ParameterBinding; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.math.BigDecimal; +import java.sql.JDBCType; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Calendar; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.SqlType; +import org.apache.calcite.avatica.proto.Common; + +@UtilityClass +@Slf4j +public class ArrowUtils { + + public static List toColumnMetaData(List fields) { + AtomicInteger index = new AtomicInteger(); + return fields.stream() + .map(field -> { + try { + return fieldToColumnMetaData(field, index.getAndIncrement()); + } catch (SQLException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + } + + private static ColumnMetaData fieldToColumnMetaData(Field field, int index) throws SQLException { + final Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder() + .setOrdinal(index) + .setColumnName(field.getName()) + .setLabel(field.getName()) + .setType(getAvaticaType(field.getType()).toProto()); + return ColumnMetaData.fromProto(builder.build()); + } + + /** Converts from JDBC metadata to Avatica columns. */ + public static List convertJDBCMetadataToAvaticaColumns(ResultSetMetaData metaData, int maxSize) { + if (metaData == null) { + return Collections.emptyList(); + } + + return Stream.iterate(1, i -> i + 1) + .limit(maxSize) + .map(i -> { + try { + val avaticaType = getAvaticaType(metaData.getColumnType(i), metaData.getColumnTypeName(i)); + return new ColumnMetaData( + i - 1, + metaData.isAutoIncrement(i), + metaData.isCaseSensitive(i), + metaData.isSearchable(i), + metaData.isCurrency(i), + metaData.isNullable(i), + metaData.isSigned(i), + metaData.getColumnDisplaySize(i), + metaData.getColumnLabel(i), + metaData.getColumnName(i), + metaData.getSchemaName(i), + metaData.getPrecision(i), + metaData.getScale(i), + metaData.getTableName(i), + metaData.getCatalogName(i), + avaticaType, + metaData.isReadOnly(i), + metaData.isWritable(i), + metaData.isDefinitelyWritable(i), + metaData.getColumnClassName(i)); + } catch (SQLException e) { + log.error("Error converting JDBC Metadata to Avatica Columns"); + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + } + + private static final Map> SQL_TYPE_TO_FIELD_TYPE = Map.ofEntries( + Map.entry(Types.VARCHAR, pb -> FieldType.nullable(new ArrowType.Utf8())), + Map.entry(Types.INTEGER, pb -> FieldType.nullable(new ArrowType.Int(32, true))), + Map.entry(Types.BIGINT, pb -> FieldType.nullable(new ArrowType.Int(64, true))), + Map.entry(Types.BOOLEAN, pb -> FieldType.nullable(new ArrowType.Bool())), + Map.entry(Types.TINYINT, pb -> FieldType.nullable(new ArrowType.Int(8, true))), + Map.entry(Types.SMALLINT, pb -> FieldType.nullable(new ArrowType.Int(16, true))), + Map.entry(Types.DATE, pb -> FieldType.nullable(new ArrowType.Date(DateUnit.DAY))), + Map.entry(Types.TIME, pb -> FieldType.nullable(new ArrowType.Time(TimeUnit.MICROSECOND, 64))), + Map.entry(Types.TIMESTAMP, pb -> FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"))), + Map.entry( + Types.FLOAT, pb -> FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))), + Map.entry( + Types.DOUBLE, pb -> FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))), + Map.entry(Types.DECIMAL, ArrowUtils::createDecimalFieldType), + Map.entry(Types.ARRAY, pb -> FieldType.nullable(new ArrowType.List()))); + + /** + * Creates a Schema from a list of ParameterBinding. + * + * @param parameterBindings a list of ParameterBinding objects + * @return a Schema object corresponding to the provided parameters + */ + public static Schema createSchemaFromParameters(List parameterBindings) { + if (parameterBindings == null) { + throw new IllegalArgumentException("ParameterBindings list cannot be null"); + } + List fields = IntStream.range(0, parameterBindings.size()) + .mapToObj(i -> createField(parameterBindings.get(i), i + 1)) + .collect(Collectors.toList()); + + return new Schema(fields); + } + + /** + * Creates a Field based on the ParameterBinding and its index. + * + * @param parameterBinding the ParameterBinding object + * @param index the index of the parameter in the list + * @return a Field object with a name based on the index and a FieldType based on the parameter + */ + private static Field createField(ParameterBinding parameterBinding, int index) { + FieldType fieldType = determineFieldType(parameterBinding); + return new Field(String.valueOf(index), fieldType, null); + } + + /** + * Determines the Arrow FieldType for a given ParameterBinding. + * + * @param parameterBinding the ParameterBinding object + * @return the corresponding Arrow FieldType + */ + private static FieldType determineFieldType(ParameterBinding parameterBinding) { + if (parameterBinding == null) { + // Default type for null values, using VARCHAR for simplicity + return FieldType.nullable(new ArrowType.Utf8()); + } + + int sqlType = parameterBinding.getSqlType(); + Function fieldTypeFunction = SQL_TYPE_TO_FIELD_TYPE.get(sqlType); + + if (fieldTypeFunction != null) { + return fieldTypeFunction.apply(parameterBinding); + } else { + throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); + } + } + + /** + * Creates a Decimal Arrow FieldType based on a ParameterBinding. + * + * @param parameterBinding the ParameterBinding object + * @return the corresponding Arrow FieldType for Decimal + */ + private static FieldType createDecimalFieldType(ParameterBinding parameterBinding) { + if (parameterBinding.getValue() instanceof BigDecimal) { + BigDecimal bd = (BigDecimal) parameterBinding.getValue(); + return FieldType.nullable(new ArrowType.Decimal(bd.precision(), bd.scale(), 128)); + } + throw new IllegalArgumentException("Decimal type requires a BigDecimal value"); + } + + public static byte[] toArrowByteArray(List parameters, Calendar calendar) throws IOException { + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Schema schema = ArrowUtils.createSchemaFromParameters(parameters); + + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + VectorPopulator.populateVectors(root, parameters, calendar); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(root, null, outputStream)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + return outputStream.toByteArray(); + } + } + + public static int getSQLTypeFromArrowType(ArrowType arrowType) { + val typeId = arrowType.getTypeID(); + switch (typeId) { + case Int: + return getSQLTypeForInt((ArrowType.Int) arrowType); + case Bool: + return Types.BOOLEAN; + case Utf8: + return Types.VARCHAR; + case LargeUtf8: + return Types.LONGVARCHAR; + case Binary: + return Types.VARBINARY; + case FixedSizeBinary: + return Types.BINARY; + case LargeBinary: + return Types.LONGVARBINARY; + case FloatingPoint: + return getSQLTypeForFloatingPoint((ArrowType.FloatingPoint) arrowType); + case Decimal: + return Types.DECIMAL; + case Date: + return Types.DATE; + case Time: + return Types.TIME; + case Timestamp: + return Types.TIMESTAMP; + case List: + case LargeList: + case FixedSizeList: + return Types.ARRAY; + case Map: + case Duration: + case Union: + case Interval: + return Types.JAVA_OBJECT; + case Struct: + return Types.STRUCT; + case NONE: + case Null: + return Types.NULL; + default: + break; + } + throw new IllegalArgumentException("Unsupported Arrow type: " + arrowType); + } + + private int getSQLTypeForInt(ArrowType.Int arrowType) { + val bitWidth = arrowType.getBitWidth(); + switch (bitWidth) { + case 8: + return Types.TINYINT; + case 16: + return Types.SMALLINT; + case 32: + return Types.INTEGER; + case 64: + return Types.BIGINT; + default: + break; + } + throw new IllegalArgumentException("Unsupported Arrow Integer Bit Width: " + bitWidth); + } + + private int getSQLTypeForFloatingPoint(ArrowType.FloatingPoint arrowType) { + val precision = arrowType.getPrecision(); + switch (precision) { + case SINGLE: + return Types.FLOAT; + case DOUBLE: + return Types.DOUBLE; + default: + break; + } + throw new IllegalArgumentException("Unsupported Arrow Floating Point: " + precision); + } + + private static ColumnMetaData.AvaticaType getAvaticaType(ArrowType arrowType) throws SQLException { + val sqlType = getSQLTypeFromArrowType(arrowType); + return getAvaticaType(sqlType, JDBCType.valueOf(sqlType).getName()); + } + + private static ColumnMetaData.AvaticaType getAvaticaType(int type, String typeName) throws SQLException { + final ColumnMetaData.AvaticaType avaticaType; + final SqlType sqlType = SqlType.valueOf(type); + final ColumnMetaData.Rep rep = ColumnMetaData.Rep.of(sqlType.internal); + if (sqlType == SqlType.ARRAY || sqlType == SqlType.STRUCT || sqlType == SqlType.MULTISET) { + ColumnMetaData.AvaticaType arrayValueType = + ColumnMetaData.scalar(java.sql.Types.JAVA_OBJECT, typeName, ColumnMetaData.Rep.OBJECT); + avaticaType = ColumnMetaData.array(arrayValueType, typeName, rep); + } else { + avaticaType = ColumnMetaData.scalar(type, typeName, rep); + } + return avaticaType; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/Constants.java b/src/main/java/com/salesforce/datacloud/jdbc/util/Constants.java new file mode 100644 index 0000000..1329960 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/Constants.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class Constants { + + public static final String CDP_URL = "/api/v1"; + public static final String METADATA_URL = "/metadata"; + + public static final String CONNECTION_PROTOCOL = "jdbc:salesforce-datacloud:"; + public static final String HYPER_LAKEHOUSE_ALIAS = "lakehouse"; + public static final String HYPER_LAKEHOUSE_PATH_PREFIX = "lakehouse:"; + + // authentication constants + public static final String LOGIN_URL = "loginURL"; + + // Property constants + public static final String CLIENT_ID = "clientId"; + public static final String CLIENT_SECRET = "clientSecret"; + public static final String USER = "user"; + public static final String USER_NAME = "userName"; + public static final String PRIVATE_KEY = "privateKey"; + public static final String FORCE_SYNC = "force-sync"; + + // Http/grpc client constants + public static final String AUTHORIZATION = "Authorization"; + public static final String CONTENT_TYPE_JSON = "application/json"; + public static final String POST = "POST"; + + // Column Types + public static final String INTEGER = "INTEGER"; + public static final String TEXT = "TEXT"; + public static final String SHORT = "SHORT"; + + public static final String DRIVER_NAME = "salesforce-datacloud-jdbc"; + public static final String DATABASE_PRODUCT_NAME = "salesforce-datacloud-queryservice"; + public static final String DATABASE_PRODUCT_VERSION = "24.8.0"; + public static final String DRIVER_VERSION = "3.0"; + + // Date Time constants + public static final String ISO_DATE_TIME_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; + public static final String ISO_DATE_TIME_SEC_FORMAT = "yyyy-MM-dd'T'HH:mm:ss'Z'"; + public static final String ISO_TIME_FORMAT = "HH:mm:ss"; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIterator.java b/src/main/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIterator.java new file mode 100644 index 0000000..6e6ae08 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIterator.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; +import java.util.stream.Stream; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +@Slf4j +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class ConsumingPeekingIterator implements Iterator { + public static ConsumingPeekingIterator of(Stream stream, Predicate isNotEmpty) { + return new ConsumingPeekingIterator<>(stream.filter(isNotEmpty).iterator(), isNotEmpty); + } + + private final Iterator iterator; + private final Predicate isNotEmpty; + private final AtomicReference consumable = new AtomicReference<>(); + + private boolean consumableHasMore() { + val head = this.consumable.get(); + return head != null && isNotEmpty.test(this.consumable.get()); + } + + @Override + public boolean hasNext() { + return consumableHasMore() || iterator.hasNext(); + } + + @Override + public T next() { + if (consumableHasMore()) { + return this.consumable.get(); + } + + val iteratorHasMore = iterator.hasNext(); + if (iteratorHasMore) { + this.consumable.set(iterator.next()); + return this.consumable.get(); + } + + throw new NoSuchElementException(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/DateTimeUtils.java b/src/main/java/com/salesforce/datacloud/jdbc/util/DateTimeUtils.java new file mode 100644 index 0000000..7e93430 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/DateTimeUtils.java @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; +import java.util.Calendar; +import java.util.TimeZone; +import lombok.experimental.UtilityClass; +import lombok.val; + +/** Datetime utility functions. */ +@UtilityClass +public class DateTimeUtils { + + public static final long MILLIS_TO_MICRO_SECS_CONVERSION_FACTOR = 1000; + + /** Subtracts default Calendar's timezone offset from epoch milliseconds to get relative UTC milliseconds */ + public static long applyCalendarOffset(long milliseconds) { + final TimeZone defaultTz = TimeZone.getDefault(); + return milliseconds - defaultTz.getOffset(milliseconds); + } + + public static long applyCalendarOffset(long milliseconds, Calendar calendar) { + if (calendar == null) { + return applyCalendarOffset(milliseconds); + } + val timeZone = calendar.getTimeZone(); + return milliseconds - timeZone.getOffset(milliseconds); + } + + public static Date getUTCDateFromMilliseconds(long milliseconds) { + return new Date(applyCalendarOffset(milliseconds)); + } + + public static Time getUTCTimeFromMilliseconds(long milliseconds) { + return new Time(applyCalendarOffset(milliseconds)); + } + + public static Date getUTCDateFromDateAndCalendar(Date date, Calendar calendar) { + val milliseconds = date.getTime(); + return new Date(applyCalendarOffset(milliseconds, calendar)); + } + + public static Time getUTCTimeFromTimeAndCalendar(Time time, Calendar calendar) { + val milliseconds = time.getTime(); + return new Time(applyCalendarOffset(milliseconds, calendar)); + } + + public static Timestamp getUTCTimestampFromTimestampAndCalendar(Timestamp timestamp, Calendar calendar) { + val milliseconds = timestamp.getTime(); + return new Timestamp(applyCalendarOffset(milliseconds, calendar)); + } + + /** + * Converts LocalDateTime to microseconds since epoch. + * + * @param localDateTime The LocalDateTime to convert. + * @return The microseconds since epoch. + */ + public static long localDateTimeToMicrosecondsSinceEpoch(LocalDateTime localDateTime) { + long epochMillis = localDateTime.toInstant(ZoneOffset.UTC).toEpochMilli(); + return millisToMicrosecondsSinceMidnight(epochMillis); + } + + /** + * Converts milliseconds since midnight to microseconds since midnight. + * + * @param millis The milliseconds since midnight. + * @return The microseconds since midnight. + */ + public static long millisToMicrosecondsSinceMidnight(long millis) { + return millis * MILLIS_TO_MICRO_SECS_CONVERSION_FACTOR; + } + + /** + * Adjusts LocalDateTime for the given Calendar's timezone offset. + * + * @param localDateTime The LocalDateTime to adjust. + * @param calendar The Calendar with the target timezone. + * @param defaultTimeZone The default timezone to compare against. + * @return The adjusted LocalDateTime. + */ + public static LocalDateTime adjustForCalendar( + LocalDateTime localDateTime, Calendar calendar, TimeZone defaultTimeZone) { + if (calendar == null) { + return localDateTime; + } + + TimeZone targetTimeZone = calendar.getTimeZone(); + long millis = localDateTime.toInstant(ZoneOffset.UTC).toEpochMilli(); + long offsetMillis = targetTimeZone.getOffset(millis) - (long) defaultTimeZone.getOffset(millis); + return localDateTime.plus(offsetMillis, ChronoUnit.MILLIS); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/Messages.java b/src/main/java/com/salesforce/datacloud/jdbc/util/Messages.java new file mode 100644 index 0000000..e0458ed --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/Messages.java @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public final class Messages { + + public static final String FAILED_LOGIN = "Failed to login. Please check credentials"; + + public static final String ILLEGAL_CONNECTION_PROTOCOL = + "URL is specified with invalid datasource, expected jdbc:salesforce-datacloud"; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/MetadataCacheUtil.java b/src/main/java/com/salesforce/datacloud/jdbc/util/MetadataCacheUtil.java new file mode 100644 index 0000000..9a96147 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/MetadataCacheUtil.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import java.util.concurrent.TimeUnit; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +@UtilityClass +public class MetadataCacheUtil { + + private static Cache cache = CacheBuilder.newBuilder() + .expireAfterWrite(600000, TimeUnit.MILLISECONDS) + .maximumSize(10) + .build(); + + public static String getMetadata(String url) { + return cache.getIfPresent(url); + } + + public static void cacheMetadata(String url, String response) { + cache.put(url, response); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensions.java b/src/main/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensions.java new file mode 100644 index 0000000..ebcfcd3 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensions.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import lombok.experimental.UtilityClass; +import lombok.val; + +@UtilityClass +public class PropertiesExtensions { + public static Optional optional(Properties properties, String key) { + if (properties == null) { + return Optional.empty(); + } + + if (key == null || !properties.containsKey(key)) { + return Optional.empty(); + } + + val value = properties.getProperty(key); + return (value == null || value.isBlank()) ? Optional.empty() : Optional.of(value); + } + + public static String required(Properties properties, String key) { + return optional(properties, key) + .orElseThrow(() -> new IllegalArgumentException(Messages.REQUIRED_MISSING_PREFIX + key)); + } + + public static Properties copy(Properties properties, Set filterKeys) { + val result = new Properties(); + for (val key : filterKeys) { + val value = properties.getProperty(key); + if (value != null) { + result.setProperty(key, value); + } + } + return result; + } + + public static Integer getIntegerOrDefault(Properties properties, String key, Integer defaultValue) { + return optional(properties, key) + .map(PropertiesExtensions::toIntegerOrNull) + .orElse(defaultValue); + } + + public static Integer toIntegerOrNull(String s) { + try { + return Integer.parseInt(s); + } catch (Exception ex) { + return null; + } + } + + public static Boolean getBooleanOrDefault(Properties properties, String key, Boolean defaultValue) { + return optional(properties, key) + .map(PropertiesExtensions::toBooleanOrDefault) + .orElse(defaultValue); + } + + public static Boolean toBooleanOrDefault(String s) { + return Boolean.valueOf(s); + } + + @UtilityClass + class Messages { + static final String REQUIRED_MISSING_PREFIX = "Properties missing required value for key: "; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/Require.java b/src/main/java/com/salesforce/datacloud/jdbc/util/Require.java new file mode 100644 index 0000000..e8f5e59 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/Require.java @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.StringUtils; + +@UtilityClass +public class Require { + public static void requireNotNullOrBlank(String value, String name) { + if (value == null || StringUtils.isBlank(value)) { + throw new IllegalArgumentException("Expected argument '" + name + "' to not be null or blank"); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/Result.java b/src/main/java/com/salesforce/datacloud/jdbc/util/Result.java new file mode 100644 index 0000000..882b3bc --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/Result.java @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.Optional; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NonNull; + +public abstract class Result { + private Result() {} + + public static Result of(@NonNull ThrowingSupplier supplier) { + try { + return new Success<>(supplier.get()); + } catch (Throwable t) { + return new Failure<>(t); + } + } + + abstract Optional get(); + + abstract Optional getError(); + + @Getter + @AllArgsConstructor + public static class Success extends Result { + private final T value; + + @Override + Optional get() { + return Optional.ofNullable(value); + } + + @Override + Optional getError() { + return Optional.empty(); + } + } + + @Getter + @AllArgsConstructor + public static class Failure extends Result { + private final Throwable error; + + @Override + Optional get() { + return Optional.empty(); + } + + @Override + Optional getError() { + return Optional.ofNullable(error); + } + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/SqlErrorCodes.java b/src/main/java/com/salesforce/datacloud/jdbc/util/SqlErrorCodes.java new file mode 100644 index 0000000..531bd74 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/SqlErrorCodes.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public class SqlErrorCodes { + public static final String FEATURE_NOT_SUPPORTED = "0A000"; + public static final String UNDEFINED_FILE = "58P01"; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java b/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java new file mode 100644 index 0000000..61d153b --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/StreamUtilities.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.Iterator; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.LongSupplier; +import java.util.function.Supplier; +import java.util.function.UnaryOperator; +import java.util.stream.Stream; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class StreamUtilities { + public Stream lazyLimitedStream(Supplier> streamSupplier, LongSupplier limitSupplier) { + return streamSupplier.get().limit(limitSupplier.getAsLong()); + } + + public Stream toStream(Iterator iterator) { + return Stream.iterate(iterator, Iterator::hasNext, UnaryOperator.identity()) + .map(Iterator::next); + } + + public Optional tryTimes( + int times, ThrowingSupplier attempt, Consumer consumer) { + return Stream.iterate(attempt, UnaryOperator.identity()) + .limit(times) + .map(Result::of) + .filter(r -> { + if (r.getError().isPresent()) { + consumer.accept(r.getError().get()); + return false; + } + return true; + }) + .findFirst() + .flatMap(Result::get); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java b/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java new file mode 100644 index 0000000..730dbe7 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.function.Function; + +@FunctionalInterface +public interface ThrowingFunction { + static Function rethrowFunction(ThrowingFunction function) throws E { + return t -> { + try { + return function.apply(t); + } catch (Exception exception) { + throwAsUnchecked(exception); + return null; + } + }; + } + + @SuppressWarnings("unchecked") + private static void throwAsUnchecked(Exception exception) throws E { + throw (E) exception; + } + + R apply(T t) throws E; +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingSupplier.java b/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingSupplier.java new file mode 100644 index 0000000..afc6b5a --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/ThrowingSupplier.java @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.function.LongSupplier; +import java.util.function.Supplier; +import java.util.stream.Stream; + +public interface ThrowingSupplier { + T get() throws E; + + static Supplier> rethrowSupplier(ThrowingSupplier function) throws E { + return () -> { + try { + return (Stream) function.get(); + } catch (Exception exception) { + throwAsUnchecked(exception); + return null; + } + }; + } + + static LongSupplier rethrowLongSupplier(ThrowingSupplier function) throws E { + return () -> { + try { + return (long) function.get(); + } catch (Exception exception) { + throwAsUnchecked(exception); + return Long.parseLong(null); + } + }; + } + + private static void throwAsUnchecked(Exception exception) throws E { + throw (E) exception; + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/VectorPopulator.java b/src/main/java/com/salesforce/datacloud/jdbc/util/VectorPopulator.java new file mode 100644 index 0000000..cbb5534 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/VectorPopulator.java @@ -0,0 +1,351 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.salesforce.datacloud.jdbc.core.model.ParameterBinding; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import lombok.experimental.UtilityClass; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; + +/** Populates vectors in a VectorSchemaRoot with values from a list of parameters. */ +@UtilityClass +public final class VectorPopulator { + + /** + * Populates the vectors in the given VectorSchemaRoot. + * + * @param root The VectorSchemaRoot to populate. + */ + public static void populateVectors(VectorSchemaRoot root, List parameters, Calendar calendar) { + VectorValueSetterFactory factory = new VectorValueSetterFactory(calendar); + + for (int i = 0; i < parameters.size(); i++) { + Field field = root.getSchema().getFields().get(i); + ValueVector vector = root.getVector(field.getName()); + Object value = parameters.get(i) == null ? null : parameters.get(i).getValue(); + + @SuppressWarnings(value = "unchecked") + VectorValueSetter setter = + (VectorValueSetter) factory.getSetter(vector.getClass()); + + if (setter != null) { + setter.setValue(vector, value); + } else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + root.setRowCount(1); // Set row count to 1 since we have exactly one row + } +} + +@FunctionalInterface +interface VectorValueSetter { + void setValue(T vector, Object value); +} + +/** Factory for creating appropriate setter instances based on vector type. */ +class VectorValueSetterFactory { + private final Map, VectorValueSetter> setterMap; + + VectorValueSetterFactory(Calendar calendar) { + setterMap = Map.ofEntries( + Map.entry(VarCharVector.class, new VarCharVectorSetter()), + Map.entry(Float4Vector.class, new Float4VectorSetter()), + Map.entry(Float8Vector.class, new Float8VectorSetter()), + Map.entry(IntVector.class, new IntVectorSetter()), + Map.entry(SmallIntVector.class, new SmallIntVectorSetter()), + Map.entry(BigIntVector.class, new BigIntVectorSetter()), + Map.entry(BitVector.class, new BitVectorSetter()), + Map.entry(DecimalVector.class, new DecimalVectorSetter()), + Map.entry(DateDayVector.class, new DateDayVectorSetter()), + Map.entry(TimeMicroVector.class, new TimeMicroVectorSetter(calendar)), + Map.entry(TimeStampMicroTZVector.class, new TimeStampMicroTZVectorSetter(calendar)), + Map.entry(TinyIntVector.class, new TinyIntVectorSetter())); + } + + @SuppressWarnings("unchecked") + VectorValueSetter getSetter(Class vectorClass) { + return (VectorValueSetter) setterMap.get(vectorClass); + } +} + +/** Base setter implementation for ValueVectors that need type validation. */ +abstract class BaseVectorSetter implements VectorValueSetter { + private final Class valueType; + + BaseVectorSetter(Class valueType) { + this.valueType = valueType; + } + + @Override + public void setValue(T vector, Object value) { + if (value == null) { + setNullValue(vector); + } else if (valueType.isInstance(value)) { + setValueInternal(vector, valueType.cast(value)); + } else { + throw new IllegalArgumentException( + "Value for " + vector.getClass().getSimpleName() + " must be of type " + valueType.getSimpleName()); + } + } + + protected abstract void setNullValue(T vector); + + protected abstract void setValueInternal(T vector, V value); +} + +/** Setter implementation for VarCharVector. */ +class VarCharVectorSetter extends BaseVectorSetter { + VarCharVectorSetter() { + super(String.class); + } + + @Override + protected void setValueInternal(VarCharVector vector, String value) { + vector.setSafe(0, value.getBytes(StandardCharsets.UTF_8)); + } + + @Override + protected void setNullValue(VarCharVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for Float4Vector. */ +class Float4VectorSetter extends BaseVectorSetter { + Float4VectorSetter() { + super(Float.class); + } + + @Override + protected void setValueInternal(Float4Vector vector, Float value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(Float4Vector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for Float8Vector. */ +class Float8VectorSetter extends BaseVectorSetter { + Float8VectorSetter() { + super(Double.class); + } + + @Override + protected void setValueInternal(Float8Vector vector, Double value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(Float8Vector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for IntVector. */ +class IntVectorSetter extends BaseVectorSetter { + IntVectorSetter() { + super(Integer.class); + } + + @Override + protected void setValueInternal(IntVector vector, Integer value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(IntVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for SmallIntVector. */ +class SmallIntVectorSetter extends BaseVectorSetter { + SmallIntVectorSetter() { + super(Short.class); + } + + @Override + protected void setValueInternal(SmallIntVector vector, Short value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(SmallIntVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for BigIntVector. */ +class BigIntVectorSetter extends BaseVectorSetter { + BigIntVectorSetter() { + super(Long.class); + } + + @Override + protected void setValueInternal(BigIntVector vector, Long value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(BigIntVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for BitVector. */ +class BitVectorSetter extends BaseVectorSetter { + BitVectorSetter() { + super(Boolean.class); + } + + @Override + protected void setValueInternal(BitVector vector, Boolean value) { + vector.setSafe(0, Boolean.TRUE.equals(value) ? 1 : 0); + } + + @Override + protected void setNullValue(BitVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for DecimalVector. */ +class DecimalVectorSetter extends BaseVectorSetter { + DecimalVectorSetter() { + super(BigDecimal.class); + } + + @Override + protected void setValueInternal(DecimalVector vector, BigDecimal value) { + vector.setSafe(0, value.unscaledValue().longValue()); + } + + @Override + protected void setNullValue(DecimalVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for DateDayVector. */ +class DateDayVectorSetter extends BaseVectorSetter { + DateDayVectorSetter() { + super(Date.class); + } + + @Override + protected void setValueInternal(DateDayVector vector, Date value) { + long daysSinceEpoch = value.toLocalDate().toEpochDay(); + vector.setSafe(0, (int) daysSinceEpoch); + } + + @Override + protected void setNullValue(DateDayVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for TimeMicroVector. */ +class TimeMicroVectorSetter extends BaseVectorSetter { + private final Calendar calendar; + + TimeMicroVectorSetter(Calendar calendar) { + super(Time.class); + this.calendar = calendar; + } + + @Override + protected void setValueInternal(TimeMicroVector vector, Time value) { + LocalDateTime localDateTime = new Timestamp(value.getTime()).toLocalDateTime(); + localDateTime = DateTimeUtils.adjustForCalendar(localDateTime, calendar, TimeZone.getTimeZone("UTC")); + long midnightMillis = localDateTime.toLocalTime().toNanoOfDay() / 1_000_000; + long microsecondsSinceMidnight = DateTimeUtils.millisToMicrosecondsSinceMidnight(midnightMillis); + + vector.setSafe(0, microsecondsSinceMidnight); + } + + @Override + protected void setNullValue(TimeMicroVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for TimeStampMicroTZVector. */ +class TimeStampMicroTZVectorSetter extends BaseVectorSetter { + private final Calendar calendar; + + TimeStampMicroTZVectorSetter(Calendar calendar) { + super(Timestamp.class); + this.calendar = calendar; + } + + @Override + protected void setValueInternal(TimeStampMicroTZVector vector, Timestamp value) { + LocalDateTime localDateTime = value.toLocalDateTime(); + localDateTime = DateTimeUtils.adjustForCalendar(localDateTime, calendar, TimeZone.getTimeZone("UTC")); + long microsecondsSinceEpoch = DateTimeUtils.localDateTimeToMicrosecondsSinceEpoch(localDateTime); + + vector.setSafe(0, microsecondsSinceEpoch); + } + + @Override + protected void setNullValue(TimeStampMicroTZVector vector) { + vector.setNull(0); + } +} + +/** Setter implementation for TinyIntVectorSetter. */ +class TinyIntVectorSetter extends BaseVectorSetter { + TinyIntVectorSetter() { + super(Byte.class); + } + + @Override + protected void setValueInternal(TinyIntVector vector, Byte value) { + vector.setSafe(0, value); + } + + @Override + protected void setNullValue(TinyIntVector vector) { + vector.setNull(0); + } +} diff --git a/src/main/java/com/salesforce/datacloud/jdbc/util/internal/SFDefaultSocketFactoryWrapper.java b/src/main/java/com/salesforce/datacloud/jdbc/util/internal/SFDefaultSocketFactoryWrapper.java new file mode 100644 index 0000000..4d05749 --- /dev/null +++ b/src/main/java/com/salesforce/datacloud/jdbc/util/internal/SFDefaultSocketFactoryWrapper.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util.internal; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.Proxy; +import java.net.Socket; +import javax.net.SocketFactory; +import lombok.extern.slf4j.Slf4j; + +/** Default Wrapper for SocketFactory. */ +@Slf4j +public class SFDefaultSocketFactoryWrapper extends SocketFactory { + + private final boolean isSocksProxyDisabled; + private final SocketFactory socketFactory; + + public SFDefaultSocketFactoryWrapper(boolean isSocksProxyDisabled) { + this(isSocksProxyDisabled, SocketFactory.getDefault()); + } + + public SFDefaultSocketFactoryWrapper(boolean isSocksProxyDisabled, SocketFactory socketFactory) { + super(); + this.isSocksProxyDisabled = isSocksProxyDisabled; + this.socketFactory = socketFactory; + } + + /** + * When isSocksProxyDisabled then, socket backed by plain socket impl is returned. Otherwise, delegates + * the socket creation to specified socketFactory + * + * @return socket + * @throws IOException when socket creation fails + */ + @Override + public Socket createSocket() throws IOException { + // avoid creating SocksSocket when SocksProxyDisabled + // this is the method called by okhttp + return isSocksProxyDisabled ? new Socket(Proxy.NO_PROXY) : this.socketFactory.createSocket(); + } + + @Override + public Socket createSocket(String host, int port) throws IOException { + return this.socketFactory.createSocket(host, port); + } + + @Override + public Socket createSocket(InetAddress address, int port) throws IOException { + return this.socketFactory.createSocket(address, port); + } + + @Override + public Socket createSocket(String host, int port, InetAddress clientAddress, int clientPort) throws IOException { + return this.socketFactory.createSocket(host, port, clientAddress, clientPort); + } + + @Override + public Socket createSocket(InetAddress address, int port, InetAddress clientAddress, int clientPort) + throws IOException { + return this.socketFactory.createSocket(address, port, clientAddress, clientPort); + } +} diff --git a/src/main/proto/error_details.proto b/src/main/proto/error_details.proto new file mode 100644 index 0000000..1be0e26 --- /dev/null +++ b/src/main/proto/error_details.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package salesforce.hyperdb.grpc.v1; + +option java_multiple_files = true; +option java_package = "com.salesforce.hyperdb.grpc"; + +message TextPosition { + uint64 error_begin_character_offset = 2; + uint64 error_end_character_offset = 3; +} + +message ErrorInfo { + string primary_message = 1; + string sqlstate = 2; + string customer_hint = 3; + string customer_detail = 4; + string system_detail = 5; + TextPosition position = 6; +} \ No newline at end of file diff --git a/src/main/proto/hyper_service.proto b/src/main/proto/hyper_service.proto new file mode 100644 index 0000000..62c70b4 --- /dev/null +++ b/src/main/proto/hyper_service.proto @@ -0,0 +1,199 @@ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + +package salesforce.hyperdb.grpc.v1; + + +option java_multiple_files = true; +option java_package = "com.salesforce.hyperdb.grpc"; +option java_outer_classname = "HyperDatabaseServiceProto"; + +service HyperService { + rpc ExecuteQuery (QueryParam) returns (stream ExecuteQueryResponse); + rpc GetQueryResult (QueryResultParam) returns (stream QueryResult); + rpc GetQueryInfo (QueryInfoParam) returns (stream QueryInfo); + rpc CancelQuery (CancelQueryParam) returns (google.protobuf.Empty); +} + +message CancelQueryParam { + string query_id = 1; +} + +message QueryResultParam { + string query_id = 1; + OutputFormat output_format = 2; + oneof requested_data { + uint64 chunk_id = 3; + RowRange row_range = 5; + } + bool omit_schema = 4; +} + +message RowRange { + uint64 offset = 1; + uint64 row_count = 2; +} + +message QueryResult { + oneof result { + QueryResultPartBinary binary_part = 1; + QueryResultPartString string_part = 2; + } +} + +message QueryInfoParam { + string query_id = 1; + bool streaming = 2; +} + +message QueryInfo { + oneof content { + QueryStatus query_status = 1; + } + bool optional = 2; +} + +message QueryStatus { + enum CompletionStatus{ + RUNNING = 0; + RESULTS_PRODUCED = 1; + FINISHED = 2; + } + string query_id = 1; + CompletionStatus completion_status = 2; + uint64 chunk_count = 3; + uint64 row_count = 4; + double progress = 5; + google.protobuf.Timestamp expiration_time = 6; +} + +enum OutputFormat { + TEXT_DEBUG = 0; + reserved 1; + ARROW_LEGACY_V1 = 2; + JSON_LEGACY_V1 = 3; + ARROW_LEGACY_V2 = 4; + ARROW_V3 = 5; + JSON_V2 = 6; +} + +message QueryParam { + enum TransferMode { + ADAPTIVE = 0; + ASYNC = 1; + SYNC = 2; + } + + enum ParameterStyle { + QUESTION_MARK = 0; + DOLLAR_NUMBERED = 1; + NAMED = 2; + } + + string query = 1; + repeated AttachedDatabase database = 2; + OutputFormat output_format = 3; + map settings = 4; + TransferMode transfer_mode = 5; + ParameterStyle param_style = 6; + oneof parameters { + QueryParameterArrow arrow_parameters = 7; + QueryParameterJson json_parameters = 8; + } + uint64 max_rows = 9; +} + +message AttachedDatabase { + string path = 1; + string alias = 2; +} + +message ExecuteQueryResponse { + oneof result { + QueryResultHeader header = 1; + QueryResultPartBinary binary_part = 4; + QueryResultPartString string_part = 5; + QueryInfo query_info = 6; + QueryResult query_result = 7; + } + bool optional = 9; +} + +message QueryResultHeader { + oneof header { + QueryResultSchema schema = 1; + QueryCommandOk command = 2; + } +} + +message QueryCommandOk { + oneof command_return { + google.protobuf.Empty empty = 2; + uint64 affected_rows = 1; + } +} + +message QueryResultSchema { + repeated ColumnDescription column = 1; +} + +message ColumnDescription { + string name = 1; + SqlType type = 2; +} + +message SqlType { + enum TypeTag { + HYPER_UNSUPPORTED = 0; + HYPER_BOOL = 1; + HYPER_BIG_INT = 2; + HYPER_SMALL_INT = 3; + HYPER_INT = 4; + HYPER_NUMERIC = 5; + HYPER_DOUBLE = 6; + HYPER_OID = 7; + HYPER_BYTE_A = 8; + HYPER_TEXT = 9; + HYPER_VARCHAR = 10; + HYPER_CHAR = 11; + HYPER_JSON = 12; + HYPER_DATE = 13; + HYPER_INTERVAL = 14; + HYPER_TIME = 15; + HYPER_TIMESTAMP = 16; + HYPER_TIMESTAMP_TZ = 17; + HYPER_GEOGRAPHY = 18; + HYPER_FLOAT = 19; + HYPER_ARRAY_OF_FLOAT = 20; + } + + message NumericModifier { + uint32 precision = 1; + uint32 scale = 2; + } + + TypeTag tag = 1; + oneof modifier { + google.protobuf.Empty empty = 2; + uint32 max_length = 3; + NumericModifier numeric_modifier = 4; + } +} + +message QueryResultPartBinary { + bytes data = 127; +} + +message QueryResultPartString { + string data = 127; +} + +message QueryParameterArrow { + bytes data = 127; +} + +message QueryParameterJson { + string data = 127; +} \ No newline at end of file diff --git a/src/main/resources/META-INF/services/java.sql.Driver b/src/main/resources/META-INF/services/java.sql.Driver new file mode 100644 index 0000000..68a8e91 --- /dev/null +++ b/src/main/resources/META-INF/services/java.sql.Driver @@ -0,0 +1 @@ +com.salesforce.datacloud.jdbc.DataCloudJDBCDriver diff --git a/src/main/resources/keywords/hyper_sql_lexer_keywords.txt b/src/main/resources/keywords/hyper_sql_lexer_keywords.txt new file mode 100644 index 0000000..2c832a9 --- /dev/null +++ b/src/main/resources/keywords/hyper_sql_lexer_keywords.txt @@ -0,0 +1,450 @@ +abort +absolute +access +action +add +admin +after +aggregate +all +also +alter +always +analyse +analyze +and +any +array +as +asc +assertion +assignment +assumed +asymmetric +at +attach +attribute +authorization +backward +before +begin +between +bigint +binary +bit +boolean +both +bounded +buffered +bulk +by +cache +call +called +cascade +cascaded +case +cast +catalog +chain +character +characteristics +char +check +checkpoint +class +close +cluster +coalesce +collate +collation +column +comment +comments +commit +committed +concurrently +configuration +connection +constraint +constraints +constref +content +continue +continuous +conversion +copy +cost +create +cross +csv +cube +current_catalog +current_date +current +current_role +current_schema +current_time +current_timestamp +current_user +cursor +cycle +database +data +day +deallocate +dec +decimal +declare +default +defaults +deferrable +deferred +definer +delete +delimiter +delimiters +desc +descriptor +detach +dictionary +disable +discard +distinct +do +document +domain +double +drop +each +else +empty +enable +encoding +encrypted +end +enum +escape +event +except +exclude +excluding +exclusive +execute +exists +explain +export +extension +external +extract +false +family +fetch +filter +first +float +following +for +force +foreign +forward +freeze +from +full +function +functions +global +grant +granted +greatest +group +groups +grouping +handler +having +header +hold +hour +identity +if +ignore +ilike +immediate +immutable +implicit +import +including +increment +index +indexes +inherit +inherits +initially +inline +inner +inout +input +insensitive +insert +instead +integer +intersect +interval +into +int +invoker +in +is +isnull +isolation +join +key +keep +label +language +large +last +lateral +leading +leakproof +least +left +level +like +limit +listen +load +local +localtime +localtimestamp +location +locked +lock +logged +mapping +match +materialized +maxvalue +minute +minvalue +mode +month +move +names +name +national +natural +nchar +next +no +none +not +nothing +notify +notnull +nowait +nth_value +nullif +nulls +null +numeric +object +of +off +offset +oids +on +only +operator +option +options +or +order +ordinality +others +outer +out +over +overlaps +overlay +owned +owner +parser +partial +partition +pass +passing +password +placing +plans +policy +position +preceding +precision +prepare +prepared +preserve +primary +prior +privileges +procedural +procedure +program +prune +quote +range +read +real +reassign +recheck +recursive +ref +references +refresh +reindex +relative +release +rename +repeatable +replace +replica +reset +respect +restart +restrict +returning +returns +revoke +right +role +rollback +rollup +row +rows +rule +sanitize +savepoint +schema +scroll +search +second +secure +security +select +semantics +sequence +sequences +serializable +server +session +session_user +set +setof +sets +share +show +similar +simple +skip +smallint +snapshot +some +stable +standalone +start +statement +statistics +stdin +stdout +storage +strict +strip +substring +symmetric +sysid +system +table +tables +tablesample +tablespace +temp +template +temporary +text +then +through +throughs +ties +time +timestamp +to +trailing +transaction +treat +trigger +trim +true +truncate +trusted +try_cast +types +type +unbounded +uncommitted +unencrypted +union +unique +unknown +unlisten +unlogged +until +update +user +using +vacuum +valid +validate +validator +values +value +varchar +variadic +varying +verbose +version +view +views +volatile +when +where +whitespace +window +with +within +without +work +wrapper +write +xmlattributes +xmlconcat +xmlelement +xmlexists +xmlforest +xmlparse +xmlpi +xmlroot +xmlserialize +xml +year +yes +zone +cardinality +constant +defragmentation +gracefully +forcefully +percent +return +stream +unload +lambda \ No newline at end of file diff --git a/src/main/resources/simplelogger.properties b/src/main/resources/simplelogger.properties new file mode 100644 index 0000000..bafbd34 --- /dev/null +++ b/src/main/resources/simplelogger.properties @@ -0,0 +1,2 @@ +org.slf4j.simpleLogger.logFile=System.out +org.slf4j.simpleLogger.defaultLogLevel=info diff --git a/src/main/resources/sql/get_columns_query.sql b/src/main/resources/sql/get_columns_query.sql new file mode 100644 index 0000000..8c83490 --- /dev/null +++ b/src/main/resources/sql/get_columns_query.sql @@ -0,0 +1,27 @@ +SELECT n.nspname, + c.relname, + a.attname, + a.atttypid, + a.attnotnull OR (t.typtype = 'd' AND t.typnotnull) AS attnotnull, + a.atttypmod, + a.attlen, + t.typtypmod, + a.attnum, + null as attidentity, + null as attgenerated, + pg_catalog.pg_get_expr(def.adbin, def.adrelid) AS adsrc, + dsc.description, + t.typbasetype, + t.typtype, + pg_catalog.format_type(a.atttypid, a.atttypmod) as datatype +FROM pg_catalog.pg_namespace n + JOIN pg_catalog.pg_class c ON (c.relnamespace = n.oid) + JOIN pg_catalog.pg_attribute a ON (a.attrelid = c.oid) + JOIN pg_catalog.pg_type t ON (a.atttypid = t.oid) + LEFT JOIN pg_catalog.pg_attrdef def ON (a.attrelid = def.adrelid AND a.attnum = def.adnum) + LEFT JOIN pg_catalog.pg_description dsc ON (c.oid = dsc.objoid AND a.attnum = dsc.objsubid) + LEFT JOIN pg_catalog.pg_class dc ON (dc.oid = dsc.classoid AND dc.relname = 'pg_class') + LEFT JOIN pg_catalog.pg_namespace dn ON (dc.relnamespace = dn.oid AND dn.nspname = 'pg_catalog') +WHERE c.relkind in ('r', 'p', 'v', 'f', 'm') + and a.attnum > 0 + AND NOT a.attisdropped \ No newline at end of file diff --git a/src/main/resources/sql/get_schemas_query.sql b/src/main/resources/sql/get_schemas_query.sql new file mode 100644 index 0000000..21effa1 --- /dev/null +++ b/src/main/resources/sql/get_schemas_query.sql @@ -0,0 +1,7 @@ +SELECT nspname AS TABLE_SCHEM, NULL AS TABLE_CATALOG +FROM pg_catalog.pg_namespace +WHERE nspname <> 'pg_toast' + AND (nspname !~ '^pg_temp_' + OR nspname = (pg_catalog.current_schemas(true))[1]) + AND (nspname !~ '^pg_toast_temp_' + OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_')) \ No newline at end of file diff --git a/src/main/resources/sql/get_tables_query.sql b/src/main/resources/sql/get_tables_query.sql new file mode 100644 index 0000000..a58d42d --- /dev/null +++ b/src/main/resources/sql/get_tables_query.sql @@ -0,0 +1,44 @@ +SELECT NULL AS TABLE_CAT, + n.nspname AS TABLE_SCHEM, + c.relname AS TABLE_NAME, + CASE n.nspname ~ '^pg_' OR n.nspname = 'information_schema' WHEN true THEN CASE + WHEN n.nspname = 'pg_catalog' OR n.nspname = 'information_schema' THEN CASE c.relkind + WHEN 'r' THEN 'SYSTEM TABLE' + WHEN 'v' THEN 'SYSTEM VIEW' + WHEN 'i' THEN 'SYSTEM INDEX' + ELSE NULL +END +WHEN n.nspname = 'pg_toast' THEN CASE c.relkind + WHEN 'r' THEN 'SYSTEM TOAST TABLE' + WHEN 'i' THEN 'SYSTEM TOAST INDEX' + ELSE NULL +END +ELSE CASE c.relkind + WHEN 'r' THEN 'TEMPORARY TABLE' + WHEN 'p' THEN 'TEMPORARY TABLE' + WHEN 'i' THEN 'TEMPORARY INDEX' + WHEN 'S' THEN 'TEMPORARY SEQUENCE' + WHEN 'v' THEN 'TEMPORARY VIEW' + ELSE NULL +END +END +WHEN false THEN CASE c.relkind + WHEN 'r' THEN 'TABLE' + WHEN 'p' THEN 'PARTITIONED TABLE' + WHEN 'i' THEN 'INDEX' + WHEN 'P' then 'PARTITIONED INDEX' + WHEN 'S' THEN 'SEQUENCE' + WHEN 'v' THEN 'VIEW' + WHEN 'c' THEN 'TYPE' + WHEN 'f' THEN 'FOREIGN TABLE' + WHEN 'm' THEN 'MATERIALIZED VIEW' + ELSE NULL +END +ELSE NULL +END +AS TABLE_TYPE, d.description AS REMARKS, + '' as TYPE_CAT, '' as TYPE_SCHEM, '' as TYPE_NAME, + '' AS SELF_REFERENCING_COL_NAME, '' AS REF_GENERATION + FROM pg_catalog.pg_namespace n, pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_description d ON (c.oid = d.objoid AND d.objsubid = 0 and d.classoid = 'pg_class'::regclass) + WHERE c.relnamespace = n.oid \ No newline at end of file diff --git a/src/main/resources/version.properties b/src/main/resources/version.properties new file mode 100644 index 0000000..e5683df --- /dev/null +++ b/src/main/resources/version.properties @@ -0,0 +1 @@ +version=${project.version} \ No newline at end of file diff --git a/src/test/java/com/salesforce/datacloud/jdbc/DataCloudDatasourceTest.java b/src/test/java/com/salesforce/datacloud/jdbc/DataCloudDatasourceTest.java new file mode 100644 index 0000000..e18bf96 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/DataCloudDatasourceTest.java @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; +import java.util.UUID; +import java.util.stream.Stream; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedStatic; + +class DataCloudDatasourceTest { + private static final DataCloudDatasource dataCloudDatasource = new DataCloudDatasource(); + + @Test + void testGetConnectionReturnsInstanceOfConnection() throws SQLException { + val expectedProperties = new Properties(); + val connectionUrl = UUID.randomUUID().toString(); + val userName = UUID.randomUUID().toString(); + val password = UUID.randomUUID().toString(); + val privateKey = UUID.randomUUID().toString(); + val coreToken = UUID.randomUUID().toString(); + val refreshToken = UUID.randomUUID().toString(); + val clientId = UUID.randomUUID().toString(); + val clientSecret = UUID.randomUUID().toString(); + val internalEndpoint = UUID.randomUUID().toString(); + val port = UUID.randomUUID().toString(); + val tenantId = UUID.randomUUID().toString(); + val dataspace = UUID.randomUUID().toString(); + val coreTenantId = UUID.randomUUID().toString(); + + expectedProperties.setProperty("userName", userName); + expectedProperties.setProperty("password", password); + expectedProperties.setProperty("privateKey", privateKey); + expectedProperties.setProperty("refreshToken", refreshToken); + expectedProperties.setProperty("coreToken", coreToken); + expectedProperties.setProperty("clientId", clientId); + expectedProperties.setProperty("clientSecret", clientSecret); + expectedProperties.setProperty("internalEndpoint", internalEndpoint); + expectedProperties.setProperty("port", port); + expectedProperties.setProperty("tenantId", tenantId); + expectedProperties.setProperty("dataspace", dataspace); + expectedProperties.setProperty("coreTenantId", coreTenantId); + + val dataCloudDatasource = new DataCloudDatasource(); + dataCloudDatasource.setConnectionUrl(connectionUrl); + dataCloudDatasource.setUserName(userName); + dataCloudDatasource.setPassword(password); + dataCloudDatasource.setPrivateKey(privateKey); + dataCloudDatasource.setRefreshToken(refreshToken); + dataCloudDatasource.setCoreToken(coreToken); + dataCloudDatasource.setInternalEndpoint(internalEndpoint); + dataCloudDatasource.setPort(port); + dataCloudDatasource.setTenantId(tenantId); + dataCloudDatasource.setDataspace(dataspace); + dataCloudDatasource.setCoreTenantId(coreTenantId); + dataCloudDatasource.setClientId(clientId); + dataCloudDatasource.setClientSecret(clientSecret); + Connection mockConnection = mock(Connection.class); + + try (MockedStatic mockedDriverManager = mockStatic(DriverManager.class)) { + mockedDriverManager + .when(() -> DriverManager.getConnection(connectionUrl, expectedProperties)) + .thenReturn(mockConnection); + val connection = dataCloudDatasource.getConnection(); + assertThat(connection).isSameAs(mockConnection); + } + } + + @Test + void testGetConnectionWithUsernameAndPasswordReturnsInstanceOfConnection() throws SQLException { + val expectedProperties = new Properties(); + val connectionUrl = UUID.randomUUID().toString(); + val userName = UUID.randomUUID().toString(); + val password = UUID.randomUUID().toString(); + expectedProperties.setProperty("userName", userName); + expectedProperties.setProperty("password", password); + val dataCloudDatasource = new DataCloudDatasource(); + dataCloudDatasource.setConnectionUrl(connectionUrl); + Connection mockConnection = mock(Connection.class); + + try (MockedStatic mockedDriverManager = mockStatic(DriverManager.class)) { + mockedDriverManager + .when(() -> DriverManager.getConnection(connectionUrl, expectedProperties)) + .thenReturn(mockConnection); + val connection = dataCloudDatasource.getConnection(userName, password); + assertThat(connection).isSameAs(mockConnection); + } + } + + private static Stream unsupportedMethods() { + return Stream.of( + () -> dataCloudDatasource.setLoginTimeout(0), + () -> dataCloudDatasource.getLoginTimeout(), + () -> dataCloudDatasource.setLogWriter(null), + () -> dataCloudDatasource.getLogWriter(), + () -> dataCloudDatasource.getParentLogger()); + } + + @ParameterizedTest + @MethodSource("unsupportedMethods") + void throwsOnUnsupportedMethods(Executable func) { + val ex = Assertions.assertThrows(DataCloudJDBCException.class, func); + AssertionsForClassTypes.assertThat(ex) + .hasMessage("Datasource method is not supported in Data Cloud query") + .hasFieldOrPropertyWithValue("SQLState", SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Test + void unwrapMethodsActAsExpected() throws SQLException { + val dataCloudDatasource = new DataCloudDatasource(); + Assertions.assertNull(dataCloudDatasource.unwrap(DataCloudDatasource.class)); + Assertions.assertFalse(dataCloudDatasource.isWrapperFor(DataCloudDatasource.class)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriverTest.java b/src/test/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriverTest.java new file mode 100644 index 0000000..ef02dfa --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/DataCloudJDBCDriverTest.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType; + +import com.salesforce.datacloud.jdbc.config.DriverVersion; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; +import java.util.regex.Pattern; +import lombok.val; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class DataCloudJDBCDriverTest { + public static final String VALID_URL = "jdbc:salesforce-datacloud://login.salesforce.com"; + private static final String DRIVER_NAME = "salesforce-datacloud-jdbc"; + private static final String PRODUCT_NAME = "salesforce-datacloud-queryservice"; + private static final String PRODUCT_VERSION = "1.0"; + private final Pattern pattern = Pattern.compile("^\\d+(\\.\\d+)*(-SNAPSHOT)?$"); + + @Test + void testIsDriverRegisteredInDriverManager() throws Exception { + assertThat(DriverManager.getDriver(VALID_URL)).isNotNull().isInstanceOf(DataCloudJDBCDriver.class); + } + + @Test + void testNullUrlNotAllowedWhenConnecting() { + final Driver driver = new DataCloudJDBCDriver(); + Properties properties = new Properties(); + + assertThatExceptionOfType(SQLException.class).isThrownBy(() -> driver.connect(null, properties)); + } + + @Test + void testUnsupportedPrefixUrlNotAllowedWhenConnecting() throws Exception { + final Driver driver = new DataCloudJDBCDriver(); + Properties properties = new Properties(); + + assertThat(driver.connect("jdbc:mysql://localhost:3306", properties)).isNull(); + } + + @Test + void testInvalidPrefixUrlNotAccepted() throws Exception { + final Driver driver = new DataCloudJDBCDriver(); + + assertThat(driver.acceptsURL("jdbc:mysql://localhost:3306")).isFalse(); + } + + @Test + void testGetMajorVersion() { + final Driver driver = new DataCloudJDBCDriver(); + assertThat(driver.getMajorVersion()).isEqualTo(DriverVersion.getMajorVersion()); + } + + @Test + void testGetMinorVersion() { + final Driver driver = new DataCloudJDBCDriver(); + assertThat(driver.getMinorVersion()).isEqualTo(DriverVersion.getMinorVersion()); + } + + @Test + void testValidUrlPrefixAccepted() throws Exception { + final Driver driver = new DataCloudJDBCDriver(); + + assertThat(driver.acceptsURL(VALID_URL)).isTrue(); + } + + @Test + void testjdbcCompliant() { + final Driver driver = new DataCloudJDBCDriver(); + assertThat(driver.jdbcCompliant()).isFalse(); + } + + @Test + void testSuccessfulDriverVersion() { + Assertions.assertEquals(DRIVER_NAME, DriverVersion.getDriverName()); + Assertions.assertEquals(PRODUCT_NAME, DriverVersion.getProductName()); + Assertions.assertEquals(PRODUCT_VERSION, DriverVersion.getProductVersion()); + + val version = DriverVersion.getDriverVersion(); + assertThat(version) + .isNotBlank() + .matches(pattern) + .as("We expect this string to start with a digit, if this fails make sure you've run mvn compile"); + + val formattedDriverInfo = DriverVersion.formatDriverInfo(); + Assertions.assertEquals( + String.format("%s/%s", DRIVER_NAME, DriverVersion.getDriverVersion()), formattedDriverInfo); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/OrgIntegrationTest.java b/src/test/java/com/salesforce/datacloud/jdbc/OrgIntegrationTest.java new file mode 100644 index 0000000..1ea5a21 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/OrgIntegrationTest.java @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +import static com.salesforce.datacloud.jdbc.core.StreamingResultSetTest.query; +import static com.salesforce.datacloud.jdbc.util.Constants.CONNECTION_PROTOCOL; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import com.salesforce.datacloud.jdbc.core.DataCloudResultSet; +import com.salesforce.datacloud.jdbc.core.DataCloudStatement; +import com.salesforce.datacloud.jdbc.core.StreamingResultSet; +import com.salesforce.datacloud.jdbc.util.ThrowingBiFunction; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.TimeZone; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.Value; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +/** + * To run this test, set the environment variables for the various AuthenticationSettings strategies. Right-click the + * play button and click "modify run configuration" and paste the following in "Environment Variables" Then you can + * click the little icon on the right side of the field to update the values appropriately. + * loginURL=login.salesforce.com/;userName=xyz@salesforce.com;password=...;clientId=...;clientSecret=...;query=SELECT + * "Description__c" FROM Account_Home__dll LIMIT 100 + */ +@Slf4j +@Value +@EnabledIf("validateProperties") +class OrgIntegrationTest { + static Properties getPropertiesFromEnvironment() { + val properties = new Properties(); + System.getenv().forEach(properties::setProperty); + return properties; + } + + static AuthenticationSettings getSettingsFromEnvironment() { + try { + return AuthenticationSettings.of(getPropertiesFromEnvironment()); + } catch (Exception e) { + return null; + } + } + + @Getter(lazy = true) + Properties properties = getPropertiesFromEnvironment(); + + @Getter(lazy = true) + AuthenticationSettings settings = getSettingsFromEnvironment(); + + private static final int NUM_THREADS = 100; + + @Test + @SneakyThrows + @Disabled + void testDatasource() { + val query = "SELECT * FROM Account_Home__dll LIMIT 100"; + val connectionUrl = CONNECTION_PROTOCOL + getSettings().getLoginUrl(); + Class.forName("com.salesforce.datacloud.jdbc.DataCloudJDBCDriver"); + DataCloudDatasource datasource = new DataCloudDatasource(); + datasource.setConnectionUrl(connectionUrl); + datasource.setUserName(getProperties().getProperty("userName")); + datasource.setPassword(getProperties().getProperty("password")); + datasource.setClientId(getProperties().getProperty("clientId")); + datasource.setClientSecret(getProperties().getProperty("clientSecret")); + + try (val connection = datasource.getConnection(); + val statement = connection.createStatement()) { + val resultSet = statement.executeQuery(query); + assertThat(resultSet.next()).isTrue(); + } + + assertThrows(SQLException.class, () -> datasource.getConnection("foo", "bar")); + } + + @Test + @SneakyThrows + @Disabled + void testMetadata() { + try (val connection = getConnection()) { + val tableName = getProperties().getProperty("tableName", "Account_Home__dll"); + ResultSet columnResultSet = connection.getMetaData().getColumns("", "public", tableName, null); + ResultSet tableResultSet = connection.getMetaData().getTables(null, null, "%", null); + ResultSet schemaResultSetWithCatalogAndSchemaPattern = + connection.getMetaData().getSchemas(null, "public"); + ResultSet schemaResultSet = connection.getMetaData().getSchemas(); + ResultSet tableTypesResultSet = connection.getMetaData().getTableTypes(); + ResultSet catalogsResultSet = connection.getMetaData().getCatalogs(); + ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(null, null, "Account_Home__dll"); + while (primaryKeys.next()) { + log.info("trying to print primary keys"); + } + + assertThat(columnResultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + assertThat(tableResultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + assertThat(schemaResultSetWithCatalogAndSchemaPattern.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + assertThat(schemaResultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + assertThat(tableTypesResultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + assertThat(catalogsResultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("com.salesforce.datacloud.jdbc.core.StreamingResultSetTest#queryModesWithMax") + public void exerciseQueryMode( + ThrowingBiFunction queryMode, int max) { + val sql = query(max); + try (val connection = getConnection(); + val statement = connection.createStatement().unwrap(DataCloudStatement.class)) { + val rs = queryMode.apply(statement, sql); + + assertThat(rs.isReady()).isTrue(); + assertThat(rs).isInstanceOf(StreamingResultSet.class); + + var expected = 0; + while (rs.next()) { + expected++; + } + + log.info("final value: {}", expected); + assertThat(expected).isEqualTo(max); + assertThat(rs.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + @SneakyThrows + private DataCloudConnection getConnection() { + val connectionUrl = CONNECTION_PROTOCOL + getSettings().getLoginUrl(); + log.info("Connection URL: {}", connectionUrl); + + Class.forName("com.salesforce.datacloud.jdbc.DataCloudJDBCDriver"); + val connection = DriverManager.getConnection(connectionUrl, getProperties()); + return (DataCloudConnection) connection; + } + + @SneakyThrows + @Test + @Disabled + public void testPreparedStatementExecuteWithParams() { + val query = + "SELECT \"Id__c\", \"AnnualRevenue__c\", \"LastModifiedDate__c\" FROM Account_Home__dll WHERE \"Id__c\" = ? AND \"AnnualRevenue__c\" = ? AND \"LastModifiedDate__c\" = ?"; + try (val connection = getConnection(); + val statement = connection.prepareStatement(query)) { + val id = "001SB00000K3pP4YAJ"; + val annualRevenue = 100000000; + val lastModifiedDate = Timestamp.valueOf("2024-06-10 05:07:52.0"); + statement.setString(1, id); + statement.setInt(2, annualRevenue); + statement.setTimestamp(3, lastModifiedDate); + + statement.execute(query); + val resultSet = statement.getResultSet(); + + val results = new ArrayList(); + Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + + while (resultSet.next()) { + val idResult = resultSet.getString(1); + val annualRevenueResult = resultSet.getInt(2); + val lastModifiedDateResult = resultSet.getTimestamp(3, cal); + log.info("{} : {} : {}", idResult, annualRevenueResult, lastModifiedDateResult); + assertThat(idResult).isEqualTo(id); + assertThat(annualRevenueResult).isEqualTo(annualRevenue); + assertThat(lastModifiedDateResult).isEqualTo(lastModifiedDate); + val row = resultSet.getRow(); + Optional.ofNullable(resultSet.getObject("Id__c")).ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("AnnualRevenue__c")) + .ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("LastModifiedDate__c")) + .ifPresent(t -> results.add(row + " - " + t)); + } + assertThat(results.stream().filter(t -> !Objects.isNull(t))).hasSizeGreaterThanOrEqualTo(0); + + assertThat(resultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + @SneakyThrows + @Test + @Disabled + public void testPreparedStatementGetResultSetNoParams() { + val query = "SELECT \"Id__c\", \"AnnualRevenue__c\", \"LastModifiedDate__c\" FROM Account_Home__dll LIMIT 100"; + try (val connection = getConnection(); + val statement = connection.prepareStatement(query)) { + statement.execute(query); + val resultSet = statement.getResultSet(); + + val results = new ArrayList(); + + while (resultSet.next()) { + val row = resultSet.getRow(); + val resultFromColumnIndex = resultSet.getString(1); + val resultFromColumnName = resultSet.getString("Id__c"); + val resultFromColumnIndex2 = resultSet.getString(2); + val resultFromColumnName2 = resultSet.getString("AnnualRevenue__c"); + val resultFromColumnIndex3 = resultSet.getString(3); + val resultFromColumnName3 = resultSet.getString("LastModifiedDate__c"); + assertThat(resultFromColumnIndex).isEqualTo(resultFromColumnName); + assertThat(resultFromColumnIndex2).isEqualTo(resultFromColumnName2); + assertThat(resultFromColumnIndex3).isEqualTo(resultFromColumnName3); + Optional.ofNullable(resultSet.getObject("Id__c")).ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("AnnualRevenue__c")) + .ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("LastModifiedDate__c")) + .ifPresent(t -> results.add(row + " - " + t)); + } + assertThat(results.stream().filter(t -> !Objects.isNull(t))).hasSizeGreaterThanOrEqualTo(1); + + assertThat(resultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + @SneakyThrows + @Test + @Disabled + public void testPreparedStatementExecuteQueryNoParams() { + val query = "SELECT \"Id__c\", \"AnnualRevenue__c\", \"LastModifiedDate__c\" FROM Account_Home__dll LIMIT 100"; + try (val connection = getConnection(); + val statement = connection.prepareStatement(query)) { + + val resultSet = statement.executeQuery(query); + + val results = new ArrayList(); + + while (resultSet.next()) { + val row = resultSet.getRow(); + val resultFromColumnIndex = resultSet.getString(1); + val resultFromColumnName = resultSet.getString("Id__c"); + assertThat(resultFromColumnIndex).isEqualTo(resultFromColumnName); + Optional.ofNullable(resultSet.getObject("Id__c")).ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("AnnualRevenue__c")) + .ifPresent(t -> results.add(row + " - " + t)); + Optional.ofNullable(resultSet.getObject("LastModifiedDate__c")) + .ifPresent(t -> results.add(row + " - " + t)); + } + assertThat(results.stream().filter(t -> !Objects.isNull(t))).hasSizeGreaterThanOrEqualTo(1); + + assertThat(resultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + @Test + @SneakyThrows + @Disabled + void testArrowFieldConversion() { + Map queries = new HashMap<>(); + queries.put(Types.BOOLEAN, "SELECT 5 > 100 AS \"boolean_output\""); + queries.put(Types.VARCHAR, "SELECT 'a test string' as \"string_column\""); + queries.put(Types.DATE, "SELECT current_date"); + queries.put(Types.TIME, "SELECT current_time"); + queries.put(Types.TIMESTAMP, "SELECT current_timestamp"); + queries.put(Types.DECIMAL, "SELECT 82.3 as \"decimal_column\""); + queries.put(Types.INTEGER, "SELECT 82 as \"Integer_column\""); + try (val connection = getConnection(); + val statement = connection.createStatement()) { + for (var entry : queries.entrySet()) { + val resultSet = statement.executeQuery(entry.getValue().toString()); + val metadata = resultSet.getMetaData(); + log.info("columntypename: {}", metadata.getColumnTypeName(1)); + log.info("columntype: {}", Integer.toString(metadata.getColumnType(1))); + assertEquals( + Integer.toString(metadata.getColumnType(1)), + entry.getKey().toString()); + } + } + } + + @Test + @SneakyThrows + @Disabled + void testMultiThreadedAuth() { + ExecutorService executor = Executors.newFixedThreadPool(NUM_THREADS); + + for (int i = 0; i < NUM_THREADS; i++) { + executor.submit(this::testMainQuery); + } + + executor.shutdown(); + } + + @SneakyThrows + void testMainQuery() { + int max = 100; + val query = query(100); + + try (val connection = getConnection(); + val statement = connection.createStatement().unwrap(DataCloudStatement.class)) { + + log.info("Begin executeQuery"); + long startTime = System.currentTimeMillis(); + ResultSet resultSet = statement.executeAdaptiveQuery(query); + log.info("Query executed in {}ms", System.currentTimeMillis() - startTime); + + var expected = 0; + while (resultSet.next()) { + expected++; + } + + log.info("final value: {}", expected); + assertThat(expected).isEqualTo(max); + assertThat(resultSet.isClosed()) + .as("Query ResultSet was closed unexpectedly.") + .isFalse(); + } + } + + static boolean validateProperties() { + AuthenticationSettings getSettings = getSettingsFromEnvironment(); + return getSettings != null; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/ResponseEnum.java b/src/test/java/com/salesforce/datacloud/jdbc/ResponseEnum.java new file mode 100644 index 0000000..2b4be73 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/ResponseEnum.java @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc; + +public enum ResponseEnum { + INTERNAL_SERVER_ERROR("{\n" + + " \"timestamp\": \"2021-01-08T11:53:29.668+0000\",\n" + + " \"error\": \"Internal Server Error\",\n" + + " \"message\": \"Internal Server Error\",\n" + + " \"internalErrorCode\": \"COMMON_ERROR_GENERIC\",\n" + + " \"details\": {\n" + + " \"status\": \"Internal Server Error\",\n" + + " \"statusCode\": 500\n" + + " }\n" + + "}"), + + TABLE_METADATA("{\n" + + " \"metadata\": [\n" + + " {\n" + + " \"fields\": [\n" + + " {\n" + + " \"name\": \"DataSourceId__c\",\n" + + " \"type\": \"STRING\"\n" + + " },\n" + + " {\n" + + " \"name\": \"DataSourceObjectId__c\",\n" + + " \"type\": \"STRING\"\n" + + " },\n" + + " {\n" + + " \"name\": \"EmailAddress__c\",\n" + + " \"type\": \"STRING\"\n" + + " },\n" + + " {\n" + + " \"name\": \"Id__c\",\n" + + " \"type\": \"STRING\"\n" + + " },\n" + + " {\n" + + " \"name\": \"InternalOrganizationId__c\",\n" + + " \"type\": \"STRING\"\n" + + " },\n" + + " {\n" + + " \"name\": \"active__c\",\n" + + " \"type\": \"BOOLEAN\"\n" + + " },\n" + + " {\n" + + " \"name\": \"PartyId__c\",\n" + + " \"type\": \"STRING\"\n" + + " }\n" + + " ],\n" + + " \"category\": \"Profile\",\n" + + " \"name\": \"ContactPointEmail__dlm\",\n" + + " \"relationships\": [\n" + + " {\n" + + " \"fromEntity\": \"ContactPointEmail__dlm\",\n" + + " \"toEntity\": \"ContactPointEmailIdentityLink__dlm\",\n" + + " \"fromEntityAttribute\": \"Id__c\",\n" + + " \"toEntityAttribute\": \"SourceRecordId__c\",\n" + + " \"cardinality\": \"ONETOONE\"\n" + + " },\n" + + " {\n" + + " \"fromEntity\": \"ContactPointEmail__dlm\",\n" + + " \"toEntity\": \"Individual__dlm\",\n" + + " \"fromEntityAttribute\": \"PartyId__c\",\n" + + " \"toEntityAttribute\": \"Id__c\",\n" + + " \"cardinality\": \"NTOONE\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + "}"), + NOT_FOUND("{\n" + + " \"timestamp\": \"2021-01-08T11:53:29.668+0000\",\n" + + " \"error\": \"Not Found\",\n" + + " \"message\": \"Not Found\",\n" + + "}"), + UNAUTHORIZED("{\n" + + " \"timestamp\": \"2021-01-08T11:53:29.668+0000\",\n" + + " \"error\": \"Unauthorized\",\n" + + " \"message\": \"Authorization header verification failed\",\n" + + " \"internalErrorCode\": \"COMMON_ERROR_GENERIC\",\n" + + " \"details\": {\n" + + " \"status\": \"UNAUTHENTICATED: Invalid JWT (not before or expired)\",\n" + + " \"statusCode\": 401\n" + + " }\n" + + "}"), + QUERY_RESPONSE("{\n" + + " \"data\": [\n" + + " {\n" + + " \"telephonenumber__c\": \"001 5483188\"\n" + + " },\n" + + " {\n" + + " \"telephonenumber__c\": \"001 3205512\"\n" + + " }]," + + " \"startTime\": \"2021-01-11T05:34:34.040931Z\",\n" + + " \"endTime\": \"2021-01-11T05:34:34.040981Z\",\n" + + " \"rowCount\": 2,\n" + + " \"queryId\": \"53c66a0f-e666-4f61-9f84-7718528b7a63\",\n" + + " \"done\": true," + + " \"metadata\": {\n" + + " \"telephonenumber__c\": {\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 12,\n" + + " \"type\": \"VARCHAR\"\n" + + " }\n" + + "} \n" + + "}"), + QUERY_RESPONSE_V2("{\n" + + " \"data\": [\n" + + " [\n" + + " \"00034d6c-f5b4-348a-9fc2-7707d5b07dba\",\n" + + " \"Larae\"\n" + + " ]\n" + + " ],\n" + + " \"startTime\": \"2021-09-21T10:38:55.520428Z\",\n" + + " \"endTime\": \"2021-09-21T10:39:02.995939Z\",\n" + + " \"rowCount\": 1,\n" + + " \"queryId\": \"20210921_103858_00784_rpkgk\",\n" + + " \"nextBatchId\": \"\",\n" + + " \"done\": true,\n" + + " \"metadata\": {\n" + + " \"Id__c\": {\n" + + " \"type\": \"VARCHAR\",\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 12\n" + + " },\n" + + " \"FirstName__c\": {\n" + + " \"type\": \"VARCHAR\",\n" + + " \"placeInOrder\": 1,\n" + + " \"typeCode\": 12\n" + + " }\n" + + " }\n" + + "}"), + PAGINATED_RESPONSE_V2("{\n" + + " \"data\": [\n" + + " [\n" + + " \"00034d6c-f5b4-348a-9fc2-7707d5b07dba\",\n" + + " \"Larae\"\n" + + " ]\n" + + " ],\n" + + " \"startTime\": \"2021-09-21T10:38:55.520428Z\",\n" + + " \"endTime\": \"2021-09-21T10:39:02.995939Z\",\n" + + " \"rowCount\": 1,\n" + + " \"queryId\": \"20210921_103858_00784_rpkgk\",\n" + + " \"nextBatchId\": \"f98c7bcd-b1bd-4e8d-b98d-11aabdd6c604\",\n" + + " \"done\": false,\n" + + " \"metadata\": {\n" + + " \"Id__c\": {\n" + + " \"type\": \"VARCHAR\",\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 12\n" + + " },\n" + + " \"FirstName__c\": {\n" + + " \"type\": \"VARCHAR\",\n" + + " \"placeInOrder\": 1,\n" + + " \"typeCode\": 12\n" + + " }\n" + + " }\n" + + "}"), + EMPTY_RESPONSE("{\n" + + " \"data\": []," + + " \"startTime\": \"2021-01-11T05:34:34.040931Z\",\n" + + " \"endTime\": \"2021-01-11T05:34:34.040981Z\",\n" + + " \"rowCount\": 0,\n" + + " \"queryId\": \"53c66a0f-e666-4f61-9f84-7718528b7a63\",\n" + + " \"done\": true}"), + TOKEN_EXCHANGE("{\n" + + " \"access_token\": \"1234.eyJhdWRpZW5jZVRlbmFudElkIjoiYTM2MC9mYWxjb25kZXYvMTIzNCJ9.5678\",\n" + + " \"instance_url\": \"abcd\",\n" + + " \"token_type\": \"Bearer\",\n" + + " \"issued_token_type\": \"urn:ietf:params:oauth:token-type:jwt\",\n" + + " \"expires_in\": 7193}"), + TOO_MANY_REQUESTS("{\n" + + " \"timestamp\": \"2021-01-08T11:53:29.668+0000\",\n" + + " \"error\": \"Too many requests\",\n" + + " \"message\": \"Too many requests\",\n" + + "}"), + PAGINATION_RESPONSE("{\n" + + " \"data\": [\n" + + " {\n" + + " \"telephonenumber__c\": \"001 6723687\"\n" + + " },\n" + + " {\n" + + " \"telephonenumber__c\": \"001 9387489\"\n" + + " }]," + + " \"startTime\": \"2021-01-11T05:34:34.040931Z\",\n" + + " \"endTime\": \"2021-01-11T05:34:34.040981Z\",\n" + + " \"rowCount\": 2,\n" + + " \"queryId\": \"53c66a0f-e666-4f61-9f84-7718528b7a63\",\n" + + " \"done\": false," + + " \"metadata\": {\n" + + " \"telephonenumber__c\": {\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 12,\n" + + " \"type\": \"VARCHAR\"\n" + + " }\n" + + "} \n" + + "}"), + RENEWED_CORE_TOKEN("{\n" + + " \"access_token\": \"00DR0000000KvIt\",\n" + + " \"signature\": \"0w96S+=\",\n" + + " \"scope\": \"refresh_token cdpquery api cdpprofile\",\n" + + " \"instance_url\": \"https://flash232cdpusercom.my.stmpa.stm.salesforce.com\",\n" + + " \"id\": \"https://login.stmpa.stm.salesforce.com/id/00DR0000000KvItMAK/005R0000000pgIjIAI\",\n" + + " \"token_type\": \"Bearer\",\n" + + " \"issued_at\": \"1611569641915\"\n" + + "}"), + OAUTH_TOKEN_ERROR("{\n" + + " \"error\": \"invalid_grant\",\n" + + " \"error_description\": \"expired authorization code\"\n" + + "}"), + QUERY_RESPONSE_WITHOUT_DONE_FLAG("{\n" + + " \"data\": [\n" + + " {\n" + + " \"telephonenumber__c\": \"001 6723687\"\n" + + " },\n" + + " {\n" + + " \"telephonenumber__c\": \"001 9387489\"\n" + + " }]," + + " \"startTime\": \"2021-01-11T05:34:34.040931Z\",\n" + + " \"endTime\": \"2021-01-11T05:34:34.040981Z\",\n" + + " \"rowCount\": 2,\n" + + " \"queryId\": \"53c66a0f-e666-4f61-9f84-7718528b7a63\"," + + " \"metadata\": {\n" + + " \"telephonenumber__c\": {\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 12,\n" + + " \"type\": \"VARCHAR\"\n" + + " }\n" + + "} \n" + + "}"), + QUERY_RESPONSE_WITH_METADATA("{\n" + + " \"data\": [\n" + + " {\n" + + " \"count_num\": \"10\"\n" + + " }]," + + " \"startTime\": \"2021-01-11T05:34:34.040931Z\",\n" + + " \"endTime\": \"2021-01-11T05:34:34.040981Z\",\n" + + " \"rowCount\": 1,\n" + + " \"queryId\": \"53c66a0f-e666-4f61-9f84-7718528b7a63\",\n" + + " \"done\": true,\n" + + " \"metadata\": {\n" + + " \"count_num\": {\n" + + " \"placeInOrder\": 0,\n" + + " \"typeCode\": 3,\n" + + " \"type\": \"DECIMAL\"\n" + + " }\n" + + "} \n" + + "}"), + HTML_ERROR_RESPONSE( + "\n" + + "\n" + + "\n" + + "\n" + + "\t\n" + + "\n" + + "\n" + + "\n" + + "\n" + + "\n" + + "\t\n" + + "\n" + + "\n" + + "\n" + + "\n" + + ""), + DATASPACE_RESPONSE("{\n" + + " \"totalSize\": 2,\n" + + " \"done\": true,\n" + + " \"records\": [\n" + + " {\n" + + " \"attributes\": {\n" + + " \"type\": \"DataSpace\",\n" + + " \"url\": \"/services/data/v60.0/sobjects/DataSpace/0vhVF00000003AnYAI\"\n" + + " },\n" + + " \"Name\": \"default\"\n" + + " },\n" + + " {\n" + + " \"attributes\": {\n" + + " \"type\": \"DataSpace\",\n" + + " \"url\": \"/services/data/v60.0/sobjects/DataSpace/0vhVF0000000Ch7YAE\"\n" + + " },\n" + + " \"Name\": \"DS2\"\n" + + " }\n" + + " ]\n" + + "}"); + + private final String response; + + ResponseEnum(String response) { + this.response = response; + } + + public String getResponse() { + return response; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettingsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettingsTest.java new file mode 100644 index 0000000..999cf07 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationSettingsTest.java @@ -0,0 +1,209 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.allPropertiesExcept; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPassword; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPrivateKey; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForRefreshToken; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.randomString; +import static com.salesforce.datacloud.jdbc.util.ThrowingFunction.rethrowFunction; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.function.Function; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +@ExtendWith(SoftAssertionsExtension.class) +public class AuthenticationSettingsTest { + private static AuthenticationSettings sut(Properties properties) throws SQLException { + return AuthenticationSettings.of(properties); + } + + @InjectSoftAssertions + SoftAssertions softly; + + @SneakyThrows + private static Stream constructors() { + List properties = Arrays.asList(null, new Properties()); + List> ctors = List.of( + rethrowFunction(AuthenticationSettings::of), + rethrowFunction(PasswordAuthenticationSettings::new), + rethrowFunction(PrivateKeyAuthenticationSettings::new), + rethrowFunction(RefreshTokenAuthenticationSettings::new)); + + return ctors.stream().flatMap(c -> properties.stream().map(p -> Arguments.of(p, c))); + } + + @ParameterizedTest + @MethodSource("constructors") + void ofWithNullProperties(Properties p, Function ctor) { + if (p == null) { + val expectedMessage = AuthenticationSettings.Messages.PROPERTIES_NULL; + val expectedException = IllegalArgumentException.class; + val e = assertThrows(expectedException, () -> ctor.apply(p)); + softly.assertThat((Throwable) e).hasMessage(expectedMessage).hasNoCause(); + } else { + val expectedMessage = AuthenticationSettings.Messages.PROPERTIES_EMPTY; + val expectedException = DataCloudJDBCException.class; + val e = assertThrows(expectedException, () -> ctor.apply(p)); + softly.assertThat((Throwable) e) + .hasMessage(expectedMessage) + .hasCause(new IllegalArgumentException(expectedMessage)); + } + } + + @Test + void ofWithNoneOfTheRequiredProperties() { + val p = allPropertiesExcept( + AuthenticationSettings.Keys.PRIVATE_KEY, + AuthenticationSettings.Keys.PASSWORD, + AuthenticationSettings.Keys.REFRESH_TOKEN); + val e = assertThrows(DataCloudJDBCException.class, () -> sut(p)); + assertThat((Throwable) e) + .hasMessage(AuthenticationSettings.Messages.PROPERTIES_MISSING) + .hasNoCause(); + } + + @Test + @SneakyThrows + void ofWithPassword() { + val password = randomString(); + val userName = randomString(); + val p = propertiesForPassword(userName, password); + + val sut = sut(p); + + assertThat(sut).isInstanceOf(PasswordAuthenticationSettings.class); + assertThat(((PasswordAuthenticationSettings) sut).getPassword()).isEqualTo(password); + } + + @Test + @SneakyThrows + void ofWithPrivateKey() { + val privateKey = randomString(); + val p = propertiesForPrivateKey(privateKey); + + val sut = sut(p); + + assertThat(sut).isInstanceOf(PrivateKeyAuthenticationSettings.class); + assertThat(((PrivateKeyAuthenticationSettings) sut).getPrivateKey()).isEqualTo(privateKey); + } + + @Test + @SneakyThrows + void ofWithRefreshToken() { + val refreshToken = randomString(); + val p = propertiesForRefreshToken(refreshToken); + + val sut = sut(p); + + assertThat(sut).isInstanceOf(RefreshTokenAuthenticationSettings.class); + assertThat((RefreshTokenAuthenticationSettings) sut) + .satisfies(s -> assertThat(s.getRefreshToken()).isEqualTo(refreshToken)); + } + + @Test + @SneakyThrows + void getRelevantPropertiesFiltersUnexpectedProperties() { + val p = allPropertiesExcept(); + p.setProperty("unexpected", randomString()); + + val sut = sut(p); + + assertThat(sut.getRelevantProperties().containsKey("unexpected")).isFalse(); + } + + @Test + @SneakyThrows + void baseAuthenticationOptionalSettingsGettersReturnDefaultValues() { + val p = allPropertiesExcept( + AuthenticationSettings.Keys.USER_AGENT, + AuthenticationSettings.Keys.DATASPACE, + AuthenticationSettings.Keys.MAX_RETRIES); + val sut = sut(p); + + assertThat(sut) + .returns( + AuthenticationSettings.Defaults.USER_AGENT, + Assertions.from(AuthenticationSettings::getUserAgent)) + .returns( + AuthenticationSettings.Defaults.MAX_RETRIES, + Assertions.from(AuthenticationSettings::getMaxRetries)) + .returns( + AuthenticationSettings.Defaults.DATASPACE, + Assertions.from(AuthenticationSettings::getDataspace)); + } + + @Test + @SneakyThrows + void baseAuthenticationSettingsGettersReturnCorrectValues() { + val loginUrl = randomString(); + val userName = randomString(); + val clientId = randomString(); + val clientSecret = randomString(); + val dataspace = randomString(); + val userAgent = randomString(); + val maxRetries = 123; + + val p = allPropertiesExcept(); + p.put(AuthenticationSettings.Keys.LOGIN_URL, loginUrl); + p.put(AuthenticationSettings.Keys.USER_NAME, userName); + p.put(AuthenticationSettings.Keys.CLIENT_ID, clientId); + p.put(AuthenticationSettings.Keys.CLIENT_SECRET, clientSecret); + p.put(AuthenticationSettings.Keys.DATASPACE, dataspace); + p.put(AuthenticationSettings.Keys.USER_AGENT, userAgent); + p.put(AuthenticationSettings.Keys.MAX_RETRIES, Integer.toString(maxRetries)); + + val sut = sut(p); + + assertThat(sut) + .returns(loginUrl, Assertions.from(AuthenticationSettings::getLoginUrl)) + .returns(clientId, Assertions.from(AuthenticationSettings::getClientId)) + .returns(clientSecret, Assertions.from(AuthenticationSettings::getClientSecret)) + .returns(userAgent, Assertions.from(AuthenticationSettings::getUserAgent)) + .returns(maxRetries, Assertions.from(AuthenticationSettings::getMaxRetries)) + .returns(dataspace, Assertions.from(AuthenticationSettings::getDataspace)); + } + + @Test + @SneakyThrows + void baseAuthenticationSettingsRequiredSettingsThrow() { + AuthenticationSettings.Keys.REQUIRED_KEYS.forEach(k -> { + val p = allPropertiesExcept(k); + val e = assertThrows(DataCloudJDBCException.class, () -> sut(p)); + assertThat((Throwable) e) + .hasMessage(AuthenticationSettings.Messages.PROPERTIES_REQUIRED + k) + .hasCause(new IllegalArgumentException(AuthenticationSettings.Messages.PROPERTIES_REQUIRED + k)); + }); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategyTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategyTest.java new file mode 100644 index 0000000..c74272c --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/AuthenticationStrategyTest.java @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPassword; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPrivateKey; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForRefreshToken; +import static com.salesforce.datacloud.jdbc.util.ThrowingFunction.rethrowFunction; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.FormCommand; +import java.net.URI; +import java.net.URISyntaxException; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Properties; +import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +@ExtendWith(SoftAssertionsExtension.class) +class AuthenticationStrategyTest { + private static final URI LOGIN; + private static final URI INSTANCE; + + static { + try { + LOGIN = new URI("login.test1.pc-rnd.salesforce.com"); + INSTANCE = new URI("https://valid-instance.salesforce.com"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private static final String JWT_GRANT = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + + @InjectSoftAssertions + SoftAssertions softly; + + @SneakyThrows + static Stream settings() { + Function from = rethrowFunction(AuthenticationSettings::of); + + return Stream.of( + Arguments.of( + from.apply(propertiesForPrivateKey(PrivateKeyHelpersTest.fakePrivateKey)), + PrivateKeyAuthenticationStrategy.class), + Arguments.of(from.apply(propertiesForPassword("un", "pw")), PasswordAuthenticationStrategy.class), + Arguments.of(from.apply(propertiesForRefreshToken("rt")), RefreshTokenAuthenticationStrategy.class)); + } + + @ParameterizedTest + @MethodSource("settings") + @SneakyThrows + void ofSettingsProperlyMapsTypes(AuthenticationSettings settings, Class type) { + assertThat(AuthenticationStrategy.of(settings)).isInstanceOf(type); + } + + @Test + @SneakyThrows + void ofSettingsProperlyThrowsOnUnknown() { + val anonymous = new AuthenticationSettings(propertiesForPassword("un", "pw")) {}; + val e = assertThrows(DataCloudJDBCException.class, () -> AuthenticationStrategy.of(anonymous)); + assertThat((Throwable) e) + .hasMessage(AuthenticationStrategy.Messages.UNKNOWN_SETTINGS_TYPE) + .hasCause(new IllegalArgumentException(AuthenticationStrategy.Messages.UNKNOWN_SETTINGS_TYPE)); + } + + @Test + @SneakyThrows + void givenTokenSharedStrategyCreatesCorrectRevokeCommand() { + val passwordProperties = propertiesForPassword("un", "pw"); + val privateKeyProperties = propertiesForPrivateKey("pw"); + val refreshTokenProperties = propertiesForRefreshToken("rt"); + + val props = new Properties[] {passwordProperties, privateKeyProperties, refreshTokenProperties}; + Arrays.stream(props).forEach(p -> { + val token = new OAuthToken("token", INSTANCE); + + FormCommand revokeCommand = null; + try { + revokeCommand = RevokeTokenAuthenticationStrategy.of(AuthenticationSettings.of(p), token) + .toCommand(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + assertThat(revokeCommand.getUrl()).isEqualTo(INSTANCE); + assertThat(revokeCommand.getSuffix()).isEqualTo(URI.create("services/oauth2/revoke")); + assertThat(revokeCommand.getBodyEntries()).containsKeys("token").containsEntry("token", token.getToken()); + }); + } + + @Test + @SneakyThrows + void givenTokenSharedStrategyCreatesCorrectExchangeCommand() { + val passwordProperties = propertiesForPassword("un", "pw"); + val privateKeyProperties = propertiesForPrivateKey("pw"); + val refreshTokenProperties = propertiesForRefreshToken("rt"); + + val props = new Properties[] {passwordProperties, privateKeyProperties, refreshTokenProperties}; + Arrays.stream(props).forEach(p -> { + val token = new OAuthToken("token", INSTANCE); + + FormCommand exchangeCommand = null; + try { + exchangeCommand = ExchangeTokenAuthenticationStrategy.of(AuthenticationSettings.of(p), token) + .toCommand(); + } catch (SQLException e) { + throw new RuntimeException(e); + } + assertThat(exchangeCommand.getUrl()).isEqualTo(INSTANCE); + assertThat(exchangeCommand.getSuffix()).isEqualTo(URI.create("services/a360/token")); + assertThat(exchangeCommand.getBodyEntries()) + .containsKeys("grant_type", "subject_token_type", "subject_token") + .containsEntry("grant_type", "urn:salesforce:grant-type:external:cdp") + .containsEntry("subject_token_type", "urn:ietf:params:oauth:token-type:access_token") + .containsEntry("subject_token", token.getToken()); + }); + } + + @SneakyThrows + @Test + void privateKeyCreatesCorrectCommand() { + val p = propertiesForPrivateKey(PrivateKeyHelpersTest.fakePrivateKey); + + val actual = AuthenticationStrategy.of(p).buildAuthenticate(); + + assertThat(actual.getUrl()).isEqualTo(LOGIN); + assertThat(actual.getSuffix()).isEqualTo(URI.create("services/oauth2/token")); + assertThat(actual.getBodyEntries()) + .containsKeys("grant_type", "assertion") + .containsEntry("grant_type", JWT_GRANT); + PrivateKeyHelpersTest.shouldYieldJwt(actual.getBodyEntries().get("assertion"), PrivateKeyHelpersTest.fakeJwt); + } + + @SneakyThrows + @Test + void passwordCreatesCorrectCommand() { + val password = UUID.randomUUID().toString(); + val userName = UUID.randomUUID().toString(); + val p = propertiesForPassword(userName, password); + val actual = AuthenticationStrategy.of(p).buildAuthenticate(); + + assertThat(actual.getUrl()).isEqualTo(LOGIN); + assertThat(actual.getSuffix()).isEqualTo(URI.create("services/oauth2/token")); + assertThat(actual.getBodyEntries()) + .containsKeys("grant_type", "username", "password", "client_id", "client_secret") + .containsEntry("grant_type", "password") + .containsEntry("username", p.getProperty("userName")) + .containsEntry("password", password); + } + + @SneakyThrows + @Test + void refreshTokenCreatesCorrectCommand() { + val refreshToken = UUID.randomUUID().toString(); + val p = propertiesForRefreshToken(refreshToken); + val actual = AuthenticationStrategy.of(p).buildAuthenticate(); + + assertThat(actual.getUrl()).isEqualTo(LOGIN); + assertThat(actual.getSuffix()).isEqualTo(URI.create("services/oauth2/token")); + assertThat(actual.getBodyEntries()) + .containsKeys("grant_type", "refresh_token", "client_id", "client_secret") + .containsEntry("grant_type", "refresh_token") + .containsEntry("refresh_token", refreshToken); + } + + static Stream grantTypeExpectations() { + return Stream.of( + Arguments.of(JWT_GRANT, propertiesForPrivateKey(PrivateKeyHelpersTest.fakePrivateKey)), + Arguments.of("password", propertiesForPassword("un", "pw")), + Arguments.of("refresh_token", propertiesForRefreshToken("rt"))); + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("grantTypeExpectations") + void allIncludeGrantType(String expectedGrantType, Properties properties) { + val actual = AuthenticationStrategy.of(properties).buildAuthenticate().getBodyEntries(); + assertThat(actual).containsEntry("grant_type", expectedGrantType); + } + + static Stream sharedAuthenticationSettings() { + return Stream.of(propertiesForPassword("un", "pw"), propertiesForRefreshToken("rt")); + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("sharedAuthenticationSettings") + void allIncludeSharedSettings(Properties properties) { + + val clientId = properties.getProperty(AuthenticationSettings.Keys.CLIENT_ID); + val clientSecret = properties.getProperty(AuthenticationSettings.Keys.CLIENT_SECRET); + + val actual = AuthenticationStrategy.of(properties).buildAuthenticate().getBodyEntries(); + softly.assertThat(actual) + .containsEntry(AuthenticationStrategy.Keys.CLIENT_ID, clientId) + .containsEntry(AuthenticationStrategy.Keys.CLIENT_SECRET, clientSecret); + } + + static Stream allAuthenticationSettings() { + return Stream.of(propertiesForPassword("un", "pw"), propertiesForRefreshToken("rt")); + } + + @SneakyThrows + @Test + void exchangeTokenAuthenticationStrategyIncludesDataspaceOptionally() { + val properties = propertiesForPassword("un", "pw"); + val key = AuthenticationSettings.Keys.DATASPACE; + properties.remove(key); + + val token = new OAuthToken("token", INSTANCE); + + val none = ExchangeTokenAuthenticationStrategy.of(AuthenticationSettings.of(properties), token) + .toCommand(); + softly.assertThat(none.getBodyEntries()).doesNotContainKey(key); + + val dataspace = UUID.randomUUID().toString(); + properties.put(key, dataspace); + + val some = ExchangeTokenAuthenticationStrategy.of(AuthenticationSettings.of(properties), token) + .toCommand(); + softly.assertThat(some.getBodyEntries()).containsEntry(key, dataspace); + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("allAuthenticationSettings") + void allIncludeUserAgentOptionally(Properties properties) { + val key = AuthenticationSettings.Keys.USER_AGENT; + properties.remove(key); + + val none = AuthenticationStrategy.of(properties).buildAuthenticate(); + softly.assertThat(none.getHeaders()).containsEntry(key, AuthenticationSettings.Defaults.USER_AGENT); + + val userAgent = UUID.randomUUID().toString(); + properties.put(key, userAgent); + + val some = AuthenticationStrategy.of(properties).buildAuthenticate(); + softly.assertThat(some.getHeaders()).containsEntry(key, userAgent); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessorTest.java new file mode 100644 index 0000000..cac8497 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenProcessorTest.java @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.auth.PrivateKeyHelpersTest.fakeTenantId; +import static com.salesforce.datacloud.jdbc.auth.PrivateKeyHelpersTest.fakeToken; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPassword; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.auth.errors.AuthorizationException; +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import com.salesforce.datacloud.jdbc.auth.model.OAuthTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLException; +import java.util.Properties; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import lombok.SneakyThrows; +import lombok.val; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.FailsafeException; +import net.jodah.failsafe.RetryPolicy; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.SocketPolicy; +import org.assertj.core.api.AssertionsForClassTypes; +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(SoftAssertionsExtension.class) +class DataCloudTokenProcessorTest { + @InjectSoftAssertions + private SoftAssertions softly; + + static final Function buildRetry = DataCloudTokenProcessor::buildRetryPolicy; + + Random random = new Random(10); + + @SneakyThrows + @Test + void createsRetryPolicyWithAndWithoutDefault() { + val properties = new Properties(); + val none = buildRetry.apply(properties); + + softly.assertThat(none.getMaxRetries()).isEqualTo(DataCloudTokenProcessor.DEFAULT_MAX_RETRIES); + + val retries = random.nextInt(12345); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, Integer.toString(retries)); + + val some = buildRetry.apply(properties); + softly.assertThat(some.getMaxRetries()).isEqualTo(retries); + } + + @SneakyThrows + @Test + void retryPolicyDoesntHandleAuthorizationException() { + val retry = buildRetry.apply(new Properties()); + + val ex = + assertThrows(FailsafeException.class, () -> Failsafe.with(retry).get(() -> { + throw AuthorizationException.builder().build(); + })); + + AssertionsForClassTypes.assertThat(ex).hasRootCauseInstanceOf(AuthorizationException.class); + } + + @SneakyThrows + @Test + void retryPolicyOnlyHandlesTokenException() { + val retry = buildRetry.apply(new Properties()); + val expected = UUID.randomUUID().toString(); + + assertThrows(IllegalArgumentException.class, () -> Failsafe.with(retry).get(() -> { + throw new IllegalArgumentException(); + })); + + val counter = new AtomicInteger(0); + + val actual = Failsafe.with(retry).get(() -> { + if (counter.getAndIncrement() < DataCloudTokenProcessor.DEFAULT_MAX_RETRIES) { + throw new SQLException("hi"); + } + + return expected; + }); + + assertThat(actual).isEqualTo(expected); + } + + @SneakyThrows + @Test + void retryPolicyRetriesExpectedNumberOfTimesThenGivesUp() { + val properties = propertiesForPassword("un", "pw"); + val expectedTriesCount = DataCloudTokenProcessor.DEFAULT_MAX_RETRIES + 1; + try (val server = new MockWebServer()) { + server.start(); + for (int x = 0; x < expectedTriesCount; x++) { + server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)); + } + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getOAuthToken()); + assertThat(server.getRequestCount()).isEqualTo(expectedTriesCount); + server.shutdown(); + } + } + + @SneakyThrows + @Test + void exponentialBackoffPolicyRetriesExpectedNumberOfTimesThenGivesUp() { + val mapper = new ObjectMapper(); + val oAuthTokenResponse = new OAuthTokenResponse(); + val accessToken = UUID.randomUUID().toString(); + oAuthTokenResponse.setToken(accessToken); + val properties = propertiesForPassword("un", "pw"); + val expectedTriesCount = 2 * DataCloudTokenProcessor.DEFAULT_MAX_RETRIES + 1; + try (val server = new MockWebServer()) { + server.start(); + for (int x = 0; x < expectedTriesCount; x++) { + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + server.enqueue(new MockResponse().setSocketPolicy(SocketPolicy.DISCONNECT_AT_START)); + } + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + assertThat(server.getRequestCount()).isEqualTo(expectedTriesCount); + server.shutdown(); + } + } + + @SneakyThrows + @Test + void oauthTokenRetrieved() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + val oAuthTokenResponse = new OAuthTokenResponse(); + val accessToken = UUID.randomUUID().toString(); + oAuthTokenResponse.setToken(accessToken); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + + val actual = DataCloudTokenProcessor.of(properties).getOAuthToken(); + assertThat(actual.getToken()).as("access token").isEqualTo(accessToken); + assertThat(actual.getInstanceUrl().toString()) + .as("instance url") + .isEqualTo(server.url("").toString()); + } + } + + @SneakyThrows + @Test + void bothTokensRetrieved() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + val oAuthTokenResponse = new OAuthTokenResponse(); + oAuthTokenResponse.setToken(UUID.randomUUID().toString()); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + val dataCloudTokenResponse = new DataCloudTokenResponse(); + dataCloudTokenResponse.setTokenType(UUID.randomUUID().toString()); + dataCloudTokenResponse.setToken(fakeToken); + dataCloudTokenResponse.setInstanceUrl(server.url("").toString()); + val expected = DataCloudToken.of(dataCloudTokenResponse); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(dataCloudTokenResponse))); + + val actual = DataCloudTokenProcessor.of(properties).getDataCloudToken(); + assertThat(actual.getAccessToken()).as("access token").isEqualTo(expected.getAccessToken()); + assertThat(actual.getTenantUrl()).as("tenant url").isEqualTo(expected.getTenantUrl()); + assertThat(actual.getTenantId()).as("tenant id").isEqualTo(fakeTenantId); + } + } + + @SneakyThrows + @Test + void throwsExceptionWhenDataCloudTokenResponseContainsErrorDescription() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, "0"); + val oAuthTokenResponse = new OAuthTokenResponse(); + oAuthTokenResponse.setToken(UUID.randomUUID().toString()); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + val dataCloudTokenResponse = new DataCloudTokenResponse(); + val errorDescription = UUID.randomUUID().toString(); + val errorCode = UUID.randomUUID().toString(); + dataCloudTokenResponse.setErrorDescription(errorDescription); + dataCloudTokenResponse.setErrorCode(errorCode); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(dataCloudTokenResponse))); + + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + + assertAuthorizationException( + ex, + "Received an error when exchanging oauth access token for data cloud token.", + errorCode + ": " + errorDescription); + } + } + + @SneakyThrows + @Test + void throwsExceptionWhenOauthTokenResponseIsMissingAccessToken() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, "0"); + val oAuthTokenResponse = new OAuthTokenResponse(); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + + assertSQLException(ex, "Received an error when acquiring oauth access token, no token in response."); + } + } + + @SneakyThrows + @Test + void throwsExceptionWhenDataCloudTokenResponseIsMissingAccessToken() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, "0"); + val oAuthTokenResponse = new OAuthTokenResponse(); + oAuthTokenResponse.setToken(UUID.randomUUID().toString()); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + val dataCloudTokenResponse = new DataCloudTokenResponse(); + dataCloudTokenResponse.setTokenType(UUID.randomUUID().toString()); + dataCloudTokenResponse.setInstanceUrl(server.url("").toString()); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(dataCloudTokenResponse))); + + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + + assertSQLException( + ex, + "Received an error when exchanging oauth access token for data cloud token, no token in response."); + } + } + + @SneakyThrows + @Test + void throwsExceptionWhenOauthTokenResponseIsNull() { + val properties = propertiesForPassword("un", "pw"); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, "0"); + val oAuthTokenResponse = new OAuthTokenResponse(); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody("{}")); + + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + + assertSQLException(ex, "Received an error when acquiring oauth access token, no token in response."); + } + } + + @SneakyThrows + @Test + void throwsExceptionWhenDataCloudTokenResponseIsNull() { + val mapper = new ObjectMapper(); + val properties = propertiesForPassword("un", "pw"); + properties.put(DataCloudTokenProcessor.MAX_RETRIES_KEY, "0"); + val oAuthTokenResponse = new OAuthTokenResponse(); + oAuthTokenResponse.setToken(UUID.randomUUID().toString()); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + properties.setProperty( + AuthenticationSettings.Keys.LOGIN_URL, server.url("").toString()); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(oAuthTokenResponse))); + server.enqueue(new MockResponse().setBody("{}")); + + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudTokenProcessor.of(properties) + .getDataCloudToken()); + + assertSQLException( + ex, + "Received an error when exchanging oauth access token for data cloud token, no token in response."); + } + } + + private static void assertAuthorizationException(Throwable actual, CharSequence... messages) { + AssertionsForClassTypes.assertThat(actual) + .hasMessageContainingAll(messages) + .hasCauseInstanceOf(DataCloudJDBCException.class) + .hasRootCauseInstanceOf(AuthorizationException.class); + } + + private static void assertSQLException(Throwable actual, CharSequence... messages) { + AssertionsForClassTypes.assertThat(actual) + .hasMessageContainingAll(messages) + .hasCauseInstanceOf(DataCloudJDBCException.class) + .hasRootCauseInstanceOf(SQLException.class); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenTest.java new file mode 100644 index 0000000..43c8bb2 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/DataCloudTokenTest.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.auth.PrivateKeyHelpersTest.fakeTenantId; +import static com.salesforce.datacloud.jdbc.auth.PrivateKeyHelpersTest.fakeToken; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Messages; +import java.util.UUID; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(SoftAssertionsExtension.class) +class DataCloudTokenTest { + private final String validToken = "token-" + UUID.randomUUID(); + private final String validUrl = "https://login.something.salesforce.com"; + + @InjectSoftAssertions + SoftAssertions softly; + + @SneakyThrows + @Test + void whenTokenHasExpiredIsAliveIsFalse() { + val expired = new DataCloudTokenResponse(); + expired.setTokenType("type"); + expired.setToken(validToken); + expired.setInstanceUrl(validUrl); + expired.setExpiresIn(-100); + assertThat(DataCloudToken.of(expired).isAlive()).isFalse(); + } + + @SneakyThrows + @Test + void whenTokenHasNotExpiredIsAliveIsTrue() { + val notExpired = new DataCloudTokenResponse(); + notExpired.setTokenType("type"); + notExpired.setToken(validToken); + notExpired.setInstanceUrl(validUrl); + notExpired.setExpiresIn(100); + + assertThat(DataCloudToken.of(notExpired).isAlive()).isTrue(); + } + + @Test + void throwsWhenIllegalArgumentsAreProvided() { + val noTokenResponse = new DataCloudTokenResponse(); + noTokenResponse.setTokenType("type"); + noTokenResponse.setInstanceUrl(validUrl); + noTokenResponse.setExpiresIn(10000); + noTokenResponse.setToken(""); + assertThat(assertThrows(IllegalArgumentException.class, () -> DataCloudToken.of(noTokenResponse))) + .hasMessageContaining("token"); + val noUriResponse = new DataCloudTokenResponse(); + noUriResponse.setTokenType("type"); + noUriResponse.setInstanceUrl(""); + noUriResponse.setExpiresIn(10000); + noUriResponse.setToken(validToken); + assertThat(assertThrows(IllegalArgumentException.class, () -> DataCloudToken.of(noUriResponse))) + .hasMessageContaining("instance_url"); + } + + @Test + void throwsWhenTenantUrlIsIllegal() { + val nonNullOrBlankIllegalUrl = "%XY"; + val bad = new DataCloudTokenResponse(); + bad.setInstanceUrl(nonNullOrBlankIllegalUrl); + bad.setToken("token"); + bad.setTokenType("type"); + bad.setExpiresIn(123); + val exception = assertThrows(DataCloudJDBCException.class, () -> DataCloudToken.of(bad)); + assertThat(exception.getMessage()).contains(Messages.FAILED_LOGIN); + assertThat(exception.getCause().getMessage()) + .contains("Malformed escape pair at index 0: " + nonNullOrBlankIllegalUrl); + } + + @SneakyThrows + @Test + void properlyReturnsCorrectValues() { + val validResponse = new DataCloudTokenResponse(); + val token = fakeToken; + + validResponse.setInstanceUrl(validUrl); + validResponse.setToken(token); + validResponse.setTokenType("Bearer"); + validResponse.setExpiresIn(123); + + val actual = DataCloudToken.of(validResponse); + softly.assertThat(actual.getAccessToken()).isEqualTo("Bearer " + token); + softly.assertThat(actual.getTenantUrl()).isEqualTo(validUrl); + softly.assertThat(actual.getTenantId()).isEqualTo(fakeTenantId); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/OAuthTokenTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/OAuthTokenTest.java new file mode 100644 index 0000000..9898aa8 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/OAuthTokenTest.java @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.auth.model.OAuthTokenResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Messages; +import lombok.val; +import org.junit.jupiter.api.Test; + +class OAuthTokenTest { + @Test + void throwsOnBadInstanceUrl() { + val response = new OAuthTokenResponse(); + response.setToken("not empty"); + response.setInstanceUrl("%&#("); + val ex = assertThrows(DataCloudJDBCException.class, () -> OAuthToken.of(response)); + assertThat(ex).hasMessage(Messages.FAILED_LOGIN); + } + + @Test + void throwsOnBadToken() { + val response = new OAuthTokenResponse(); + response.setInstanceUrl("login.salesforce.com"); + val ex = assertThrows(DataCloudJDBCException.class, () -> OAuthToken.of(response)); + assertThat(ex).hasMessage(Messages.FAILED_LOGIN).hasNoCause(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpersTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpersTest.java new file mode 100644 index 0000000..f2c0931 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/PrivateKeyHelpersTest.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPrivateKey; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.util.Constants; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Properties; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.SneakyThrows; +import lombok.Value; +import lombok.extern.jackson.Jacksonized; +import lombok.val; +import org.junit.jupiter.api.Test; + +@Builder +@Jacksonized +@Value +class Part { + String iss; + String sub; + String aud; + int iat; + int exp; +} + +class PrivateKeyHelpersTest { + + @Test + @SneakyThrows + void testAudience() { + Audience audience = Audience.of("https://something.salesforce.com"); + assertThat(audience).isEqualTo(Audience.PROD); + assertThat(audience.getUrl()).isEqualTo("login.salesforce.com"); + + audience = Audience.of("https://login.test1.pc-rnd.salesforce.com"); + assertThat(audience).isEqualTo(Audience.DEV); + assertThat(audience.getUrl()).isEqualTo("login.test1.pc-rnd.salesforce.com"); + + assertThrows( + SQLException.class, + () -> Audience.of("not a url"), + "The specified url: 'not a url' didn't match any known environments"); + } + + @SneakyThrows + @Test + void testJwtParts() { + String audience = "https://login.test1.pc-rnd.salesforce.com"; + Properties properties = propertiesForPrivateKey(fakePrivateKey); + properties.put(Constants.CLIENT_ID, "client_id"); + properties.put(Constants.CLIENT_SECRET, "client_secret"); + properties.put(Constants.USER_NAME, "user_name"); + properties.put(Constants.LOGIN_URL, audience); + properties.put(Constants.PRIVATE_KEY, fakePrivateKey); + String actual = JwtParts.buildJwt((PrivateKeyAuthenticationSettings) AuthenticationSettings.of(properties)); + shouldYieldJwt(actual, fakeJwt); + } + + @SneakyThrows + public static void shouldYieldJwt(String actual, String expected) { + val actualParts = Arrays.stream(actual.split("\\.")).limit(2).collect(Collectors.toList()); + val expectedParts = Arrays.stream(expected.split("\\.")).limit(2).collect(Collectors.toList()); + + assertThat(actualParts.get(0)).isEqualTo(expectedParts.get(0)); + + Part actualPart = new ObjectMapper().readValue(decodeBase64String(actualParts.get(1)), Part.class); + Part expectedPart = new ObjectMapper().readValue(decodeBase64String(expectedParts.get(1)), Part.class); + + assertThat(actualPart.getIss()).isEqualTo(expectedPart.getIss()); + assertThat(actualPart.getAud()).isEqualTo(expectedPart.getAud()); + assertThat(actualPart.getSub()).isEqualTo(expectedPart.getSub()); + + assertThat(actualPart.getExp() - actualPart.getIat()).isGreaterThanOrEqualTo(110); + assertThat(actualPart.getExp() - actualPart.getIat()).isLessThanOrEqualTo(130); + } + + private static String decodeBase64String(String input) { + byte[] decodedBytes = Base64.getDecoder().decode(input); + return new String(decodedBytes, StandardCharsets.UTF_8); + } + + static final String fakePrivateKey = + "-----BEGIN PRIVATE KEY-----MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDaaLkxUbuT6lBD9ZPXBjjJkA4+JzewQ/kXPAryD1cvh7hQN07KNy/bn0eBviRbpqp7sbCRo2hk8F2TRGb13yHM/h0uPTzcDGMmwqScZ7BvP0lvWKs8AtvwLJbyPpXOXa8cGy1al5mA9sc0xLLprKGrWW0GK2pzcIzvOiDyG1W9a0oOJPPvA6nS5N9AusvRrnaLhukL4bXo6iYPRMj6vkJwEAKp0S5Lj9/5lqN9QynipYrpwMcBnmFZst+IFGfpu5w9EKxTLR4O/3Cf8CXdcEidGZMQE2jYtQPyTJlscG6fdy421q0qUKf/fFN6n+cRw5aqXRXZsm+UC1aqhIMr3LD5AgMBAAECggEAAWhs5TCdZZvkDRXBGDPEkL0rCpyqR4c94s30/QClG0A+iDaI2iB54HMxqQjhJRYIqYxlSd4WH+6oP+dh5B1pZzaGan4Uh+wxAJsvrG3zt/xrXgaShlpjZlMatTytmKTOIZoPuvBxRf+ruyx4iIFUo72zF9sj02BMysLAocnwzbURT7hFTc3TjzQRC0NWMNASPXqxhTcM+DI6YWMaxJxbXcW47t6a6NNI0/LvRNeP6PyaVEVM2N3Br52zgnM1Arv0l3jBKlCMhLQAavWbKVtRZUVLVMuOMeRdSpfaXlo9AEB2uxtPvuWJm+4p56Irlh/ggu/9idBmhWNjKF7yEJ7OVwKBgQD6AmZGyvl9nQQ1cnRjutaTlCwxPbaka+Z4V0IZ6knUSLMfydbgUJjI17odNkBmaAVBHY0mFjv48t38/ooCai50dzW3RIEqGVS39e9I2/0ONtZlyZrypvNZxrZ3Pa8F4cNTd05/ldSdiwl2Ric6zqXk5QrBspJnpT4aQ56qf21UewKBgQDfpHtzr3VdhBioyoYCR6R7sQIp2wyLEJ1cEBqpTW97RViiiNTLops8P0+kHZK3Rx8etv9/DTlufyuA+qGNLs+qnJ4MOD/GMTeTu/LPnQSx/ilR2coDFQKh3n/Z0p18FfG8evLNfeRjlTGSYvMR/YvIc0OkuqRWCuDBeXwCHxfYGwKBgCXQ4RmKMCzA6FcRReuj4jsWaYzVMeAy9fxz7mqvFpXGnVmMlTT+2+1dPCiZASq8RzcvOh9ts4qXad6Pvd5Zo0c4lOZwtTzh8f+VcqlJpUBWKR3iXc6gVCTbOtRUfznbiUkBvdzsk+l0k2zRdbOeeFdkEbl0wlJtGzSrz78oYSgrAoGAGkeEripe+zcrgqIRrzDl9hbtrydrSOgR5aCK0Xwk7nJOoQK9JpSb8y9pV1qWQ+0ajgxo53ARYJeW8BgDZcirZFv1AnCVpd9grX53YMgNpjC8gD68SzJr1cOEeH8UPGGDv2cfIuB5Nu5wHch80Y9enpZUy4WXC/lJQdLZrJIkxiMCgYEAhTU67BVB/sSKYY+wKqcu96fz6U5FrhEHqgV24326OOpyCjf+aPyK6O0lBw5bYfaXCR1XmO4mwk+ZCcfRg9kqmFUb+OoJGMJipeuQmDXDv2naGqECZ14dOid32KeJva11CJfzAT0SFQFD+tb7HZK5VrY9C/t5FFsSxKF0G2QGb9k=-----END PRIVATE KEY-----"; + static final String fakeJwt = + "eyJhbGciOiJSUzI1NiJ9.eyJpc3MiOiJjbGllbnRfaWQiLCJzdWIiOiJ1c2VyX25hbWUiLCJhdWQiOiJsb2dpbi50ZXN0MS5wYy1ybmQuc2FsZXNmb3JjZS5jb20iLCJpYXQiOjE3MTgyMjkxNjAsImV4cCI6MTcxODIyOTI4MH0.QzwM8DP0CInYPvj8u7wP1bJ3bXnOXOh5mTFFXkiInIdZ_NveVyiv5xnupAAm7ri8c2C4it4dkIXX3SlrMkX8dSJoff5JG6J_t8qJCGYqnW4MZHuAwEYg1Pbn9eqDvni0PRsA4kUloYr1OLIkIEFySPnzEBPcD-Yxrm5jdAlMINHC38brHMhBBPjuKnWQIZz4iVcQaLt9HcC6yx351PwMUX64yiUPbht1qP54Ohbu6zAn5wf1h1-47_X-7ewYVytqw7teqbN-2vlDiHiHsMsh1rmnp0wFDicE-q8aWpaY7m5DQJLxpUmwhnGPjQ8-EsEeG6jr0Ul6EQv4FuO9WiinmQ"; + static final String fakeToken = + "eyJraWQiOiJDT1JFLjAwRE9LMDAwMDAwOVp6ci4xNzE4MDUyMTU0NDIyIiwidHlwIjoiSldUIiwiYWxnIjoiRVMyNTYifQ.eyJzdWIiOiJodHRwczovL2xvZ2luLnRlc3QxLnBjLXJuZC5zYWxlc2ZvcmNlLmNvbS9pZC8wMERPSzAwMDAwMDlaenIyQUUvMDA1T0swMDAwMDBVeTkxWUFDIiwic2NwIjoiY2RwX3Byb2ZpbGVfYXBpIGNkcF9pbmdlc3RfYXBpIGNkcF9pZGVudGl0eXJlc29sdXRpb25fYXBpIGNkcF9zZWdtZW50X2FwaSBjZHBfcXVlcnlfYXBpIGNkcF9hcGkiLCJpc3MiOiJodHRwczovL2xvZ2luLnRlc3QxLnBjLXJuZC5zYWxlc2ZvcmNlLmNvbS8iLCJvcmdJZCI6IjAwRE9LMDAwMDAwOVp6ciIsImlzc3VlclRlbmFudElkIjoiY29yZS9mYWxjb250ZXN0MS1jb3JlNG9yYTE1LzAwRE9LMDAwMDAwOVp6cjJBRSIsInNmYXBwaWQiOiIzTVZHOVhOVDlUbEI3VmtZY0tIVm5sUUZzWEd6cUJuMGszUC5zNHJBU0I5V09oRU1OdkgyNzNpM1NFRzF2bWl3WF9YY2NXOUFZbHA3VnJnQ3BGb0ZXIiwiYXVkaWVuY2VUZW5hbnRJZCI6ImEzNjAvZmFsY29uZGV2L2E2ZDcyNmE3M2Y1MzQzMjdhNmE4ZTJlMGYzY2MzODQwIiwiY3VzdG9tX2F0dHJpYnV0ZXMiOnsiZGF0YXNwYWNlIjoiZGVmYXVsdCJ9LCJhdWQiOiJhcGkuYTM2MC5zYWxlc2ZvcmNlLmNvbSIsIm5iZiI6MTcyMDczMTAyMSwic2ZvaWQiOiIwMERPSzAwMDAwMDlaenIiLCJzZnVpZCI6IjAwNU9LMDAwMDAwVXk5MSIsImV4cCI6MTcyMDczODI4MCwiaWF0IjoxNzIwNzMxMDgxLCJqdGkiOiIwYjYwMzc4OS1jMGI2LTQwZTMtYmIzNi03NDQ3MzA2MzAxMzEifQ.lXgeAhJIiGoxgNpBi0W5oBWyn2_auB2bFxxajGuK6DMHlkqDhHJAlFN_uf6QPSjGSJCh5j42Ow5SrEptUDJwmQ"; + static final String fakeTenantId = "a360/falcondev/a6d726a73f534327a6a8e2e0f3cc3840"; +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/PropertiesUtils.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/PropertiesUtils.java new file mode 100644 index 0000000..78d54d1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/PropertiesUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import lombok.val; + +public class PropertiesUtils { + public static Properties allPropertiesExcept(Set except) { + val properties = new Properties(); + AuthenticationSettings.Keys.ALL.stream() + .filter(k -> !except.contains(k)) + .forEach(k -> properties.setProperty(k, randomString())); + return properties; + } + + public static Properties allPropertiesExcept(String... excepts) { + Set except = excepts == null || excepts.length == 0 ? Set.of() : Set.of(excepts); + return allPropertiesExcept(except); + } + + public static String randomString() { + return UUID.randomUUID().toString(); + } + + public static Properties propertiesForPrivateKey(String privateKey) { + val properties = + allPropertiesExcept(AuthenticationSettings.Keys.PASSWORD, AuthenticationSettings.Keys.REFRESH_TOKEN); + properties.setProperty(AuthenticationSettings.Keys.PRIVATE_KEY, privateKey); + properties.setProperty(AuthenticationSettings.Keys.LOGIN_URL, "login.test1.pc-rnd.salesforce.com"); + properties.setProperty(AuthenticationSettings.Keys.CLIENT_ID, "client_id"); + properties.setProperty(AuthenticationSettings.Keys.USER_NAME, "user_name"); + return properties; + } + + public static Properties propertiesForPassword(String userName, String password) { + val properties = + allPropertiesExcept(AuthenticationSettings.Keys.PRIVATE_KEY, AuthenticationSettings.Keys.REFRESH_TOKEN); + properties.setProperty(AuthenticationSettings.Keys.USER_NAME, userName); + properties.setProperty(AuthenticationSettings.Keys.PASSWORD, password); + properties.setProperty(AuthenticationSettings.Keys.LOGIN_URL, "login.test1.pc-rnd.salesforce.com"); + return properties; + } + + public static Properties propertiesForRefreshToken(String refreshToken) { + val properties = + allPropertiesExcept(AuthenticationSettings.Keys.PASSWORD, AuthenticationSettings.Keys.PRIVATE_KEY); + properties.setProperty(AuthenticationSettings.Keys.REFRESH_TOKEN, refreshToken); + properties.setProperty(AuthenticationSettings.Keys.LOGIN_URL, "login.test1.pc-rnd.salesforce.com"); + return properties; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/auth/TokenCacheImplTest.java b/src/test/java/com/salesforce/datacloud/jdbc/auth/TokenCacheImplTest.java new file mode 100644 index 0000000..931a83c --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/auth/TokenCacheImplTest.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.auth; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import java.util.UUID; +import lombok.SneakyThrows; +import lombok.val; +import org.junit.jupiter.api.Test; + +public class TokenCacheImplTest { + private final ObjectMapper mapper = new ObjectMapper(); + + @Test + void canSetGetAndClearADataCloudToken() { + val accessToken = UUID.randomUUID().toString(); + val token = makeDataCloudToken(accessToken); + + val sut = new TokenCacheImpl(); + + assertThat(sut.getDataCloudToken()).isNull(); + sut.setDataCloudToken(token); + assertThat(sut.getDataCloudToken()).isEqualTo(token); + sut.clearDataCloudToken(); + assertThat(sut.getDataCloudToken()).isNull(); + } + + @SneakyThrows + private DataCloudToken makeDataCloudToken(String accessToken) { + val json = String.format( + "{\"access_token\": \"%s\", \"instance_url\": \"something.salesforce.com\", \"token_type\": \"something\", \"expires_in\": 100 }", + accessToken); + val model = mapper.readValue(json, DataCloudTokenResponse.class); + return DataCloudToken.of(model); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/config/QueryResourcesTest.java b/src/test/java/com/salesforce/datacloud/jdbc/config/QueryResourcesTest.java new file mode 100644 index 0000000..eb57c7c --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/config/QueryResourcesTest.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import lombok.val; +import org.junit.jupiter.api.Test; + +class QueryResourcesTest { + + @Test + void getColumnsQuery() { + val actual = QueryResources.getColumnsQuery(); + assertThat(actual) + .contains("SELECT n.nspname,") + .contains("FROM pg_catalog.pg_namespace n") + .contains("WHERE c.relkind in ('r', 'p', 'v', 'f', 'm')"); + } + + @Test + void getSchemasQuery() { + val actual = QueryResources.getSchemasQuery(); + assertThat(actual) + .contains("SELECT nspname") + .contains("FROM pg_catalog.pg_namespace") + .contains("WHERE nspname"); + } + + @Test + void getTablesQuery() { + val actual = QueryResources.getTablesQuery(); + assertThat(actual) + .contains("SELECT") + .contains("FROM pg_catalog.pg_namespace") + .contains("LEFT JOIN pg_catalog.pg_description d ON"); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/config/ResourceReaderTest.java b/src/test/java/com/salesforce/datacloud/jdbc/config/ResourceReaderTest.java new file mode 100644 index 0000000..9057e9f --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/config/ResourceReaderTest.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.config; + +import static com.salesforce.datacloud.jdbc.config.ResourceReader.readResourceAsProperties; +import static com.salesforce.datacloud.jdbc.config.ResourceReader.readResourceAsString; +import static com.salesforce.datacloud.jdbc.config.ResourceReader.readResourceAsStringList; +import static com.salesforce.datacloud.jdbc.config.ResourceReader.withResourceAsStream; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.io.IOException; +import java.util.UUID; +import lombok.val; +import org.assertj.core.api.AssertionsForInterfaceTypes; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class ResourceReaderTest { + private static final String expectedState = "58P01"; + private static final String validPath = "/simplelogger.properties"; + private static final String validProperty = "org.slf4j.simpleLogger.defaultLogLevel"; + + @Test + void withResourceAsStreamHandlesIOException() { + val message = UUID.randomUUID().toString(); + val ex = Assertions.assertThrows( + DataCloudJDBCException.class, + () -> withResourceAsStream(validPath, in -> { + throw new IOException(message); + })); + + assertThat(ex) + .hasMessage("Error while loading resource file. path=" + validPath) + .hasRootCauseMessage(message); + assertThat(ex.getSQLState()).isEqualTo(expectedState); + } + + @Test + void readResourceAsStringThrowsOnNotFound() { + val badPath = "/" + UUID.randomUUID(); + val ex = Assertions.assertThrows(DataCloudJDBCException.class, () -> readResourceAsString(badPath)); + + assertThat(ex).hasMessage("Resource file not found. path=" + badPath); + assertThat(ex.getSQLState()).isEqualTo(expectedState); + } + + @Test + void readResourceAsPropertiesThrowsOnNotFound() { + val badPath = "/" + UUID.randomUUID(); + val ex = Assertions.assertThrows(DataCloudJDBCException.class, () -> readResourceAsProperties(badPath)); + + assertThat(ex).hasMessage("Resource file not found. path=" + badPath); + assertThat(ex.getSQLState()).isEqualTo(expectedState); + } + + @Test + void readResourceAsStringHappyPath() { + assertThat(readResourceAsString(validPath)).contains(validProperty); + } + + @Test + void readResourceAsPropertiesHappyPath() { + assertThat(readResourceAsProperties(validPath).getProperty(validProperty)) + .isNotNull() + .isNotBlank(); + } + + @Test + void readResourceAsStringListHappyPath() { + AssertionsForInterfaceTypes.assertThat(readResourceAsStringList(validPath)) + .hasSizeGreaterThanOrEqualTo(1); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursorTest.java new file mode 100644 index 0000000..beafdd1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/ArrowStreamReaderCursorTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.util.stream.IntStream; +import lombok.SneakyThrows; +import lombok.val; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class ArrowStreamReaderCursorTest { + @Mock + protected ArrowStreamReader reader; + + @Mock + protected VectorSchemaRoot root; + + @Test + void createGetterIsUnsupported() { + val sut = new ArrowStreamReaderCursor(reader); + Assertions.assertThrows(UnsupportedOperationException.class, () -> sut.createGetter(0)); + } + + @Test + @SneakyThrows + void closesTheReader() { + val sut = new ArrowStreamReaderCursor(reader); + sut.close(); + verify(reader, times(1)).close(); + } + + @Test + @SneakyThrows + void incrementsInternalIndexUntilRowsExhaustedThenLoadsNextBatch() { + val times = 5; + when(reader.getVectorSchemaRoot()).thenReturn(root); + when(reader.loadNextBatch()).thenReturn(true); + when(root.getRowCount()).thenReturn(times); + + val sut = new ArrowStreamReaderCursor(reader); + IntStream.range(0, times + 1).forEach(i -> sut.next()); + + verify(root, times(times + 1)).getRowCount(); + verify(reader, times(1)).loadNextBatch(); + } + + @ParameterizedTest + @SneakyThrows + @ValueSource(booleans = {true, false}) + void forwardsLoadNextBatch(boolean result) { + when(root.getRowCount()).thenReturn(-10); + when(reader.getVectorSchemaRoot()).thenReturn(root); + when(reader.loadNextBatch()).thenReturn(result); + + val sut = new ArrowStreamReaderCursor(reader); + + assertThat(sut.next()).isEqualTo(result); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/AsyncStreamingResultSetTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/AsyncStreamingResultSetTest.java new file mode 100644 index 0000000..3fda6ce --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/AsyncStreamingResultSetTest.java @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; +import io.grpc.StatusRuntimeException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Pattern; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class AsyncStreamingResultSetTest extends HyperTestBase { + private static final String sql = + "select cast(a as numeric(38,18)) a, cast(a as numeric(38,18)) b, cast(a as numeric(38,18)) c from generate_series(1, 1024 * 1024 * 10) as s(a) order by a asc"; + + @Test + @SneakyThrows + public void testThrowsOnNonsenseQueryAsync() { + val ex = Assertions.assertThrows(DataCloudJDBCException.class, () -> { + try (val connection = HyperTestBase.getHyperQueryConnection(); + val statement = connection.createStatement().unwrap(DataCloudStatement.class)) { + val rs = statement.executeAsyncQuery("select * from nonsense"); + waitUntilReady(statement); + rs.getResultSet().next(); + } + }); + + AssertionsForClassTypes.assertThat(ex).hasCauseInstanceOf(StatusRuntimeException.class); + Pattern rootCausePattern = Pattern.compile("^INVALID_ARGUMENT: table \"nonsense\" does not exist.*"); + AssertionsForClassTypes.assertThat(ex.getCause().getMessage()).containsPattern(rootCausePattern); + } + + @Test + @SneakyThrows + public void testNoDataIsLostAsync() { + assertWithStatement(statement -> { + statement.executeAsyncQuery(sql); + + val asyncReady = waitUntilReady(statement); + + val rs = statement.getResultSet(); + assertThat(asyncReady).isTrue(); + assertThat(rs).isInstanceOf(StreamingResultSet.class); + + val expected = new AtomicInteger(0); + + while (rs.next()) { + assertEachRowIsTheSame(rs, expected); + } + + assertThat(expected.get()).isEqualTo(1024 * 1024 * 10); + }); + } + + @Test + @SneakyThrows + public void testQueryIdChangesInHeaderAsync() { + try (val connection = getHyperQueryConnection(); + val statement = connection.createStatement().unwrap(DataCloudStatement.class)) { + val rs = statement.executeAsyncQuery(sql); + waitUntilReady(statement); + rs.getResultSet().next(); + + rs.executeAsyncQuery(sql); + waitUntilReady(statement); + } catch (StatusRuntimeException e) { + Assertions.fail(e); + } + } + + @SneakyThrows + static boolean waitUntilReady(DataCloudStatement statement) { + while (!statement.isReady()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + return false; + } + } + return true; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/ConnectionSettingsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/ConnectionSettingsTest.java new file mode 100644 index 0000000..b76f1c4 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/ConnectionSettingsTest.java @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; +import java.time.LocalDate; +import java.time.format.DateTimeFormatter; +import java.util.Map; +import lombok.SneakyThrows; +import lombok.val; +import org.junit.jupiter.api.Test; + +public class ConnectionSettingsTest extends HyperTestBase { + @Test + @SneakyThrows + public void testHyperRespectsConnectionSetting() { + val settings = Map.entry("serverSetting.date_style", "YMD"); + val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); + + assertWithStatement( + statement -> { + val result = statement.executeQuery("SELECT CURRENT_DATE"); + result.next(); + + val expected = LocalDate.parse(result.getDate(1).toString(), formatter); + val actual = result.getDate(1); + + assertThat(actual.toString()).isEqualTo(expected.toString()); + }, + settings); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudConnectionTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudConnectionTest.java new file mode 100644 index 0000000..7eabec8 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudConnectionTest.java @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPassword; +import static com.salesforce.datacloud.jdbc.util.Messages.ILLEGAL_CONNECTION_PROTOCOL; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.Connection; +import java.util.Properties; +import lombok.SneakyThrows; +import lombok.val; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class DataCloudConnectionTest extends HyperGrpcTestBase { + static final String validUrl = "jdbc:salesforce-datacloud://login.salesforce.com"; + + static final Properties properties = propertiesForPassword("un", "pw"); + + @SneakyThrows + @Test + void getServiceRootUrl_acceptsWellFormedUrl() { + val url = "jdbc:salesforce-datacloud://login.salesforce.com"; + assertThat(DataCloudConnection.acceptsUrl(url)).isTrue(); + assertThat(DataCloudConnection.getServiceRootUrl(url)).isEqualTo("https://login.salesforce.com"); + } + + @SneakyThrows + @Test + void getServiceRootUrl_rejectsUrlThatContainsHttps() { + val url = "jdbc:salesforce-datacloud:https://login.salesforce.com"; + assertThat(DataCloudConnection.acceptsUrl(url)).isFalse(); + val ex = + Assertions.assertThrows(DataCloudJDBCException.class, () -> DataCloudConnection.getServiceRootUrl(url)); + assertThat(ex).hasMessage(ILLEGAL_CONNECTION_PROTOCOL); + } + + @SneakyThrows + @Test + void getServiceRootUrl_rejectsUrlThatDoesNotStartWithConnectionProtocol() { + val url = "foo:https://login.salesforce.com"; + assertThat(DataCloudConnection.acceptsUrl(url)).isFalse(); + val ex = + Assertions.assertThrows(DataCloudJDBCException.class, () -> DataCloudConnection.getServiceRootUrl(url)); + assertThat(ex).hasMessage(ILLEGAL_CONNECTION_PROTOCOL); + } + + @SneakyThrows + @Test + void getServiceRootUrl_rejectsUrlThatIsMalformed() { + val url = "jdbc:salesforce-datacloud://log^in.sal^esf^orce.c^om"; + assertThat(DataCloudConnection.acceptsUrl(url)).isTrue(); + val ex = + Assertions.assertThrows(DataCloudJDBCException.class, () -> DataCloudConnection.getServiceRootUrl(url)); + assertThat(ex).hasMessage(ILLEGAL_CONNECTION_PROTOCOL).hasCauseInstanceOf(IllegalArgumentException.class); + } + + @SneakyThrows + @Test + void getServiceRootUrl_rejectsUrlThatIsNull() { + assertThat(DataCloudConnection.acceptsUrl(null)).isFalse(); + val ex = Assertions.assertThrows( + DataCloudJDBCException.class, () -> DataCloudConnection.getServiceRootUrl(null)); + assertThat(ex).hasMessage(ILLEGAL_CONNECTION_PROTOCOL); + } + + @Test + void testCreateStatement() { + try (val connection = sut()) { + val statement = connection.createStatement(); + assertThat(statement).isInstanceOf(DataCloudStatement.class); + } + } + + @Test + void testNullUrlThrows() { + Assertions.assertThrows(DataCloudJDBCException.class, () -> DataCloudConnection.of(null, new Properties())); + } + + @Test + void testUnsupportedPrefixUrlNotAllowed() { + val ex = assertThrows(DataCloudJDBCException.class, () -> DataCloudConnection.of("fake-url", new Properties())); + assertThat(ex).hasMessage(ILLEGAL_CONNECTION_PROTOCOL); + } + + @Test + void testClose() { + try (val connection = sut()) { + assertThat(connection.isClosed()).isFalse(); + connection.close(); + assertThat(connection.isClosed()).isTrue(); + } + } + + @Test + void testGetMetadata() { + try (val connection = sut()) { + assertThat(connection.getMetaData()).isInstanceOf(DataCloudDatabaseMetadata.class); + } + } + + @Test + void testGetTransactionIsolation() { + try (val connection = sut()) { + assertThat(connection.getTransactionIsolation()).isEqualTo(Connection.TRANSACTION_NONE); + } + } + + @Test + void testIsValidNegativeTimeoutThrows() { + try (val connection = sut()) { + val ex = assertThrows(DataCloudJDBCException.class, () -> connection.isValid(-1)); + assertThat(ex).hasMessage("Invalid timeout value: -1").hasNoCause(); + } + } + + @Test + @SneakyThrows + void testIsValid() { + try (val connection = sut()) { + assertThat(connection.isValid(200)).isTrue(); + } + } + + @Test + @SneakyThrows + void testConnectionUnwrap() { + val connection = sut(); + DataCloudConnection query_conn = connection.unwrap(DataCloudConnection.class); + assertThat(connection.isWrapperFor(DataCloudConnection.class)).isTrue(); + assertThrows(DataCloudJDBCException.class, () -> connection.unwrap(String.class)); + connection.close(); + } + + private DataCloudConnection sut() { + return DataCloudConnection.builder() + .executor(hyperGrpcClient) + .tokenProcessor(mockSession) + .properties(properties) + .build(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadataTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadataTest.java new file mode 100644 index 0000000..4e59e1d --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudDatabaseMetadataTest.java @@ -0,0 +1,1553 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.propertiesForPassword; +import static com.salesforce.datacloud.jdbc.auth.PropertiesUtils.randomString; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; +import com.salesforce.datacloud.jdbc.auth.DataCloudToken; +import com.salesforce.datacloud.jdbc.auth.OAuthToken; +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.auth.model.DataCloudTokenResponse; +import com.salesforce.datacloud.jdbc.auth.model.OAuthTokenResponse; +import com.salesforce.datacloud.jdbc.config.KeywordResources; +import com.salesforce.datacloud.jdbc.core.model.DataspaceResponse; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.http.ClientBuilder; +import com.salesforce.datacloud.jdbc.util.Constants; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +@Slf4j +public class DataCloudDatabaseMetadataTest { + static final int NUM_TABLE_METADATA_COLUMNS = 10; + static final int NUM_COLUMN_METADATA_COLUMNS = 24; + static final int NUM_SCHEMA_METADATA_COLUMNS = 2; + static final int NUM_TABLE_TYPES_METADATA_COLUMNS = 1; + static final int NUM_CATALOG_METADATA_COLUMNS = 1; + private static final String FAKE_TOKEN = + "eyJraWQiOiJDT1JFLjAwRE9LMDAwMDAwOVp6ci4xNzE4MDUyMTU0NDIyIiwidHlwIjoiSldUIiwiYWxnIjoiRVMyNTYifQ.eyJzdWIiOiJodHRwczovL2xvZ2luLnRlc3QxLnBjLXJuZC5zYWxlc2ZvcmNlLmNvbS9pZC8wMERPSzAwMDAwMDlaenIyQUUvMDA1T0swMDAwMDBVeTkxWUFDIiwic2NwIjoiY2RwX3Byb2ZpbGVfYXBpIGNkcF9pbmdlc3RfYXBpIGNkcF9pZGVudGl0eXJlc29sdXRpb25fYXBpIGNkcF9zZWdtZW50X2FwaSBjZHBfcXVlcnlfYXBpIGNkcF9hcGkiLCJpc3MiOiJodHRwczovL2xvZ2luLnRlc3QxLnBjLXJuZC5zYWxlc2ZvcmNlLmNvbS8iLCJvcmdJZCI6IjAwRE9LMDAwMDAwOVp6ciIsImlzc3VlclRlbmFudElkIjoiY29yZS9mYWxjb250ZXN0MS1jb3JlNG9yYTE1LzAwRE9LMDAwMDAwOVp6cjJBRSIsInNmYXBwaWQiOiIzTVZHOVhOVDlUbEI3VmtZY0tIVm5sUUZzWEd6cUJuMGszUC5zNHJBU0I5V09oRU1OdkgyNzNpM1NFRzF2bWl3WF9YY2NXOUFZbHA3VnJnQ3BGb0ZXIiwiYXVkaWVuY2VUZW5hbnRJZCI6ImEzNjAvZmFsY29uZGV2L2E2ZDcyNmE3M2Y1MzQzMjdhNmE4ZTJlMGYzY2MzODQwIiwiY3VzdG9tX2F0dHJpYnV0ZXMiOnsiZGF0YXNwYWNlIjoiZGVmYXVsdCJ9LCJhdWQiOiJhcGkuYTM2MC5zYWxlc2ZvcmNlLmNvbSIsIm5iZiI6MTcyMDczMTAyMSwic2ZvaWQiOiIwMERPSzAwMDAwMDlaenIiLCJzZnVpZCI6IjAwNU9LMDAwMDAwVXk5MSIsImV4cCI6MTcyMDczODI4MCwiaWF0IjoxNzIwNzMxMDgxLCJqdGkiOiIwYjYwMzc4OS1jMGI2LTQwZTMtYmIzNi03NDQ3MzA2MzAxMzEifQ.lXgeAhJIiGoxgNpBi0W5oBWyn2_auB2bFxxajGuK6DMHlkqDhHJAlFN_uf6QPSjGSJCh5j42Ow5SrEptUDJwmQ"; + private static final String FAKE_TENANT_ID = "a360/falcondev/a6d726a73f534327a6a8e2e0f3cc3840"; + + @Mock + DataCloudStatement dataCloudStatement; + + @Mock + TokenProcessor tokenProcessor; + + @Mock + ResultSet resultSetMock; + + @Mock + AuthenticationSettings authenticationSettings; + + DataCloudDatabaseMetadata dataCloudDatabaseMetadata; + + @BeforeEach + public void beforeEach() { + dataCloudStatement = mock(DataCloudStatement.class); + tokenProcessor = mock(TokenProcessor.class); + val properties = propertiesForPassword("un", "pw"); + val client = ClientBuilder.buildOkHttpClient(properties); + authenticationSettings = mock(AuthenticationSettings.class); + dataCloudDatabaseMetadata = new DataCloudDatabaseMetadata( + dataCloudStatement, Optional.ofNullable(tokenProcessor), client, "loginURL", "userName"); + } + + @Test + public void testAllProceduresAreCallable() { + assertThat(dataCloudDatabaseMetadata.allProceduresAreCallable()).isFalse(); + } + + @Test + public void testAllTablesAreSelectable() { + assertThat(dataCloudDatabaseMetadata.allTablesAreSelectable()).isTrue(); + } + + @Test + public void testGetURL() { + assertThat(dataCloudDatabaseMetadata.getURL()).isEqualTo("loginURL"); + } + + @Test + public void testGetUserName() { + assertThat(dataCloudDatabaseMetadata.getUserName()).isEqualTo("userName"); + } + + @Test + public void testIsReadOnly() { + assertThat(dataCloudDatabaseMetadata.isReadOnly()).isTrue(); + } + + @Test + public void testNullsAreSortedHigh() { + assertThat(dataCloudDatabaseMetadata.nullsAreSortedHigh()).isFalse(); + } + + @Test + public void testNullsAreSortedLow() { + assertThat(dataCloudDatabaseMetadata.nullsAreSortedLow()).isTrue(); + } + + @Test + public void testNullsAreSortedAtStart() { + assertThat(dataCloudDatabaseMetadata.nullsAreSortedAtStart()).isFalse(); + } + + @Test + public void testNullsAreSortedAtEnd() { + assertThat(dataCloudDatabaseMetadata.nullsAreSortedAtEnd()).isFalse(); + } + + @Test + public void testGetDatabaseProductName() { + assertThat(dataCloudDatabaseMetadata.getDatabaseProductName()).isEqualTo(Constants.DATABASE_PRODUCT_NAME); + } + + @Test + public void testGetDatabaseProductVersion() { + assertThat(dataCloudDatabaseMetadata.getDatabaseProductVersion()).isEqualTo(Constants.DATABASE_PRODUCT_VERSION); + } + + @Test + public void testGetDriverName() { + assertThat(dataCloudDatabaseMetadata.getDriverName()).isEqualTo(Constants.DRIVER_NAME); + } + + @Test + public void testGetDriverVersion() { + assertThat(dataCloudDatabaseMetadata.getDriverVersion()).isEqualTo(Constants.DRIVER_VERSION); + } + + @Test + public void testGetDriverMajorVersion() { + assertThat(dataCloudDatabaseMetadata.getDriverMajorVersion()).isEqualTo(1); + } + + @Test + public void testGetDriverMinorVersion() { + assertThat(dataCloudDatabaseMetadata.getDriverMinorVersion()).isEqualTo(0); + } + + @Test + public void testUsesLocalFiles() { + assertThat(dataCloudDatabaseMetadata.usesLocalFiles()).isFalse(); + } + + @Test + public void testUsesLocalFilePerTable() { + assertThat(dataCloudDatabaseMetadata.usesLocalFilePerTable()).isFalse(); + } + + @Test + public void testSupportsMixedCaseIdentifiers() { + assertThat(dataCloudDatabaseMetadata.supportsMixedCaseIdentifiers()).isFalse(); + } + + @Test + public void testStoresUpperCaseIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesUpperCaseIdentifiers()).isFalse(); + } + + @Test + public void testStoresLowerCaseIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesLowerCaseIdentifiers()).isTrue(); + } + + @Test + public void testStoresMixedCaseIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesMixedCaseIdentifiers()).isFalse(); + } + + @Test + public void testSupportsMixedCaseQuotedIdentifiers() { + assertThat(dataCloudDatabaseMetadata.supportsMixedCaseQuotedIdentifiers()) + .isTrue(); + } + + @Test + public void testStoresUpperCaseQuotedIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesUpperCaseQuotedIdentifiers()).isFalse(); + } + + @Test + public void testStoresLowerCaseQuotedIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesLowerCaseQuotedIdentifiers()).isFalse(); + } + + @Test + public void testStoresMixedCaseQuotedIdentifiers() { + assertThat(dataCloudDatabaseMetadata.storesMixedCaseQuotedIdentifiers()).isFalse(); + } + + @Test + public void testGetIdentifierQuoteString() { + assertThat(dataCloudDatabaseMetadata.getIdentifierQuoteString()).isEqualTo("\""); + } + /** + * The expected output of getSqlKeywords is an alphabetized, all-caps, comma-delimited String made up of all the + * keywords found in Hyper's SQL Lexer, excluding: those that are also SQL:2003 keywords + * pseudo-tokens like "<=", ">=", "==", "=>" tokens ending with "_la" Hyper-Script tokens like "break", "continue", + * "throw", "var", "while", "yield" To add new keywords, adjust the file + * src/main/resources/keywords/hyper_sql_lexer_keywords.txt + */ + @Test + public void testGetSQLKeywords() { + val actual = dataCloudDatabaseMetadata.getSQLKeywords().split(","); + assertThat(actual.length).isGreaterThan(250).isLessThan(300); + KeywordResources.SQL_2003_KEYWORDS.forEach(k -> assertThat(actual).doesNotContain(k)); + val sorted = Arrays.stream(actual).sorted().collect(Collectors.toList()); + val uppercase = Arrays.stream(actual).map(String::toUpperCase).collect(Collectors.toList()); + val distinct = Arrays.stream(actual).distinct().collect(Collectors.toList()); + assertThat(sorted) + .withFailMessage("SQL Keywords should be in alphabetical order.") + .containsExactly(actual); + assertThat(uppercase) + .withFailMessage("SQL Keywords should contain uppercase values.") + .containsExactly(actual); + assertThat(distinct) + .withFailMessage("SQL Keywords should have no duplicates.") + .containsExactly(actual); + } + + @Test + public void testGetNumericFunctions() { + assertThat(dataCloudDatabaseMetadata.getNumericFunctions()).isNull(); + } + + @Test + public void testGetStringFunctions() { + assertThat(dataCloudDatabaseMetadata.getStringFunctions()).isNull(); + } + + @Test + public void testGetSystemFunctions() { + assertThat(dataCloudDatabaseMetadata.getSystemFunctions()).isNull(); + } + + @Test + public void testGetTimeDateFunctions() { + assertThat(dataCloudDatabaseMetadata.getTimeDateFunctions()).isNull(); + } + + @Test + public void testGetSearchStringEscape() { + assertThat(dataCloudDatabaseMetadata.getSearchStringEscape()).isEqualTo("\\"); + } + + @Test + public void testGetExtraNameCharacters() { + assertThat(dataCloudDatabaseMetadata.getExtraNameCharacters()).isNull(); + } + + @Test + public void testSupportsAlterTableWithAddColumn() { + assertThat(dataCloudDatabaseMetadata.supportsAlterTableWithAddColumn()).isFalse(); + } + + @Test + public void testSupportsAlterTableWithDropColumn() { + assertThat(dataCloudDatabaseMetadata.supportsAlterTableWithDropColumn()).isFalse(); + } + + @Test + public void testSupportsColumnAliasing() { + assertThat(dataCloudDatabaseMetadata.supportsColumnAliasing()).isTrue(); + } + + @Test + public void testNullPlusNonNullIsNull() { + assertThat(dataCloudDatabaseMetadata.nullPlusNonNullIsNull()).isFalse(); + } + + @Test + public void testSupportsConvert() { + assertThat(dataCloudDatabaseMetadata.supportsConvert()).isTrue(); + } + + @Test + public void testSupportsConvertFromTypeToType() { + assertThat(dataCloudDatabaseMetadata.supportsConvert(1, 1)).isTrue(); + } + + @Test + public void testSupportsTableCorrelationNames() { + assertThat(dataCloudDatabaseMetadata.supportsTableCorrelationNames()).isTrue(); + } + + @Test + public void testSupportsDifferentTableCorrelationNames() { + assertThat(dataCloudDatabaseMetadata.supportsDifferentTableCorrelationNames()) + .isFalse(); + } + + @Test + public void testSupportsExpressionsInOrderBy() { + assertThat(dataCloudDatabaseMetadata.supportsExpressionsInOrderBy()).isTrue(); + } + + @Test + public void testSupportsOrderByUnrelated() { + assertThat(dataCloudDatabaseMetadata.supportsOrderByUnrelated()).isTrue(); + } + + @Test + public void testSupportsGroupBy() { + assertThat(dataCloudDatabaseMetadata.supportsGroupBy()).isTrue(); + } + + @Test + public void testSupportsGroupByUnrelated() { + assertThat(dataCloudDatabaseMetadata.supportsGroupByUnrelated()).isTrue(); + } + + @Test + public void testSupportsGroupByBeyondSelect() { + assertThat(dataCloudDatabaseMetadata.supportsGroupByBeyondSelect()).isTrue(); + } + + @Test + public void testSupportsLikeEscapeClause() { + assertThat(dataCloudDatabaseMetadata.supportsLikeEscapeClause()).isTrue(); + } + + @Test + public void testSupportsMultipleResultSets() { + assertThat(dataCloudDatabaseMetadata.supportsMultipleResultSets()).isFalse(); + } + + @Test + public void testSupportsMultipleTransactions() { + assertThat(dataCloudDatabaseMetadata.supportsMultipleTransactions()).isFalse(); + } + + @Test + public void testSupportsNonNullableColumns() { + assertThat(dataCloudDatabaseMetadata.supportsNonNullableColumns()).isTrue(); + } + + @Test + public void testSupportsMinimumSQLGrammar() { + assertThat(dataCloudDatabaseMetadata.supportsMinimumSQLGrammar()).isTrue(); + } + + @Test + public void testSupportsCoreSQLGrammar() { + assertThat(dataCloudDatabaseMetadata.supportsCoreSQLGrammar()).isFalse(); + } + + @Test + public void testSupportsExtendedSQLGrammar() { + assertThat(dataCloudDatabaseMetadata.supportsExtendedSQLGrammar()).isFalse(); + } + + @Test + public void testSupportsANSI92EntryLevelSQL() { + assertThat(dataCloudDatabaseMetadata.supportsANSI92EntryLevelSQL()).isTrue(); + } + + @Test + public void testSupportsANSI92IntermediateSQL() { + assertThat(dataCloudDatabaseMetadata.supportsANSI92IntermediateSQL()).isTrue(); + } + + @Test + public void testSupportsANSI92FullSQL() { + assertThat(dataCloudDatabaseMetadata.supportsANSI92FullSQL()).isTrue(); + } + + @Test + public void testSupportsIntegrityEnhancementFacility() { + assertThat(dataCloudDatabaseMetadata.supportsIntegrityEnhancementFacility()) + .isFalse(); + } + + @Test + public void testSupportsOuterJoins() { + assertThat(dataCloudDatabaseMetadata.supportsOuterJoins()).isTrue(); + } + + @Test + public void testSupportsFullOuterJoins() { + assertThat(dataCloudDatabaseMetadata.supportsFullOuterJoins()).isTrue(); + } + + @Test + public void testSupportsLimitedOuterJoins() { + assertThat(dataCloudDatabaseMetadata.supportsLimitedOuterJoins()).isTrue(); + } + + @Test + public void testGetSchemaTerm() { + assertThat(dataCloudDatabaseMetadata.getSchemaTerm()).isEqualTo("schema"); + } + + @Test + public void testGetProcedureTerm() { + assertThat(dataCloudDatabaseMetadata.getProcedureTerm()).isEqualTo("procedure"); + } + + @Test + public void testGetCatalogTerm() { + assertThat(dataCloudDatabaseMetadata.getCatalogTerm()).isEqualTo("database"); + } + + @Test + public void testIsCatalogAtStart() { + assertThat(dataCloudDatabaseMetadata.isCatalogAtStart()).isTrue(); + } + + @Test + public void testGetCatalogSeparator() { + assertThat(dataCloudDatabaseMetadata.getCatalogSeparator()).isEqualTo("."); + } + + @Test + public void testSupportsSchemasInDataManipulation() { + assertThat(dataCloudDatabaseMetadata.supportsSchemasInDataManipulation()) + .isFalse(); + } + + @Test + public void testSupportsSchemasInProcedureCalls() { + assertThat(dataCloudDatabaseMetadata.supportsSchemasInProcedureCalls()).isFalse(); + } + + @Test + public void testSupportsSchemasInTableDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsSchemasInTableDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsSchemasInIndexDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsSchemasInIndexDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsSchemasInPrivilegeDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsSchemasInPrivilegeDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsCatalogsInDataManipulation() { + assertThat(dataCloudDatabaseMetadata.supportsCatalogsInDataManipulation()) + .isFalse(); + } + + @Test + public void testSupportsCatalogsInProcedureCalls() { + assertThat(dataCloudDatabaseMetadata.supportsCatalogsInProcedureCalls()).isFalse(); + } + + @Test + public void testSupportsCatalogsInTableDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsCatalogsInTableDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsCatalogsInIndexDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsCatalogsInIndexDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsCatalogsInPrivilegeDefinitions() { + assertThat(dataCloudDatabaseMetadata.supportsCatalogsInPrivilegeDefinitions()) + .isFalse(); + } + + @Test + public void testSupportsPositionedDelete() { + assertThat(dataCloudDatabaseMetadata.supportsPositionedDelete()).isFalse(); + } + + @Test + public void testSupportsPositionedUpdate() { + assertThat(dataCloudDatabaseMetadata.supportsPositionedUpdate()).isFalse(); + } + + @Test + public void testSupportsSelectForUpdate() { + assertThat(dataCloudDatabaseMetadata.supportsSelectForUpdate()).isFalse(); + } + + @Test + public void testSupportsStoredProcedures() { + assertThat(dataCloudDatabaseMetadata.supportsStoredProcedures()).isFalse(); + } + + @Test + public void testSupportsSubqueriesInComparisons() { + assertThat(dataCloudDatabaseMetadata.supportsSubqueriesInComparisons()).isTrue(); + } + + @Test + public void testSupportsSubqueriesInExists() { + assertThat(dataCloudDatabaseMetadata.supportsSubqueriesInExists()).isTrue(); + } + + @Test + public void testSupportsSubqueriesInIns() { + assertThat(dataCloudDatabaseMetadata.supportsSubqueriesInIns()).isTrue(); + } + + @Test + public void testSupportsSubqueriesInQuantifieds() { + assertThat(dataCloudDatabaseMetadata.supportsSubqueriesInQuantifieds()).isTrue(); + } + + @Test + public void testSupportsCorrelatedSubqueries() { + assertThat(dataCloudDatabaseMetadata.supportsCorrelatedSubqueries()).isTrue(); + } + + @Test + public void testSupportsUnion() { + assertThat(dataCloudDatabaseMetadata.supportsUnion()).isTrue(); + } + + @Test + public void testSupportsUnionAll() { + assertThat(dataCloudDatabaseMetadata.supportsUnionAll()).isTrue(); + } + + @Test + public void testSupportsOpenCursorsAcrossCommit() { + assertThat(dataCloudDatabaseMetadata.supportsOpenCursorsAcrossCommit()).isFalse(); + } + + @Test + public void testSupportsOpenCursorsAcrossRollback() { + assertThat(dataCloudDatabaseMetadata.supportsOpenCursorsAcrossRollback()) + .isFalse(); + } + + @Test + public void testSupportsOpenStatementsAcrossCommit() { + assertThat(dataCloudDatabaseMetadata.supportsOpenStatementsAcrossCommit()) + .isFalse(); + } + + @Test + public void testSupportsOpenStatementsAcrossRollback() { + assertThat(dataCloudDatabaseMetadata.supportsOpenStatementsAcrossRollback()) + .isFalse(); + } + + @Test + public void testGetMaxBinaryLiteralLength() { + assertThat(dataCloudDatabaseMetadata.getMaxBinaryLiteralLength()).isEqualTo(0); + } + + @Test + public void testGetMaxCharLiteralLength() { + assertThat(dataCloudDatabaseMetadata.getMaxCharLiteralLength()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnsInGroupBy() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnsInGroupBy()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnsInIndex() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnsInIndex()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnsInOrderBy() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnsInOrderBy()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnsInSelect() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnsInSelect()).isEqualTo(0); + } + + @Test + public void testGetMaxColumnsInTable() { + assertThat(dataCloudDatabaseMetadata.getMaxColumnsInTable()).isEqualTo(0); + } + + @Test + public void testGetMaxConnections() { + assertThat(dataCloudDatabaseMetadata.getMaxConnections()).isEqualTo(0); + } + + @Test + public void testGetMaxCursorNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxCursorNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxIndexLength() { + assertThat(dataCloudDatabaseMetadata.getMaxIndexLength()).isEqualTo(0); + } + + @Test + public void testGetMaxSchemaNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxSchemaNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxProcedureNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxProcedureNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxCatalogNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxCatalogNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxRowSize() { + assertThat(dataCloudDatabaseMetadata.getMaxRowSize()).isEqualTo(0); + } + + @Test + public void testDoesMaxRowSizeIncludeBlobs() { + assertThat(dataCloudDatabaseMetadata.doesMaxRowSizeIncludeBlobs()).isFalse(); + } + + @Test + public void testGetMaxStatementLength() { + assertThat(dataCloudDatabaseMetadata.getMaxStatementLength()).isEqualTo(0); + } + + @Test + public void testGetMaxStatements() { + assertThat(dataCloudDatabaseMetadata.getMaxStatements()).isEqualTo(0); + } + + @Test + public void testGetMaxTableNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxTableNameLength()).isEqualTo(0); + } + + @Test + public void testGetMaxTablesInSelect() { + assertThat(dataCloudDatabaseMetadata.getMaxTablesInSelect()).isEqualTo(0); + } + + @Test + public void testGetMaxUserNameLength() { + assertThat(dataCloudDatabaseMetadata.getMaxUserNameLength()).isEqualTo(0); + } + + @Test + public void testGetDefaultTransactionIsolation() { + assertThat(dataCloudDatabaseMetadata.getDefaultTransactionIsolation()) + .isEqualTo(Connection.TRANSACTION_SERIALIZABLE); + } + + @Test + public void testSupportsTransactions() { + assertThat(dataCloudDatabaseMetadata.supportsTransactions()).isFalse(); + } + + @Test + public void testSupportsTransactionIsolationLevel() { + assertThat(dataCloudDatabaseMetadata.supportsTransactionIsolationLevel(1)) + .isFalse(); + } + + @Test + public void testSupportsDataDefinitionAndDataManipulationTransactions() { + assertThat(dataCloudDatabaseMetadata.supportsDataDefinitionAndDataManipulationTransactions()) + .isFalse(); + } + + @Test + public void testSupportsDataManipulationTransactionsOnly() { + assertThat(dataCloudDatabaseMetadata.supportsDataManipulationTransactionsOnly()) + .isFalse(); + } + + @Test + public void testDataDefinitionCausesTransactionCommit() { + assertThat(dataCloudDatabaseMetadata.dataDefinitionCausesTransactionCommit()) + .isFalse(); + } + + @Test + public void testDataDefinitionIgnoredInTransactions() { + assertThat(dataCloudDatabaseMetadata.dataDefinitionIgnoredInTransactions()) + .isFalse(); + } + + @Test + public void testGetProcedures() { + assertThat(dataCloudDatabaseMetadata.getProcedures(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGetProcedureColumns() { + assertThat(dataCloudDatabaseMetadata.getProcedureColumns( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + @SneakyThrows + public void testGetTables() { + String[] types = new String[] {}; + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = dataCloudDatabaseMetadata.getTables(null, "schemaName", "tableName", types); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_TABLE_METADATA_COLUMNS); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("TABLE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(4)).isEqualTo("TABLE_TYPE"); + assertThat(resultSet.getMetaData().getColumnName(5)).isEqualTo("REMARKS"); + assertThat(resultSet.getMetaData().getColumnName(6)).isEqualTo("TYPE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(7)).isEqualTo("TYPE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(8)).isEqualTo("TYPE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(9)).isEqualTo("SELF_REFERENCING_COL_NAME"); + assertThat(resultSet.getMetaData().getColumnName(10)).isEqualTo("REF_GENERATION"); + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(3)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(4)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(5)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(6)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(7)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(8)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(9)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(10)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetTablesNullValues() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = dataCloudDatabaseMetadata.getTables(null, null, null, null); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(10); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("TABLE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(4)).isEqualTo("TABLE_TYPE"); + assertThat(resultSet.getMetaData().getColumnName(5)).isEqualTo("REMARKS"); + assertThat(resultSet.getMetaData().getColumnName(6)).isEqualTo("TYPE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(7)).isEqualTo("TYPE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(8)).isEqualTo("TYPE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(9)).isEqualTo("SELF_REFERENCING_COL_NAME"); + assertThat(resultSet.getMetaData().getColumnName(10)).isEqualTo("REF_GENERATION"); + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(3)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(4)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(5)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(6)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(7)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(8)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(9)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(10)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetTablesEmptyValues() { + String[] emptyTypes = new String[] {}; + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = + dataCloudDatabaseMetadata.getTables(null, StringUtils.EMPTY, StringUtils.EMPTY, emptyTypes); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(10); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("TABLE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(4)).isEqualTo("TABLE_TYPE"); + assertThat(resultSet.getMetaData().getColumnName(5)).isEqualTo("REMARKS"); + assertThat(resultSet.getMetaData().getColumnName(6)).isEqualTo("TYPE_CAT"); + assertThat(resultSet.getMetaData().getColumnName(7)).isEqualTo("TYPE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(8)).isEqualTo("TYPE_NAME"); + assertThat(resultSet.getMetaData().getColumnName(9)).isEqualTo("SELF_REFERENCING_COL_NAME"); + assertThat(resultSet.getMetaData().getColumnName(10)).isEqualTo("REF_GENERATION"); + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(3)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(4)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(5)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(6)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(7)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(8)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(9)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(10)).isEqualTo("TEXT"); + } + + @SneakyThrows + @Test + public void testGetDataspaces() { + val mapper = new ObjectMapper(); + val oAuthTokenResponse = new OAuthTokenResponse(); + val accessToken = UUID.randomUUID().toString(); + val dataspaceAttributeName = randomString(); + oAuthTokenResponse.setToken(accessToken); + val dataspaceResponse = new DataspaceResponse(); + val dataspaceAttributes = new DataspaceResponse.DataSpaceAttributes(); + dataspaceAttributes.setName(dataspaceAttributeName); + dataspaceResponse.setRecords(List.of(dataspaceAttributes)); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + Mockito.when(tokenProcessor.getOAuthToken()).thenReturn(OAuthToken.of(oAuthTokenResponse)); + + server.enqueue(new MockResponse().setBody(mapper.writeValueAsString(dataspaceResponse))); + val actual = dataCloudDatabaseMetadata.getDataspaces(); + List expected = List.of(dataspaceAttributeName); + assertThat(actual).isEqualTo(expected); + + val actualRequest = server.takeRequest(); + val query = "SELECT+name+from+Dataspace"; + assertThat(actualRequest.getMethod()).isEqualTo("GET"); + assertThat(actualRequest.getRequestUrl()).isEqualTo(server.url("services/data/v61.0/query/?q=" + query)); + assertThat(actualRequest.getBody().readUtf8()).isBlank(); + assertThat(actualRequest.getHeader("Authorization")).isEqualTo("Bearer " + accessToken); + assertThat(actualRequest.getHeader("Content-Type")).isEqualTo("application/json"); + assertThat(actualRequest.getHeader("User-Agent")).isEqualTo("cdp/jdbc"); + assertThat(actualRequest.getHeader("enable-stream-flow")).isEqualTo("false"); + } + } + + @SneakyThrows + @Test + public void testGetDataspacesThrowsExceptionWhenCallFails() { + val oAuthTokenResponse = new OAuthTokenResponse(); + val accessToken = UUID.randomUUID().toString(); + val dataspaceAttributeName = randomString(); + oAuthTokenResponse.setToken(accessToken); + val dataspaceResponse = new DataspaceResponse(); + val dataspaceAttributes = new DataspaceResponse.DataSpaceAttributes(); + dataspaceAttributes.setName(dataspaceAttributeName); + dataspaceResponse.setRecords(List.of(dataspaceAttributes)); + + try (val server = new MockWebServer()) { + server.start(); + oAuthTokenResponse.setInstanceUrl(server.url("").toString()); + Mockito.when(tokenProcessor.getOAuthToken()).thenReturn(OAuthToken.of(oAuthTokenResponse)); + + server.enqueue(new MockResponse().setResponseCode(500)); + Assertions.assertThrows(DataCloudJDBCException.class, () -> dataCloudDatabaseMetadata.getDataspaces()); + } + } + + @SneakyThrows + @Test + public void testGetCatalogs() { + val mapper = new ObjectMapper(); + val oAuthTokenResponse = new OAuthTokenResponse(); + val dataCloudTokenResponse = new DataCloudTokenResponse(); + val dataSpaceName = randomString(); + oAuthTokenResponse.setToken(FAKE_TOKEN); + dataCloudTokenResponse.setToken(FAKE_TOKEN); + dataCloudTokenResponse.setInstanceUrl(FAKE_TENANT_ID); + dataCloudTokenResponse.setTokenType("token"); + + Mockito.when(tokenProcessor.getDataCloudToken()).thenReturn(DataCloudToken.of(dataCloudTokenResponse)); + Mockito.when(tokenProcessor.getSettings()).thenReturn(authenticationSettings); + Mockito.when(authenticationSettings.getDataspace()).thenReturn(dataSpaceName); + + val actual = dataCloudDatabaseMetadata.getCatalogs(); + assertThat(actual.next()).isTrue(); + assertThat(actual.getString(1)).isEqualTo("lakehouse:" + FAKE_TENANT_ID + ";" + dataSpaceName); + assertThat(actual.getMetaData().getColumnName(1)).isEqualTo("TABLE_CAT"); + assertThat(actual.next()).isFalse(); + } + + @Test + @SneakyThrows + public void testGetTableTypes() { + + ResultSet resultSet = dataCloudDatabaseMetadata.getTableTypes(); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_TABLE_TYPES_METADATA_COLUMNS); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_TYPE"); + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetColumnsContainsCorrectMetadata() { + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = dataCloudDatabaseMetadata.getColumns(null, "schemaName", "tableName", "columnName"); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_COLUMN_METADATA_COLUMNS); + + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_CAT"); + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("TABLE_NAME"); + assertThat(resultSet.getMetaData().getColumnTypeName(3)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(4)).isEqualTo("COLUMN_NAME"); + assertThat(resultSet.getMetaData().getColumnTypeName(4)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(5)).isEqualTo("DATA_TYPE"); + assertThat(resultSet.getMetaData().getColumnTypeName(5)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(6)).isEqualTo("TYPE_NAME"); + assertThat(resultSet.getMetaData().getColumnTypeName(6)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(7)).isEqualTo("COLUMN_SIZE"); + assertThat(resultSet.getMetaData().getColumnTypeName(7)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(8)).isEqualTo("BUFFER_LENGTH"); + assertThat(resultSet.getMetaData().getColumnTypeName(8)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(9)).isEqualTo("DECIMAL_DIGITS"); + assertThat(resultSet.getMetaData().getColumnTypeName(9)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(10)).isEqualTo("NUM_PREC_RADIX"); + assertThat(resultSet.getMetaData().getColumnTypeName(10)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(11)).isEqualTo("NULLABLE"); + assertThat(resultSet.getMetaData().getColumnTypeName(11)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(12)).isEqualTo("REMARKS"); + assertThat(resultSet.getMetaData().getColumnTypeName(12)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(13)).isEqualTo("COLUMN_DEF"); + assertThat(resultSet.getMetaData().getColumnTypeName(13)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(14)).isEqualTo("SQL_DATA_TYPE"); + assertThat(resultSet.getMetaData().getColumnTypeName(14)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(15)).isEqualTo("SQL_DATETIME_SUB"); + assertThat(resultSet.getMetaData().getColumnTypeName(15)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(16)).isEqualTo("CHAR_OCTET_LENGTH"); + assertThat(resultSet.getMetaData().getColumnTypeName(16)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(17)).isEqualTo("ORDINAL_POSITION"); + assertThat(resultSet.getMetaData().getColumnTypeName(17)).isEqualTo("INTEGER"); + + assertThat(resultSet.getMetaData().getColumnName(18)).isEqualTo("IS_NULLABLE"); + assertThat(resultSet.getMetaData().getColumnTypeName(18)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(19)).isEqualTo("SCOPE_CATALOG"); + assertThat(resultSet.getMetaData().getColumnTypeName(19)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(20)).isEqualTo("SCOPE_SCHEMA"); + assertThat(resultSet.getMetaData().getColumnTypeName(20)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(21)).isEqualTo("SCOPE_TABLE"); + assertThat(resultSet.getMetaData().getColumnTypeName(21)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(22)).isEqualTo("SOURCE_DATA_TYPE"); + assertThat(resultSet.getMetaData().getColumnTypeName(22)).isEqualTo("SHORT"); + + assertThat(resultSet.getMetaData().getColumnName(23)).isEqualTo("IS_AUTOINCREMENT"); + assertThat(resultSet.getMetaData().getColumnTypeName(23)).isEqualTo("TEXT"); + + assertThat(resultSet.getMetaData().getColumnName(24)).isEqualTo("IS_GENERATEDCOLUMN"); + assertThat(resultSet.getMetaData().getColumnTypeName(24)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetColumnsNullValues() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = dataCloudDatabaseMetadata.getColumns(null, null, null, null); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(24); + assertThat(resultSet.next()).isTrue(); + } + + @Test + @SneakyThrows + public void testGetColumnsEmptyValues() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()) + .thenReturn(true) + .thenReturn(true) + .thenReturn(true) + .thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + + ResultSet resultSet = + dataCloudDatabaseMetadata.getColumns(null, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(24); + assertThat(resultSet.next()).isTrue(); + } + + @Test + public void testTestTest() throws SQLException { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()).thenReturn(true).thenReturn(false); + Mockito.when(resultSetMock.getString("nspname")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("relname")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("attname")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("attname")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("datatype")).thenReturn("TEXT"); + Mockito.when(resultSetMock.getBoolean("attnotnull")).thenReturn(true); + Mockito.when(resultSetMock.getString("description")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("adsrc")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getInt("attnum")).thenReturn(1); + Mockito.when(resultSetMock.getBoolean("attnotnull")).thenReturn(true); + Mockito.when(resultSetMock.getString("attidentity")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("adsrc")).thenReturn(StringUtils.EMPTY); + Mockito.when(resultSetMock.getString("attgenerated")).thenReturn(StringUtils.EMPTY); + + ResultSet columnResultSet = QueryMetadataUtil.createColumnResultSet( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, dataCloudStatement); + while (columnResultSet.next()) { + assertThat(columnResultSet.getString("TYPE_NAME")).isEqualTo("VARCHAR"); + assertThat(columnResultSet.getString("DATA_TYPE")).isEqualTo("12"); + } + } + + @Test + public void testGetColumnPrivileges() throws SQLException { + assertExpectedEmptyResultSet(dataCloudDatabaseMetadata.getColumnPrivileges( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetTablePrivileges() { + assertExpectedEmptyResultSet( + dataCloudDatabaseMetadata.getTablePrivileges(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetBestRowIdentifier() { + assertExpectedEmptyResultSet(dataCloudDatabaseMetadata.getBestRowIdentifier( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, 1, true)); + } + + @Test + @SneakyThrows + public void testGetVersionColumns() { + assertExpectedEmptyResultSet( + dataCloudDatabaseMetadata.getVersionColumns(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetPrimaryKeys() { + assertExpectedEmptyResultSet( + dataCloudDatabaseMetadata.getPrimaryKeys(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetImportedKeys() { + assertExpectedEmptyResultSet( + dataCloudDatabaseMetadata.getImportedKeys(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetExportedKeys() { + assertExpectedEmptyResultSet( + dataCloudDatabaseMetadata.getExportedKeys(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetCrossReference() { + assertExpectedEmptyResultSet(dataCloudDatabaseMetadata.getCrossReference( + StringUtils.EMPTY, + StringUtils.EMPTY, + StringUtils.EMPTY, + StringUtils.EMPTY, + StringUtils.EMPTY, + StringUtils.EMPTY)); + } + + @Test + @SneakyThrows + public void testGetTypeInfo() { + assertExpectedEmptyResultSet(dataCloudDatabaseMetadata.getTypeInfo()); + } + + @Test + @SneakyThrows + public void testGetIndexInfo() { + assertExpectedEmptyResultSet(dataCloudDatabaseMetadata.getIndexInfo( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, true, true)); + } + + @Test + public void testSupportsResultSetType() { + assertThat(dataCloudDatabaseMetadata.supportsResultSetType(1)).isFalse(); + } + + @Test + public void testSupportsResultSetConcurrency() { + assertThat(dataCloudDatabaseMetadata.supportsResultSetConcurrency(1, 1)).isFalse(); + } + + @Test + public void testOwnUpdatesAreVisible() { + assertThat(dataCloudDatabaseMetadata.ownUpdatesAreVisible(1)).isFalse(); + } + + @Test + public void testOwnDeletesAreVisible() { + assertThat(dataCloudDatabaseMetadata.ownDeletesAreVisible(1)).isFalse(); + } + + @Test + public void testOwnInsertsAreVisible() { + assertThat(dataCloudDatabaseMetadata.ownInsertsAreVisible(1)).isFalse(); + } + + @Test + public void testOthersUpdatesAreVisible() { + assertThat(dataCloudDatabaseMetadata.othersUpdatesAreVisible(1)).isFalse(); + } + + @Test + public void testOthersDeletesAreVisible() { + assertThat(dataCloudDatabaseMetadata.othersDeletesAreVisible(1)).isFalse(); + } + + @Test + public void testOthersInsertsAreVisible() { + assertThat(dataCloudDatabaseMetadata.othersInsertsAreVisible(1)).isFalse(); + } + + @Test + public void testUpdatesAreDetected() { + assertThat(dataCloudDatabaseMetadata.updatesAreDetected(1)).isFalse(); + } + + @Test + public void testDeletesAreDetected() { + assertThat(dataCloudDatabaseMetadata.deletesAreDetected(1)).isFalse(); + } + + @Test + public void testInsertsAreDetected() { + assertThat(dataCloudDatabaseMetadata.insertsAreDetected(1)).isFalse(); + } + + @Test + public void testSupportsBatchUpdates() { + assertThat(dataCloudDatabaseMetadata.supportsBatchUpdates()).isFalse(); + } + + @Test + public void testGetUDTs() { + assertThat(dataCloudDatabaseMetadata.getUDTs( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, new int[] {})) + .isNull(); + } + + @Test + public void testGetConnection() { + assertThat(dataCloudDatabaseMetadata.getConnection()).isNull(); + } + + @Test + public void testSupportsSavepoints() { + assertThat(dataCloudDatabaseMetadata.supportsSavepoints()).isFalse(); + } + + @Test + public void testSupportsNamedParameters() { + assertThat(dataCloudDatabaseMetadata.supportsNamedParameters()).isFalse(); + } + + @Test + public void testSupportsMultipleOpenResults() { + assertThat(dataCloudDatabaseMetadata.supportsMultipleOpenResults()).isFalse(); + } + + @Test + public void testSupportsGetGeneratedKeys() { + assertThat(dataCloudDatabaseMetadata.supportsGetGeneratedKeys()).isFalse(); + } + + @Test + public void testGetSuperTypes() { + assertThat(dataCloudDatabaseMetadata.getSuperTypes(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGetSuperTables() { + assertThat(dataCloudDatabaseMetadata.getSuperTables(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGetAttributes() { + assertThat(dataCloudDatabaseMetadata.getAttributes( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testSupportsResultSetHoldability() { + assertThat(dataCloudDatabaseMetadata.supportsResultSetHoldability(1)).isFalse(); + } + + @Test + public void testGetResultSetHoldability() { + assertThat(dataCloudDatabaseMetadata.getResultSetHoldability()).isEqualTo(0); + } + + @Test + public void testGetDatabaseMajorVersion() { + assertThat(dataCloudDatabaseMetadata.getDatabaseMajorVersion()).isEqualTo(1); + } + + @Test + public void testGetDatabaseMinorVersion() { + assertThat(dataCloudDatabaseMetadata.getDatabaseMinorVersion()).isEqualTo(0); + } + + @Test + public void testGetJDBCMajorVersion() { + assertThat(dataCloudDatabaseMetadata.getJDBCMajorVersion()).isEqualTo(1); + } + + @Test + public void testGetJDBCMinorVersion() { + assertThat(dataCloudDatabaseMetadata.getJDBCMinorVersion()).isEqualTo(0); + } + + @Test + public void testGetSQLStateType() { + assertThat(dataCloudDatabaseMetadata.getSQLStateType()).isEqualTo(0); + } + + @Test + public void testLocatorsUpdateCopy() { + assertThat(dataCloudDatabaseMetadata.locatorsUpdateCopy()).isFalse(); + } + + @Test + public void testSupportsStatementPooling() { + assertThat(dataCloudDatabaseMetadata.supportsStatementPooling()).isFalse(); + } + + @Test + public void testGetRowIdLifetime() { + assertThat(dataCloudDatabaseMetadata.getRowIdLifetime()).isNull(); + } + + @Test + @SneakyThrows + public void testGetSchemas() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()).thenReturn(true).thenReturn(true).thenReturn(false); + Mockito.when(resultSetMock.getString("TABLE_SCHEM")).thenReturn(null); + Mockito.when(resultSetMock.getString("TABLE_CATALOG")).thenReturn(null); + + ResultSet resultSet = dataCloudDatabaseMetadata.getSchemas(); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_SCHEMA_METADATA_COLUMNS); + + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_CATALOG"); + while (resultSet.next()) { + assertThat(resultSet.getString("TABLE_SCHEM")).isEqualTo(null); + assertThat(resultSet.getString("TABLE_CATALOG")).isEqualTo(null); + } + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetSchemasCatalogAndSchemaPattern() { + String schemaPattern = "public"; + String tableCatalog = "catalog"; + Mockito.when(resultSetMock.next()).thenReturn(true).thenReturn(true).thenReturn(false); + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.getString("TABLE_SCHEM")).thenReturn(schemaPattern); + Mockito.when(resultSetMock.getString("TABLE_CATALOG")).thenReturn(tableCatalog); + + ResultSet resultSet = dataCloudDatabaseMetadata.getSchemas(null, "schemaName"); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_SCHEMA_METADATA_COLUMNS); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_CATALOG"); + while (resultSet.next()) { + assertThat(resultSet.getString("TABLE_SCHEM")).isEqualTo(schemaPattern); + assertThat(resultSet.getString("TABLE_CATALOG")).isEqualTo(tableCatalog); + } + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetSchemasCatalogAndSchemaPatternNullValues() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()).thenReturn(true).thenReturn(true).thenReturn(false); + Mockito.when(resultSetMock.getString("TABLE_SCHEM")).thenReturn(null); + Mockito.when(resultSetMock.getString("TABLE_CATALOG")).thenReturn(null); + + ResultSet resultSet = dataCloudDatabaseMetadata.getSchemas(null, null); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_SCHEMA_METADATA_COLUMNS); + + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_CATALOG"); + while (resultSet.next()) { + assertThat(resultSet.getString("TABLE_SCHEM")).isEqualTo(null); + assertThat(resultSet.getString("TABLE_CATALOG")).isEqualTo(null); + } + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + } + + @Test + @SneakyThrows + public void testGetSchemasEmptyValues() { + Mockito.when(dataCloudStatement.executeQuery(anyString())).thenReturn(resultSetMock); + Mockito.when(resultSetMock.next()).thenReturn(true).thenReturn(true).thenReturn(false); + Mockito.when(resultSetMock.getString("TABLE_SCHEM")).thenReturn(null); + Mockito.when(resultSetMock.getString("TABLE_CATALOG")).thenReturn(null); + + ResultSet resultSet = dataCloudDatabaseMetadata.getSchemas(StringUtils.EMPTY, StringUtils.EMPTY); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(NUM_SCHEMA_METADATA_COLUMNS); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("TABLE_SCHEM"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("TABLE_CATALOG"); + while (resultSet.next()) { + assertThat(resultSet.getString("TABLE_SCHEM")).isEqualTo(null); + assertThat(resultSet.getString("TABLE_CATALOG")).isEqualTo(null); + } + + assertThat(resultSet.getMetaData().getColumnTypeName(1)).isEqualTo("TEXT"); + assertThat(resultSet.getMetaData().getColumnTypeName(2)).isEqualTo("TEXT"); + } + + @Test + public void testSupportsStoredFunctionsUsingCallSyntax() { + assertThat(dataCloudDatabaseMetadata.supportsStoredFunctionsUsingCallSyntax()) + .isFalse(); + } + + @Test + public void testAutoCommitFailureClosesAllResultSets() { + assertThat(dataCloudDatabaseMetadata.autoCommitFailureClosesAllResultSets()) + .isFalse(); + } + + @Test + public void testGetClientInfoProperties() { + assertThat(dataCloudDatabaseMetadata.getClientInfoProperties()).isNull(); + } + + @Test + public void testGetFunctions() { + assertThat(dataCloudDatabaseMetadata.getFunctions(StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGetFunctionColumns() { + assertThat(dataCloudDatabaseMetadata.getFunctionColumns( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGetPseudoColumns() { + assertThat(dataCloudDatabaseMetadata.getPseudoColumns( + StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY, StringUtils.EMPTY)) + .isNull(); + } + + @Test + public void testGeneratedKeyAlwaysReturned() { + assertThat(dataCloudDatabaseMetadata.generatedKeyAlwaysReturned()).isFalse(); + } + + @Test + public void testUnwrap() { + try { + assertThat(dataCloudDatabaseMetadata.unwrap(DataCloudDatabaseMetadata.class)) + .isInstanceOf(DataCloudDatabaseMetadata.class); + } catch (Exception e) { + fail("Uncaught Exception", e); + } + val ex = assertThrows(DataCloudJDBCException.class, () -> dataCloudDatabaseMetadata.unwrap(String.class)); + } + + @Test + public void testIsWrapperFor() { + try { + assertThat(dataCloudDatabaseMetadata.isWrapperFor(DataCloudDatabaseMetadata.class)) + .isTrue(); + } catch (Exception e) { + fail("Uncaught Exception", e); + } + } + + @Test + public void testQuoteStringLiteral() { + String unescapedString = "unescaped"; + String actual = QueryMetadataUtil.quoteStringLiteral(unescapedString); + assertThat(actual).isEqualTo("'unescaped'"); + } + + @Test + public void testQuoteStringLiteralSingleQuote() { + char singleQuote = '\''; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(singleQuote))) + .isEqualTo("''''"); + } + + @Test + public void testQuoteStringLiteralBackslash() { + char backslash = '\\'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(backslash))) + .isEqualTo("E'\\\\'"); + } + + @Test + public void testQuoteStringLiteralNewline() { + char newLine = '\n'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(newLine))) + .isEqualTo("E'\\n'"); + } + + @Test + public void testQuoteStringLiteralCarriageReturn() { + char carriageReturn = '\r'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(carriageReturn))) + .isEqualTo("E'\\r'"); + } + + @Test + public void testQuoteStringLiteralTab() { + char tab = '\t'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(tab))).isEqualTo("E'\\t'"); + } + + @Test + public void testQuoteStringLiteralBackspace() { + char backspace = '\b'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(backspace))) + .isEqualTo("E'\\b'"); + } + + @Test + public void testQuoteStringLiteralFormFeed() { + char formFeed = '\f'; + assertThat(QueryMetadataUtil.quoteStringLiteral(String.valueOf(formFeed))) + .isEqualTo("E'\\f'"); + } + + @SneakyThrows + private void assertExpectedEmptyResultSet(ResultSet resultSet) { + assertThat(resultSet).isNotNull(); + assertFalse(resultSet.next()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java new file mode 100644 index 0000000..427d194 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementHyperTest.java @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; +import java.sql.Connection; +import java.sql.Date; +import java.sql.PreparedStatement; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Calendar; +import java.util.TimeZone; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.junit.jupiter.api.Test; + +@Slf4j +public class DataCloudPreparedStatementHyperTest extends HyperTestBase { + @Test + @SneakyThrows + public void testPreparedStatementDateRange() { + LocalDate startDate = LocalDate.of(2024, 1, 1); + LocalDate endDate = LocalDate.of(2024, 1, 5); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalDate date = startDate; !date.isAfter(endDate); date = date.plusDays(1)) { + val sqlDate = Date.valueOf(date); + preparedStatement.setDate(1, sqlDate); + + try (var resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + assertThat(resultSet.getDate("a")) + .isEqualTo(sqlDate) + .as("Expected the date to be %s but got %s", sqlDate, resultSet.getDate("a")); + } + } + } + } + } + } + + @Test + @SneakyThrows + public void testPreparedStatementDateWithCalendarRange() { + LocalDate startDate = LocalDate.of(2024, 1, 1); + LocalDate endDate = LocalDate.of(2024, 1, 5); + + TimeZone plusTwoTimeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(plusTwoTimeZone); + + TimeZone utcTimeZone = TimeZone.getTimeZone("UTC"); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalDate date = startDate; !date.isAfter(endDate); date = date.plusDays(1)) { + val sqlDate = Date.valueOf(date); + preparedStatement.setDate(1, sqlDate, calendar); + + val time = sqlDate.getTime(); + + val dateTime = new Timestamp(time).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(plusTwoTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(utcTimeZone.toZoneId()); + val expected = + Date.valueOf(convertedDateTime.toLocalDateTime().toLocalDate()); + + try (var resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + val actual = resultSet.getDate("a"); + assertThat(actual) + .isEqualTo(expected) + .as("Expected the date to be %s in UTC timezone but got %s", sqlDate, actual); + } + } + } + } + } + } + + @Test + @SneakyThrows + public void testPreparedStatementTimeRange() { + LocalTime startTime = LocalTime.of(10, 0, 0, 0); + LocalTime endTime = LocalTime.of(15, 0, 0, 0); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalTime time = startTime; !time.isAfter(endTime); time = time.plusHours(1)) { + val sqlTime = Time.valueOf(time); + preparedStatement.setTime(1, sqlTime); + + try (val resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + Time actual = resultSet.getTime("a"); + assertThat(actual) + .isEqualTo(sqlTime) + .as("Expected the date to be %s but got %s", sqlTime, actual); + } + } + } + } + } + } + + @Test + @SneakyThrows + public void testPreparedStatementTimeWithCalendarRange() { + LocalTime startTime = LocalTime.of(10, 0, 0, 0); + LocalTime endTime = LocalTime.of(15, 0, 0, 0); + + TimeZone plusTwoTimeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(plusTwoTimeZone); + + TimeZone utcTimeZone = TimeZone.getTimeZone("UTC"); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalTime time = startTime; !time.isAfter(endTime); time = time.plusHours(1)) { + + val sqlTime = Time.valueOf(time); + preparedStatement.setTime(1, sqlTime, calendar); + + val dateTime = new Timestamp(sqlTime.getTime()).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(plusTwoTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(utcTimeZone.toZoneId()); + val expected = Time.valueOf(convertedDateTime.toLocalTime()); + + try (val resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + val actual = resultSet.getTime("a"); + assertThat(actual) + .isEqualTo(expected) + .as("Expected the date to be %s in UTC timezone but got %s", sqlTime, actual); + } + } + } + } + } + } + + @Test + @SneakyThrows + public void testPreparedStatementTimestampRange() { + LocalDateTime startDateTime = LocalDateTime.of(2024, 1, 1, 0, 0); + LocalDateTime endDateTime = LocalDateTime.of(2024, 1, 5, 0, 0); + + TimeZone utcTimeZone = TimeZone.getTimeZone("UTC"); + Calendar utcCalendar = Calendar.getInstance(utcTimeZone); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalDateTime dateTime = startDateTime; + !dateTime.isAfter(endDateTime); + dateTime = dateTime.plusDays(1)) { + val sqlTimestamp = Timestamp.valueOf(dateTime); + preparedStatement.setTimestamp(1, sqlTimestamp); + + try (var resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + val actual = resultSet.getTimestamp("a", utcCalendar); + assertThat(actual) + .isEqualTo(sqlTimestamp) + .as("Expected the date to be %s in UTC timezone but got %s", sqlTimestamp, actual); + } + } + } + } + } + } + + @Test + @SneakyThrows + public void testPreparedStatementTimestampWithCalendarRange() { + LocalDateTime startDateTime = LocalDateTime.of(2024, 1, 1, 0, 0); + LocalDateTime endDateTime = LocalDateTime.of(2024, 1, 5, 0, 0); + + TimeZone plusTwoTimeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(plusTwoTimeZone); + + TimeZone utcTimeZone = TimeZone.getTimeZone("UTC"); + Calendar utcCalendar = Calendar.getInstance(utcTimeZone); + + try (Connection connection = getHyperQueryConnection()) { + try (PreparedStatement preparedStatement = connection.prepareStatement("select ? as a")) { + for (LocalDateTime dateTime = startDateTime; + !dateTime.isAfter(endDateTime); + dateTime = dateTime.plusDays(1)) { + + val sqlTimestamp = Timestamp.valueOf(dateTime); + preparedStatement.setTimestamp(1, sqlTimestamp, calendar); + + val localDateTime = sqlTimestamp.toLocalDateTime(); + + val zonedDateTime = localDateTime.atZone(plusTwoTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(utcTimeZone.toZoneId()); + val expected = Timestamp.valueOf(convertedDateTime.toLocalDateTime()); + + try (var resultSet = preparedStatement.executeQuery()) { + while (resultSet.next()) { + val actual = resultSet.getTimestamp("a", utcCalendar); + assertThat(actual) + .isEqualTo(expected) + .as("Expected the date to be %s in UTC timezone but got %s", sqlTimestamp, actual); + } + } + } + } + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java new file mode 100644 index 0000000..d63acf5 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudPreparedStatementTest.java @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.DateTimeUtils; +import com.salesforce.datacloud.jdbc.util.GrpcUtils; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.QueryParam; +import io.grpc.StatusRuntimeException; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.util.Calendar; +import java.util.List; +import java.util.Properties; +import java.util.TimeZone; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.assertj.core.api.ThrowingConsumer; +import org.grpcmock.GrpcMock; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedStatic; + +public class DataCloudPreparedStatementTest extends HyperGrpcTestBase { + + @Mock + private DataCloudConnection mockConnection; + + @Mock + private ParameterManager mockParameterManager; + + private DataCloudPreparedStatement preparedStatement; + + private final Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("GMT+2")); + + @BeforeEach + public void beforeEach() { + + mockConnection = mock(DataCloudConnection.class); + val properties = new Properties(); + when(mockConnection.getExecutor()).thenReturn(hyperGrpcClient); + when(mockConnection.getProperties()).thenReturn(properties); + + mockParameterManager = mock(ParameterManager.class); + + preparedStatement = new DataCloudPreparedStatement(mockConnection, mockParameterManager); + } + + @Test + @SneakyThrows + public void testExecuteQuery() { + setupHyperGrpcClientWithMockedResultSet("query id", List.of()); + ResultSet resultSet = preparedStatement.executeQuery("SELECT * FROM table"); + assertNotNull(resultSet); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(3); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("id"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("name"); + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("grade"); + } + + @Test + @SneakyThrows + public void testExecute() { + setupHyperGrpcClientWithMockedResultSet("query id", List.of()); + preparedStatement.execute("SELECT * FROM table"); + ResultSet resultSet = preparedStatement.getResultSet(); + assertNotNull(resultSet); + assertThat(resultSet.getMetaData().getColumnCount()).isEqualTo(3); + assertThat(resultSet.getMetaData().getColumnName(1)).isEqualTo("id"); + assertThat(resultSet.getMetaData().getColumnName(2)).isEqualTo("name"); + assertThat(resultSet.getMetaData().getColumnName(3)).isEqualTo("grade"); + } + + @SneakyThrows + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testForceSyncOverride(boolean forceSync) { + val p = new Properties(); + p.setProperty(Constants.FORCE_SYNC, Boolean.toString(forceSync)); + when(mockConnection.getProperties()).thenReturn(p); + + val statement = new DataCloudPreparedStatement(mockConnection, mockParameterManager); + + setupHyperGrpcClientWithMockedResultSet( + "query id", List.of(), forceSync ? QueryParam.TransferMode.SYNC : QueryParam.TransferMode.ADAPTIVE); + ResultSet response = statement.executeQuery("SELECT * FROM table"); + AssertionsForClassTypes.assertThat(statement.isReady()).isTrue(); + assertNotNull(response); + AssertionsForClassTypes.assertThat(response.getMetaData().getColumnCount()) + .isEqualTo(3); + } + + @Test + public void testExecuteQueryWithSqlException() { + StatusRuntimeException fakeException = GrpcUtils.getFakeStatusRuntimeExceptionAsInvalidArgument(); + + GrpcMock.stubFor(GrpcMock.unaryMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .willReturn(GrpcMock.exception(fakeException))); + + assertThrows(DataCloudJDBCException.class, () -> preparedStatement.executeQuery("SELECT * FROM table")); + } + + @Test + @SneakyThrows + void testClearParameters() { + preparedStatement.setString(1, "TEST"); + preparedStatement.setInt(2, 123); + + preparedStatement.clearParameters(); + verify(mockParameterManager).clearParameters(); + } + + @Test + void testSetParameterNegativeIndexThrowsSQLException() { + ParameterManager parameterManager = new DefaultParameterManager(); + preparedStatement = new DataCloudPreparedStatement(mockConnection, parameterManager); + + assertThatThrownBy(() -> preparedStatement.setString(0, "TEST")) + .isInstanceOf(DataCloudJDBCException.class) + .hasMessageContaining("Parameter index must be greater than 0"); + + assertThatThrownBy(() -> preparedStatement.setString(-1, "TEST")) + .isInstanceOf(DataCloudJDBCException.class) + .hasMessageContaining("Parameter index must be greater than 0"); + } + + @Test + @SneakyThrows + void testAllSetMethods() { + preparedStatement.setString(1, "TEST"); + verify(mockParameterManager).setParameter(1, Types.VARCHAR, "TEST"); + + preparedStatement.setBoolean(2, true); + verify(mockParameterManager).setParameter(2, Types.BOOLEAN, true); + + preparedStatement.setByte(3, (byte) 1); + verify(mockParameterManager).setParameter(3, Types.TINYINT, (byte) 1); + + preparedStatement.setShort(4, (short) 2); + verify(mockParameterManager).setParameter(4, Types.SMALLINT, (short) 2); + + preparedStatement.setInt(5, 3); + verify(mockParameterManager).setParameter(5, Types.INTEGER, 3); + + preparedStatement.setLong(6, 4L); + verify(mockParameterManager).setParameter(6, Types.BIGINT, 4L); + + preparedStatement.setFloat(7, 5.0f); + verify(mockParameterManager).setParameter(7, Types.FLOAT, 5.0f); + + preparedStatement.setDouble(8, 6.0); + verify(mockParameterManager).setParameter(8, Types.DOUBLE, 6.0); + + preparedStatement.setBigDecimal(9, new java.math.BigDecimal("7.0")); + verify(mockParameterManager).setParameter(9, Types.DECIMAL, new java.math.BigDecimal("7.0")); + + Date date = new Date(System.currentTimeMillis()); + preparedStatement.setDate(10, date); + verify(mockParameterManager).setParameter(10, Types.DATE, date); + + Time time = new Time(System.currentTimeMillis()); + preparedStatement.setTime(11, time); + verify(mockParameterManager).setParameter(11, Types.TIME, time); + + Timestamp timestamp = new Timestamp(System.currentTimeMillis()); + preparedStatement.setTimestamp(12, timestamp); + verify(mockParameterManager).setParameter(12, Types.TIMESTAMP, timestamp); + + preparedStatement.setNull(13, Types.NULL); + verify(mockParameterManager).setParameter(13, Types.NULL, null); + + preparedStatement.setObject(14, "TEST"); + verify(mockParameterManager).setParameter(14, Types.VARCHAR, "TEST"); + + preparedStatement.setObject(15, null); + verify(mockParameterManager).setParameter(15, Types.NULL, null); + + preparedStatement.setObject(16, "TEST", Types.VARCHAR); + verify(mockParameterManager).setParameter(16, Types.VARCHAR, "TEST"); + + preparedStatement.setObject(17, null, Types.VARCHAR); + verify(mockParameterManager).setParameter(17, Types.NULL, null); + + try (MockedStatic mockedDateTimeUtil = mockStatic(DateTimeUtils.class)) { + mockedDateTimeUtil + .when(() -> DateTimeUtils.getUTCDateFromDateAndCalendar(Date.valueOf("1970-01-01"), calendar)) + .thenReturn(Date.valueOf("1969-12-31")); + + preparedStatement.setDate(18, Date.valueOf("1970-01-01"), calendar); + + mockedDateTimeUtil.verify( + () -> DateTimeUtils.getUTCDateFromDateAndCalendar(Date.valueOf("1970-01-01"), calendar), times(1)); + verify(mockParameterManager).setParameter(18, Types.DATE, Date.valueOf("1969-12-31")); + } + + try (MockedStatic mockedDateTimeUtil = mockStatic(DateTimeUtils.class)) { + mockedDateTimeUtil + .when(() -> DateTimeUtils.getUTCTimeFromTimeAndCalendar(Time.valueOf("00:00:00"), calendar)) + .thenReturn(Time.valueOf("22:00:00")); + + preparedStatement.setTime(19, Time.valueOf("00:00:00"), calendar); + + mockedDateTimeUtil.verify( + () -> DateTimeUtils.getUTCTimeFromTimeAndCalendar(Time.valueOf("00:00:00"), calendar), times(1)); + verify(mockParameterManager).setParameter(19, Types.TIME, Time.valueOf("22:00:00")); + } + + try (MockedStatic mockedDateTimeUtil = mockStatic(DateTimeUtils.class)) { + mockedDateTimeUtil + .when(() -> DateTimeUtils.getUTCTimestampFromTimestampAndCalendar( + Timestamp.valueOf("1970-01-01 00:00:00.000000000"), calendar)) + .thenReturn(Timestamp.valueOf("1969-12-31 22:00:00.000000000")); + + preparedStatement.setTimestamp(20, Timestamp.valueOf("1970-01-01 00:00:00.000000000"), calendar); + + mockedDateTimeUtil.verify( + () -> DateTimeUtils.getUTCTimestampFromTimestampAndCalendar( + Timestamp.valueOf("1970-01-01 00:00:00.000000000"), calendar), + times(1)); + verify(mockParameterManager) + .setParameter(20, Types.TIMESTAMP, Timestamp.valueOf("1969-12-31 22:00:00.000000000")); + } + + assertThatThrownBy(() -> preparedStatement.setObject(1, new InvalidClass())) + .isInstanceOf(DataCloudJDBCException.class) + .hasMessageContaining("Object type not supported for:"); + } + + private static Arguments impl(String name, ThrowingConsumer impl) { + return arguments(named(name, impl)); + } + + private static Stream unsupported() { + return Stream.of( + impl("setAsciiStream", s -> s.setAsciiStream(1, null, 0)), + impl("setUnicodeStream", s -> s.setUnicodeStream(1, null, 0)), + impl("setBinaryStream", s -> s.setBinaryStream(1, null, 0)), + impl("addBatch", DataCloudPreparedStatement::addBatch), + impl("clearBatch", DataCloudStatement::clearBatch), + impl("setCharacterStream", s -> s.setCharacterStream(1, null, 0)), + impl("setRef", s -> s.setRef(1, null)), + impl("setBlob", s -> s.setBlob(1, (Blob) null)), + impl("setClob", s -> s.setClob(1, (Clob) null)), + impl("setArray", s -> s.setArray(1, null)), + impl("setURL", s -> s.setURL(1, null)), + impl("setRowId", s -> s.setRowId(1, null)), + impl("setNString", s -> s.setNString(1, null)), + impl("setNCharacterStream", s -> s.setNCharacterStream(1, null, 0)), + impl("setNClob", s -> s.setNClob(1, (NClob) null)), + impl("setClob", s -> s.setClob(1, null, 0)), + impl("setBlob", s -> s.setBlob(1, null, 0)), + impl("setNClob", s -> s.setNClob(1, null, 0)), + impl("setSQLXML", s -> s.setSQLXML(1, null)), + impl("setObject", s -> s.setObject(1, null, Types.OTHER, 0)), + impl("setAsciiStream", s -> s.setAsciiStream(1, null, (long) 0)), + impl("setUnicodeStream", s -> s.setUnicodeStream(1, null, 0)), + impl("setBinaryStream", s -> s.setBinaryStream(1, null, (long) 0)), + impl("setAsciiStream", s -> s.setAsciiStream(1, null, 0)), + impl("setBinaryStream", s -> s.setBinaryStream(1, null, 0)), + impl("setCharacterStream", s -> s.setCharacterStream(1, null, (long) 0)), + impl("setAsciiStream", s -> s.setAsciiStream(1, null)), + impl("setBinaryStream", s -> s.setBinaryStream(1, null)), + impl("setCharacterStream", s -> s.setCharacterStream(1, null)), + impl("setNCharacterStream", s -> s.setNCharacterStream(1, null)), + impl("setClob", s -> s.setClob(1, (Reader) null)), + impl("setBlob", s -> s.setBlob(1, (InputStream) null)), + impl("setNClob", s -> s.setNClob(1, (Reader) null)), + impl("setBytes", s -> s.setBytes(1, null)), + impl("setNull", s -> s.setNull(1, Types.ARRAY, "ARRAY")), + impl("executeUpdate", DataCloudPreparedStatement::executeUpdate), + impl("executeUpdate", s -> s.executeUpdate("")), + impl("addBatch", s -> s.addBatch("")), + impl("executeBatch", DataCloudStatement::executeBatch), + impl("executeUpdate", s -> s.executeUpdate("", Statement.RETURN_GENERATED_KEYS)), + impl("executeUpdate", s -> s.executeUpdate("", new int[] {})), + impl("executeUpdate", s -> s.executeUpdate("", new String[] {})), + impl("getMetaData", DataCloudPreparedStatement::getMetaData), + impl("getParameterMetaData", DataCloudPreparedStatement::getParameterMetaData)); + } + + @ParameterizedTest + @MethodSource("unsupported") + void testUnsupportedOperations(ThrowingConsumer func) { + val e = Assertions.assertThrows(RuntimeException.class, () -> func.accept(preparedStatement)); + AssertionsForClassTypes.assertThat(e).hasRootCauseInstanceOf(DataCloudJDBCException.class); + AssertionsForClassTypes.assertThat(e.getCause()) + .hasMessageContaining("is not supported in Data Cloud query") + .hasFieldOrPropertyWithValue("SQLState", SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Test + @SneakyThrows + void testUnwrapWithCorrectInterface() { + DataCloudPreparedStatement result = preparedStatement.unwrap(DataCloudPreparedStatement.class); + assertThat(result).isExactlyInstanceOf(DataCloudPreparedStatement.class); + + assertThatThrownBy(() -> preparedStatement.unwrap(String.class)) + .isExactlyInstanceOf(DataCloudJDBCException.class) + .hasMessageContaining("Cannot unwrap to java.lang.String"); + + assertThat(preparedStatement.isWrapperFor(DataCloudPreparedStatement.class)) + .isTrue(); + assertThat(preparedStatement.isWrapperFor(String.class)).isFalse(); + } + + @Test + void testSetQueryTimeout() { + preparedStatement.setQueryTimeout(30); + assertEquals(30, preparedStatement.getQueryTimeout()); + + preparedStatement.setQueryTimeout(-1); + assertThat(preparedStatement.getQueryTimeout()).isEqualTo(DataCloudStatement.DEFAULT_QUERY_TIMEOUT); + } + + @Test + void testTypeHandlerMapInitialization() { + assertEquals(TypeHandlers.STRING_HANDLER, TypeHandlers.typeHandlerMap.get(String.class)); + assertEquals(TypeHandlers.BIGDECIMAL_HANDLER, TypeHandlers.typeHandlerMap.get(BigDecimal.class)); + assertEquals(TypeHandlers.SHORT_HANDLER, TypeHandlers.typeHandlerMap.get(Short.class)); + assertEquals(TypeHandlers.INTEGER_HANDLER, TypeHandlers.typeHandlerMap.get(Integer.class)); + assertEquals(TypeHandlers.LONG_HANDLER, TypeHandlers.typeHandlerMap.get(Long.class)); + assertEquals(TypeHandlers.FLOAT_HANDLER, TypeHandlers.typeHandlerMap.get(Float.class)); + assertEquals(TypeHandlers.DOUBLE_HANDLER, TypeHandlers.typeHandlerMap.get(Double.class)); + assertEquals(TypeHandlers.DATE_HANDLER, TypeHandlers.typeHandlerMap.get(Date.class)); + assertEquals(TypeHandlers.TIME_HANDLER, TypeHandlers.typeHandlerMap.get(Time.class)); + assertEquals(TypeHandlers.TIMESTAMP_HANDLER, TypeHandlers.typeHandlerMap.get(Timestamp.class)); + assertEquals(TypeHandlers.BOOLEAN_HANDLER, TypeHandlers.typeHandlerMap.get(Boolean.class)); + } + + @Test + @SneakyThrows + void testAllTypeHandlers() { + PreparedStatement ps = mock(PreparedStatement.class); + + TypeHandlers.STRING_HANDLER.setParameter(ps, 1, "test"); + verify(ps).setString(1, "test"); + + TypeHandlers.BIGDECIMAL_HANDLER.setParameter(ps, 1, new BigDecimal("123.45")); + verify(ps).setBigDecimal(1, new BigDecimal("123.45")); + + TypeHandlers.SHORT_HANDLER.setParameter(ps, 1, (short) 123); + verify(ps).setShort(1, (short) 123); + + TypeHandlers.INTEGER_HANDLER.setParameter(ps, 1, 123); + verify(ps).setInt(1, 123); + + TypeHandlers.LONG_HANDLER.setParameter(ps, 1, 123L); + verify(ps).setLong(1, 123L); + + TypeHandlers.FLOAT_HANDLER.setParameter(ps, 1, 123.45f); + verify(ps).setFloat(1, 123.45f); + + TypeHandlers.DOUBLE_HANDLER.setParameter(ps, 1, 123.45); + verify(ps).setDouble(1, 123.45); + + TypeHandlers.DATE_HANDLER.setParameter(ps, 1, Date.valueOf("2024-08-15")); + verify(ps).setDate(1, Date.valueOf("2024-08-15")); + + TypeHandlers.TIME_HANDLER.setParameter(ps, 1, Time.valueOf("12:34:56")); + verify(ps).setTime(1, Time.valueOf("12:34:56")); + + TypeHandlers.TIMESTAMP_HANDLER.setParameter(ps, 1, Timestamp.valueOf("2024-08-15 12:34:56")); + verify(ps).setTimestamp(1, Timestamp.valueOf("2024-08-15 12:34:56")); + + TypeHandlers.BOOLEAN_HANDLER.setParameter(ps, 1, true); + verify(ps).setBoolean(1, true); + } + + static class InvalidClass {} +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java new file mode 100644 index 0000000..bd508d9 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DataCloudStatementTest.java @@ -0,0 +1,205 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.when; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.GrpcUtils; +import com.salesforce.datacloud.jdbc.util.RequestRecordingInterceptor; +import com.salesforce.datacloud.jdbc.util.SqlErrorCodes; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.QueryParam; +import io.grpc.StatusRuntimeException; +import java.sql.ResultSet; +import java.util.List; +import java.util.Properties; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.grpcmock.GrpcMock; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.Mockito; + +@ExtendWith(InProcessGrpcMockExtension.class) +public class DataCloudStatementTest extends HyperGrpcTestBase { + @Mock + private DataCloudConnection connection; + + static DataCloudStatement statement; + + @BeforeEach + public void beforeEach() { + connection = Mockito.mock(DataCloudConnection.class); + val properties = new Properties(); + Mockito.when(connection.getExecutor()).thenReturn(hyperGrpcClient); + Mockito.when(connection.getProperties()).thenReturn(properties); + statement = new DataCloudStatement(connection); + } + + @Test + @SneakyThrows + public void forwardOnly() { + assertThat(statement.getFetchDirection()).isEqualTo(ResultSet.FETCH_FORWARD); + assertThat(statement.getResultSetType()).isEqualTo(ResultSet.TYPE_FORWARD_ONLY); + } + + private static Stream unsupportedBatchExecutes() { + return Stream.of( + () -> statement.execute("", 1), + () -> statement.execute("", new int[] {}), + () -> statement.execute("", new String[] {})); + } + + @ParameterizedTest + @MethodSource("unsupportedBatchExecutes") + @SneakyThrows + public void batchExecutesAreNotSupported(Executable func) { + val ex = Assertions.assertThrows(DataCloudJDBCException.class, func); + assertThat(ex) + .hasMessage("Batch execution is not supported in Data Cloud query") + .hasFieldOrPropertyWithValue("SQLState", SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @SneakyThrows + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testForceSyncOverride(boolean forceSync) { + val p = new Properties(); + p.setProperty(Constants.FORCE_SYNC, Boolean.toString(forceSync)); + when(connection.getProperties()).thenReturn(p); + + val statement = new DataCloudStatement(connection); + + setupHyperGrpcClientWithMockedResultSet( + "query id", List.of(), forceSync ? QueryParam.TransferMode.SYNC : QueryParam.TransferMode.ADAPTIVE); + ResultSet response = statement.executeQuery("SELECT * FROM table"); + AssertionsForClassTypes.assertThat(statement.isReady()).isTrue(); + assertNotNull(response); + AssertionsForClassTypes.assertThat(response.getMetaData().getColumnCount()) + .isEqualTo(3); + } + + @Test + @SneakyThrows + public void testExecuteQuery() { + setupHyperGrpcClientWithMockedResultSet("query id", List.of()); + ResultSet response = statement.executeQuery("SELECT * FROM table"); + assertThat(statement.isReady()).isTrue(); + assertNotNull(response); + assertThat(response.getMetaData().getColumnCount()).isEqualTo(3); + assertThat(response.getMetaData().getColumnName(1)).isEqualTo("id"); + assertThat(response.getMetaData().getColumnName(2)).isEqualTo("name"); + assertThat(response.getMetaData().getColumnName(3)).isEqualTo("grade"); + } + + @Test + @SneakyThrows + public void testExecute() { + setupHyperGrpcClientWithMockedResultSet("query id", List.of()); + statement.execute("SELECT * FROM table"); + ResultSet response = statement.getResultSet(); + assertNotNull(response); + assertThat(response.getMetaData().getColumnCount()).isEqualTo(3); + assertThat(response.getMetaData().getColumnName(1)).isEqualTo("id"); + assertThat(response.getMetaData().getColumnName(2)).isEqualTo("name"); + assertThat(response.getMetaData().getColumnName(3)).isEqualTo("grade"); + } + + @Test + @SneakyThrows + public void testExecuteQueryIncludesInterceptorsProvidedByCaller() { + setupHyperGrpcClientWithMockedResultSet("abc", List.of()); + val interceptor = new RequestRecordingInterceptor(); + Mockito.when(connection.getInterceptors()).thenReturn(List.of(interceptor)); + + assertThat(interceptor.getQueries().size()).isEqualTo(0); + statement.executeQuery("SELECT * FROM table"); + assertThat(interceptor.getQueries().size()).isEqualTo(1); + statement.executeQuery("SELECT * FROM table"); + assertThat(interceptor.getQueries().size()).isEqualTo(2); + statement.executeQuery("SELECT * FROM table"); + assertThat(interceptor.getQueries().size()).isEqualTo(3); + assertDoesNotThrow(() -> statement.close()); + } + + @Test + public void testExecuteQueryWithSqlException() { + StatusRuntimeException fakeException = GrpcUtils.getFakeStatusRuntimeExceptionAsInvalidArgument(); + + GrpcMock.stubFor(GrpcMock.unaryMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .willReturn(GrpcMock.exception(fakeException))); + + assertThrows(DataCloudJDBCException.class, () -> statement.executeQuery("SELECT * FROM table")); + } + + @Test + public void testExecuteUpdate() { + String sql = "UPDATE table SET column = value"; + val e = assertThrows(DataCloudJDBCException.class, () -> statement.executeUpdate(sql)); + assertThat(e) + .hasMessageContaining("is not supported in Data Cloud query") + .hasFieldOrPropertyWithValue("SQLState", SqlErrorCodes.FEATURE_NOT_SUPPORTED); + } + + @Test + public void testSetQueryTimeoutNegativeValue() { + statement.setQueryTimeout(-100); + assertThat(statement.getQueryTimeout()).isEqualTo(DataCloudStatement.DEFAULT_QUERY_TIMEOUT); + } + + @Test + public void testGetQueryTimeoutDefaultValue() { + assertThat(statement.getQueryTimeout()).isEqualTo(DataCloudStatement.DEFAULT_QUERY_TIMEOUT); + } + + @Test + public void testGetQueryTimeoutSetByConfig() { + Properties properties = new Properties(); + properties.setProperty("queryTimeout", Integer.toString(30)); + connection = Mockito.mock(DataCloudConnection.class); + Mockito.when(connection.getProperties()).thenReturn(properties); + val statement = new DataCloudStatement(connection); + assertThat(statement.getQueryTimeout()).isEqualTo(30); + } + + @Test + public void testGetQueryTimeoutSetInQueryStatementLevel() { + statement.setQueryTimeout(10); + assertThat(statement.getQueryTimeout()).isEqualTo(10); + } + + @Test + @SneakyThrows + public void testCloseIsNullSafe() { + assertDoesNotThrow(() -> statement.close()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/DefaultParameterManagerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/DefaultParameterManagerTest.java new file mode 100644 index 0000000..2c2d523 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/DefaultParameterManagerTest.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.salesforce.datacloud.jdbc.core.model.ParameterBinding; +import java.sql.SQLException; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class DefaultParameterManagerTest { + + private DefaultParameterManager parameterManager; + + @BeforeEach + void setUp() { + parameterManager = new DefaultParameterManager(); + } + + @Test + void testSetParameterValidIndex() throws SQLException { + parameterManager.setParameter(1, java.sql.Types.VARCHAR, "TEST"); + List parameters = parameterManager.getParameters(); + + assertEquals(1, parameters.size()); + assertEquals("TEST", parameters.get(0).getValue()); + assertEquals(java.sql.Types.VARCHAR, parameters.get(0).getSqlType()); + } + + @Test + void testSetParameterExpandingList() throws SQLException { + parameterManager.setParameter(3, java.sql.Types.INTEGER, 42); + List parameters = parameterManager.getParameters(); + + assertEquals(3, parameters.size()); + assertNull(parameters.get(0)); + assertNull(parameters.get(1)); + assertEquals(42, parameters.get(2).getValue()); + assertEquals(java.sql.Types.INTEGER, parameters.get(2).getSqlType()); + } + + @Test + void testSetParameterNegativeIndexThrowsSQLException() { + SQLException thrown = assertThrows( + SQLException.class, () -> parameterManager.setParameter(0, java.sql.Types.VARCHAR, "TEST")); + assertEquals("Parameter index must be greater than 0", thrown.getMessage()); + + thrown = assertThrows( + SQLException.class, () -> parameterManager.setParameter(-1, java.sql.Types.VARCHAR, "TEST")); + assertEquals("Parameter index must be greater than 0", thrown.getMessage()); + } + + @Test + void testClearParameters() throws SQLException { + parameterManager.setParameter(1, java.sql.Types.VARCHAR, "TEST"); + parameterManager.setParameter(2, java.sql.Types.INTEGER, 123); + + parameterManager.clearParameters(); + List parameters = parameterManager.getParameters(); + + assertTrue(parameters.isEmpty()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannelTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannelTest.java new file mode 100644 index 0000000..128cd5d --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/ExecuteQueryResponseChannelTest.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.google.protobuf.ByteString; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryResultPartBinary; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.junit.jupiter.api.Test; + +@Slf4j +class ExecuteQueryResponseChannelTest { + @Test + void isNotEmptyDetectsEmpty() { + val empty = ByteBuffer.allocateDirect(0); + assertThat(ExecuteQueryResponseChannel.isNotEmpty(empty)).isFalse(); + } + + @Test + void isNotEmptyDetectsNotEmpty() { + val notEmpty = ByteBuffer.wrap("not empty".getBytes(StandardCharsets.UTF_8)); + assertThat(ExecuteQueryResponseChannel.isNotEmpty(notEmpty)).isTrue(); + } + + @Test + @SneakyThrows + void isOpenDetectsIfIteratorIsExhausted() { + try (val channel = ExecuteQueryResponseChannel.of(empty())) { + assertThat(channel.isOpen()).isFalse(); + } + } + + @Test + @SneakyThrows + void isOpenDetectsIfIteratorHasRemaining() { + try (val channel = ExecuteQueryResponseChannel.of(some())) { + assertThat(channel.isOpen()).isTrue(); + } + } + + @Test + @SneakyThrows + void readReturnsNegativeOneOnIteratorExhaustion() { + try (val channel = ExecuteQueryResponseChannel.of(empty())) { + assertThat(channel.read(ByteBuffer.allocateDirect(0))).isEqualTo(-1); + } + } + + @SneakyThrows + @Test + void readIsLazy() { + val first = ByteBuffer.allocate(5); + val second = ByteBuffer.allocate(5); + val seen = new ArrayList(); + + val stream = infiniteStream().peek(seen::add); + + val channel = new ReadChannel(ExecuteQueryResponseChannel.of(stream)); + + channel.readFully(first); + assertThat(seen).hasSize(5); + channel.readFully(second); + assertThat(seen).hasSize(10); + + assertThat(new String(first.array(), StandardCharsets.UTF_8)).isEqualTo("01234"); + assertThat(new String(second.array(), StandardCharsets.UTF_8)).isEqualTo("56789"); + } + + private static Stream some() { + return infiniteStream(); + } + + private static Stream empty() { + return infiniteStream().limit(0); + } + + private static Stream infiniteStream() { + return Stream.iterate(0, i -> i + 1) + .map(i -> Integer.toString(i)) + .map(ExecuteQueryResponseChannelTest::toMessage); + } + + private static QueryResult toMessage(String string) { + val byteString = ByteString.copyFromUtf8(string); + val binaryPart = QueryResultPartBinary.newBuilder().setData(byteString); + return QueryResult.newBuilder().setBinaryPart(binaryPart).build(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettingsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettingsTest.java new file mode 100644 index 0000000..6f5adec --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperConnectionSettingsTest.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import io.grpc.inprocess.InProcessChannelBuilder; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import lombok.SneakyThrows; +import lombok.val; +import org.grpcmock.GrpcMock; +import org.junit.jupiter.api.Test; + +class HyperConnectionSettingsTest extends HyperGrpcTestBase { + private static final String HYPER_SETTING = "serverSetting."; + + @Test + void testGetSettingWithCorrectPrefix() { + Map expected = Map.of("lc_time", "en_US"); + Properties properties = new Properties(); + properties.setProperty(HYPER_SETTING + "lc_time", "en_US"); + properties.setProperty("username", "alice"); + HyperConnectionSettings hyperConnectionSettings = HyperConnectionSettings.of(properties); + assertThat(hyperConnectionSettings.getSettings()).containsExactlyInAnyOrderEntriesOf(expected); + } + + @Test + void testGetSettingReturnEmptyResultSet() { + Map expected = Map.of(); + Properties properties = new Properties(); + properties.setProperty("c_time", "en_US"); + properties.setProperty("username", "alice"); + HyperConnectionSettings hyperConnectionSettings = HyperConnectionSettings.of(properties); + assertThat(hyperConnectionSettings.getSettings()).containsExactlyInAnyOrderEntriesOf(expected); + } + + @Test + void testGetSettingWithEmptyProperties() { + Map expected = Map.of(); + Properties properties = new Properties(); + HyperConnectionSettings hyperConnectionSettings = HyperConnectionSettings.of(properties); + assertThat(hyperConnectionSettings.getSettings()).containsExactlyInAnyOrderEntriesOf(expected); + } + + @SneakyThrows + @Test + void itSubmitsSettingsOnCall() { + val key = UUID.randomUUID().toString(); + val setting = UUID.randomUUID().toString(); + val properties = new Properties(); + val actual = new AtomicReference>(); + properties.setProperty(HYPER_SETTING + key, setting); + val channel = InProcessChannelBuilder.forName(GrpcMock.getGlobalInProcessName()) + .usePlaintext(); + try (val client = HyperGrpcClientExecutor.of(channel, properties)) { + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(t -> { + actual.set(t.getSettingsMap()); + return true; + }) + .willReturn(List.of(executeQueryResponse("", null, null)))); + + client.executeQuery("").next(); + } + + assertThat(actual.get()).containsOnly(Map.entry(key, setting)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientRetryTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientRetryTest.java new file mode 100644 index 0000000..0ef7db6 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientRetryTest.java @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.ByteString; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryResultPartBinary; +import io.grpc.ManagedChannel; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import lombok.Getter; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class HyperGrpcClientRetryTest { + private HyperServiceGrpc.HyperServiceBlockingStub hyperServiceStub; + private HyperServiceImpl hyperService; + private ManagedChannel channel; + + @BeforeEach + public void setUpClient() throws IOException { + String serverName = InProcessServerBuilder.generateName(); + InProcessServerBuilder serverBuilder = + InProcessServerBuilder.forName(serverName).directExecutor(); + + hyperService = new HyperServiceImpl(); + serverBuilder.addService(hyperService); + serverBuilder.build().start(); + + channel = InProcessChannelBuilder.forName(serverName) + .usePlaintext() + .enableRetry() + .maxRetryAttempts(5) + .directExecutor() + .defaultServiceConfig(retryPolicy()) + .build(); + + hyperServiceStub = HyperServiceGrpc.newBlockingStub(channel); + } + + @SneakyThrows + @AfterEach + public void cleanupClient() { + if (channel != null) { + channel.shutdown(); + + try { + assertTrue(channel.awaitTermination(5, TimeUnit.SECONDS)); + } finally { + channel.shutdownNow(); + } + } + } + + private Map retryPolicy() { + return Map.of( + "methodConfig", + List.of(Map.of( + "name", + List.of(Collections.EMPTY_MAP), + "retryPolicy", + Map.of( + "maxAttempts", + String.valueOf(5), + "initialBackoff", + "0.5s", + "maxBackoff", + "30s", + "backoffMultiplier", + 2.0, + "retryableStatusCodes", + List.of("UNAVAILABLE"))))); + } + + private final String query = "SELECT * FROM test"; + private static final ExecuteQueryResponse chunk1 = ExecuteQueryResponse.newBuilder() + .setBinaryPart(QueryResultPartBinary.newBuilder() + .setData(ByteString.copyFromUtf8("test 1")) + .build()) + .build(); + + @Test + public void testExecuteQueryWithRetry() { + Iterator queryResultIterator = hyperServiceStub.executeQuery( + QueryParam.newBuilder().setQuery(query).build()); + + assertDoesNotThrow(() -> { + boolean responseReceived = false; + while (queryResultIterator.hasNext()) { + ExecuteQueryResponse response = queryResultIterator.next(); + if (response.getBinaryPart().getData().toStringUtf8().equals("test 1")) { + responseReceived = true; + } + } + assertTrue(responseReceived, "Expected response not received after retries."); + }); + + Assertions.assertThat(hyperService.getRetryCount()).isEqualTo(5); + } + + @Getter + @Slf4j + public static class HyperServiceImpl extends HyperServiceGrpc.HyperServiceImplBase { + int retryCount = 1; + + @Override + public void executeQuery(QueryParam request, StreamObserver responseObserver) { + log.warn("Executing query attempt #{}", retryCount); + if (retryCount < 5) { + retryCount++; + responseObserver.onError(Status.UNAVAILABLE + .withDescription("Service unavailable") + .asRuntimeException()); + return; + } + + responseObserver.onNext(chunk1); + responseObserver.onCompleted(); + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientTest.java new file mode 100644 index 0000000..15a0382 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcClientTest.java @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import com.google.protobuf.ByteString; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.OutputFormat; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryResultPartBinary; +import java.sql.SQLException; +import java.util.Iterator; +import org.grpcmock.GrpcMock; +import org.junit.jupiter.api.Test; + +class HyperGrpcClientTest extends HyperGrpcTestBase { + + private static final ExecuteQueryResponse chunk1 = ExecuteQueryResponse.newBuilder() + .setBinaryPart(QueryResultPartBinary.newBuilder() + .setData(ByteString.copyFromUtf8("test 1")) + .build()) + .build(); + + @Test + public void testExecuteQuery() throws SQLException { + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .willReturn(chunk1)); + + String query = "SELECT * FROM test"; + Iterator queryResultIterator = hyperGrpcClient.executeQuery(query); + assertDoesNotThrow(() -> { + while (queryResultIterator.hasNext()) { + queryResultIterator.next(); + } + }); + + QueryParam expectedQueryParam = QueryParam.newBuilder() + .setQuery(query) + .setOutputFormat(OutputFormat.ARROW_V3) + .setTransferMode(QueryParam.TransferMode.SYNC) + .build(); + GrpcMock.verifyThat( + GrpcMock.calledMethod(HyperServiceGrpc.getExecuteQueryMethod()).withRequest(expectedQueryParam)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java new file mode 100644 index 0000000..9884466 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/HyperGrpcTestBase.java @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; + +import com.salesforce.datacloud.jdbc.auth.AuthenticationSettings; +import com.salesforce.datacloud.jdbc.auth.DataCloudToken; +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.inprocess.InProcessChannelBuilder; +import java.io.IOException; +import java.sql.SQLException; +import java.util.List; +import java.util.Properties; +import java.util.function.UnaryOperator; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.grpcmock.GrpcMock; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; + +@ExtendWith(InProcessGrpcMockExtension.class) +public class HyperGrpcTestBase { + + protected static HyperGrpcClientExecutor hyperGrpcClient; + + @Mock + protected TokenProcessor mockSession; + + @Mock + protected DataCloudToken mockToken; + + @Mock + protected AuthenticationSettings mockSettings; + + @BeforeEach + public void setUpClient() throws SQLException, IOException { + mockToken = mock(DataCloudToken.class); + lenient().when(mockToken.getAccessToken()).thenReturn("1234"); + lenient().when(mockToken.getTenantId()).thenReturn("testTenantId"); + lenient().when(mockToken.getTenantUrl()).thenReturn("tenant.salesforce.com"); + + mockSettings = mock(AuthenticationSettings.class); + lenient().when(mockSettings.getUserAgent()).thenReturn("userAgent"); + lenient().when(mockSettings.getDataspace()).thenReturn("testDataspace"); + + mockSession = mock(TokenProcessor.class); + lenient().when(mockSession.getDataCloudToken()).thenReturn(mockToken); + lenient().when(mockSession.getSettings()).thenReturn(mockSettings); + + val channel = InProcessChannelBuilder.forName(GrpcMock.getGlobalInProcessName()) + .usePlaintext(); + hyperGrpcClient = HyperGrpcClientExecutor.of(channel, new Properties()); + } + + @SneakyThrows + @AfterEach + public void cleanup() { + if (hyperGrpcClient != null) { + hyperGrpcClient.close(); + } + } + + private void willReturn(List responses, QueryParam.TransferMode mode) { + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(req -> mode == null || req.getTransferMode() == mode) + .willReturn(GrpcMock.stream(responses))); + } + + private Stream queryStatusResponse(String queryId) { + return Stream.of(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus( + QueryStatus.newBuilder().setQueryId(queryId).build()) + .build()) + .build()); + } + + public void setupHyperGrpcClientWithMockedResultSet( + String expectedQueryId, List students) { + setupHyperGrpcClientWithMockedResultSet(expectedQueryId, students, null); + } + + public void setupHyperGrpcClientWithMockedResultSet( + String expectedQueryId, List students, QueryParam.TransferMode mode) { + willReturn( + Stream.concat( + queryStatusResponse(expectedQueryId), + RealisticArrowGenerator.getMockedData(students) + .map(t -> ExecuteQueryResponse.newBuilder() + .setQueryResult(t) + .build())) + .collect(Collectors.toList()), + mode); + } + + public void setupExecuteQuery( + String queryId, String query, QueryParam.TransferMode mode, ExecuteQueryResponse... responses) { + val first = ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus( + QueryStatus.newBuilder().setQueryId(queryId).build()) + .build()) + .build(); + + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(req -> req.getQuery().equals(query) && req.getTransferMode() == mode) + .willReturn(Stream.concat(Stream.of(first), Stream.of(responses)) + .collect(Collectors.toUnmodifiableList()))); + } + + public void setupGetQueryInfo(String queryId, QueryStatus.CompletionStatus completionStatus) { + setupGetQueryInfo(queryId, completionStatus, 1); + } + + protected void setupGetQueryInfo(String queryId, QueryStatus.CompletionStatus completionStatus, int chunkCount) { + val queryInfo = QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setQueryId(queryId) + .setCompletionStatus(completionStatus) + .setChunkCount(chunkCount) + .build()) + .build(); + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getGetQueryInfoMethod()) + .withRequest(req -> req.getQueryId().equals(queryId)) + .willReturn(queryInfo)); + } + + protected void verifyGetQueryInfo(int times) { + GrpcMock.verifyThat(GrpcMock.calledMethod(HyperServiceGrpc.getGetQueryInfoMethod()), GrpcMock.times(times)); + } + + public void setupGetQueryResult( + String queryId, int chunkId, int parts, List students) { + val results = IntStream.range(0, parts) + .mapToObj(i -> RealisticArrowGenerator.getMockedData(students)) + .flatMap(UnaryOperator.identity()) + .collect(Collectors.toList()); + + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getGetQueryResultMethod()) + .withRequest(req -> req.getQueryId().equals(queryId) + && req.getChunkId() == chunkId + && req.getOmitSchema() == chunkId > 0) + .willReturn(results)); + } + + public void setupAdaptiveInitialResults( + String sql, + String queryId, + int parts, + Integer chunks, + QueryStatus.CompletionStatus status, + List students) { + val results = IntStream.range(0, parts) + .mapToObj(i -> RealisticArrowGenerator.getMockedData(students)) + .flatMap(UnaryOperator.identity()) + .map(r -> ExecuteQueryResponse.newBuilder().setQueryResult(r).build()); + + val response = Stream.concat( + Stream.of(executeQueryResponse(queryId, null, null)), + Stream.concat(results, Stream.of(executeQueryResponse(queryId, status, chunks)))) + .collect(Collectors.toUnmodifiableList()); + + GrpcMock.stubFor(GrpcMock.serverStreamingMethod(HyperServiceGrpc.getExecuteQueryMethod()) + .withRequest(req -> req.getQuery().equals(sql)) + .willReturn(response)); + } + + public static ExecuteQueryResponse executeQueryResponseWithData(List students) { + val result = RealisticArrowGenerator.getMockedData(students).findFirst().orElseThrow(); + return ExecuteQueryResponse.newBuilder().setQueryResult(result).build(); + } + + public static ExecuteQueryResponse executeQueryResponse( + String queryId, QueryStatus.CompletionStatus status, Integer chunkCount) { + val queryStatus = QueryStatus.newBuilder().setQueryId(queryId); + + if (status != null) { + queryStatus.setCompletionStatus(status); + } + + if (chunkCount != null) { + queryStatus.setChunkCount(chunkCount); + } + + return ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder().setQueryStatus(queryStatus).build()) + .build(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadataTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadataTest.java new file mode 100644 index 0000000..bf52f8a --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryDBMetadataTest.java @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.Types; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class QueryDBMetadataTest { + private static final List COLUMN_NAMES = Arrays.asList( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "TYPE_NAME", + "COLUMN_SIZE", + "BUFFER_LENGTH", + "DECIMAL_DIGITS", + "NUM_PREC_RADIX", + "NULLABLE", + "REMARKS", + "COLUMN_DEF", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", + "IS_NULLABLE", + "SCOPE_CATALOG", + "SCOPE_SCHEMA", + "SCOPE_TABLE", + "SOURCE_DATA_TYPE", + "IS_AUTOINCREMENT", + "IS_GENERATEDCOLUMN"); + + private static final List COLUMN_TYPES = Arrays.asList( + "TEXT", "TEXT", "TEXT", "TEXT", "INTEGER", "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT", "SHORT", "TEXT", + "TEXT"); + + private static final List COLUMN_TYPE_IDS = Arrays.asList( + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.INTEGER, + Types.VARCHAR, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.VARCHAR, + Types.VARCHAR, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.INTEGER, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.VARCHAR, + Types.SMALLINT, + Types.VARCHAR, + Types.VARCHAR); + + @Test + public void testGetColumnNames() { + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnNames()).isEqualTo(COLUMN_NAMES); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnNames().size()).isEqualTo(24); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnNames().get(0)).isEqualTo("TABLE_CAT"); + } + + @Test + public void testGetColumnTypes() { + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypes()).isEqualTo(COLUMN_TYPES); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypes().size()).isEqualTo(24); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypes().get(0)).isEqualTo("TEXT"); + } + + @Test + public void testGetColumnTypeIds() { + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypeIds()).isEqualTo(COLUMN_TYPE_IDS); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypeIds().size()).isEqualTo(24); + assertThat(QueryDBMetadata.GET_COLUMNS.getColumnTypeIds().get(0)).isEqualTo(Types.VARCHAR); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCAccessorFactoryTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCAccessorFactoryTest.java new file mode 100644 index 0000000..eda865a --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCAccessorFactoryTest.java @@ -0,0 +1,382 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import com.salesforce.datacloud.jdbc.core.accessor.impl.BaseIntVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.BinaryVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.BooleanVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DateVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DecimalVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.DoubleVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.LargeListVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.ListVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.TimeStampVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.TimeVectorAccessor; +import com.salesforce.datacloud.jdbc.core.accessor.impl.VarCharVectorAccessor; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.function.IntSupplier; +import lombok.SneakyThrows; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class QueryJDBCAccessorFactoryTest { + public static final IntSupplier GET_CURRENT_ROW = () -> 0; + + List binaryList = List.of( + "BINARY_DATA_0001".getBytes(StandardCharsets.UTF_8), + "BINARY_DATA_0002".getBytes(StandardCharsets.UTF_8), + "BINARY_DATA_0003".getBytes(StandardCharsets.UTF_8)); + + List uint4List = List.of( + 0, + 1, + -1, + (int) Byte.MIN_VALUE, + (int) Byte.MAX_VALUE, + (int) Short.MIN_VALUE, + (int) Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE); + + @RegisterExtension + public static RootAllocatorTestExtension rootAllocatorTestExtension = new RootAllocatorTestExtension(); + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsVarChar() { + try (ValueVector valueVector = new VarCharVector("VarChar", rootAllocatorTestExtension.getRootAllocator())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(VarCharVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsLargeVarChar() { + try (ValueVector valueVector = + new LargeVarCharVector("LargeVarChar", rootAllocatorTestExtension.getRootAllocator())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(VarCharVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsDecimal() { + try (ValueVector valueVector = rootAllocatorTestExtension.createDecimalVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(DecimalVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsBoolean() { + try (ValueVector valueVector = rootAllocatorTestExtension.createBitVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BooleanVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsFloat8() { + try (ValueVector valueVector = rootAllocatorTestExtension.createFloat8Vector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(DoubleVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsDateDay() { + try (ValueVector valueVector = rootAllocatorTestExtension.createDateDayVector()) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(DateVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsDateMilli() { + try (ValueVector valueVector = rootAllocatorTestExtension.createDateMilliVector()) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(DateVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeNano() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeNanoVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(TimeVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeMicro() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeMicroVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(TimeVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeMilli() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeMilliVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(TimeVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeSec() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeSecVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + assertThat(accessor).isInstanceOf(TimeVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsUnsupportedVector() { + try (ValueVector valueVector = new NullVector("Null")) { + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> QueryJDBCAccessorFactory.createAccessor( + valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {})); + } + } + + @Test + @SneakyThrows + public void testCreateAccessorCorrectlyDetectsTinyInt() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTinyIntVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BaseIntVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + public void testCreateAccessorCorrectlyDetectsSmallInt() { + try (ValueVector valueVector = rootAllocatorTestExtension.createSmallIntVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BaseIntVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + public void testCreateAccessorCorrectlyDetectsInt() { + try (ValueVector valueVector = rootAllocatorTestExtension.createIntVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BaseIntVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + public void testCreateAccessorCorrectlyDetectsBigInt() { + try (ValueVector valueVector = rootAllocatorTestExtension.createBigIntVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BaseIntVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + public void testCreateAccessorCorrectlyDetectsUInt4() { + try (ValueVector valueVector = rootAllocatorTestExtension.createUInt4Vector(uint4List)) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + Assertions.assertInstanceOf(BaseIntVectorAccessor.class, accessor); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createVarBinaryVector(binaryList)) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(BinaryVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsLargeVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createLargeVarBinaryVector(binaryList)) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(BinaryVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsFixedSizeBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createFixedSizeBinaryVector(binaryList)) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(BinaryVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampNanoVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampNanoVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampNanoTZVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampNanoTZVector(List.of(), "UTC")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampMicroVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampMicroVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampMicroTZVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampMicroTZVector(List.of(), "UTC")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampMilliVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampMilliVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampMilliTZVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampMilliTZVector(List.of(), "UTC")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampSecVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampSecVector(List.of())) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsTimeStampSecTZVector() { + try (ValueVector valueVector = rootAllocatorTestExtension.createTimeStampSecTZVector(List.of(), "UTC")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(TimeStampVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsListVectorAccessor() { + try (ValueVector valueVector = rootAllocatorTestExtension.createListVector("list-vector")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(ListVectorAccessor.class); + } + } + + @Test + @SneakyThrows + void testCreateAccessorCorrectlyDetectsLargeListVectorAccessor() { + try (ValueVector valueVector = rootAllocatorTestExtension.createLargeListVector("large-list-vector")) { + QueryJDBCAccessor accessor = + QueryJDBCAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, (boolean wasNull) -> {}); + + assertThat(accessor).isInstanceOf(LargeListVectorAccessor.class); + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursorTest.java new file mode 100644 index 0000000..e05151e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCCursorTest.java @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.google.common.collect.ImmutableList; +import java.sql.SQLException; +import java.util.Collections; +import java.util.List; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.util.Cursor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class QueryJDBCCursorTest { + + static QueryJDBCCursor cursor; + BufferAllocator allocator; + + @AfterEach + public void tearDown() { + allocator.close(); + cursor.close(); + } + + @Test + public void testVarCharVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("VarChar", new ArrowType.Utf8()); + ((VarCharVector) root.getVector("VarChar")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDecimalVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Decimal", new ArrowType.Decimal(38, 18, 128)); + ((DecimalVector) root.getVector("Decimal")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testBooleanVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Boolean", new ArrowType.Bool()); + ((BitVector) root.getVector("Boolean")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateMilliVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("DateMilli", new ArrowType.Date(DateUnit.MILLISECOND)); + ((DateMilliVector) root.getVector("DateMilli")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateDayVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("DateDay", new ArrowType.Date(DateUnit.DAY)); + ((DateDayVector) root.getVector("DateDay")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeNanoVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeNano", new ArrowType.Time(TimeUnit.NANOSECOND, 64)); + ((TimeNanoVector) root.getVector("TimeNano")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeMicroVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeMicro", new ArrowType.Time(TimeUnit.MICROSECOND, 64)); + ((TimeMicroVector) root.getVector("TimeMicro")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeMilliVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeMilli", new ArrowType.Time(TimeUnit.MILLISECOND, 32)); + ((TimeMilliVector) root.getVector("TimeMilli")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeSecVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeSec", new ArrowType.Time(TimeUnit.SECOND, 32)); + ((TimeSecVector) root.getVector("TimeSec")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeStampVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = + getVectorSchemaRoot("TimeStamp", new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)); + ((TimeStampMilliVector) root.getVector("TimeStamp")).setNull(0); + testCursorWasNull(root); + } + + private VectorSchemaRoot getVectorSchemaRoot(String name, ArrowType arrowType) { + final Schema schema = new Schema( + ImmutableList.of(new Field(name, new FieldType(true, arrowType, null), Collections.emptyList()))); + allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + return root; + } + + private void testCursorWasNull(VectorSchemaRoot root) throws SQLException { + root.setRowCount(1); + cursor = new QueryJDBCCursor(root); + cursor.next(); + List accessorList = cursor.createAccessors(null, null, null); + accessorList.get(0).getObject(); + assertThat(cursor.wasNull()).as("cursor.wasNull()").isTrue(); + root.close(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCDataCursorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCDataCursorTest.java new file mode 100644 index 0000000..b7eaccb --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryJDBCDataCursorTest.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class QueryJDBCDataCursorTest { + + static MetadataCursor cursor; + + @AfterEach + public void tearDown() { + cursor.close(); + } + + @Test + public void testCursorNextAndRowCount() { + List data = new ArrayList<>(); + data.add("Account_home_dll"); + cursor = new MetadataCursor(data); + cursor.next(); + assertThat(cursor.next()).isFalse(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadataTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadataTest.java new file mode 100644 index 0000000..b198c77 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/QueryResultSetMetadataTest.java @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.sql.ResultSetMetaData; +import java.util.List; +import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class QueryResultSetMetadataTest { + QueryDBMetadata queryDBMetadata = QueryDBMetadata.GET_COLUMNS; + + QueryResultSetMetadata queryResultSetMetadata; + + @BeforeEach + public void init() { + queryResultSetMetadata = new QueryResultSetMetadata(queryDBMetadata); + } + + @Test + public void testGetColumnCount() { + assertThat(queryResultSetMetadata.getColumnCount()).isEqualTo(24); + } + + @Test + public void testIsAutoIncrement() { + assertThat(queryResultSetMetadata.isAutoIncrement(1)).isFalse(); + } + + @Test + public void testIsCaseSensitive() { + assertThat(queryResultSetMetadata.isCaseSensitive(1)).isTrue(); + } + + @Test + public void testIsSearchable() { + assertThat(queryResultSetMetadata.isSearchable(1)).isFalse(); + } + + @Test + public void testIsCurrency() { + assertThat(queryResultSetMetadata.isCurrency(1)).isFalse(); + } + + @Test + public void testIsNullable() { + assertThat(queryResultSetMetadata.isNullable(1)).isEqualTo(1); + } + + @Test + public void testIsSigned() { + assertThat(queryResultSetMetadata.isSigned(1)).isFalse(); + } + + @Test + public void testGetColumnDisplaySize() { + assertThat(queryResultSetMetadata.getColumnDisplaySize(1)).isEqualTo(9); + assertThat(queryResultSetMetadata.getColumnDisplaySize(5)).isEqualTo(39); + } + + @Test + public void testGetColumnLabel() { + for (int i = 1; i <= queryDBMetadata.getColumnNames().size(); i++) { + assertThat(queryResultSetMetadata.getColumnLabel(i)) + .isEqualTo(queryDBMetadata.getColumnNames().get(i - 1)); + } + } + + @Test + public void testGetColumnLabelWithNullColumnNameReturnsDefaultValue() { + queryResultSetMetadata = new QueryResultSetMetadata(null, List.of("Col1"), List.of(1)); + assertThat(queryResultSetMetadata.getColumnLabel(1)).isEqualTo("C0"); + } + + @Test + public void testGetColumnName() { + for (int i = 1; i <= queryDBMetadata.getColumnNames().size(); i++) { + assertThat(queryResultSetMetadata.getColumnName(i)) + .isEqualTo(queryDBMetadata.getColumnNames().get(i - 1)); + } + } + + @Test + public void testGetSchemaName() { + assertThat(queryResultSetMetadata.getSchemaName(1)).isEqualTo(StringUtils.EMPTY); + } + + @Test + public void testGetPrecision() { + assertThat(queryResultSetMetadata.getPrecision(1)).isEqualTo(9); + assertThat(queryResultSetMetadata.getPrecision(5)).isEqualTo(38); + } + + @Test + public void testGetScale() { + assertThat(queryResultSetMetadata.getScale(1)).isEqualTo(0); + assertThat(queryResultSetMetadata.getScale(5)).isEqualTo(18); + } + + @Test + public void testGetTableName() { + assertThat(queryResultSetMetadata.getTableName(1)).isEqualTo(StringUtils.EMPTY); + } + + @Test + public void testGetCatalogName() { + assertThat(queryResultSetMetadata.getCatalogName(1)).isEqualTo(StringUtils.EMPTY); + } + + @Test + public void getColumnType() { + for (int i = 1; i <= queryDBMetadata.getColumnTypeIds().size(); i++) { + assertThat(queryResultSetMetadata.getColumnType(i)) + .isEqualTo(queryDBMetadata.getColumnTypeIds().get(i - 1)); + } + } + + @Test + public void getColumnTypeName() { + for (int i = 1; i <= queryDBMetadata.getColumnTypes().size(); i++) { + assertThat(queryResultSetMetadata.getColumnTypeName(i)) + .isEqualTo(queryDBMetadata.getColumnTypes().get(i - 1)); + } + } + + @Test + public void testIsReadOnly() { + assertThat(queryResultSetMetadata.isReadOnly(1)).isTrue(); + } + + @Test + public void isWritable() { + assertThat(queryResultSetMetadata.isWritable(1)).isFalse(); + } + + @Test + public void isDefinitelyWritable() { + assertThat(queryResultSetMetadata.isDefinitelyWritable(1)).isFalse(); + } + + @Test + public void getColumnClassName() { + assertThat(queryResultSetMetadata.getColumnClassName(1)).isNull(); + } + + @Test + public void unwrap() { + assertThat(queryResultSetMetadata.unwrap(ResultSetMetaData.class)).isNull(); + } + + @Test + public void isWrapperFor() { + assertThat(queryResultSetMetadata.isWrapperFor(ResultSetMetaData.class)).isFalse(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java new file mode 100644 index 0000000..16fdc4b --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/StreamingResultSetTest.java @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +import com.salesforce.datacloud.jdbc.hyper.HyperTestBase; +import com.salesforce.datacloud.jdbc.util.ThrowingBiFunction; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class StreamingResultSetTest extends HyperTestBase { + private static final int small = 10; + private static final int large = 10 * 1024 * 1024; + + private static Stream queryModes(int size) { + return Stream.of( + inline("executeSyncQuery", DataCloudStatement::executeSyncQuery, size), + inline("executeAdaptiveQuery", DataCloudStatement::executeAdaptiveQuery, size), + deferred("executeAsyncQuery", DataCloudStatement::executeAsyncQuery, true, size), + deferred("execute", DataCloudStatement::execute, false, size), + deferred("executeQuery", DataCloudStatement::executeQuery, false, size)); + } + + public static Stream queryModesWithMax() { + return Stream.of(small, large).flatMap(StreamingResultSetTest::queryModes); + } + + @SneakyThrows + @Test + public void exercisePreparedStatement() { + val sql = + "select cast(a as numeric(38,18)) a, cast(a as numeric(38,18)) b, cast(a as numeric(38,18)) c from generate_series(1, ?) as s(a) order by a asc"; + val expected = new AtomicInteger(0); + + assertWithConnection(conn -> { + try (val statement = conn.prepareStatement(sql)) { + statement.setInt(1, large); + + val rs = statement.executeQuery(); + assertThat(rs).isInstanceOf(StreamingResultSet.class); + assertThat(((StreamingResultSet) rs).isReady()).isTrue(); + + while (rs.next()) { + assertEachRowIsTheSame(rs, expected); + } + } + }); + + assertThat(expected.get()).isEqualTo(large); + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("queryModesWithMax") + public void exerciseQueryMode( + ThrowingBiFunction queryMode, int max) { + val sql = query(max); + val actual = new AtomicInteger(0); + + assertWithStatement(statement -> { + val rs = queryMode.apply(statement, sql); + assertThat(rs).isInstanceOf(StreamingResultSet.class); + assertThat(rs.isReady()).isTrue(); + + while (rs.next()) { + assertEachRowIsTheSame(rs, actual); + } + }); + + assertThat(actual.get()).isEqualTo(max); + } + + private static Stream queryModesWithNoSize() { + return queryModes(-1); + } + + @SneakyThrows + @ParameterizedTest + @MethodSource("queryModesWithNoSize") + public void allModesThrowOnNonsense(ThrowingBiFunction queryMode) { + val ex = Assertions.assertThrows(SQLException.class, () -> { + try (val conn = getHyperQueryConnection(); + val statement = (DataCloudStatement) conn.createStatement()) { + val result = queryMode.apply(statement, "select * from nonsense"); + result.next(); + } + }); + + AssertionsForClassTypes.assertThat(ex).hasRootCauseInstanceOf(StatusRuntimeException.class); + } + + public static String query(int max) { + return String.format( + "select cast(a as numeric(38,18)) a, cast(a as numeric(38,18)) b, cast(a as numeric(38,18)) c from generate_series(1, %d) as s(a) order by a asc", + max); + } + + private static Arguments inline( + String name, ThrowingBiFunction impl, int size) { + return arguments(named(String.format("%s -> DataCloudResultSet", name), impl), size); + } + + private static Arguments deferred( + String name, ThrowingBiFunction impl, Boolean wait, int size) { + ThrowingBiFunction deferred = + (DataCloudStatement s, String x) -> { + impl.apply(s, x); + + if (wait) { + waitUntilReady(s); + } + + return (DataCloudResultSet) s.getResultSet(); + }; + return arguments(named(String.format("%s; getResultSet -> DataCloudResultSet", name), deferred), size); + } + + @SneakyThrows + static boolean waitUntilReady(DataCloudStatement statement) { + while (!statement.isReady()) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + return false; + } + } + return true; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorAssert.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorAssert.java new file mode 100644 index 0000000..2af342f --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorAssert.java @@ -0,0 +1,867 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor; + +import org.assertj.core.api.AbstractObjectAssert; +import org.assertj.core.api.Assertions; +import org.assertj.core.util.Objects; + +/** {@link QueryJDBCAccessor} specific assertions - Generated by CustomAssertionGenerator. */ +@javax.annotation.Generated(value = "assertj-assertions-generator") +public class QueryJDBCAccessorAssert extends AbstractObjectAssert { + + /** + * Creates a new {@link QueryJDBCAccessorAssert} to make assertions on actual QueryJDBCAccessor. + * + * @param actual the QueryJDBCAccessor we want to make assertions on. + */ + public QueryJDBCAccessorAssert(QueryJDBCAccessor actual) { + super(actual, QueryJDBCAccessorAssert.class); + } + + /** + * An entry point for QueryJDBCAccessorAssert to follow AssertJ standard assertThat() statements.
+ * With a static import, one can write directly: assertThat(myQueryJDBCAccessor) and get specific + * assertion with code completion. + * + * @param actual the QueryJDBCAccessor we want to make assertions on. + * @return a new {@link QueryJDBCAccessorAssert} + */ + @org.assertj.core.util.CheckReturnValue + public static QueryJDBCAccessorAssert assertThat(QueryJDBCAccessor actual) { + return new QueryJDBCAccessorAssert(actual); + } + + /** + * Verifies that the actual QueryJDBCAccessor's array is equal to the given one. + * + * @param array the given array to compare the actual QueryJDBCAccessor's array to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's array is not equal to the given one. + * @throws java.sql.SQLException if actual.getArray() throws one. + */ + public QueryJDBCAccessorAssert hasArray(java.sql.Array array) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting array of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.Array actualArray = actual.getArray(); + if (!Objects.areEqual(actualArray, array)) { + failWithMessage(assertjErrorMessage, actual, array, actualArray); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's asciiStream is equal to the given one. + * + * @param asciiStream the given asciiStream to compare the actual QueryJDBCAccessor's asciiStream to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's asciiStream is not equal to the given one. + * @throws java.sql.SQLException if actual.getAsciiStream() throws one. + */ + public QueryJDBCAccessorAssert hasAsciiStream(java.io.InputStream asciiStream) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting asciiStream of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.io.InputStream actualAsciiStream = actual.getAsciiStream(); + if (!Objects.areEqual(actualAsciiStream, asciiStream)) { + failWithMessage(assertjErrorMessage, actual, asciiStream, actualAsciiStream); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's bigDecimal is equal to the given one. + * + * @param bigDecimal the given bigDecimal to compare the actual QueryJDBCAccessor's bigDecimal to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's bigDecimal is not equal to the given one. + * @throws java.sql.SQLException if actual.getBigDecimal() throws one. + */ + public QueryJDBCAccessorAssert hasBigDecimal(java.math.BigDecimal bigDecimal) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting bigDecimal of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.math.BigDecimal actualBigDecimal = actual.getBigDecimal(); + if (!Objects.areEqual(actualBigDecimal, bigDecimal)) { + failWithMessage(assertjErrorMessage, actual, bigDecimal, actualBigDecimal); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's binaryStream is equal to the given one. + * + * @param binaryStream the given binaryStream to compare the actual QueryJDBCAccessor's binaryStream to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's binaryStream is not equal to the given one. + * @throws java.sql.SQLException if actual.getBinaryStream() throws one. + */ + public QueryJDBCAccessorAssert hasBinaryStream(java.io.InputStream binaryStream) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting binaryStream of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.io.InputStream actualBinaryStream = actual.getBinaryStream(); + if (!Objects.areEqual(actualBinaryStream, binaryStream)) { + failWithMessage(assertjErrorMessage, actual, binaryStream, actualBinaryStream); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's blob is equal to the given one. + * + * @param blob the given blob to compare the actual QueryJDBCAccessor's blob to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's blob is not equal to the given one. + * @throws java.sql.SQLException if actual.getBlob() throws one. + */ + public QueryJDBCAccessorAssert hasBlob(java.sql.Blob blob) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting blob of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.Blob actualBlob = actual.getBlob(); + if (!Objects.areEqual(actualBlob, blob)) { + failWithMessage(assertjErrorMessage, actual, blob, actualBlob); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's boolean is equal to the given one. + * + * @param expectedBoolean the given boolean to compare the actual QueryJDBCAccessor's boolean to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's boolean is not equal to the given one. + * @throws java.sql.SQLException if actual.getBoolean() throws one. + */ + public QueryJDBCAccessorAssert hasBoolean(boolean expectedBoolean) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting boolean of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check + boolean actualBoolean = actual.getBoolean(); + if (actualBoolean != expectedBoolean) { + failWithMessage(assertjErrorMessage, actual, expectedBoolean, actualBoolean); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's byte is equal to the given one. + * + * @param expectedByte the given byte to compare the actual QueryJDBCAccessor's byte to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's byte is not equal to the given one. + * @throws java.sql.SQLException if actual.getByte() throws one. + */ + public QueryJDBCAccessorAssert hasByte(byte expectedByte) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting byte of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check + byte actualByte = actual.getByte(); + if (actualByte != expectedByte) { + failWithMessage(assertjErrorMessage, actual, expectedByte, actualByte); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's bytes contains the given byte elements. + * + * @param bytes the given elements that should be contained in actual QueryJDBCAccessor's bytes. + * @return this assertion object. + * @throws AssertionError if the actual QueryJDBCAccessor's bytes does not contain all given byte elements. + * @throws java.sql.SQLException if actual.getBytes() throws one. + */ + public QueryJDBCAccessorAssert hasBytes(byte... bytes) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // check that given byte varargs is not null. + if (bytes == null) failWithMessage("Expecting bytes parameter not to be null."); + + // check with standard error message (use overridingErrorMessage before contains to set your own + // message). + Assertions.assertThat(actual.getBytes()).contains(bytes); + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's bytes contains only the given byte elements and nothing else + * in whatever order. + * + * @param bytes the given elements that should be contained in actual QueryJDBCAccessor's bytes. + * @return this assertion object. + * @throws AssertionError if the actual QueryJDBCAccessor's bytes does not contain all given byte elements and + * nothing else. + * @throws java.sql.SQLException if actual.getBytes() throws one. + */ + public QueryJDBCAccessorAssert hasOnlyBytes(byte... bytes) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // check that given byte varargs is not null. + if (bytes == null) failWithMessage("Expecting bytes parameter not to be null."); + + // check with standard error message (use overridingErrorMessage before contains to set your own + // message). + Assertions.assertThat(actual.getBytes()).containsOnly(bytes); + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's bytes does not contain the given byte elements. + * + * @param bytes the given elements that should not be in actual QueryJDBCAccessor's bytes. + * @return this assertion object. + * @throws AssertionError if the actual QueryJDBCAccessor's bytes contains any given byte elements. + * @throws java.sql.SQLException if actual.getBytes() throws one. + */ + public QueryJDBCAccessorAssert doesNotHaveBytes(byte... bytes) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // check that given byte varargs is not null. + if (bytes == null) failWithMessage("Expecting bytes parameter not to be null."); + + // check with standard error message (use overridingErrorMessage before contains to set your own + // message). + Assertions.assertThat(actual.getBytes()).doesNotContain(bytes); + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor has no bytes. + * + * @return this assertion object. + * @throws AssertionError if the actual QueryJDBCAccessor's bytes is not empty. + * @throws java.sql.SQLException if actual.getBytes() throws one. + */ + public QueryJDBCAccessorAssert hasNoBytes() throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // we override the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting :\n <%s>\nnot to have bytes but had :\n <%s>"; + + // check that it is not empty + if (actual.getBytes().length > 0) { + failWithMessage(assertjErrorMessage, actual, java.util.Arrays.toString(actual.getBytes())); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's characterStream is equal to the given one. + * + * @param characterStream the given characterStream to compare the actual QueryJDBCAccessor's characterStream to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's characterStream is not equal to the given one. + * @throws java.sql.SQLException if actual.getCharacterStream() throws one. + */ + public QueryJDBCAccessorAssert hasCharacterStream(java.io.Reader characterStream) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting characterStream of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.io.Reader actualCharacterStream = actual.getCharacterStream(); + if (!Objects.areEqual(actualCharacterStream, characterStream)) { + failWithMessage(assertjErrorMessage, actual, characterStream, actualCharacterStream); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's clob is equal to the given one. + * + * @param clob the given clob to compare the actual QueryJDBCAccessor's clob to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's clob is not equal to the given one. + * @throws java.sql.SQLException if actual.getClob() throws one. + */ + public QueryJDBCAccessorAssert hasClob(java.sql.Clob clob) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting clob of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.Clob actualClob = actual.getClob(); + if (!Objects.areEqual(actualClob, clob)) { + failWithMessage(assertjErrorMessage, actual, clob, actualClob); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's double is equal to the given one. + * + * @param expectedDouble the given double to compare the actual QueryJDBCAccessor's double to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's double is not equal to the given one. + * @throws java.sql.SQLException if actual.getDouble() throws one. + */ + public QueryJDBCAccessorAssert hasDouble(double expectedDouble) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting double of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check value for double + double actualDouble = actual.getDouble(); + if (actualDouble != expectedDouble) { + failWithMessage(assertjErrorMessage, actual, expectedDouble, actualDouble); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's double is close to the given value by less than the given offset. + * + *

If difference is equal to the offset value, assertion is considered successful. + * + * @param expectedDouble the value to compare the actual QueryJDBCAccessor's double to. + * @param assertjOffset the given offset. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's double is not close enough to the given value. + * @throws java.sql.SQLException if actual.getDouble() throws one. + */ + public QueryJDBCAccessorAssert hasDoubleCloseTo(double expectedDouble, double assertjOffset) + throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + double actualDouble = actual.getDouble(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = String.format( + "\nExpecting double:\n <%s>\nto be close to:\n <%s>\nby less than <%s> but difference was <%s>", + actualDouble, expectedDouble, assertjOffset, Math.abs(expectedDouble - actualDouble)); + + // check + Assertions.assertThat(actualDouble) + .overridingErrorMessage(assertjErrorMessage) + .isCloseTo(expectedDouble, Assertions.within(assertjOffset)); + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's float is equal to the given one. + * + * @param expectedFloat the given float to compare the actual QueryJDBCAccessor's float to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's float is not equal to the given one. + * @throws java.sql.SQLException if actual.getFloat() throws one. + */ + public QueryJDBCAccessorAssert hasFloat(float expectedFloat) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting float of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check value for float + float actualFloat = actual.getFloat(); + if (actualFloat != expectedFloat) { + failWithMessage(assertjErrorMessage, actual, expectedFloat, actualFloat); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's float is close to the given value by less than the given offset. + * + *

If difference is equal to the offset value, assertion is considered successful. + * + * @param expectedFloat the value to compare the actual QueryJDBCAccessor's float to. + * @param assertjOffset the given offset. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's float is not close enough to the given value. + * @throws java.sql.SQLException if actual.getFloat() throws one. + */ + public QueryJDBCAccessorAssert hasFloatCloseTo(float expectedFloat, float assertjOffset) + throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + float actualFloat = actual.getFloat(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = String.format( + "\nExpecting float:\n <%s>\nto be close to:\n <%s>\nby less than <%s> but difference was <%s>", + actualFloat, expectedFloat, assertjOffset, Math.abs(expectedFloat - actualFloat)); + + // check + Assertions.assertThat(actualFloat) + .overridingErrorMessage(assertjErrorMessage) + .isCloseTo(expectedFloat, Assertions.within(assertjOffset)); + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's int is equal to the given one. + * + * @param expectedInt the given int to compare the actual QueryJDBCAccessor's int to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's int is not equal to the given one. + * @throws java.sql.SQLException if actual.getInt() throws one. + */ + public QueryJDBCAccessorAssert hasInt(int expectedInt) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting int of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check + int actualInt = actual.getInt(); + if (actualInt != expectedInt) { + failWithMessage(assertjErrorMessage, actual, expectedInt, actualInt); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's long is equal to the given one. + * + * @param expectedLong the given long to compare the actual QueryJDBCAccessor's long to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's long is not equal to the given one. + * @throws java.sql.SQLException if actual.getLong() throws one. + */ + public QueryJDBCAccessorAssert hasLong(long expectedLong) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting long of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check + long actualLong = actual.getLong(); + if (actualLong != expectedLong) { + failWithMessage(assertjErrorMessage, actual, expectedLong, actualLong); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's nCharacterStream is equal to the given one. + * + * @param nCharacterStream the given nCharacterStream to compare the actual QueryJDBCAccessor's nCharacterStream to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's nCharacterStream is not equal to the given one. + * @throws java.sql.SQLException if actual.getNCharacterStream() throws one. + */ + public QueryJDBCAccessorAssert hasNCharacterStream(java.io.Reader nCharacterStream) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting nCharacterStream of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.io.Reader actualNCharacterStream = actual.getNCharacterStream(); + if (!Objects.areEqual(actualNCharacterStream, nCharacterStream)) { + failWithMessage(assertjErrorMessage, actual, nCharacterStream, actualNCharacterStream); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's nClob is equal to the given one. + * + * @param nClob the given nClob to compare the actual QueryJDBCAccessor's nClob to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's nClob is not equal to the given one. + * @throws java.sql.SQLException if actual.getNClob() throws one. + */ + public QueryJDBCAccessorAssert hasNClob(java.sql.NClob nClob) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting nClob of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.NClob actualNClob = actual.getNClob(); + if (!Objects.areEqual(actualNClob, nClob)) { + failWithMessage(assertjErrorMessage, actual, nClob, actualNClob); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's nString is equal to the given one. + * + * @param nString the given nString to compare the actual QueryJDBCAccessor's nString to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's nString is not equal to the given one. + * @throws java.sql.SQLException if actual.getNString() throws one. + */ + public QueryJDBCAccessorAssert hasNString(String nString) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting nString of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + String actualNString = actual.getNString(); + if (!Objects.areEqual(actualNString, nString)) { + failWithMessage(assertjErrorMessage, actual, nString, actualNString); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor was null. + * + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor was not null. + */ + public QueryJDBCAccessorAssert wasNull() { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // check that property call/field access is true + if (!actual.wasNull()) { + failWithMessage("\nExpecting that actual QueryJDBCAccessor was null but was not."); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor was not null. + * + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor was null. + */ + public QueryJDBCAccessorAssert wasNotNull() { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // check that property call/field access is false + if (actual.wasNull()) { + failWithMessage("\nExpecting that actual QueryJDBCAccessor was not null but was."); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's object is equal to the given one. + * + * @param object the given object to compare the actual QueryJDBCAccessor's object to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's object is not equal to the given one. + * @throws java.sql.SQLException if actual.getObject() throws one. + */ + public QueryJDBCAccessorAssert hasObject(Object object) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting object of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + Object actualObject = actual.getObject(); + if (!Objects.areEqual(actualObject, object)) { + failWithMessage(assertjErrorMessage, actual, object, actualObject); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's objectClass is equal to the given one. + * + * @param objectClass the given objectClass to compare the actual QueryJDBCAccessor's objectClass to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's objectClass is not equal to the given one. + */ + public QueryJDBCAccessorAssert hasObjectClass(Class objectClass) { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting objectClass of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + Class actualObjectClass = actual.getObjectClass(); + if (!Objects.areEqual(actualObjectClass, objectClass)) { + failWithMessage(assertjErrorMessage, actual, objectClass, actualObjectClass); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's ref is equal to the given one. + * + * @param ref the given ref to compare the actual QueryJDBCAccessor's ref to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's ref is not equal to the given one. + * @throws java.sql.SQLException if actual.getRef() throws one. + */ + public QueryJDBCAccessorAssert hasRef(java.sql.Ref ref) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting ref of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.Ref actualRef = actual.getRef(); + if (!Objects.areEqual(actualRef, ref)) { + failWithMessage(assertjErrorMessage, actual, ref, actualRef); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's sQLXML is equal to the given one. + * + * @param sQLXML the given sQLXML to compare the actual QueryJDBCAccessor's sQLXML to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's sQLXML is not equal to the given one. + * @throws java.sql.SQLException if actual.getSQLXML() throws one. + */ + public QueryJDBCAccessorAssert hasSQLXML(java.sql.SQLXML sQLXML) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting sQLXML of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.SQLXML actualSQLXML = actual.getSQLXML(); + if (!Objects.areEqual(actualSQLXML, sQLXML)) { + failWithMessage(assertjErrorMessage, actual, sQLXML, actualSQLXML); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's short is equal to the given one. + * + * @param expectedShort the given short to compare the actual QueryJDBCAccessor's short to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's short is not equal to the given one. + * @throws java.sql.SQLException if actual.getShort() throws one. + */ + public QueryJDBCAccessorAssert hasShort(short expectedShort) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting short of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // check + short actualShort = actual.getShort(); + if (actualShort != expectedShort) { + failWithMessage(assertjErrorMessage, actual, expectedShort, actualShort); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's string is equal to the given one. + * + * @param string the given string to compare the actual QueryJDBCAccessor's string to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's string is not equal to the given one. + * @throws java.sql.SQLException if actual.getString() throws one. + */ + public QueryJDBCAccessorAssert hasString(String string) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting string of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + String actualString = actual.getString(); + if (!Objects.areEqual(actualString, string)) { + failWithMessage(assertjErrorMessage, actual, string, actualString); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's struct is equal to the given one. + * + * @param struct the given struct to compare the actual QueryJDBCAccessor's struct to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's struct is not equal to the given one. + * @throws java.sql.SQLException if actual.getStruct() throws one. + */ + public QueryJDBCAccessorAssert hasStruct(java.sql.Struct struct) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting struct of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.sql.Struct actualStruct = actual.getStruct(); + if (!Objects.areEqual(actualStruct, struct)) { + failWithMessage(assertjErrorMessage, actual, struct, actualStruct); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's uRL is equal to the given one. + * + * @param uRL the given uRL to compare the actual QueryJDBCAccessor's uRL to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's uRL is not equal to the given one. + * @throws java.sql.SQLException if actual.getURL() throws one. + */ + public QueryJDBCAccessorAssert hasURL(java.net.URL uRL) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting uRL of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.net.URL actualURL = actual.getURL(); + if (!Objects.areEqual(actualURL, uRL)) { + failWithMessage(assertjErrorMessage, actual, uRL, actualURL); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryJDBCAccessor's unicodeStream is equal to the given one. + * + * @param unicodeStream the given unicodeStream to compare the actual QueryJDBCAccessor's unicodeStream to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryJDBCAccessor's unicodeStream is not equal to the given one. + * @throws java.sql.SQLException if actual.getUnicodeStream() throws one. + */ + public QueryJDBCAccessorAssert hasUnicodeStream(java.io.InputStream unicodeStream) throws java.sql.SQLException { + // check that actual QueryJDBCAccessor we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting unicodeStream of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + java.io.InputStream actualUnicodeStream = actual.getUnicodeStream(); + if (!Objects.areEqual(actualUnicodeStream, unicodeStream)) { + failWithMessage(assertjErrorMessage, actual, unicodeStream, actualUnicodeStream); + } + + // return the current assertion for method chaining + return this; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorTest.java new file mode 100644 index 0000000..fbd0b99 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/QueryJDBCAccessorTest.java @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Calendar; +import java.util.HashMap; +import java.util.Map; +import lombok.val; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class QueryJDBCAccessorTest { + + Calendar calendar = Calendar.getInstance(); + + @Test + public void shouldThrowUnsupportedError() { + QueryJDBCAccessor absCls = Mockito.mock(QueryJDBCAccessor.class, Mockito.CALLS_REAL_METHODS); + + val e1 = assertThrows(DataCloudJDBCException.class, absCls::getBytes); + assertThat(e1).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e2 = assertThrows(DataCloudJDBCException.class, absCls::getShort); + assertThat(e2).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e3 = assertThrows(DataCloudJDBCException.class, absCls::getInt); + assertThat(e3).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e4 = assertThrows(DataCloudJDBCException.class, absCls::getLong); + assertThat(e4).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e5 = assertThrows(DataCloudJDBCException.class, absCls::getFloat); + assertThat(e5).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e6 = assertThrows(DataCloudJDBCException.class, absCls::getDouble); + assertThat(e6).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e7 = assertThrows(DataCloudJDBCException.class, absCls::getBoolean); + assertThat(e7).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e8 = assertThrows(DataCloudJDBCException.class, absCls::getString); + assertThat(e8).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + HashMap stringDateHashMap = new HashMap<>(); + val e9 = assertThrows( + DataCloudJDBCException.class, () -> absCls.getObject((Map>) stringDateHashMap)); + assertThat(e9).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e10 = assertThrows(DataCloudJDBCException.class, absCls::getByte); + assertThat(e10).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e11 = assertThrows(DataCloudJDBCException.class, absCls::getBigDecimal); + assertThat(e11).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e12 = assertThrows(DataCloudJDBCException.class, () -> absCls.getBigDecimal(1)); + assertThat(e12).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e13 = assertThrows(DataCloudJDBCException.class, absCls::getAsciiStream); + assertThat(e13).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e14 = assertThrows(DataCloudJDBCException.class, absCls::getUnicodeStream); + assertThat(e14).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e15 = assertThrows(DataCloudJDBCException.class, absCls::getBinaryStream); + assertThat(e15).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e16 = assertThrows(DataCloudJDBCException.class, absCls::getObject); + assertThat(e16).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e17 = assertThrows(DataCloudJDBCException.class, absCls::getCharacterStream); + assertThat(e17).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e18 = assertThrows(DataCloudJDBCException.class, absCls::getRef); + assertThat(e18).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e19 = assertThrows(DataCloudJDBCException.class, absCls::getBlob); + assertThat(e19).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e20 = assertThrows(DataCloudJDBCException.class, absCls::getClob); + assertThat(e20).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e21 = assertThrows(DataCloudJDBCException.class, absCls::getArray); + assertThat(e21).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e22 = assertThrows(DataCloudJDBCException.class, absCls::getStruct); + assertThat(e22).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e23 = assertThrows(DataCloudJDBCException.class, absCls::getURL); + assertThat(e23).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e24 = assertThrows(DataCloudJDBCException.class, absCls::getNClob); + assertThat(e24).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e25 = assertThrows(DataCloudJDBCException.class, absCls::getSQLXML); + assertThat(e25).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e26 = assertThrows(DataCloudJDBCException.class, absCls::getNString); + assertThat(e26).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e27 = assertThrows(DataCloudJDBCException.class, absCls::getNCharacterStream); + assertThat(e27).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e28 = assertThrows(DataCloudJDBCException.class, () -> absCls.getDate(calendar)); + assertThat(e28).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e29 = assertThrows(DataCloudJDBCException.class, () -> absCls.getTime(calendar)); + assertThat(e29).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + + val e30 = assertThrows(DataCloudJDBCException.class, () -> absCls.getTimestamp(calendar)); + assertThat(e30).hasRootCauseInstanceOf(SQLFeatureNotSupportedException.class); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/SoftAssertions.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/SoftAssertions.java new file mode 100644 index 0000000..42d2259 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/SoftAssertions.java @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor; + +/** Entry point for soft assertions of different data types. */ +@javax.annotation.Generated(value = "assertj-assertions-generator") +public class SoftAssertions extends org.assertj.core.api.SoftAssertions { + + /** + * Creates a new "soft" instance of + * {@link com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorAssert}. + * + * @param actual the actual value. + * @return the created "soft" assertion object. + */ + @org.assertj.core.util.CheckReturnValue + public com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorAssert assertThat( + com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor actual) { + return proxy( + com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorAssert.class, + com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessor.class, + actual); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessorTest.java new file mode 100644 index 0000000..45e9292 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BaseIntVectorAccessorTest.java @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.math.BigDecimal; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class BaseIntVectorAccessorTest { + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @InjectSoftAssertions + private SoftAssertions collector; + + @SneakyThrows + @Test + public void testShouldConvertToTinyIntMethodFromBaseIntVector() { + val values = getTinyIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTinyIntVector(values)) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasByte(expected) + .hasShort(expected) + .hasInt(expected) + .hasFloat(expected) + .hasDouble(expected) + .hasBigDecimal(new BigDecimal(expected)) + .hasObject(expected) + .hasObjectClass(Long.class); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 7); + } + + @SneakyThrows + @Test + public void testShouldConvertToSmallIntMethodFromBaseIntVector() { + val values = getSmallIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createSmallIntVector(values)) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasShort(expected) + .hasInt((int) expected) + .hasFloat(expected) + .hasDouble(expected) + .hasBigDecimal(new BigDecimal(expected)) + .hasObject(expected) + .hasObjectClass(Long.class); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 6); + } + + @SneakyThrows + @Test + public void testShouldConvertToIntegerMethodFromBaseIntVector() { + val values = getIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createIntVector(values)) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasInt(expected) + .hasFloat(expected) + .hasDouble(expected) + .hasBigDecimal(new BigDecimal(expected)) + .hasObject(expected) + .hasObjectClass(Long.class); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 5); + } + + @SneakyThrows + @Test + public void testShouldConvertToBigIntMethodFromBaseIntVector() { + val values = getBigIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createBigIntVector(values)) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasFloat(expected) + .hasDouble(expected) + .hasBigDecimal(new BigDecimal(expected)) + .hasObject(expected) + .hasObjectClass(Long.class); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 4); + } + + @SneakyThrows + @Test + public void testShouldConvertToUInt4MethodFromBaseIntVector() { + val values = getIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createUInt4Vector(values)) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasInt(expected) + .hasFloat(expected) + .hasDouble(expected) + .hasBigDecimal(new BigDecimal(expected)) + .hasObject(expected) + .hasObjectClass(Long.class); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 5); + } + + @SneakyThrows + @Test + public void testGetBigDecimalGetObjectAndGetObjectClassFromNulledDecimalVector() { + val values = getBigIntValues(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createBigIntVector(values))) { + val i = new AtomicInteger(0); + val sut = new BaseIntVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasBigDecimal(null).hasObject(null).hasObjectClass(Long.class); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 2); + } + + private List getTinyIntValues() { + return List.of((byte) 0, (byte) 1, (byte) -1, Byte.MIN_VALUE, Byte.MAX_VALUE); + } + + private List getSmallIntValues() { + return List.of( + (short) 0, + (short) 1, + (short) -1, + (short) Byte.MIN_VALUE, + (short) Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE); + } + + private List getIntValues() { + return List.of( + 0, + 1, + -1, + (int) Byte.MIN_VALUE, + (int) Byte.MAX_VALUE, + (int) Short.MIN_VALUE, + (int) Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE); + } + + private List getBigIntValues() { + return List.of( + (long) 0, + (long) 1, + (long) -1, + (long) Byte.MIN_VALUE, + (long) Byte.MAX_VALUE, + (long) Short.MIN_VALUE, + (long) Short.MAX_VALUE, + (long) Integer.MIN_VALUE, + (long) Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessorTest.java new file mode 100644 index 0000000..13c7ffe --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BinaryVectorAccessorTest.java @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class BinaryVectorAccessorTest { + @RegisterExtension + public static RootAllocatorTestExtension rootAllocatorTestExtension = new RootAllocatorTestExtension(); + + @InjectSoftAssertions + private SoftAssertions collector; + + private final List binaryList = List.of( + "BINARY_DATA_0001".getBytes(StandardCharsets.UTF_8), + "BINARY_DATA_0002".getBytes(StandardCharsets.UTF_8), + "BINARY_DATA_0003".getBytes(StandardCharsets.UTF_8)); + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromValidVarBinaryVector() { + val values = binaryList; + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = rootAllocatorTestExtension.createVarBinaryVector(values)) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasBytes(expected) + .hasObject(expected) + .hasString(new String(expected, StandardCharsets.UTF_8)); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromNulledVarBinaryVector() { + val expectedNullChecks = binaryList.size() * 3; // seen thrice since getObject and getString both call + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(rootAllocatorTestExtension.createVarBinaryVector(binaryList))) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasObject(null) + .hasString(null); + collector.assertThat(sut.getBytes()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromValidLargeVarBinaryVector() { + val values = binaryList; + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = rootAllocatorTestExtension.createLargeVarBinaryVector(values)) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasBytes(expected) + .hasObject(expected) + .hasString(new String(expected, StandardCharsets.UTF_8)); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromNulledLargeVarCharVector() { + val expectedNullChecks = binaryList.size() * 3; // seen thrice since getObject and getString both call + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(rootAllocatorTestExtension.createLargeVarBinaryVector(binaryList))) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasObject(null) + .hasString(null); + collector.assertThat(sut.getBytes()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromValidFixedSizeVarBinaryVector() { + val values = binaryList; + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = rootAllocatorTestExtension.createFixedSizeBinaryVector(values)) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasBytes(expected) + .hasObject(expected) + .hasString(new String(expected, StandardCharsets.UTF_8)); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetBytesGetStringGetObjectAndGetObjectClassFromNulledFixedSizeVarCharVector() { + val expectedNullChecks = binaryList.size() * 3; // seen thrice since getObject and getString both call + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(rootAllocatorTestExtension.createFixedSizeBinaryVector(binaryList))) { + val i = new AtomicInteger(0); + val sut = new BinaryVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector + .assertThat(sut) + .hasObjectClass(byte[].class) + .hasObject(null) + .hasString(null); + collector.assertThat(sut.getBytes()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessorTest.java new file mode 100644 index 0000000..2f5f3a2 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/BooleanVectorAccessorTest.java @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorAssert.assertThat; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.math.BigDecimal; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.val; +import org.assertj.core.api.ThrowingConsumer; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class BooleanVectorAccessorTest { + @InjectSoftAssertions + private SoftAssertions collector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + private static final Random random = new Random(10); + private static final List values = Stream.concat( + IntStream.range(0, 15).mapToObj(i -> random.nextBoolean()), Stream.of(null, null, null, null, null)) + .collect(Collectors.toList()); + private static final int expectedNulls = + (int) values.stream().filter(Objects::isNull).count(); + private static final int expectedNonNulls = values.size() - expectedNulls; + + @BeforeAll + static void setup() { + Collections.shuffle(values); + } + + private void iterate(BuildThrowingConsumer builder) { + val consumer = new TestWasNullConsumer(collector); + try (val vector = extension.createBitVector(values)) { + val i = new AtomicInteger(0); + val sut = new BooleanVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + val s = builder.buildSatisfies(expected); + collector + .assertThat(sut) + .hasObjectClass(Boolean.class) + .satisfies(b -> s.accept((BooleanVectorAccessor) b)); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @FunctionalInterface + private interface BuildThrowingConsumer { + ThrowingConsumer buildSatisfies(Boolean expected); + } + + @Test + void testShouldGetBooleanMethodFromBitVector() { + iterate(expected -> sut -> assertThat(sut).hasBoolean(expected != null && expected)); + } + + @Test + void testShouldGetByteMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasByte(expected == null || !expected ? (byte) 0 : (byte) 1)); + } + + @Test + void testShouldGetShortMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasShort(expected == null || !expected ? (short) 0 : (short) 1)); + } + + @Test + void testShouldGetIntMethodFromBitVector() { + iterate(expected -> sut -> collector.assertThat(sut).hasInt(expected == null || !expected ? 0 : 1)); + } + + @Test + void testShouldGetFloatMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasFloat(expected == null || !expected ? (float) 0 : (float) 1)); + } + + @Test + void testShouldGetDoubleMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasDouble(expected == null || !expected ? (double) 0 : (double) 1)); + } + + @Test + void testShouldGetLongMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasLong(expected == null || !expected ? (long) 0 : (long) 1)); + } + + @Test + void testShouldGetBigDecimalMethodFromBitVector() { + iterate(expected -> sut -> collector + .assertThat(sut) + .hasBigDecimal(expected == null ? null : (expected ? BigDecimal.ONE : BigDecimal.ZERO))); + } + + @Test + void testShouldGetStringMethodFromBitVector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasString(expected == null ? null : (expected ? "true" : "false"))); + } + + @Test + void testShouldGetBooleanMethodFromBitVectorFromNull() { + iterate(expected -> sut -> collector.assertThat(sut).hasBoolean(expected != null && expected)); + } + + @Test + void testShouldGetObjectMethodFromBitVector() { + iterate(expected -> sut -> collector.assertThat(sut).hasObject(expected)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArrayTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArrayTest.java new file mode 100644 index 0000000..86e801e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DataCloudArrayTest.java @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import java.sql.Types; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.apache.arrow.vector.FieldVector; +import org.assertj.core.api.AssertionsForClassTypes; +import org.assertj.core.api.ThrowingConsumer; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +@ExtendWith({SoftAssertionsExtension.class}) +public class DataCloudArrayTest { + @InjectSoftAssertions + private SoftAssertions collector; + + FieldVector dataVector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @AfterEach + public void tearDown() { + this.dataVector.close(); + } + + @SneakyThrows + @Test + void testGetBaseTypeReturnsCorrectBaseType() { + val values = List.of(true, false); + dataVector = extension.createBitVector(values); + val array = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + collector.assertThat(array.getBaseType()).isEqualTo(Types.BOOLEAN); + } + + @SneakyThrows + @Test + void testGetBaseTypeNameReturnsCorrectBaseTypeName() { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val array = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + collector.assertThat(array.getBaseTypeName()).isEqualTo("INTEGER"); + } + + @SneakyThrows + @Test + void testGetArrayReturnsCorrectArray() { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val dataCloudArray = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + val array = (Object[]) dataCloudArray.getArray(); + val expected = values.toArray(); + collector.assertThat(array).isEqualTo(expected); + collector.assertThat(array.length).isEqualTo(expected.length); + } + + @SneakyThrows + @Test + void testGetArrayWithCorrectOffsetReturnsCorrectArray() { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val dataCloudArray = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + val array = (Object[]) dataCloudArray.getArray(0, 2); + val expected = List.of(1, 2).toArray(); + collector.assertThat(array).isEqualTo(expected); + collector.assertThat(array.length).isEqualTo(expected.length); + } + + @SneakyThrows + @Test + void testShouldThrowIfGetArrayHasIncorrectOffset() { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val dataCloudArray = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + assertThrows( + ArrayIndexOutOfBoundsException.class, () -> dataCloudArray.getArray(0, dataVector.getValueCount() + 1)); + } + + @SneakyThrows + @Test + void testShouldThrowIfGetArrayHasIncorrectIndex() { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val dataCloudArray = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + assertThrows( + ArrayIndexOutOfBoundsException.class, () -> dataCloudArray.getArray(-1, dataVector.getValueCount())); + } + + private static Arguments impl(String name, ThrowingConsumer impl) { + return arguments(named(name, impl)); + } + + private static Stream unsupported() { + return Stream.of( + impl("getArray with map", a -> a.getArray(new HashMap<>())), + impl("getArray with map & index", a -> a.getArray(0, 1, new HashMap<>())), + impl("getResultSet", DataCloudArray::getResultSet), + impl("getResultSet with map", a -> a.getResultSet(new HashMap<>())), + impl("getResultSet with map & index", a -> a.getResultSet(0, 1, new HashMap<>())), + impl("getResultSet with count & index", a -> a.getResultSet(0, 1))); + } + + @ParameterizedTest + @MethodSource("unsupported") + @SneakyThrows + void testUnsupportedOperations(ThrowingConsumer func) { + val values = List.of(1, 2, 3); + dataVector = extension.createIntVector(values); + val dataCloudArray = new DataCloudArray(dataVector, 0, dataVector.getValueCount()); + + val e = Assertions.assertThrows(RuntimeException.class, () -> func.accept(dataCloudArray)); + AssertionsForClassTypes.assertThat(e).hasRootCauseInstanceOf(DataCloudJDBCException.class); + AssertionsForClassTypes.assertThat(e).hasMessageContaining("Array method is not supported in Data Cloud query"); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessorTest.java new file mode 100644 index 0000000..668174e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DateVectorAccessorTest.java @@ -0,0 +1,467 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.appendDates; +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.sql.Date; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class DateVectorAccessorTest { + + @InjectSoftAssertions + private SoftAssertions collector; + + public static final String ASIA_BANGKOK = "Asia/Bangkok"; + public static final ZoneId UTC_ZONE_ID = TimeZone.getTimeZone("UTC").toZoneId(); + + private static final int expectedZero = 0; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + private static final List values = Stream.of( + 1625702400000L, // 2021-07-08 00:00:00 UTC, 2021-07-07 17:00:00 PT + 1625788800000L, // 2021-07-09 00:00:00 UTC, 2021-07-08 17:00:00 PT + 0L, // 1970-01-01 00:00:00 UTC, 1969-12-31 16:00:00 PT + -1625702400000L, // 1918-06-27 00:00:00 UTC, 1918-06-26 17:00:00 PT + -601689600000L) // 1950-12-08 00:00:00 UTC, 1950-12-07 16:00:00 PT + .collect(Collectors.toList()); + + private static final int expectedNulls = + (int) values.stream().filter(Objects::isNull).count(); + + private static final int expectedNonNulls = values.size() - expectedNulls; + + @SneakyThrows + @Test + void testDateDayVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateDayVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Date.class); + } + } + } + + @SneakyThrows + @Test + void testDateMilliVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateMilliVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Date.class); + } + } + } + + @SneakyThrows + @Test + void testDateDayVectorGetDateReturnsUTCDate() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateDayVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector + .assertThat(sut.getDate(null)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + collector + .assertThat(sut.getDate(calendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + collector + .assertThat(sut.getDate(defaultCalendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 3).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateMilliVectorGetDateReturnsUTCDate() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateMilliVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector + .assertThat(sut.getDate(null)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + collector + .assertThat(sut.getDate(calendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + collector + .assertThat(sut.getDate(defaultCalendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 3).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateDayVectorGetObjectReturnsUTCDate() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateDayVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + collector.assertThat(sut.getObject()).isInstanceOf(Date.class); + collector + .assertThat((Date) sut.getObject()) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateMilliVectorGetObjectReturnsUTCDate() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateMilliVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + collector.assertThat(sut.getObject()).isInstanceOf(Date.class); + collector + .assertThat((Date) sut.getObject()) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateDayVectorGetTimestampReturnsUTCTimestampAtMidnight() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateDayVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector + .assertThat(sut.getTimestamp(null)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero) + .hasMillisecond(expectedZero); + collector + .assertThat(sut.getTimestamp(calendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero); + collector + .assertThat(sut.getTimestamp(defaultCalendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 3).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateMilliVectorGetTimestampReturnsUTCTimestampAtMidnight() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateMilliVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedDate = getExpectedDateForTimeZone(values.get(i.get()), UTC_ZONE_ID); + val expectedYear = expectedDate.get("year"); + val expectedMonth = expectedDate.get("month"); + val expectedDay = expectedDate.get("day"); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector + .assertThat(sut.getTimestamp(null)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero); + collector + .assertThat(sut.getTimestamp(calendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero); + collector + .assertThat(sut.getTimestamp(defaultCalendar)) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedZero) + .hasMinute(expectedZero) + .hasSecond(expectedZero); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 3).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateMilliVectorGetString() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateMilliVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val expected = getISOString(values.get(i.get())); + collector.assertThat(stringValue).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testDateDayVectorGetString() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = appendDates(values, extension.createDateDayVector())) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val expected = getISOString(values.get(i.get())); + collector.assertThat(stringValue).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testNulledOutDateDayVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(appendDates(values, extension.createDateDayVector()))) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector.assertThat(sut.getDate(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + @SneakyThrows + @Test + void testNulledOutDateMilliVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(appendDates(values, extension.createDateMilliVector()))) { + val i = new AtomicInteger(0); + val sut = new DateVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + collector.assertThat(sut.getDate(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + private Map getExpectedDateForTimeZone(long epochMilli, ZoneId zoneId) { + final Instant instant = Instant.ofEpochMilli(epochMilli); + final ZonedDateTime zdt = ZonedDateTime.ofInstant(instant, zoneId); + + final int expectedYear = zdt.getYear(); + final int expectedMonth = zdt.getMonthValue(); + final int expectedDay = zdt.getDayOfMonth(); + + return Map.of("year", expectedYear, "month", expectedMonth, "day", expectedDay); + } + + private String getISOString(Long millis) { + val epochDays = millis / MILLIS_PER_DAY; + LocalDate localDate = LocalDate.ofEpochDay(epochDays); + return localDate.format(DateTimeFormatter.ISO_DATE); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessorTest.java new file mode 100644 index 0000000..6a4bde5 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DecimalVectorAccessorTest.java @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.math.BigDecimal; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class DecimalVectorAccessorTest { + private static final int total = 8; + private final Random random = new Random(10); + + @InjectSoftAssertions + private SoftAssertions collector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @SneakyThrows + @Test + void testGetBigDecimalGetObjectAndGetObjectClassFromValidDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createDecimalVector(values)) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasBigDecimal(expected) + .hasObject(expected) + .hasObjectClass(BigDecimal.class); + } + } + + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(values.size() * 2); + } + + @SneakyThrows + @Test + void testGetBigDecimalGetObjectAndGetObjectClassFromNulledDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createDecimalVector(values))) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasBigDecimal(null).hasObject(null).hasObjectClass(BigDecimal.class); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 2); + } + + @SneakyThrows + @Test + void testGetStringFromDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createDecimalVector(values)) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val expected = values.get(i.get()).toString(); + collector.assertThat(stringValue).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(values.size()).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetStringFromNullDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createDecimalVector(values))) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + collector.assertThat(stringValue).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size()); + } + + @SneakyThrows + @Test + void testGetIntFromDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createDecimalVector(values)) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val intValue = sut.getInt(); + val expected = values.get(i.get()).intValue(); + collector.assertThat(intValue).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(values.size()).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetIntFromNullDecimalVector() { + val values = getBigDecimals(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createDecimalVector(values))) { + val i = new AtomicInteger(0); + val sut = new DecimalVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val intValue = sut.getInt(); + collector.assertThat(intValue).isZero(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size()); + } + + private List getBigDecimals() { + return IntStream.range(0, total) + .mapToObj(x -> new BigDecimal(random.nextLong())) + .collect(Collectors.toUnmodifiableList()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessorTest.java new file mode 100644 index 0000000..b536f8e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/DoubleVectorAccessorTest.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.Collections; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.val; +import org.assertj.core.api.ThrowingConsumer; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class DoubleVectorAccessorTest { + @InjectSoftAssertions + private SoftAssertions collector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + private static final int total = 10; + private static final Random random = new Random(10); + private static final List values = IntStream.range(0, total - 2) + .mapToDouble(x -> random.nextDouble()) + .filter(Double::isFinite) + .boxed() + .collect(Collectors.toList()); + + @BeforeAll + static void setup() { + values.add(null); + values.add(null); + Collections.shuffle(values); + } + + private TestWasNullConsumer iterate(List values, BuildThrowingConsumer builder) { + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createFloat8Vector(values)) { + val i = new AtomicInteger(0); + val sut = new DoubleVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + val s = builder.buildSatisfies(expected); + collector.assertThat(sut).satisfies(b -> s.accept((DoubleVectorAccessor) b)); + } + } + + return consumer; + } + + private TestWasNullConsumer iterate(BuildThrowingConsumer builder) { + val consumer = iterate(values, builder); + consumer.assertThat().hasNullSeen(2).hasNotNullSeen(values.size() - 2); + return consumer; + } + + @FunctionalInterface + private interface BuildThrowingConsumer { + ThrowingConsumer buildSatisfies(Double expected); + } + + @Test + void testShouldGetDoubleMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasDouble(expected == null ? 0.0 : expected)); + } + + @Test + void testShouldGetObjectMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasObject(expected)); + } + + @Test + void testShouldGetStringMethodFromFloat8Vector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasString(expected == null ? null : Double.toString(expected))); + } + + @Test + void testShouldGetBooleanMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasBoolean(expected != null && (expected != 0.0))); + } + + @Test + void testShouldGetByteMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasByte((byte) (expected == null ? 0.0 : expected))); + } + + @Test + void testShouldGetShortMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasShort((short) (expected == null ? 0.0 : expected))); + } + + @Test + void testShouldGetIntMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasInt((int) (expected == null ? 0.0 : expected))); + } + + @Test + void testShouldGetLongMethodFromFloat8Vector() { + iterate(expected -> sut -> collector.assertThat(sut).hasLong((long) (expected == null ? 0.0 : expected))); + } + + @Test + void testShouldGetFloatMethodFromFloat8Vector() { + iterate(expected -> + sut -> collector.assertThat(sut).hasFloat((float) (expected == null ? 0.0 : expected))); // 0.0f + } + + @Test + void testGetBigDecimalIllegalDoublesMethodFromFloat8Vector() { + val consumer = iterate( + List.of(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN), + expected -> sut -> assertThrows(DataCloudJDBCException.class, sut::getBigDecimal)); + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(3); + } + + @Test + void testShouldGetBigDecimalWithScaleMethodFromFloat4Vector() { + val scale = 9; + val big = Double.MAX_VALUE; + val expected = BigDecimal.valueOf(big).setScale(scale, RoundingMode.HALF_UP); + iterate( + List.of(Double.MAX_VALUE), + e -> sut -> collector.assertThat(sut.getBigDecimal(scale)).isEqualTo(expected)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessorTest.java new file mode 100644 index 0000000..4394f6c --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/ListVectorAccessorTest.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith({SoftAssertionsExtension.class}) +public class ListVectorAccessorTest { + private static final int total = 127; + + @InjectSoftAssertions + private SoftAssertions collector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @SneakyThrows + @Test + void testGetObjectAndGetObjectClassFromValidListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + try (val vector = extension.createListVector("test-list-vector")) { + val i = new AtomicInteger(0); + val sut = new ListVectorAccessor(vector, i::get, consumer); + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector.assertThat(sut).hasObjectClass(List.class).hasObject(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetArrayFromValidListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + try (val vector = extension.createListVector("test-list-vector")) { + val i = new AtomicInteger(0); + val sut = new ListVectorAccessor(vector, i::get, consumer); + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()).toArray(); + val actual = (Object[]) sut.getArray().getArray(); + collector.assertThat(actual).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetArrayFromValidLargeListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + try (val vector = extension.createLargeListVector("test-list-vector")) { + val i = new AtomicInteger(0); + val sut = new LargeListVectorAccessor(vector, i::get, consumer); + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()).toArray(); + val actual = (Object[]) sut.getArray().getArray(); + collector.assertThat(actual).isEqualTo(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetArrayFromNulledListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + try (val vector = nulledOutVector(extension.createListVector("test-list-vector"))) { + val i = new AtomicInteger(0); + val sut = new ListVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val actual = sut.getArray(); + collector.assertThat(actual).isNull(); + } + } + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + @SneakyThrows + @Test + void testGetArrayFromNulledLargeListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + try (val vector = nulledOutVector(extension.createLargeListVector("test-list-vector"))) { + val i = new AtomicInteger(0); + val sut = new LargeListVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val actual = sut.getArray(); + collector.assertThat(actual).isNull(); + } + } + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + @SneakyThrows + @Test + void testGetObjectAndGetObjectClassFromValidLargeListVector() { + val values = createListVectors(); + val expectedNullChecks = values.size(); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createLargeListVector("test-large-list-vector")) { + val i = new AtomicInteger(0); + val sut = new LargeListVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector.assertThat(sut).hasObjectClass(List.class).hasObject(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + private List> createListVectors() { + return IntStream.range(0, total) + .mapToObj(x -> IntStream.range(0, 5).map(j -> j * x).boxed().collect(Collectors.toList())) + .collect(Collectors.toUnmodifiableList()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessorTest.java new file mode 100644 index 0000000..a08f27c --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeStampVectorAccessorTest.java @@ -0,0 +1,512 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.Random; +import java.util.TimeZone; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class TimeStampVectorAccessorTest { + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @InjectSoftAssertions + private SoftAssertions collector; + + public static final int BASE_YEAR = 2020; + public static final int NUM_OF_METHODS = 4; + + @Test + @SneakyThrows + void testTimestampNanoVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.getAndIncrement()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimestampNanoTZVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampNanoTZVector(values, "UTC")) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimestampMicroVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimestampMicroTZVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampMicroTZVector(values, "UTC")) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimeStampMilliVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimeStampMilliTZVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampMilliTZVector(values, "UTC")) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = values.get(i.get()); + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimestampSecVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = (values.get(i.get()) / 1000) * 1000; + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testTimestampSecTZVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createTimeStampSecTZVector(values, "UTC")) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + val currentNumber = monthNumber.get(i.get()); + val currentMillis = (values.get(i.get()) / 1000) * 1000; + + collector + .assertThat(timestampValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(dateValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector + .assertThat(timeValue) + .hasYear(BASE_YEAR + currentNumber) + .hasMonth(currentNumber + 1) + .hasDayOfMonth(currentNumber) + .hasHourOfDay(currentNumber) + .hasMinute(currentNumber) + .hasSecond(currentNumber); + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis)); + } + } + consumer.assertThat().hasNullSeen(0).hasNotNullSeen(NUM_OF_METHODS * values.size()); + } + + @Test + @SneakyThrows + void testNulledTimestampVector() { + Calendar calendar = Calendar.getInstance(); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + List monthNumber = getRandomMonthNumber(); + val values = getMilliSecondValues(calendar, monthNumber); + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createTimeStampSecTZVector(values, "UTC"))) { + val i = new AtomicInteger(0); + val sut = new TimeStampVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val timestampValue = sut.getTimestamp(calendar); + val dateValue = sut.getDate(calendar); + val timeValue = sut.getTime(calendar); + val stringValue = sut.getString(); + + collector.assertThat(timestampValue).isNull(); + collector.assertThat(dateValue).isNull(); + collector.assertThat(timeValue).isNull(); + collector.assertThat(stringValue).isNull(); + } + } + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(NUM_OF_METHODS * values.size()); + } + + private List getMilliSecondValues(Calendar calendar, List monthNumber) { + List result = new ArrayList<>(); + for (int currentNumber : monthNumber) { + calendar.set( + BASE_YEAR + currentNumber, + currentNumber, + currentNumber, + currentNumber, + currentNumber, + currentNumber); + result.add(calendar.getTimeInMillis()); + } + return result; + } + + private List getRandomMonthNumber() { + Random rand = new Random(); + int valA = rand.nextInt(10) + 1; + int valB = rand.nextInt(10) + 1; + int valC = rand.nextInt(10) + 1; + return List.of(valA, valB, valC); + } + + private String getISOString(Long millis) { + return Instant.ofEpochMilli(millis).toString(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessorTest.java new file mode 100644 index 0000000..f65d610 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/TimeVectorAccessorTest.java @@ -0,0 +1,864 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.Constants.ISO_TIME_FORMAT; +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.sql.Time; +import java.time.Instant; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.TimeZone; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class TimeVectorAccessorTest { + + @InjectSoftAssertions + private SoftAssertions collector; + + public static final String ASIA_BANGKOK = "Asia/Bangkok"; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + private static final int[] edgeCaseVals = {0, 1, 60, 60 * 60, (24 * 60 * 60 - 1)}; + public static final int NUM_OF_CALLS = 3; + + @SneakyThrows + @Test + void testTimeNanoVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + + try (val vector = extension.createTimeNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Time.class); + } + } + } + + @SneakyThrows + @Test + void testTimeMicroVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + + try (val vector = extension.createTimeMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Time.class); + } + } + } + + @SneakyThrows + @Test + void testTimeMilliVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + + try (val vector = extension.createTimeMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Time.class); + } + } + } + + @SneakyThrows + @Test + void testTimeSecVectorGetObjectClass() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + + try (val vector = extension.createTimeSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut).hasObjectClass(Time.class); + } + } + } + + @SneakyThrows + @Test + void testTimeNanoVectorGetTime() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + try (val vector = extension.createTimeNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.NANOSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTime(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(calendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(defaultCalendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * NUM_OF_CALLS).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMicroVectorGetTime() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + try (val vector = extension.createTimeMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.MICROSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTime(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(calendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(defaultCalendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * NUM_OF_CALLS).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMilliVectorGetTime() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + try (val vector = extension.createTimeMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(values.get(i.get())); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTime(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(calendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(defaultCalendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * NUM_OF_CALLS).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeSecVectorGetTime() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + try (val vector = extension.createTimeSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.SECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTime(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(calendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + collector + .assertThat(sut.getTime(defaultCalendar)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * NUM_OF_CALLS).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeNanoGetObject() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.NANOSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector.assertThat(sut.getObject()).isInstanceOf(Time.class); + collector + .assertThat((Time) sut.getObject()) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMicroGetObject() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.MICROSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector.assertThat(sut.getObject()).isInstanceOf(Time.class); + collector + .assertThat((Time) sut.getObject()) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMilliGetObject() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(values.get(i.get())); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector.assertThat(sut.getObject()).isInstanceOf(Time.class); + collector + .assertThat((Time) sut.getObject()) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeSecGetObject() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.SECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector.assertThat(sut.getObject()).isInstanceOf(Time.class); + collector + .assertThat((Time) sut.getObject()) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls * 2).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeNanoGetTimestamp() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.NANOSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTimestamp(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMicroGetTimestamp() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.MICROSECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTimestamp(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeMilliGetTimestamp() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(values.get(i.get())); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTimestamp(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeSecGetTimestamp() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expectedTime = getExpectedTime(TimeUnit.SECONDS.toMillis(values.get(i.get()))); + val expectedHour = expectedTime.get("hour"); + val expectedMinute = expectedTime.get("minute"); + val expectedSecond = expectedTime.get("second"); + val expectedMilli = expectedTime.get("milli"); + + collector + .assertThat(sut.getTimestamp(null)) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMilli); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNonNulls).hasNullSeen(expectedNulls); + } + + @SneakyThrows + @Test + void testTimeNanoGetString() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeNanoVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val currentNanos = values.get(i.get()); + + collector.assertThat(stringValue).isEqualTo(getISOString(currentNanos, TimeUnit.NANOSECONDS)); + } + } + consumer.assertThat().hasNullSeen(expectedNulls).hasNotNullSeen(expectedNonNulls); + } + + @SneakyThrows + @Test + void testTimeMicroGetString() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMicroVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val currentMicros = values.get(i.get()); + + collector.assertThat(stringValue).isEqualTo(getISOString(currentMicros, TimeUnit.MICROSECONDS)); + } + } + consumer.assertThat().hasNullSeen(expectedNulls).hasNotNullSeen(expectedNonNulls); + } + + @SneakyThrows + @Test + void testTimeMilliGetString() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeMilliVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val currentMillis = values.get(i.get()); + + collector.assertThat(stringValue).isEqualTo(getISOString(currentMillis, TimeUnit.MILLISECONDS)); + } + } + consumer.assertThat().hasNullSeen(expectedNulls).hasNotNullSeen(expectedNonNulls); + } + + @SneakyThrows + @Test + void testTimeSecGetString() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + final int expectedNulls = (int) values.stream().filter(Objects::isNull).count(); + final int expectedNonNulls = values.size() - expectedNulls; + + try (val vector = extension.createTimeSecVector(values)) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val stringValue = sut.getString(); + val currentSec = values.get(i.get()); + + collector.assertThat(stringValue).isEqualTo(getISOString(currentSec, TimeUnit.SECONDS)); + } + } + consumer.assertThat().hasNullSeen(expectedNulls).hasNotNullSeen(expectedNonNulls); + } + + @SneakyThrows + @Test + void testNulledOutTimeNanoVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.NANOSECONDS); + + try (val vector = nulledOutVector(extension.createTimeNanoVector(values))) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut.getTime(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + @SneakyThrows + @Test + void testNulledOutTimeMicroVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomLongs(TimeUnit.MICROSECONDS); + + try (val vector = nulledOutVector(extension.createTimeMicroVector(values))) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut.getTime(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + @SneakyThrows + @Test + void testNulledOutTimeMilliVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.MILLISECONDS); + + try (val vector = nulledOutVector(extension.createTimeMilliVector(values))) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut.getTime(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + @SneakyThrows + @Test + void testNulledOutTimeSecVectorReturnsNull() { + val consumer = new TestWasNullConsumer(collector); + + List values = generateRandomIntegers(TimeUnit.SECONDS); + + try (val vector = nulledOutVector(extension.createTimeSecVector(values))) { + val i = new AtomicInteger(0); + val sut = new TimeVectorAccessor(vector, i::get, consumer); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(ASIA_BANGKOK)); + Calendar defaultCalendar = Calendar.getInstance(TimeZone.getDefault()); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector.assertThat(sut.getTime(defaultCalendar)).isNull(); + collector.assertThat(sut.getTimestamp(calendar)).isNull(); + collector.assertThat(sut.getObject()).isNull(); + collector.assertThat(sut.getString()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(values.size() * 4); + } + + private static List generateRandomLongs(TimeUnit timeUnit) { + int rightLimit = 86400; + val i = new AtomicInteger(0); + List result = new ArrayList<>(); + + for (; i.get() < edgeCaseVals.length; i.incrementAndGet()) { + switch (timeUnit) { + case NANOSECONDS: + result.add(ThreadLocalRandom.current().nextLong(rightLimit * 1_000_000_000L)); + result.add(edgeCaseVals[i.get()] * 1_000_000_000L); + break; + case MICROSECONDS: + result.add(ThreadLocalRandom.current().nextLong(rightLimit * 1_000_000L)); + result.add(edgeCaseVals[i.get()] * 1_000_000L); + break; + default: + } + } + + return result; + } + + private static List generateRandomIntegers(TimeUnit timeUnit) { + int rightLimit = 86400; + val i = new AtomicInteger(0); + List result = new ArrayList<>(); + Random rnd = new Random(); + + for (; i.get() < edgeCaseVals.length; i.incrementAndGet()) { + switch (timeUnit) { + case MILLISECONDS: + result.add(rnd.nextInt(rightLimit * 1_000)); + result.add(edgeCaseVals[i.get()] * 1_000); + break; + case SECONDS: + result.add(rnd.nextInt(rightLimit)); + result.add(edgeCaseVals[i.get()]); + break; + default: + } + } + + return result; + } + + private Map getExpectedTime(long epochMilli) { + final ZoneId zoneId = TimeZone.getTimeZone("UTC").toZoneId(); + final Instant instant = Instant.ofEpochMilli(epochMilli); + final ZonedDateTime zdt = ZonedDateTime.ofInstant(instant, zoneId); + + final int expectedHour = zdt.getHour(); + final int expectedMinute = zdt.getMinute(); + final int expectedSecond = zdt.getSecond(); + final int expectedMillisecond = (int) TimeUnit.NANOSECONDS.toMillis(zdt.getNano()); + + return Map.of( + "hour", expectedHour, "minute", expectedMinute, "second", expectedSecond, "milli", expectedMillisecond); + } + + private String getISOString(Long value, TimeUnit unit) { + Long adjustedNanos; + switch (unit) { + case NANOSECONDS: + adjustedNanos = value; + break; + case MICROSECONDS: + adjustedNanos = value * 1_000; + break; + default: + adjustedNanos = value; + } + + val localTime = LocalTime.ofNanoOfDay(adjustedNanos); + val result = localTime.format(DateTimeFormatter.ofPattern(ISO_TIME_FORMAT)); + return result; + } + + private String getISOString(Integer value, TimeUnit unit) { + Integer adjustedSeconds; + switch (unit) { + case MILLISECONDS: + adjustedSeconds = value / 1_000; + break; + case SECONDS: + adjustedSeconds = value; + break; + default: + adjustedSeconds = value; + } + + val localTime = LocalTime.ofSecondOfDay(adjustedSeconds); + val result = localTime.format(DateTimeFormatter.ofPattern(ISO_TIME_FORMAT)); + return result; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessorTest.java new file mode 100644 index 0000000..3dc1785 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/accessor/impl/VarCharVectorAccessorTest.java @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.accessor.impl; + +import static com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension.nulledOutVector; + +import com.salesforce.datacloud.jdbc.core.accessor.SoftAssertions; +import com.salesforce.datacloud.jdbc.util.RootAllocatorTestExtension; +import com.salesforce.datacloud.jdbc.util.TestWasNullConsumer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.api.extension.RegisterExtension; + +@ExtendWith(SoftAssertionsExtension.class) +public class VarCharVectorAccessorTest { + private static final int total = 8; + + @InjectSoftAssertions + private SoftAssertions collector; + + @RegisterExtension + public static RootAllocatorTestExtension extension = new RootAllocatorTestExtension(); + + @SneakyThrows + @Test + void testGetStringGetObjectAndGetObjectClassFromValidVarCharVector() { + val values = getStrings(); + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createVarCharVectorFrom(values)) { + val i = new AtomicInteger(0); + val sut = new VarCharVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasObjectClass(String.class) + .hasBytes(expected.getBytes(StandardCharsets.UTF_8)) + .hasObject(expected) + .hasString(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetStringGetObjectAndGetObjectClassFromNulledVarCharVector() { + val values = getStrings(); + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createVarCharVectorFrom(values))) { + val i = new AtomicInteger(0); + val sut = new VarCharVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector + .assertThat(sut) + .hasObjectClass(String.class) + .hasObject(null) + .hasString(null); + collector.assertThat(sut.getBytes()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + @SneakyThrows + @Test + void testGetStringGetObjectAndGetObjectClassFromValidLargeVarCharVector() { + val values = getStrings(); + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = extension.createLargeVarCharVectorFrom(values)) { + val i = new AtomicInteger(0); + val sut = new VarCharVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + val expected = values.get(i.get()); + collector + .assertThat(sut) + .hasObjectClass(String.class) + .hasBytes(expected.getBytes(StandardCharsets.UTF_8)) + .hasObject(expected) + .hasString(expected); + } + } + + consumer.assertThat().hasNotNullSeen(expectedNullChecks).hasNullSeen(0); + } + + @SneakyThrows + @Test + void testGetStringGetObjectAndGetObjectClassFromNulledLargeVarCharVector() { + val values = getStrings(); + val expectedNullChecks = values.size() * 3; // seen thrice since getObject and getString both call getBytes + val consumer = new TestWasNullConsumer(collector); + + try (val vector = nulledOutVector(extension.createLargeVarCharVectorFrom(values))) { + val i = new AtomicInteger(0); + val sut = new VarCharVectorAccessor(vector, i::get, consumer); + + for (; i.get() < vector.getValueCount(); i.incrementAndGet()) { + collector + .assertThat(sut) + .hasObjectClass(String.class) + .hasObject(null) + .hasString(null); + collector.assertThat(sut.getBytes()).isNull(); + } + } + + consumer.assertThat().hasNotNullSeen(0).hasNullSeen(expectedNullChecks); + } + + private List getStrings() { + return IntStream.range(0, total) + .mapToObj(x -> UUID.randomUUID().toString()) + .collect(Collectors.toUnmodifiableList()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListenerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListenerTest.java new file mode 100644 index 0000000..740ebf4 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusListenerTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcClientExecutor; +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import java.time.Duration; +import java.util.List; +import java.util.UUID; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +@Slf4j +@ExtendWith(MockitoExtension.class) +@ExtendWith(InProcessGrpcMockExtension.class) +public class AdaptiveQueryStatusListenerTest extends HyperGrpcTestBase { + private final String query = "select * from stuff"; + private final QueryParam.TransferMode mode = QueryParam.TransferMode.ADAPTIVE; + + private final RealisticArrowGenerator.Student alice = new RealisticArrowGenerator.Student(1, "alice", 2); + private final RealisticArrowGenerator.Student bob = new RealisticArrowGenerator.Student(2, "bob", 3); + + private final List twoStudents = List.of(alice, bob); + + @SneakyThrows + @Test + void itWillWaitUntilResultsProducedOrFinishedToProduceStatus() { + val queryId = UUID.randomUUID().toString(); + setupExecuteQuery(queryId, query, mode, executeQueryResponse(queryId, null, 1)); + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.RESULTS_PRODUCED, 3); + val listener = sut(query); + QueryStatusListenerAssert.assertThat(listener).hasStatus(QueryStatus.CompletionStatus.RESULTS_PRODUCED.name()); + } + + @SneakyThrows + @Test + void itReturnsNoChunkResults() { + val queryId = UUID.randomUUID().toString(); + setupExecuteQuery( + queryId, + query, + mode, + executeQueryResponseWithData(twoStudents), + executeQueryResponse(queryId, QueryStatus.CompletionStatus.RESULTS_PRODUCED, 1)); + + val resultSet = sut(query).generateResultSet(); + + assertThat(resultSet).isNotNull(); + assertThat(resultSet.getQueryId()).isEqualTo(queryId); + assertThat(resultSet.isReady()).isTrue(); + + resultSet.next(); + assertThat(resultSet.getInt("id")).isEqualTo(alice.getId()); + assertThat(resultSet.getString("name")).isEqualTo(alice.getName()); + assertThat(resultSet.getDouble("grade")).isEqualTo(alice.getGrade()); + + resultSet.next(); + assertThat(resultSet.getInt("id")).isEqualTo(bob.getId()); + assertThat(resultSet.getString("name")).isEqualTo(bob.getName()); + assertThat(resultSet.getDouble("grade")).isEqualTo(bob.getGrade()); + } + + @SneakyThrows + @Test + void itIsAlwaysReadyBecauseWeImmediatelyGetResultsThenBlockForAsyncIfNecessary() { + val queryId = UUID.randomUUID().toString(); + setupExecuteQuery(queryId, query, mode, executeQueryResponse(queryId, null, 1)); + val listener = sut(query); + + assertThat(listener.isReady()).isTrue(); + } + + @SneakyThrows + @Test + void itCatchesAndMakesSqlExceptionWhenQueryFails() { + val client = Mockito.mock(HyperGrpcClientExecutor.class); + Mockito.when(client.executeAdaptiveQuery(Mockito.anyString())) + .thenThrow(new StatusRuntimeException(Status.ABORTED)); + + val ex = Assertions.assertThrows( + DataCloudJDBCException.class, () -> AdaptiveQueryStatusListener.of("any", client, Duration.ZERO)); + AssertionsForClassTypes.assertThat(ex) + .hasMessageContaining("Failed to execute query: ") + .hasRootCauseInstanceOf(StatusRuntimeException.class); + } + + @SneakyThrows + QueryStatusListener sut(String query) { + return AdaptiveQueryStatusListener.of(query, hyperGrpcClient, Duration.ofMinutes(1)); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPollerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPollerTest.java new file mode 100644 index 0000000..10a221b --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AdaptiveQueryStatusPollerTest.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryStatus; +import java.util.UUID; +import lombok.val; +import org.junit.jupiter.api.Test; + +class AdaptiveQueryStatusPollerTest extends HyperGrpcTestBase { + @Test + void mapCapturesQueryStatus() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + assertThat(sut.pollQueryStatus()).isEqualTo(null); + sut.map(null); + assertThat(sut.pollQueryStatus()).isEqualTo(null); + val status = QueryStatus.newBuilder() + .setQueryId(UUID.randomUUID().toString()) + .setChunkCount(4L) + .setCompletionStatus(QueryStatus.CompletionStatus.RESULTS_PRODUCED) + .build(); + sut.map(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder().setQueryStatus(status)) + .build()); + assertThat(sut.pollQueryStatus()).isEqualTo(status); + sut.map(null); + assertThat(sut.pollQueryStatus()).isEqualTo(status); + } + + @Test + void mapCapturesChunkCount() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RESULTS_PRODUCED, 3); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + sut.map(null); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + sut.map(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder().setChunkCount(5L))) + .build()); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + sut.map(null); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + } + + @Test + void mapCapturesCompletionStatus() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RESULTS_PRODUCED, 3); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + } + + @Test + void pollChunkCountOnlyCallsQueryInfoIfUnfinished() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + sut.map(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setCompletionStatus(QueryStatus.CompletionStatus.RUNNING) + .build()) + .build()) + .build()); + setupGetQueryInfo(id, QueryStatus.CompletionStatus.FINISHED, 3); + assertThat(sut.pollChunkCount()).isEqualTo(3L); + verifyGetQueryInfo(1); + } + + @Test + void pollChunkCountDoesNotCallQueryInfoIfFinished() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + sut.map(ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setCompletionStatus(QueryStatus.CompletionStatus.FINISHED) + .setChunkCount(5L) + .build()) + .build()) + .build()); + assertThat(sut.pollChunkCount()).isEqualTo(5L); + verifyGetQueryInfo(0); + } + + @Test + void mapExtractsQueryResult() { + val id = UUID.randomUUID().toString(); + val sut = new AdaptiveQueryStatusPoller(id, hyperGrpcClient); + val data = RealisticArrowGenerator.data(); + val actual = sut.map( + ExecuteQueryResponse.newBuilder().setQueryResult(data).build()) + .orElseThrow(); + assertThat(actual).isEqualTo(data); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java new file mode 100644 index 0000000..f634fa9 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusListenerTest.java @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import com.salesforce.datacloud.jdbc.core.DataCloudStatement; +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.RealisticArrowGenerator; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryStatus; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Properties; +import java.util.Random; +import java.util.UUID; +import java.util.stream.Collectors; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mockito; + +@ExtendWith(InProcessGrpcMockExtension.class) +class AsyncQueryStatusListenerTest extends HyperGrpcTestBase { + private final String query = "select * from stuff"; + private final QueryParam.TransferMode mode = QueryParam.TransferMode.ASYNC; + + @ParameterizedTest + @ValueSource( + ints = { + QueryStatus.CompletionStatus.RUNNING_VALUE, + QueryStatus.CompletionStatus.RESULTS_PRODUCED_VALUE, + QueryStatus.CompletionStatus.FINISHED_VALUE + }) + void itCanGetStatus(int value) { + val status = QueryStatus.CompletionStatus.forNumber(value); + + val queryId = UUID.randomUUID().toString(); + setupExecuteQuery(queryId, query, mode); + val listener = sut(query); + + setupGetQueryInfo(queryId, status); + assertThat(listener.getStatus()) + .isEqualTo(Objects.requireNonNull(status).name()); + } + + @Test + void itThrowsIfStreamIsRequestedBeforeReady() { + val queryId = UUID.randomUUID().toString(); + + setupExecuteQuery(queryId, query, mode); + val listener = sut(query); + + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.RUNNING); + + QueryStatusListenerAssert.assertThat(listener).isNotReady(); + val ex = Assertions.assertThrows( + DataCloudJDBCException.class, () -> listener.stream().collect(Collectors.toList())); + AssertionsForClassTypes.assertThat(ex).hasMessageContaining(QueryStatusListener.BEFORE_READY); + } + + @SneakyThrows + @Test + void itCorrectlyReturnsStreamOfParts() { + val queryId = UUID.randomUUID().toString(); + setupExecuteQuery(queryId, query, mode); + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.FINISHED, 2); + val twoStudents = List.of( + new RealisticArrowGenerator.Student(1, "alice", 2), new RealisticArrowGenerator.Student(2, "bob", 3)); + setupGetQueryResult(queryId, 0, 2, twoStudents); + setupGetQueryResult(queryId, 1, 2, twoStudents); + + val iterator = sut(query).stream().iterator(); + + assertThat(iterator) + .as("Retrieves status under the hood so we now know there are two chunks on the server") + .hasNext(); + iterator.next(); + assertThat(iterator) + .as("There should be an additional QueryResult in memory and one chunk left on the server") + .hasNext(); + iterator.next(); + assertThat(iterator) + .as("We've exhausted our in-memory QueryResults but there's still a chunk left on the server") + .hasNext(); + iterator.next(); + assertThat(iterator).as("There's still one QueryResult left in memory").hasNext(); + iterator.next(); + assertThat(iterator) + .as("We've already used up all the QueryResults in memory and there are no remaining chunks") + .isExhausted(); + + Assertions.assertThrows(NoSuchElementException.class, iterator::next); + } + + @SneakyThrows + @Test + public void itCorrectlyReturnsResultSet() { + val random = new Random(10); + val queryId = UUID.randomUUID().toString(); + val studentId = random.nextInt(); + val studentGrade = random.nextDouble(); + val studentName = UUID.randomUUID().toString(); + setupHyperGrpcClientWithMockedResultSet(queryId, List.of()); + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.FINISHED); + setupGetQueryResult( + queryId, 0, 1, List.of(new RealisticArrowGenerator.Student(studentId, studentName, studentGrade))); + val resultSet = sut(query).generateResultSet(); + assertThat(resultSet).isNotNull(); + assertThat(resultSet.getQueryId()).isEqualTo(queryId); + assertThat(resultSet.isReady()).isTrue(); + + resultSet.next(); + assertThat(resultSet.getInt(1)).isEqualTo(studentId); + assertThat(resultSet.getString(2)).isEqualTo(studentName); + assertThat(resultSet.getDouble(3)).isEqualTo(studentGrade); + } + + @SneakyThrows + @Test + void userShouldExecuteQueryBeforeAccessingResultSet() { + try (val statement = statement()) { + val ex = assertThrows(DataCloudJDBCException.class, statement::getResultSet); + AssertionsForClassTypes.assertThat(ex) + .hasMessageContaining("a query was not executed before attempting to access results"); + } + } + + @SneakyThrows + @Test + void userShouldWaitForQueryBeforeAccessingResultSet() { + val queryId = UUID.randomUUID().toString(); + setupHyperGrpcClientWithMockedResultSet(queryId, List.of()); + setupGetQueryInfo(queryId, QueryStatus.CompletionStatus.RUNNING); + + try (val statement = statement().executeAsyncQuery(query)) { + val ex = assertThrows(DataCloudJDBCException.class, statement::getResultSet); + AssertionsForClassTypes.assertThat(ex).hasMessageContaining("query results were not ready"); + } + } + + DataCloudStatement statement() { + val connection = Mockito.mock(DataCloudConnection.class); + Mockito.when(connection.getProperties()).thenReturn(new Properties()); + Mockito.when(connection.getExecutor()).thenReturn(hyperGrpcClient); + + return new DataCloudStatement(connection); + } + + @SneakyThrows + QueryStatusListener sut(String query) { + return AsyncQueryStatusListener.of(query, hyperGrpcClient); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPollerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPollerTest.java new file mode 100644 index 0000000..2497cec --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/AsyncQueryStatusPollerTest.java @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.datacloud.jdbc.util.GrpcUtils; +import com.salesforce.hyperdb.grpc.HyperServiceGrpc; +import com.salesforce.hyperdb.grpc.QueryStatus; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import java.util.UUID; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.grpcmock.GrpcMock; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(InProcessGrpcMockExtension.class) +class AsyncQueryStatusPollerTest extends HyperGrpcTestBase { + @Test + void testPollTracksChunkCount() throws SQLException { + val id = UUID.randomUUID().toString(); + val poller = new AsyncQueryStatusPoller(id, hyperGrpcClient); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.FINISHED, 3); + + assertThat(poller.pollChunkCount()).isEqualTo(3); + } + + @Test + void testPollChunkCountThrowsWhenNotReady() { + val id = UUID.randomUUID().toString(); + val poller = new AsyncQueryStatusPoller(id, hyperGrpcClient); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RUNNING); + + Assertions.assertThrows(DataCloudJDBCException.class, poller::pollChunkCount); + } + + @Test + void testPollIsReadyStopsFetchingAfterFinished() { + val id = UUID.randomUUID().toString(); + val poller = new AsyncQueryStatusPoller(id, hyperGrpcClient); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.FINISHED); + + assertThat(poller.pollIsReady()).isTrue(); + assertThat(poller.pollIsReady()).isTrue(); + assertThat(poller.pollQueryStatus().getCompletionStatus()).isEqualTo(QueryStatus.CompletionStatus.FINISHED); + assertThat(poller.pollQueryStatus().getCompletionStatus()).isEqualTo(QueryStatus.CompletionStatus.FINISHED); + + verifyGetQueryInfo(1); + } + + @Test + void testPollHappyPath() { + val id = UUID.randomUUID().toString(); + val poller = new AsyncQueryStatusPoller(id, hyperGrpcClient); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RUNNING); + assertThat(poller.pollIsReady()).isFalse(); + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RUNNING); + assertThat(poller.pollQueryStatus().getCompletionStatus()).isEqualTo(QueryStatus.CompletionStatus.RUNNING); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.RESULTS_PRODUCED); + assertThat(poller.pollIsReady()).isTrue(); + assertThat(poller.pollQueryStatus().getCompletionStatus()) + .isEqualTo(QueryStatus.CompletionStatus.RESULTS_PRODUCED); + + setupGetQueryInfo(id, QueryStatus.CompletionStatus.FINISHED); + assertThat(poller.pollIsReady()).isTrue(); + assertThat(poller.pollQueryStatus().getCompletionStatus()).isEqualTo(QueryStatus.CompletionStatus.FINISHED); + + verifyGetQueryInfo(5); + } + + @Test + void throwsDataCloudJDBCExceptionOnPollFailure() { + + val fakeException = GrpcUtils.getFakeStatusRuntimeExceptionAsInvalidArgument(); + GrpcMock.stubFor(GrpcMock.unaryMethod(HyperServiceGrpc.getGetQueryInfoMethod()) + .willReturn(GrpcMock.exception(fakeException))); + + val poller = new AsyncQueryStatusPoller(UUID.randomUUID().toString(), hyperGrpcClient); + + val ex = assertThrows(DataCloudJDBCException.class, poller::pollQueryStatus); + AssertionsForClassTypes.assertThat(ex) + .hasMessageContaining("Table not found") + .hasRootCauseInstanceOf(StatusRuntimeException.class); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java new file mode 100644 index 0000000..c7b17e5 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerAssert.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import org.assertj.core.api.AbstractObjectAssert; +import org.assertj.core.util.Objects; + +/** {@link QueryStatusListener} specific assertions - Generated by CustomAssertionGenerator. */ +@javax.annotation.Generated(value = "assertj-assertions-generator") +public class QueryStatusListenerAssert extends AbstractObjectAssert { + + /** + * Creates a new {@link QueryStatusListenerAssert} to make assertions on actual QueryStatusListener. + * + * @param actual the QueryStatusListener we want to make assertions on. + */ + public QueryStatusListenerAssert(QueryStatusListener actual) { + super(actual, QueryStatusListenerAssert.class); + } + + /** + * An entry point for QueryStatusListenerAssert to follow AssertJ standard assertThat() statements.
+ * With a static import, one can write directly: assertThat(myQueryStatusListener) and get specific + * assertion with code completion. + * + * @param actual the QueryStatusListener we want to make assertions on. + * @return a new {@link QueryStatusListenerAssert} + */ + @org.assertj.core.util.CheckReturnValue + public static QueryStatusListenerAssert assertThat(QueryStatusListener actual) { + return new QueryStatusListenerAssert(actual); + } + + /** + * Verifies that the actual QueryStatusListener's query is equal to the given one. + * + * @param query the given query to compare the actual QueryStatusListener's query to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryStatusListener's query is not equal to the given one. + */ + public QueryStatusListenerAssert hasQuery(String query) { + // check that actual QueryStatusListener we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting query of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + String actualQuery = actual.getQuery(); + if (!Objects.areEqual(actualQuery, query)) { + failWithMessage(assertjErrorMessage, actual, query, actualQuery); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryStatusListener's queryId is equal to the given one. + * + * @param queryId the given queryId to compare the actual QueryStatusListener's queryId to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryStatusListener's queryId is not equal to the given one. + */ + public QueryStatusListenerAssert hasQueryId(String queryId) { + // check that actual QueryStatusListener we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting queryId of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + String actualQueryId = actual.getQueryId(); + if (!Objects.areEqual(actualQueryId, queryId)) { + failWithMessage(assertjErrorMessage, actual, queryId, actualQueryId); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryStatusListener is ready. + * + * @return this assertion object. + * @throws AssertionError - if the actual QueryStatusListener is not ready. + */ + public QueryStatusListenerAssert isReady() { + // check that actual QueryStatusListener we want to make assertions on is not null. + isNotNull(); + + // check that property call/field access is true + if (!actual.isReady()) { + failWithMessage("\nExpecting that actual QueryStatusListener is ready but is not."); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryStatusListener is not ready. + * + * @return this assertion object. + * @throws AssertionError - if the actual QueryStatusListener is ready. + */ + public QueryStatusListenerAssert isNotReady() { + // check that actual QueryStatusListener we want to make assertions on is not null. + isNotNull(); + + // check that property call/field access is false + if (actual.isReady()) { + failWithMessage("\nExpecting that actual QueryStatusListener is not ready but is."); + } + + // return the current assertion for method chaining + return this; + } + + /** + * Verifies that the actual QueryStatusListener's status is equal to the given one. + * + * @param status the given status to compare the actual QueryStatusListener's status to. + * @return this assertion object. + * @throws AssertionError - if the actual QueryStatusListener's status is not equal to the given one. + */ + public QueryStatusListenerAssert hasStatus(String status) { + // check that actual QueryStatusListener we want to make assertions on is not null. + isNotNull(); + + // overrides the default error message with a more explicit one + String assertjErrorMessage = "\nExpecting status of:\n <%s>\nto be:\n <%s>\nbut was:\n <%s>"; + + // null safe check + String actualStatus = actual.getStatus(); + if (!Objects.areEqual(actualStatus, status)) { + failWithMessage(assertjErrorMessage, actual, status, actualStatus); + } + + // return the current assertion for method chaining + return this; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerTest.java new file mode 100644 index 0000000..a604732 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/QueryStatusListenerTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.hyperdb.grpc.QueryParam; +import java.time.Duration; +import java.util.UUID; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +class QueryStatusListenerTest extends HyperGrpcTestBase { + private final String query = "select * from stuff"; + + @SneakyThrows + private QueryStatusListener sut(String query, QueryParam.TransferMode mode) { + return sut(query, mode, Duration.ofMinutes(1)); + } + + @SneakyThrows + private QueryStatusListener sut(String query, QueryParam.TransferMode mode, Duration timeout) { + switch (mode) { + case SYNC: + return SyncQueryStatusListener.of(query, hyperGrpcClient); + case ASYNC: + return AsyncQueryStatusListener.of(query, hyperGrpcClient); + case ADAPTIVE: + return AdaptiveQueryStatusListener.of(query, hyperGrpcClient, timeout); + default: + Assertions.fail("QueryStatusListener mode not supported. mode=" + mode.name()); + return null; + } + } + + private static Stream supported() { + return Stream.of(QueryParam.TransferMode.SYNC, QueryParam.TransferMode.ASYNC, QueryParam.TransferMode.ADAPTIVE); + } + + @ParameterizedTest + @MethodSource("supported") + void itKeepsTrackOfQueryAndQueryId(QueryParam.TransferMode mode) { + val query = this.query + UUID.randomUUID(); + val queryId = UUID.randomUUID().toString(); + + setupExecuteQuery(queryId, query, mode); + val listener = sut(query, mode); + + QueryStatusListenerAssert.assertThat(listener).hasQueryId(queryId).hasQuery(query); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListenerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListenerTest.java new file mode 100644 index 0000000..5353401 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/core/listener/SyncQueryStatusListenerTest.java @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.core.listener; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.hyperdb.grpc.ExecuteQueryResponse; +import com.salesforce.hyperdb.grpc.QueryInfo; +import com.salesforce.hyperdb.grpc.QueryParam; +import com.salesforce.hyperdb.grpc.QueryStatus; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import org.grpcmock.junit5.InProcessGrpcMockExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(InProcessGrpcMockExtension.class) +class SyncQueryStatusListenerTest extends HyperGrpcTestBase { + private final String query = "select * from stuff"; + private final QueryParam.TransferMode mode = QueryParam.TransferMode.SYNC; + + @SneakyThrows + @Test + void isReady() { + setupExecuteQuery("", query, QueryParam.TransferMode.SYNC); + val listener = SyncQueryStatusListener.of(query, hyperGrpcClient); + QueryStatusListenerAssert.assertThat(listener).isReady(); + } + + @SneakyThrows + @Test + void getStatus() { + val expected = Stream.of(QueryStatus.CompletionStatus.values()) + .filter(cs -> cs != QueryStatus.CompletionStatus.UNRECOGNIZED) + .collect(Collectors.toList()); + val responses = expected.stream() + .map(s -> ExecuteQueryResponse.newBuilder() + .setQueryInfo(QueryInfo.newBuilder() + .setQueryStatus(QueryStatus.newBuilder() + .setCompletionStatus(s) + .build()) + .build()) + .build()) + .collect(Collectors.toUnmodifiableList()); + setupExecuteQuery("", query, QueryParam.TransferMode.SYNC, responses.toArray(ExecuteQueryResponse[]::new)); + val listener = SyncQueryStatusListener.of(query, hyperGrpcClient); + assertThat(listener.getStatus()).isNull(); + val iterator = listener.stream().iterator(); + + expected.forEach(exp -> { + iterator.next(); + QueryStatusListenerAssert.assertThat(listener).hasStatus(exp.name()); + }); + } + + @SneakyThrows + @Test + void getQuery() { + val randomQuery = this.query + UUID.randomUUID(); + val id = UUID.randomUUID().toString(); + setupExecuteQuery(id, randomQuery, QueryParam.TransferMode.SYNC); + val listener = SyncQueryStatusListener.of(randomQuery, hyperGrpcClient); + QueryStatusListenerAssert.assertThat(listener).hasQuery(randomQuery); + } + + @SneakyThrows + @Test + void getQueryId() { + val id = UUID.randomUUID().toString(); + setupExecuteQuery(id, query, QueryParam.TransferMode.SYNC); + val listener = SyncQueryStatusListener.of(query, hyperGrpcClient); + QueryStatusListenerAssert.assertThat(listener).hasQueryId(id); + } + + @SneakyThrows + @Test + void getResultSet() { + val id = UUID.randomUUID().toString(); + setupHyperGrpcClientWithMockedResultSet(id, List.of()); + + val listener = SyncQueryStatusListener.of(query, hyperGrpcClient); + val resultSet = listener.generateResultSet(); + assertThat(resultSet).isNotNull(); + assertThat(resultSet.getQueryId()).isEqualTo(id); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandlerTest.java b/src/test/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandlerTest.java new file mode 100644 index 0000000..2639868 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/exception/QueryExceptionHandlerTest.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.exception; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +import com.salesforce.datacloud.jdbc.util.GrpcUtils; +import io.grpc.StatusRuntimeException; +import java.sql.SQLException; +import org.junit.jupiter.api.Test; + +class QueryExceptionHandlerTest { + + @Test + public void testCreateExceptionWithStatusRuntimeException() { + StatusRuntimeException fakeException = GrpcUtils.getFakeStatusRuntimeExceptionAsInvalidArgument(); + SQLException actualException = QueryExceptionHandler.createException("test message", fakeException); + + assertInstanceOf(SQLException.class, actualException); + assertEquals("42P01", actualException.getSQLState()); + assertEquals("42P01: Table not found\n" + "DETAIL:\n" + "\nHINT:\n", actualException.getMessage()); + assertEquals(StatusRuntimeException.class, actualException.getCause().getClass()); + } + + @Test + void testCreateExceptionWithGenericException() { + Exception mockException = new Exception("Generic exception"); + SQLException sqlException = QueryExceptionHandler.createException("Default message", mockException); + + assertEquals("Default message", sqlException.getMessage()); + assertEquals(mockException, sqlException.getCause()); + } + + @Test + void testCreateException() { + SQLException actualException = QueryExceptionHandler.createException("test message"); + + assertInstanceOf(SQLException.class, actualException); + assertEquals("test message", actualException.getMessage()); + } + + @Test + public void testCreateExceptionWithSQLStateAndThrowableCause() { + Exception mockException = new Exception("Generic exception"); + String mockSQLState = "42P01"; + SQLException sqlException = QueryExceptionHandler.createException("test message", mockSQLState, mockException); + + assertInstanceOf(SQLException.class, sqlException); + assertEquals("42P01", sqlException.getSQLState()); + assertEquals("test message", sqlException.getMessage()); + } + + @Test + public void testCreateExceptionWithSQLStateAndMessage() { + String mockSQLState = "42P01"; + SQLException sqlException = QueryExceptionHandler.createException("test message", mockSQLState); + + assertInstanceOf(SQLException.class, sqlException); + assertEquals("42P01", sqlException.getSQLState()); + assertEquals("test message", sqlException.getMessage()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/http/ClientBuilderTest.java b/src/test/java/com/salesforce/datacloud/jdbc/http/ClientBuilderTest.java new file mode 100644 index 0000000..88501e1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/http/ClientBuilderTest.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.util.internal.SFDefaultSocketFactoryWrapper; +import java.util.Optional; +import java.util.Properties; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Stream; +import lombok.SneakyThrows; +import lombok.val; +import okhttp3.OkHttpClient; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; + +public class ClientBuilderTest { + + static final Function buildClient = ClientBuilder::buildOkHttpClient; + Random random = new Random(10); + + @FunctionalInterface + interface OkHttpClientTimeout { + int getTimeout(OkHttpClient client); + } + + static Stream timeoutArguments() { + return Stream.of( + Arguments.of("readTimeOutSeconds", 600, (OkHttpClientTimeout) OkHttpClient::readTimeoutMillis), + Arguments.of("connectTimeOutSeconds", 600, (OkHttpClientTimeout) OkHttpClient::connectTimeoutMillis), + Arguments.of("callTimeOutSeconds", 600, (OkHttpClientTimeout) OkHttpClient::callTimeoutMillis)); + } + + @ParameterizedTest + @MethodSource("timeoutArguments") + void createsClientWithAppropriateTimeouts(String key, int defaultSeconds, OkHttpClientTimeout actual) { + val properties = new Properties(); + val none = buildClient.apply(properties); + assertThat(actual.getTimeout(none)).isEqualTo(defaultSeconds * 1000); + + val notDefaultSeconds = defaultSeconds + random.nextInt(12345); + properties.setProperty(key, Integer.toString(notDefaultSeconds)); + val some = buildClient.apply(properties); + assertThat(actual.getTimeout(some)).isEqualTo(notDefaultSeconds * 1000); + } + + @SneakyThrows + @Test + void createsClientWithSocketFactoryIfSocksProxyEnabled() { + val actual = new AtomicReference<>(Optional.empty()); + + try (val x = Mockito.mockConstruction( + SFDefaultSocketFactoryWrapper.class, + (mock, context) -> actual.set(Optional.of(context.arguments().get(0))))) { + val client = buildClient.apply(new Properties()); + assertThat(client).isNotNull(); + } + + assertThat(actual.get()).isPresent().isEqualTo(Optional.of(false)); + } + + @SneakyThrows + @Test + void createsClientWithSocketFactoryIfSocksProxyDisabled() { + val actual = new AtomicReference<>(Optional.empty()); + + val properties = new Properties(); + properties.put("disableSocksProxy", "true"); + + try (val x = Mockito.mockConstruction( + SFDefaultSocketFactoryWrapper.class, + (mock, context) -> actual.set(Optional.of(context.arguments().get(0))))) { + val client = buildClient.apply(properties); + assertThat(client).isNotNull(); + } + + assertThat(actual.get()).isPresent().isEqualTo(Optional.of(true)); + } + + @SneakyThrows + @Test + void createClientHasSomeDefaults() { + val client = buildClient.apply(new Properties()); + assertThat(client.retryOnConnectionFailure()).isTrue(); + assertThat(client.interceptors()).hasAtLeastOneElementOfType(MetadataCacheInterceptor.class); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/http/FormCommandTest.java b/src/test/java/com/salesforce/datacloud/jdbc/http/FormCommandTest.java new file mode 100644 index 0000000..0a57417 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/http/FormCommandTest.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import lombok.extern.jackson.Jacksonized; +import lombok.val; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.assertj.core.api.SoftAssertions; +import org.assertj.core.api.junit.jupiter.InjectSoftAssertions; +import org.assertj.core.api.junit.jupiter.SoftAssertionsExtension; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@ExtendWith(SoftAssertionsExtension.class) +class FormCommandTest { + static URI VALID; + + static { + try { + VALID = new URI("https://localhost"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @InjectSoftAssertions + private SoftAssertions softly; + + private MockWebServer server; + private OkHttpClient client; + + @BeforeEach + @SneakyThrows + void beforeEach() { + this.server = new MockWebServer(); + this.server.start(); + this.client = new OkHttpClient.Builder().build(); + } + + @SneakyThrows + @AfterEach + void afterEach() { + this.server.close(); + } + + @Test + void throwsOnNullClient() { + val ex = assertThrows( + IllegalArgumentException.class, + () -> FormCommand.post( + null, new FormCommand(VALID, new URI("/suffix"), Map.of(), Map.of(), Map.of()), Object.class)); + assertThat(ex).hasMessage("client is marked non-null but is null").hasNoCause(); + } + + @Test + void throwsOnNullCommand() { + val ex = assertThrows( + IllegalArgumentException.class, + () -> FormCommand.post(new OkHttpClient.Builder().build(), null, Object.class)); + assertThat(ex).hasMessage("command is marked non-null but is null").hasNoCause(); + } + + @Test + @SneakyThrows + void postProperlyMakesFormFromHttpCommandFormContent() { + val body = "{ \"numbers\": [1, 2, 3] }"; + val userAgent = UUID.randomUUID().toString(); + val a = UUID.randomUUID().toString(); + val b = UUID.randomUUID().toString(); + val expectedBody = Map.of("a", a, "b", b); + val command = new FormCommand( + new URI(server.url("").toString()), + new URI("foo"), + Map.of("User-Agent", userAgent), + expectedBody, + Map.of()); + + server.enqueue(new MockResponse().setBody(body)); + + val resp = FormCommand.post(client, command, FakeCommandResp.class); + + assertThat(resp.getNumbers()).hasSameElementsAs(List.of(1, 2, 3)); + val r = server.takeRequest(); + softly.assertThat(r.getRequestLine()).startsWith("POST"); + softly.assertThat(r.getPath()).as("path").isEqualTo("/foo"); + softly.assertThat(r.getHeader("Accept")).as("Accept header").isEqualTo("application/json"); + softly.assertThat(r.getHeader("Content-Type")) + .as("Content-Type header") + .isEqualTo("application/x-www-form-urlencoded"); + softly.assertThat(r.getHeader("User-Agent")).as("User-Agent header").isEqualTo(userAgent); + + val actualBody = getBody(r); + softly.assertThat(actualBody).containsExactlyInAnyOrderEntriesOf(expectedBody); + } + + @Test + @SneakyThrows + void getProperlyMakesFormFromHttpCommandFormContent() { + val body = "{ \"numbers\": [1, 2, 3] }"; + val userAgent = UUID.randomUUID().toString(); + val a = UUID.randomUUID().toString(); + val b = UUID.randomUUID().toString(); + val expectedQueryParams = Map.of("a", a, "b", b); + val command = new FormCommand( + new URI(server.url("").toString()), + new URI("foo"), + Map.of("User-Agent", userAgent), + Map.of(), + expectedQueryParams); + + server.enqueue(new MockResponse().setBody(body)); + + val resp = FormCommand.get(client, command, FakeCommandResp.class); + + assertThat(resp.getNumbers()).hasSameElementsAs(List.of(1, 2, 3)); + val r = server.takeRequest(); + val actualRequestUrl = r.getRequestUrl(); + softly.assertThat(actualRequestUrl.queryParameter("a")).isEqualTo(a); + softly.assertThat(actualRequestUrl.queryParameter("b")).isEqualTo(b); + val expectedUrl = HttpUrl.get(server.url("").toString()); + softly.assertThat(actualRequestUrl.scheme()).isEqualTo(expectedUrl.scheme()); + softly.assertThat(actualRequestUrl.host()).isEqualTo(expectedUrl.host()); + softly.assertThat(actualRequestUrl.port()).isEqualTo(expectedUrl.port()); + softly.assertThat(actualRequestUrl.pathSegments()).isEqualTo(List.of("foo")); + softly.assertThat(r.getRequestLine()).startsWith("GET"); + softly.assertThat(r.getHeader("Accept")).as("Accept header").isEqualTo("application/json"); + softly.assertThat(r.getHeader("Content-Type")) + .as("Content-Type header") + .isEqualTo("application/x-www-form-urlencoded"); + softly.assertThat(r.getHeader("User-Agent")).as("User-Agent header").isEqualTo(userAgent); + var actualBody = getBody(r); + softly.assertThat(actualBody).isEqualTo(Map.of()); + } + + private Map getBody(RecordedRequest request) { + return Arrays.stream(request.getBody().readUtf8().split("&")) + .map(p -> p.split("=")) + .filter(t -> t.length == 2) + .collect(Collectors.toMap(arr -> arr[0], arr -> arr[1])); + } + + @Data + @Builder + @Jacksonized + @JsonIgnoreProperties(ignoreUnknown = true) + static class FakeCommandResp { + private List numbers; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptorTest.java new file mode 100644 index 0000000..b2a8cee --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/http/MetadataCacheInterceptorTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.http; + +import static com.salesforce.datacloud.jdbc.ResponseEnum.EMPTY_RESPONSE; +import static com.salesforce.datacloud.jdbc.ResponseEnum.QUERY_RESPONSE; +import static com.salesforce.datacloud.jdbc.ResponseEnum.TABLE_METADATA; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.salesforce.datacloud.jdbc.ResponseEnum; +import com.salesforce.datacloud.jdbc.util.Constants; +import com.salesforce.datacloud.jdbc.util.MetadataCacheUtil; +import lombok.SneakyThrows; +import okhttp3.Interceptor; +import okhttp3.MediaType; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class MetadataCacheInterceptorTest { + + private Interceptor.Chain chain; + + private MetadataCacheInterceptor metadataCacheInterceptor; + + @BeforeEach + public void init() { + chain = mock(Interceptor.Chain.class); + metadataCacheInterceptor = new MetadataCacheInterceptor(); + doReturn(buildRequest()).when(chain).request(); + } + + @Test + @SneakyThrows + public void testMetadataRequestWithNoCachePresent() { + doReturn(buildResponse(200, EMPTY_RESPONSE)) + .doReturn(buildResponse(200, QUERY_RESPONSE)) + .when(chain) + .proceed(any(Request.class)); + metadataCacheInterceptor.intercept(chain); + verify(chain, times(1)).proceed(any(Request.class)); + } + + @Test + @SneakyThrows + public void testMetadataFromCache() { + MetadataCacheUtil.cacheMetadata( + "https://mjrgg9bzgy2dsyzvmjrgkmzzg1.c360a.salesforce.com" + Constants.CDP_URL + Constants.METADATA_URL, + TABLE_METADATA.getResponse()); + metadataCacheInterceptor.intercept(chain); + verify(chain, times(0)).proceed(any(Request.class)); + } + + private Request buildRequest() { + return new Request.Builder() + .url("https://mjrgg9bzgy2dsyzvmjrgkmzzg1.c360a.salesforce.com" + + Constants.CDP_URL + + Constants.METADATA_URL) + .method(Constants.POST, RequestBody.create("{test: test}", MediaType.parse("application/json"))) + .build(); + } + + private Response buildResponse(int statusCode, ResponseEnum responseEnum) { + String jsonString = responseEnum.getResponse(); + return new Response.Builder() + .code(statusCode) + .request(buildRequest()) + .protocol(Protocol.HTTP_1_1) + .message("Redirected") + .body(ResponseBody.create(jsonString, MediaType.parse("application/json"))) + .build(); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java b/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java new file mode 100644 index 0000000..cbe5570 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/hyper/HyperTestBase.java @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.hyper; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.core.DataCloudConnection; +import com.salesforce.datacloud.jdbc.core.DataCloudStatement; +import com.salesforce.datacloud.jdbc.interceptor.AuthorizationHeaderInterceptor; +import io.grpc.ManagedChannelBuilder; +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.file.Paths; +import java.sql.ResultSet; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.assertj.core.api.ThrowingConsumer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInstance; + +@Slf4j +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class HyperTestBase { + private static final String LISTENING = "gRPC listening on"; + + @SneakyThrows + public static void assertEachRowIsTheSame(ResultSet rs, AtomicInteger prev) { + val expected = prev.incrementAndGet(); + val a = rs.getBigDecimal(1).intValue(); + assertThat(expected).isEqualTo(a); + } + + @SafeVarargs + @SneakyThrows + public static void assertWithConnection( + ThrowingConsumer assertion, Map.Entry... settings) { + try (val connection = getHyperQueryConnection(settings == null ? Map.of() : Map.ofEntries(settings))) { + assertion.accept(connection); + } + } + + @SafeVarargs + @SneakyThrows + public static void assertWithStatement( + ThrowingConsumer assertion, Map.Entry... settings) { + try (val connection = getHyperQueryConnection(settings == null ? Map.of() : Map.ofEntries(settings)); + val result = connection.createStatement().unwrap(DataCloudStatement.class)) { + assertion.accept(result); + } + } + + public static DataCloudConnection getHyperQueryConnection() { + return getHyperQueryConnection(Map.of()); + } + + @SneakyThrows + public static DataCloudConnection getHyperQueryConnection(Map connectionSettings) { + + val properties = new Properties(); + properties.putAll(connectionSettings); + val auth = AuthorizationHeaderInterceptor.of(new NoopTokenSupplier()); + val channel = ManagedChannelBuilder.forAddress("localhost", 8181).usePlaintext(); + + return DataCloudConnection.fromTokenSupplier(auth, channel, properties); + } + + private static Process hyperProcess; + private static final ExecutorService hyperMonitors = Executors.newFixedThreadPool(2); + + public static boolean enabled() { + return hyperProcess != null && hyperProcess.isAlive(); + } + + @AfterAll + public void afterAll() { + try { + if (hyperProcess != null && hyperProcess.isAlive()) { + hyperProcess.destroy(); + } + } catch (Throwable e) { + log.error("Failed to destroy hyperd", e); + } + + try { + hyperMonitors.shutdown(); + } catch (Throwable e) { + log.error("Failed to shutdown hyper monitor thread pool", e); + } + } + + @SneakyThrows + @BeforeAll + public void beforeAll() { + if (hyperProcess != null) { + log.info("hyperd was started but not cleaned up?"); + return; + } else { + log.info("starting hyperd, this might take a few seconds"); + } + + val hyperd = new File("./target/hyper/hyperd"); + val properties = Paths.get(requireNonNull(HyperTestBase.class.getResource("/hyper.yaml")) + .toURI()) + .toFile(); + + if (!hyperd.exists()) { + Assertions.fail("hyperd executable couldn't be found, have you run mvn process-test-resources? expected=" + + hyperd.getAbsolutePath()); + } + + hyperProcess = new ProcessBuilder() + .command(hyperd.getAbsolutePath(), "--config", properties.getAbsolutePath(), "--no-password", "run") + .start(); + + val latch = new CountDownLatch(1); + + hyperMonitors.execute(() -> logStream(hyperProcess.getErrorStream(), log::error)); + hyperMonitors.execute(() -> logStream(hyperProcess.getInputStream(), line -> { + log.info(line); + if (line.contains(LISTENING)) { + latch.countDown(); + } + })); + + if (!latch.await(30, TimeUnit.SECONDS)) { + Assertions.fail("failed to start instance of hyper within 30 seconds"); + } + } + + @BeforeEach + public void assumeHyperEnabled() { + Assumptions.assumeTrue(enabled(), "Hyper wasn't started so skipping test"); + } + + static class NoopTokenSupplier implements AuthorizationHeaderInterceptor.TokenSupplier { + @Override + public String getToken() { + return ""; + } + } + + private static void logStream(InputStream inputStream, Consumer consumer) { + try (val reader = new BufferedReader(new BufferedReader(new InputStreamReader(inputStream)))) { + String line; + while ((line = reader.readLine()) != null) { + consumer.accept(line); + } + } catch (Exception e) { + log.error("Caught exception while consuming log stream", e); + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptorTest.java new file mode 100644 index 0000000..f781040 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/AuthorizationHeaderInterceptorTest.java @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.salesforce.datacloud.jdbc.auth.DataCloudToken; +import com.salesforce.datacloud.jdbc.auth.TokenProcessor; +import io.grpc.Metadata; +import java.util.UUID; +import lombok.SneakyThrows; +import lombok.val; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class AuthorizationHeaderInterceptorTest { + private static final String AUTH = "Authorization"; + private static final String AUD = "audience"; + + private static final Metadata.Key AUTH_KEY = Metadata.Key.of(AUTH, ASCII_STRING_MARSHALLER); + private static final Metadata.Key AUD_KEY = Metadata.Key.of(AUD, ASCII_STRING_MARSHALLER); + + @Mock + private TokenProcessor mockTokenProcessor; + + @BeforeEach + void beforeEach() { + Mockito.reset(mockTokenProcessor); + } + + @SneakyThrows + @Test + void interceptorCallsGetDataCloudTokenTwice() { + val token = UUID.randomUUID().toString(); + val aud = UUID.randomUUID().toString(); + setupToken(token, aud); + setupToken(token, aud); + + val sut = sut(); + val metadata = new Metadata(); + + sut.mutate(metadata); + + assertThat(metadata.get(AUTH_KEY)).isEqualTo(token); + assertThat(metadata.get(AUD_KEY)).isEqualTo(aud); + } + + @SneakyThrows + @Test + void interceptorIgnoresNullAudience() { + setupToken("", null); + setupToken("", null); + + val sut = sut(); + val metadata = new Metadata(); + + sut.mutate(metadata); + + assertThat(metadata.get(AUD_KEY)).isNull(); + } + + private AuthorizationHeaderInterceptor sut() { + return AuthorizationHeaderInterceptor.of(mockTokenProcessor); + } + + @SneakyThrows + private void setupToken(String token, String aud) { + val newToken = mock(DataCloudToken.class); + lenient().when(newToken.getAccessToken()).thenReturn(token); + lenient().when(newToken.getTenantId()).thenReturn(aud); + + when(mockTokenProcessor.getDataCloudToken()).thenReturn(newToken); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptorTest.java new file mode 100644 index 0000000..f74965e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/DataspaceHeaderInterceptorTest.java @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.Assertions.assertThat; + +import io.grpc.Metadata; +import java.util.Properties; +import java.util.UUID; +import lombok.val; +import org.junit.jupiter.api.Test; + +class DataspaceHeaderInterceptorTest { + @Test + void ofReturnsNullWithNoDataspace() { + assertThat(DataspaceHeaderInterceptor.of(new Properties())).isNull(); + } + + @Test + void appliesDataspaceValueToMetadata() { + val expected = UUID.randomUUID().toString(); + + val metadata = new Metadata(); + sut(expected).mutate(metadata); + + assertThat(metadata.get(Metadata.Key.of("dataspace", ASCII_STRING_MARSHALLER))) + .isEqualTo(expected); + } + + @Test + void hasNiceToString() { + val expected = UUID.randomUUID().toString(); + assertThat(sut(expected).toString()).contains("dataspace=" + expected); + } + + private static DataspaceHeaderInterceptor sut(String dataspace) { + val properties = new Properties(); + properties.put("dataspace", dataspace); + + return DataspaceHeaderInterceptor.of(properties); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptorTest.java new file mode 100644 index 0000000..43272fe --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HeaderMutatingClientInterceptorTest.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.salesforce.datacloud.jdbc.core.HyperGrpcTestBase; +import com.salesforce.datacloud.jdbc.exception.DataCloudJDBCException; +import com.salesforce.hyperdb.grpc.QueryParam; +import io.grpc.Metadata; +import java.util.List; +import java.util.UUID; +import java.util.function.Consumer; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.val; +import org.assertj.core.api.AssertionsForClassTypes; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class HeaderMutatingClientInterceptorTest extends HyperGrpcTestBase { + String query = UUID.randomUUID().toString(); + String queryId = UUID.randomUUID().toString(); + + @Test + @SneakyThrows + void interceptCallAlwaysCallsMutate() { + Consumer mockConsumer = mock(Consumer.class); + val sut = new Sut(mockConsumer); + + try (val client = hyperGrpcClient.toBuilder().interceptors(List.of(sut)).build()) { + setupExecuteQuery(queryId, query, QueryParam.TransferMode.SYNC); + client.executeQuery(query); + } + + val argumentCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(mockConsumer).accept(argumentCaptor.capture()); + } + + @Test + void interceptCallCatchesMutateAndWrapsException() { + val message = UUID.randomUUID().toString(); + Consumer mockConsumer = mock(Consumer.class); + + doAnswer(invocation -> { + throw new RuntimeException(message); + }) + .when(mockConsumer) + .accept(any()); + + val sut = new Sut(mockConsumer); + + val ex = Assertions.assertThrows(DataCloudJDBCException.class, () -> { + try (val client = + hyperGrpcClient.toBuilder().interceptors(List.of(sut)).build()) { + setupExecuteQuery(queryId, query, QueryParam.TransferMode.SYNC); + client.executeQuery(query); + } + }); + + AssertionsForClassTypes.assertThat(ex) + .hasRootCauseMessage(message) + .hasMessage("Caught exception when mutating headers in client interceptor"); + + val argumentCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(mockConsumer).accept(argumentCaptor.capture()); + } + + @AllArgsConstructor + static class Sut implements HeaderMutatingClientInterceptor { + private final Consumer headersConsumer; + + @Override + public void mutate(Metadata headers) { + headersConsumer.accept(headers); + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptorTest.java new file mode 100644 index 0000000..86bc2db --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/HyperDefaultsHeaderInterceptorTest.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.Assertions.assertThat; + +import io.grpc.Metadata; +import java.util.Objects; +import lombok.val; +import org.junit.jupiter.api.Test; + +class HyperDefaultsHeaderInterceptorTest { + + private static final HyperDefaultsHeaderInterceptor sut = new HyperDefaultsHeaderInterceptor(); + + @Test + void setsWorkload() { + val key = Metadata.Key.of("x-hyperdb-workload", ASCII_STRING_MARSHALLER); + assertThat(actual().get(key)).isEqualTo("jdbcv3"); + } + + @Test + void setsMaxSize() { + val key = Metadata.Key.of("grpc.max_metadata_size", ASCII_STRING_MARSHALLER); + assertThat(Integer.parseInt(Objects.requireNonNull(actual().get(key)))).isEqualTo(1024 * 1024); + } + + @Test + void hasNiceToString() { + assertThat(sut.toString()).isEqualTo("HyperDefaultsHeaderInterceptor()"); + } + + private static Metadata actual() { + val metadata = new Metadata(); + sut.mutate(metadata); + return metadata; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptorTest.java new file mode 100644 index 0000000..c5f921e --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/QueryIdHeaderInterceptorTest.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import io.grpc.Metadata; +import java.util.UUID; +import lombok.val; +import org.junit.jupiter.api.Test; + +class QueryIdHeaderInterceptorTest { + @Test + void appliesQueryIdToHeaders() { + val key = Metadata.Key.of("x-hyperdb-query-id", ASCII_STRING_MARSHALLER); + val id = UUID.randomUUID().toString(); + val interceptor = new QueryIdHeaderInterceptor(id); + val headers = new Metadata(); + interceptor.mutate(headers); + assertThat(headers.get(key)).isEqualTo(id); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptorTest.java new file mode 100644 index 0000000..d6ba292 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/TracingHeadersInterceptorTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; + +import io.grpc.Metadata; +import lombok.val; +import org.junit.jupiter.api.Test; + +class TracingHeadersInterceptorTest { + private static final TracingHeadersInterceptor sut = TracingHeadersInterceptor.of(); + + private static final Metadata.Key trace = Metadata.Key.of("x-b3-traceid", ASCII_STRING_MARSHALLER); + private static final Metadata.Key span = Metadata.Key.of("x-b3-spanid", ASCII_STRING_MARSHALLER); + + @Test + void itAppliesIdsFromTracerToHeaders() { + val metadata = new Metadata(); + + sut.mutate(metadata); + + val traceA = metadata.get(trace); + val spanA = metadata.get(span); + + sut.mutate(metadata); + + val traceB = metadata.get(trace); + val spanB = metadata.get(span); + + assertThat(traceA).isNotBlank(); + assertThat(traceB).isNotBlank(); + assertThat(traceA).isEqualTo(traceB); + + assertThat(spanA).isNotBlank(); + assertThat(spanB).isNotBlank(); + assertThat(spanA).isNotEqualTo(spanB); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptorTest.java new file mode 100644 index 0000000..7017bb0 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/interceptor/UserAgentHeaderInterceptorTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.interceptor; + +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; +import static org.assertj.core.api.Assertions.assertThat; + +import com.salesforce.datacloud.jdbc.config.DriverVersion; +import io.grpc.Metadata; +import java.util.Properties; +import java.util.UUID; +import lombok.val; +import org.junit.jupiter.api.Test; + +class UserAgentHeaderInterceptorTest { + @Test + void ofReturnsBasicWithNoUserAgent() { + val metadata = new Metadata(); + sut(null).mutate(metadata); + + assertThat(metadata.get(Metadata.Key.of("User-Agent", ASCII_STRING_MARSHALLER))) + .isEqualTo(DriverVersion.formatDriverInfo()); + } + + @Test + void noDuplicateDriverInfo() { + val metadata = new Metadata(); + sut(DriverVersion.formatDriverInfo()).mutate(metadata); + + assertThat(metadata.get(Metadata.Key.of("User-Agent", ASCII_STRING_MARSHALLER))) + .isEqualTo(DriverVersion.formatDriverInfo()); + } + + @Test + void appliesDataspaceValueToMetadata() { + val expected = UUID.randomUUID().toString(); + + val metadata = new Metadata(); + sut(expected).mutate(metadata); + + assertThat(metadata.get(Metadata.Key.of("User-Agent", ASCII_STRING_MARSHALLER))) + .isEqualTo(expected + " " + DriverVersion.formatDriverInfo()); + } + + private static UserAgentHeaderInterceptor sut(String userAgent) { + val properties = new Properties(); + + if (userAgent != null) { + properties.put("User-Agent", userAgent); + } + + return UserAgentHeaderInterceptor.of(properties); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/internal/EncodingUtilsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/internal/EncodingUtilsTest.java new file mode 100644 index 0000000..f65eae1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/internal/EncodingUtilsTest.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.internal; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class EncodingUtilsTest { + private static final long FIRST_LONG = 0x1213141516171819L; + private static final int LONG_BYTES = Long.SIZE / Byte.SIZE; + private static final int BYTE_BASE16 = 2; + private static final int LONG_BASE16 = BYTE_BASE16 * LONG_BYTES; + + private static final char[] FIRST_CHAR_ARRAY = + new char[] {'1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9'}; + private static final long SECOND_LONG = 0xFFEEDDCCBBAA9988L; + private static final char[] SECOND_CHAR_ARRAY = + new char[] {'f', 'f', 'e', 'e', 'd', 'd', 'c', 'c', 'b', 'b', 'a', 'a', '9', '9', '8', '8'}; + private static final char[] BOTH_CHAR_ARRAY = new char[] { + '1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', 'f', 'f', 'e', 'e', 'd', 'd', + 'c', 'c', 'b', 'b', 'a', 'a', '9', '9', '8', '8' + }; + + @Test + void testLongToBase16String() { + char[] chars1 = new char[LONG_BASE16]; + EncodingUtils.longToBase16String(FIRST_LONG, chars1, 0); + assertThat(chars1).isEqualTo(FIRST_CHAR_ARRAY); + + char[] chars2 = new char[LONG_BASE16]; + EncodingUtils.longToBase16String(SECOND_LONG, chars2, 0); + assertThat(chars2).isEqualTo(SECOND_CHAR_ARRAY); + + char[] chars3 = new char[2 * LONG_BASE16]; + EncodingUtils.longToBase16String(FIRST_LONG, chars3, 0); + EncodingUtils.longToBase16String(SECOND_LONG, chars3, LONG_BASE16); + assertThat(chars3).isEqualTo(BOTH_CHAR_ARRAY); + } + + @Test + void testValidHex() { + assertThat(EncodingUtils.isValidBase16String("abcdef1234567890")).isTrue(); + assertThat(EncodingUtils.isValidBase16String("abcdefg1234567890")).isFalse(); + assertThat(EncodingUtils.isValidBase16String(" arrowFields = List.of(new Field("id", FieldType.nullable(new ArrowType.Utf8()), null)); + final List expectedColumnMetadata = + List.of(ColumnMetaData.fromProto(Common.ColumnMetaData.newBuilder() + .setColumnName("id") + .setOrdinal(0) + .build())); + List actualColumnMetadata = ArrowUtils.toColumnMetaData(arrowFields); + assertThat(actualColumnMetadata).hasSameSizeAs(expectedColumnMetadata); + val expected = expectedColumnMetadata.get(0); + val actual = actualColumnMetadata.get(0); + + softly.assertThat(actual.columnName).isEqualTo(expected.columnName); + softly.assertThat(actual.ordinal).isEqualTo(expected.ordinal); + softly.assertThat(actual.type.name) + .isEqualTo(JDBCType.valueOf(Types.VARCHAR).getName()); + } + + @Test + void testConvertArrowFieldsToColumnMetaDataTypes() { + Map> testCases = new HashMap<>(); + testCases.put( + JDBCType.valueOf(Types.TINYINT).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Int(8, true)), null))); + testCases.put( + JDBCType.valueOf(Types.SMALLINT).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Int(16, true)), null))); + testCases.put( + JDBCType.valueOf(Types.INTEGER).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Int(32, true)), null))); + testCases.put( + JDBCType.valueOf(Types.BIGINT).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Int(64, true)), null))); + testCases.put( + JDBCType.valueOf(Types.BOOLEAN).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Bool()), null))); + testCases.put( + JDBCType.valueOf(Types.VARCHAR).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Utf8()), null))); + testCases.put( + JDBCType.valueOf(Types.FLOAT).getName(), + List.of(new Field( + "", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null))); + testCases.put( + JDBCType.valueOf(Types.DOUBLE).getName(), + List.of(new Field( + "", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null))); + testCases.put( + JDBCType.valueOf(Types.DATE).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null))); + testCases.put( + JDBCType.valueOf(Types.TIME).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Time(TimeUnit.MICROSECOND, 64)), null))); + testCases.put( + JDBCType.valueOf(Types.TIMESTAMP).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), null))); + testCases.put( + JDBCType.valueOf(Types.DECIMAL).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.Decimal(1, 1, 128)), null))); + testCases.put( + JDBCType.valueOf(Types.ARRAY).getName(), + List.of(new Field("", FieldType.nullable(new ArrowType.List()), null))); + + for (var entry : testCases.entrySet()) { + List actual = ArrowUtils.toColumnMetaData(entry.getValue()); + softly.assertThat(actual.get(0).type.name).isEqualTo(entry.getKey()); + } + } + + private static Stream arrowTypes() { + return Stream.of( + Arguments.of(new ArrowType.Int(8, true), Types.TINYINT), + Arguments.of(new ArrowType.Int(16, true), Types.SMALLINT), + Arguments.of(new ArrowType.Int(32, true), Types.INTEGER), + Arguments.of(new ArrowType.Int(64, true), Types.BIGINT), + Arguments.of(new ArrowType.Bool(), Types.BOOLEAN), + Arguments.of(new ArrowType.Utf8(), Types.VARCHAR), + Arguments.of(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), Types.FLOAT), + Arguments.of(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), Types.DOUBLE), + Arguments.of(new ArrowType.LargeUtf8(), Types.LONGVARCHAR), + Arguments.of(new ArrowType.Binary(), Types.VARBINARY), + Arguments.of(new ArrowType.FixedSizeBinary(8), Types.BINARY), + Arguments.of(new ArrowType.LargeBinary(), Types.LONGVARBINARY), + Arguments.of(new ArrowType.Decimal(1, 1, 128), Types.DECIMAL), + Arguments.of(new ArrowType.Date(DateUnit.DAY), Types.DATE), + Arguments.of(new ArrowType.Time(TimeUnit.MICROSECOND, 64), Types.TIME), + Arguments.of(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC"), Types.TIMESTAMP), + Arguments.of(new ArrowType.List(), Types.ARRAY), + Arguments.of(new ArrowType.LargeList(), Types.ARRAY), + Arguments.of(new ArrowType.FixedSizeList(1), Types.ARRAY), + Arguments.of(new ArrowType.Map(true), Types.JAVA_OBJECT), + Arguments.of(new ArrowType.Duration(TimeUnit.MICROSECOND), Types.JAVA_OBJECT), + Arguments.of(new ArrowType.Interval(IntervalUnit.DAY_TIME), Types.JAVA_OBJECT), + Arguments.of(new ArrowType.Struct(), Types.STRUCT), + Arguments.of(new ArrowType.Null(), Types.NULL)); + } + + @ParameterizedTest + @MethodSource("arrowTypes") + void testGetSQLTypeFromArrowTypes(ArrowType arrowType, int expectedSqlType) { + softly.assertThat(ArrowUtils.getSQLTypeFromArrowType(arrowType)).isEqualTo(expectedSqlType); + } + + @Test + void testConvertJDBCMetadataToAvaticaColumns() throws SQLException { + ResultSetMetaData resultSetMetaData = mockResultSetMetadata(); + List columnMetaDataList = ArrowUtils.convertJDBCMetadataToAvaticaColumns(resultSetMetaData, 7); + + for (int i = 0; i < columnMetaDataList.size(); i++) { + val actual = columnMetaDataList.get(i); + + softly.assertThat(actual.type.id).isEqualTo(resultSetMetaData.getColumnType(i + 1)); + softly.assertThat(actual.columnName).isEqualTo(resultSetMetaData.getColumnName(i + 1)); + softly.assertThat(actual.type.name).isEqualTo(resultSetMetaData.getColumnTypeName(i + 1)); + } + } + + @Test + void testConvertJDBCMetadataToAvaticaColumnsEmptyMetadata() { + assertThat(ArrowUtils.convertJDBCMetadataToAvaticaColumns(null, 7)).isEmpty(); + } + + private ResultSetMetaData mockResultSetMetadata() { + String[] columnNames = {"col1", "col2", "col3", "col4", "col5", "col6", "col7"}; + String[] columnTypes = { + "INTEGER", "VARCHAR", "DECIMAL", "TIMESTAMP WITH TIME ZONE", "ARRAY", "STRUCT", "MULTISET" + }; + Integer[] columnTypeIds = {4, 12, 3, 2013, 2003, 2002, 2003}; + + return new QueryResultSetMetadata( + Arrays.asList(columnNames), Arrays.asList(columnTypes), Arrays.asList(columnTypeIds)); + } + + @Test + void testCreateSchemaFromParametersValid() { + List parameterBindings = Arrays.asList( + new ParameterBinding(Types.VARCHAR, "string"), + new ParameterBinding(Types.INTEGER, 1), + new ParameterBinding(Types.BIGINT, 123456789L), + new ParameterBinding(Types.BOOLEAN, true), + new ParameterBinding(Types.TINYINT, (byte) 1), + new ParameterBinding(Types.SMALLINT, (short) 1), + new ParameterBinding(Types.DATE, new Date(1)), + new ParameterBinding(Types.TIME, new Time(1)), + new ParameterBinding(Types.TIMESTAMP, new Timestamp(1)), + new ParameterBinding(Types.DECIMAL, new BigDecimal("123.45")), + new ParameterBinding(Types.ARRAY, List.of(1, 2, 3))); + + Schema schema = ArrowUtils.createSchemaFromParameters(parameterBindings); + + Assertions.assertNotNull(schema); + assertEquals(11, schema.getFields().size()); + + Field field = schema.getFields().get(0); + assertEquals("1", field.getName()); + assertInstanceOf(ArrowType.Utf8.class, field.getType()); + + field = schema.getFields().get(1); + assertInstanceOf(ArrowType.Int.class, field.getType()); + ArrowType.Int intType = (ArrowType.Int) field.getType(); + assertEquals(32, intType.getBitWidth()); + + field = schema.getFields().get(2); + assertInstanceOf(ArrowType.Int.class, field.getType()); + ArrowType.Int longType = (ArrowType.Int) field.getType(); + assertEquals(64, longType.getBitWidth()); + + field = schema.getFields().get(3); + assertInstanceOf(ArrowType.Bool.class, field.getType()); + + field = schema.getFields().get(4); + assertInstanceOf(ArrowType.Int.class, field.getType()); + ArrowType.Int byteType = (ArrowType.Int) field.getType(); + assertEquals(8, byteType.getBitWidth()); + + field = schema.getFields().get(5); + assertInstanceOf(ArrowType.Int.class, field.getType()); + ArrowType.Int shortType = (ArrowType.Int) field.getType(); + assertEquals(16, shortType.getBitWidth()); + + field = schema.getFields().get(6); + assertInstanceOf(ArrowType.Date.class, field.getType()); + ArrowType.Date dateType = (ArrowType.Date) field.getType(); + assertEquals(DateUnit.DAY, dateType.getUnit()); + + field = schema.getFields().get(7); + assertInstanceOf(ArrowType.Time.class, field.getType()); + ArrowType.Time timeType = (ArrowType.Time) field.getType(); + assertEquals(TimeUnit.MICROSECOND, timeType.getUnit()); + assertEquals(64, timeType.getBitWidth()); + + field = schema.getFields().get(8); + assertInstanceOf(ArrowType.Timestamp.class, field.getType()); + ArrowType.Timestamp timestampType = (ArrowType.Timestamp) field.getType(); + assertEquals(TimeUnit.MICROSECOND, timestampType.getUnit()); + assertEquals("UTC", timestampType.getTimezone()); + + field = schema.getFields().get(9); + assertInstanceOf(ArrowType.Decimal.class, field.getType()); + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + assertEquals(5, decimalType.getPrecision()); + assertEquals(2, decimalType.getScale()); + + field = schema.getFields().get(10); + assertInstanceOf(ArrowType.List.class, field.getType()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIteratorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIteratorTest.java new file mode 100644 index 0000000..a741d86 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/ConsumingPeekingIteratorTest.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.val; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class ConsumingPeekingIteratorTest { + private final List expected = IntStream.rangeClosed(1, 20).boxed().collect(Collectors.toList()); + + private List> getData() { + return Stream.of( + List.of(1, 2, 3), + List.of(4, 5, 6, 7, 8, 9), + new ArrayList(), + new ArrayList(), + new ArrayList(), + new ArrayList(), + List.of(10, 11, 12, 13, 14, 15, 16, 17, 18, 19), + List.of(20), + new ArrayList()) + .map(ArrayDeque::new) + .collect(Collectors.toList()); + } + + ConsumingPeekingIterator> getSut() { + return ConsumingPeekingIterator.of(getData().stream(), t -> !t.isEmpty()); + } + + @Test + void returnsSameValueUntilCallerConsumesItem() { + val sut = getSut(); + assertThat(sut.next()).isSameAs(sut.next()); + } + + @Test + void consumesAllData() { + val sut = getSut(); + val actual = new ArrayList(); + while (sut.hasNext()) { + val current = sut.next(); + actual.add(current.removeFirst()); + } + + assertThat(actual).hasSameElementsAs(expected); + } + + @Test + void throwsWhenExhausted() { + val sut = getSut(); + while (sut.hasNext()) { + sut.next().removeFirst(); + } + + Assertions.assertThrows(NoSuchElementException.class, sut::next); + } + + @Test + void removeIsNotSupported() { + val sut = ConsumingPeekingIterator.of(Stream.empty(), t -> true); + Assertions.assertThrows(UnsupportedOperationException.class, sut::remove); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/DateTimeUtilsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/util/DateTimeUtilsTest.java new file mode 100644 index 0000000..bfeb4e8 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/DateTimeUtilsTest.java @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCDateFromDateAndCalendar; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCDateFromMilliseconds; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimeFromMilliseconds; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimeFromTimeAndCalendar; +import static com.salesforce.datacloud.jdbc.util.DateTimeUtils.getUTCTimestampFromTimestampAndCalendar; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import lombok.val; +import org.junit.jupiter.api.Test; + +public class DateTimeUtilsTest { + + private final long positiveEpochMilli = 959817600000L; // 2000-06-01 00:00:00 UTC + private final long negativeEpochMilli = -618105600000L; // 1950-06-01 00:00:00 UTC + + private final long zeroEpochMilli = 0L; // 1970-01-01 00:00:00 UTC + + private final long maxMillisecondsInDay = 86400000L - 1; // 23:59:59 + + private final long randomMillisecondsInDay = 44444444L; // 12:20:44 + + private final ZoneId UTCZoneId = TimeZone.getTimeZone("UTC").toZoneId(); + + @Test + void testShouldGetUTCDateFromEpochZero() { + // actual Date 1970, 01, 01 UTC + val instant = Instant.ofEpochMilli(zeroEpochMilli); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedYear = zdt.getYear(); + val expectedMonth = zdt.getMonthValue(); + val expectedDay = zdt.getDayOfMonth(); + + val actual = getUTCDateFromMilliseconds(zeroEpochMilli); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCDateFromPositiveEpoch() { + // actual Date 2000-06-01 00:00:00 UTC + val instant = Instant.ofEpochMilli(positiveEpochMilli); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedYear = zdt.getYear(); + val expectedMonth = zdt.getMonthValue(); + val expectedDay = zdt.getDayOfMonth(); + + val actual = getUTCDateFromMilliseconds(positiveEpochMilli); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCDateFromNegativeEpoch() { + // actual Date 1950-06-01 00:00:00 UTC + val instant = Instant.ofEpochMilli(negativeEpochMilli); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedYear = zdt.getYear(); + val expectedMonth = zdt.getMonthValue(); + val expectedDay = zdt.getDayOfMonth(); + + val actual = getUTCDateFromMilliseconds(negativeEpochMilli); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCTimeFromEpochZero() { + // actual Time 00:00:00.000 UTC + val instant = Instant.ofEpochMilli(zeroEpochMilli); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedHour = zdt.getHour(); + val expectedMinute = zdt.getMinute(); + val expectedSeconds = zdt.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(zdt.getNano()); + + val actual = getUTCTimeFromMilliseconds(zeroEpochMilli); + + assertThat(actual) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSeconds) + .hasMillisecond(expectedMillis); + } + + @Test + void testShouldGetUTCTimeFromMaxMillisecondsInDay() { + // actual Time 23:59:59.999 UTC + val instant = Instant.ofEpochMilli(maxMillisecondsInDay); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedHour = zdt.getHour(); + val expectedMinute = zdt.getMinute(); + val expectedSeconds = zdt.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(zdt.getNano()); + + val actual = getUTCTimeFromMilliseconds(maxMillisecondsInDay); + + assertThat(actual) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSeconds) + .hasMillisecond(expectedMillis); + } + + @Test + void testShouldGetUTCTimeFromRandomMillisecondsInDay() { + // actual Time 12:20:44 UTC + val instant = Instant.ofEpochMilli(randomMillisecondsInDay); + val zdt = ZonedDateTime.ofInstant(instant, UTCZoneId); + + val expectedHour = zdt.getHour(); + val expectedMinute = zdt.getMinute(); + val expectedSeconds = zdt.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(zdt.getNano()); + + val actual = getUTCTimeFromMilliseconds(randomMillisecondsInDay); + + assertThat(actual) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSeconds) + .hasMillisecond(expectedMillis); + } + + @Test + void testShouldGetUTCDateFromEpochZeroDifferentCalendar() { + // UTC Date would be 1969-12-31 + TimeZone timeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(timeZone); + + val date = Date.valueOf("1970-01-01"); + val time = date.getTime(); + + val dateTime = new Timestamp(time).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(timeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + + val actual = getUTCDateFromDateAndCalendar(date, calendar); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCDateFromEpochZeroSameCalendar() { + // UTC Date would be 1970-01-01 + TimeZone timeZone = TimeZone.getTimeZone("UTC"); + Calendar calendar = Calendar.getInstance(timeZone); + + val date = Date.valueOf("1970-01-01"); + val time = date.getTime(); + + val dateTime = new Timestamp(time).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(timeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + + val actual = getUTCDateFromDateAndCalendar(date, calendar); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCDateFromEpochZeroNullCalendar() { + // UTC Date would be 1970-01-01 + TimeZone defaultTz = TimeZone.getDefault(); + TimeZone timeZone = TimeZone.getTimeZone("GMT-8"); + TimeZone.setDefault(timeZone); + + val date = Date.valueOf("1970-01-01"); + val time = date.getTime(); + + val dateTime = new Timestamp(time).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(timeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + + val actual = getUTCDateFromDateAndCalendar(date, null); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + + TimeZone.setDefault(defaultTz); + } + + @Test + void testShouldGetUTCDateFromNegativeEpochGivenCalendar() { + // UTC Date would be 1949-12-31 22:00 + TimeZone timeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(timeZone); + + val date = Date.valueOf("1950-01-01"); + val time = date.getTime(); + + val dateTime = new Timestamp(time).toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(timeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + + val actual = getUTCDateFromDateAndCalendar(date, calendar); + + assertThat(actual).hasYear(expectedYear).hasMonth(expectedMonth).hasDayOfMonth(expectedDay); + } + + @Test + void testShouldGetUTCTimeFromEpochZeroDifferentCalendar() { + // UTC Time would be 22:00:00 + TimeZone timeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(timeZone); + + val time = Time.valueOf("00:00:00"); + + val actual = getUTCTimeFromTimeAndCalendar(time, calendar); + + assertThat(actual).hasHourOfDay(22).hasMinute(0).hasSecond(0).hasMillisecond(0); + } + + @Test + void testShouldGetUTCTimeFromEpochZeroSameCalendar() { + // UTC Time would be 00:00:00 + TimeZone timeZone = TimeZone.getTimeZone("UTC"); + Calendar calendar = Calendar.getInstance(timeZone); + + val time = Time.valueOf("00:00:00"); + + val actual = getUTCTimeFromTimeAndCalendar(time, calendar); + + assertThat(actual).hasHourOfDay(0).hasMinute(0).hasSecond(0).hasMillisecond(0); + } + + @Test + void testShouldGetUTCTimeFromEpochZeroNullCalendar() { + // UTC Time would be 08:00:00 + TimeZone defaultTz = TimeZone.getDefault(); + TimeZone timeZone = TimeZone.getTimeZone("GMT-8"); + TimeZone.setDefault(timeZone); + + val time = Time.valueOf("00:00:00"); + + val actual = getUTCTimeFromTimeAndCalendar(time, null); + + assertThat(actual).hasHourOfDay(8).hasMinute(0).hasSecond(0).hasMillisecond(0); + + TimeZone.setDefault(defaultTz); + } + + @Test + void testShouldGetUTCTimestampFromEpochZeroDifferentCalendar() { + // UTC Timestamp would be 1969-12-31 22:00:00.000000000 + TimeZone plusTwoTimeZone = TimeZone.getTimeZone("GMT+2"); + Calendar calendar = Calendar.getInstance(plusTwoTimeZone); + + val timestamp = Timestamp.valueOf("1970-01-01 00:00:00.000000000"); + + val dateTime = timestamp.toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(plusTwoTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + val expectedHour = finalDateTime.getHour(); + val expectedMinute = finalDateTime.getMinute(); + val expectedSecond = finalDateTime.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(finalDateTime.getNano()); + + val actual = getUTCTimestampFromTimestampAndCalendar(timestamp, calendar); + + assertThat(actual) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMillis); + } + + @Test + void testShouldGetUTCTimestampFromEpochZeroSameCalendar() { + // UTC Timestamp would be 1970-01-01 00:00:00.000000000 + TimeZone UTCTimeZone = TimeZone.getTimeZone("UTC"); + Calendar calendar = Calendar.getInstance(UTCTimeZone); + + val timestamp = Timestamp.valueOf("1970-01-01 00:00:00.000000000"); + + val dateTime = timestamp.toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(UTCTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + val expectedHour = finalDateTime.getHour(); + val expectedMinute = finalDateTime.getMinute(); + val expectedSecond = finalDateTime.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(finalDateTime.getNano()); + + val actual = getUTCTimestampFromTimestampAndCalendar(timestamp, calendar); + + assertThat(actual) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMillis); + } + + @Test + void testShouldGetUTCTimestampFromEpochZeroNullCalendar() { + // UTC Time would be 1970-01-01 08:00:00.000000000 + TimeZone defaultTz = TimeZone.getDefault(); + TimeZone minusEightTimeZone = TimeZone.getTimeZone("GMT-8"); + TimeZone.setDefault(minusEightTimeZone); + + val timestamp = Timestamp.valueOf("1970-01-01 00:00:00.000000000"); + + val dateTime = timestamp.toLocalDateTime(); + + val zonedDateTime = dateTime.atZone(minusEightTimeZone.toZoneId()); + val convertedDateTime = zonedDateTime.withZoneSameInstant(UTCZoneId); + val finalDateTime = convertedDateTime.toLocalDateTime(); + + val expectedYear = finalDateTime.getYear(); + val expectedMonth = finalDateTime.getMonthValue(); + val expectedDay = finalDateTime.getDayOfMonth(); + val expectedHour = finalDateTime.getHour(); + val expectedMinute = finalDateTime.getMinute(); + val expectedSecond = finalDateTime.getSecond(); + val expectedMillis = (int) TimeUnit.NANOSECONDS.toMillis(finalDateTime.getNano()); + + val actual = getUTCTimestampFromTimestampAndCalendar(timestamp, null); + + assertThat(actual) + .hasYear(expectedYear) + .hasMonth(expectedMonth) + .hasDayOfMonth(expectedDay) + .hasHourOfDay(expectedHour) + .hasMinute(expectedMinute) + .hasSecond(expectedSecond) + .hasMillisecond(expectedMillis); + + TimeZone.setDefault(defaultTz); + } + + @Test + void testMillisToMicrosecondsSinceMidnight() { + long millisSinceMidnight = 3600000; // 1 hour in milliseconds + long expectedMicroseconds = millisSinceMidnight * 1000; + assertEquals(expectedMicroseconds, DateTimeUtils.millisToMicrosecondsSinceMidnight(millisSinceMidnight)); + } + + @Test + void testLocalDateTimeToMicrosecondsSinceEpoch() { + LocalDateTime localDateTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0); + long expectedMicroseconds = localDateTime.toInstant(ZoneOffset.UTC).toEpochMilli() * 1000; + assertEquals(expectedMicroseconds, DateTimeUtils.localDateTimeToMicrosecondsSinceEpoch(localDateTime)); + } + + @Test + void testAdjustForCalendar() { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/New_York")); + LocalDateTime localDateTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0); + LocalDateTime expectedAdjusted = localDateTime.plusHours(-5); // NYC is UTC-5 + + LocalDateTime adjusted = DateTimeUtils.adjustForCalendar(localDateTime, calendar, TimeZone.getTimeZone("UTC")); + + assertEquals(expectedAdjusted, adjusted); + } + + @Test + void testAdjustForCalendarWithNullCalendar() { + LocalDateTime localDateTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0); + LocalDateTime adjusted = DateTimeUtils.adjustForCalendar(localDateTime, null, TimeZone.getTimeZone("UTC")); + + assertEquals(localDateTime, adjusted); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/GrpcUtils.java b/src/test/java/com/salesforce/datacloud/jdbc/util/GrpcUtils.java new file mode 100644 index 0000000..c8555a6 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/GrpcUtils.java @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.google.protobuf.Any; +import com.salesforce.hyperdb.grpc.ErrorInfo; +import io.grpc.Metadata; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class GrpcUtils { + private static final com.google.rpc.Status rpcStatus = com.google.rpc.Status.newBuilder() + .setCode(io.grpc.Status.INVALID_ARGUMENT.getCode().value()) + .setMessage("Resource Not Found") + .addDetails(Any.pack(ErrorInfo.newBuilder() + .setSqlstate("42P01") + .setPrimaryMessage("Table not found") + .build())) + .build(); + + private static final Metadata.Key metaDataKey = + Metadata.Key.of("test-metadata", Metadata.ASCII_STRING_MARSHALLER); + private static final String metaDataValue = "test metadata value"; + + public static Metadata getFakeMetaData() { + Metadata metadata = new Metadata(); + metadata.put(metaDataKey, metaDataValue); + return metadata; + } + + public static StatusRuntimeException getFakeStatusRuntimeExceptionAsInvalidArgument() { + return StatusProto.toStatusRuntimeException(rpcStatus, getFakeMetaData()); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensionsTest.java b/src/test/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensionsTest.java new file mode 100644 index 0000000..19b64b1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/PropertiesExtensionsTest.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Stream; +import lombok.val; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class PropertiesExtensionsTest { + private final String key = UUID.randomUUID().toString(); + + @Test + void optionalValidKeyAndValue() { + val expected = UUID.randomUUID().toString(); + val p = new Properties(); + p.put(key, expected); + + val some = PropertiesExtensions.optional(p, key); + assertThat(some).isPresent().contains(expected); + } + + @Test + void optionalNotPresentKey() { + val none = PropertiesExtensions.optional(new Properties(), "key"); + assertThat(none).isNotPresent(); + } + + @Test + void optionalNotPresentOnNullProperties() { + assertThat(PropertiesExtensions.optional(null, "key")).isNotPresent(); + } + + @ParameterizedTest + @ValueSource(strings = {" ", "\t", "\n"}) + void optionalEmptyOnIllegalValue(String input) { + val p = new Properties(); + p.put(key, input); + + val none = PropertiesExtensions.optional(p, UUID.randomUUID().toString()); + assertThat(none).isNotPresent(); + } + + @Test + void requiredValidKeyAndValue() { + val expected = UUID.randomUUID().toString(); + val p = new Properties(); + p.put(key, expected); + + val some = PropertiesExtensions.required(p, key); + assertThat(some).isEqualTo(expected); + } + + @ParameterizedTest + @ValueSource(strings = {" ", "\t", "\n"}) + void requiredThrowsOnBadValue(String input) { + val p = new Properties(); + p.put(key, input); + + val e = assertThrows(IllegalArgumentException.class, () -> PropertiesExtensions.required(p, key)); + assertThat(e).hasMessage(PropertiesExtensions.Messages.REQUIRED_MISSING_PREFIX + key); + } + + @Test + void copy() { + val included = Set.of("a", "b", "c", "d", "e"); + val excluded = Set.of("1", "2", "3", "4", "5"); + + val p = new Properties(); + Stream.concat(included.stream(), excluded.stream()).forEach(k -> p.put(k, k.toUpperCase(Locale.ROOT))); + + val actual = PropertiesExtensions.copy(p, included); + + assertThat(actual).containsExactlyInAnyOrderEntriesOf(Map.of("a", "A", "b", "B", "c", "C", "d", "D", "e", "E")); + } + + @Test + void toIntegerOrNull() { + assertThat(PropertiesExtensions.toIntegerOrNull("123")).isEqualTo(123); + assertThat(PropertiesExtensions.toIntegerOrNull("asdfasdf")).isNull(); + } + + @Test + void getBooleanOrDefaultKeyExistsValidInvalidValues() { + Properties properties = new Properties(); + properties.setProperty("myKeyTrue", "true"); + Boolean resultTrue = PropertiesExtensions.getBooleanOrDefault(properties, "myKeyTrue", false); + assertThat(resultTrue).isEqualTo(true); + + properties.setProperty("myKeyFalse", "false"); + Boolean resultFalse = PropertiesExtensions.getBooleanOrDefault(properties, "myKeyFalse", true); + assertThat(resultFalse).isEqualTo(false); + + properties.setProperty("myKeyEmpty", ""); + Boolean resultEmpty = PropertiesExtensions.getBooleanOrDefault(properties, "myKeyEmpty", false); + assertThat(resultEmpty).isEqualTo(false); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/RealisticArrowGenerator.java b/src/test/java/com/salesforce/datacloud/jdbc/util/RealisticArrowGenerator.java new file mode 100644 index 0000000..83ef758 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/RealisticArrowGenerator.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.google.protobuf.ByteString; +import com.salesforce.hyperdb.grpc.QueryResult; +import com.salesforce.hyperdb.grpc.QueryResultPartBinary; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Random; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import lombok.Value; +import lombok.experimental.UtilityClass; +import lombok.val; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +@UtilityClass +public class RealisticArrowGenerator { + private static final Random random = new Random(10); + + public static QueryResult data() { + val student = new Student(random.nextInt(), "Somebody", random.nextDouble()); + return getMockedData(List.of(student)).findFirst().orElse(null); + } + + @Value + @AllArgsConstructor + public static class Student { + int id; + String name; + double grade; + } + + public static Stream getMockedData(List students) { + val qr = QueryResult.newBuilder() + .setBinaryPart(QueryResultPartBinary.newBuilder() + .setData(convertStudentsToArrowBinary(students)) + .build()) + .build(); + return Stream.of(qr); + } + + private Schema getSchema() { + val intField = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + val stringField = new Field("name", FieldType.nullable(new ArrowType.Utf8()), null); + val doubleField = new Field( + "grade", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null); + + val fields = List.of(intField, stringField, doubleField); + return new Schema(fields); + } + + @SneakyThrows + private ByteString convertStudentsToArrowBinary(List students) { + val rowCount = students.size(); + + try (val allocator = new RootAllocator()) { + val schemaPerson = getSchema(); + try (val vectorSchemaRoot = VectorSchemaRoot.create(schemaPerson, allocator)) { + val idVector = (IntVector) vectorSchemaRoot.getVector("id"); + val nameVector = (VarCharVector) vectorSchemaRoot.getVector("name"); + val scoreVector = (Float8Vector) vectorSchemaRoot.getVector("grade"); + + idVector.allocateNew(rowCount); + nameVector.allocateNew(rowCount); + scoreVector.allocateNew(rowCount); + + IntStream.range(0, rowCount).forEach(i -> { + idVector.setSafe(i, students.get(i).id); + nameVector.setSafe(i, students.get(i).name.getBytes(StandardCharsets.UTF_8)); + scoreVector.setSafe(i, students.get(i).grade); + }); + + vectorSchemaRoot.setRowCount(rowCount); + + val outputStream = new ByteArrayOutputStream(); + try (val writer = new ArrowStreamWriter(vectorSchemaRoot, null, outputStream)) { + writer.start(); + writer.writeBatch(); + writer.end(); + } + + val s = outputStream.toByteArray(); + return ByteString.copyFrom(s); + } + } + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/RequestRecordingInterceptor.java b/src/test/java/com/salesforce/datacloud/jdbc/util/RequestRecordingInterceptor.java new file mode 100644 index 0000000..96f9840 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/RequestRecordingInterceptor.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import java.util.ArrayList; +import java.util.List; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import lombok.val; + +/** https://grpc.github.io/grpc-java/javadoc/io/grpc/auth/ClientAuthInterceptor.html */ +@Getter +@Slf4j +public class RequestRecordingInterceptor implements ClientInterceptor { + + private final List queries = new ArrayList<>(); + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + val name = method.getFullMethodName(); + + queries.add(name); + log.info("Executing grpc endpoint: " + name); + + return new ForwardingClientCall.SimpleForwardingClientCall<>(next.newCall(method, callOptions)) { + @Override + public void start(final Listener responseListener, final Metadata headers) { + headers.put(Metadata.Key.of("FOO", Metadata.ASCII_STRING_MARSHALLER), "BAR"); + super.start(responseListener, headers); + } + }; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/RequireTest.java b/src/test/java/com/salesforce/datacloud/jdbc/util/RequireTest.java new file mode 100644 index 0000000..3c826e6 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/RequireTest.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static com.salesforce.datacloud.jdbc.util.Require.requireNotNullOrBlank; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import lombok.val; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullSource; +import org.junit.jupiter.params.provider.ValueSource; + +class RequireTest { + @ParameterizedTest(name = "#{index} - requireNotNullOrBlank throws on args='{0}'") + @NullSource + @ValueSource(strings = {"", " "}) + void requireThrowsOn(String value) { + val exception = assertThrows(IllegalArgumentException.class, () -> requireNotNullOrBlank(value, "thing")); + val expected = "Expected argument 'thing' to not be null or blank"; + assertThat(exception).hasMessage(expected); + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/RootAllocatorTestExtension.java b/src/test/java/com/salesforce/datacloud/jdbc/util/RootAllocatorTestExtension.java new file mode 100644 index 0000000..e0a6fc7 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/RootAllocatorTestExtension.java @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static java.lang.Byte.MAX_VALUE; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import lombok.Getter; +import lombok.val; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; + +@Getter +public class RootAllocatorTestExtension implements AfterAllCallback, AutoCloseable { + + private final BufferAllocator rootAllocator = new RootAllocator(); + private final Random random = new Random(10); + + @Override + public void afterAll(ExtensionContext context) { + try { + close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() throws Exception { + this.rootAllocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(this.rootAllocator); + } + + public Float8Vector createFloat8Vector(List values) { + val vector = new Float8Vector("test-double-vector", getRootAllocator()); + vector.allocateNew(values.size()); + for (int i = 0; i < values.size(); i++) { + Double d = values.get(i); + if (d == null) { + vector.setNull(i); + } else { + vector.setSafe(i, d); + } + } + vector.setValueCount(values.size()); + return vector; + } + + public VarCharVector createVarCharVectorFrom(List values) { + val vector = new VarCharVector("test-varchar-vector", getRootAllocator()); + vector.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + vector.setSafe(i, values.get(i).getBytes(StandardCharsets.UTF_8)); + } + return vector; + } + + public LargeVarCharVector createLargeVarCharVectorFrom(List values) { + val vector = new LargeVarCharVector("test-large-varchar-vector", getRootAllocator()); + vector.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + vector.setSafe(i, values.get(i).getBytes(StandardCharsets.UTF_8)); + } + return vector; + } + + public DecimalVector createDecimalVector(List values) { + val vector = new DecimalVector("test-decimal-vector", getRootAllocator(), 39, 0); + vector.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + vector.setSafe(i, values.get(i)); + } + return vector; + } + + public BitVector createBitVector(List values) { + BitVector vector = new BitVector("Value", this.getRootAllocator()); + vector.allocateNew(values.size()); + for (int i = 0; i < values.size(); i++) { + Boolean b = values.get(i); + if (b == null) { + vector.setNull(i); + } else { + vector.setSafe(i, b ? 1 : 0); + } + } + vector.setValueCount(values.size()); + return vector; + } + + public static T nulledOutVector(T vector) { + val original = vector.getValueCount(); + vector.clear(); + vector.setValueCount(original); + return vector; + } + + public DateDayVector createDateDayVector() { + return new DateDayVector("DateDay", this.getRootAllocator()); + } + + public DateMilliVector createDateMilliVector() { + return new DateMilliVector("DateMilli", this.getRootAllocator()); + } + + public static T appendDates(List values, T vector) { + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.stream() + .map(TimeUnit.MILLISECONDS::toDays) + .forEachOrdered(t -> vector.setSafe(i.getAndIncrement(), Math.toIntExact(t))); + vector.setValueCount(values.size()); + return vector; + } + + public static T appendDates(List values, T vector) { + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.forEach(t -> vector.setSafe(i.getAndIncrement(), t)); + vector.setValueCount(values.size()); + return vector; + } + + public TimeNanoVector createTimeNanoVector(List values) { + TimeNanoVector vector = new TimeNanoVector("TimeNano", this.getRootAllocator()); + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.forEach(t -> vector.setSafe(i.getAndIncrement(), t)); + vector.setValueCount(values.size()); + return vector; + } + + public TimeMicroVector createTimeMicroVector(List values) { + TimeMicroVector vector = new TimeMicroVector("TimeMicro", this.getRootAllocator()); + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.forEach(t -> vector.setSafe(i.getAndIncrement(), t)); + vector.setValueCount(values.size()); + return vector; + } + + public TimeMilliVector createTimeMilliVector(List values) { + TimeMilliVector vector = new TimeMilliVector("TimeMilli", this.getRootAllocator()); + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.forEach(t -> vector.setSafe(i.getAndIncrement(), t)); + vector.setValueCount(values.size()); + return vector; + } + + public TimeSecVector createTimeSecVector(List values) { + TimeSecVector vector = new TimeSecVector("TimeSec", this.getRootAllocator()); + AtomicInteger i = new AtomicInteger(0); + vector.allocateNew(values.size()); + values.forEach(t -> vector.setSafe(i.getAndIncrement(), t)); + vector.setValueCount(values.size()); + return vector; + } + + /** + * Create a VarBinaryVector to be used in the accessor tests. + * + * @return VarBinaryVector + */ + public VarBinaryVector createVarBinaryVector(List values) { + VarBinaryVector vector = new VarBinaryVector("test-varbinary-vector", this.getRootAllocator()); + vector.allocateNew(values.size()); + for (int i = 0; i < values.size(); i++) { + byte[] b = values.get(i); + if (b == null) { + vector.setNull(i); + } else { + vector.setSafe(i, values.get(i)); + } + } + vector.setValueCount(values.size()); + return vector; + } + + /** + * Create a LargeVarBinaryVector to be used in the accessor tests. + * + * @return LargeVarBinaryVector + */ + public LargeVarBinaryVector createLargeVarBinaryVector(List values) { + LargeVarBinaryVector vector = new LargeVarBinaryVector("test-large-varbinary-vector", this.getRootAllocator()); + vector.allocateNew(values.size()); + for (int i = 0; i < values.size(); i++) { + byte[] b = values.get(i); + if (b == null) { + vector.setNull(i); + } else { + vector.setSafe(i, values.get(i)); + } + } + vector.setValueCount(values.size()); + return vector; + } + + /** + * Create a FixedSizeBinaryVector to be used in the accessor tests. + * + * @return FixedSizeBinaryVector + */ + public FixedSizeBinaryVector createFixedSizeBinaryVector(List values) { + FixedSizeBinaryVector vector = + new FixedSizeBinaryVector("test-fixedsize-varbinary-vector", this.getRootAllocator(), 16); + vector.allocateNew(values.size()); + for (int i = 0; i < values.size(); i++) { + byte[] b = values.get(i); + if (b == null) { + vector.setNull(i); + } else { + vector.setSafe(i, values.get(i)); + } + } + vector.setValueCount(values.size()); + return vector; + } + + /** + * Create a TinyIntVector to be used in the accessor tests. + * + * @return TinyIntVector + */ + public TinyIntVector createTinyIntVector(List values) { + TinyIntVector result = new TinyIntVector("test-tinyInt-Vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + + return result; + } + + /** + * Create a SmallIntVector to be used in the accessor tests. + * + * @return SmallIntVector + */ + public SmallIntVector createSmallIntVector(List values) { + + SmallIntVector result = new SmallIntVector("test-smallInt-Vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + + return result; + } + + /** + * Create a IntVector to be used in the accessor tests. + * + * @return IntVector + */ + public IntVector createIntVector(List values) { + IntVector result = new IntVector("test-int-vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + + return result; + } + + /** + * Create a UInt4Vector to be used in the accessor tests. + * + * @return UInt4Vector + */ + public UInt4Vector createUInt4Vector(List values) { + UInt4Vector result = new UInt4Vector("test-uint4-vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + + return result; + } + + /** + * Create a BigIntVector to be used in the accessor tests. + * + * @return BigIntVector + */ + public BigIntVector createBigIntVector(List values) { + BigIntVector result = new BigIntVector("test-bitInt-vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + return result; + } + + public TimeStampNanoVector createTimeStampNanoVector(List values) { + TimeStampNanoVector result = new TimeStampNanoVector("", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toNanos(values.get(i))); + } + return result; + } + + public TimeStampNanoTZVector createTimeStampNanoTZVector(List values, String timeZone) { + TimeStampNanoTZVector result = new TimeStampNanoTZVector("", this.getRootAllocator(), timeZone); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toNanos(values.get(i))); + } + return result; + } + + public TimeStampMicroVector createTimeStampMicroVector(List values) { + TimeStampMicroVector result = new TimeStampMicroVector("", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toMicros(values.get(i))); + } + return result; + } + + public TimeStampMicroTZVector createTimeStampMicroTZVector(List values, String timeZone) { + TimeStampMicroTZVector result = new TimeStampMicroTZVector("", this.getRootAllocator(), timeZone); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toMicros(values.get(i))); + } + return result; + } + + public TimeStampMilliVector createTimeStampMilliVector(List values) { + TimeStampMilliVector result = new TimeStampMilliVector("test-milli-vector", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + return result; + } + + public TimeStampMilliTZVector createTimeStampMilliTZVector(List values, String timeZone) { + TimeStampMilliTZVector result = new TimeStampMilliTZVector("", this.getRootAllocator(), timeZone); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, values.get(i)); + } + return result; + } + + public TimeStampSecVector createTimeStampSecVector(List values) { + TimeStampSecVector result = new TimeStampSecVector("", this.getRootAllocator()); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toSeconds(values.get(i))); + } + return result; + } + + public TimeStampSecTZVector createTimeStampSecTZVector(List values, String timeZone) { + TimeStampSecTZVector result = new TimeStampSecTZVector("", this.getRootAllocator(), timeZone); + result.setValueCount(values.size()); + for (int i = 0; i < values.size(); i++) { + result.setSafe(i, TimeUnit.MILLISECONDS.toSeconds(values.get(i))); + } + return result; + } + + public ListVector createListVector(String fieldName) { + ListVector listVector = ListVector.empty(fieldName, this.getRootAllocator()); + listVector.setInitialCapacity(MAX_VALUE); + + UnionListWriter writer = listVector.getWriter(); + + IntStream.range(0, MAX_VALUE).forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + listVector.setValueCount(MAX_VALUE); + return listVector; + } + + public LargeListVector createLargeListVector(String fieldName) { + LargeListVector largeListVector = LargeListVector.empty(fieldName, this.getRootAllocator()); + largeListVector.setInitialCapacity(MAX_VALUE); + + UnionLargeListWriter writer = largeListVector.getWriter(); + + IntStream.range(0, MAX_VALUE).forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + largeListVector.setValueCount(MAX_VALUE); + return largeListVector; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/TestWasNullConsumer.java b/src/test/java/com/salesforce/datacloud/jdbc/util/TestWasNullConsumer.java new file mode 100644 index 0000000..a9e27c1 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/TestWasNullConsumer.java @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import com.salesforce.datacloud.jdbc.core.accessor.QueryJDBCAccessorFactory; +import lombok.Data; +import org.assertj.core.api.SoftAssertions; + +@Data +public class TestWasNullConsumer implements QueryJDBCAccessorFactory.WasNullConsumer { + private final SoftAssertions collector; + + private int wasNullSeen = 0; + private int wasNotNullSeen = 0; + + @Override + public void setWasNull(boolean wasNull) { + if (wasNull) wasNullSeen++; + else wasNotNullSeen++; + } + + public TestWasNullConsumer hasNullSeen(int nullsSeen) { + collector.assertThat(this.wasNullSeen).as("witnessed null count").isEqualTo(nullsSeen); + return this; + } + + public TestWasNullConsumer hasNotNullSeen(int notNullSeen) { + collector.assertThat(this.wasNotNullSeen).as("witnessed not null count").isEqualTo(notNullSeen); + return this; + } + + public TestWasNullConsumer assertThat() { + return this; + } +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingBiFunction.java b/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingBiFunction.java new file mode 100644 index 0000000..20f9115 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingBiFunction.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +public interface ThrowingBiFunction { + R apply(T t, U u) throws Exception; +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java b/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java new file mode 100644 index 0000000..730dbe7 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/ThrowingFunction.java @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import java.util.function.Function; + +@FunctionalInterface +public interface ThrowingFunction { + static Function rethrowFunction(ThrowingFunction function) throws E { + return t -> { + try { + return function.apply(t); + } catch (Exception exception) { + throwAsUnchecked(exception); + return null; + } + }; + } + + @SuppressWarnings("unchecked") + private static void throwAsUnchecked(Exception exception) throws E { + throw (E) exception; + } + + R apply(T t) throws E; +} diff --git a/src/test/java/com/salesforce/datacloud/jdbc/util/VectorPopulatorTest.java b/src/test/java/com/salesforce/datacloud/jdbc/util/VectorPopulatorTest.java new file mode 100644 index 0000000..9067c42 --- /dev/null +++ b/src/test/java/com/salesforce/datacloud/jdbc/util/VectorPopulatorTest.java @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2024, Salesforce, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.salesforce.datacloud.jdbc.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.salesforce.datacloud.jdbc.core.model.ParameterBinding; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.List; +import java.util.TimeZone; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class VectorPopulatorTest { + + private VectorSchemaRoot mockRoot; + private VarCharVector varcharVector; + private Float4Vector float4Vector; + private Float8Vector float8Vector; + private IntVector intVector; + private SmallIntVector smallIntVector; + private BigIntVector bigIntVector; + private BitVector bitVector; + private DecimalVector decimalVector; + private DateDayVector dateDayVector; + private TimeMicroVector timeMicroVector; + private TimeStampMicroTZVector timestampMicroTZVector; + private List parameterBindings; + private Calendar calendar; + + @BeforeEach + public void setUp() { + mockRoot = mock(VectorSchemaRoot.class); + varcharVector = mock(VarCharVector.class); + float4Vector = mock(Float4Vector.class); + float8Vector = mock(Float8Vector.class); + intVector = mock(IntVector.class); + smallIntVector = mock(SmallIntVector.class); + bigIntVector = mock(BigIntVector.class); + bitVector = mock(BitVector.class); + decimalVector = mock(DecimalVector.class); + dateDayVector = mock(DateDayVector.class); + timeMicroVector = mock(TimeMicroVector.class); + timestampMicroTZVector = mock(TimeStampMicroTZVector.class); + + Schema schema = new Schema(List.of( + new Field("1", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("2", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), null), + new Field("3", FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), null), + new Field("4", FieldType.nullable(new ArrowType.Int(32, true)), null), + new Field("5", FieldType.nullable(new ArrowType.Int(16, true)), null), + new Field("6", FieldType.nullable(new ArrowType.Int(64, true)), null), + new Field("7", FieldType.nullable(new ArrowType.Bool()), null), + new Field("8", FieldType.nullable(new ArrowType.Decimal(10, 2, 128)), null), + new Field("9", FieldType.nullable(new ArrowType.Date(DateUnit.DAY)), null), + new Field("10", FieldType.nullable(new ArrowType.Time(TimeUnit.MICROSECOND, 64)), null), + new Field("11", FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), null))); + + when(mockRoot.getSchema()).thenReturn(schema); + + when(mockRoot.getVector("1")).thenReturn(varcharVector); + when(mockRoot.getVector("2")).thenReturn(float4Vector); + when(mockRoot.getVector("3")).thenReturn(float8Vector); + when(mockRoot.getVector("4")).thenReturn(intVector); + when(mockRoot.getVector("5")).thenReturn(smallIntVector); + when(mockRoot.getVector("6")).thenReturn(bigIntVector); + when(mockRoot.getVector("7")).thenReturn(bitVector); + when(mockRoot.getVector("8")).thenReturn(decimalVector); + when(mockRoot.getVector("9")).thenReturn(dateDayVector); + when(mockRoot.getVector("10")).thenReturn(timeMicroVector); + when(mockRoot.getVector("11")).thenReturn(timestampMicroTZVector); + + // Initialize parameter bindings + parameterBindings = Arrays.asList( + new ParameterBinding(Types.VARCHAR, "test"), + new ParameterBinding(Types.FLOAT, 1.23f), + new ParameterBinding(Types.DOUBLE, 4.56), + new ParameterBinding(Types.INTEGER, 123), + new ParameterBinding(Types.SMALLINT, (short) 12345), + new ParameterBinding(Types.BIGINT, 123456789L), + new ParameterBinding(Types.BOOLEAN, true), + new ParameterBinding(Types.DECIMAL, new BigDecimal("12345.67")), + new ParameterBinding(Types.DATE, Date.valueOf("2024-01-01")), + new ParameterBinding(Types.TIME, Time.valueOf("12:34:56")), + new ParameterBinding(Types.TIMESTAMP, Timestamp.valueOf("2024-01-01 12:34:56"))); + + calendar = Calendar.getInstance(); // Mock calendar as needed + } + + @Test + public void testPopulateVectors() { + VectorPopulator.populateVectors(mockRoot, parameterBindings, null); + + verify(mockRoot, times(1)).getVector("1"); + verify(varcharVector, times(1)).setSafe(0, "test".getBytes(StandardCharsets.UTF_8)); + verify(mockRoot, times(1)).getVector("2"); + verify(float4Vector, times(1)).setSafe(0, 1.23f); + verify(mockRoot, times(1)).getVector("3"); + verify(float8Vector, times(1)).setSafe(0, 4.56); + verify(mockRoot, times(1)).getVector("4"); + verify(intVector, times(1)).setSafe(0, 123); + verify(mockRoot, times(1)).getVector("5"); + verify(smallIntVector, times(1)).setSafe(0, (short) 12345); + verify(mockRoot, times(1)).getVector("6"); + verify(bigIntVector, times(1)).setSafe(0, 123456789L); + verify(mockRoot, times(1)).getVector("7"); + verify(bitVector, times(1)).setSafe(0, 1); + verify(mockRoot, times(1)).getVector("8"); + verify(decimalVector, times(1)) + .setSafe(0, new BigDecimal("12345.67").unscaledValue().longValue()); + verify(mockRoot, times(1)).getVector("9"); + verify(dateDayVector, times(1)) + .setSafe(0, (int) Date.valueOf("2024-01-01").toLocalDate().toEpochDay()); + Time time = Time.valueOf("12:34:56"); + LocalDateTime localDateTime = new java.sql.Timestamp(time.getTime()).toLocalDateTime(); + localDateTime = DateTimeUtils.adjustForCalendar(localDateTime, calendar, TimeZone.getDefault()); + + long midnightMillis = localDateTime.toLocalTime().toNanoOfDay() / 1_000_000; + long expectedMicroseconds = DateTimeUtils.millisToMicrosecondsSinceMidnight(midnightMillis); + + verify(mockRoot, times(1)).getVector("10"); + verify(timeMicroVector, times(1)).setSafe(0, expectedMicroseconds); + + Timestamp timestamp = Timestamp.valueOf("2024-01-01 12:34:56"); + LocalDateTime localDateTime1 = timestamp.toLocalDateTime(); + localDateTime1 = DateTimeUtils.adjustForCalendar(localDateTime1, calendar, TimeZone.getDefault()); + + long expectedTimestampInMicroseconds = DateTimeUtils.localDateTimeToMicrosecondsSinceEpoch(localDateTime1); + + verify(mockRoot, times(1)).getVector("11"); + verify(timestampMicroTZVector, times(1)).setSafe(0, expectedTimestampInMicroseconds); + + verify(mockRoot, times(1)).setRowCount(1); + } + + @Test + public void testPopulateVectorsWithNullParameterBindings() { + Schema schema = new Schema(List.of( + new Field("1", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("2", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("3", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("4", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("5", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("6", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("7", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("8", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("9", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("10", FieldType.nullable(new ArrowType.Utf8()), null), + new Field("11", FieldType.nullable(new ArrowType.Utf8()), null))); + + when(mockRoot.getSchema()).thenReturn(schema); + + List parameterBindings = List.of( + new ParameterBinding(Types.VARCHAR, null), + new ParameterBinding(Types.FLOAT, null), + new ParameterBinding(Types.DOUBLE, null), + new ParameterBinding(Types.INTEGER, null), + new ParameterBinding(Types.SMALLINT, null), + new ParameterBinding(Types.BIGINT, null), + new ParameterBinding(Types.BOOLEAN, null), + new ParameterBinding(Types.DECIMAL, null), + new ParameterBinding(Types.DATE, null), + new ParameterBinding(Types.TIME, null), + new ParameterBinding(Types.TIMESTAMP, null)); + + VectorPopulator.populateVectors(mockRoot, parameterBindings, calendar); + + verify(mockRoot, times(1)).getVector("1"); + verify(varcharVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("2"); + verify(float4Vector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("3"); + verify(float8Vector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("4"); + verify(intVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("5"); + verify(smallIntVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("6"); + verify(bigIntVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("7"); + verify(bitVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("8"); + verify(decimalVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("9"); + verify(dateDayVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("10"); + verify(timeMicroVector, times(1)).setNull(0); + verify(mockRoot, times(1)).getVector("11"); + verify(timestampMicroTZVector, times(1)).setNull(0); + + verify(mockRoot, times(1)).setRowCount(1); + } +} diff --git a/src/test/resources/hyper.yaml b/src/test/resources/hyper.yaml new file mode 100644 index 0000000..8d99af3 --- /dev/null +++ b/src/test/resources/hyper.yaml @@ -0,0 +1,8 @@ +listen-connection: tcp.grpc://0.0.0.0:8181 +skip-license: true +language: en_US +no-password: true +use_v3_new_endpoints: true +use_result_spooling: true +grpc_persist_results: true +log_pipelines: true \ No newline at end of file diff --git a/src/test/resources/simplelogger.properties b/src/test/resources/simplelogger.properties new file mode 100644 index 0000000..f14e91d --- /dev/null +++ b/src/test/resources/simplelogger.properties @@ -0,0 +1,4 @@ +org.slf4j.simpleLogger.logFile=System.out +org.slf4j.simpleLogger.defaultLogLevel=info + +org.slf4j.simpleLogger.log.com.salesforce.cdp.queryservice.core.QueryServiceConnection=debug