From 53a4ba9b8bb4a46173246d0f26d551277f8a3989 Mon Sep 17 00:00:00 2001 From: Valentin Berlier Date: Thu, 7 Dec 2023 00:32:58 +0100 Subject: [PATCH] fix: refactor codegen to use a statement tree --- bolt/codegen.py | 79 +++++++++++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/bolt/codegen.py b/bolt/codegen.py index cba409c..1e09695 100644 --- a/bolt/codegen.py +++ b/bolt/codegen.py @@ -1,6 +1,7 @@ __all__ = [ "Codegen", "Accumulator", + "CodegenStatement", "ChildrenCollector", "CommandCollector", "RootCommandCollector", @@ -89,18 +90,34 @@ from .module import CodegenResult, MacroLibrary +@dataclass(slots=True) +class CodegenStatement: + """Python statement emitted by the codegen, which can recursively contain other statements.""" + + code: str + lineno: Optional[int] = None + children: List["CodegenStatement"] = field(default_factory=list) + + def flatten(self, indent: str = "") -> Iterable[Tuple[str, Optional[int]]]: + """Yield the indented statements with their associated line number.""" + yield f"{indent}{self.code}", self.lineno + if self.children: + indent += 4 * " " + for child_statement in self.children: + yield from child_statement.flatten(indent) + + @dataclass class Accumulator: """Utility for generating python code.""" - indentation: str = "" refs: List[Any] = field(default_factory=list) dependencies: Set[str] = field(default_factory=set) prelude_imports: List[AstPrelude] = field(default_factory=list) macros: MacroLibrary = field(default_factory=dict) macro_ids: Dict[str, int] = field(default_factory=dict) memo_index: Dict[AstMemo, int] = field(default_factory=dict) - lines: List[str] = field(default_factory=list) + statements: List["CodegenStatement"] = field(default_factory=list) counter: int = 0 header: Dict[str, str] = field(default_factory=dict) root_scope: bool = True @@ -111,23 +128,21 @@ class Accumulator: def get_source(self) -> str: """Return the source code.""" - header = "".join( - f"{variable} = {expression}\n" + header = [ + CodegenStatement(f"{variable} = {expression}") for variable, expression in self.header.items() - ) + ] lines: List[str] = ["_bolt_lineno = "] numbers1: List[int] = [1] numbers2: List[int] = [1] - for line in (header + "".join(self.lines)).splitlines(): - if line.startswith("!lineno "): - current_line = int(line[8:]) - if numbers2[-1] != current_line: + for statement in header + self.statements: + for code, lineno in statement.flatten(): + if lineno and numbers2[-1] != lineno: numbers1.append(len(lines) + 1) - numbers2.append(current_line) - else: - lines.append(line) + numbers2.append(lineno) + lines.append(code) lines[0] += f"{numbers1}, {numbers2}" @@ -208,27 +223,27 @@ def get_macro(self, name: str) -> str: self.macro_ids[name] = len(self.macro_ids) return f"_bolt_macro{self.macro_ids[name]}" - def lineno(self, lineno: Any): - """Emit line number.""" + def extract_lineno(self, lineno: Any): + """Utility to extract the line number.""" if isinstance(lineno, AstNode) and not lineno.location.unknown: lineno = lineno.location.lineno if isinstance(lineno, int): - self.lines.append(f"!lineno {lineno}\n") + return lineno + return None @contextmanager def block(self): """Wrap statements in an indented block.""" - previous_indentation = self.indentation - self.indentation += " " + previous_statements = self.statements + self.statements = self.statements[-1].children try: yield finally: - self.indentation = previous_indentation + self.statements = previous_statements def statement(self, code: str, *, lineno: Any = None): """Emit statement.""" - self.lineno(lineno) - self.lines.append(f"{self.indentation}{code}\n") + self.statements.append(CodegenStatement(code, self.extract_lineno(lineno))) @contextmanager def function(self, name: str, *args: str, return_type: str = ""): @@ -262,13 +277,13 @@ def else_statement(self): with self.if_statement(self.condition_inverse): yield - def enclose(self, code: str, from_index: int): - """Enclose lines starting from the given index.""" - self.lines[from_index:] = [ - line if line.startswith("!") else f" {line}" - for line in self.lines[from_index:] + def enclose(self, code: str, from_index: int, *, lineno: Any = None): + """Enclose statements starting from the given index.""" + self.statements[from_index:] = [ + CodegenStatement( + code, self.extract_lineno(lineno), self.statements[from_index:] + ) ] - self.lines.insert(from_index, f"{self.indentation}{code}\n") @dataclass @@ -364,7 +379,7 @@ def visit_multiple( """Yield all the nodes and return a single result pointing to the new children.""" current_count = 0 collector: Optional[ChildrenCollector] = None - index = len(acc.lines) + index = len(acc.statements) previous_siblings = acc.current_siblings previous_sibling_index = acc.current_sibling_index @@ -379,14 +394,14 @@ def visit_multiple( if not collector: collector = children_collector(acc, index) - lines = acc.lines[index:] - del acc.lines[index:] + statements = acc.statements[index:] + del acc.statements[index:] collector.add_static(*children[current_count:i]) - acc.lines.extend(lines) + acc.statements.extend(statements) collector.add_dynamic(*result) current_count = i + 1 - index = len(acc.lines) + index = len(acc.statements) acc.current_siblings = previous_siblings acc.current_sibling_index = previous_sibling_index @@ -536,7 +551,7 @@ def command( return [acc.replace(acc.make_ref(node), arguments=arguments)] arguments = yield from visit_multiple(node.arguments[:-1], acc) - nesting_index = len(acc.lines) + nesting_index = len(acc.statements) nesting = yield from visit_single(node.arguments[-1]) if nesting is None: