diff --git a/README.md b/README.md index b3fb5866..bd0c6820 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,23 @@ vyper: Import the voting contract types like this: ```python -# @version 0.3.9 +# @version 0.3.10 import voting.ballot as ballot ``` + +### Pragmas + +Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493) + +#### Version Pragma + +```python +#pragma version 0.3.10 +``` + +#### Optimization Pragma + +```python +#pragma optimize codesize +``` diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index a1284ac6..411fe195 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -68,9 +68,9 @@ def _install_vyper(version: Version): ) from err -def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: +def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: """ - Extracts pragma information from Vyper source code. + Extracts version pragma information from Vyper source code. Args: source (str): Vyper source code @@ -78,21 +78,42 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]: Returns: ``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found. """ + version_pragma_patterns = [ + r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", + r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", + ] + source_str = source if isinstance(source, str) else source.read_text() - pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None) - if pragma_match is None: - return None # Try compiling with latest + for pattern in version_pragma_patterns: + for match in re.finditer(pattern, source_str): + raw_pragma = match.groups()[0] + pragma_str = " ".join(raw_pragma.split()).replace("^", "~=") + if pragma_str and pragma_str[0].isnumeric(): + pragma_str = f"=={pragma_str}" - raw_pragma = pragma_match.groups()[0] - pragma_str = " ".join(raw_pragma.split()).replace("^", "~=") - if pragma_str and pragma_str[0].isnumeric(): - pragma_str = f"=={pragma_str}" + try: + return SpecifierSet(pragma_str) + except InvalidSpecifier: + logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.") + return None + return None - try: - return SpecifierSet(pragma_str) - except InvalidSpecifier: - logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.") + +def get_optimization_pragma(source: Union[str, Path]) -> Optional[str]: + """ + Extracts optimization pragma information from Vyper source code. + + Args: + source (Union[str, Path]): Vyper source code + + Returns: + ``str``, or None if no valid pragma is found. + """ + source_str = source if isinstance(source, str) else source.read_text() + pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None) + if pragma_match is None: return None + return pragma_match.groups()[0] class VyperCompiler(CompilerAPI): @@ -145,7 +166,7 @@ def get_imports( def get_versions(self, all_paths: List[Path]) -> Set[str]: versions = set() for path in all_paths: - if version_spec := get_pragma_spec(path): + if version_spec := get_version_pragma_spec(path): try: # Make sure we have the best compiler available to compile this version_iter = version_spec.filter(self.available_versions) @@ -270,6 +291,13 @@ def import_remapping(self) -> Dict[str, Dict]: return interfaces + def classify_ast(self, _node: ASTNode): + if _node.ast_type in _FUNCTION_AST_TYPES: + _node.classification = ASTClassification.FUNCTION + + for child in _node.children: + self.classify_ast(child) + def compile( self, contract_filepaths: List[Path], base_path: Optional[Path] = None ) -> List[ContractType]: @@ -281,90 +309,105 @@ def compile( all_settings = self.get_compiler_settings(sources, base_path=base_path) for vyper_version, source_paths in version_map.items(): - settings = all_settings.get(vyper_version, {}) - path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths} - input_json = { - "language": "Vyper", - "settings": settings, - "sources": {s: {"content": p.read_text()} for s, p in path_args.items()}, - } - if interfaces := self.import_remapping: - input_json["interfaces"] = interfaces - - vyper_binary = compiler_data[vyper_version]["vyper_binary"] - try: - result = vvm.compile_standard( - input_json, - base_path=base_path, - vyper_version=vyper_version, - vyper_binary=vyper_binary, - ) - except VyperError as err: - raise VyperCompileError(err) from err - - def classify_ast(_node: ASTNode): - if _node.ast_type in _FUNCTION_AST_TYPES: - _node.classification = ASTClassification.FUNCTION + version_settings = all_settings.get(vyper_version, {}) + optimizations_map = self.get_optimization_pragma_map(list(source_paths)) + + for optimization, source_paths in optimizations_map.items(): + settings: Dict[str, Any] = version_settings.copy() + settings["optimize"] = optimization or True + path_args = { + str(get_relative_path(p.absolute(), base_path)): p for p in source_paths + } + settings["outputSelection"] = {s: ["*"] for s in path_args} + input_json = { + "language": "Vyper", + "settings": settings, + "sources": {s: {"content": p.read_text()} for s, p in path_args.items()}, + } - for child in _node.children: - classify_ast(child) + if interfaces := self.import_remapping: + input_json["interfaces"] = interfaces - for source_id, output_items in result["contracts"].items(): - content = { - i + 1: ln - for i, ln in enumerate((base_path / source_id).read_text().splitlines()) - } - for name, output in output_items.items(): - # De-compress source map to get PC POS map. - ast = ASTNode.parse_obj(result["sources"][source_id]["ast"]) - classify_ast(ast) - - # Track function offsets. - function_offsets = [] - for node in ast.children: - lineno = node.lineno - - # NOTE: Constructor is handled elsewhere. - if node.ast_type == "FunctionDef" and "__init__" not in content.get( - lineno, "" - ): - function_offsets.append((node.lineno, node.end_lineno)) - - evm = output["evm"] - bytecode = evm["deployedBytecode"] - opcodes = bytecode["opcodes"].split(" ") - compressed_src_map = SourceMap(__root__=bytecode["sourceMap"]) - src_map = list(compressed_src_map.parse())[1:] - - pcmap = ( - _get_legacy_pcmap(ast, src_map, opcodes) - if vyper_version <= Version("0.3.7") - else _get_pcmap(bytecode) + vyper_binary = compiler_data[vyper_version]["vyper_binary"] + try: + result = vvm.compile_standard( + input_json, + base_path=base_path, + vyper_version=vyper_version, + vyper_binary=vyper_binary, ) + except VyperError as err: + raise VyperCompileError(err) from err + + for source_id, output_items in result["contracts"].items(): + content = { + i + 1: ln + for i, ln in enumerate((base_path / source_id).read_text().splitlines()) + } + for name, output in output_items.items(): + # De-compress source map to get PC POS map. + ast = ASTNode.parse_obj(result["sources"][source_id]["ast"]) + self.classify_ast(ast) + + # Track function offsets. + function_offsets = [] + for node in ast.children: + lineno = node.lineno + + # NOTE: Constructor is handled elsewhere. + if node.ast_type == "FunctionDef" and "__init__" not in content.get( + lineno, "" + ): + function_offsets.append((node.lineno, node.end_lineno)) + + evm = output["evm"] + bytecode = evm["deployedBytecode"] + opcodes = bytecode["opcodes"].split(" ") + compressed_src_map = SourceMap(__root__=bytecode["sourceMap"]) + src_map = list(compressed_src_map.parse())[1:] + + pcmap = ( + _get_legacy_pcmap(ast, src_map, opcodes) + if vyper_version <= Version("0.3.7") + else _get_pcmap(bytecode) + ) - # Find content-specified dev messages. - dev_messages = {} - for line_no, line in content.items(): - if match := re.search(DEV_MSG_PATTERN, line): - dev_messages[line_no] = match.group(1).strip() - - contract_type = ContractType( - ast=ast, - contractName=name, - sourceId=source_id, - deploymentBytecode={"bytecode": evm["bytecode"]["object"]}, - runtimeBytecode={"bytecode": bytecode["object"]}, - abi=output["abi"], - sourcemap=compressed_src_map, - pcmap=pcmap, - userdoc=output["userdoc"], - devdoc=output["devdoc"], - dev_messages=dev_messages, - ) - contract_types.append(contract_type) + # Find content-specified dev messages. + dev_messages = {} + for line_no, line in content.items(): + if match := re.search(DEV_MSG_PATTERN, line): + dev_messages[line_no] = match.group(1).strip() + + contract_type = ContractType( + ast=ast, + contractName=name, + sourceId=source_id, + deploymentBytecode={"bytecode": evm["bytecode"]["object"]}, + runtimeBytecode={"bytecode": bytecode["object"]}, + abi=output["abi"], + sourcemap=compressed_src_map, + pcmap=pcmap, + userdoc=output["userdoc"], + devdoc=output["devdoc"], + dev_messages=dev_messages, + ) + contract_types.append(contract_type) return contract_types + def get_optimization_pragma_map( + self, contract_filepaths: List[Path], base_path: Optional[Path] = None + ) -> Dict[Union[str, bool], Set[Path]]: + base_path = base_path or self.config_manager.contracts_folder + optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {} + for path in contract_filepaths: + pragma = get_optimization_pragma(path) or True + if pragma not in optimization_pragma_map: + optimization_pragma_map[pragma] = set() + optimization_pragma_map[pragma].add(path) + + return optimization_pragma_map + def get_version_map( self, contract_filepaths: List[Path], base_path: Optional[Path] = None ) -> Dict[Version, Set[Path]]: @@ -374,7 +417,7 @@ def get_version_map( # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): - if pragma := get_pragma_spec(path): + if pragma := get_version_pragma_spec(path): _safe_append(source_path_by_pragma_spec, pragma, path) else: source_paths_without_pragma.add(path) @@ -441,10 +484,6 @@ def get_compiler_settings( continue version_settings: Dict = {"optimize": True} - path_args = { - str(get_relative_path(p.absolute(), contracts_path)): p for p in source_paths - } - version_settings["outputSelection"] = {s: ["*"] for s in path_args} if evm_version := data.get("evm_version"): version_settings["evmVersion"] = evm_version @@ -955,7 +994,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap: error_str = RuntimeErrorType.FALLBACK_NOT_DEFINED.value use_loc = False elif "bad calldatasize or callvalue" in error_type: - # Only on >=0.3.10rc3. + # Only on >=0.3.10. # NOTE: We are no longer able to get Nonpayable checks errors since they # are now combined. error_str = RuntimeErrorType.INVALID_CALLDATA_OR_VALUE.value diff --git a/ape_vyper/exceptions.py b/ape_vyper/exceptions.py index d70f3427..b3467af9 100644 --- a/ape_vyper/exceptions.py +++ b/ape_vyper/exceptions.py @@ -86,7 +86,7 @@ def __init__(self, **kwargs): class InvalidCalldataOrValueError(VyperRuntimeError): """ - Raises on Vyper versions >= 0.3.10rc3 in place of NonPayableError. + Raises on Vyper versions >= 0.3.10 in place of NonPayableError. """ def __init__(self, **kwargs): diff --git a/setup.py b/setup.py index bf635bd5..30a487ce 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "black>=23.3.0,<24", # Auto-formatter and linter "mypy>=0.991,<1", # Static type analyzer "types-setuptools", # Needed due to mypy typeshed + "pydantic<2.0", # Needed for successful type check. TODO: Remove after full v2 support. "flake8>=6.0.0,<7", # Style linter "isort>=5.10.1", # Import sorting linter "mdformat>=0.7.16", # Auto-formatter for markdown diff --git a/tests/conftest.py b/tests/conftest.py index 616edad7..dd803fba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,14 +35,14 @@ "0.3.4", "0.3.7", "0.3.9", - "0.3.10rc3", + "0.3.10", ) CONTRACT_VERSION_GEN_MAP = { "": ( "0.3.7", "0.3.9", - "0.3.10rc3", + "0.3.10", ), "sub_reverts": ALL_VERSIONS, } @@ -188,7 +188,7 @@ def account(): return ape.accounts.test_accounts[0] -@pytest.fixture(params=("037", "039", "0310rc3")) +@pytest.fixture(params=("037", "039", "0310")) def traceback_contract(request, account, project, geth_provider): return _get_tb_contract(request.param, project, account) diff --git a/tests/contracts/passing_contracts/contract_with_dev_messages.vy b/tests/contracts/passing_contracts/contract_with_dev_messages.vy index 04c39a81..9655de45 100644 --- a/tests/contracts/passing_contracts/contract_with_dev_messages.vy +++ b/tests/contracts/passing_contracts/contract_with_dev_messages.vy @@ -1,4 +1,4 @@ -# @version 0.3.9 +# @version 0.3.10 # Test dev messages in various code placements @external diff --git a/tests/contracts/passing_contracts/optimize_codesize.vy b/tests/contracts/passing_contracts/optimize_codesize.vy new file mode 100644 index 00000000..754f0ff0 --- /dev/null +++ b/tests/contracts/passing_contracts/optimize_codesize.vy @@ -0,0 +1,8 @@ +#pragma version 0.3.10 +#pragma optimize codesize + +x: uint256 + +@external +def __init__(): + self.x = 0 diff --git a/tests/contracts/passing_contracts/pragma_with_space.vy b/tests/contracts/passing_contracts/pragma_with_space.vy new file mode 100644 index 00000000..7b69ecf8 --- /dev/null +++ b/tests/contracts/passing_contracts/pragma_with_space.vy @@ -0,0 +1,7 @@ +# pragma version 0.3.10 + +x: uint256 + +@external +def __init__(): + self.x = 0 diff --git a/tests/test_ape_reverts.py b/tests/test_ape_reverts.py index 1c1e7d70..abebef26 100644 --- a/tests/test_ape_reverts.py +++ b/tests/test_ape_reverts.py @@ -10,7 +10,7 @@ def older_reverts_contract(account, project, geth_provider, request): return container.deploy(sender=account) -@pytest.fixture(params=("037", "039", "0310rc3")) +@pytest.fixture(params=("037", "039", "0310")) def reverts_contract_instance(account, project, geth_provider, request): sub_reverts_container = project.get_contract(f"sub_reverts_{request.param}") sub_reverts = sub_reverts_container.deploy(sender=account) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index a6a107ca..b4258b5f 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -21,7 +21,7 @@ OLDER_VERSION_FROM_PRAGMA = Version("0.2.16") VERSION_37 = Version("0.3.7") -VERSION_FROM_PRAGMA = Version("0.3.9") +VERSION_FROM_PRAGMA = Version("0.3.10") @pytest.fixture @@ -98,12 +98,16 @@ def test_get_version_map(project, compiler, all_versions): "contract_with_dev_messages.vy", "erc20.vy", "use_iface.vy", + "optimize_codesize.vy", "use_iface2.vy", + "contract_no_pragma.vy", # no pragma should compile with latest version + "empty.vy", # empty file still compiles with latest version + "pragma_with_space.vy", ] - # Add the 0.3.9 contracts. + # Add the 0.3.10 contracts. for template in TEMPLATES: - expected.append(f"{template}_039.vy") + expected.append(f"{template}_0310.vy") names = [x.name for x in actual[VERSION_FROM_PRAGMA]] failures = [] @@ -141,7 +145,7 @@ def test_compiler_data_in_manifest(project): assert len(vyper_latest.contractTypes) >= 9 assert len(vyper_028.contractTypes) >= 1 - assert "contract_039" in vyper_latest.contractTypes + assert "contract_0310" in vyper_latest.contractTypes assert "older_version" in vyper_028.contractTypes for compiler in (vyper_latest, vyper_028): assert compiler.settings["evmVersion"] == "istanbul" @@ -263,7 +267,7 @@ def line(cont: str) -> int: if nonpayable_checks: assert len(nonpayable_checks) >= 1 else: - # NOTE: Vyper 0.3.10rc3 doesn't have these anymore. + # NOTE: Vyper 0.3.10 doesn't have these anymore. # But they do have a new error type instead. checks = _all(RuntimeErrorType.INVALID_CALLDATA_OR_VALUE) assert len(checks) >= 1 @@ -322,7 +326,7 @@ def test_enrich_error_int_overflow(geth_provider, traceback_contract, account): def test_enrich_error_non_payable_check(geth_provider, traceback_contract, account): - if traceback_contract.contract_type.name.endswith("0310rc3"): + if traceback_contract.contract_type.name.endswith("0310"): # NOTE: Nonpayable error is combined with calldata check now. with pytest.raises(InvalidCalldataOrValueError): traceback_contract.addBalance(123, sender=account, value=1)