From e8589df8507ba90a782901e79cc0b8b90f3c237a Mon Sep 17 00:00:00 2001 From: Juliya Smith Date: Wed, 10 Apr 2024 19:17:20 -0500 Subject: [PATCH] refactor: project --- .github/workflows/commitlint.yaml | 4 +- .github/workflows/prtitle.yaml | 4 +- .github/workflows/publish.yaml | 4 +- .github/workflows/test.yaml | 16 +- .pre-commit-config.yaml | 6 +- ape_vyper/compiler.py | 361 ++++++++++++------ setup.py | 6 +- tests/ape-config.yaml | 6 +- tests/conftest.py | 66 ++-- .../contracts/passing_contracts/flatten_me.vy | 2 +- .../contracts/passing_contracts/use_iface.vy | 2 +- tests/test_compiler.py | 113 +++--- 12 files changed, 358 insertions(+), 232 deletions(-) diff --git a/.github/workflows/commitlint.yaml b/.github/workflows/commitlint.yaml index 17294ec2..3b44d690 100644 --- a/.github/workflows/commitlint.yaml +++ b/.github/workflows/commitlint.yaml @@ -9,12 +9,12 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/prtitle.yaml b/.github/workflows/prtitle.yaml index 4de42565..f523a078 100644 --- a/.github/workflows/prtitle.yaml +++ b/.github/workflows/prtitle.yaml @@ -12,10 +12,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 88b2e64f..2c550693 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9d961395..1b2d345c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -17,10 +17,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" @@ -45,10 +45,10 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" @@ -72,10 +72,10 @@ jobs: GETH_VERSION: 1.12.0 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -130,10 +130,10 @@ jobs: # fail-fast: true # # steps: -# - uses: actions/checkout@v3 +# - uses: actions/checkout@v4 # # - name: Setup Python -# uses: actions/setup-python@v4 +# uses: actions/setup-python@v5 # with: # python-version: "3.10" # diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28881d1f..19b9642a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,18 +10,18 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 23.12.0 + rev: 24.3.0 hooks: - id: black name: black - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.7.1 + rev: v1.9.0 hooks: - id: mypy additional_dependencies: [types-setuptools, pydantic==1.10.4] diff --git a/ape_vyper/compiler.py b/ape_vyper/compiler.py index e7691fd0..95508923 100644 --- a/ape_vyper/compiler.py +++ b/ape_vyper/compiler.py @@ -13,6 +13,7 @@ from ape.api.compiler import CompilerAPI from ape.exceptions import ContractLogicError from ape.logging import logger +from ape.managers.project import ProjectManager from ape.types import ContractSourceCoverage, ContractType, SourceTraceback, TraceFrame from ape.utils import GithubClient, cached_property, get_relative_path, pragma_str_to_specifier_set from eth_pydantic_types import HexBytes @@ -24,7 +25,7 @@ from evm_trace.enums import CALL_OPCODES from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version -from pydantic import field_serializer, field_validator +from pydantic import field_serializer, field_validator, model_validator from vvm import compile_standard as vvm_compile_standard from vvm.exceptions import VyperError # type: ignore @@ -53,6 +54,38 @@ Optimization = Union[str, bool] +class Remapping(PluginConfig): + key: str + dependency_name: str + dependency_version: Optional[None] = None + + @model_validator(mode="before") + @classmethod + def validate_str(cls, value): + if isinstance(value, str): + parts = value.split("=") + key = parts[0].strip() + value = parts[1].strip() + if "@" in value: + value_parts = value.split("@") + dep_name = value_parts[0].strip() + dep_version = value_parts[1].strip() + else: + dep_name = value + dep_version = None + + return {"key": key, "dependency_name": dep_name, "dependency_version": dep_version} + + return value + + def __str__(self) -> str: + value = self.dependency_name + if _version := self.dependency_version: + value = f"{value}@{_version}" + + return f"{self.key}={value}" + + class VyperConfig(PluginConfig): version: Optional[SpecifierSet] = None """ @@ -65,7 +98,7 @@ class VyperConfig(PluginConfig): The evm-version or hard-fork name. """ - import_remapping: List[str] = [] + import_remapping: List[Remapping] = [] """ Configuration of an import name mapped to a dependency listing. To use a specific version of a dependency, specify using ``@`` symbol. @@ -202,22 +235,17 @@ class VyperCompiler(CompilerAPI): def name(self) -> str: return "vyper" - @property - def settings(self) -> VyperConfig: - return cast(VyperConfig, super().settings) - - @property - def evm_version(self) -> Optional[str]: - return self.settings.evm_version - def get_imports( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Sequence[Path], + project: Optional[ProjectManager] = None, ) -> Dict[str, List[str]]: - base_path = (base_path or self.project_manager.contracts_folder).absolute() + pm = project or self.project_manager import_map = {} + dependencies = self.get_dependencies(project=pm) for path in contract_filepaths: content = path.read_text().splitlines() - source_id = str(get_relative_path(path.absolute(), base_path.absolute())) + source_id = str(get_relative_path(path.absolute(), pm.path.absolute())) for line in content: if line.startswith("import "): import_line_parts = line.replace("import ", "").split(" ") @@ -234,12 +262,40 @@ def get_imports( # NOTE: Defaults to JSON (assuming from input JSON or a local JSON), # unless a Vyper file exists. - ext = "vy" if (base_path / f"{suffix}.vy").is_file() else "json" + import_source_id = None + if (pm.interfaces_folder.parent / f"{suffix}.vy").is_file(): + import_source_id = f"{suffix}.vy" + + elif (pm.interfaces_folder.parent / f"{suffix}.json").is_file(): + import_source_id = f"{suffix}.json" + + elif suffix.startswith(f"vyper{os.path.sep}"): + # Vyper built-ins. + import_source_id = f"{suffix}.json" + + elif suffix.split(os.path.sep)[0] in dependencies: + dependency_name = suffix.split(os.path.sep)[0] + filestem = suffix.replace(f"{dependency_name}{os.path.sep}", "") + for version_str, version_installed in pm.dependencies[dependency_name].items(): + dep_project = version_installed.project + dependency_source_prefix = ( + f"{get_relative_path(dep_project.contracts_folder, dep_project.path)}" + ) + source_id_stem = f"{dependency_source_prefix}{os.path.sep}{filestem}" + for ext in (".vy", ".json"): + if f"{source_id_stem}{ext}" in dep_project.sources: + import_source_id = os.path.sep.join( + (dependency_name, version_str, f"{source_id_stem}{ext}") + ) + break + + else: + logger.error(f"Unable to find dependency {suffix}") + continue - import_source_id = f"{suffix}.{ext}" - if source_id not in import_map: + if import_source_id and source_id not in import_map: import_map[source_id] = [import_source_id] - elif import_source_id not in import_map[source_id]: + elif import_source_id and import_source_id not in import_map[source_id]: import_map[source_id].append(import_source_id) return import_map @@ -333,62 +389,74 @@ def vyper_json(self): except ImportError: return None - @property - def config_version_pragma(self) -> Optional[SpecifierSet]: - if version := self.settings.version: - return version - - return None + def get_dependencies( + self, project: Optional[ProjectManager] = None + ) -> Dict[str, ProjectManager]: + pm = project or self.project_manager + config = self.get_config(pm) + dependencies: Dict[str, ProjectManager] = {} + handled: Set[str] = set() + + # Add remappings from config. + for remapping in config.import_remapping: + name = remapping.dependency_name + if not (_version := remapping.dependency_version): + versions = pm.dependencies[name] + if len(versions) == 1: + _version = versions[0] + else: + continue - @property - def remapped_manifests(self) -> Dict[str, PackageManifest]: - """ - Interface import manifests. - """ + dependency = pm.dependencies.get_dependency(name, _version) + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - dependencies: Dict[str, PackageManifest] = {} + handled.add(dep_id) - for remapping in self.settings.import_remapping: - key, value = remapping.split("=") + try: + dependency.compile() + except Exception: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'." + ) + continue - if remapping in dependencies: - dependency = dependencies[remapping] - else: - parts = value.split("@") - dep_name = parts[0] - dependency_versions = self.project_manager.dependencies[dep_name] - if not dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + dependencies[remapping.key] = dependency.project - elif len(parts) == 1 and len(dependency_versions) < 2: - # Use only version. - version = list(dependency_versions.keys())[0] + # Add auto-remapped dependencies. + for dependency in pm.dependencies.specified: + dep_id = f"{dependency.name}_{dependency.version}" + if dep_id in handled: + continue - elif parts[1] not in dependency_versions: - raise VyperCompileError(f"Missing dependency '{dep_name}'.") + handled.add(dep_id) - else: - version = parts[1] + try: + dependency.compile() + except Exception: + logger.warning( + f"Failed to compile dependency '{dependency.name}' @ '{dependency.version}'." + ) + continue - dependency = dependency_versions[version].compile() - dependencies[remapping] = dependency + dependencies[dependency.name] = dependency.project return dependencies - @property - def import_remapping(self) -> Dict[str, Dict]: + def get_import_remapping(self, project: Optional[ProjectManager] = None) -> Dict[str, Dict]: """ Configured interface imports from dependencies. """ - - interfaces = {} - - for remapping in self.settings.import_remapping: - key, _ = remapping.split("=") - for name, ct in (self.remapped_manifests[remapping].contract_types or {}).items(): - interfaces[f"{key}/{name}.json"] = { - "abi": [x.model_dump(mode="json", by_alias=True) for x in ct.abi] - } + pm = project or self.project_manager + dependencies = self.get_dependencies(project=pm) + interfaces: Dict[str, Dict] = {} + for key, dependency_project in dependencies.items(): + manifest = dependency_project.manifest + for name, ct in (manifest.contract_types or {}).items(): + filename = f"{key}/{name}.json" + abi_list = [x.model_dump(mode="json", by_alias=True) for x in ct.abi] + interfaces[filename] = {"abi": abi_list} return interfaces @@ -400,36 +468,55 @@ def classify_ast(self, _node: ASTNode): self.classify_ast(child) def compile( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None - ) -> List[ContractType]: - contract_types = [] - base_path = base_path or self.config_manager.contracts_folder + self, + contract_filepaths: Sequence[Path], + project: Optional[ProjectManager] = None, + settings: Optional[Dict] = None, + ) -> Iterator[ContractType]: + pm = project or self.project_manager + settings = settings or {} sources = [p for p in contract_filepaths if p.parent.name != "interfaces"] - version_map = self.get_version_map(sources) - compiler_data = self._get_compiler_arguments(version_map, base_path) - all_settings = self.get_compiler_settings(sources, base_path=base_path) + contract_types: List[ContractType] = [] + if version := settings.get("version", None): + version_map = {Version(version): set(sources)} + else: + version_map = self.get_version_map(sources, project=project) + + compiler_data = self._get_compiler_arguments(version_map, project=pm) + all_settings: Dict = self.get_compiler_settings( + sources, project=project, **(settings or {}) + ) contract_versions: Dict[str, Tuple[Version, str]] = {} + import_remapping = self.get_import_remapping(project=pm) for vyper_version, version_settings in all_settings.items(): - for settings_key, settings in version_settings.items(): - source_ids = settings["outputSelection"] - optimization_paths = {p: base_path / p for p in source_ids} - input_json = { + for settings_key, settings_set in version_settings.items(): + source_ids = settings_set["outputSelection"] + optimization_paths = {p: pm.path / p for p in source_ids} + input_json: Dict = { "language": "Vyper", - "settings": settings, + "settings": settings_set, "sources": { s: {"content": p.read_text()} for s, p in optimization_paths.items() }, } - if interfaces := self.import_remapping: + if interfaces := import_remapping: input_json["interfaces"] = interfaces + # Output compiler details. + keys = ( + "\n\t".join(sorted([x for x in input_json.get("sources", {}).keys()])) + or "No input." + ) + log_str = f"Compiling using Vyper compiler '{vyper_version}'.\nInput:\n\t{keys}" + logger.info(log_str) + vyper_binary = compiler_data[vyper_version]["vyper_binary"] try: result = vvm_compile_standard( input_json, - base_path=base_path, + base_path=pm.path, vyper_version=vyper_version, vyper_binary=vyper_binary, ) @@ -439,7 +526,7 @@ def compile( for source_id, output_items in result["contracts"].items(): content = { i + 1: ln - for i, ln in enumerate((base_path / source_id).read_text().splitlines()) + for i, ln in enumerate((pm.path / source_id).read_text().splitlines()) } for name, output in output_items.items(): # De-compress source map to get PC POS map. @@ -489,6 +576,7 @@ def compile( dev_messages=dev_messages, ) contract_types.append(contract_type) + yield contract_type contract_versions[name] = (vyper_version, settings_key) # Output compiler data used. @@ -504,13 +592,14 @@ def compile( if ct_version not in compilers_used: compilers_used[ct_version] = {} + contract_id = f"{ct.source_id}:{ct.name}" if ct_settings_key in compilers_used[ct_version] and ct.name not in ( compilers_used[ct_version][ct_settings_key].contractTypes or [] ): # Add contractType to already-tracked compiler. compilers_used[ct_version][ct_settings_key].contractTypes = [ *(compilers_used[ct_version][ct_settings_key].contractTypes or []), - ct.name, + contract_id, ] elif ct_settings_key not in compilers_used[ct_version]: @@ -518,7 +607,7 @@ def compile( compilers_used[ct_version][ct_settings_key] = Compiler( name=self.name.lower(), version=f"{ct_version}", - contractTypes=[ct.name], + contractTypes=[contract_id], settings=settings, ) @@ -531,12 +620,12 @@ def compile( # NOTE: This method handles merging contractTypes and filtered out # no longer used Compilers. - self.project_manager.local_project.add_compiler_data(compilers_ls) + pm.add_compiler_data(compilers_ls) - return contract_types - - def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> ContractType: - base_path = base_path or self.project_manager.contracts_folder + def compile_code( + self, code: str, project: Optional[ProjectManager] = None, **kwargs + ) -> ContractType: + pm = project or self.project_manager # Figure out what compiler version we need for this contract... version = self._source_vyper_version(code) @@ -544,7 +633,7 @@ def compile_code(self, code: str, base_path: Optional[Path] = None, **kwargs) -> _install_vyper(version) try: - result = vvm.compile_source(code, base_path=base_path, vyper_version=version) + result = vvm.compile_source(code, base_path=pm.path, vyper_version=version) except Exception as err: raise VyperCompileError(str(err)) from err @@ -564,42 +653,43 @@ def first_full_release(versions: Iterable[Version]) -> Optional[Version]: for vers in versions: if not vers.is_devrelease and not vers.is_postrelease and not vers.is_prerelease: return vers + return None if version_spec is None: if version := first_full_release(self.installed_versions + self.available_versions): return version + raise VyperInstallError("No available version.") return next(version_spec.filter(self.available_versions)) - def _flatten_source( - self, path: Path, base_path: Optional[Path] = None, raw_import_name: Optional[str] = None - ) -> str: - base_path = base_path or self.config_manager.contracts_folder + def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None) -> str: + pm = project or self.project_manager # Get the non stdlib import paths for our contracts imports = list( filter( lambda x: not x.startswith("vyper/"), - [y for x in self.get_imports([path], base_path).values() for y in x], + [y for x in self.get_imports((path,), project=pm).values() for y in x], ) ) dependencies: Dict[str, PackageManifest] = {} - for key, manifest in self.remapped_manifests.items(): + dependency_projects = self.get_dependencies(project=pm) + for key, dependency_project in dependency_projects.items(): package = key.split("=")[0] - + base = dependency_project.path if hasattr(dependency_project, "path") else package + manifest = dependency_project.manifest if manifest.sources is None: continue for source_id in manifest.sources.keys(): - import_match = f"{package}/{source_id}" + import_match = f"{base}/{source_id}" dependencies[import_match] = manifest - flattened_source = "" interfaces_source = "" - og_source = (base_path / path).read_text() + og_source = (pm.path / path).read_text() # Get info about imports and source meta aliases = extract_import_aliases(og_source) @@ -607,20 +697,27 @@ def _flatten_source( stdlib_imports, _, source_without_imports = extract_imports(source_without_meta) for import_path in sorted(imports): - import_file = base_path / import_path + import_file = None + for base in (pm.path, pm.interfaces_folder): + for opt in {import_path, import_path.replace(f"interfaces{os.path.sep}", "")}: + try_import_file = base / opt + if try_import_file.is_file(): + import_file = try_import_file + break + + if import_file is None: + import_file = pm.path / import_path # Vyper imported interface names come from their file names file_name = iface_name_from_file(import_file) # If we have a known alias, ("import X as Y"), use the alias as interface name iface_name = aliases[file_name] if file_name in aliases else file_name - # We need to compare without extensions because sometimes they're made up for some - # reason. TODO: Cleaner way to deal with this? - def _match_source(import_path: str) -> Optional[PackageManifest]: - import_path_name = ".".join(import_path.split(".")[:-1]) + def _match_source(imp_path: str) -> Optional[PackageManifest]: for source_path in dependencies.keys(): - if source_path.startswith(import_path_name): + if source_path.endswith(imp_path): return dependencies[source_path] + return None if matched_source := _match_source(import_path): @@ -635,11 +732,10 @@ def _match_source(import_path: str) -> Optional[PackageManifest]: interfaces_source += generate_interface(abis, iface_name) continue - # Vyper imported interface names come from their file names - file_name = iface_name_from_file(import_file) # Generate an ABI from the source code - abis = source_to_abi(import_file.read_text()) - interfaces_source += generate_interface(abis, iface_name) + elif import_file.is_file(): + abis = source_to_abi(import_file.read_text()) + interfaces_source += generate_interface(abis, iface_name) def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]: # Type guard like generator to remove Nones and make mypy happy @@ -660,23 +756,33 @@ def format_source(source: str) -> str: return format_source(flattened_source) - def flatten_contract(self, path: Path, base_path: Optional[Path] = None) -> Content: + def flatten_contract( + self, + path: Path, + project: Optional[ProjectManager] = None, + **kwargs, + ) -> Content: """ Returns the flattened contract suitable for compilation or verification as a single file """ - source = self._flatten_source(path, base_path, path.name) + pm = project or self.project_manager + source = self._flatten_source(path, project=pm) return Content({i: ln for i, ln in enumerate(source.splitlines())}) def get_version_map( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Sequence[Path], + project: Optional[ProjectManager] = None, ) -> Dict[Version, Set[Path]]: + pm = project or self.project_manager + config = self.get_config(pm) version_map: Dict[Version, Set[Path]] = {} source_path_by_version_spec: Dict[SpecifierSet, Set[Path]] = {} source_paths_without_pragma = set() # Sort contract_filepaths to promote consistent, reproduce-able behavior for path in sorted(contract_filepaths): - if config_spec := self.config_version_pragma: + if config_spec := config.version: _safe_append(source_path_by_version_spec, config_spec, path) elif pragma := get_version_pragma_spec(path): _safe_append(source_path_by_version_spec, pragma, path) @@ -730,15 +836,22 @@ def get_version_map( return version_map def get_compiler_settings( - self, contract_filepaths: Sequence[Path], base_path: Optional[Path] = None + self, + contract_filepaths: Sequence[Path], + project: Optional[ProjectManager] = None, + **kwargs, ) -> Dict[Version, Dict]: + pm = project or self.project_manager valid_paths = [p for p in contract_filepaths if p.suffix == ".vy"] - contracts_path = base_path or self.config_manager.contracts_folder - files_by_vyper_version = self.get_version_map(valid_paths, base_path=contracts_path) + if version := kwargs.pop("version", None): + files_by_vyper_version = {Version(version): set(valid_paths)} + else: + files_by_vyper_version = self.get_version_map(valid_paths, project=pm) + if not files_by_vyper_version: return {} - compiler_data = self._get_compiler_arguments(files_by_vyper_version, contracts_path) + compiler_data = self._get_compiler_arguments(files_by_vyper_version, project=pm) settings = {} for version, data in compiler_data.items(): source_paths = list(files_by_vyper_version.get(version, [])) @@ -746,11 +859,11 @@ def get_compiler_settings( continue output_selection: Dict[str, Set[str]] = {} - optimizations_map = get_optimization_pragma_map(source_paths, contracts_path) - evm_version_map = get_evm_version_pragma_map(source_paths, contracts_path) + optimizations_map = get_optimization_pragma_map(source_paths, pm.path) + evm_version_map = get_evm_version_pragma_map(source_paths, pm.path) default_evm_version = data.get("evm_version", data.get("evmVersion")) for source_path in source_paths: - source_id = str(get_relative_path(source_path.absolute(), contracts_path)) + source_id = str(get_relative_path(source_path.absolute(), pm.path)) optimization = optimizations_map.get(source_id, True) evm_version = evm_version_map.get(source_id, default_evm_version) settings_key = f"{optimization}%{evm_version}".lower() @@ -956,14 +1069,18 @@ def _profile(_name: str, _full_name: str): # Auto-getter found. Profile function without statements. contract_coverage.include(method.name, method.selector) - def _get_compiler_arguments(self, version_map: Dict, base_path: Path) -> Dict[Version, Dict]: - base_path = base_path or self.project_manager.contracts_folder + def _get_compiler_arguments( + self, version_map: Dict, project: Optional[ProjectManager] = None + ) -> Dict[Version, Dict]: + pm = project or self.project_manager + config = self.get_config(pm) + evm_version = config.evm_version arguments_map = {} for vyper_version, source_paths in version_map.items(): bin_arg = self._get_vyper_bin(vyper_version) arguments_map[vyper_version] = { - "base_path": str(base_path), - "evm_version": self.evm_version, + "base_path": f"{pm.path}", + "evm_version": evm_version, "vyper_version": str(vyper_version), "vyper_binary": bin_arg, } @@ -1017,12 +1134,12 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: ) def trace_source( - self, contract_type: ContractType, trace: Iterator[TraceFrame], calldata: HexBytes + self, + contract_source: ContractSource, + trace: Iterator[TraceFrame], + calldata: HexBytes, ) -> SourceTraceback: - if source_contract_type := self.project_manager._create_contract_source(contract_type): - return self._get_traceback(source_contract_type, trace, calldata) - - return SourceTraceback.model_validate([]) + return self._get_traceback(contract_source, trace, calldata) def _get_traceback( self, diff --git a/setup.py b/setup.py index 678406a1..aebc8768 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,10 @@ "hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer ], "lint": [ - "black>=23.12.0,<24", # Auto-formatter and linter - "mypy>=1.7.1", # Static type analyzer + "black>=24.3.0,<25", # Auto-formatter and linter + "mypy>=1.9.0,<2", # Static type analyzer "types-setuptools", # Needed due to mypy typeshed - "flake8>=6.1.0,<7", # Style linter + "flake8>=7.0.0,<8", # Style linter "isort>=5.10.1", # Import sorting linter "mdformat>=0.7.17", # Auto-formatter for markdown "mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown diff --git a/tests/ape-config.yaml b/tests/ape-config.yaml index 1677b3e9..b0209580 100644 --- a/tests/ape-config.yaml +++ b/tests/ape-config.yaml @@ -3,12 +3,8 @@ contracts_folder: contracts/passing_contracts # Specify a dependency to use in Vyper imports. dependencies: - - name: ExampleDependency + - name: exampledependency local: ./ExampleDependency vyper: evm_version: istanbul - - # Allows importing dependencies. - import_remapping: - - "exampledep=ExampleDependency" diff --git a/tests/conftest.py b/tests/conftest.py index 54b79248..90e9ca95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os import shutil +import tempfile from contextlib import contextmanager -from distutils.dir_util import copy_tree from pathlib import Path from tempfile import mkdtemp from typing import List @@ -11,12 +11,6 @@ import vvm # type: ignore from click.testing import CliRunner -# NOTE: Ensure that we don't use local paths for these -DATA_FOLDER = Path(mkdtemp()).resolve() -PROJECT_FOLDER = Path(mkdtemp()).resolve() -ape.config.DATA_FOLDER = DATA_FOLDER -ape.config.PROJECT_FOLDER = PROJECT_FOLDER - BASE_CONTRACTS_PATH = Path(__file__).parent / "contracts" TEMPLATES_PATH = BASE_CONTRACTS_PATH / "templates" FAILING_BASE = BASE_CONTRACTS_PATH / "failing_contracts" @@ -47,6 +41,31 @@ } +@pytest.fixture(scope="session", autouse=True) +def from_tests_dir(): + # Makes default project correct. + here = Path(__file__).parent + orig = Path.cwd() + if orig != here: + os.chdir(f"{here}") + + yield + + if Path.cwd() != orig: + os.chdir(f"{orig}") + + +@pytest.fixture(scope="session", autouse=True) +def config(): + cfg = ape.config + + # Ensure we don't persist any .ape data. + with tempfile.TemporaryDirectory() as temp_dir: + path = Path(temp_dir).resolve() + cfg.DATA_FOLDER = path + yield cfg + + def contract_test_cases(passing: bool) -> List[str]: """ Returns test-case names for outputting nicely with pytest. @@ -127,16 +146,6 @@ def temp_vvm_path(monkeypatch): yield path -@pytest.fixture -def data_folder(): - return DATA_FOLDER - - -@pytest.fixture -def project_folder(): - return PROJECT_FOLDER - - @pytest.fixture def compiler_manager(): return ape.compilers @@ -147,28 +156,17 @@ def compiler(compiler_manager): return compiler_manager.vyper -@pytest.fixture -def config(): - return ape.config - - -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def project(config): project_source_dir = Path(__file__).parent - project_dest_dir = config.PROJECT_FOLDER / project_source_dir.name # Delete build / .cache that may exist pre-copy - project_path = Path(__file__).parent - cache = project_path / ".build" - - if cache.is_dir(): - shutil.rmtree(cache) + cache = project_source_dir / ".build" + shutil.rmtree(cache, ignore_errors=True) - copy_tree(project_source_dir.as_posix(), project_dest_dir.as_posix()) - with config.using_project(project_dest_dir) as project: - yield project - if project.local_project._cache_folder.is_dir(): - shutil.rmtree(project.local_project._cache_folder) + root_project = ape.Project(project_source_dir) + with root_project.isolate_in_tempdir() as tmp_project: + yield tmp_project @pytest.fixture diff --git a/tests/contracts/passing_contracts/flatten_me.vy b/tests/contracts/passing_contracts/flatten_me.vy index 456e2457..77b993d2 100644 --- a/tests/contracts/passing_contracts/flatten_me.vy +++ b/tests/contracts/passing_contracts/flatten_me.vy @@ -4,7 +4,7 @@ from vyper.interfaces import ERC20 from interfaces import IFace2 as IFaceTwo import interfaces.IFace as IFace -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep @external diff --git a/tests/contracts/passing_contracts/use_iface.vy b/tests/contracts/passing_contracts/use_iface.vy index ac136ddc..38f99932 100644 --- a/tests/contracts/passing_contracts/use_iface.vy +++ b/tests/contracts/passing_contracts/use_iface.vy @@ -4,7 +4,7 @@ import interfaces.IFace as IFace # Import from input JSON (ape-config.yaml). -import exampledep.Dependency as Dep +import exampledependency.Dependency as Dep from interfaces import IFace2 as IFace2 diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 98f00972..c9a94fec 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -1,8 +1,10 @@ import re +from pathlib import Path +import ape import pytest import vvm # type: ignore -from ape.exceptions import ContractLogicError +from ape.exceptions import CompilerError, ContractLogicError from ethpm_types import ContractType from packaging.version import Version from vvm.exceptions import VyperError # type: ignore @@ -48,39 +50,42 @@ def test_compile_project(project): assert len(contracts) == len( [p.name for p in project.contracts_folder.glob("*.vy") if p.is_file()] ) - assert contracts["contract_039"].source_id == "contract_039.vy" - assert contracts["contract_no_pragma"].source_id == "contract_no_pragma.vy" - assert contracts["older_version"].source_id == "older_version.vy" + prefix = "contracts/passing_contracts" + assert contracts["contract_039"].source_id == f"{prefix}/contract_039.vy" + assert contracts["contract_no_pragma"].source_id == f"{prefix}/contract_no_pragma.vy" + assert contracts["older_version"].source_id == f"{prefix}/older_version.vy" @pytest.mark.parametrize("contract_name", PASSING_CONTRACT_NAMES) def test_compile_individual_contracts(project, contract_name, compiler): path = project.contracts_folder / contract_name - assert compiler.compile([path]) + assert compiler.compile((path,), project=project) @pytest.mark.parametrize( "contract_name", [n for n in FAILING_CONTRACT_NAMES if n != "contract_unknown_pragma.vy"] ) def test_compile_failures(contract_name, compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / contract_name with pytest.raises(VyperCompileError, match=EXPECTED_FAIL_PATTERNS[path.stem]) as err: - compiler.compile([path], base_path=FAILING_BASE) + compiler.compile((path,), project=failing_project) assert isinstance(err.value.base_err, VyperError) def test_install_failure(compiler): + failing_project = ape.Project(FAILING_BASE) path = FAILING_BASE / "contract_unknown_pragma.vy" with pytest.raises(VyperInstallError, match="No available version to install."): - compiler.compile([path]) + compiler.compile((path,), project=failing_project) def test_get_version_map(project, compiler, all_versions): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_version_map(vyper_files) + actual = compiler.get_version_map(vyper_files, project=project) expected_versions = [Version(v) for v in all_versions] for version, sources in actual.items(): @@ -154,23 +159,30 @@ def run_test(manifest): assert compiler.settings["evmVersion"] == "istanbul" # There is only one contract with codesize pragma. - assert codesize_latest.contractTypes == ["optimize_codesize"] + assert codesize_latest.contractTypes == [ + "contracts/passing_contracts/optimize_codesize.vy:optimize_codesize" + ] assert codesize_latest.settings["optimize"] == "codesize" # There is only one contract with evm-version pragma. - assert evm_latest.contractTypes == ["evm_pragma"] + assert evm_latest.contractTypes == ["contracts/passing_contracts/evm_pragma.vy:evm_pragma"] assert evm_latest.settings["evmVersion"] == "paris" assert len(true_latest.contractTypes) >= 9 assert len(vyper_028.contractTypes) >= 1 - assert "contract_0310" in true_latest.contractTypes - assert "older_version" in vyper_028.contractTypes + assert ( + "contracts/passing_contracts/contract_0310.vy:contract_0310" + in true_latest.contractTypes + ) + assert ( + "contracts/passing_contracts/older_version.vy:older_version" in vyper_028.contractTypes + ) for compiler in (true_latest, vyper_028): assert compiler.settings["optimize"] is True - project.local_project.update_manifest(compilers=[]) + project.update_manifest(compilers=[]) project.load_contracts(use_cache=False) - run_test(project.local_project.manifest) + run_test(project.manifest) man = project.extract_manifest() run_test(man) @@ -182,7 +194,7 @@ def test_compile_parse_dev_messages(compiler, dev_revert_source, project): The compiler will output a map that maps dev messages to line numbers. See contract_with_dev_messages.vy for more information. """ - result = compiler.compile([dev_revert_source], base_path=project.contracts_folder) + result = compiler.compile((dev_revert_source,), project=project) assert len(result) == 1 @@ -201,18 +213,19 @@ def test_get_imports(compiler, project): vyper_files = [ x for x in project.contracts_folder.iterdir() if x.is_file() and x.suffix == ".vy" ] - actual = compiler.get_imports(vyper_files) + actual = compiler.get_imports(vyper_files, project=project) + prefix = "contracts/passing_contracts" builtin_import = "vyper/interfaces/ERC20.json" local_import = "interfaces/IFace.vy" local_from_import = "interfaces/IFace2.vy" - dependency_import = "exampledep/Dependency.json" - - assert len(actual["contract_037.vy"]) == 1 - assert set(actual["contract_037.vy"]) == {builtin_import} - assert len(actual["use_iface.vy"]) == 3 - assert set(actual["use_iface.vy"]) == {local_import, local_from_import, dependency_import} - assert len(actual["use_iface2.vy"]) == 1 - assert set(actual["use_iface2.vy"]) == {local_import} + dependency_import = "exampledependency/contracts/Dependency.vy" + assert set(actual[f"{prefix}/contract_037.vy"]) == {builtin_import} + assert set(actual[f"{prefix}/use_iface.vy"]) == { + local_import, + local_from_import, + dependency_import, + } + assert set(actual[f"{prefix}/use_iface2.vy"]) == {local_import} @pytest.mark.parametrize("src,vers", [("contract_039", "0.3.9"), ("contract_037", "0.3.7")]) @@ -222,15 +235,16 @@ def test_pc_map(compiler, project, src, vers): from `compile_src()` which includes the uncompressed source map data. """ - path = project.contracts_folder / f"{src}.vy" - result = compiler.compile([path], base_path=project.contracts_folder)[0] + path = project.sources.lookup(src) + result = compiler.compile((path,), project=project)[0] actual = result.pcmap.root code = path.read_text() vvm.install_vyper(vers) - compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=compiler.evm_version)[ - "" - ] - src_map = compile_result["source_map"] + cfg = compiler.get_config(project=project) + evm_version = cfg.evm_version + compile_result = vvm.compile_source(code, vyper_version=vers, evm_version=evm_version) + std_result = compile_result[""] + src_map = std_result["source_map"] lines = code.splitlines() # Use the old-fashioned way of gathering PCMap to ensure our creative way works @@ -386,7 +400,7 @@ def test_enrich_error_handle_when_name(compiler, geth_provider, mocker): def test_trace_source(account, geth_provider, project, traceback_contract, arguments): receipt = traceback_contract.addBalance(*arguments, sender=account) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name expected = rf""" Traceback (most recent call last) @@ -434,7 +448,7 @@ def test_trace_err_source(account, geth_provider, project, traceback_contract): receipt = geth_provider.get_receipt(txn.txn_hash.hex()) actual = receipt.source_traceback - base_folder = project.contracts_folder + base_folder = Path(__file__).parent / "contracts" / "passing_contracts" contract_name = traceback_contract.contract_type.name version_key = contract_name.split("traceback_contract_")[-1] expected = rf""" @@ -475,19 +489,20 @@ def test_compile_with_version_set_in_config(config, projects_path, compiler, moc path = projects_path / "version_in_config" version_from_config = "0.3.7" spy = mocker.patch("ape_vyper.compiler.vvm_compile_standard") - with config.using_project(path) as project: - contract = project.contracts_folder / "v_contract.vy" - settings = compiler.get_compiler_settings((contract,)) - assert str(list(settings.keys())[0]) == version_from_config + project = ape.Project(path) - # Show it uses this version in the compiler. - project.load_contracts(use_cache=False) - assert str(spy.call_args[1]["vyper_version"]) == version_from_config + contract = project.contracts_folder / "v_contract.vy" + settings = compiler.get_compiler_settings((contract,), project=project) + assert str(list(settings.keys())[0]) == version_from_config + # Show it uses this version in the compiler. + project.load_contracts(use_cache=False) + assert str(spy.call_args[1]["vyper_version"]) == version_from_config -def test_compile_code(compiler, dev_revert_source): + +def test_compile_code(project, compiler, dev_revert_source): code = dev_revert_source.read_text() - actual = compiler.compile_code(code, contractName="MyContract") + actual = compiler.compile_code(code, project=project, contractName="MyContract") assert isinstance(actual, ContractType) assert actual.name == "MyContract" assert len(actual.abi) > 1 @@ -498,13 +513,13 @@ def test_compile_code(compiler, dev_revert_source): def test_compile_with_version_set_in_settings_dict(config, compiler_manager, projects_path): path = projects_path / "version_in_config" contract = path / "contracts" / "v_contract.vy" - - with config.using_project(path): - expected = ( - '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' - ) - with pytest.raises(VyperCompileError, match=expected): - compiler_manager.compile([contract], settings={"version": "0.3.3"}) + project = ape.Project(path) + expected = '.*Version specification "0.3.10" is not compatible with compiler version "0.3.3"' + iterator = compiler_manager.compile( + (contract,), project=project, settings={"vyper": {"version": "0.3.3"}} + ) + with pytest.raises(CompilerError, match=expected): + _ = list(iterator) @pytest.mark.parametrize( @@ -526,7 +541,7 @@ def test_compile_with_version_set_in_settings_dict(config, compiler_manager, pro ) def test_flatten_contract(all_versions, project, contract_name, compiler): path = project.contracts_folder / contract_name - source = compiler.flatten_contract(path) + source = compiler.flatten_contract(path, project=project) source_code = str(source) version = compiler._source_vyper_version(source_code) vvm.install_vyper(str(version))