Skip to content

Commit

Permalink
feat: add chained comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Feb 20, 2024
1 parent 992d832 commit 0cc5da8
Show file tree
Hide file tree
Showing 24 changed files with 2,502 additions and 76 deletions.
9 changes: 9 additions & 0 deletions bolt/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"AstExpression",
"AstExpressionBinary",
"AstExpressionUnary",
"AstChainedComparison",
"AstValue",
"AstIdentifier",
"AstFormatString",
Expand Down Expand Up @@ -115,6 +116,14 @@ class AstExpressionUnary(AstExpression):
value: AstExpression = required_field()


@dataclass(frozen=True, slots=True)
class AstChainedComparison(AstExpression):
"""Ast chained comparison node."""

operators: Tuple[str, ...] = required_field()
operands: AstChildren[AstExpression] = required_field()


@dataclass(frozen=True, slots=True)
class AstValue(AstExpression):
"""Ast value node."""
Expand Down
98 changes: 67 additions & 31 deletions bolt/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AstAssignment,
AstAttribute,
AstCall,
AstChainedComparison,
AstClassBases,
AstClassName,
AstDict,
Expand Down Expand Up @@ -304,6 +305,43 @@ def else_statement(self):
with self.if_statement(self.condition_inverse):
yield

def binary(self, left: str, op: str, right: str, *, lineno: Any = None):
"""Emit binary operator."""
if op in ["in", "not_in"]:
value = self.helper(f"operator_{op}", left, right)
self.statement(f"{left} = {value}", lineno=lineno)
else:
op = op.replace("_", " ")
self.statement(f"{left} = {left} {op} {right}", lineno=lineno)

def dup(self, target: str, *, lineno: Any = None) -> str:
"""Emit __dup__()."""
dup = self.make_variable()
value = self.helper("get_dup", target)
self.statement(f"{dup} = {value}", lineno=lineno)
self.statement(f"if {dup} is not None:")
with self.block():
self.statement(f"{target} = {dup}()")
return dup

def rebind(self, target: str, op: str, value: str, *, lineno: Any = None):
"""Emit __rebind__()."""
rebind = self.helper("get_rebind", target)
self.statement(f"_bolt_rebind = {rebind}", lineno=lineno)
self.statement(f"{target} {op} {value}")
self.statement(f"if _bolt_rebind is not None:")
with self.block():
self.statement(f"{target} = _bolt_rebind({target})")

def rebind_dup(self, target: str, dup: str, value: str, *, lineno: Any = None):
"""Emit __rebind__() if target was __dup__()."""
self.statement(f"if {dup} is not None:")
with self.block():
self.rebind(target, "=", value, lineno=lineno)
self.statement("else:")
with self.block():
self.statement(f"{target} = {value}")

