Skip to content

Commit

Permalink
feat: refactor non-function root
Browse files Browse the repository at this point in the history
  • Loading branch information
vberlier committed Oct 28, 2023
1 parent 24a9e6e commit b74f756
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 25 deletions.
6 changes: 3 additions & 3 deletions bolt/ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = [
"AstModuleRoot",
"AstNonFunctionRoot",
"AstExpression",
"AstExpressionBinary",
"AstExpressionUnary",
Expand Down Expand Up @@ -83,8 +83,8 @@


@dataclass(frozen=True, slots=True)
class AstModuleRoot(AstRoot):
"""Module root ast node."""
class AstNonFunctionRoot(AstRoot):
"""Non-function root ast node."""


@dataclass(frozen=True, slots=True)
Expand Down
9 changes: 5 additions & 4 deletions bolt/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Tuple, Type, cast
from uuid import UUID, uuid4

from beet import Function
from beet.core.utils import extra_field
from mecha import (
AdjacentConstraint,
Expand Down Expand Up @@ -149,7 +150,7 @@
AstMacroMatchArgument,
AstMacroMatchLiteral,
AstMemo,
AstModuleRoot,
AstNonFunctionRoot,
AstProcMacro,
AstProcMacroMarker,
AstProcMacroResult,
Expand All @@ -165,7 +166,7 @@
AstValue,
)
from .emit import CommandEmitter
from .module import Module, ModuleManager, UnusableCompilationUnit
from .module import ModuleManager, UnusableCompilationUnit
from .pattern import (
DOCSTRING_PATTERN,
FALSE_PATTERN,
Expand Down Expand Up @@ -621,8 +622,8 @@ def __call__(self, stream: TokenStream) -> Any:

self.macro_handler.cache_local_spec(stream)

if isinstance(node, AstRoot) and isinstance(current, Module):
node = set_location(AstModuleRoot(commands=node.commands), node)
if isinstance(node, AstRoot) and not isinstance(current, Function):
node = set_location(AstNonFunctionRoot(commands=node.commands), node)

return node

Expand Down
41 changes: 25 additions & 16 deletions bolt/runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = [
"Runtime",
"Evaluator",
"check_toplevel_commands",
"NonFunctionSerializer",
]


Expand All @@ -18,6 +18,7 @@
AstRoot,
CommandSpec,
CommandTree,
CompilationDatabase,
Diagnostic,
Mecha,
Visitor,
Expand All @@ -28,7 +29,7 @@
from pathspec import PathSpec
from tokenstream import set_location

from .ast import AstModuleRoot
from .ast import AstNonFunctionRoot, AstRoot
from .codegen import Codegen
from .emit import CommandEmitter
from .helpers import get_bolt_helpers
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(self, ctx: Union[Context, Mecha]):

mc.steps.insert(0, self.evaluate)

mc.serialize.extend(check_toplevel_commands)
mc.serialize.extend(NonFunctionSerializer(database=mc.database))
mc.cache_backend = ModuleCacheBackend(modules=self.modules)

def expose(self, name: str, function: Callable[..., Any]):
Expand Down Expand Up @@ -235,7 +236,7 @@ def root(self, node: AstRoot) -> Optional[AstRoot]:
compilation_unit, module = self.modules.match_ast(node)

if (
isinstance(node, AstModuleRoot)
isinstance(self.modules.database.current, Module)
and not module.executed
and module.resource_location
and not self.entrypoint_spec.match_file(module.resource_location)
Expand All @@ -257,20 +258,28 @@ def root(self, node: AstRoot) -> Optional[AstRoot]:
compilation_unit.priority = module.execution_index
return node

def restore_module(self, key: TextFileBase[Any], node: AstModuleRoot, step: int):
def restore_module(self, key: TextFileBase[Any], node: AstRoot, step: int):
compilation_unit = self.modules.database[key]
compilation_unit.ast = node
self.modules.database.enqueue(key, step, compilation_unit.priority)


@rule(AstModuleRoot)
def check_toplevel_commands(node: AstModuleRoot, result: List[str]):
"""Emit diagnostic if module has toplevel commands."""
if node.commands:
command = node.commands[0]
name = command.identifier.partition(":")[0]
raise set_location(
Diagnostic("warn", f'Standalone "{name}" command in module.'),
command,
command.arguments[0] if command.arguments else command,
)
@dataclass
class NonFunctionSerializer(Visitor):
"""Serializer that preserves the original source of non-function files."""

database: CompilationDatabase = required_field()

@rule(AstNonFunctionRoot)
def non_function_root(self, node: AstNonFunctionRoot, result: List[str]):
if source := self.database[self.database.current].source:
result.append(source)
if node.commands:
command = node.commands[0]
name = command.identifier.partition(":")[0]
d = Diagnostic(
"warn", f'Ignored top-level "{name}" command outside function.'
)
return set_location(
d, command, command.arguments[0] if command.arguments else command
)
1 change: 0 additions & 1 deletion examples/bolt_patchers/patchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def process_files(ctx: Context):
mc.compile(
patcher,
resource_location=patcher_name,
readonly=True,
report=mc.diagnostics,
)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def test_parse(snapshot: SnapshotFixture, ctx: Context, source: Function):
ast = None
diagnostics = None

mc.database[mc.database.current] = CompilationUnit(resource_location="demo:test")
mc.database.current = source
mc.database[source] = CompilationUnit(resource_location="demo:test")

rand = Random()
rand.seed(42)
Expand Down

0 comments on commit b74f756

Please sign in to comment.