From f0b95336d54b7089fc2b9bb29462359f98854cef Mon Sep 17 00:00:00 2001 From: z80 Date: Mon, 23 Oct 2023 23:38:48 -0400 Subject: [PATCH] feat: support new pragma formats --- ape_vyper/compiler.py | 195 +++++++++++++++++++++++++----------------- 1 file changed, 116 insertions(+), 79 deletions(-) diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index a1284ac6..4efa00f1 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -68,9 +68,9 @@ def _install_vyper(version: Version): ) from err -def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: +def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: """ - Extracts pragma information from Vyper source code. + Extracts version pragma information from Vyper source code. Args: source (str): Vyper source code @@ -81,7 +81,10 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: source_str = source if isinstance(source, str) else source.read_text() pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None) if pragma_match is None: - return None # Try compiling with latest + # support new pragma syntax + pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+version\s*([^\n]*)", source_str), None) + if pragma_match is None: + return None # Try compiling with latest raw_pragma = pragma_match.groups()[0] pragma_str = " ".join(raw_pragma.split()).replace("^", "~=") @@ -94,6 +97,22 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.") return None +def get_optimization_pragma(source: Union[str, Path]) -> Optional[str | bool]: + """ + Extracts optimization pragma information from Vyper source code. + + Args: + source (str): Vyper source code + Returns: + ``str``, or True if no valid pragma is found (for backwards compatibility). + """ + source_str = source if isinstance(source, str) else source.read_text() + pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None) + if pragma_match is None: + return True + return pragma_match.groups()[0] + + class VyperCompiler(CompilerAPI): @property @@ -145,7 +164,7 @@ def get_imports( def get_versions(self, all_paths: List[Path]) -> Set[str]: versions = set() for path in all_paths: - if version_spec := get_pragma_spec(path): + if version_spec := get_version_pragma_spec(path): try: # Make sure we have the best compiler available to compile this version_iter = version_spec.filter(self.available_versions) @@ -283,88 +302,106 @@ def compile( for vyper_version, source_paths in version_map.items(): settings = all_settings.get(vyper_version, {}) path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths} - input_json = { - "language": "Vyper", - "settings": settings, - "sources": {s: {"content": p.read_text()} for s, p in path_args.items()}, - } - if interfaces := self.import_remapping: - input_json["interfaces"] = interfaces - - vyper_binary = compiler_data[vyper_version]["vyper_binary"] - try: - result = vvm.compile_standard( - input_json, - base_path=base_path, - vyper_version=vyper_version, - vyper_binary=vyper_binary, - ) - except VyperError as err: - raise VyperCompileError(err) from err + optimizations_map = self.get_optimization_pragma_map(list(source_paths)) - def classify_ast(_node: ASTNode): - if _node.ast_type in _FUNCTION_AST_TYPES: - _node.classification = ASTClassification.FUNCTION + for optimization, paths in optimizations_map.items(): + input_json = { + "language": "Vyper", + "settings": settings, + "sources": {s: {"content": p.read_text()} for s, p in path_args.items()}, + } - for child in _node.children: - classify_ast(child) + input_json["settings"]["optimize"] = optimization + if interfaces := self.import_remapping: + input_json["interfaces"] = interfaces - for source_id, output_items in result["contracts"].items(): - content = { - i + 1: ln - for i, ln in enumerate((base_path / source_id).read_text().splitlines()) - } - for name, output in output_items.items(): - # De-compress source map to get PC POS map. - ast = ASTNode.parse_obj(result["sources"][source_id]["ast"]) - classify_ast(ast) - - # Track function offsets. - function_offsets = [] - for node in ast.children: - lineno = node.lineno - - # NOTE: Constructor is handled elsewhere. - if node.ast_type == "FunctionDef" and "__init__" not in content.get( - lineno, "" - ): - function_offsets.append((node.lineno, node.end_lineno)) - - evm = output["evm"] - bytecode = evm["deployedBytecode"] - opcodes = bytecode["opcodes"].split(" ") - compressed_src_map = SourceMap(__root__=bytecode["sourceMap"]) - src_map = list(compressed_src_map.parse())[1:] - - pcmap = ( - _get_legacy_pcmap(ast, src_map, opcodes) - if vyper_version <= Version("0.3.7") - else _get_pcmap(bytecode) + vyper_binary = compiler_data[vyper_version]["vyper_binary"] + try: + result = vvm.compile_standard( + input_json, + base_path=base_path, + vyper_version=vyper_version, + vyper_binary=vyper_binary, ) + except VyperError as err: + raise VyperCompileError(err) from err + + def classify_ast(_node: ASTNode): + if _node.ast_type in _FUNCTION_AST_TYPES: + _node.classification = ASTClassification.FUNCTION + + for child in _node.children: + classify_ast(child) + + for source_id, output_items in result["contracts"].items(): + content = { + i + 1: ln + for i, ln in enumerate((base_path / source_id).read_text().splitlines()) + } + for name, output in output_items.items(): + # De-compress source map to get PC POS map. + ast = ASTNode.parse_obj(result["sources"][source_id]["ast"]) + classify_ast(ast) + + # Track function offsets. + function_offsets = [] + for node in ast.children: + lineno = node.lineno + + # NOTE: Constructor is handled elsewhere. + if node.ast_type == "FunctionDef" and "__init__" not in content.get( + lineno, "" + ): + function_offsets.append((node.lineno, node.end_lineno)) + + evm = output["evm"] + bytecode = evm["deployedBytecode"] + opcodes = bytecode["opcodes"].split(" ") + compressed_src_map = SourceMap(__root__=bytecode["sourceMap"]) + src_map = list(compressed_src_map.parse())[1:] + + pcmap = ( + _get_legacy_pcmap(ast, src_map, opcodes) + if vyper_version <= Version("0.3.7") + else _get_pcmap(bytecode) + ) - # Find content-specified dev messages. - dev_messages = {} - for line_no, line in content.items(): - if match := re.search(DEV_MSG_PATTERN, line): - dev_messages[line_no] = match.group(1).strip() - - contract_type = ContractType( - ast=ast, - contractName=name, - sourceId=source_id, - deploymentBytecode={"bytecode": evm["bytecode"]["object"]}, - runtimeBytecode={"bytecode": bytecode["object"]}, - abi=output["abi"], - sourcemap=compressed_src_map, - pcmap=pcmap, - userdoc=output["userdoc"], - devdoc=output["devdoc"], - dev_messages=dev_messages, - ) - contract_types.append(contract_type) + # Find content-specified dev messages. + dev_messages = {} + for line_no, line in content.items(): + if match := re.search(DEV_MSG_PATTERN, line): + dev_messages[line_no] = match.group(1).strip() + + contract_type = ContractType( + ast=ast, + contractName=name, + sourceId=source_id, + deploymentBytecode={"bytecode": evm["bytecode"]["object"]}, + runtimeBytecode={"bytecode": bytecode["object"]}, + abi=output["abi"], + sourcemap=compressed_src_map, + pcmap=pcmap, + userdoc=output["userdoc"], + devdoc=output["devdoc"], + dev_messages=dev_messages, + ) + contract_types.append(contract_type) return contract_types + def get_optimization_pragma_map( + self, contract_filepaths: List[Path], base_path: Optional[Path] = None + ) -> Dict[str, Set[Path]]: + base_path = base_path or self.config_manager.contracts_folder + optimization_pragma_map: Dict[str, Set[Path]] = {} + for path in contract_filepaths: + if pragma := get_optimization_pragma(path): + if pragma not in optimization_pragma_map: + optimization_pragma_map[pragma] = set() + optimization_pragma_map[pragma].add(path) + + return optimization_pragma_map + def get_version_map( self, contract_filepaths: List[Path], base_path: Optional[Path] = None ) -> Dict[Version, Set[Path]]: @@ -374,7 +411,7 @@ def get_version_map( # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): - if pragma := get_pragma_spec(path): + if pragma := get_version_pragma_spec(path): _safe_append(source_path_by_pragma_spec, pragma, path) else: source_paths_without_pragma.add(path)