diff --git a/bolt/ast.py b/bolt/ast.py index 513a8ea..550241d 100644 --- a/bolt/ast.py +++ b/bolt/ast.py @@ -468,6 +468,7 @@ class AstClassBases(AstNode): """Ast class bases node.""" inherit: AstChildren[AstExpression] = AstChildren() + kwargs: AstChildren[AstKeyword] = AstChildren() @dataclass(frozen=True, slots=True) diff --git a/bolt/codegen.py b/bolt/codegen.py index 7555c2b..4bd4a79 100644 --- a/bolt/codegen.py +++ b/bolt/codegen.py @@ -780,15 +780,18 @@ def class_definition( value = yield from visit_single(decorator.expression, required=True) decorators.append(value) - inherit: List[str] = [] + class_args: List[str] = [] if isinstance(bases := node.arguments[1], AstClassBases): for base in bases.inherit: result = yield from visit_single(base, required=True) - inherit.append(result) + class_args.append(result) + for kwarg in bases.kwargs: + result = yield from visit_single(kwarg, required=True) + class_args.append(result) - joined_bases = f"({', '.join(inherit)})" if inherit else "" - acc.statement(f"class {name.value}{joined_bases}:", lineno=node) + joined_args = f"({', '.join(class_args)})" if class_args else "" + acc.statement(f"class {name.value}{joined_args}:", lineno=node) with acc.block(): temp_start = acc.counter diff --git a/bolt/parse.py b/bolt/parse.py index c14bd11..0254143 100644 --- a/bolt/parse.py +++ b/bolt/parse.py @@ -31,7 +31,7 @@ "ProcMacroParser", "ProcMacroExpansion", "parse_class_name", - "parse_class_bases", + "ClassBasesParser", "parse_class_root", "parse_del_target", "parse_identifier", @@ -316,7 +316,14 @@ def get_bolt_parsers( ), "bolt:proc_macro": ProcMacroParser(modules), "bolt:class_name": parse_class_name, - "bolt:class_bases": parse_class_bases, + "bolt:class_bases": ClassBasesParser( + AlternativeParser( + [ + KeywordParser(delegate("bolt:expression")), + delegate("bolt:expression"), + ] + ) + ), "bolt:class_root": FlushPendingBindingsParser(parse_class_root, after=True), "bolt:memo_variable": TrailingCommaParser( AlternativeParser( @@ -1771,23 +1778,39 @@ def parse_class_name(stream: TokenStream) -> AstClassName: return node -def parse_class_bases(stream: TokenStream) -> AstClassBases: - """Parse class bases.""" - inherit: List[AstExpression] = [] +@dataclass +class ClassBasesParser: + """Parser for class bases.""" - with stream.syntax(brace=r"\(|\)", comma=","): - token = stream.expect(("brace", "(")) + parser: Parser - with stream.ignore("newline"): - for _ in stream.peek_until(("brace", ")")): - inherit.append(delegate("bolt:expression", stream)) + def __call__(self, stream: TokenStream) -> AstClassBases: + inherit: List[AstExpression] = [] + kwargs: List[AstKeyword] = [] - if not stream.get("comma"): - stream.expect(("brace", ")")) - break + with stream.syntax(brace=r"\(|\)", comma=","): + token = stream.expect(("brace", "(")) - node = AstClassBases(inherit=AstChildren(inherit)) - return set_location(node, token, stream.current) + with stream.ignore("newline"): + for _ in stream.peek_until(("brace", ")")): + arg = self.parser(stream) + + if isinstance(arg, AstKeyword): + kwargs.append(arg) + else: + if kwargs: + exc = InvalidSyntax( + f'Base class not allowed after keyword argument "{kwargs[-1].name}".' + ) + raise set_location(exc, arg) + inherit.append(arg) + + if not stream.get("comma"): + stream.expect(("brace", ")")) + break + + node = AstClassBases(inherit=AstChildren(inherit), kwargs=AstChildren(kwargs)) + return set_location(node, token, stream.current) def parse_class_root(stream: TokenStream) -> AstClassRoot: diff --git a/examples/bolt_class/src/data/demo/functions/pydantic.mcfunction b/examples/bolt_class/src/data/demo/functions/pydantic.mcfunction new file mode 100644 index 0000000..c751c60 --- /dev/null +++ b/examples/bolt_class/src/data/demo/functions/pydantic.mcfunction @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +class Model(BaseModel, frozen=True): + a: int + b: str + +say Model(a="123", b="123").json() diff --git a/tests/resources/bolt_examples.mcfunction b/tests/resources/bolt_examples.mcfunction index 509592a..fcd9f4e 100644 --- a/tests/resources/bolt_examples.mcfunction +++ b/tests/resources/bolt_examples.mcfunction @@ -1342,3 +1342,16 @@ predicate x [] ### with storage ./args: $say $(message) +### +class A(foo=True, bar=1 + 2): + pass +### +class A: + pass +class B(foo=True, A): + pass +### +class A: + pass +class B(A, foo=True): + pass diff --git a/tests/snapshots/bolt__parse_247__0.txt b/tests/snapshots/bolt__parse_247__0.txt index 3eaf33a..98eb722 100644 --- a/tests/snapshots/bolt__parse_247__0.txt +++ b/tests/snapshots/bolt__parse_247__0.txt @@ -24,6 +24,8 @@ class Foo(object): location: SourceLocation(pos=10, lineno=1, colno=11) end_location: SourceLocation(pos=16, lineno=1, colno=17) value: 'object' + kwargs: + location: SourceLocation(pos=23, lineno=2, colno=5) end_location: SourceLocation(pos=27, lineno=2, colno=9) diff --git a/tests/snapshots/bolt__parse_248__0.txt b/tests/snapshots/bolt__parse_248__0.txt index 768b23f..ff2b479 100644 --- a/tests/snapshots/bolt__parse_248__0.txt +++ b/tests/snapshots/bolt__parse_248__0.txt @@ -74,6 +74,8 @@ class Foo(A, B): location: SourceLocation(pos=49, lineno=5, colno=14) end_location: SourceLocation(pos=50, lineno=5, colno=15) value: 'B' + kwargs: + location: SourceLocation(pos=57, lineno=6, colno=5) end_location: SourceLocation(pos=61, lineno=6, colno=9) diff --git a/tests/snapshots/bolt__parse_367__0.txt b/tests/snapshots/bolt__parse_367__0.txt new file mode 100644 index 0000000..5f9abbb --- /dev/null +++ b/tests/snapshots/bolt__parse_367__0.txt @@ -0,0 +1,62 @@ +class A(foo=True, bar=1 + 2): + pass +--- + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=39, lineno=3, colno=1) + commands: + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=38, lineno=2, colno=9) + identifier: 'class:name:bases:body' + arguments: + + location: SourceLocation(pos=6, lineno=1, colno=7) + end_location: SourceLocation(pos=7, lineno=1, colno=8) + decorators: + + value: 'A' + + location: SourceLocation(pos=7, lineno=1, colno=8) + end_location: SourceLocation(pos=28, lineno=1, colno=29) + inherit: + + kwargs: + + location: SourceLocation(pos=8, lineno=1, colno=9) + end_location: SourceLocation(pos=16, lineno=1, colno=17) + name: 'foo' + value: + + location: SourceLocation(pos=12, lineno=1, colno=13) + end_location: SourceLocation(pos=16, lineno=1, colno=17) + value: True + + location: SourceLocation(pos=18, lineno=1, colno=19) + end_location: SourceLocation(pos=27, lineno=1, colno=28) + name: 'bar' + value: + + location: SourceLocation(pos=22, lineno=1, colno=23) + end_location: SourceLocation(pos=27, lineno=1, colno=28) + operator: '+' + left: + + location: SourceLocation(pos=22, lineno=1, colno=23) + end_location: SourceLocation(pos=23, lineno=1, colno=24) + value: 1 + right: + + location: SourceLocation(pos=26, lineno=1, colno=27) + end_location: SourceLocation(pos=27, lineno=1, colno=28) + value: 2 + + location: SourceLocation(pos=34, lineno=2, colno=5) + end_location: SourceLocation(pos=38, lineno=2, colno=9) + commands: + + location: SourceLocation(pos=34, lineno=2, colno=5) + end_location: SourceLocation(pos=38, lineno=2, colno=9) + identifier: 'pass' + arguments: + diff --git a/tests/snapshots/bolt__parse_367__1.txt b/tests/snapshots/bolt__parse_367__1.txt new file mode 100644 index 0000000..9b1a40e --- /dev/null +++ b/tests/snapshots/bolt__parse_367__1.txt @@ -0,0 +1,20 @@ +_bolt_lineno = [1], [1] +_bolt_helper_children = _bolt_runtime.helpers['children'] +_bolt_helper_replace = _bolt_runtime.helpers['replace'] +with _bolt_runtime.scope() as _bolt_var3: + _bolt_var0 = True + _bolt_var1 = 1 + _bolt_var2 = 2 + _bolt_var1 = _bolt_var1 + _bolt_var2 + class A(foo=_bolt_var0, bar=_bolt_var1): + pass +_bolt_var4 = _bolt_helper_replace(_bolt_refs[0], commands=_bolt_helper_children(_bolt_var3)) +--- +output = _bolt_var4 +--- +_bolt_refs[0] + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=39, lineno=3, colno=1) + commands: + diff --git a/tests/snapshots/bolt__parse_368__0.txt b/tests/snapshots/bolt__parse_368__0.txt new file mode 100644 index 0000000..88103f3 --- /dev/null +++ b/tests/snapshots/bolt__parse_368__0.txt @@ -0,0 +1,10 @@ +class A: + pass +#>ERROR Base class not allowed after keyword argument "foo". +# line 3, column 19 +# 2 | pass +# 3 | class B(foo=True, A): +# : ^ +# 4 | pass +class B(foo=True, A): + pass diff --git a/tests/snapshots/bolt__parse_369__0.txt b/tests/snapshots/bolt__parse_369__0.txt new file mode 100644 index 0000000..2fe47bb --- /dev/null +++ b/tests/snapshots/bolt__parse_369__0.txt @@ -0,0 +1,69 @@ +class A: + pass +class B(A, foo=True): + pass +--- + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=49, lineno=5, colno=1) + commands: + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=17, lineno=2, colno=9) + identifier: 'class:name:body' + arguments: + + location: SourceLocation(pos=6, lineno=1, colno=7) + end_location: SourceLocation(pos=7, lineno=1, colno=8) + decorators: + + value: 'A' + + location: SourceLocation(pos=13, lineno=2, colno=5) + end_location: SourceLocation(pos=17, lineno=2, colno=9) + commands: + + location: SourceLocation(pos=13, lineno=2, colno=5) + end_location: SourceLocation(pos=17, lineno=2, colno=9) + identifier: 'pass' + arguments: + + + location: SourceLocation(pos=18, lineno=3, colno=1) + end_location: SourceLocation(pos=48, lineno=4, colno=9) + identifier: 'class:name:bases:body' + arguments: + + location: SourceLocation(pos=24, lineno=3, colno=7) + end_location: SourceLocation(pos=25, lineno=3, colno=8) + decorators: + + value: 'B' + + location: SourceLocation(pos=25, lineno=3, colno=8) + end_location: SourceLocation(pos=38, lineno=3, colno=21) + inherit: + + location: SourceLocation(pos=26, lineno=3, colno=9) + end_location: SourceLocation(pos=27, lineno=3, colno=10) + value: 'A' + kwargs: + + location: SourceLocation(pos=29, lineno=3, colno=12) + end_location: SourceLocation(pos=37, lineno=3, colno=20) + name: 'foo' + value: + + location: SourceLocation(pos=33, lineno=3, colno=16) + end_location: SourceLocation(pos=37, lineno=3, colno=20) + value: True + + location: SourceLocation(pos=44, lineno=4, colno=5) + end_location: SourceLocation(pos=48, lineno=4, colno=9) + commands: + + location: SourceLocation(pos=44, lineno=4, colno=5) + end_location: SourceLocation(pos=48, lineno=4, colno=9) + identifier: 'pass' + arguments: + diff --git a/tests/snapshots/bolt__parse_369__1.txt b/tests/snapshots/bolt__parse_369__1.txt new file mode 100644 index 0000000..c088990 --- /dev/null +++ b/tests/snapshots/bolt__parse_369__1.txt @@ -0,0 +1,21 @@ +_bolt_lineno = [1, 7], [1, 3] +_bolt_helper_children = _bolt_runtime.helpers['children'] +_bolt_helper_replace = _bolt_runtime.helpers['replace'] +with _bolt_runtime.scope() as _bolt_var2: + class A: + pass + _bolt_var0 = A + _bolt_var1 = True + class B(_bolt_var0, foo=_bolt_var1): + pass +_bolt_var3 = _bolt_helper_replace(_bolt_refs[0], commands=_bolt_helper_children(_bolt_var2)) +--- +output = _bolt_var3 +--- +_bolt_refs[0] + + location: SourceLocation(pos=0, lineno=1, colno=1) + end_location: SourceLocation(pos=49, lineno=5, colno=1) + commands: + + diff --git a/tests/snapshots/examples__build_bolt_class__0.pack.md b/tests/snapshots/examples__build_bolt_class__0.pack.md index 59cf286..acd176f 100644 --- a/tests/snapshots/examples__build_bolt_class__0.pack.md +++ b/tests/snapshots/examples__build_bolt_class__0.pack.md @@ -31,3 +31,9 @@ say a say a say b ``` + +`@function demo:pydantic` + +```mcfunction +say {"a": 123, "b": "123"} +```