Skip to content

Commit

Permalink
fix: Invalid f-string syntax in Python 3.7
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Nov 21, 2023
1 parent b589260 commit c830e36
Showing 1 changed file with 58 additions and 28 deletions.
86 changes: 58 additions & 28 deletions target_snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, *args, **kwargs) -> None:
self.table_cache: dict = {}
self.schema_cache: dict = {}
super().__init__(*args, **kwargs)

def get_table_columns(
self,
full_table_name: str,
Expand Down Expand Up @@ -87,7 +87,8 @@ def get_table_columns(
)
for col_meta in columns
if not column_names
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
or col_meta["name"].casefold()
in {col.casefold() for col in column_names}
}
self.table_cache[full_table_name] = parsed_columns
return parsed_columns
Expand Down Expand Up @@ -118,7 +119,9 @@ def get_sqlalchemy_url(self, config: dict) -> str:
if "password" in config:
params["password"] = config["password"]
elif "private_key_path" not in config:
raise Exception("Neither password nor private_key_path was provided for authentication.")
raise Exception(
"Neither password nor private_key_path was provided for authentication."
)

for option in ["warehouse", "role"]:
if config.get(option):
Expand Down Expand Up @@ -149,7 +152,9 @@ def create_engine(self) -> Engine:
with open(self.config["private_key_path"], "rb") as private_key_file:
private_key = serialization.load_pem_private_key(
private_key_file.read(),
password=self.config["private_key_passphrase"].encode() if "private_key_passphrase" in self.config else None,
password=self.config["private_key_passphrase"].encode()
if "private_key_passphrase" in self.config
else None,
backend=default_backend(),
)
connect_args["private_key"] = private_key.private_bytes(
Expand All @@ -163,9 +168,13 @@ def create_engine(self) -> Engine:
echo=False,
)
connection = engine.connect()
db_names = [db[1] for db in connection.execute(text("SHOW DATABASES;")).fetchall()]
db_names = [
db[1] for db in connection.execute(text("SHOW DATABASES;")).fetchall()
]
if self.config["database"] not in db_names:
raise Exception(f"Database '{self.config['database']}' does not exist or the user/role doesn't have access to it.")
raise Exception(
f"Database '{self.config['database']}' does not exist or the user/role doesn't have access to it."
)
return engine

