Skip to content

Commit

Permalink
Update snowdialect.py
Browse files Browse the repository at this point in the history
Implement optimization changes suggested by sfc-gh-aling. Thank you!
  • Loading branch information
sfc-gh-kterada authored and sfc-gh-aling committed Nov 20, 2022
1 parent c824dac commit b7d6e31
Showing 1 changed file with 66 additions and 161 deletions.
227 changes: 66 additions & 161 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,35 +318,22 @@ def get_check_constraints(self, connection, table_name, schema, **kw):
return []

@reflection.cache
def _get_table_primary_keys(self, connection, schema, table_name, **kw):
fully_qualified_path = self._denormalize_quote_join(
schema, self.denormalize_name(table_name)
)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_table_primary_keys */ PRIMARY KEYS IN TABLE {fully_qualified_path}"
def _get_schema_primary_keys(self, connection, schema, table_name=None, **kw):
if table_name is not None:
fully_qualified_path = self._denormalize_quote_join(
schema, self.denormalize_name(table_name)
)
)
ans = {}
for row in result:
table_name = self.normalize_name(row._mapping["table_name"])
if table_name not in ans:
ans[table_name] = {
"constrained_columns": [],
"name": self.normalize_name(row._mapping["constraint_name"]),
}
ans[table_name]["constrained_columns"].append(
self.normalize_name(row._mapping["column_name"])
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_primary_keys */ PRIMARY KEYS IN TABLE {fully_qualified_path}"
)
)
return ans

@reflection.cache
def _get_schema_primary_keys(self, connection, schema, **kw):
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
)
)
)
ans = {}
for row in result:
table_name = self.normalize_name(row._mapping["table_name"])
Expand All @@ -368,54 +355,28 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
full_schema_name = self._denormalize_quote_join(
current_database, schema if schema else current_schema
)

if table_name is not None:
return self._get_table_primary_keys(
connection,
self.denormalize_name(full_schema_name),
self.denormalize_name(table_name),
**kw,
).get(table_name, {"constrained_columns": [], "name": None})
else:
return self._get_schema_primary_keys(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, {"constrained_columns": [], "name": None})
return self._get_schema_primary_keys(
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
).get(table_name, {"constrained_columns": [], "name": None})

@reflection.cache
def _get_table_unique_constraints(self, connection, schema, table_name, **kw):
fully_qualified_path = self._denormalize_quote_join(schema, table_name)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_table_unique_constraints */ UNIQUE KEYS IN TABLE {fully_qualified_path}"
def _get_schema_unique_constraints(self, connection, schema, table_name=None, **kw):
if table_name is not None:
fully_qualified_path = self._denormalize_quote_join(schema, table_name)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN TABLE {fully_qualified_path}"
)
)
)
unique_constraints = {}
for row in result:
name = self.normalize_name(row._mapping["constraint_name"])
if name not in unique_constraints:
unique_constraints[name] = {
"column_names": [self.normalize_name(row._mapping["column_name"])],
"name": name,
"table_name": self.normalize_name(row._mapping["table_name"]),
}
else:
unique_constraints[name]["column_names"].append(
self.normalize_name(row._mapping["column_name"])
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
)

ans = defaultdict(list)
for constraint in unique_constraints.values():
table_name = constraint.pop("table_name")
ans[table_name].append(constraint)
return ans

@reflection.cache
def _get_schema_unique_constraints(self, connection, schema, **kw):
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
)
)
unique_constraints = {}
for row in result:
name = self.normalize_name(row._mapping["constraint_name"])
Expand Down Expand Up @@ -444,91 +405,31 @@ def get_unique_constraints(self, connection, table_name, schema, **kw):
full_schema_name = self._denormalize_quote_join(
current_database, schema if schema else current_schema
)
if table_name is not None:
return self._get_table_unique_constraints(
connection,
self.denormalize_name(full_schema_name),
self.denormalize_name(table_name),
**kw,
).get(table_name, [])
else:
return self._get_schema_unique_constraints(
connection, self.denormalize_name(full_schema_name), **kw
).get(table_name, [])
return self._get_schema_unique_constraints(
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
).get(table_name, [])

