From f1e8c07fb71d1a6b7551b01ebaf1920405da835f Mon Sep 17 00:00:00 2001 From: Adam Ling Date: Sat, 19 Nov 2022 23:46:24 -0800 Subject: [PATCH] improve coding style --- src/snowflake/sqlalchemy/snowdialect.py | 79 ++++++++++--------------- 1 file changed, 31 insertions(+), 48 deletions(-) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 4f641ac6..fbb63b5a 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -319,21 +319,17 @@ def get_check_constraints(self, connection, table_name, schema, **kw): @reflection.cache def _get_schema_primary_keys(self, connection, schema, table_name=None, **kw): - if table_name is not None: - fully_qualified_path = self._denormalize_quote_join( - schema, self.denormalize_name(table_name) - ) - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_primary_keys */ PRIMARY KEYS IN TABLE {fully_qualified_path}" - ) - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}" - ) + fully_qualified_path = ( + self._denormalize_quote_join(schema, self.denormalize_name(table_name)) + if table_name is not None + else schema + ) + result = connection.execute( + text( + f"SHOW /* sqlalchemy:_get_schema_primary_keys */ PRIMARY KEYS IN " + f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}" ) + ) ans = {} for row in result: table_name = self.normalize_name(row._mapping["table_name"]) @@ -364,19 +360,17 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): @reflection.cache def _get_schema_unique_constraints(self, connection, schema, table_name=None, **kw): - if table_name is not None: - fully_qualified_path = self._denormalize_quote_join(schema, table_name) - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN TABLE {fully_qualified_path}" - ) - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}" - ) + fully_qualified_path = ( + self._denormalize_quote_join(schema, self.denormalize_name(table_name)) + if table_name is not None + else schema + ) + result = connection.execute( + text( + f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN " + f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}" ) + ) unique_constraints = {} for row in result: name = self.normalize_name(row._mapping["constraint_name"]) @@ -415,21 +409,17 @@ def get_unique_constraints(self, connection, table_name, schema, **kw): @reflection.cache def _get_schema_foreign_keys(self, connection, schema, table_name=None, **kw): _, current_schema = self._current_database_schema(connection, **kw) - if table_name is not None: - fully_qualified_path = self._denormalize_quote_join( - schema, self.denormalize_name(table_name) - ) - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN TABLE {fully_qualified_path}" - ) - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}" - ) + fully_qualified_path = ( + self._denormalize_quote_join(schema, self.denormalize_name(table_name)) + if table_name is not None + else schema + ) + result = connection.execute( + text( + f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN " + f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}" ) + ) foreign_key_map = {} for row in result: name = self.normalize_name(row._mapping["fk_name"]) @@ -711,14 +701,7 @@ def get_columns(self, connection, table_name, schema=None, **kw): if not schema: _, schema = self._current_database_schema(connection, **kw) - if table_name is not None: - return self._get_table_columns(connection, table_name, schema, **kw) - else: - schema_columns = self._get_schema_columns(connection, schema, **kw) - if schema_columns is None: - # Too many results, fall back to only query about single table - return self._get_table_columns(connection, table_name, schema, **kw) - return schema_columns[self.normalize_name(table_name)] + return self._get_table_columns(connection, table_name, schema, **kw) @reflection.cache def get_table_names(self, connection, schema=None, **kw):