Skip to content

Commit

Permalink
feat: Vyper 0.4 flattener (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 12, 2024
1 parent 2a9bb55 commit 86a607f
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 91 deletions.
245 changes: 164 additions & 81 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
extract_imports,
extract_meta,
generate_interface,
iface_name_from_file,
)

DEV_MSG_PATTERN = re.compile(r".*\s*#\s*(dev:.+)")
Expand Down Expand Up @@ -277,6 +276,77 @@ def get_evm_version_pragma_map(
return pragmas


def _lookup_source_from_site_packages(
dependency_name: str,
filestem: str,
config_override: Optional[dict] = None,
) -> Optional[tuple[Path, ProjectManager]]:
# Attempt looking up dependency from site-packages.
config_override = config_override or {}
if "contracts_folder" not in config_override:
# Default to looking through the whole release for
# contracts. Most often, Python-based dependencies publish
# only their contracts this way, and we are only looking
# for sources so accurate project configuration is not required.
config_override["contracts_folder"] = "."

try:
imported_project = ProjectManager.from_python_library(
dependency_name,
config_override=config_override,
)
except ProjectError as err:
# Still attempt to let Vyper handle this during compilation.
logger.error(
f"'{dependency_name}' may not be installed. "
"Could not find it in Ape dependencies or Python's site-packages. "
f"Error: {err}"
)
else:
extensions = [*[f"{t}" for t in FileType], ".json"]

def seek() -> Optional[Path]:
for ext in extensions:
try_source_id = f"{filestem}{ext}"
if source_path := imported_project.sources.lookup(try_source_id):
return source_path

return None

if res := seek():
return res, imported_project

# Still not found. Try again without contracts_folder set.
# This will attempt to use Ape's contracts_folder detection system.
# However, I am not sure this situation occurs, as Vyper-python
# based dependencies are new at the time of writing this.
new_override = config_override or {}
if "contracts_folder" in new_override:
del new_override["contracts_folder"]

imported_project.reconfigure(**new_override)
if res := seek():
return res, imported_project

# Still not found. Log a very helpful message.
existing_filestems = [f.stem for f in imported_project.path.iterdir()]
fs_str = ", ".join(existing_filestems)
contracts_folder = imported_project.contracts_folder
path = imported_project.path

# This will log the calculated / user-set contracts_folder.
contracts_path = f"{get_relative_path(contracts_folder, path)}"

logger.error(
f"Source for stem '{filestem}' not found in "
f"'{imported_project.path}'."
f"Contracts folder: {contracts_path}, "
f"Existing file(s): {fs_str}"
)

return None


class VyperCompiler(CompilerAPI):
@property
def name(self) -> str:
Expand Down Expand Up @@ -309,12 +379,17 @@ def _get_imports(
if not path.is_file():
continue

content = path.read_text(encoding="utf8").splitlines()
content = path.read_text().splitlines()
source_id = (
str(path.absolute())
if use_absolute_paths
else str(get_relative_path(path.absolute(), pm.path.absolute()))
)

# Prevent infinitely handling imports when they cross over.
if source_id in handled:
continue

handled.add(source_id)
for line in content:
if line.startswith("import "):
Expand Down Expand Up @@ -385,6 +460,7 @@ def _get_imports(
import_source_id = os.path.sep.join(
(path_id, version_str, f"{source_id_stem}{ext}")
)

# Also include imports of imports.
sub_imports = self._get_imports(
(dep_project.path / f"{source_id_stem}{ext}",),
Expand All @@ -396,9 +472,9 @@ def _get_imports(

is_local = False
break
else:
elif dependency_name:
# Attempt looking up dependency from site-packages.
if res := self._lookup_source_from_site_packages(dependency_name, filestem):
if res := _lookup_source_from_site_packages(dependency_name, filestem):
source_path, imported_project = res
import_source_id = str(source_path)
# Also include imports of imports.
Expand Down Expand Up @@ -429,77 +505,6 @@ def _get_imports(

return dict(import_map)

def _lookup_source_from_site_packages(
self,
dependency_name: str,
filestem: str,
config_override: Optional[dict] = None,
) -> Optional[tuple[Path, ProjectManager]]:
# Attempt looking up dependency from site-packages.
config_override = config_override or {}
if "contracts_folder" not in config_override:
# Default to looking through the whole release for
# contracts. Most often, Python-based dependencies publish
# only their contracts this way, and we are only looking
# for sources so accurate project configuration is not required.
config_override["contracts_folder"] = "."

try:
imported_project = ProjectManager.from_python_library(
dependency_name,
config_override=config_override,
)
except ProjectError as err:
# Still attempt to let Vyper handle this during compilation.
logger.error(
f"'{dependency_name}' may not be installed. "
"Could not find it in Ape dependencies or Python's site-packages. "
f"Error: {err}"
)
else:
extensions = [*[f"{t}" for t in FileType], ".json"]

def seek() -> Optional[Path]:
for ext in extensions:
try_source_id = f"{filestem}{ext}"
if source_path := imported_project.sources.lookup(try_source_id):
return source_path

return None

if res := seek():
return res, imported_project

# Still not found. Try again without contracts_folder set.
# This will attempt to use Ape's contracts_folder detection system.
# However, I am not sure this situation occurs, as Vyper-python
# based dependencies are new at the time of writing this.
new_override = config_override or {}
if "contracts_folder" in new_override:
del new_override["contracts_folder"]

imported_project.reconfigure(**new_override)
if res := seek():
return res, imported_project

# Still not found. Log a very helpful message.
existing_filestems = [f.stem for f in imported_project.path.iterdir()]
fs_str = ", ".join(existing_filestems)
contracts_folder = imported_project.contracts_folder
path = imported_project.path

# This will log the calculated / user-set contracts_folder.
contracts_path = f"{get_relative_path(contracts_folder, path)}"

logger.error(
f"Source for stem '{filestem}' not found in "
f"'{imported_project.path}'."
f"Contracts folder: {contracts_path}, "
f"Existing file(s): {fs_str}"
)

return None

def get_versions(self, all_paths: Iterable[Path]) -> set[str]:
versions = set()
for path in all_paths:
Expand Down Expand Up @@ -962,9 +967,17 @@ def first_full_release(versions: Iterable[Version]) -> Optional[Version]:

return next(version_spec.filter(self.available_versions))

def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None) -> str:
def _flatten_source(
self,
path: Path,
project: Optional[ProjectManager] = None,
include_pragma: bool = True,
sources_handled: Optional[set[Path]] = None,
warn_flattening_modules: bool = True,
) -> str:
pm = project or self.local_project

handled = sources_handled or set()
handled.add(path)
# Get the non stdlib import paths for our contracts
imports = list(
filter(
Expand Down Expand Up @@ -992,7 +1005,10 @@ def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None)
# Get info about imports and source meta
aliases = extract_import_aliases(og_source)
pragma, source_without_meta = extract_meta(og_source)
version_specifier = get_version_pragma_spec(pragma) if pragma else None
stdlib_imports, _, source_without_imports = extract_imports(source_without_meta)
flattened_modules = ""
modules_prefixes: set[str] = set()

for import_path in sorted(imports):
import_file = None
Expand All @@ -1007,7 +1023,7 @@ def _flatten_source(self, path: Path, project: Optional[ProjectManager] = None)
import_file = pm.path / import_path

# Vyper imported interface names come from their file names
file_name = iface_name_from_file(import_file)
file_name = import_file.stem
# If we have a known alias, ("import X as Y"), use the alias as interface name
iface_name = aliases[file_name] if file_name in aliases else file_name

Expand All @@ -1032,20 +1048,87 @@ def _match_source(imp_path: str) -> Optional[PackageManifest]:

# Generate an ABI from the source code
elif import_file.is_file():
abis = source_to_abi(import_file.read_text(encoding="utf8"))
interfaces_source += generate_interface(abis, iface_name)
if (
version_specifier
and version_specifier.contains("0.4.0")
and import_file.suffix != ".vyi"
):
if warn_flattening_modules:
logger.warning(
"Flattening modules DOES NOT yield the same bytecode! "
"This is **NOT** valid for contract-verification."
)
warn_flattening_modules = False

modules_prefixes.add(import_file.stem)
if import_file in handled:
# We have already included this source somewhere.
continue

# Is a module or an interface imported from a module.
# Copy in the source code directly.
flattened_module = self._flatten_source(
import_file,
include_pragma=False,
sources_handled=handled,
warn_flattening_modules=warn_flattening_modules,
)
flattened_modules = f"{flattened_modules}\n\n{flattened_module}"

else:
# Vyper <0.4 interface from folder other than interfaces/
# such as a .vyi file in the contracts folder.
abis = source_to_abi(import_file.read_text(encoding="utf8"))
interfaces_source += generate_interface(abis, iface_name)

def no_nones(it: Iterable[Optional[str]]) -> Iterable[str]:
# Type guard like generator to remove Nones and make mypy happy
for el in it:
if el is not None:
yield el

pragma_to_include = pragma if include_pragma else ""

# Join all the OG and generated parts back together
flattened_source = "\n\n".join(
no_nones((pragma, stdlib_imports, interfaces_source, source_without_imports))
no_nones(
(
pragma_to_include,
stdlib_imports,
interfaces_source,
flattened_modules,
source_without_imports,
)
)
)

# Clear module-usage prefixes.
for prefix in modules_prefixes:
# Replace usage lines like 'zero_four_module.moduleMethod()'
# with 'self.moduleMethod()'.
flattened_source = flattened_source.replace(f"{prefix}.", "self.")

# Remove module-level doc-strings, as it causes compilation issues
# when used in root contracts.
lines_no_doc: list[str] = []
in_str_comment = False
for line in flattened_source.splitlines():
line_stripped = line.rstrip()
if not in_str_comment and line_stripped.startswith('"""'):
if line_stripped == '"""' or not line_stripped.endswith('"""'):
in_str_comment = True
continue

elif in_str_comment:
if line_stripped.endswith('"""'):
in_str_comment = False

continue

lines_no_doc.append(line)

flattened_source = "\n".join(lines_no_doc)

# TODO: Replace this nonsense with a real code formatter
def format_source(source: str) -> str:
while "\n\n\n\n" in source:
Expand Down
8 changes: 1 addition & 7 deletions ape_vyper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Tools for working with ABI specs and Vyper interface source code
"""

from pathlib import Path
from typing import Any, Optional, Union

from ethpm_types import ABI, MethodABI
Expand All @@ -17,11 +16,6 @@ def indent_line(line: str, level=1) -> str:
return f"{INDENT * level}{line}"


def iface_name_from_file(fpath: Path) -> str:
"""Get Interface name from file path"""
return fpath.name.split(".")[0]


def generate_inputs(inputs: list[ABIType]) -> str:
"""Generate the source code input args from ABI inputs"""
return ", ".join(f"{i.name}: {i.type}" for i in inputs)
Expand Down Expand Up @@ -71,7 +65,7 @@ def generate_interface(abi: Union[list[dict[str, Any]], list[ABI]], iface_name:


def extract_meta(source_code: str) -> tuple[Optional[str], str]:
"""Extract version pragma, and returne cleaned source"""
"""Extract version pragma, and return cleaned source"""
version_pragma: Optional[str] = None
cleaned_source_lines: list[str] = []

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"ethpm-types", # Use same version as eth-ape
"tqdm", # Use same version as eth-ape
"vvm>=0.2.0,<0.3",
"vyper~=0.3.7",
"vyper>=0.3.7,<0.5",
],
python_requires=">=3.10,<4",
extras_require=extras_require,
Expand Down
9 changes: 9 additions & 0 deletions tests/contracts/passing_contracts/zero_four.vy
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ from snekmate.auth import ownable
# (new in Vyper 0.4).
from ethereum.ercs import IERC20

# `zero_four_module.vy` also imports this next line.
# We are testing that the flattener can handle that.
from . import zero_four_module_2 as zero_four_module_2

@external
@view
def implementThisPlease(role: bytes32) -> bool:
Expand All @@ -20,3 +24,8 @@ def implementThisPlease(role: bytes32) -> bool:
@external
def callModuleFunction(role: bytes32) -> bool:
return zero_four_module.moduleMethod()


@external
def callModule2Function(role: bytes32) -> bool:
return zero_four_module_2.moduleMethod2()
Loading

0 comments on commit 86a607f

Please sign in to comment.