@reflection.cache
def _get_table_foreign_keys(self, connection, schema, table_name, **kw):
def _get_schema_foreign_keys(self, connection, schema, table_name=None, **kw):
_, current_schema = self._current_database_schema(connection, **kw)
fully_qualified_path = self._denormalize_quote_join(
schema, self.denormalize_name(table_name)
)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_table_foreign_keys */ IMPORTED KEYS IN TABLE {fully_qualified_path}"
if table_name is not None:
fully_qualified_path = self._denormalize_quote_join(
schema, self.denormalize_name(table_name)
)
)
foreign_key_map = {}
for row in result:
name = self.normalize_name(row._mapping["fk_name"])
if name not in foreign_key_map:
referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
foreign_key_map[name] = {
"constrained_columns": [
self.normalize_name(row._mapping["fk_column_name"])
],
# referred schema should be None in context where it doesn't need to be specified
# https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
"referred_schema": (
referred_schema
if referred_schema
not in (self.default_schema_name, current_schema)
else None
),
"referred_table": self.normalize_name(
row._mapping["pk_table_name"]
),
"referred_columns": [
self.normalize_name(row._mapping["pk_column_name"])
],
"name": name,
"table_name": self.normalize_name(row._mapping["fk_table_name"]),
}
options = {}
if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
options["ondelete"] = self.normalize_name(
row._mapping["delete_rule"]
)
if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
options["onupdate"] = self.normalize_name(
row._mapping["update_rule"]
)
foreign_key_map[name]["options"] = options
else:
foreign_key_map[name]["constrained_columns"].append(
self.normalize_name(row._mapping["fk_column_name"])
)
foreign_key_map[name]["referred_columns"].append(
self.normalize_name(row._mapping["pk_column_name"])
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN TABLE {fully_qualified_path}"
)

ans = {}

for _, v in foreign_key_map.items():
if v["table_name"] not in ans:
ans[v["table_name"]] = []
ans[v["table_name"]].append(
{k2: v2 for k2, v2 in v.items() if k2 != "table_name"}
)
return ans

@reflection.cache
def _get_schema_foreign_keys(self, connection, schema, **kw):
_, current_schema = self._current_database_schema(connection, **kw)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
)
)
)
foreign_key_map = {}
for row in result:
name = self.normalize_name(row._mapping["fk_name"])
Expand Down Expand Up @@ -595,17 +496,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
current_database, schema if schema else current_schema
)

if table_name is not None:
foreign_key_map = self._get_table_foreign_keys(
connection,
self.denormalize_name(full_schema_name),
self.denormalize_name(table_name),
**kw,
)
else:
foreign_key_map = self._get_schema_foreign_keys(
connection, self.denormalize_name(full_schema_name), **kw
)
foreign_key_map = self._get_schema_foreign_keys(
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
)
return foreign_key_map.get(table_name, [])

@reflection.cache
Expand Down Expand Up @@ -716,8 +612,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
ans = []
current_database, _ = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(current_database, schema)
table_primary_keys = self._get_table_primary_keys(
connection, full_schema_name, table_name, **kw
table_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, table_name=table_name, **kw
)
result = connection.execute(
text(
Expand All @@ -732,7 +628,9 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
ic.is_nullable,
ic.column_default,
ic.is_identity,
ic.comment
ic.comment,
ic.identity_start,
ic.identity_increment
FROM information_schema.columns ic
WHERE ic.table_schema=:table_schema
AND ic.table_name=:table_name
Expand All @@ -754,6 +652,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
column_default,
is_identity,
comment,
identity_start,
identity_increment,
) in result:
table_name = self.normalize_name(table_name)
column_name = self.normalize_name(column_name)
Expand Down Expand Up @@ -796,6 +696,11 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
else False,
}
)
if is_identity == "YES":
ans[-1]["identity"] = {
"start": identity_start,
"increment": identity_increment,
}
return ans

def get_columns(self, connection, table_name, schema=None, **kw):
Expand Down

0 comments on commit b7d6e31

Please sign in to comment.