Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FOU-471] Implement AssetSelectionVisitor #25704

Open
wants to merge 7 commits into
base: briantu/set-up-antlr-asset-selection
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from antlr4 import CommonTokenStream, InputStream
from antlr4.error.ErrorListener import ErrorListener

from dagster._annotations import experimental
from dagster._core.definitions.antlr_asset_selection.generated.AssetSelectionLexer import (
Expand All @@ -7,18 +8,170 @@
from dagster._core.definitions.antlr_asset_selection.generated.AssetSelectionParser import (
AssetSelectionParser,
)
from dagster._core.definitions.antlr_asset_selection.generated.AssetSelectionVisitor import (
AssetSelectionVisitor,
)
from dagster._core.definitions.asset_selection import AssetSelection, CodeLocationAssetSelection
from dagster._core.storage.tags import KIND_PREFIX


class AntlrInputErrorListener(ErrorListener):
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
raise Exception(f"Syntax error at line {line}, column {column}: {msg}")


class AntlrAssetSelectionVisitor(AssetSelectionVisitor):
# Visit a parse tree produced by AssetSelectionParser#start.
def visitStart(self, ctx: AssetSelectionParser.StartContext):
return self.visit(ctx.expr())

# Visit a parse tree produced by AssetSelectionParser#AssetExpression.
def visitAssetExpression(self, ctx: AssetSelectionParser.AssetExpressionContext):
return self.visit(ctx.assetExpr())

# Visit a parse tree produced by AssetSelectionParser#ParenthesizedExpression.
def visitParenthesizedExpression(
self, ctx: AssetSelectionParser.ParenthesizedExpressionContext
):
return self.visit(ctx.expr())

# Visit a parse tree produced by AssetSelectionParser#UpTraversalExpression.
def visitUpTraversalExpression(self, ctx: AssetSelectionParser.UpTraversalExpressionContext):
selection: AssetSelection = self.visit(ctx.expr())
traversal_depth = self.visit(ctx.traversal())
return selection.upstream(depth=traversal_depth)

# Visit a parse tree produced by AssetSelectionParser#AndExpression.
def visitAndExpression(self, ctx: AssetSelectionParser.AndExpressionContext):
left: AssetSelection = self.visit(ctx.expr(0))
right: AssetSelection = self.visit(ctx.expr(1))
return left & right

# Visit a parse tree produced by AssetSelectionParser#NotExpression.
def visitNotExpression(self, ctx: AssetSelectionParser.NotExpressionContext):
selection: AssetSelection = self.visit(ctx.expr())
return AssetSelection.all() - selection

# Visit a parse tree produced by AssetSelectionParser#DownTraversalExpression.
def visitDownTraversalExpression(
self, ctx: AssetSelectionParser.DownTraversalExpressionContext
):
selection: AssetSelection = self.visit(ctx.expr())
traversal_depth = self.visit(ctx.traversal())
return selection.downstream(depth=traversal_depth)

# Visit a parse tree produced by AssetSelectionParser#OrExpression.
def visitOrExpression(self, ctx: AssetSelectionParser.OrExpressionContext):
left: AssetSelection = self.visit(ctx.expr(0))
right: AssetSelection = self.visit(ctx.expr(1))
return left | right

# Visit a parse tree produced by AssetSelectionParser#AttributeExpression.
def visitAttributeExpression(self, ctx: AssetSelectionParser.AttributeExpressionContext):
return self.visit(ctx.attributeExpr())

# Visit a parse tree produced by AssetSelectionParser#FunctionCallExpression.
def visitFunctionCallExpression(self, ctx: AssetSelectionParser.FunctionCallExpressionContext):
function = self.visit(ctx.functionName())
selection: AssetSelection = self.visit(ctx.expr())
if function == "sinks":
return selection.sinks()
elif function == "roots":
return selection.roots()

# Visit a parse tree produced by AssetSelectionParser#UpAndDownTraversalExpression.
def visitUpAndDownTraversalExpression(
self, ctx: AssetSelectionParser.UpAndDownTraversalExpressionContext
):
selection: AssetSelection = self.visit(ctx.expr())
up_depth = self.visit(ctx.traversal(0))
down_depth = self.visit(ctx.traversal(1))
return selection.upstream(depth=up_depth) | selection.downstream(depth=down_depth)

# Visit a parse tree produced by AssetSelectionParser#traversal.
def visitTraversal(self, ctx: AssetSelectionParser.TraversalContext):
# Get traversal depth from a traversal context
if ctx.STAR():
return None # Star means no depth limit
elif ctx.PLUS():
return len(ctx.PLUS()) # Depth is the count of '+'

# Visit a parse tree produced by AssetSelectionParser#functionName.
def visitFunctionName(self, ctx: AssetSelectionParser.FunctionNameContext):
if ctx.SINKS():
return "sinks"
elif ctx.ROOTS():
return "roots"

# Visit a parse tree produced by AssetSelectionParser#TagAttributeExpr.
def visitTagAttributeExpr(self, ctx: AssetSelectionParser.TagAttributeExprContext):
key = self.visit(ctx.value(0))
value = self.visit(ctx.value(1)) if ctx.EQUAL() else ""
return AssetSelection.tag(key, value)

# Visit a parse tree produced by AssetSelectionParser#OwnerAttributeExpr.
def visitOwnerAttributeExpr(self, ctx: AssetSelectionParser.OwnerAttributeExprContext):
owner = self.visit(ctx.value())
return AssetSelection.owner(owner)

# Visit a parse tree produced by AssetSelectionParser#GroupAttributeExpr.
def visitGroupAttributeExpr(self, ctx: AssetSelectionParser.GroupAttributeExprContext):
group = self.visit(ctx.value())
return AssetSelection.groups(group)

# Visit a parse tree produced by AssetSelectionParser#KindAttributeExpr.
def visitKindAttributeExpr(self, ctx: AssetSelectionParser.KindAttributeExprContext):
kind = self.visit(ctx.value())
return AssetSelection.tag(f"{KIND_PREFIX}{kind}", "")

# Visit a parse tree produced by AssetSelectionParser#CodeLocationAttributeExpr.
def visitCodeLocationAttributeExpr(
self, ctx: AssetSelectionParser.CodeLocationAttributeExprContext
):
code_location = self.visit(ctx.value())
return CodeLocationAssetSelection(selected_code_location=code_location)

# Visit a parse tree produced by AssetSelectionParser#value.
def visitValue(self, ctx: AssetSelectionParser.ValueContext):
if ctx.QUOTED_STRING():
return ctx.QUOTED_STRING().getText().strip('"')
elif ctx.UNQUOTED_STRING():
return ctx.UNQUOTED_STRING().getText()

# Visit a parse tree produced by AssetSelectionParser#ExactMatchAsset.
def visitExactMatchAsset(self, ctx: AssetSelectionParser.ExactMatchAssetContext):
asset = ctx.QUOTED_STRING().getText().strip('"')
return AssetSelection.assets(asset)

# Visit a parse tree produced by AssetSelectionParser#PrefixMatchAsset.
def visitPrefixMatchAsset(self, ctx: AssetSelectionParser.PrefixMatchAssetContext):
asset = ctx.UNQUOTED_STRING().getText()
return AssetSelection.key_prefixes(asset)


@experimental
class AntlrAssetSelection:
_visitor: AntlrAssetSelectionVisitor = AntlrAssetSelectionVisitor()

def __init__(self, selection_str: str):
lexer = AssetSelectionLexer(InputStream(selection_str))
lexer.removeErrorListeners()
lexer.addErrorListener(AntlrInputErrorListener())

stream = CommonTokenStream(lexer)

parser = AssetSelectionParser(stream)
parser.removeErrorListeners()
parser.addErrorListener(AntlrInputErrorListener())

self._tree = parser.start()
self._tree_str = self._tree.toStringTree(recog=parser)
self._asset_selection = self._visitor.visit(self._tree)

@property
def tree_str(self) -> str:
return self._tree_str

@property
def asset_selection(self) -> AssetSelection:
return self._asset_selection
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from dagster._core.definitions.antlr_asset_selection.antlr_asset_selection import (
AntlrAssetSelection,
)
from dagster._core.definitions.asset_selection import AssetSelection, CodeLocationAssetSelection
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.storage.tags import KIND_PREFIX


@pytest.mark.parametrize(
Expand All @@ -18,8 +21,8 @@
),
("kind:my_kind", "(start (expr (attributeExpr kind : (value my_kind))) <EOF>)"),
(
"codelocation:my_location",
"(start (expr (attributeExpr codelocation : (value my_location))) <EOF>)",
"code_location:my_location",
"(start (expr (attributeExpr code_location : (value my_location))) <EOF>)",
),
(
'((("a")))',
Expand All @@ -43,3 +46,71 @@
def test_antlr_tree(selection_str, expected_tree_str):
asset_selection = AntlrAssetSelection(selection_str)
assert asset_selection.tree_str == expected_tree_str


@pytest.mark.parametrize(
"selection_str",
["not", "a b", "a and and", "a and", "sinks", "owner", "tag:foo=", "owner:[email protected]"],
)
def test_antlr_tree_invalid(selection_str):
with pytest.raises(Exception):
AntlrAssetSelection(selection_str)


@pytest.mark.parametrize(
"selection_str, expected_assets",
[
('"a"', AssetSelection.assets("a")),
("not a", AssetSelection.all() - AssetSelection.assets("a")),
("a and b", AssetSelection.assets("a") & AssetSelection.assets("b")),
("a or b", AssetSelection.assets("a") | AssetSelection.assets("b")),
("+a", AssetSelection.assets("a").upstream(1)),
("++a", AssetSelection.assets("a").upstream(2)),
("a+", AssetSelection.assets("a").downstream(1)),
("a++", AssetSelection.assets("a").downstream(2)),
(
"a* and *b",
AssetSelection.assets("a").downstream() and AssetSelection.assets("b").upstream(),
),
("sinks(a)", AssetSelection.assets("a").sinks()),
("roots(c)", AssetSelection.assets("c").roots()),
("tag:foo", AssetSelection.tag("foo", "")),
("tag:foo=bar", AssetSelection.tag("foo", "bar")),
('owner:"[email protected]"', AssetSelection.owner("[email protected]")),
("group:my_group", AssetSelection.groups("my_group")),
("kind:my_kind", AssetSelection.tag(f"{KIND_PREFIX}my_kind", "")),
],
)
def test_antlr_visit_basic(selection_str, expected_assets):
# a -> b -> c
@asset(tags={"foo": "bar"}, owners=["team:billing"])
def a(): ...

@asset(deps=[a], kinds={"python", "snowflake"})
def b(): ...

@asset(
deps=[b],
group_name="my_group",
)
def c(): ...

defs = [a, b, c]

assert AntlrAssetSelection(selection_str).asset_selection.resolve(
defs
) == expected_assets.resolve(defs)


def test_code_location() -> None:
@asset
def my_asset(): ...

# Selection can be instantiated.
selection = AntlrAssetSelection("code_location:code_location1").asset_selection

assert selection == CodeLocationAssetSelection(selected_code_location="code_location1")

# But not resolved.
with pytest.raises(NotImplementedError):
selection.resolve([my_asset])