From b74f756229f2ba4f22a09c00caa512aebf810a1e Mon Sep 17 00:00:00 2001 From: Valentin Berlier Date: Sat, 28 Oct 2023 02:02:29 +0200 Subject: [PATCH] feat: refactor non-function root --- bolt/ast.py | 6 ++--- bolt/parse.py | 9 ++++--- bolt/runtime.py | 41 ++++++++++++++++++------------ examples/bolt_patchers/patchers.py | 1 - tests/test_bolt.py | 3 ++- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/bolt/ast.py b/bolt/ast.py index af5f996..513a8ea 100644 --- a/bolt/ast.py +++ b/bolt/ast.py @@ -1,5 +1,5 @@ __all__ = [ - "AstModuleRoot", + "AstNonFunctionRoot", "AstExpression", "AstExpressionBinary", "AstExpressionUnary", @@ -83,8 +83,8 @@ @dataclass(frozen=True, slots=True) -class AstModuleRoot(AstRoot): - """Module root ast node.""" +class AstNonFunctionRoot(AstRoot): + """Non-function root ast node.""" @dataclass(frozen=True, slots=True) diff --git a/bolt/parse.py b/bolt/parse.py index 8f1e8f7..c14bd11 100644 --- a/bolt/parse.py +++ b/bolt/parse.py @@ -66,6 +66,7 @@ from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, Type, cast from uuid import UUID, uuid4 +from beet import Function from beet.core.utils import extra_field from mecha import ( AdjacentConstraint, @@ -149,7 +150,7 @@ AstMacroMatchArgument, AstMacroMatchLiteral, AstMemo, - AstModuleRoot, + AstNonFunctionRoot, AstProcMacro, AstProcMacroMarker, AstProcMacroResult, @@ -165,7 +166,7 @@ AstValue, ) from .emit import CommandEmitter -from .module import Module, ModuleManager, UnusableCompilationUnit +from .module import ModuleManager, UnusableCompilationUnit from .pattern import ( DOCSTRING_PATTERN, FALSE_PATTERN, @@ -621,8 +622,8 @@ def __call__(self, stream: TokenStream) -> Any: self.macro_handler.cache_local_spec(stream) - if isinstance(node, AstRoot) and isinstance(current, Module): - node = set_location(AstModuleRoot(commands=node.commands), node) + if isinstance(node, AstRoot) and not isinstance(current, Function): + node = set_location(AstNonFunctionRoot(commands=node.commands), node) return node diff --git a/bolt/runtime.py b/bolt/runtime.py index 80927ab..4b79d1e 100644 --- a/bolt/runtime.py +++ b/bolt/runtime.py @@ -1,7 +1,7 @@ __all__ = [ "Runtime", "Evaluator", - "check_toplevel_commands", + "NonFunctionSerializer", ] @@ -18,6 +18,7 @@ AstRoot, CommandSpec, CommandTree, + CompilationDatabase, Diagnostic, Mecha, Visitor, @@ -28,7 +29,7 @@ from pathspec import PathSpec from tokenstream import set_location -from .ast import AstModuleRoot +from .ast import AstNonFunctionRoot, AstRoot from .codegen import Codegen from .emit import CommandEmitter from .helpers import get_bolt_helpers @@ -136,7 +137,7 @@ def __init__(self, ctx: Union[Context, Mecha]): mc.steps.insert(0, self.evaluate) - mc.serialize.extend(check_toplevel_commands) + mc.serialize.extend(NonFunctionSerializer(database=mc.database)) mc.cache_backend = ModuleCacheBackend(modules=self.modules) def expose(self, name: str, function: Callable[..., Any]): @@ -235,7 +236,7 @@ def root(self, node: AstRoot) -> Optional[AstRoot]: compilation_unit, module = self.modules.match_ast(node) if ( - isinstance(node, AstModuleRoot) + isinstance(self.modules.database.current, Module) and not module.executed and module.resource_location and not self.entrypoint_spec.match_file(module.resource_location) @@ -257,20 +258,28 @@ def root(self, node: AstRoot) -> Optional[AstRoot]: compilation_unit.priority = module.execution_index return node - def restore_module(self, key: TextFileBase[Any], node: AstModuleRoot, step: int): + def restore_module(self, key: TextFileBase[Any], node: AstRoot, step: int): compilation_unit = self.modules.database[key] compilation_unit.ast = node self.modules.database.enqueue(key, step, compilation_unit.priority) -@rule(AstModuleRoot) -def check_toplevel_commands(node: AstModuleRoot, result: List[str]): - """Emit diagnostic if module has toplevel commands.""" - if node.commands: - command = node.commands[0] - name = command.identifier.partition(":")[0] - raise set_location( - Diagnostic("warn", f'Standalone "{name}" command in module.'), - command, - command.arguments[0] if command.arguments else command, - ) +@dataclass +class NonFunctionSerializer(Visitor): + """Serializer that preserves the original source of non-function files.""" + + database: CompilationDatabase = required_field() + + @rule(AstNonFunctionRoot) + def non_function_root(self, node: AstNonFunctionRoot, result: List[str]): + if source := self.database[self.database.current].source: + result.append(source) + if node.commands: + command = node.commands[0] + name = command.identifier.partition(":")[0] + d = Diagnostic( + "warn", f'Ignored top-level "{name}" command outside function.' + ) + return set_location( + d, command, command.arguments[0] if command.arguments else command + ) diff --git a/examples/bolt_patchers/patchers.py b/examples/bolt_patchers/patchers.py index e88e87e..29dc5f9 100644 --- a/examples/bolt_patchers/patchers.py +++ b/examples/bolt_patchers/patchers.py @@ -21,7 +21,6 @@ def process_files(ctx: Context): mc.compile( patcher, resource_location=patcher_name, - readonly=True, report=mc.diagnostics, ) diff --git a/tests/test_bolt.py b/tests/test_bolt.py index e437bf4..3183769 100644 --- a/tests/test_bolt.py +++ b/tests/test_bolt.py @@ -30,7 +30,8 @@ def test_parse(snapshot: SnapshotFixture, ctx: Context, source: Function): ast = None diagnostics = None - mc.database[mc.database.current] = CompilationUnit(resource_location="demo:test") + mc.database.current = source + mc.database[source] = CompilationUnit(resource_location="demo:test") rand = Random() rand.seed(42)