Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Vyper 0.3.10 #97

Merged
merged 16 commits into from
Oct 26, 2023
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,23 @@ vyper:
Import the voting contract types like this:

```python
# @version 0.3.9
# @version 0.3.10

import voting.ballot as ballot
```

### Pragmas

Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)

#### Version Pragma

```python
#pragma version 0.3.10
```

#### Optimization Pragma

```python
#pragma optimize codesize
```
214 changes: 127 additions & 87 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _install_vyper(version: Version):
) from err


def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
"""
Extracts pragma information from Vyper source code.
Extracts version pragma information from Vyper source code.

Args:
source (str): Vyper source code
Expand All @@ -81,7 +81,12 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None)
if pragma_match is None:
return None # Try compiling with latest
# support new pragma syntax
pragma_match = next(
re.finditer(r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", source_str), None
z80dev marked this conversation as resolved.
Show resolved Hide resolved
)
if pragma_match is None:
return None # Try compiling with latest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might work a little nicer

    if not (pragma_match := next(
        itertools.chain(
            re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str),
            re.finditer(r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", source_str),
            None,
        )
    ):
        return None  # Try compiling with latest

Copy link
Contributor Author

@z80dev z80dev Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

had to make a few tweaks to this:

  • we want to do something with the match if we have it, so I had to invert the condition here by removing the not
  • all the args to itertools.chain must be iterators, so None caused an exception, I had to wrap it as [None] which accomplishes what we want, but looks a bit weird. Do you have an idea for making it more readable?
def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
    """
    Extracts version pragma information from Vyper source code.

    Args:
        source (str): Vyper source code

    Returns:
        ``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found.
    """
    source_str = source if isinstance(source, str) else source.read_text()
    if pragma_match := next(
        itertools.chain(
            re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str),
            re.finditer(r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)", source_str),
            [None],
        )
    ):
        raw_pragma = pragma_match.groups()[0]
        pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
        if pragma_str and pragma_str[0].isnumeric():
            pragma_str = f"=={pragma_str}"

        try:
            return SpecifierSet(pragma_str)
        except InvalidSpecifier:
            logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.")
            return None
    else:
        return None

Copy link
Contributor Author

@z80dev z80dev Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ended up with this:

def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
    """
    Extracts version pragma information from Vyper source code.

    Args:
        source (str): Vyper source code

    Returns:
        ``packaging.specifiers.SpecifierSet``, or None if no valid pragma is found.
    """
    version_pragma_patterns = [
        r"(?:\n|^)\s*#\s*@version\s*([^\n]*)",
        r"(?:\n|^)\s*#\s*pragma\s+version\s*([^\n]*)",
    ]

    source_str = source if isinstance(source, str) else source.read_text()
    for pattern in version_pragma_patterns:
        for match in re.finditer(pattern, source_str):
            raw_pragma = match.groups()[0]
            pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
            if pragma_str and pragma_str[0].isnumeric():
                pragma_str = f"=={pragma_str}"

            try:
                return SpecifierSet(pragma_str)
            except InvalidSpecifier:
                logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.")
                return None
    return None


raw_pragma = pragma_match.groups()[0]
pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
Expand All @@ -95,6 +100,23 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
return None


def get_optimization_pragma(source: Union[str, Path]) -> Optional[str]:
"""
Extracts optimization pragma information from Vyper source code.

Args:
source (str): Vyper source code
z80dev marked this conversation as resolved.
Show resolved Hide resolved

