From 7096c4e8d7c07fd4a670db35af9cc1a8edc3f1c1 Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 11 Jun 2024 22:40:07 -0500 Subject: [PATCH] feat: vyper 0.4 support (#110) --- ape_vyper/__init__.py | 4 +- ape_vyper/compiler.py | 406 +++++++++++++----- ape_vyper/exceptions.py | 9 +- setup.py | 2 +- tests/conftest.py | 8 +- .../interfaces/IFaceZeroFour.vyi | 6 + .../contracts/passing_contracts/zero_four.vy | 16 + .../passing_contracts/zero_four_module.vy | 5 + tests/test_compiler.py | 47 +- 9 files changed, 374 insertions(+), 129 deletions(-) create mode 100644 tests/contracts/passing_contracts/interfaces/IFaceZeroFour.vyi create mode 100644 tests/contracts/passing_contracts/zero_four.vy create mode 100644 tests/contracts/passing_contracts/zero_four_module.vy diff --git a/ape_vyper/__init__.py b/ape_vyper/__init__.py index f3de0527..ebd7db41 100644 --- a/ape_vyper/__init__.py +++ b/ape_vyper/__init__.py @@ -1,6 +1,6 @@ from ape import plugins -from .compiler import VyperCompiler, VyperConfig +from .compiler import FileType, VyperCompiler, VyperConfig @plugins.register(plugins.Config) @@ -10,4 +10,4 @@ def config_class(): @plugins.register(plugins.CompilerPlugin) def register_compiler(): - return (".vy",), VyperCompiler + return tuple(e.value for e in FileType), VyperCompiler diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index 5a9aff2b..5aaf566c 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -5,6 +5,7 @@ from base64 import b64encode from collections import defaultdict from collections.abc import Iterable, Iterator +from enum import Enum from fnmatch import fnmatch from importlib import import_module from pathlib import Path @@ -15,15 +16,11 @@ from ape.api.compiler import CompilerAPI from ape.exceptions import ContractLogicError from ape.logging import logger -from ape.managers.project import ProjectManager +from ape.managers.project import LocalProject, ProjectManager from ape.types import ContractSourceCoverage, ContractType, SourceTraceback -from ape.utils import ( - cached_property, - get_full_extension, - get_relative_path, - pragma_str_to_specifier_set, -) +from ape.utils import cached_property, get_relative_path, pragma_str_to_specifier_set from ape.utils._github import _GithubClient +from ape.utils.os import clean_path, get_full_extension from eth_pydantic_types import HexBytes from eth_utils import is_0x_prefixed from ethpm_types import ASTNode, PackageManifest, PCMap, SourceMapItem @@ -76,9 +73,18 @@ "0.3.8": "shanghai", "0.3.9": "shanghai", "0.3.10": "shanghai", + "0.4.0rc6": "shanghai", } +class FileType(str, Enum): + SOURCE = ".vy" + INTERFACE = ".vyi" + + def __str__(self) -> str: + return self.value + + class Remapping(PluginConfig): key: str dependency_name: str @@ -199,7 +205,13 @@ def get_optimization_pragma(source: Union[str, Path]) -> Optional[str]: Returns: ``str``, or None if no valid pragma is found. """ - source_str = source if isinstance(source, str) else source.read_text() + if isinstance(source, str): + source_str = source + elif not source.is_file(): + return None + else: + source_str = source.read_text() + if pragma_match := next( re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None ): @@ -218,7 +230,13 @@ def get_evmversion_pragma(source: Union[str, Path]) -> Optional[str]: Returns: ``str``, or None if no valid pragma is found. """ - source_str = source if isinstance(source, str) else source.read_text() + if isinstance(source, str): + source_str = source + elif not source.is_file(): + return None + else: + source_str = source.read_text() + if pragma_match := next( re.finditer(r"(?:\n|^)\s*#pragma\s+evm-version\s+([^\n]*)", source_str), None ): @@ -265,59 +283,109 @@ def get_imports( contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None, ) -> dict[str, list[str]]: + pm = project or self.local_project + return self._get_imports(contract_filepaths, project=pm, handled=set()) + + def _get_imports( + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + handled: Optional[set[str]] = None, + ): pm = project or self.local_project import_map: defaultdict = defaultdict(list) + handled = handled or set() dependencies = self.get_dependencies(project=pm) for path in contract_filepaths: + if not path.is_file(): + continue + content = path.read_text().splitlines() source_id = str(get_relative_path(path.absolute(), pm.path.absolute())) + handled.add(source_id) for line in content: if line.startswith("import "): import_line_parts = line.replace("import ", "").split(" ") - suffix = import_line_parts[0].strip().replace(".", os.path.sep) + prefix = import_line_parts[0] elif line.startswith("from ") and " import " in line: - import_line_parts = line.replace("from ", "").split(" ") - module_name = import_line_parts[0].strip().replace(".", os.path.sep) - suffix = os.path.sep.join([module_name, import_line_parts[2].strip()]) + import_line_parts = line.replace("from ", "").strip().split(" ") + module_name = import_line_parts[0].strip() + prefix = os.path.sep.join([module_name, import_line_parts[2].strip()]) else: # Not an import line continue - # NOTE: Defaults to JSON (assuming from input JSON or a local JSON), - # unless a Vyper file exists. + dots = "" + while prefix.startswith("."): + dots += prefix[0] + prefix = prefix[1:] + + # Replace rest of dots with slashes. + prefix = prefix.replace(".", os.path.sep) + + if prefix.startswith("vyper/"): + if f"{prefix}.json" not in import_map[source_id]: + import_map[source_id].append(f"{prefix}.json") + + continue + + local_path = (path.parent / dots / prefix.lstrip(os.path.sep)).resolve() + local_prefix = str(local_path).replace(f"{pm.path}", "").lstrip(os.path.sep) + import_source_id = None - if (pm.interfaces_folder.parent / f"{suffix}.vy").is_file(): - import_source_id = f"{suffix}.vy" - - elif (pm.interfaces_folder.parent / f"{suffix}.json").is_file(): - import_source_id = f"{suffix}.json" - - elif suffix.startswith(f"vyper{os.path.sep}"): - # Vyper built-ins. - import_source_id = f"{suffix}.json" - - elif suffix.split(os.path.sep)[0] in dependencies: - dependency_name = suffix.split(os.path.sep)[0] - filestem = suffix.replace(f"{dependency_name}{os.path.sep}", "") - for version_str, dep_project in pm.dependencies[dependency_name].items(): - dependency = pm.dependencies.get_dependency(dependency_name, version_str) - path_id = dependency.package_id.replace("/", "_") - dependency_source_prefix = ( - f"{get_relative_path(dep_project.contracts_folder, dep_project.path)}" - ) - source_id_stem = f"{dependency_source_prefix}{os.path.sep}{filestem}" - for ext in (".vy", ".json"): - if f"{source_id_stem}{ext}" in dep_project.sources: - import_source_id = os.path.sep.join( - (path_id, version_str, f"{source_id_stem}{ext}") - ) - break + is_local = True + # NOTE: Defaults to JSON (assuming from input JSON or a local JSON), + # unless a Vyper file exists. + if (pm.path / f"{local_prefix}{FileType.SOURCE}").is_file(): + ext = FileType.SOURCE.value + elif (pm.path / f"{local_prefix}{FileType.SOURCE}").is_file(): + ext = FileType.INTERFACE.value + elif (pm.path / f"{local_prefix}{FileType.INTERFACE}").is_file(): + ext = FileType.INTERFACE.value else: - logger.error(f"Unable to find dependency {suffix}") - continue + ext = ".json" + dep_key = prefix.split(os.path.sep)[0] + if dep_key in dependencies: + dependency_name = prefix.split(os.path.sep)[0] + filestem = prefix.replace(f"{dependency_name}{os.path.sep}", "") + for version_str, dep_project in pm.dependencies[dependency_name].items(): + dependency = pm.dependencies.get_dependency( + dependency_name, version_str + ) + path_id = dependency.package_id.replace("/", "_") + contracts_path = dep_project.contracts_folder + dependency_source_prefix = ( + f"{get_relative_path(contracts_path, dep_project.path)}" + ) + source_id_stem = f"{dependency_source_prefix}{os.path.sep}{filestem}" + for ext in (".vy", ".json"): + if f"{source_id_stem}{ext}" in dep_project.sources: + import_source_id = os.path.sep.join( + (path_id, version_str, f"{source_id_stem}{ext}") + ) + # Also include imports of imports. + sub_imports = self._get_imports( + (dep_project.path / f"{source_id_stem}{ext}",), + project=dep_project, + handled=handled, + ) + for sub_import_ls in sub_imports.values(): + import_map[source_id].extend(sub_import_ls) + + is_local = False + break + + if is_local: + import_source_id = f"{local_prefix}{ext}" + full_path = local_path.parent / f"{local_path.stem}{ext}" + + # Also include imports of imports. + sub_imports = self._get_imports((full_path,), project=project, handled=handled) + for sub_import_ls in sub_imports.values(): + import_map[source_id].extend(sub_import_ls) if import_source_id and import_source_id not in import_map[source_id]: import_map[source_id].append(import_source_id) @@ -417,7 +485,7 @@ def get_dependencies( self, project: Optional[ProjectManager] = None ) -> dict[str, ProjectManager]: pm = project or self.local_project - config = self.get_config(pm) + config = self.get_config(project=pm) dependencies: dict[str, ProjectManager] = {} handled: set[str] = set() @@ -433,7 +501,9 @@ def get_dependencies( dependency = pm.dependencies.get_dependency(name, _version) dep_id = f"{dependency.name}_{dependency.version}" - if dep_id in handled: + if dep_id in handled or ( + isinstance(dependency.project, LocalProject) and dependency.project.path == pm.path + ): continue handled.add(dep_id) @@ -452,7 +522,9 @@ def get_dependencies( # Add auto-remapped dependencies. for dependency in pm.dependencies.specified: dep_id = f"{dependency.name}_{dependency.version}" - if dep_id in handled: + if dep_id in handled or ( + isinstance(dependency.project, LocalProject) and dependency.project.path == pm.path + ): continue handled.add(dep_id) @@ -500,31 +572,61 @@ def compile( settings: Optional[dict] = None, ) -> Iterator[ContractType]: pm = project or self.local_project - settings = settings or {} - sources = [p for p in contract_filepaths if p.parent.name != "interfaces"] + self.compiler_settings = {**self.compiler_settings, **(settings or {})} contract_types: list[ContractType] = [] - if version := settings.get("version", None): - version_map = {Version(version): set(sources)} - else: - version_map = self.get_version_map(sources, project=project) - - compiler_data = self._get_compiler_arguments(version_map, project=pm) - all_settings: dict = self.get_compiler_settings( - sources, project=project, **(settings or {}) + import_map = self.get_imports(contract_filepaths, project=pm) + config = self.get_config(pm) + version_map = self._get_version_map_from_import_map( + contract_filepaths, + import_map, + project=pm, + config=config, ) + compiler_data = self._get_compiler_arguments(version_map, project=pm, config=config) + all_settings = self._get_compiler_settings_from_version_map(version_map, project=pm) contract_versions: dict[str, tuple[Version, str]] = {} import_remapping = self.get_import_remapping(project=pm) for vyper_version, version_settings in all_settings.items(): for settings_key, settings_set in version_settings.items(): - source_ids = settings_set["outputSelection"] - optimization_paths = {p: pm.path / p for p in source_ids} + if vyper_version >= Version("0.4.0rc1"): + sources = settings_set.get("outputSelection", {}) + if not sources: + continue + + src_dict = {p: {"content": Path(p).read_text()} for p in sources} + for src in sources: + if Path(src).is_absolute(): + src_id = f"{get_relative_path(Path(src), pm.path)}" + else: + src_id = src + + if imports := import_map.get(src_id): + for imp in imports: + if imp in src_dict: + continue + + imp_path = pm.path / imp + if not imp_path.is_file(): + continue + + src_dict[str(imp_path)] = {"content": imp_path.read_text()} + + else: + # NOTE: Pre vyper 0.4.0, interfaces CANNOT be in the source dict, + # but post 0.4.0, they MUST be. + src_dict = { + s: {"content": p.read_text()} + for s, p in { + p: pm.path / p for p in settings_set["outputSelection"] + }.items() + if p.parent != pm.path / "interfaces" + } + input_json: dict = { "language": "Vyper", "settings": settings_set, - "sources": { - s: {"content": p.read_text()} for s, p in optimization_paths.items() - }, + "sources": src_dict, } if interfaces := import_remapping: @@ -532,20 +634,28 @@ def compile( # Output compiler details. keys = ( - "\n\t".join(sorted([x for x in input_json.get("sources", {}).keys()])) + "\n\t".join( + sorted( + [ + clean_path(Path(x)) + for x in settings_set.get("outputSelection", {}).keys() + ] + ) + ) or "No input." ) log_str = f"Compiling using Vyper compiler '{vyper_version}'.\nInput:\n\t{keys}" logger.info(log_str) vyper_binary = compiler_data[vyper_version]["vyper_binary"] + comp_kwargs = {"vyper_version": vyper_version, "vyper_binary": vyper_binary} + + # `base_path` is required for pre-0.4 versions or else imports won't resolve. + if vyper_version < Version("0.4.0rc6"): + comp_kwargs["base_path"] = pm.path + try: - result = vvm_compile_standard( - input_json, - base_path=pm.path, - vyper_version=vyper_version, - vyper_binary=vyper_binary, - ) + result = vvm_compile_standard(input_json, **comp_kwargs) except VyperError as err: raise VyperCompileError(err) from err @@ -573,7 +683,17 @@ def compile( evm = output["evm"] bytecode = evm["deployedBytecode"] opcodes = bytecode["opcodes"].split(" ") - compressed_src_map = SourceMap(root=bytecode["sourceMap"]) + + src_map_raw = bytecode["sourceMap"] + if isinstance(src_map_raw, str): + # <0.4 range. + compressed_src_map = SourceMap(root=src_map_raw) + else: + # >=0.4 range. + compressed_src_map = SourceMap( + root=src_map_raw["pc_pos_map_compressed"] + ) + src_map = list(compressed_src_map.parse())[1:] pcmap = ( @@ -588,18 +708,26 @@ def compile( if match := re.search(DEV_MSG_PATTERN, line): dev_messages[line_no] = match.group(1).strip() - contract_type = ContractType( - ast=ast, - contractName=name, - sourceId=source_id, - deploymentBytecode={"bytecode": evm["bytecode"]["object"]}, - runtimeBytecode={"bytecode": bytecode["object"]}, - abi=output["abi"], - sourcemap=compressed_src_map, - pcmap=pcmap, - userdoc=output["userdoc"], - devdoc=output["devdoc"], - dev_messages=dev_messages, + source_id_path = Path(source_id) + if source_id_path.is_absolute(): + final_source_id = f"{get_relative_path(Path(source_id), pm.path)}" + else: + final_source_id = source_id + + contract_type = ContractType.model_validate( + { + "ast": ast, + "contractName": name, + "sourceId": final_source_id, + "deploymentBytecode": {"bytecode": evm["bytecode"]["object"]}, + "runtimeBytecode": {"bytecode": bytecode["object"]}, + "abi": output["abi"], + "sourcemap": compressed_src_map, + "pcmap": pcmap, + "userdoc": output["userdoc"], + "devdoc": output["devdoc"], + "dev_messages": dev_messages, + } ) contract_types.append(contract_type) contract_versions[name] = (vyper_version, settings_key) @@ -663,11 +791,13 @@ def compile_code( raise VyperCompileError(str(err)) from err output = result.get("", {}) - return ContractType( - abi=output["abi"], - deploymentBytecode={"bytecode": output["bytecode"]}, - runtimeBytecode={"bytecode": output["bytecode_runtime"]}, - **kwargs, + return ContractType.model_validate( + { + "abi": output["abi"], + "deploymentBytecode": {"bytecode": output["bytecode"]}, + "runtimeBytecode": {"bytecode": output["bytecode_runtime"]}, + **kwargs, + } ) def _source_vyper_version(self, code: str) -> Version: @@ -800,17 +930,32 @@ def get_version_map( project: Optional[ProjectManager] = None, ) -> dict[Version, set[Path]]: pm = project or self.local_project - config = self.get_config(pm) + import_map = self.get_imports(contract_filepaths, project=pm) + return self._get_version_map_from_import_map(contract_filepaths, import_map, project=pm) + + def _get_version_map_from_import_map( + self, + contract_filepaths: Iterable[Path], + import_map: dict[str, list[str]], + project: Optional[ProjectManager] = None, + config: Optional[PluginConfig] = None, + ): + pm = project or self.local_project + self.compiler_settings = {**self.compiler_settings} + config = config or self.get_config(pm) version_map: dict[Version, set[Path]] = {} source_path_by_version_spec: dict[SpecifierSet, set[Path]] = {} source_paths_without_pragma = set() # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): + src_id = f"{get_relative_path(path.absolute(), pm.path)}" + imports = [pm.path / imp for imp in import_map.get(src_id, [])] + if config_spec := config.version: - _safe_append(source_path_by_version_spec, config_spec, path) + _safe_append(source_path_by_version_spec, config_spec, {path, *imports}) elif pragma := get_version_pragma_spec(path): - _safe_append(source_path_by_version_spec, pragma, path) + _safe_append(source_path_by_version_spec, pragma, {path, *imports}) else: source_paths_without_pragma.add(path) @@ -844,7 +989,9 @@ def get_version_map( for pragma_spec, path_set in source_path_by_version_spec.items(): versions = sorted(list(pragma_spec.filter(self.installed_versions)), reverse=True) if versions: - _safe_append(version_map, versions[0], path_set) + _safe_append( + version_map, versions[0], {p for p in path_set if p in contract_filepaths} + ) if not self.installed_versions: # If we have no installed versions by this point, we need to install one. @@ -853,12 +1000,21 @@ def get_version_map( # Handle no-pragma sources if source_paths_without_pragma: - max_installed_vyper_version = ( - max(version_map) - if version_map - else max(v for v in self.installed_versions if not v.pre) + versions_given = [x for x in version_map.keys()] + max_installed_vyper_version = None + if versions_given: + version_given_non_pre = [x for x in versions_given if not x.pre] + if version_given_non_pre: + max_installed_vyper_version = max(version_given_non_pre) + + if max_installed_vyper_version is None: + max_installed_vyper_version = max(v for v in self.installed_versions if not v.pre) + + _safe_append( + version_map, + max_installed_vyper_version, + {p for p in source_paths_without_pragma if p in contract_filepaths}, ) - _safe_append(version_map, max_installed_vyper_version, source_paths_without_pragma) return version_map @@ -869,19 +1025,27 @@ def get_compiler_settings( **kwargs, ) -> dict[Version, dict]: pm = project or self.local_project - valid_paths = [p for p in contract_filepaths if p.suffix == ".vy"] - if version := kwargs.pop("version", None): - files_by_vyper_version = {Version(version): set(valid_paths)} - else: - files_by_vyper_version = self.get_version_map(valid_paths, project=pm) - - if not files_by_vyper_version: - return {} + # NOTE: Interfaces cannot be in the outputSelection + # (but are required in `sources` for the 0.4.0 range). + valid_paths = [ + p + for p in contract_filepaths + if get_full_extension(p) == FileType.SOURCE + and not str(p).startswith(str(pm.path / "interfaces")) + ] + version_map = self.get_version_map(valid_paths, project=pm) + return self._get_compiler_settings_from_version_map(version_map, project=pm) - compiler_data = self._get_compiler_arguments(files_by_vyper_version, project=pm) + def _get_compiler_settings_from_version_map( + self, + version_map: dict[Version, set[Path]], + project: Optional[ProjectManager] = None, + ): + pm = project or self.local_project + compiler_data = self._get_compiler_arguments(version_map, project=pm) settings = {} for version, data in compiler_data.items(): - source_paths = list(files_by_vyper_version.get(version, [])) + source_paths = list(version_map.get(version, [])) if not source_paths: continue @@ -909,9 +1073,26 @@ def get_compiler_settings( elif optimization == "false": optimization = False + if version >= Version("0.4.0rc6"): + # Vyper 0.4.0 seems to require absolute paths. + selection_dict = { + (pm.path / s).as_posix(): ["*"] + for s in selection + if (pm.path / s).is_file() + and f"interfaces{os.path.sep}" not in s + and get_full_extension(pm.path / s) != FileType.INTERFACE + } + else: + selection_dict = { + s: ["*"] + for s in selection + if (pm.path / s).is_file() + if "interfaces" not in s + } + version_settings[settings_key] = { "optimize": optimization, - "outputSelection": {s: ["*"] for s in selection}, + "outputSelection": selection_dict, } if evm_version and evm_version not in ("none", "null"): version_settings[settings_key]["evmVersion"] = f"{evm_version}" @@ -1099,10 +1280,13 @@ def _profile(_name: str, _full_name: str): contract_coverage.include(method.name, method.selector) def _get_compiler_arguments( - self, version_map: dict, project: Optional[ProjectManager] = None + self, + version_map: dict, + project: Optional[ProjectManager] = None, + config: Optional[PluginConfig] = None, ) -> dict[Version, dict]: pm = project or self.local_project - config = self.get_config(pm) + config = config or self.get_config(pm) evm_version = config.evm_version arguments_map = {} for vyper_version, source_paths in version_map.items(): @@ -1181,12 +1365,12 @@ def _get_traceback( pcmap = PCMap.model_validate({}) for frame in frames: - if frame["op"] in CALL_OPCODES: + if frame["op"] in [c.value for c in CALL_OPCODES]: start_depth = frame["depth"] called_contract, sub_calldata = self._create_contract_from_call(frame) if called_contract: ext = get_full_extension(Path(called_contract.source_id)) - if ext.endswith(".vy"): + if ext in [x for x in FileType]: # Called another Vyper contract. sub_trace = self._get_traceback( called_contract, frames, sub_calldata, previous_depth=frame["depth"] @@ -1415,7 +1599,7 @@ def _has_empty_revert(opcodes: list[str]) -> bool: def _get_pcmap(bytecode: dict) -> PCMap: # Find the non payable value check. - src_info = bytecode["sourceMapFull"] + src_info = bytecode["sourceMapFull"] if "sourceMapFull" in bytecode else bytecode["sourceMap"] pc_data = {pc: {"location": ln} for pc, ln in src_info["pc_pos_map"].items()} if not pc_data: return PCMap.model_validate({}) diff --git a/ape_vyper/exceptions.py b/ape_vyper/exceptions.py index 3234d483..b45fc9d7 100644 --- a/ape_vyper/exceptions.py +++ b/ape_vyper/exceptions.py @@ -29,8 +29,15 @@ def __init__(self, err: Union[VyperError, str]): message = "\n\n".join( f"{e['sourceLocation']['file']}\n{e['type']}:" f"{e.get('formattedMessage', e['message'])}" - for e in err.error_dict + for e in (err.error_dict or {}) ) + # Try to find any indication of error. + message = message or getattr(err, "message", "") + + # If is only the default, check stderr. + if message == "An error occurred during execution" and getattr(err, "stderr_data", ""): + message = err.stderr_data + else: self.base_err = None message = str(err) diff --git a/setup.py b/setup.py index 57092a0d..f4e367de 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ url="https://github.com/ApeWorX/ape-vyper", include_package_data=True, install_requires=[ - "eth-ape>=0.8.2,<0.9", + "eth-ape>=0.8.3,<0.9", "ethpm-types", # Use same version as eth-ape "tqdm", # Use same version as eth-ape "vvm>=0.2.0,<0.3", diff --git a/tests/conftest.py b/tests/conftest.py index 0d33c100..0c72115e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import ape import pytest import vvm # type: ignore +from ape.contracts import ContractContainer from click.testing import CliRunner BASE_CONTRACTS_PATH = Path(__file__).parent / "contracts" @@ -28,6 +29,7 @@ "0.3.7", "0.3.9", "0.3.10", + "0.4.0rc6", ) CONTRACT_VERSION_GEN_MAP = { @@ -36,7 +38,7 @@ "0.3.9", "0.3.10", ), - "sub_reverts": ALL_VERSIONS, + "sub_reverts": [v for v in ALL_VERSIONS if "0.4.0" not in v], } @@ -215,7 +217,11 @@ def cli_runner(): def _get_tb_contract(version: str, project, account): + project.load_contracts() + registry_type = project.get_contract(f"registry_{version}") + assert isinstance(registry_type, ContractContainer), "Setup failed - couldn't get container" registry = account.deploy(registry_type) contract = project.get_contract(f"traceback_contract_{version}") + assert isinstance(contract, ContractContainer), "Setup failed - couldn't get container" return account.deploy(contract, registry) diff --git a/tests/contracts/passing_contracts/interfaces/IFaceZeroFour.vyi b/tests/contracts/passing_contracts/interfaces/IFaceZeroFour.vyi new file mode 100644 index 00000000..969691dd --- /dev/null +++ b/tests/contracts/passing_contracts/interfaces/IFaceZeroFour.vyi @@ -0,0 +1,6 @@ +# pragma version ~=0.4.0rc6 + +@external +@view +def implementThisPlease(role: bytes32) -> bool: + ... diff --git a/tests/contracts/passing_contracts/zero_four.vy b/tests/contracts/passing_contracts/zero_four.vy new file mode 100644 index 00000000..f3b72488 --- /dev/null +++ b/tests/contracts/passing_contracts/zero_four.vy @@ -0,0 +1,16 @@ +# pragma version ~=0.4.0rc6 + +import interfaces.IFaceZeroFour as IFaceZeroFour +implements: IFaceZeroFour + +from . import zero_four_module as zero_four_module + +@external +@view +def implementThisPlease(role: bytes32) -> bool: + return True + + +@external +def callModuleFunction(role: bytes32) -> bool: + return zero_four_module.moduleMethod() diff --git a/tests/contracts/passing_contracts/zero_four_module.vy b/tests/contracts/passing_contracts/zero_four_module.vy new file mode 100644 index 00000000..cc00b46c --- /dev/null +++ b/tests/contracts/passing_contracts/zero_four_module.vy @@ -0,0 +1,5 @@ +# pragma version ~=0.4.0rc6 + +@internal +def moduleMethod() -> bool: + return True diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 2278acd3..a2ac7fde 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -5,6 +5,7 @@ import pytest import vvm # type: ignore from ape.exceptions import CompilerError, ContractLogicError +from ape.utils import get_full_extension from ethpm_types import ContractType from packaging.version import Version from vvm.exceptions import VyperError # type: ignore @@ -25,6 +26,7 @@ OLDER_VERSION_FROM_PRAGMA = Version("0.2.16") VERSION_37 = Version("0.3.7") VERSION_FROM_PRAGMA = Version("0.3.10") +VERSION_04 = Version("0.4.0rc6") @pytest.fixture @@ -46,14 +48,26 @@ def dev_revert_source(project): def test_compile_project(project): - contracts = project.load_contracts() - assert len(contracts) == len( - [p.name for p in project.contracts_folder.glob("*.vy") if p.is_file()] + actual = sorted(list(project.load_contracts().keys())) + + # NOTE: Ignore interfaces for this test. + expected = sorted( + [ + p.stem + for p in project.contracts_folder.rglob("*.vy") + if p.is_file() and not p.name.startswith("I") + ] ) - prefix = "contracts/passing_contracts" - assert contracts["contract_039"].source_id == f"{prefix}/contract_039.vy" - assert contracts["contract_no_pragma"].source_id == f"{prefix}/contract_no_pragma.vy" - assert contracts["older_version"].source_id == f"{prefix}/older_version.vy" + if missing := [e for e in expected if e not in actual]: + missing_str = ", ".join(missing) + pytest.xfail(f"Missing the following expected sources: {missing_str}") + if extra := [a for a in actual if a not in expected]: + extra_str = ", ".join(extra) + pytest.xfail(f"Received the following extra sources: {extra_str}") + + assert "contract_039" in actual + assert "contract_no_pragma" in actual + assert "older_version" in actual @pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) @@ -83,7 +97,9 @@ def test_install_failure(compiler): def test_get_version_map(project, compiler, all_versions): vyper_files = [ - x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" + x + for x in project.contracts_folder.iterdir() + if x.is_file() and get_full_extension(x) == ".vy" ] actual = compiler.get_version_map(vyper_files, project=project) expected_versions = [Version(v) for v in all_versions] @@ -137,6 +153,11 @@ def test_get_version_map(project, compiler, all_versions): assert not failures, "\n".join(failures) + # Vyper 0.4.0 assertions. + actual4 = {x.name for x in actual[VERSION_04]} + expected4 = {"zero_four_module.vy", "zero_four.vy"} + assert actual4 == expected4 + def test_compiler_data_in_manifest(project): def run_test(manifest): @@ -144,11 +165,11 @@ def run_test(manifest): all_latest = [c for c in manifest.compilers if str(c.version) == str(VERSION_FROM_PRAGMA)] codesize_latest = [c for c in all_latest if c.settings["optimize"] == "codesize"][0] - evm_latest = [c for c in all_latest if c.settings["evmVersion"] == "paris"][0] + evm_latest = [c for c in all_latest if c.settings.get("evmVersion") == "paris"][0] true_latest = [ c for c in all_latest - if c.settings["optimize"] is True and c.settings["evmVersion"] != "paris" + if c.settings["optimize"] is True and c.settings.get("evmVersion") != "paris" ][0] vyper_028 = [ c for c in manifest.compilers if str(c.version) == str(OLDER_VERSION_FROM_PRAGMA) @@ -167,7 +188,7 @@ def run_test(manifest): # There is only one contract with evm-version pragma. assert evm_latest.contractTypes == ["evm_pragma"] - assert evm_latest.settings["evmVersion"] == "paris" + assert evm_latest.settings.get("evmVersion") == "paris" assert len(true_latest.contractTypes) >= 9 assert len(vyper_028.contractTypes) >= 1 @@ -212,8 +233,8 @@ def test_get_imports(compiler, project): actual = compiler.get_imports(vyper_files, project=project) prefix = "contracts/passing_contracts" builtin_import = "vyper/interfaces/ERC20.json" - local_import = "interfaces/IFace.vy" - local_from_import = "interfaces/IFace2.vy" + local_import = f"{prefix}/interfaces/IFace.vy" + local_from_import = f"{prefix}/interfaces/IFace2.vy" dep_key = project.dependencies.get_dependency("exampledependency", "local").package_id.replace( "/", "_" )