def enclose(self, code: str, from_index: int, *, lineno: Any = None):
"""Enclose statements starting from the given index."""
self.statements[from_index:] = [
Expand Down Expand Up @@ -498,12 +536,7 @@ def visit_binding(

for node, target, value in zip(nodes, targets, values):
if isinstance(node, AstTargetIdentifier) and node.rebind:
rebind = acc.helper("get_rebind", target)
acc.statement(f"_bolt_rebind = {rebind}", lineno=node)
acc.statement(f"{target} {op} {value}")
acc.statement(f"if _bolt_rebind is not None:")
with acc.block():
acc.statement(f"{target} = _bolt_rebind({target})")
acc.rebind(target, op, value, lineno=node)
else:
acc.statement(f"{target} {op} {value}", lineno=node)

Expand Down Expand Up @@ -1125,12 +1158,7 @@ def binary(
) -> Generator[AstNode, Optional[List[str]], Optional[List[str]]]:
left = yield from visit_single(node.left, required=True)
right = yield from visit_single(node.right, required=True)
if node.operator in ["in", "not_in"]:
value = acc.helper(f"operator_{node.operator}", left, right)
acc.statement(f"{left} = {value}", lineno=node)
else:
op = node.operator.replace("_", " ")
acc.statement(f"{left} = {left} {op} {right}", lineno=node)
acc.binary(left, node.operator, right, lineno=node)
return [left]

@rule(AstExpressionBinary, operator="and")
Expand All @@ -1146,28 +1174,11 @@ def binary_logical(
value = acc.helper("operator_not", left) if node.operator == "or" else left
acc.statement(f"{condition} = {value}", lineno=node.left)

dup = acc.make_variable()
value = acc.helper("get_dup", left)
acc.statement(f"{dup} = {value}")
acc.statement(f"if {dup} is not None:")
with acc.block():
acc.statement(f"{left} = {dup}()")
dup = acc.dup(left)

with acc.if_statement(condition):
right = yield from visit_single(node.right, required=True)

acc.statement(f"if {dup} is not None:")
with acc.block():
rebind = acc.helper("get_rebind", left)
acc.statement(f"_bolt_rebind = {rebind}", lineno=node.right)
acc.statement(f"{left} = {right}")
acc.statement(f"if _bolt_rebind is not None:")
with acc.block():
acc.statement(f"{left} = _bolt_rebind({left})")

acc.statement("else:")
with acc.block():
acc.statement(f"{left} = {right}")
acc.rebind_dup(left, dup, right, lineno=node.right)

return [left]

Expand All @@ -1186,6 +1197,31 @@ def unary(
acc.statement(f"{result} = {op} {result}", lineno=node)
return [result]

@rule(AstChainedComparison)
def chained_comparison(
self,
node: AstChainedComparison,
acc: Accumulator,
) -> Generator[AstNode, Optional[List[str]], Optional[List[str]]]:
left = yield from visit_single(node.operands[0], required=True)
right = yield from visit_single(node.operands[1], required=True)
acc.binary(left, node.operators[0], right, lineno=node)

condition = acc.make_variable()

for op, operand in zip(node.operators[1:], node.operands[2:]):
acc.statement(f"{condition} = {left}")

dup = acc.dup(left)

with acc.if_statement(condition):
current = right
right = yield from visit_single(operand, required=True)
acc.binary(current, op, right, lineno=node)
acc.rebind_dup(left, dup, current, lineno=operand)

return [left]

@rule(AstValue)
def value(self, node: AstValue, acc: Accumulator) -> Optional[List[str]]:
result = acc.make_variable()
Expand Down
35 changes: 33 additions & 2 deletions bolt/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"RootScopeHandler",
"BinaryParser",
"UnaryParser",
"ChainedComparisonParser",
"UnpackParser",
"UnpackConstraint",
"KeywordParser",
Expand Down Expand Up @@ -112,6 +113,7 @@
AstAssignment,
AstAttribute,
AstCall,
AstChainedComparison,
AstClassBases,
AstClassName,
AstClassRoot,
Expand Down Expand Up @@ -380,7 +382,7 @@ def get_bolt_parsers(
operators=[r"\bnot\b"],
parser=delegate("bolt:comparison"),
),
"bolt:comparison": BinaryParser(
"bolt:comparison": ChainedComparisonParser(
operators=[
"==",
"!=",
Expand Down Expand Up @@ -2283,7 +2285,9 @@ class BinaryParser:
parser: Parser
right_associative: bool = False

def __call__(self, stream: TokenStream) -> Any:
def parse_operands(
self, stream: TokenStream
) -> Tuple[List[AstExpression], List[str]]:
with stream.syntax(operator="|".join(self.operators)):
nodes = [self.parser(stream)]
operations: List[str] = []
Expand All @@ -2292,6 +2296,11 @@ def __call__(self, stream: TokenStream) -> Any:
nodes.append(self.parser(stream))
operations.append(normalize_whitespace(op.value))

return nodes, operations

def __call__(self, stream: TokenStream) -> Any:
nodes, operations = self.parse_operands(stream)

if self.right_associative:
result = nodes[-1]
nodes = nodes[-2::-1]
Expand Down Expand Up @@ -2325,6 +2334,28 @@ def __call__(self, stream: TokenStream) -> Any:
return self.parser(stream)


@dataclass
class ChainedComparisonParser(BinaryParser):
"""Parser for chained comparisons."""

def __call__(self, stream: TokenStream) -> Any:
nodes, operations = self.parse_operands(stream)

if len(nodes) == 1:
return nodes[0]

if len(operations) == 1:
node = AstExpressionBinary(
operator=operations[0], left=nodes[0], right=nodes[1]
)
return set_location(node, node.left, node.right)

node = AstChainedComparison(
operators=tuple(operations), operands=AstChildren(nodes)
)
return set_location(node, nodes[0], nodes[-1])


@dataclass
class UnpackParser:
"""Parser for unpacking."""
Expand Down
2 changes: 1 addition & 1 deletion examples/bolt_basic/src/data/demo/functions/foo.mcfunction
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def thing_equal(self, item):
return f"thing == {item}"

Thing = type("Thing", (), {"__within__": thing_within, "__contains__": thing_contains, "__eq__": thing_equal})
Thing() in [1, 2, 3] in [99]
(Thing() in [1, 2, 3]) in [99]
"world" in ("hello" in Thing())

from contextlib import contextmanager
Expand Down
6 changes: 6 additions & 0 deletions examples/bolt_chained_comp/beet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
require:
- bolt
data_pack:
load: "src"
pipeline:
- mecha
80 changes: 80 additions & 0 deletions examples/bolt_chained_comp/src/data/demo/functions/foo.mcfunction
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from contextlib import contextmanager
from itertools import combinations


class Tmp:
def __init__(self, name = None):
if name is None:
name = ctx.generate.format("tmp{incr}")
self.name = name

def __dup__(self):
result = Tmp()
result = self
return result

def __rebind__(self, rhs):
if isinstance(rhs, Tmp):
scoreboard players operation global self = global rhs
else:
scoreboard players set global self int(rhs)
return self

@contextmanager
def __branch__(self):
unless score global self matches 0:
yield True

def __not__(self):
result = Tmp()
result = 1
unless score global self matches 0:
result = 0
return result

def __eq__(self, rhs):
result = Tmp()
result = 0
if isinstance(rhs, Tmp):
if score global self = global rhs:
result = 1
else:
if score global self matches int(rhs):
result = 1
return result

def __str__(self):
return self.name


for s in [7, Tmp("seven")]:
if s == s and s == s and s == s:
say 1

if s == s == s == s:
say 2


if 1 == (Tmp("foo") == Tmp("bar")) == Tmp("thing"):
say 3


for a, b in combinations([123, 456, Tmp("foo"), Tmp("bar")], 2):
raw #
raw f"# check {a}, {b}"
say (a == a == a == a)
say (a == a == a == b)
say (a == a == b == a)
say (a == a == b == b)
say (a == b == a == a)
say (a == b == a == b)
say (a == b == b == a)
say (a == b == b == b)
say (b == a == a == a)
say (b == a == a == b)
say (b == a == b == a)
say (b == a == b == b)
say (b == b == a == a)
say (b == b == a == b)
say (b == b == b == a)
say (b == b == b == b)
21 changes: 20 additions & 1 deletion tests/resources/bolt_examples.mcfunction
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ for a in [False, True]:
for b in [False, True]:
for c in [False, True]:
for d in [False, True]:
print((a and b or c and not d) in [True] not in [False])
print(((a and b or c and not d) in [True]) not in [False])
###
a = 1
def f():
Expand Down Expand Up @@ -1364,3 +1364,22 @@ text = "demo:foo"
class A:
text: str
data: str
###
1 == 2 == 3
###
if 1 < 2 < 3 < 4:
pass
###
print(123 == 123 == 123 == 123)
###
if 7 == 7 and 7 == 7 and 7 == 7:
say yep
###
if 7 == 7 == 7 == 7:
say 2
###
if 7 == (7 == 7) == 7:
say 2
###
if (7 == 7) == (7 == 7):
say 2
Loading

0 comments on commit 0cc5da8

Please sign in to comment.