Returns:
z80dev marked this conversation as resolved.
Show resolved Hide resolved
``str``, or True if no valid pragma is found (for backwards compatibility).
"""
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None)
if pragma_match is None:
return None
z80dev marked this conversation as resolved.
Show resolved Hide resolved
return pragma_match.groups()[0]


class VyperCompiler(CompilerAPI):
@property
def config(self) -> VyperConfig:
Expand Down Expand Up @@ -145,7 +167,7 @@ def get_imports(
def get_versions(self, all_paths: List[Path]) -> Set[str]:
versions = set()
for path in all_paths:
if version_spec := get_pragma_spec(path):
if version_spec := get_version_pragma_spec(path):
try:
# Make sure we have the best compiler available to compile this
version_iter = version_spec.filter(self.available_versions)
Expand Down Expand Up @@ -281,90 +303,112 @@ def compile(
all_settings = self.get_compiler_settings(sources, base_path=base_path)

for vyper_version, source_paths in version_map.items():
settings = all_settings.get(vyper_version, {})
path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION
version_settings = all_settings.get(vyper_version, {})
optimizations_map = self.get_optimization_pragma_map(list(source_paths))

for optimization, source_paths in optimizations_map.items():
settings: Dict[str, Any] = version_settings.copy()
settings["optimize"] = optimization or True
path_args = {
str(get_relative_path(p.absolute(), base_path)): p for p in source_paths
}
settings["outputSelection"] = {s: ["*"] for s in path_args}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}

for child in _node.children:
classify_ast(child)
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
z80dev marked this conversation as resolved.
Show resolved Hide resolved
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION

for child in _node.children:
classify_ast(child)

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
)

# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)
# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)

return contract_types

def get_optimization_pragma_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Union[str, bool], Set[Path]]:
base_path = base_path or self.config_manager.contracts_folder
optimization_pragma_map: Dict[Union[str, bool], Set[Path]] = {}
for path in contract_filepaths:
pragma = get_optimization_pragma(path) or True
if pragma not in optimization_pragma_map:
optimization_pragma_map[pragma] = set()
optimization_pragma_map[pragma].add(path)

return optimization_pragma_map

def get_version_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Version, Set[Path]]:
Expand All @@ -374,7 +418,7 @@ def get_version_map(

# Sort contract_filepaths to promote consistent, reproduce-able behavior
for path in sorted(contract_filepaths):
if pragma := get_pragma_spec(path):
if pragma := get_version_pragma_spec(path):
_safe_append(source_path_by_pragma_spec, pragma, path)
else:
source_paths_without_pragma.add(path)
Expand Down Expand Up @@ -441,10 +485,6 @@ def get_compiler_settings(
continue

version_settings: Dict = {"optimize": True}
path_args = {
str(get_relative_path(p.absolute(), contracts_path)): p for p in source_paths
}
version_settings["outputSelection"] = {s: ["*"] for s in path_args}
if evm_version := data.get("evm_version"):
version_settings["evmVersion"] = evm_version

Expand Down Expand Up @@ -955,7 +995,7 @@ def _get_pcmap(bytecode: Dict) -> PCMap:
error_str = RuntimeErrorType.FALLBACK_NOT_DEFINED.value
use_loc = False
elif "bad calldatasize or callvalue" in error_type:
# Only on >=0.3.10rc3.
# Only on >=0.3.10.
# NOTE: We are no longer able to get Nonpayable checks errors since they
# are now combined.
error_str = RuntimeErrorType.INVALID_CALLDATA_OR_VALUE.value
Expand Down
2 changes: 1 addition & 1 deletion ape_vyper/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, **kwargs):

class InvalidCalldataOrValueError(VyperRuntimeError):
"""
Raises on Vyper versions >= 0.3.10rc3 in place of NonPayableError.
Raises on Vyper versions >= 0.3.10 in place of NonPayableError.
"""

def __init__(self, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"black>=23.3.0,<24", # Auto-formatter and linter
"mypy>=0.991,<1", # Static type analyzer
"types-setuptools", # Needed due to mypy typeshed
"pydantic<2.0", # Needed for successful type check. TODO: Remove after full v2 support.
"flake8>=6.0.0,<7", # Style linter
"isort>=5.10.1", # Import sorting linter
"mdformat>=0.7.16", # Auto-formatter for markdown
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
"0.3.4",
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
)

CONTRACT_VERSION_GEN_MAP = {
"": (
"0.3.7",
"0.3.9",
"0.3.10rc3",
"0.3.10",
),
"sub_reverts": ALL_VERSIONS,
}
Expand Down Expand Up @@ -188,7 +188,7 @@ def account():
return ape.accounts.test_accounts[0]


@pytest.fixture(params=("037", "039", "0310rc3"))
@pytest.fixture(params=("037", "039", "0310"))
def traceback_contract(request, account, project, geth_provider):
return _get_tb_contract(request.param, project, account)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# @version 0.3.9
# @version 0.3.10

# Test dev messages in various code placements
@external
Expand Down
8 changes: 8 additions & 0 deletions tests/contracts/passing_contracts/optimize_codesize.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma version 0.3.10
z80dev marked this conversation as resolved.
Show resolved Hide resolved
#pragma optimize codesize

x: uint256

@external
def __init__():
self.x = 0
7 changes: 7 additions & 0 deletions tests/contracts/passing_contracts/pragma_with_space.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pragma version 0.3.10

x: uint256

@external
def __init__():
self.x = 0
2 changes: 1 addition & 1 deletion tests/test_ape_reverts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def older_reverts_contract(account, project, geth_provider, request):
return container.deploy(sender=account)


@pytest.fixture(params=("037", "039", "0310rc3"))
@pytest.fixture(params=("037", "039", "0310"))
def reverts_contract_instance(account, project, geth_provider, request):
sub_reverts_container = project.get_contract(f"sub_reverts_{request.param}")
sub_reverts = sub_reverts_container.deploy(sender=account)
Expand Down
Loading
Loading