diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 1bdd4d9..b5d612a 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -9,10 +9,11 @@ from cryptography.hazmat.primitives import serialization from singer_sdk import typing as th from singer_sdk.connectors import SQLConnector +from singer_sdk.connectors.sql import FullyQualifiedName from snowflake.sqlalchemy import URL from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer from snowflake.sqlalchemy.snowdialect import SnowflakeDialect -from sqlalchemy.sql import quoted_name, text +from sqlalchemy.sql import text from target_snowflake.snowflake_types import NUMBER, TIMESTAMP_NTZ, VARIANT @@ -44,6 +45,23 @@ def evaluate_typemaps(type_maps, compare_value, unmatched_value): # noqa: ANN00 return unmatched_value +class SnowflakeFullyQualifiedName(FullyQualifiedName): + def __init__( + self, + *, + table: str | None = None, + schema: str | None = None, + database: str | None = None, + delimiter: str = ".", + dialect: SnowflakeDialect, + ) -> None: + self.dialect = dialect + super().__init__(table=table, schema=schema, database=database, delimiter=delimiter) + + def prepare_part(self, part: str) -> str: + return self.dialect.identifier_preparer.quote(part) + + class SnowflakeConnector(SQLConnector): """Snowflake Target Connector. @@ -388,7 +406,7 @@ def _get_merge_from_stage_statement( # noqa: ANN202 dedup = f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {dedup_cols} ORDER BY SEQ8() DESC) = 1" return ( text( - f"merge into {quoted_name(full_table_name, quote=True)} d using " # noqa: ISC003 + f"merge into {full_table_name} d using " # noqa: ISC003 + f"(select {json_casting_selects} from '@~/target-snowflake/{sync_id}'" # noqa: S608 + f"(file_format => {file_format}) {dedup}) s " + f"on {join_expr} " @@ -634,3 +652,18 @@ def _adapt_column_type( sql_type, ) raise + + def get_fully_qualified_name( + self, + table_name: str | None = None, + schema_name: str | None = None, + db_name: str | None = None, + delimiter: str = ".", + ) -> SnowflakeFullyQualifiedName: + return SnowflakeFullyQualifiedName( + table=table_name, + schema=schema_name, + database=db_name, + delimiter=delimiter, + dialect=self._dialect, + ) diff --git a/tests/core.py b/tests/core.py index 6643ea8..ae04337 100644 --- a/tests/core.py +++ b/tests/core.py @@ -458,6 +458,61 @@ def setup(self) -> None: ) +class SnowflakeTargetExistingReservedNameTableAlter(TargetFileTestTemplate): + name = "existing_reserved_name_table_alter" + # This sends a schema that will request altering from TIMESTAMP_NTZ to VARCHAR + + @property + def singer_filepath(self) -> Path: + current_dir = Path(__file__).resolve().parent + return current_dir / "target_test_streams" / "reserved_words_in_table.singer" + + def setup(self) -> None: + connector = self.target.default_sink_class.connector_class(self.target.config) + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.\"order\"".upper() + connector.connection.execute( + f""" + CREATE OR REPLACE TABLE {table} ( + ID VARCHAR(16777216), + COL_STR VARCHAR(16777216), + COL_TS TIMESTAMP_NTZ(9), + COL_INT STRING, + COL_BOOL BOOLEAN, + COL_VARIANT VARIANT, + _SDC_BATCHED_AT TIMESTAMP_NTZ(9), + _SDC_DELETED_AT VARCHAR(16777216), + _SDC_EXTRACTED_AT TIMESTAMP_NTZ(9), + _SDC_RECEIVED_AT TIMESTAMP_NTZ(9), + _SDC_SEQUENCE NUMBER(38,0), + _SDC_TABLE_VERSION NUMBER(38,0), + PRIMARY KEY (ID) + ) + """, + ) + + +class SnowflakeTargetReservedWordsInTable(TargetFileTestTemplate): + # Contains reserved words from + # https://docs.snowflake.com/en/sql-reference/reserved-keywords + # Syncs records then alters schema by adding a non-reserved word column. + name = "reserved_words_in_table" + + @property + def singer_filepath(self) -> Path: + current_dir = Path(__file__).resolve().parent + return current_dir / "target_test_streams" / "reserved_words_in_table.singer" + + def validate(self) -> None: + connector = self.target.default_sink_class.connector_class(self.target.config) + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.\"order\"".upper() + result = connector.connection.execute( + f"select * from {table}", + ) + assert result.rowcount == 1 + row = result.first() + assert len(row) == 13, f"Row has unexpected length {len(row)}" + + class SnowflakeTargetTypeEdgeCasesTest(TargetFileTestTemplate): name = "type_edge_cases" @@ -540,6 +595,8 @@ def singer_filepath(self) -> Path: SnowflakeTargetColonsInColName, SnowflakeTargetExistingTable, SnowflakeTargetExistingTableAlter, + SnowflakeTargetExistingReservedNameTableAlter, + SnowflakeTargetReservedWordsInTable, SnowflakeTargetTypeEdgeCasesTest, SnowflakeTargetColumnOrderMismatch, ], diff --git a/tests/target_test_streams/reserved_words_in_table.singer b/tests/target_test_streams/reserved_words_in_table.singer new file mode 100644 index 0000000..ede9a18 --- /dev/null +++ b/tests/target_test_streams/reserved_words_in_table.singer @@ -0,0 +1,2 @@ +{ "type": "SCHEMA", "stream": "order", "schema": { "properties": { "id": { "type": [ "string", "null" ] }, "col_str": { "type": [ "string", "null" ] }, "col_ts": { "format": "date-time", "type": [ "string", "null" ] }, "col_int": { "type": "integer" }, "col_bool": { "type": [ "boolean", "null" ] }, "col_variant": {"type": "object"} }, "type": "object" }, "key_properties": [ "id" ], "bookmark_properties": [ "col_ts" ] } +{ "type": "RECORD", "stream": "order", "record": { "id": "123", "col_str": "foo", "col_ts": "2023-06-13 11:50:04.072", "col_int": 5, "col_bool": true, "col_variant": {"key": "val"} }, "time_extracted": "2023-06-14T18:08:23.074716+00:00" }