From 9e81dc4978f7b769ff3d80b333ac853851648a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edgar=20Ram=C3=ADrez=20Mondrag=C3=B3n?= <16805946+edgarrmondragon@users.noreply.github.com> Date: Tue, 21 Nov 2023 15:58:30 -0600 Subject: [PATCH] fix: Invalid f-string syntax in Python 3.7 (#132) --- .pre-commit-config.yaml | 14 +- output/.gitignore | 2 +- pyproject.toml | 21 +- target_snowflake/__init__.py | 1 + target_snowflake/connector.py | 257 ++++++++++-------- target_snowflake/initializer.py | 56 +++- target_snowflake/sinks.py | 62 +++-- target_snowflake/snowflake_types.py | 11 +- tests/__init__.py | 2 +- tests/batch.py | 47 ++-- tests/conftest.py | 3 +- tests/core.py | 60 ++-- .../batch_multiple_state_messages__a.jsonl | 2 +- .../batch_multiple_state_messages__b.jsonl | 2 +- ...ch_record_missing_required_property.singer | 1 - .../type_edge_cases.singer | 2 +- tests/test_target_snowflake.py | 29 +- tox.ini | 1 - 18 files changed, 336 insertions(+), 237 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24251f4..522e8cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-json - id: check-toml @@ -14,24 +14,20 @@ repos: - id: trailing-whitespace - repo: https://github.com/python-jsonschema/check-jsonschema - rev: 0.23.0 + rev: 0.27.1 hooks: - id: check-dependabot - id: check-github-workflows - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.269 + rev: v0.1.6 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black + - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 + rev: v1.7.0 hooks: - id: mypy additional_dependencies: diff --git a/output/.gitignore b/output/.gitignore index c96a04f..d6b7ef3 100644 --- a/output/.gitignore +++ b/output/.gitignore @@ -1,2 +1,2 @@ * -!.gitignore \ No newline at end of file +!.gitignore diff --git a/pyproject.toml b/pyproject.toml index 0451e7d..f99cda9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,21 +29,32 @@ singer-sdk = { version="0.30.0", extras = ["testing"] } coverage = "^7.2.7" [tool.ruff] +line-length = 120 +src = ["target_snowflake"] +target-version = "py37" + +[tool.ruff.lint] ignore = [ "ANN101", # missing-type-self "ANN102", # missing-type-cls + "ANN201", + "TD", + "D", + "FIX", ] select = ["ALL"] -src = ["target_snowflake"] -target-version = "py37" -[tool.ruff.flake8-annotations] +[tool.ruff.lint.flake8-annotations] allow-star-arg-any = true -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["target_snowflake"] +required-imports = ["from __future__ import annotations"] + +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101", "S608", "PLR2004", "ANN"] -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" [build-system] diff --git a/target_snowflake/__init__.py b/target_snowflake/__init__.py index f9d415b..66819ec 100644 --- a/target_snowflake/__init__.py +++ b/target_snowflake/__init__.py @@ -1,3 +1,4 @@ """Singer.io Target for the Snowflake data warehouse platform.""" +from __future__ import annotations __version__ = "0.0.0" diff --git a/target_snowflake/connector.py b/target_snowflake/connector.py index 12a48cb..88d5b2f 100644 --- a/target_snowflake/connector.py +++ b/target_snowflake/connector.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from operator import contains, eq -from typing import Dict, List, Sequence, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast import snowflake.sqlalchemy.custom_types as sct import sqlalchemy @@ -10,21 +12,23 @@ from snowflake.sqlalchemy import URL from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer from snowflake.sqlalchemy.snowdialect import SnowflakeDialect -from sqlalchemy.engine import Engine from sqlalchemy.sql import text from target_snowflake.snowflake_types import NUMBER, TIMESTAMP_NTZ, VARIANT +if TYPE_CHECKING: + from sqlalchemy.engine import Engine + SNOWFLAKE_MAX_STRING_LENGTH = 16777216 class TypeMap: - def __init__(self, operator, map_value, match_value=None): + def __init__(self, operator, map_value, match_value=None) -> None: # noqa: ANN001 self.operator = operator self.map_value = map_value self.match_value = match_value - def match(self, compare_value): + def match(self, compare_value): # noqa: ANN001 try: if self.match_value: return self.operator(compare_value, self.match_value) @@ -33,7 +37,7 @@ def match(self, compare_value): return False -def evaluate_typemaps(type_maps, compare_value, unmatched_value): +def evaluate_typemaps(type_maps, compare_value, unmatched_value): # noqa: ANN001 for type_map in type_maps: if type_map.match(compare_value): return type_map.map_value @@ -52,16 +56,16 @@ class SnowflakeConnector(SQLConnector): allow_merge_upsert: bool = False # Whether MERGE UPSERT is supported. allow_temp_tables: bool = True # Whether temp tables are supported. - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self.table_cache: dict = {} self.schema_cache: dict = {} super().__init__(*args, **kwargs) - + def get_table_columns( self, full_table_name: str, - column_names: Union[List[str], None] = None, - ) -> Dict[str, sqlalchemy.Column]: + column_names: list[str] | None = None, + ) -> dict[str, sqlalchemy.Column]: """Return a list of table columns. Args: @@ -71,37 +75,36 @@ def get_table_columns( Returns: An ordered list of column objects. """ - # Cache these columns because they're frequently used. if full_table_name in self.table_cache: return self.table_cache[full_table_name] - else: - _, schema_name, table_name = self.parse_full_table_name(full_table_name) - inspector = sqlalchemy.inspect(self._engine) - columns = inspector.get_columns(table_name, schema_name) - - parsed_columns = { - col_meta["name"]: sqlalchemy.Column( - col_meta["name"], - self._convert_type(col_meta["type"]), - nullable=col_meta.get("nullable", False), - ) - for col_meta in columns - if not column_names - or col_meta["name"].casefold() in {col.casefold() for col in column_names} - } - self.table_cache[full_table_name] = parsed_columns - return parsed_columns + _, schema_name, table_name = self.parse_full_table_name(full_table_name) + inspector = sqlalchemy.inspect(self._engine) + columns = inspector.get_columns(table_name, schema_name) + + parsed_columns = { + col_meta["name"]: sqlalchemy.Column( + col_meta["name"], + self._convert_type(col_meta["type"]), + nullable=col_meta.get("nullable", False), + ) + for col_meta in columns + if not column_names or col_meta["name"].casefold() in {col.casefold() for col in column_names} + } + self.table_cache[full_table_name] = parsed_columns + return parsed_columns @staticmethod - def _convert_type(sql_type): + def _convert_type(sql_type): # noqa: ANN205, ANN001 if isinstance(sql_type, sct.TIMESTAMP_NTZ): return TIMESTAMP_NTZ - elif isinstance(sql_type, sct.NUMBER): + + if isinstance(sql_type, sct.NUMBER): return NUMBER - elif isinstance(sql_type, sct.VARIANT): + + if isinstance(sql_type, sct.VARIANT): return VARIANT - else: - return sql_type + + return sql_type def get_sqlalchemy_url(self, config: dict) -> str: """Generates a SQLAlchemy URL for Snowflake. @@ -118,7 +121,8 @@ def get_sqlalchemy_url(self, config: dict) -> str: if "password" in config: params["password"] = config["password"] elif "private_key_path" not in config: - raise Exception("Neither password nor private_key_path was provided for authentication.") + msg = "Neither password nor private_key_path was provided for authentication." + raise Exception(msg) # noqa: TRY002 for option in ["warehouse", "role"]: if config.get(option): @@ -143,13 +147,15 @@ def create_engine(self) -> Engine: connect_args = { "session_parameters": { "QUOTED_IDENTIFIERS_IGNORE_CASE": "TRUE", - } + }, } if "private_key_path" in self.config: - with open(self.config["private_key_path"], "rb") as private_key_file: + with open(self.config["private_key_path"], "rb") as private_key_file: # noqa: PTH123 private_key = serialization.load_pem_private_key( private_key_file.read(), - password=self.config["private_key_passphrase"].encode() if "private_key_passphrase" in self.config else None, + password=self.config["private_key_passphrase"].encode() + if "private_key_passphrase" in self.config + else None, backend=default_backend(), ) connect_args["private_key"] = private_key.private_bytes( @@ -165,7 +171,8 @@ def create_engine(self) -> Engine: connection = engine.connect() db_names = [db[1] for db in connection.execute(text("SHOW DATABASES;")).fetchall()] if self.config["database"] not in db_names: - raise Exception(f"Database '{self.config['database']}' does not exist or the user/role doesn't have access to it.") + msg = f"Database '{self.config['database']}' does not exist or the user/role doesn't have access to it." + raise Exception(msg) # noqa: TRY002 return engine def prepare_column( @@ -187,9 +194,13 @@ def prepare_column( column_name, sql_type, ) - except Exception as e: - self.logger.error(f"Error preparing column for {full_table_name=} {column_name=}") - raise e + except Exception: + self.logger.exception( + "Error preparing column for '%s.%s'", + full_table_name, + column_name, + ) + raise @staticmethod def get_column_rename_ddl( @@ -208,7 +219,9 @@ def get_column_rename_ddl( @staticmethod def get_column_alter_ddl( - table_name: str, column_name: str, column_type: sqlalchemy.types.TypeEngine + table_name: str, + column_name: str, + column_type: sqlalchemy.types.TypeEngine, ) -> sqlalchemy.DDL: """Get the alter column DDL statement. @@ -235,7 +248,7 @@ def get_column_alter_ddl( ) @staticmethod - def _conform_max_length(jsonschema_type): + def _conform_max_length(jsonschema_type): # noqa: ANN205, ANN001 """Alter jsonschema representations to limit max length to Snowflake's VARCHAR length.""" max_length = jsonschema_type.get("maxLength") if max_length and max_length > SNOWFLAKE_MAX_STRING_LENGTH: @@ -268,13 +281,13 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: TypeMap(eq, sqlalchemy.types.VARCHAR(maxlength), None), ] type_maps = [ - TypeMap(th._jsonschema_type_check, NUMBER(), ("integer",)), - TypeMap(th._jsonschema_type_check, VARIANT(), ("object",)), - TypeMap(th._jsonschema_type_check, VARIANT(), ("array",)), - TypeMap(th._jsonschema_type_check, sct.DOUBLE(), ("number",)), + TypeMap(th._jsonschema_type_check, NUMBER(), ("integer",)), # noqa: SLF001 + TypeMap(th._jsonschema_type_check, VARIANT(), ("object",)), # noqa: SLF001 + TypeMap(th._jsonschema_type_check, VARIANT(), ("array",)), # noqa: SLF001 + TypeMap(th._jsonschema_type_check, sct.DOUBLE(), ("number",)), # noqa: SLF001 ] # apply type maps - if th._jsonschema_type_check(jsonschema_type, ("string",)): + if th._jsonschema_type_check(jsonschema_type, ("string",)): # noqa: SLF001 datelike_type = th.get_datelike_property_type(jsonschema_type) target_type = evaluate_typemaps(string_submaps, datelike_type, target_type) else: @@ -285,37 +298,42 @@ def to_sql_type(jsonschema_type: dict) -> sqlalchemy.types.TypeEngine: def schema_exists(self, schema_name: str) -> bool: if schema_name in self.schema_cache: return True - else: - schema_names = sqlalchemy.inspect(self._engine).get_schema_names() - self.schema_cache = schema_names - formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) - # Make quoted schema names upper case because we create them that way - # and the metadata that SQLAlchemy returns is case insensitive only for non-quoted - # schema names so these will look like they dont exist yet. - if '"' in formatter.format_collation(schema_name): - schema_name = schema_name.upper() - return schema_name in schema_names + schema_names = sqlalchemy.inspect(self._engine).get_schema_names() + self.schema_cache = schema_names + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + # Make quoted schema names upper case because we create them that way + # and the metadata that SQLAlchemy returns is case insensitive only for + # non-quoted schema names so these will look like they dont exist yet. + if '"' in formatter.format_collation(schema_name): + schema_name = schema_name.upper() + return schema_name in schema_names # Custom SQL get methods - def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]: + def _get_put_statement(self, sync_id: str, file_uri: str) -> tuple[text, dict]: # noqa: ARG002 """Get Snowflake PUT statement.""" return (text(f"put :file_uri '@~/target-snowflake/{sync_id}'"), {}) @staticmethod - def _format_column_selections(column_selections: dict, format: str) -> str: + def _format_column_selections(column_selections: list, format: str) -> str: # noqa: A002 if format == "json_casting": - return ', '.join( + return ", ".join( [ - f"$1:{col['clean_property_name']}::{col['sql_type']} as {col['clean_alias']}" for col in column_selections - ] + f"$1:{col['clean_property_name']}::{col['sql_type']} as {col['clean_alias']}" + for col in column_selections + ], ) - elif format == "col_alias": + if format == "col_alias": return f"({', '.join([col['clean_alias'] for col in column_selections])})" - else: - raise NotImplementedError(f"Column format not implemented: {format}") - def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPreparer) -> list: + error_message = f"Column format not implemented: {format}" + raise NotImplementedError(error_message) + + def _get_column_selections( + self, + schema: dict, + formatter: SnowflakeIdentifierPreparer, + ) -> list: column_selections = [] for property_name, property_def in schema["properties"].items(): clean_property_name = formatter.format_collation(property_name) @@ -327,81 +345,91 @@ def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPre "clean_property_name": clean_property_name, "sql_type": self.to_sql_type(property_def), "clean_alias": clean_alias, - } + }, ) return column_selections - def _get_merge_from_stage_statement( - self, full_table_name, schema, sync_id, file_format, key_properties + def _get_merge_from_stage_statement( # noqa: ANN202, PLR0913 + self, + full_table_name: str, + schema: dict, + sync_id: str, + file_format: str, + key_properties: Iterable[str], ): """Get Snowflake MERGE statement.""" - formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) column_selections = self._get_column_selections(schema, formatter) - json_casting_selects = self._format_column_selections(column_selections, "json_casting") + json_casting_selects = self._format_column_selections( + column_selections, + "json_casting", + ) # use UPPER from here onwards - formatted_properties = [formatter.format_collation(col) for col in schema["properties"].keys()] + formatted_properties = [formatter.format_collation(col) for col in schema["properties"]] formatted_key_properties = [formatter.format_collation(col) for col in key_properties] join_expr = " and ".join( - [f'd.{key} = s.{key}' for key in formatted_key_properties] + [f"d.{key} = s.{key}" for key in formatted_key_properties], ) matched_clause = ", ".join( - [f'd.{col} = s.{col}' for col in formatted_properties] + [f"d.{col} = s.{col}" for col in formatted_properties], ) not_matched_insert_cols = ", ".join(formatted_properties) not_matched_insert_values = ", ".join( - [f's.{col}' for col in formatted_properties] + [f"s.{col}" for col in formatted_properties], ) - dedup_cols = ", ".join([key for key in formatted_key_properties]) + dedup_cols = ", ".join(list(formatted_key_properties)) dedup = f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {dedup_cols} ORDER BY SEQ8() DESC) = 1" return ( text( - f"merge into {full_table_name} d using " - + f"(select {json_casting_selects} from '@~/target-snowflake/{sync_id}'" + f"merge into {full_table_name} d using " # noqa: S608, ISC003 + + f"(select {json_casting_selects} from '@~/target-snowflake/{sync_id}'" # noqa: S608 + f"(file_format => {file_format}) {dedup}) s " + f"on {join_expr} " + f"when matched then update set {matched_clause} " + f"when not matched then insert ({not_matched_insert_cols}) " - + f"values ({not_matched_insert_values})" + + f"values ({not_matched_insert_values})", ), {}, ) - def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): + def _get_copy_statement(self, full_table_name, schema, sync_id, file_format): # noqa: ANN202, ANN001 """Get Snowflake COPY statement.""" formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) column_selections = self._get_column_selections(schema, formatter) - json_casting_selects = self._format_column_selections(column_selections, "json_casting") - col_alias_selects = self._format_column_selections(column_selections, "col_alias") + json_casting_selects = self._format_column_selections( + column_selections, + "json_casting", + ) + col_alias_selects = self._format_column_selections( + column_selections, + "col_alias", + ) return ( text( - f"copy into {full_table_name} {col_alias_selects} from " - + f"(select {json_casting_selects} from " + f"copy into {full_table_name} {col_alias_selects} from " # noqa: S608, ISC003 + + f"(select {json_casting_selects} from " # noqa: S608 + f"'@~/target-snowflake/{sync_id}')" - + f"file_format = (format_name='{file_format}')" + + f"file_format = (format_name='{file_format}')", ), {}, ) - def _get_file_format_statement(self, file_format): + def _get_file_format_statement(self, file_format): # noqa: ANN202, ANN001 """Get Snowflake CREATE FILE FORMAT statement.""" return ( - text( - f"create or replace file format {file_format}" - + "type = 'JSON' compression = 'AUTO'" - ), + text(f"create or replace file format {file_format}type = 'JSON' compression = 'AUTO'"), {}, ) - def _get_drop_file_format_statement(self, file_format): + def _get_drop_file_format_statement(self, file_format): # noqa: ANN202, ANN001 """Get Snowflake DROP FILE FORMAT statement.""" return ( text(f"drop file format if exists {file_format}"), {}, ) - def _get_stage_files_remove_statement(self, sync_id): + def _get_stage_files_remove_statement(self, sync_id): # noqa: ANN202, ANN001 """Get Snowflake REMOVE statement.""" return ( text(f"remove '@~/target-snowflake/{sync_id}/'"), @@ -420,7 +448,8 @@ def put_batches_to_stage(self, sync_id: str, files: Sequence[str]) -> None: with self._connect() as conn: for file_uri in files: put_statement, kwargs = self._get_put_statement( - sync_id=sync_id, file_uri=file_uri + sync_id=sync_id, + file_uri=file_uri, ) # sqlalchemy.text stripped a slash, which caused windows to fail so we used bound parameters instead # See https://github.com/MeltanoLabs/target-snowflake/issues/87 for more information about this error @@ -434,14 +463,15 @@ def create_file_format(self, file_format: str) -> None: """ with self._connect() as conn: file_format_statement, kwargs = self._get_file_format_statement( - file_format=file_format + file_format=file_format, ) self.logger.debug( - f"Creating file format with SQL: {file_format_statement!s}" + "Creating file format with SQL: %s", + file_format_statement, ) conn.execute(file_format_statement, **kwargs) - def merge_from_stage( + def merge_from_stage( # noqa: PLR0913 self, full_table_name: str, schema: dict, @@ -464,11 +494,15 @@ def merge_from_stage( file_format=file_format, key_properties=key_properties, ) - self.logger.debug(f"Merging with SQL: {merge_statement!s}") + self.logger.debug("Merging with SQL: %s", merge_statement) conn.execute(merge_statement, **kwargs) def copy_from_stage( - self, full_table_name: str, schema: dict, sync_id: str, file_format: str + self, + full_table_name: str, + schema: dict, + sync_id: str, + file_format: str, ): """Copy data from a stage into a table. @@ -485,7 +519,7 @@ def copy_from_stage( sync_id=sync_id, file_format=file_format, ) - self.logger.debug(f"Copying with SQL: {copy_statement!s}") + self.logger.debug("Copying with SQL: %s", copy_statement) conn.execute(copy_statement, **kwargs) def drop_file_format(self, file_format: str) -> None: @@ -496,9 +530,9 @@ def drop_file_format(self, file_format: str) -> None: """ with self._connect() as conn: drop_statement, kwargs = self._get_drop_file_format_statement( - file_format=file_format + file_format=file_format, ) - self.logger.debug(f"Dropping file format with SQL: {drop_statement!s}") + self.logger.debug("Dropping file format with SQL: %s", drop_statement) conn.execute(drop_statement, **kwargs) def remove_staged_files(self, sync_id: str) -> None: @@ -509,13 +543,13 @@ def remove_staged_files(self, sync_id: str) -> None: """ with self._connect() as conn: remove_statement, kwargs = self._get_stage_files_remove_statement( - sync_id=sync_id + sync_id=sync_id, ) - self.logger.debug(f"Removing staged files with SQL: {remove_statement!s}") + self.logger.debug("Removing staged files with SQL: %s", remove_statement) conn.execute(remove_statement, **kwargs) @staticmethod - def get_initialize_script(role, user, password, warehouse, database): + def get_initialize_script(role, user, password, warehouse, database) -> str: # noqa: ANN001 # https://fivetran.com/docs/destinations/snowflake/setup-guide return f""" begin; @@ -558,7 +592,7 @@ def get_initialize_script(role, user, password, warehouse, database): grant CREATE SCHEMA, MONITOR, USAGE on database {database} to role {role}; - + commit; """ @@ -579,13 +613,18 @@ def _adapt_column_type( Raises: NotImplementedError: if altering columns is not supported. """ - try: super()._adapt_column_type(full_table_name, column_name, sql_type) - except Exception as e: + except Exception: current_type: sqlalchemy.types.TypeEngine = self._get_column_type( full_table_name, column_name, ) - self.logger.error(f"Error adapting column type for {full_table_name=} {column_name=}, {current_type=} {sql_type=} (new sql type)") - raise e \ No newline at end of file + self.logger.exception( + "Error adapting column type for '%s.%s', '%s' to '%s' (new sql type)", + full_table_name, + column_name, + current_type, + sql_type, + ) + raise diff --git a/target_snowflake/initializer.py b/target_snowflake/initializer.py index 7504daa..d7476aa 100644 --- a/target_snowflake/initializer.py +++ b/target_snowflake/initializer.py @@ -1,8 +1,12 @@ -import click -from target_snowflake.connector import SnowflakeConnector +from __future__ import annotations + import sys + +import click from sqlalchemy import text +from target_snowflake.connector import SnowflakeConnector + def initializer(): click.echo("") @@ -10,10 +14,18 @@ def initializer(): click.echo("✨Initializing Snowflake account.✨") click.echo("Note: You will always be asked to confirm before anything is executed.") click.echo("") - click.echo("Additionally you can run in `dry_run` mode which will print the SQL without running it.") - dry_run = click.prompt("Would you like to run in `dry_run` mode?", default=False, type=bool) + click.echo( + "Additionally you can run in `dry_run` mode which will print the SQL without running it.", + ) + dry_run = click.prompt( + "Would you like to run in `dry_run` mode?", + default=False, + type=bool, + ) click.echo("") - click.echo("We will now interactively create (or the print queries) for all the following objects in your Snowflake account:") + click.echo( + "We will now interactively create (or the print queries) for all the following objects in your Snowflake account:", # noqa: E501 + ) click.echo(" - Role") click.echo(" - User") click.echo(" - Warehouse") @@ -22,9 +34,23 @@ def initializer(): role = click.prompt("Meltano Role Name:", type=str, default="MELTANO_ROLE") user = click.prompt("Meltano User Name:", type=str, default="MELTANO_USER") password = click.prompt("Meltano Password", type=str, confirmation_prompt=True) - warehouse = click.prompt("Meltano Warehouse Name", type=str, default="MELTANO_WAREHOUSE") - database = click.prompt("Meltano Database Name", type=str, default="MELTANO_DATABASE") - script = SnowflakeConnector.get_initialize_script(role, user, password, warehouse, database) + warehouse = click.prompt( + "Meltano Warehouse Name", + type=str, + default="MELTANO_WAREHOUSE", + ) + database = click.prompt( + "Meltano Database Name", + type=str, + default="MELTANO_DATABASE", + ) + script = SnowflakeConnector.get_initialize_script( + role, + user, + password, + warehouse, + database, + ) if dry_run: click.echo(script) sys.exit(0) @@ -39,23 +65,23 @@ def initializer(): "password": admin_pass, "role": "SYSADMIN", "user": admin_user, - } + }, ) - connector + try: click.echo("Initialization Started") - with connector._connect() as conn: - click.echo(f"Executing:") + with connector._connect() as conn: # noqa: SLF001 + click.echo("Executing:") click.echo(f"{script}") click.prompt("Confirm?", default=True, type=bool) click.echo("Initialization Started...") - for statement in script.split(';'): + for statement in script.split(";"): if len(statement.strip()) > 0: conn.execute( - text(statement) + text(statement), ) click.echo("Success!") click.echo("Initialization Complete") - except Exception as e: + except Exception as e: # noqa: BLE001 click.echo(f"Initialization Failed: {e}") sys.exit(1) diff --git a/target_snowflake/sinks.py b/target_snowflake/sinks.py index 8db7294..63d99f1 100644 --- a/target_snowflake/sinks.py +++ b/target_snowflake/sinks.py @@ -6,7 +6,6 @@ from urllib.parse import urlparse from uuid import uuid4 -from singer_sdk import PluginBase, SQLConnector from singer_sdk.batch import JSONLinesBatcher from singer_sdk.helpers._batch import ( BaseBatchFileEncoding, @@ -20,6 +19,9 @@ from target_snowflake.connector import SnowflakeConnector +if t.TYPE_CHECKING: + from singer_sdk import PluginBase, SQLConnector + DEFAULT_BATCH_CONFIG = { "encoding": {"format": "jsonl", "compression": "gzip"}, "storage": {"root": "file://"}, @@ -31,7 +33,7 @@ class SnowflakeSink(SQLSink): connector_class = SnowflakeConnector - def __init__( + def __init__( # noqa: PLR0913 self, target: PluginBase, stream_name: str, @@ -50,12 +52,12 @@ def __init__( ) @property - def schema_name(self) -> t.Optional[str]: + def schema_name(self) -> str | None: schema = super().schema_name or self.config.get("schema") return schema.upper() if schema else None @property - def database_name(self) -> t.Optional[str]: + def database_name(self) -> str | None: db = super().database_name or self.config.get("database") return db.upper() if db else None @@ -72,10 +74,7 @@ def setup(self) -> None: if self.schema_name: # Needed to conform schema name self.connector.prepare_schema( - self.conform_name( - self.schema_name, - object_type="schema" - ), + self.conform_name(self.schema_name, object_type="schema"), ) try: self.connector.prepare_table( @@ -84,30 +83,34 @@ def setup(self) -> None: primary_keys=self.key_properties, as_temp_table=False, ) - except Exception as e: - self.logger.error(f"Error creating {self.full_table_name=} {self.conform_schema(self.schema)=}") - raise e - + except Exception: + ( + self.logger.exception( + "Error creating %s %s", + self.full_table_name, + self.conform_schema(self.schema), + ), + ) + raise def conform_name( self, name: str, - object_type: str | None = None, # noqa: ARG002 + object_type: str | None = None, ) -> str: - if not object_type or object_type == "column": - formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) - if '"' not in formatter.format_collation(name.lower()): - name = name.lower() - return name - else: + if object_type and object_type != "column": return super().conform_name(name=name, object_type=object_type) + formatter = SnowflakeIdentifierPreparer(SnowflakeDialect()) + if '"' not in formatter.format_collation(name.lower()): + name = name.lower() + return name def bulk_insert_records( self, full_table_name: str, schema: dict, - records: t.Iterable[t.Dict[str, t.Any]], - ) -> t.Optional[int]: + records: t.Iterable[dict[str, t.Any]], + ) -> int | None: """Bulk insert records to an existing destination table. The default implementation uses a generic SQLAlchemy bulk insert operation. @@ -180,10 +183,7 @@ def insert_batch_files_via_internal_stage( file_format = f'{self.database_name}.{self.schema_name}."{sync_id}"' self.connector.put_batches_to_stage(sync_id=sync_id, files=files) self.connector.prepare_schema( - self.conform_name( - self.schema_name, - object_type="schema" - ), + self.conform_name(self.schema_name, object_type="schema"), # type: ignore[arg-type] ) self.connector.create_file_format(file_format=file_format) @@ -213,11 +213,13 @@ def insert_batch_files_via_internal_stage( if self.config.get("clean_up_batch_files"): for file_url in files: file_path = urlparse(file_url).path - if os.path.exists(file_path): - os.remove(file_path) + if os.path.exists(file_path): # noqa: PTH110 + os.remove(file_path) # noqa: PTH107 def process_batch_files( - self, encoding: BaseBatchFileEncoding, files: t.Sequence[str] + self, + encoding: BaseBatchFileEncoding, + files: t.Sequence[str], ) -> None: """Process a batch file with the given batch context. @@ -234,8 +236,9 @@ def process_batch_files( files=files, ) else: + msg = f"Unsupported batch file encoding: {encoding.format}" raise NotImplementedError( - f"Unsupported batch file encoding: {encoding.format}" + msg, ) # TODO: remove after https://github.com/meltano/sdk/issues/1819 is fixed @@ -248,4 +251,3 @@ def _singer_validate_message(self, record: dict) -> None: Raises: MissingKeyPropertiesError: If record is missing one or more key properties. """ - pass diff --git a/target_snowflake/snowflake_types.py b/target_snowflake/snowflake_types.py index 02512df..920816c 100644 --- a/target_snowflake/snowflake_types.py +++ b/target_snowflake/snowflake_types.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import datetime as dt +import typing as t import snowflake.sqlalchemy.custom_types as sct -class TIMESTAMP_NTZ(sct.TIMESTAMP_NTZ): +class TIMESTAMP_NTZ(sct.TIMESTAMP_NTZ): # noqa: N801 """Snowflake TIMESTAMP_NTZ type.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) @property @@ -17,7 +20,7 @@ def python_type(self): class NUMBER(sct.NUMBER): """Snowflake NUMBER type.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) @property @@ -28,7 +31,7 @@ def python_type(self): class VARIANT(sct.VARIANT): """Snowflake VARIANT type.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) @property diff --git a/tests/__init__.py b/tests/__init__.py index 563c810..afb9572 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -"""Test suite for {{ cookiecutter.target_id }}.""" \ No newline at end of file +"""Test suite for {{ cookiecutter.target_id }}.""" diff --git a/tests/batch.py b/tests/batch.py index f0bf3aa..315c86b 100644 --- a/tests/batch.py +++ b/tests/batch.py @@ -1,8 +1,9 @@ """BATCH Tests for Target Snowflake.""" +from __future__ import annotations +import typing as t from pathlib import Path -import pytest from singer_sdk.testing.suites import TestSuite from singer_sdk.testing.target_tests import ( TargetNoPrimaryKeys, @@ -39,7 +40,8 @@ def singer_filepath(self) -> Path: class SnowflakeTargetBatchArrayData( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetArrayData + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetArrayData, ): """Test that the target can handle batch messages.""" @@ -47,7 +49,8 @@ class SnowflakeTargetBatchArrayData( class SnowflakeTargetBatchCamelcase( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetCamelcaseTest + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetCamelcaseTest, ): """Test that the target can handle batch messages.""" @@ -56,7 +59,8 @@ class SnowflakeTargetBatchCamelcase( class SnowflakeTargetBatchDuplicateRecords( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetDuplicateRecords + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetDuplicateRecords, ): """Test that the target can handle batch messages.""" @@ -64,12 +68,13 @@ class SnowflakeTargetBatchDuplicateRecords( class SnowflakeTargetBatchEncodedStringData( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetEncodedStringData + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetEncodedStringData, ): """Test that the target can handle batch messages.""" name = "batch_encoded_string_data" - stream_names = [ + stream_names: t.ClassVar[list[str]] = [ "test_batch_strings", "test_batch_strings_in_objects", "test_batch_strings_in_arrays", @@ -81,7 +86,7 @@ class SnowflakeTargetBatchEncodedStringData( # ): # """Test that the target can handle batch messages.""" -# name = "batch_multiple_state_messages" +# name = "batch_multiple_state_messages" # noqa: ERA001 # class SnowflakeTargetBatchNoPrimaryKeysAppend( @@ -89,11 +94,12 @@ class SnowflakeTargetBatchEncodedStringData( # ): # """Test that the target can handle batch messages.""" -# name = "batch_no_primary_keys_append" +# name = "batch_no_primary_keys_append" # noqa: ERA001 class SnowflakeTargetBatchNoPrimaryKeys( - SnowflakeTargetCustomTestTemplate, TargetNoPrimaryKeys + SnowflakeTargetCustomTestTemplate, + TargetNoPrimaryKeys, ): """Test that the target can handle batch messages.""" @@ -101,7 +107,8 @@ class SnowflakeTargetBatchNoPrimaryKeys( class SnowflakeTargetBatchOptionalAttributes( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetOptionalAttributes + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetOptionalAttributes, ): """Test that the target can handle batch messages.""" @@ -109,7 +116,8 @@ class SnowflakeTargetBatchOptionalAttributes( class SnowflakeTargetBatchRecordBeforeSchemaTest( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetRecordBeforeSchemaTest + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetRecordBeforeSchemaTest, ): """Test that the target can handle batch messages.""" @@ -117,7 +125,8 @@ class SnowflakeTargetBatchRecordBeforeSchemaTest( class SnowflakeTargetBatchRecordMissingKeyProperty( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetRecordMissingKeyProperty + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetRecordMissingKeyProperty, ): """Test that the target can handle batch messages.""" @@ -125,7 +134,8 @@ class SnowflakeTargetBatchRecordMissingKeyProperty( class SnowflakeTargetBatchRecordMissingRequiredProperty( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetRecordMissingRequiredProperty + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetRecordMissingRequiredProperty, ): """Test that the target can handle batch messages.""" @@ -133,19 +143,21 @@ class SnowflakeTargetBatchRecordMissingRequiredProperty( class SnowflakeTargetBatchSchemaNoProperties( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetSchemaNoProperties + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetSchemaNoProperties, ): """Test that the target can handle batch messages.""" name = "batch_schema_no_properties" - stream_names = [ + stream_names: t.ClassVar[list[str]] = [ "test_batch_object_schema_with_properties", "test_batch_object_schema_no_properties", ] class SnowflakeTargetBatchSchemaUpdates( - SnowflakeTargetCustomTestTemplate, SnowflakeTargetSchemaUpdates + SnowflakeTargetCustomTestTemplate, + SnowflakeTargetSchemaUpdates, ): """Test that the target can handle batch messages.""" @@ -153,7 +165,8 @@ class SnowflakeTargetBatchSchemaUpdates( class SnowflakeTargetBatchSpecialCharsInAttributes( - SnowflakeTargetCustomTestTemplate, TargetSpecialCharsInAttributes + SnowflakeTargetCustomTestTemplate, + TargetSpecialCharsInAttributes, ): """Test that the target can handle batch messages.""" diff --git a/tests/conftest.py b/tests/conftest.py index fb69280..4133230 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ """Test Configuration.""" +from __future__ import annotations -pytest_plugins = ("singer_sdk.testing.pytest_plugin",) \ No newline at end of file +pytest_plugins = ("singer_sdk.testing.pytest_plugin",) diff --git a/tests/core.py b/tests/core.py index 683a8e8..8514429 100644 --- a/tests/core.py +++ b/tests/core.py @@ -1,4 +1,5 @@ -import typing as t +from __future__ import annotations + from pathlib import Path import pytest @@ -28,7 +29,9 @@ class SnowflakeTargetArrayData(TargetArrayData): def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + ) result = connector.connection.execute( f"select * from {table} order by 1", ) @@ -42,7 +45,7 @@ def validate(self) -> None: assert row[1] == '[\n "apple",\n "orange",\n "pear"\n]' table_schema = connector.get_table(table) expected_types = { - "id": sct._CUSTOM_DECIMAL, + "id": sct._CUSTOM_DECIMAL, # noqa: SLF001 "fruits": sct.VARIANT, "_sdc_extracted_at": sct.TIMESTAMP_NTZ, "_sdc_batched_at": sct.TIMESTAMP_NTZ, @@ -59,10 +62,10 @@ def validate(self) -> None: class SnowflakeTargetCamelcaseComplexSchema(TargetCamelcaseComplexSchema): def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.ForecastingTypeToCategory".upper() + table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.ForecastingTypeToCategory".upper() # noqa: E501 table_schema = connector.get_table(table) expected_types = { - "id": sct._CUSTOM_DECIMAL, + "id": sct._CUSTOM_DECIMAL, # noqa: SLF001 "isdeleted": sqlalchemy.types.BOOLEAN, "createddate": sct.TIMESTAMP_NTZ, "createdbyid": sct.STRING, @@ -91,7 +94,9 @@ def validate(self) -> None: class SnowflakeTargetDuplicateRecords(TargetDuplicateRecords): def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + ) result = connector.connection.execute( f"select * from {table} order by 1", ) @@ -128,7 +133,9 @@ def stream_name(self) -> str: def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.stream_name}".upper() + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{self.stream_name}".upper() + ) connector.connection.execute( f"select * from {table} order by 1", ) @@ -151,13 +158,15 @@ def validate(self) -> None: class SnowflakeTargetEncodedStringData(TargetEncodedStringData): @property - def stream_names(self) -> t.List[str]: + def stream_names(self) -> list[str]: return ["test_strings", "test_strings_in_objects", "test_strings_in_arrays"] def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) for table_name in self.stream_names: - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper() + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper() + ) connector.connection.execute( f"select * from {table} order by 1", ) @@ -166,20 +175,20 @@ def validate(self) -> None: class SnowflakeTargetInvalidSchemaTest(TargetInvalidSchemaTest): def test(self) -> None: - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 self.runner.sync_all() class SnowflakeTargetRecordBeforeSchemaTest(TargetRecordBeforeSchemaTest): def test(self) -> None: - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 self.runner.sync_all() class SnowflakeTargetRecordMissingKeyProperty(TargetRecordMissingKeyProperty): def test(self) -> None: # TODO: catch exact exception, currently snowflake throws an integrity error - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 self.runner.sync_all() @@ -188,19 +197,19 @@ def test(self) -> None: # `ERROR_ON_COLUMN_COUNT_MISMATCH=FALSE` class SnowflakeTargetOptionalAttributes(TargetOptionalAttributes): def test(self) -> None: - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 self.runner.sync_all() class SnowflakeTargetRecordMissingRequiredProperty(TargetRecordMissingRequiredProperty): def test(self) -> None: - with pytest.raises(Exception): + with pytest.raises(Exception): # noqa: B017, PT011 self.runner.sync_all() class SnowflakeTargetSchemaNoProperties(TargetSchemaNoProperties): @property - def stream_names(self) -> t.List[str]: + def stream_names(self) -> list[str]: return [ "test_object_schema_with_properties", "test_object_schema_no_properties", @@ -209,9 +218,11 @@ def stream_names(self) -> t.List[str]: def validate(self) -> None: for table_name in self.stream_names: connector = self.target.default_sink_class.connector_class( - self.target.config + self.target.config, + ) + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper() ) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.{table_name}".upper() result = connector.connection.execute( f"select * from {table} order by 1", ) @@ -240,7 +251,9 @@ def validate(self) -> None: class SnowflakeTargetSchemaUpdates(TargetSchemaUpdates): def validate(self) -> None: connector = self.target.default_sink_class.connector_class(self.target.config) - table = f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + table = ( + f"{self.target.config['database']}.{self.target.config['default_target_schema']}.test_{self.name}".upper() + ) result = connector.connection.execute( f"select * from {table} order by 1", ) @@ -378,7 +391,7 @@ def setup(self) -> None: _SDC_TABLE_VERSION NUMBER(38,0), PRIMARY KEY (ID) ) - """ + """, ) def validate(self) -> None: @@ -391,8 +404,8 @@ def validate(self) -> None: row = result.first() assert len(row) == 12 -class SnowflakeTargetExistingTableAlter(SnowflakeTargetExistingTable): +class SnowflakeTargetExistingTableAlter(SnowflakeTargetExistingTable): name = "existing_table_alter" # This sends a schema that will request altering from TIMESTAMP_NTZ to VARCHAR @@ -416,7 +429,7 @@ def setup(self) -> None: _SDC_TABLE_VERSION NUMBER(38,0), PRIMARY KEY (ID) ) - """ + """, ) @@ -449,9 +462,7 @@ def validate(self) -> None: isinstance(column.type, expected_types[column.name]) - class SnowflakeTargetColumnOrderMismatch(TargetFileTestTemplate): - name = "column_order_mismatch" def setup(self) -> None: @@ -465,7 +476,7 @@ def setup(self) -> None: COL3 TIMESTAMP_NTZ(9), COL2 BOOLEAN ) - """ + """, ) @property @@ -473,6 +484,7 @@ def singer_filepath(self) -> Path: current_dir = Path(__file__).resolve().parent return current_dir / "target_test_streams" / f"{self.name}.singer" + target_tests = TestSuite( kind="target", tests=[ diff --git a/tests/target_test_streams/batch_multiple_state_messages__a.jsonl b/tests/target_test_streams/batch_multiple_state_messages__a.jsonl index 06ef31f..4c9f19c 100644 --- a/tests/target_test_streams/batch_multiple_state_messages__a.jsonl +++ b/tests/target_test_streams/batch_multiple_state_messages__a.jsonl @@ -3,4 +3,4 @@ {"type": "RECORD", "stream": "test_multiple_state_messages_a", "record": {"id": 3, "metric": 300}} {"type": "RECORD", "stream": "test_multiple_state_messages_a", "record": {"id": 4, "metric": 400}} {"type": "RECORD", "stream": "test_multiple_state_messages_a", "record": {"id": 5, "metric": 500}} -{"type": "RECORD", "stream": "test_multiple_state_messages_a", "record": {"id": 6, "metric": 600}} \ No newline at end of file +{"type": "RECORD", "stream": "test_multiple_state_messages_a", "record": {"id": 6, "metric": 600}} diff --git a/tests/target_test_streams/batch_multiple_state_messages__b.jsonl b/tests/target_test_streams/batch_multiple_state_messages__b.jsonl index 602caca..a7b51a0 100644 --- a/tests/target_test_streams/batch_multiple_state_messages__b.jsonl +++ b/tests/target_test_streams/batch_multiple_state_messages__b.jsonl @@ -3,4 +3,4 @@ {"type": "RECORD", "stream": "test_multiple_state_messages_b", "record": {"id": 3, "metric": 330}} {"type": "RECORD", "stream": "test_multiple_state_messages_b", "record": {"id": 4, "metric": 440}} {"type": "RECORD", "stream": "test_multiple_state_messages_b", "record": {"id": 5, "metric": 550}} -{"type": "RECORD", "stream": "test_multiple_state_messages_b", "record": {"id": 6, "metric": 660}} \ No newline at end of file +{"type": "RECORD", "stream": "test_multiple_state_messages_b", "record": {"id": 6, "metric": 660}} diff --git a/tests/target_test_streams/batch_record_missing_required_property.singer b/tests/target_test_streams/batch_record_missing_required_property.singer index 998e8b4..09c5eb0 100644 --- a/tests/target_test_streams/batch_record_missing_required_property.singer +++ b/tests/target_test_streams/batch_record_missing_required_property.singer @@ -1,3 +1,2 @@ {"type": "SCHEMA", "stream": "test_batch_record_missing_required_property", "key_properties": [], "schema": {"required": ["id"], "type": "object", "properties": {"id": {"type": "integer"}, "metric": {"type": "integer"}}}} {"type": "BATCH", "stream": "test_batch_record_missing_required_property", "encoding": {"format": "jsonl"}, "manifest": ["file://tests/target_test_streams/batch_record_missing_required_property.jsonl"]} - diff --git a/tests/target_test_streams/type_edge_cases.singer b/tests/target_test_streams/type_edge_cases.singer index 72d5f17..0f79324 100644 --- a/tests/target_test_streams/type_edge_cases.singer +++ b/tests/target_test_streams/type_edge_cases.singer @@ -1,2 +1,2 @@ {"type": "SCHEMA", "stream": "type_edge_cases", "key_properties": ["id"], "schema": {"required": ["id"], "type": "object", "properties": {"id": {"type": "integer"}, "col_max_length_str": {"maxLength": 4294967295, "type": [ "null", "string" ] }, "col_multiple_of": {"multipleOf": 0.0001, "type": [ "null", "number" ] }, "col_multiple_of_int": {"multipleOf": 10, "type": [ "null", "number" ] }}}} -{"type": "RECORD", "stream": "type_edge_cases", "record": {"id": 1, "col_max_length_str": "foo", "col_multiple_of": 123.456, "col_multiple_of_int": 100}} \ No newline at end of file +{"type": "RECORD", "stream": "type_edge_cases", "record": {"id": 1, "col_max_length_str": "foo", "col_multiple_of": 123.456, "col_multiple_of_int": 100}} diff --git a/tests/test_target_snowflake.py b/tests/test_target_snowflake.py index fa37892..8a20d9a 100644 --- a/tests/test_target_snowflake.py +++ b/tests/test_target_snowflake.py @@ -32,11 +32,11 @@ class BaseSnowflakeTargetTests: @pytest.fixture() def connection(self, runner): return runner.singer_class.default_sink_class.connector_class( - runner.config + runner.config, ).connection @pytest.fixture() - def resource(self, runner, connection): # noqa: ANN201 + def resource(self, runner, connection): # noqa: PT004 """Generic external resource. This fixture is useful for setup and teardown of external resources, @@ -46,19 +46,17 @@ def resource(self, runner, connection): # noqa: ANN201 https://github.com/meltano/sdk/tree/main/tests/samples """ connection.execute( - f"create schema {runner.config['database']}.{runner.config['default_target_schema']}" + f"create schema {runner.config['database']}.{runner.config['default_target_schema']}", ) yield connection.execute( - f"drop schema if exists {runner.config['database']}.{runner.config['default_target_schema']}" + f"drop schema if exists {runner.config['database']}.{runner.config['default_target_schema']}", ) # Custom so I can implement all validate methods STANDARD_TEST_CONFIG = copy.deepcopy(SAMPLE_CONFIG) -STANDARD_TEST_CONFIG[ - "default_target_schema" -] = f"TARGET_SNOWFLAKE_{uuid.uuid4().hex[0:6]!s}" +STANDARD_TEST_CONFIG["default_target_schema"] = f"TARGET_SNOWFLAKE_{uuid.uuid4().hex[0:6]!s}" StandardTargetTests = get_target_test_class( target_class=TargetSnowflake, config=STANDARD_TEST_CONFIG, @@ -68,15 +66,13 @@ def resource(self, runner, connection): # noqa: ANN201 ) -class TestTargetSnowflake(BaseSnowflakeTargetTests, StandardTargetTests): # type: ignore[misc, valid-type] # noqa: E501 +class TestTargetSnowflake(BaseSnowflakeTargetTests, StandardTargetTests): # type: ignore[misc, valid-type] """Standard Target Tests.""" # Custom so I can implement all validate methods BATCH_TEST_CONFIG = copy.deepcopy(SAMPLE_CONFIG) -BATCH_TEST_CONFIG[ - "default_target_schema" -] = f"TARGET_SNOWFLAKE_{uuid.uuid4().hex[0:6]!s}" +BATCH_TEST_CONFIG["default_target_schema"] = f"TARGET_SNOWFLAKE_{uuid.uuid4().hex[0:6]!s}" BATCH_TEST_CONFIG["add_record_metadata"] = False BatchTargetTests = get_target_test_class( target_class=TargetSnowflake, @@ -87,16 +83,17 @@ class TestTargetSnowflake(BaseSnowflakeTargetTests, StandardTargetTests): # typ ) -class TestTargetSnowflakeBatch(BaseSnowflakeTargetTests, BatchTargetTests): # type: ignore[misc, valid-type] # noqa: E501 +class TestTargetSnowflakeBatch(BaseSnowflakeTargetTests, BatchTargetTests): # type: ignore[misc, valid-type] """Batch Target Tests.""" + def test_invalid_database(): - INVALID_TEST_CONFIG = copy.deepcopy(SAMPLE_CONFIG) + INVALID_TEST_CONFIG = copy.deepcopy(SAMPLE_CONFIG) # noqa: N806 INVALID_TEST_CONFIG["database"] = "FOO_BAR_DOESNT_EXIST" runner = TargetTestRunner( TargetSnowflake, config=INVALID_TEST_CONFIG, - input_filepath="tests/target_test_streams/existing_table.singer" + input_filepath="tests/target_test_streams/existing_table.singer", ) - with pytest.raises(Exception): - runner.sync_all() \ No newline at end of file + with pytest.raises(Exception): # noqa: B017, PT011 + runner.sync_all() diff --git a/tox.ini b/tox.ini index d277733..72d423c 100644 --- a/tox.ini +++ b/tox.ini @@ -27,4 +27,3 @@ commands = poetry install -v poetry run coverage run -m pytest --capture=no {posargs} poetry run coverage html -d tests/codecoverage -