Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): use mainline sqlglot #11693

Merged
merged 8 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions metadata-ingestion-modules/dagster-plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,6 @@ def get_long_description():
return pathlib.Path(os.path.join(root, "README.md")).read_text()


rest_common = {"requests", "requests_file"}

sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==24.0.1.dev7",
}

_version: str = package_metadata["__version__"]
_self_pin = (
f"=={_version}"
Expand All @@ -32,11 +24,7 @@ def get_long_description():
# Actual dependencies.
"dagster >= 1.3.3",
"dagit >= 1.3.3",
*rest_common,
# Ignoring the dependency below because it causes issues with the vercel built wheel install
# f"acryl-datahub[datahub-rest]{_self_pin}",
"acryl-datahub[datahub-rest]",
*sqlglot_lib,
f"acryl-datahub[datahub-rest,sql-parser]{_self_pin}",
}

mypy_stubs = {
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ target-version = ['py37', 'py38', 'py39', 'py310']
[tool.isort]
combine_as_imports = true
indent = ' '
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub_provider._airflow_compat']
known_future_library = ['__future__', 'datahub.utilities._markupsafe_compat', 'datahub.sql_parsing._sqlglot_patch']
profile = 'black'
sections = 'FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER'
skip_glob = 'src/datahub/metadata'
Expand Down
6 changes: 4 additions & 2 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,11 @@
}

sqlglot_lib = {
# Using an Acryl fork of sqlglot.
# We heavily monkeypatch sqlglot.
# Prior to the patching, we originally maintained an acryl-sqlglot fork:
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main?expand=1
"acryl-sqlglot[rs]==25.25.2.dev9",
"sqlglot[rs]==25.26.0",
"patchy==2.8.0",
}

