Skip to content

Commit

Permalink
feat: Support Vyper 0.3.10 (#97)
Browse files Browse the repository at this point in the history
* feat: support new pragma formats

* fix: lint

* fix: properly handle rest of settings

needed for per-optimization-level compilation

* fix: typ hint

* fix: type hint should use Union

* fix: make mypy happy

* feat: update tests to support 0.3.10

* add info to README

* fix: remove pprint

* fix: mdformat

* fix: update metadata test

* fix: address PR comments

* fix: pin pydantic

* walrus operator was swallowing stuff

* fix: make version pragma check more readable

* fix: address PR feedback
  • Loading branch information
z80dev authored Oct 26, 2023
1 parent 798874a commit 888ab4c
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 110 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,23 @@ vyper:
Import the voting contract types like this:
```python
# @version 0.3.9
# @version 0.3.10

import voting.ballot as ballot
```

### Pragmas

Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)

#### Version Pragma

```python
#pragma version 0.3.10
```

#### Optimization Pragma

```python
#pragma optimize codesize
```
233 changes: 136 additions & 97 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,31 +68,52 @@ def _install_vyper(version: Version):
) from err


def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
"""
Extracts pragma information from Vyper source code.
Extracts version pragma information from Vyper source code.
Args:
source (str): Vyper source code
Returns:
``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found.
"""
version_pragma_patterns = [
r"(?:\n|^)\s*#\s*@version\s*([^\n]*)",
r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)",
]

source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None)
if pragma_match is None:
return None # Try compiling with latest
for pattern in version_pragma_patterns:
for match in re.finditer(pattern, source_str):
raw_pragma = match.groups()[0]
pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
if pragma_str and pragma_str[0].isnumeric():
pragma_str = f"=={pragma_str}"

raw_pragma = pragma_match.groups()[0]
pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
if pragma_str and pragma_str[0].isnumeric():
pragma_str = f"=={pragma_str}"
try:
return SpecifierSet(pragma_str)
except InvalidSpecifier:
logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.")
return None
return None

try:
return SpecifierSet(pragma_str)
except InvalidSpecifier:
logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.")

def get_optimization_pragma(source: Union[str, Path]) -> Optional[str]:
"""
Extracts optimization pragma information from Vyper source code.
Args:
source (Union[str, Path]): Vyper source code
Returns:
``str``, or None if no valid pragma is found.
"""
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None)
if pragma_match is None:
return None
return pragma_match.groups()[0]


class VyperCompiler(CompilerAPI):
Expand Down Expand Up @@ -145,7 +166,7 @@ def get_imports(
def get_versions(self, all_paths: List[Path]) -> Set[str]:
versions = set()
for path in all_paths:
if version_spec := get_pragma_spec(path):
if version_spec := get_version_pragma_spec(path):
try:
# Make sure we have the best compiler available to compile this
version_iter = version_spec.filter(self.available_versions)
Expand Down Expand Up @@ -270,6 +291,13 @@ def import_remapping(self) -> Dict[str, Dict]:

return interfaces

def classify_ast(self, _node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION

for child in _node.children:
self.classify_ast(child)

def compile(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> List[ContractType]:
Expand All @@ -281,90 +309,105 @@ def compile(
all_settings = self.get_compiler_settings(sources, base_path=base_path)

for vyper_version, source_paths in version_map.items():
settings = all_settings.get(vyper_version, {})
path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION
version_settings = all_settings.get(vyper_version, {})
optimizations_map = self.get_optimization_pragma_map(list(source_paths))

for optimization, source_paths in optimizations_map.items():
settings: Dict[str, Any] = version_settings.copy()
settings["optimize"] = optimization or True
path_args = {
str(get_relative_path(p.absolute(), base_path)): p for p in source_paths
}
settings["outputSelection"] = {s: ["*"] for s in path_args}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}

for child in _node.children:
classify_ast(child)
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
self.classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
)

# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)
# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)

return contract_types

def get_optimization_pragma_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Union[str, bool], Set[Path]]:
base_path = base_path or self.config_manager.contracts_folder
optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {}
for path in contract_filepaths:
pragma = get_optimization_pragma(path) or True
if pragma not in optimization_pragma_map:
optimization_pragma_map[pragma] = set()
optimization_pragma_map[pragma].add(path)

return optimization_pragma_map

def get_version_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Version, Set[Path]]:
Expand All @@ -374,7 +417,7 @@ def get_version_map(

# Sort contract_filepaths to promote consistent, reproduce-able behavior
for path in sorted(contract_filepaths):
if pragma := get_pragma_spec(path):
if pragma := get_version_pragma_spec(path):
_safe_append(source_path_by_pragma_spec, pragma, path)
else:
source_paths_without_pragma.add(path)
Expand Down Expand Up @@ -441,10 +484,6 @@ def get_compiler_settings(
continue

version_settings: Dict = {"optimize": True}
path_args = {
str(get_relative_path(p.absolute(), contracts_path)): p for p in source_paths
}
version_settings["outputSelection"] = {s: ["*"] for s in path_args}
if evm_version := data.get("evm_version"):
version_settings["evmVersion"] = evm_version

Expand Down Expand Up @@ -955,7 +994,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap:
error_str = RuntimeErrorType.FALLBACK_NOT_DEFINED.value
use_loc = False
elif "bad calldatasize or callvalue" in error_type:
# Only on >=0.3.10rc3.
# Only on >=0.3.10.
# NOTE: We are no longer able to get Nonpayable checks errors since they
# are now combined.
error_str = RuntimeErrorType.INVALID_CALLDATA_OR_VALUE.value
Expand Down
2 changes: 1 addition & 1 deletion ape_vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, **kwargs):

class InvalidCalldataOrValueError(VyperRuntimeError):
"""
Raises on Vyper versions >= 0.3.10rc3 in place of NonPayableError.
Raises on Vyper versions >= 0.3.10 in place of NonPayableError.
"""

def __init__(self, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"black>=23.3.0,<24", # Auto-formatter and linter
"mypy>=0.991,<1", # Static type analyzer
"types-setuptools", # Needed due to mypy typeshed
"pydantic<2.0", # Needed for successful type check. TODO: Remove after full v2 support.
"flake8>=6.0.0,<7", # Style linter
"isort>=5.10.1", # Import sorting linter
"mdformat>=0.7.16", # Auto-formatter for markdown
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
"0.3.4",
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
)

CONTRACT_VERSION_GEN_MAP = {
"": (
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
),
"sub_reverts": ALL_VERSIONS,
}
Expand Down Expand Up @@ -188,7 +188,7 @@ def account():
return ape.accounts.test_accounts[0]


@pytest.fixture(params=("037", "039", "0310rc3"))
@pytest.fixture(params=("037", "039", "0310"))
def traceback_contract(request, account, project, geth_provider):
return _get_tb_contract(request.param, project, account)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @version 0.3.9
# @version 0.3.10

# Test dev messages in various code placements
@external
Expand Down
8 changes: 8 additions & 0 deletions tests/contracts/passing_contracts/optimize_codesize.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma version 0.3.10
#pragma optimize codesize

x: uint256

@external
def __init__():
self.x = 0
7 changes: 7 additions & 0 deletions tests/contracts/passing_contracts/pragma_with_space.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pragma version 0.3.10

x: uint256

@external
def __init__():
self.x = 0
Loading

0 comments on commit 888ab4c

Please sign in to comment.