From 63d02fe15de7fb9399658a934f23a68ec2fbb42c Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 28 May 2024 07:56:07 -0500 Subject: [PATCH] feat!: project refactor (#113) --- .pre-commit-config.yaml | 2 +- ape_vyper/compiler.py | 373 ++++++++++++------ tests/ape-config.yaml | 7 +- tests/conftest.py | 63 +-- .../contracts/passing_contracts/flatten_me.vy | 2 +- .../contracts/passing_contracts/use_iface.vy | 2 +- tests/test_compiler.py | 118 +++--- 7 files changed, 349 insertions(+), 218 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 23d53854..47442f7b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat - additional_dependencies: [mdformat-gfm, mdformat-frontmatter] + additional_dependencies: [mdformat-gfm, mdformat-frontmatter, mdformat-pyproject] default_language_version: python: python3 diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index a2b88737..04f6569f 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -3,6 +3,7 @@ import shutil import time from base64 import b64encode +from collections import defaultdict from collections.abc import Iterable, Iterator from fnmatch import fnmatch from importlib import import_module @@ -10,10 +11,11 @@ from typing import Any, Optional, Union, cast import vvm # type: ignore -from ape.api import PluginConfig -from ape.api.compiler import CompilerAPI, TraceAPI +from ape.api import PluginConfig, TraceAPI +from ape.api.compiler import CompilerAPI from ape.exceptions import ContractLogicError from ape.logging import logger +from ape.managers.project import ProjectManager from ape.types import ContractSourceCoverage, ContractType, SourceTraceback from ape.utils import ( cached_property, @@ -33,7 +35,7 @@ from evm_trace.geth import create_call_node_data from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version -from pydantic import field_serializer, field_validator +from pydantic import field_serializer, field_validator, model_validator from vvm import compile_standard as vvm_compile_standard from vvm.exceptions import VyperError # type: ignore @@ -77,6 +79,38 @@ } +class Remapping(PluginConfig): + key: str + dependency_name: str + dependency_version: Optional[None] = None + + @model_validator(mode="before") + @classmethod + def validate_str(cls, value): + if isinstance(value, str): + parts = value.split("=") + key = parts[0].strip() + value = parts[1].strip() + if "@" in value: + value_parts = value.split("@") + dep_name = value_parts[0].strip() + dep_version = value_parts[1].strip() + else: + dep_name = value + dep_version = None + + return {"key": key, "dependency_name": dep_name, "dependency_version": dep_version} + + return value + + def __str__(self) -> str: + value = self.dependency_name + if _version := self.dependency_version: + value = f"{value}@{_version}" + + return f"{self.key}={value}" + + class VyperConfig(PluginConfig): version: Optional[SpecifierSet] = None """ @@ -89,7 +123,7 @@ class VyperConfig(PluginConfig): The evm-version or hard-fork name. """ - import_remapping: list[str] = [] + import_remapping: list[Remapping] = [] """ Configuration of an import name mapped to a dependency listing. To use a specific version of a dependency, specify using ``@`` symbol. @@ -226,22 +260,17 @@ class VyperCompiler(CompilerAPI): def name(self) -> str: return "vyper" - @property - def settings(self) -> VyperConfig: - return cast(VyperConfig, super().settings) - - @property - def evm_version(self) -> Optional[str]: - return self.settings.evm_version - def get_imports( - self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, ) -> dict[str, list[str]]: - base_path = (base_path or self.project_manager.contracts_folder).absolute() - import_map = {} + pm = project or self.local_project + import_map: defaultdict = defaultdict(list) + dependencies = self.get_dependencies(project=pm) for path in contract_filepaths: content = path.read_text().splitlines() - source_id = str(get_relative_path(path.absolute(), base_path.absolute())) + source_id = str(get_relative_path(path.absolute(), pm.path.absolute())) for line in content: if line.startswith("import "): import_line_parts = line.replace("import ", "").split(" ") @@ -258,15 +287,42 @@ def get_imports( # NOTE: Defaults to JSON (assuming from input JSON or a local JSON), # unless a Vyper file exists. - ext = "vy" if (base_path / f"{suffix}.vy").is_file() else "json" + import_source_id = None + if (pm.interfaces_folder.parent / f"{suffix}.vy").is_file(): + import_source_id = f"{suffix}.vy" + + elif (pm.interfaces_folder.parent / f"{suffix}.json").is_file(): + import_source_id = f"{suffix}.json" + + elif suffix.startswith(f"vyper{os.path.sep}"): + # Vyper built-ins. + import_source_id = f"{suffix}.json" + + elif suffix.split(os.path.sep)[0] in dependencies: + dependency_name = suffix.split(os.path.sep)[0] + filestem = suffix.replace(f"{dependency_name}{os.path.sep}", "") + for version_str, dep_project in pm.dependencies[dependency_name].items(): + dependency = pm.dependencies.get_dependency(dependency_name, version_str) + path_id = dependency.package_id.replace("/", "_") + dependency_source_prefix = ( + f"{get_relative_path(dep_project.contracts_folder, dep_project.path)}" + ) + source_id_stem = f"{dependency_source_prefix}{os.path.sep}{filestem}" + for ext in (".vy", ".json"): + if f"{source_id_stem}{ext}" in dep_project.sources: + import_source_id = os.path.sep.join( + (path_id, version_str, f"{source_id_stem}{ext}") + ) + break - import_source_id = f"{suffix}.{ext}" - if source_id not in import_map: - import_map[source_id] = [import_source_id] - elif import_source_id not in import_map[source_id]: + else: + logger.error(f"Unable to find dependency {suffix}") + continue + + if import_source_id and import_source_id not in import_map[source_id]: import_map[source_id].append(import_source_id) - return import_map + return dict(import_map) def get_versions(self, all_paths: Iterable[Path]) -> set[str]: versions = set() @@ -357,62 +413,76 @@ def vyper_json(self): except ImportError: return None - @property - def config_version_pragma(self) -> Optional[SpecifierSet]: - if version := self.settings.version: - return version - - return None + def get_dependencies( + self, project: Optional[ProjectManager] = None + ) -> dict[str, ProjectManager]: + pm = project or self.local_project + config = self.get_config(pm) + dependencies: dict[str, ProjectManager] = {} + handled: set[str] = set() + + # Add remappings from config. + for remapping in config.import_remapping: + name = remapping.dependency_name + if not (_version := remapping.dependency_version): + versions = pm.dependencies[name] + if len(versions) == 1: + _version = versions[0] + else: + continue - @property - def remapped_manifests(self) -> dict[str, PackageManifest]: - """ - Interface import manifests. - """ + dependency = pm.dependencies.get_dependency(name, _version) + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - dependencies: dict[str, PackageManifest] = {} + handled.add(dep_id) - for remapping in self.settings.import_remapping: - key, value = remapping.split("=") + try: + dependency.compile() + except Exception as err: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'.\n" + f"Reason: {err}" + ) + continue - if remapping in dependencies: - dependency = dependencies[remapping] - else: - parts = value.split("@") - dep_name = parts[0] - dependency_versions = self.project_manager.dependencies[dep_name] - if not dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + dependencies[remapping.key] = dependency.project - elif len(parts) == 1 and len(dependency_versions) < 2: - # Use only version. - version = list(dependency_versions.keys())[0] + # Add auto-remapped dependencies. + for dependency in pm.dependencies.specified: + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - elif parts[1] not in dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + handled.add(dep_id) - else: - version = parts[1] + try: + dependency.compile() + except Exception as err: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'.\n" + f"Reason: {err}" + ) + continue - dependency = dependency_versions[version].compile() - dependencies[remapping] = dependency + dependencies[dependency.name] = dependency.project return dependencies - @property - def import_remapping(self) -> dict[str, dict]: + def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict[str, dict]: """ Configured interface imports from dependencies. """ - - interfaces = {} - - for remapping in self.settings.import_remapping: - key, _ = remapping.split("=") - for name, ct in (self.remapped_manifests[remapping].contract_types or {}).items(): - interfaces[f"{key}/{name}.json"] = { - "abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi] - } + pm = project or self.local_project + dependencies = self.get_dependencies(project=pm) + interfaces: dict[str, dict] = {} + for key, dependency_project in dependencies.items(): + manifest = dependency_project.manifest + for name, ct in (manifest.contract_types or {}).items(): + filename = f"{key}/{name}.json" + abi_list = [x.model_dump(mode="json", by_alias=True) for x in ct.abi] + interfaces[filename] = {"abi": abi_list} return interfaces @@ -424,36 +494,55 @@ def classify_ast(self, _node: ASTNode): self.classify_ast(child) def compile( - self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + settings: Optional[dict] = None, ) -> Iterator[ContractType]: - contract_types = [] - base_path = base_path or self.config_manager.contracts_folder + pm = project or self.local_project + settings = settings or {} sources = [p for p in contract_filepaths if p.parent.name != "interfaces"] - version_map = self.get_version_map(sources) - compiler_data = self._get_compiler_arguments(version_map, base_path) - all_settings = self.get_compiler_settings(sources, base_path=base_path) + contract_types: list[ContractType] = [] + if version := settings.get("version", None): + version_map = {Version(version): set(sources)} + else: + version_map = self.get_version_map(sources, project=project) + + compiler_data = self._get_compiler_arguments(version_map, project=pm) + all_settings: dict = self.get_compiler_settings( + sources, project=project, **(settings or {}) + ) contract_versions: dict[str, tuple[Version, str]] = {} + import_remapping = self.get_import_remapping(project=pm) for vyper_version, version_settings in all_settings.items(): - for settings_key, settings in version_settings.items(): - source_ids = settings["outputSelection"] - optimization_paths = {p: base_path / p for p in source_ids} - input_json = { + for settings_key, settings_set in version_settings.items(): + source_ids = settings_set["outputSelection"] + optimization_paths = {p: pm.path / p for p in source_ids} + input_json: dict = { "language": "Vyper", - "settings": settings, + "settings": settings_set, "sources": { s: {"content": p.read_text()} for s, p in optimization_paths.items() }, } - if interfaces := self.import_remapping: + if interfaces := import_remapping: input_json["interfaces"] = interfaces + # Output compiler details. + keys = ( + "\n\t".join(sorted([x for x in input_json.get("sources", {}).keys()])) + or "No input." + ) + log_str = f"Compiling using Vyper compiler '{vyper_version}'.\nInput:\n\t{keys}" + logger.info(log_str) + vyper_binary = compiler_data[vyper_version]["vyper_binary"] try: result = vvm_compile_standard( input_json, - base_path=base_path, + base_path=pm.path, vyper_version=vyper_version, vyper_binary=vyper_binary, ) @@ -463,7 +552,7 @@ def compile( for source_id, output_items in result["contracts"].items(): content = { i + 1: ln - for i, ln in enumerate((base_path / source_id).read_text().splitlines()) + for i, ln in enumerate((pm.path / source_id).read_text().splitlines()) } for name, output in output_items.items(): # De-compress source map to get PC POS map. @@ -529,13 +618,14 @@ def compile( if ct_version not in compilers_used: compilers_used[ct_version] = {} + contract_id = f"{ct.source_id}:{ct.name}" if ct_settings_key in compilers_used[ct_version] and ct.name not in ( compilers_used[ct_version][ct_settings_key].contractTypes or [] ): # Add contractType to already-tracked compiler. compilers_used[ct_version][ct_settings_key].contractTypes = [ *(compilers_used[ct_version][ct_settings_key].contractTypes or []), - ct.name, + contract_id, ] elif ct_settings_key not in compilers_used[ct_version]: @@ -543,7 +633,7 @@ def compile( compilers_used[ct_version][ct_settings_key] = Compiler( name=self.name.lower(), version=f"{ct_version}", - contractTypes=[ct.name], + contractTypes=[contract_id], settings=settings, ) @@ -556,10 +646,12 @@ def compile( # NOTE: This method handles merging contractTypes and filtered out # no longer used Compilers. - self.project_manager.local_project.add_compiler_data(compilers_ls) + pm.add_compiler_data(compilers_ls) - def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> ContractType: - base_path = base_path or self.project_manager.contracts_folder + def compile_code( + self, code: str, project: Optional[ProjectManager] = None, **kwargs + ) -> ContractType: + pm = project or self.local_project # Figure out what compiler version we need for this contract... version = self._source_vyper_version(code) @@ -567,7 +659,7 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> _install_vyper(version) try: - result = vvm.compile_source(code, base_path=base_path, vyper_version=version) + result = vvm.compile_source(code, base_path=pm.path, vyper_version=version) except Exception as err: raise VyperCompileError(str(err)) from err @@ -587,42 +679,43 @@ def first_full_release(versions: Iterable[Version]) -> Optional[Version]: for vers in versions: if not vers.is_devrelease and not vers.is_postrelease and not vers.is_prerelease: return vers + return None if version_spec is None: if version := first_full_release(self.installed_versions + self.available_versions): return version + raise VyperInstallError("No available version.") return next(version_spec.filter(self.available_versions)) - def _flatten_source( - self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None - ) -> str: - base_path = base_path or self.config_manager.contracts_folder + def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None) -> str: + pm = project or self.local_project # Get the non stdlib import paths for our contracts imports = list( filter( lambda x: not x.startswith("vyper/"), - [y for x in self.get_imports([path], base_path).values() for y in x], + [y for x in self.get_imports((path,), project=pm).values() for y in x], ) ) dependencies: dict[str, PackageManifest] = {} - for key, manifest in self.remapped_manifests.items(): + dependency_projects = self.get_dependencies(project=pm) + for key, dependency_project in dependency_projects.items(): package = key.split("=")[0] - + base = dependency_project.path if hasattr(dependency_project, "path") else package + manifest = dependency_project.manifest if manifest.sources is None: continue for source_id in manifest.sources.keys(): - import_match = f"{package}/{source_id}" + import_match = f"{base}/{source_id}" dependencies[import_match] = manifest - flattened_source = "" interfaces_source = "" - og_source = (base_path / path).read_text() + og_source = (pm.path / path).read_text() # Get info about imports and source meta aliases = extract_import_aliases(og_source) @@ -630,20 +723,27 @@ def _flatten_source( stdlib_imports, _, source_without_imports = extract_imports(source_without_meta) for import_path in sorted(imports): - import_file = base_path / import_path + import_file = None + for base in (pm.path, pm.interfaces_folder): + for opt in {import_path, import_path.replace(f"interfaces{os.path.sep}", "")}: + try_import_file = base / opt + if try_import_file.is_file(): + import_file = try_import_file + break + + if import_file is None: + import_file = pm.path / import_path # Vyper imported interface names come from their file names file_name = iface_name_from_file(import_file) # If we have a known alias, ("import X as Y"), use the alias as interface name iface_name = aliases[file_name] if file_name in aliases else file_name - # We need to compare without extensions because sometimes they're made up for some - # reason. TODO: Cleaner way to deal with this? - def _match_source(import_path: str) -> Optional[PackageManifest]: - import_path_name = ".".join(import_path.split(".")[:-1]) + def _match_source(imp_path: str) -> Optional[PackageManifest]: for source_path in dependencies.keys(): - if source_path.startswith(import_path_name): + if source_path.endswith(imp_path): return dependencies[source_path] + return None if matched_source := _match_source(import_path): @@ -658,11 +758,10 @@ def _match_source(import_path: str) -> Optional[PackageManifest]: interfaces_source += generate_interface(abis, iface_name) continue - # Vyper imported interface names come from their file names - file_name = iface_name_from_file(import_file) # Generate an ABI from the source code - abis = source_to_abi(import_file.read_text()) - interfaces_source += generate_interface(abis, iface_name) + elif import_file.is_file(): + abis = source_to_abi(import_file.read_text()) + interfaces_source += generate_interface(abis, iface_name) def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]: # Type guard like generator to remove Nones and make mypy happy @@ -683,23 +782,33 @@ def format_source(source: str) -> str: return format_source(flattened_source) - def flatten_contract(self, path: Path, base_path: Optional[Path] = None) -> Content: + def flatten_contract( + self, + path: Path, + project: Optional[ProjectManager] = None, + **kwargs, + ) -> Content: """ Returns the flattened contract suitable for compilation or verification as a single file """ - source = self._flatten_source(path, base_path, path.name) - return Content({i: ln for i, ln in enumerate(source.splitlines())}) + pm = project or self.local_project + src = self._flatten_source(path, project=pm) + return Content({i: ln for i, ln in enumerate(src.splitlines())}) def get_version_map( - self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, ) -> dict[Version, set[Path]]: + pm = project or self.local_project + config = self.get_config(pm) version_map: dict[Version, set[Path]] = {} source_path_by_version_spec: dict[SpecifierSet, set[Path]] = {} source_paths_without_pragma = set() # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): - if config_spec := self.config_version_pragma: + if config_spec := config.version: _safe_append(source_path_by_version_spec, config_spec, path) elif pragma := get_version_pragma_spec(path): _safe_append(source_path_by_version_spec, pragma, path) @@ -755,15 +864,22 @@ def get_version_map( return version_map def get_compiler_settings( - self, contract_filepaths: Iterable[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Iterable[Path], + project: Optional[ProjectManager] = None, + **kwargs, ) -> dict[Version, dict]: - valid_paths = [p for p in contract_filepaths if get_full_extension(p) == ".vy"] - contracts_path = base_path or self.config_manager.contracts_folder - files_by_vyper_version = self.get_version_map(valid_paths, base_path=contracts_path) + pm = project or self.local_project + valid_paths = [p for p in contract_filepaths if p.suffix == ".vy"] + if version := kwargs.pop("version", None): + files_by_vyper_version = {Version(version): set(valid_paths)} + else: + files_by_vyper_version = self.get_version_map(valid_paths, project=pm) + if not files_by_vyper_version: return {} - compiler_data = self._get_compiler_arguments(files_by_vyper_version, contracts_path) + compiler_data = self._get_compiler_arguments(files_by_vyper_version, project=pm) settings = {} for version, data in compiler_data.items(): source_paths = list(files_by_vyper_version.get(version, [])) @@ -771,15 +887,13 @@ def get_compiler_settings( continue output_selection: dict[str, set[str]] = {} - optimizations_map = get_optimization_pragma_map(source_paths, contracts_path) - evm_version_map = get_evm_version_pragma_map(source_paths, contracts_path) - default_evm_version = ( - data.get("evm_version") - or data.get("evmVersion") - or EVM_VERSION_DEFAULT.get(version.base_version) - ) + optimizations_map = get_optimization_pragma_map(source_paths, pm.path) + evm_version_map = get_evm_version_pragma_map(source_paths, pm.path) + default_evm_version = data.get( + "evm_version", data.get("evmVersion") + ) or EVM_VERSION_DEFAULT.get(version.base_version) for source_path in source_paths: - source_id = str(get_relative_path(source_path.absolute(), contracts_path)) + source_id = str(get_relative_path(source_path.absolute(), pm.path)) optimization = optimizations_map.get(source_id, True) evm_version = evm_version_map.get(source_id, default_evm_version) settings_key = f"{optimization}%{evm_version}".lower() @@ -985,15 +1099,18 @@ def _profile(_name: str, _full_name: str): # Auto-getter found. Profile function without statements. contract_coverage.include(method.name, method.selector) - def _get_compiler_arguments(self, version_map: dict, base_path: Path) -> dict[Version, dict]: - base_path = base_path or self.project_manager.contracts_folder + def _get_compiler_arguments( + self, version_map: dict, project: Optional[ProjectManager] = None + ) -> dict[Version, dict]: + pm = project or self.local_project + config = self.get_config(pm) + evm_version = config.evm_version arguments_map = {} for vyper_version, source_paths in version_map.items(): bin_arg = self._get_vyper_bin(vyper_version) arguments_map[vyper_version] = { - "base_path": str(base_path), - "evm_version": self.evm_version - or EVM_VERSION_DEFAULT.get(vyper_version.base_version), + "base_path": f"{pm.path}", + "evm_version": evm_version, "vyper_version": str(vyper_version), "vyper_binary": bin_arg, } @@ -1275,7 +1392,7 @@ def _create_contract_from_call(self, frame: dict) -> tuple[Optional[ContractSour return None, calldata called_contract = self.chain_manager.contracts[address] - return self.project_manager._create_contract_source(called_contract), calldata + return self.local_project._create_contract_source(called_contract), calldata def _safe_append(data: dict, version: Union[Version, SpecifierSet], paths: Union[Path, set]): diff --git a/tests/ape-config.yaml b/tests/ape-config.yaml index ec6f7227..bd35d1e3 100644 --- a/tests/ape-config.yaml +++ b/tests/ape-config.yaml @@ -3,10 +3,5 @@ contracts_folder: contracts/passing_contracts # Specify a dependency to use in Vyper imports. dependencies: - - name: ExampleDependency + - name: exampledependency local: ./ExampleDependency - -vyper: - # Allows importing dependencies. - import_remapping: - - "exampledep=ExampleDependency" diff --git a/tests/conftest.py b/tests/conftest.py index 0f20707c..0d33c100 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import tempfile from contextlib import contextmanager from pathlib import Path from tempfile import mkdtemp @@ -9,12 +10,6 @@ import vvm # type: ignore from click.testing import CliRunner -# NOTE: Ensure that we don't use local paths for these -DATA_FOLDER = Path(mkdtemp()).resolve() -PROJECT_FOLDER = Path(mkdtemp()).resolve() -ape.config.DATA_FOLDER = DATA_FOLDER -ape.config.PROJECT_FOLDER = PROJECT_FOLDER - BASE_CONTRACTS_PATH = Path(__file__).parent / "contracts" TEMPLATES_PATH = BASE_CONTRACTS_PATH / "templates" FAILING_BASE = BASE_CONTRACTS_PATH / "failing_contracts" @@ -45,6 +40,31 @@ } +@pytest.fixture(scope="session", autouse=True) +def from_tests_dir(): + # Makes default project correct. + here = Path(__file__).parent + orig = Path.cwd() + if orig != here: + os.chdir(f"{here}") + + yield + + if Path.cwd() != orig: + os.chdir(f"{orig}") + + +@pytest.fixture(scope="session", autouse=True) +def config(): + cfg = ape.config + + # Ensure we don't persist any .ape data. + with tempfile.TemporaryDirectory() as temp_dir: + path = Path(temp_dir).resolve() + cfg.DATA_FOLDER = path + yield cfg + + def contract_test_cases(passing: bool) -> list[str]: """ Returns test-case names for outputting nicely with pytest. @@ -125,16 +145,6 @@ def temp_vvm_path(monkeypatch): yield path -@pytest.fixture -def data_folder(): - return DATA_FOLDER - - -@pytest.fixture -def project_folder(): - return PROJECT_FOLDER - - @pytest.fixture def compiler_manager(): return ape.compilers @@ -145,26 +155,17 @@ def compiler(compiler_manager): return compiler_manager.vyper -@pytest.fixture -def config(): - return ape.config - - -@pytest.fixture(autouse=True) -def project(config, project_folder): +@pytest.fixture(scope="session", autouse=True) +def project(config): project_source_dir = Path(__file__).parent - project_dest_dir = project_folder / project_source_dir.name - shutil.rmtree(project_dest_dir, ignore_errors=True) # Delete build / .cache that may exist pre-copy - project_path = Path(__file__).parent - cache = project_path / ".build" + cache = project_source_dir / ".build" shutil.rmtree(cache, ignore_errors=True) - shutil.copytree(project_source_dir, project_dest_dir, dirs_exist_ok=True) - with config.using_project(project_dest_dir) as project: - yield project - shutil.rmtree(project.local_project._cache_folder, ignore_errors=True) + root_project = ape.Project(project_source_dir) + with root_project.isolate_in_tempdir() as tmp_project: + yield tmp_project @pytest.fixture diff --git a/tests/contracts/passing_contracts/flatten_me.vy b/tests/contracts/passing_contracts/flatten_me.vy index 456e2457..77b993d2 100644 --- a/tests/contracts/passing_contracts/flatten_me.vy +++ b/tests/contracts/passing_contracts/flatten_me.vy @@ -4,7 +4,7 @@ from vyper.interfaces import ERC20 from interfaces import IFace2 as IFaceTwo import interfaces.IFace as IFace -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep @external diff --git a/tests/contracts/passing_contracts/use_iface.vy b/tests/contracts/passing_contracts/use_iface.vy index ac136ddc..38f99932 100644 --- a/tests/contracts/passing_contracts/use_iface.vy +++ b/tests/contracts/passing_contracts/use_iface.vy @@ -4,7 +4,7 @@ import interfaces.IFace as IFace # Import from input JSON (ape-config.yaml). -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep from interfaces import IFace2 as IFace2 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 9a569c67..22585b9d 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,8 +1,10 @@ import re +from pathlib import Path +import ape import pytest import vvm # type: ignore -from ape.exceptions import ContractLogicError +from ape.exceptions import CompilerError, ContractLogicError from ethpm_types import ContractType from packaging.version import Version from vvm.exceptions import VyperError # type: ignore @@ -48,39 +50,42 @@ def test_compile_project(project): assert len(contracts) == len( [p.name for p in project.contracts_folder.glob("*.vy") if p.is_file()] ) - assert contracts["contract_039"].source_id == "contract_039.vy" - assert contracts["contract_no_pragma"].source_id == "contract_no_pragma.vy" - assert contracts["older_version"].source_id == "older_version.vy" + prefix = "contracts/passing_contracts" + assert contracts["contract_039"].source_id == f"{prefix}/contract_039.vy" + assert contracts["contract_no_pragma"].source_id == f"{prefix}/contract_no_pragma.vy" + assert contracts["older_version"].source_id == f"{prefix}/older_version.vy" @pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) def test_compile_individual_contracts(project, contract_name, compiler): path = project.contracts_folder / contract_name - assert list(compiler.compile([path])) + assert list(compiler.compile((path,), project=project)) @pytest.mark.parametrize( "contract_name", [n for n in FAILING_CONTRACT_NAMES if n != "contract_unknown_pragma.vy"] ) def test_compile_failures(contract_name, compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / contract_name with pytest.raises(VyperCompileError, match=EXPECTED_FAIL_PATTERNS[path.stem]) as err: - list(compiler.compile([path], base_path=FAILING_BASE)) + list(compiler.compile((path,), project=failing_project)) assert isinstance(err.value.base_err, VyperError) def test_install_failure(compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / "contract_unknown_pragma.vy" with pytest.raises(VyperInstallError, match="No available version to install."): - list(compiler.compile([path])) + list(compiler.compile((path,), project=failing_project)) def test_get_version_map(project, compiler, all_versions): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_version_map(vyper_files) + actual = compiler.get_version_map(vyper_files, project=project) expected_versions = [Version(v) for v in all_versions] for version, sources in actual.items(): @@ -157,23 +162,30 @@ def run_test(manifest): assert true_latest.settings["evmVersion"] == "shanghai" # There is only one contract with codesize pragma. - assert codesize_latest.contractTypes == ["optimize_codesize"] + assert codesize_latest.contractTypes == [ + "contracts/passing_contracts/optimize_codesize.vy:optimize_codesize" + ] assert codesize_latest.settings["optimize"] == "codesize" # There is only one contract with evm-version pragma. - assert evm_latest.contractTypes == ["evm_pragma"] + assert evm_latest.contractTypes == ["contracts/passing_contracts/evm_pragma.vy:evm_pragma"] assert evm_latest.settings["evmVersion"] == "paris" assert len(true_latest.contractTypes) >= 9 assert len(vyper_028.contractTypes) >= 1 - assert "contract_0310" in true_latest.contractTypes - assert "older_version" in vyper_028.contractTypes + assert ( + "contracts/passing_contracts/contract_0310.vy:contract_0310" + in true_latest.contractTypes + ) + assert ( + "contracts/passing_contracts/older_version.vy:older_version" in vyper_028.contractTypes + ) for compiler in (true_latest, vyper_028): assert compiler.settings["optimize"] is True - project.local_project.update_manifest(compilers=[]) + project.update_manifest(compilers=[]) project.load_contracts(use_cache=False) - run_test(project.local_project.manifest) + run_test(project.manifest) man = project.extract_manifest() run_test(man) @@ -185,7 +197,7 @@ def test_compile_parse_dev_messages(compiler, dev_revert_source, project): The compiler will output a map that maps dev messages to line numbers. See contract_with_dev_messages.vy for more information. """ - result = list(compiler.compile([dev_revert_source], base_path=project.contracts_folder)) + result = list(compiler.compile((dev_revert_source,), project=project)) assert len(result) == 1 @@ -204,18 +216,22 @@ def test_get_imports(compiler, project): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_imports(vyper_files) + actual = compiler.get_imports(vyper_files, project=project) + prefix = "contracts/passing_contracts" builtin_import = "vyper/interfaces/ERC20.json" local_import = "interfaces/IFace.vy" local_from_import = "interfaces/IFace2.vy" - dependency_import = "exampledep/Dependency.json" - - assert len(actual["contract_037.vy"]) == 1 - assert set(actual["contract_037.vy"]) == {builtin_import} - assert len(actual["use_iface.vy"]) == 3 - assert set(actual["use_iface.vy"]) == {local_import, local_from_import, dependency_import} - assert len(actual["use_iface2.vy"]) == 1 - assert set(actual["use_iface2.vy"]) == {local_import} + dep_key = project.dependencies.get_dependency("exampledependency", "local").package_id.replace( + "/", "_" + ) + dependency_import = f"{dep_key}/local/contracts/Dependency.vy" + assert set(actual[f"{prefix}/contract_037.vy"]) == {builtin_import} + assert set(actual[f"{prefix}/use_iface.vy"]) == { + local_import, + local_from_import, + dependency_import, + } + assert set(actual[f"{prefix}/use_iface2.vy"]) == {local_import} @pytest.mark.parametrize("src,vers", [("contract_039", "0.3.9"), ("contract_037", "0.3.7")]) @@ -225,15 +241,16 @@ def test_pc_map(compiler, project, src, vers): from `compile_src()` which includes the uncompressed source map data. """ - path = project.contracts_folder / f"{src}.vy" - result = list(compiler.compile([path], base_path=project.contracts_folder))[0] + path = project.sources.lookup(src) + result = list(compiler.compile((path,), project=project))[0] actual = result.pcmap.root code = path.read_text() vvm.install_vyper(vers) - compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ - "" - ] - src_map = compile_result["source_map"] + cfg = compiler.get_config(project=project) + evm_version = cfg.evm_version + compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=evm_version) + std_result = compile_result[""] + src_map = std_result["source_map"] lines = code.splitlines() # Use the old-fashioned way of gathering PCMap to ensure our creative way works @@ -389,7 +406,7 @@ def test_enrich_error_handle_when_name(compiler, geth_provider, mocker): def test_trace_source(account, geth_provider, project, traceback_contract, arguments): receipt = traceback_contract.addBalance(*arguments, sender=account) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name expected = rf""" Traceback (most recent call last) @@ -437,7 +454,7 @@ def test_trace_err_source(account, geth_provider, project, traceback_contract): receipt = geth_provider.get_receipt(txn.txn_hash.hex()) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name version_key = contract_name.split("traceback_contract_")[-1] expected = rf""" @@ -478,19 +495,20 @@ def test_compile_with_version_set_in_config(config, projects_path, compiler, moc path = projects_path / "version_in_config" version_from_config = "0.3.7" spy = mocker.patch("ape_vyper.compiler.vvm_compile_standard") - with config.using_project(path) as project: - contract = project.contracts_folder / "v_contract.vy" - settings = compiler.get_compiler_settings((contract,)) - assert str(list(settings.keys())[0]) == version_from_config + project = ape.Project(path) - # Show it uses this version in the compiler. - project.load_contracts(use_cache=False) - assert str(spy.call_args[1]["vyper_version"]) == version_from_config + contract = project.contracts_folder / "v_contract.vy" + settings = compiler.get_compiler_settings((contract,), project=project) + assert str(list(settings.keys())[0]) == version_from_config + + # Show it uses this version in the compiler. + project.load_contracts(use_cache=False) + assert str(spy.call_args[1]["vyper_version"]) == version_from_config -def test_compile_code(compiler, dev_revert_source): +def test_compile_code(project, compiler, dev_revert_source): code = dev_revert_source.read_text() - actual = compiler.compile_code(code, contractName="MyContract") + actual = compiler.compile_code(code, project=project, contractName="MyContract") assert isinstance(actual, ContractType) assert actual.name == "MyContract" assert len(actual.abi) > 1 @@ -501,13 +519,13 @@ def test_compile_code(compiler, dev_revert_source): def test_compile_with_version_set_in_settings_dict(config, compiler_manager, projects_path): path = projects_path / "version_in_config" contract = path / "contracts" / "v_contract.vy" - - with config.using_project(path): - expected = ( - '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' - ) - with pytest.raises(VyperCompileError, match=expected): - list(compiler_manager.compile([contract], settings={"version": "0.3.3"})) + project = ape.Project(path) + expected = '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' + iterator = compiler_manager.compile( + (contract,), project=project, settings={"vyper": {"version": "0.3.3"}} + ) + with pytest.raises(CompilerError, match=expected): + _ = list(iterator) @pytest.mark.parametrize( @@ -529,8 +547,8 @@ def test_compile_with_version_set_in_settings_dict(config, compiler_manager, pro ) def test_flatten_contract(all_versions, project, contract_name, compiler): path = project.contracts_folder / contract_name - source = compiler.flatten_contract(path) + source = compiler.flatten_contract(path, project=project) source_code = str(source) version = compiler._source_vyper_version(source_code) vvm.install_vyper(str(version)) - vvm.compile_source(source_code, base_path=project.contracts_folder, vyper_version=version) + vvm.compile_source(source_code, base_path=project.path, vyper_version=version)