classification_lib = {
Expand Down
215 changes: 215 additions & 0 deletions metadata-ingestion/src/datahub/sql_parsing/_sqlglot_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import dataclasses
import difflib
import logging

import patchy.api
import sqlglot
import sqlglot.expressions
import sqlglot.lineage
import sqlglot.optimizer.scope
import sqlglot.optimizer.unnest_subqueries

from datahub.utilities.is_pytest import is_pytest_running
from datahub.utilities.unified_diff import apply_diff

# This injects a few patches into sqlglot to add features and mitigate
# some bugs and performance issues.
# The diffs in this file should match the diffs declared in our fork.
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main
# For a diff-formatted view, see:
# https://github.com/tobymao/sqlglot/compare/main...hsheth2:sqlglot:main.diff

_DEBUG_PATCHER = is_pytest_running() or True
logger = logging.getLogger(__name__)

_apply_diff_subprocess = patchy.api._apply_patch


def _new_apply_patch(source: str, patch_text: str, forwards: bool, name: str) -> str:
assert forwards, "Only forward patches are supported"

result = apply_diff(source, patch_text)

# TODO: When in testing mode, still run the subprocess and check that the
# results line up.
if _DEBUG_PATCHER:
result_subprocess = _apply_diff_subprocess(source, patch_text, forwards, name)
if result_subprocess != result:
logger.info("Results from subprocess and _apply_diff do not match")
logger.debug(f"Subprocess result:\n{result_subprocess}")
logger.debug(f"Our result:\n{result}")
diff = difflib.unified_diff(
result_subprocess.splitlines(), result.splitlines()
)
logger.debug("Diff:\n" + "\n".join(diff))
raise ValueError("Results from subprocess and _apply_diff do not match")

return result


patchy.api._apply_patch = _new_apply_patch


def _patch_deepcopy() -> None:
patchy.patch(
sqlglot.expressions.Expression.__deepcopy__,
"""\
@@ -1,4 +1,7 @@ def meta(self) -> t.Dict[str, t.Any]:
def __deepcopy__(self, memo):
+ import datahub.utilities.cooperative_timeout
+ datahub.utilities.cooperative_timeout.cooperate()
+
root = self.__class__()
stack = [(self, root)]
""",
)


def _patch_scope_traverse() -> None:
# Circular scope dependencies can happen in somewhat specific circumstances
# due to our usage of sqlglot.
# See https://github.com/tobymao/sqlglot/pull/4244
patchy.patch(
sqlglot.optimizer.scope.Scope.traverse,
"""\
@@ -5,9 +5,16 @@ def traverse(self):
Scope: scope instances in depth-first-search post-order
\"""
stack = [self]
+ seen_scopes = set()
result = []
while stack:
scope = stack.pop()
+
+ # Scopes aren't hashable, so we use id(scope) instead.
+ if id(scope) in seen_scopes:
+ raise OptimizeError(f"Scope {scope} has a circular scope dependency")
+ seen_scopes.add(id(scope))
+
result.append(scope)
stack.extend(
itertools.chain(
""",
)


def _patch_unnest_subqueries() -> None:
patchy.patch(
sqlglot.optimizer.unnest_subqueries.decorrelate,
"""\
@@ -261,16 +261,19 @@ def remove_aggs(node):
if key in group_by:
key.replace(nested)
elif isinstance(predicate, exp.EQ):
- parent_predicate = _replace(
- parent_predicate,
- f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
- )
+ if parent_predicate:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
+ )
else:
key.replace(exp.to_identifier("_x"))
- parent_predicate = _replace(
- parent_predicate,
- f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
- )
+
+ if parent_predicate:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))",
+ )
""",
)


def _patch_lineage() -> None:
# Add the "subfield" attribute to sqlglot.lineage.Node.
# With dataclasses, the easiest way to do this is with inheritance.
# Unfortunately, mypy won't pick up on the new field, so we need to
# use type ignores everywhere we use subfield.
@dataclasses.dataclass(frozen=True)
class Node(sqlglot.lineage.Node):
subfield: str = ""

sqlglot.lineage.Node = Node # type: ignore

patchy.patch(
sqlglot.lineage.lineage,
"""\
@@ -12,7 +12,8 @@ def lineage(
\"""

expression = maybe_parse(sql, dialect=dialect)
- column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
+ # column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
+ assert isinstance(column, str)

if sources:
expression = exp.expand(
""",
)

patchy.patch(
sqlglot.lineage.to_node,
"""\
@@ -235,11 +237,12 @@ def to_node(
)

# Find all columns that went into creating this one to list their lineage nodes.
- source_columns = set(find_all_in_scope(select, exp.Column))
+ source_columns = list(find_all_in_scope(select, exp.Column))

- # If the source is a UDTF find columns used in the UTDF to generate the table
+ # If the source is a UDTF find columns used in the UDTF to generate the table
+ source = scope.expression
if isinstance(source, exp.UDTF):
- source_columns |= set(source.find_all(exp.Column))
+ source_columns += list(source.find_all(exp.Column))
derived_tables = [
source.expression.parent
for source in scope.sources.values()
@@ -254,6 +257,7 @@ def to_node(
if dt.comments and dt.comments[0].startswith("source: ")
}

+ c: exp.Column
for c in source_columns:
table = c.table
source = scope.sources.get(table)
@@ -281,8 +285,21 @@ def to_node(
# it means this column's lineage is unknown. This can happen if the definition of a source used in a query
# is not passed into the `sources` map.
source = source or exp.Placeholder()
+
+ subfields = []
+ field: exp.Expression = c
+ while isinstance(field.parent, exp.Dot):
+ field = field.parent
+ subfields.append(field.name)
+ subfield = ".".join(subfields)
+
node.downstream.append(
- Node(name=c.sql(comments=False), source=source, expression=source)
+ Node(
+ name=c.sql(comments=False),
+ source=source,
+ expression=source,
+ subfield=subfield,
+ )
)

return node
""",
)


_patch_deepcopy()
_patch_scope_traverse()
_patch_unnest_subqueries()
_patch_lineage()

SQLGLOT_PATCHED = True
4 changes: 4 additions & 0 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED

import dataclasses
import functools
import logging
Expand Down Expand Up @@ -53,6 +55,8 @@
cooperative_timeout,
)

assert SQLGLOT_PATCHED

logger = logging.getLogger(__name__)

Urn = str
Expand Down
4 changes: 4 additions & 0 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED

import functools
import hashlib
import logging
Expand All @@ -8,6 +10,8 @@
import sqlglot.errors
import sqlglot.optimizer.eliminate_ctes

assert SQLGLOT_PATCHED

logger = logging.getLogger(__name__)
DialectOrStr = Union[sqlglot.Dialect, str]
SQL_PARSE_CACHE_SIZE = 1000
Expand Down
5 changes: 5 additions & 0 deletions metadata-ingestion/src/datahub/utilities/is_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys


def is_pytest_running() -> bool:
return "pytest" in sys.modules
Loading
Loading