def prepare_column(
Expand All @@ -188,7 +197,9 @@ def prepare_column(
sql_type,
)
except Exception as e:
self.logger.error(f"Error preparing column for {full_table_name=} {column_name=}")
self.logger.error(
f"Error preparing column for {full_table_name=} {column_name=}"
)
raise e

@staticmethod
Expand Down Expand Up @@ -305,17 +316,20 @@ def _get_put_statement(self, sync_id: str, file_uri: str) -> Tuple[text, dict]:
@staticmethod
def _format_column_selections(column_selections: dict, format: str) -> str:
if format == "json_casting":
return ', '.join(
return ", ".join(
[
f"$1:{col['clean_property_name']}::{col['sql_type']} as {col['clean_alias']}" for col in column_selections
f"$1:{col['clean_property_name']}::{col['sql_type']} as {col['clean_alias']}"
for col in column_selections
]
)
elif format == "col_alias":
return f"({', '.join([col['clean_alias'] for col in column_selections])})"
else:
raise NotImplementedError(f"Column format not implemented: {format}")

def _get_column_selections(self, schema: dict, formatter: SnowflakeIdentifierPreparer) -> list:
def _get_column_selections(
self, schema: dict, formatter: SnowflakeIdentifierPreparer
) -> list:
column_selections = []
for property_name, property_def in schema["properties"].items():
clean_property_name = formatter.format_collation(property_name)
Expand All @@ -338,20 +352,26 @@ def _get_merge_from_stage_statement(

formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
column_selections = self._get_column_selections(schema, formatter)
json_casting_selects = self._format_column_selections(column_selections, "json_casting")
json_casting_selects = self._format_column_selections(
column_selections, "json_casting"
)

# use UPPER from here onwards
formatted_properties = [formatter.format_collation(col) for col in schema["properties"].keys()]
formatted_key_properties = [formatter.format_collation(col) for col in key_properties]
formatted_properties = [
formatter.format_collation(col) for col in schema["properties"].keys()
]
formatted_key_properties = [
formatter.format_collation(col) for col in key_properties
]
join_expr = " and ".join(
[f'd.{key} = s.{key}' for key in formatted_key_properties]
[f"d.{key} = s.{key}" for key in formatted_key_properties]
)
matched_clause = ", ".join(
[f'd.{col} = s.{col}' for col in formatted_properties]
[f"d.{col} = s.{col}" for col in formatted_properties]
)
not_matched_insert_cols = ", ".join(formatted_properties)
not_matched_insert_values = ", ".join(
[f's.{col}' for col in formatted_properties]
[f"s.{col}" for col in formatted_properties]
)
dedup_cols = ", ".join([key for key in formatted_key_properties])
dedup = f"QUALIFY ROW_NUMBER() OVER (PARTITION BY {dedup_cols} ORDER BY SEQ8() DESC) = 1"
Expand All @@ -372,8 +392,12 @@ def _get_copy_statement(self, full_table_name, schema, sync_id, file_format):
"""Get Snowflake COPY statement."""
formatter = SnowflakeIdentifierPreparer(SnowflakeDialect())
column_selections = self._get_column_selections(schema, formatter)
json_casting_selects = self._format_column_selections(column_selections, "json_casting")
col_alias_selects = self._format_column_selections(column_selections, "col_alias")
json_casting_selects = self._format_column_selections(
column_selections, "json_casting"
)
col_alias_selects = self._format_column_selections(
column_selections, "col_alias"
)
return (
text(
f"copy into {full_table_name} {col_alias_selects} from "
Expand Down Expand Up @@ -437,7 +461,8 @@ def create_file_format(self, file_format: str) -> None:
file_format=file_format
)
self.logger.debug(
f"Creating file format with SQL: {file_format_statement!s}"
"Creating file format with SQL: %s",
file_format_statement,
)
conn.execute(file_format_statement, **kwargs)

Expand All @@ -464,7 +489,7 @@ def merge_from_stage(
file_format=file_format,
key_properties=key_properties,
)
self.logger.debug(f"Merging with SQL: {merge_statement!s}")
self.logger.debug("Merging with SQL: %s", merge_statement)
conn.execute(merge_statement, **kwargs)

def copy_from_stage(
Expand All @@ -485,7 +510,7 @@ def copy_from_stage(
sync_id=sync_id,
file_format=file_format,
)
self.logger.debug(f"Copying with SQL: {copy_statement!s}")
self.logger.debug("Copying with SQL: %s", copy_statement)
conn.execute(copy_statement, **kwargs)

def drop_file_format(self, file_format: str) -> None:
Expand All @@ -498,7 +523,7 @@ def drop_file_format(self, file_format: str) -> None:
drop_statement, kwargs = self._get_drop_file_format_statement(
file_format=file_format
)
self.logger.debug(f"Dropping file format with SQL: {drop_statement!s}")
self.logger.debug("Dropping file format with SQL: %s", drop_statement)
conn.execute(drop_statement, **kwargs)

def remove_staged_files(self, sync_id: str) -> None:
Expand All @@ -511,7 +536,7 @@ def remove_staged_files(self, sync_id: str) -> None:
remove_statement, kwargs = self._get_stage_files_remove_statement(
sync_id=sync_id
)
self.logger.debug(f"Removing staged files with SQL: {remove_statement!s}")
self.logger.debug("Removing staged files with SQL: %s", remove_statement)
conn.execute(remove_statement, **kwargs)

@staticmethod
Expand Down Expand Up @@ -558,7 +583,7 @@ def get_initialize_script(role, user, password, warehouse, database):
grant CREATE SCHEMA, MONITOR, USAGE
on database {database}
to role {role};
commit;
"""
Expand All @@ -579,13 +604,18 @@ def _adapt_column_type(
Raises:
NotImplementedError: if altering columns is not supported.
"""

try:
super()._adapt_column_type(full_table_name, column_name, sql_type)
except Exception as e:
except Exception:
current_type: sqlalchemy.types.TypeEngine = self._get_column_type(
full_table_name,
column_name,
)
self.logger.error(f"Error adapting column type for {full_table_name=} {column_name=}, {current_type=} {sql_type=} (new sql type)")
raise e
self.logger.exception(
"Error adapting column type for '%s.%s', '%s' to '%s' (new sql type)",
full_table_name,
column_name,
current_type,
sql_type,
)
raise

0 comments on commit c830e36

Please sign in to comment.