Skip to content

Commit

Permalink
fix subqueries bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Oct 25, 2023
1 parent 11e15cf commit 7656ea6
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 3 deletions.
36 changes: 33 additions & 3 deletions metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import enum
import functools
import itertools
import logging
import pathlib
from collections import defaultdict
Expand All @@ -11,8 +12,8 @@
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.optimizer
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
from pydantic import BaseModel
from typing_extensions import TypedDict

Expand Down Expand Up @@ -47,6 +48,14 @@
SQL_PARSE_RESULT_CACHE_SIZE = 1000


RULES_BEFORE_TYPE_ANNOTATION = list(
itertools.takewhile(
lambda func: func != sqlglot.optimizer.annotate_types.annotate_types,
sqlglot.optimizer.optimizer.RULES,
)
)


class GraphQLSchemaField(TypedDict):
fieldPath: str
nativeDataType: str
Expand Down Expand Up @@ -571,17 +580,20 @@ def _schema_aware_fuzzy_column_resolve(
# - the select instead of the full outer statement
# - schema info
# - column qualification enabled
# - running the full pre-type annotation optimizer

# logger.debug("Schema: %s", sqlglot_db_schema.mapping)
statement = sqlglot.optimizer.qualify.qualify(
statement = sqlglot.optimizer.optimizer.optimize(
statement,
dialect=dialect,
schema=sqlglot_db_schema,
qualify_columns=True,
validate_qualify_columns=False,
identify=True,
# sqlglot calls the db -> schema -> table hierarchy "catalog", "db", "table".
catalog=default_db,
db=default_schema,
rules=RULES_BEFORE_TYPE_ANNOTATION,
)
except (sqlglot.errors.OptimizeError, ValueError) as e:
raise SqlUnderstandingError(
Expand Down Expand Up @@ -751,6 +763,7 @@ def _extract_select_from_create(
_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT: Set[str] = set(
sqlglot.exp.Update.arg_types.keys()
) - set(sqlglot.exp.Select.arg_types.keys())
_UPDATE_FROM_TABLE_ARGS_TO_MOVE = {"joins", "laterals", "pivot"}


def _extract_select_from_update(
Expand All @@ -777,21 +790,38 @@ def _extract_select_from_update(
# they'll get caught later.
new_expressions.append(expr)

# Special translation for the `from` clause.
extra_args = {}
original_from = statement.args.get("from")
if original_from and isinstance(original_from.this, sqlglot.exp.Table):
# Move joins, laterals, and pivots from the Update->From->Table->field
# to the top-level Select->field.

for k in _UPDATE_FROM_TABLE_ARGS_TO_MOVE:
if k in original_from.this.args:
# Mutate the from table clause in-place.
extra_args[k] = original_from.this.args.pop(k)

select_statement = sqlglot.exp.Select(
**{
**{
k: v
for k, v in statement.args.items()
if k not in _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
},
**extra_args,
"expressions": new_expressions,
}
)

# Update statements always implicitly have the updated table in context.
# TODO: Retain table name alias.
if select_statement.args.get("from"):
select_statement = select_statement.join(statement.this, append=True)
# select_statement = sqlglot.parse_one(select_statement.sql(dialect=dialect))

select_statement = select_statement.join(
statement.this, append=True, join_kind="cross"
)
else:
select_statement = select_statement.from_(statement.this)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"query_type": "SELECT",
"in_tables": [
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)"
],
"out_tables": [],
"column_lineage": [
{
"downstream": {
"table": null,
"column": "a",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.NumberType": {}
}
},
"native_column_type": "INT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"column": "a"
}
]
},
{
"downstream": {
"table": null,
"column": "b",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.NumberType": {}
}
},
"native_column_type": "INT"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)",
"column": "b"
}
]
},
{
"downstream": {
"table": null,
"column": "c",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.ArrayType": {}
}
},
"native_column_type": "INT[]"
},
"upstreams": [
{
"table": "urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)",
"column": "c"
}
]
}
]
}
27 changes: 27 additions & 0 deletions metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,33 @@ def test_snowflake_update_self():
)


def test_postgres_select_subquery():
assert_sql_result(
"""
SELECT
a,
b,
(SELECT c FROM table2 WHERE table2.id = table1.id) as c
FROM table1
""",
dialect="postgres",
default_db="my_db",
default_schema="my_schema",
schemas={
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table1,PROD)": {
"id": "INTEGER",
"a": "INTEGER",
"b": "INTEGER",
},
"urn:li:dataset:(urn:li:dataPlatform:postgres,my_db.my_schema.table2,PROD)": {
"id": "INTEGER",
"c": "INTEGER",
},
},
expected_file=RESOURCE_DIR / "test_postgres_select_subquery.json",
)


@pytest.mark.skip(reason="We can't parse column-list syntax with sub-selects yet")
def test_postgres_update_subselect():
assert_sql_result(
Expand Down

0 comments on commit 7656ea6

Please sign in to comment.