Skip to content

Commit

Permalink
feat: decimal config
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Sep 28, 2024
1 parent a554daf commit e592cf9
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 94 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,15 @@ Import the voting contract types like this:
import voting.ballot as ballot
```

### Decimals

To use decimals on Vyper 0.4, use the following config:

```yaml
vyper:
enable_decimals: true
```
### Pragmas
Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493)
Expand Down
7 changes: 5 additions & 2 deletions tests/ape-config.yaml → ape-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Allows compiling to work from the project-level.
contracts_folder: contracts/passing_contracts
contracts_folder: tests/contracts/passing_contracts

# Specify a dependency to use in Vyper imports.
dependencies:
- name: exampledependency
local: ./ExampleDependency
local: ./tests/ExampleDependency

# NOTE: Snekmate does not need to be listed here since
# it is installed in site-packages. However, we include it
# to show it doesn't cause problems when included.
- python: snekmate
config_override:
contracts_folder: .

vyper:
enable_decimals: true
3 changes: 1 addition & 2 deletions ape_vyper/compiler/_versions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def get_settings(
optimization = False

selection_dict = self._get_selection_dictionary(selection, project=pm)
search_paths = [*getsitepackages()]
search_paths.append(".")
search_paths = [*getsitepackages(), "."]

version_settings[settings_key] = {
"optimize": optimization,
Expand Down
19 changes: 19 additions & 0 deletions ape_vyper/compiler/_versions/vyper_04.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict
# You always import via module or package name.
return {}

def get_settings(
self,
version: Version,
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional[ProjectManager] = None,
) -> dict:
pm = project or self.local_project

enable_decimals = self.api.get_config(project=pm).enable_decimals
if enable_decimals is None:
enable_decimals = False

settings = super().get_settings(version, source_paths, compiler_data, project=pm)
for settings_set in settings.values():
settings_set["enable_decimals"] = enable_decimals

return settings

def _get_sources_dictionary(
self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs
) -> dict[str, dict]:
Expand Down
16 changes: 12 additions & 4 deletions ape_vyper/compiler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,17 @@ def compile(
settings: Optional[dict] = None,
) -> Iterator[ContractType]:
pm = project or self.local_project

original_settings = self.compiler_settings
self.compiler_settings = {**self.compiler_settings, **(settings or {})}
try:
yield from self._compile(contract_filepaths, project=pm)
finally:
self.compiler_settings = original_settings

def _compile(
self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None
):
pm = project or self.local_project
contract_types: list[ContractType] = []
import_map = self._import_resolver.get_imports(pm, contract_filepaths)
config = self.get_config(pm)
Expand Down Expand Up @@ -514,12 +523,11 @@ def init_coverage_profile(
def enrich_error(self, err: ContractLogicError) -> ContractLogicError:
return enrich_error(err)

# TODO: In 0.9, make sure project is a kwarg here.
def trace_source(
self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes
) -> SourceTraceback:
frames = trace.get_raw_frames()
tracer = SourceTracer(contract_source, frames, calldata)
return tracer.trace()
return SourceTracer.trace(trace.get_raw_frames(), contract_source, calldata)

def _get_compiler_arguments(
self,
Expand Down
7 changes: 7 additions & 0 deletions ape_vyper/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class VyperConfig(PluginConfig):
"""

enable_decimals: Optional[bool] = None
"""
On Vyper 0.4, to use decimal types, you must enable it.
Defaults to ``None`` to avoid misleading that ``False``
means you cannot use decimals on a lower version.
"""

@field_validator("version", mode="before")
def validate_version(cls, value):
return pragma_str_to_specifier_set(value) if isinstance(value, str) else value
Expand Down
26 changes: 6 additions & 20 deletions ape_vyper/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def __init__(self, project: ProjectManager, paths: list[Path]):
# Even though we build up mappings of all sources, as may be referenced
# later on and that prevents re-calculating over again, we only
# "show" the items requested.
self._request_view: list[Path] = paths
self.paths: list[Path] = paths

def __getitem__(self, item: Union[str, Path], *args, **kwargs) -> list[Import]:
if isinstance(item, str) or not item.is_absolute():
Expand Down Expand Up @@ -294,7 +294,7 @@ def keys(self) -> list[Path]: # type: ignore
result = []
keys = sorted(list(super().keys()))
for path in keys:
if path not in self._request_view:
if path not in self.paths:
continue

result.append(path)
Expand All @@ -311,7 +311,7 @@ def values(self) -> list[list[Import]]: # type: ignore
def items(self) -> list[tuple[Path, list[Import]]]: # type: ignore
result = []
for path in self.keys(): # sorted
if path not in self._request_view:
if path not in self.paths:
continue

result.append((path, self[path]))
Expand All @@ -328,30 +328,16 @@ class ImportResolver(ManagerAccessMixin):
_projects: dict[str, ImportMap] = {}
_dependency_attempted_compile: set[str] = set()

def get_imports(
self,
project: ProjectManager,
contract_filepaths: Iterable[Path],
) -> ImportMap:
def get_imports(self, project: ProjectManager, contract_filepaths: Iterable[Path]) -> ImportMap:
paths = list(contract_filepaths)
reset_view = None
if project.project_id not in self._projects:
self._projects[project.project_id] = ImportMap(project, paths)
else:
# Change the items we "view". Some (or all) may need to be added as well.
reset_view = self._projects[project.project_id]._request_view
self._projects[project.project_id]._request_view = paths

try:
import_map = self._get_imports(paths, project)
finally:
if reset_view is not None:
self._projects[project.project_id]._request_view = reset_view

return import_map
return self._get_imports(paths, project)

def _get_imports(self, paths: list[Path], project: ProjectManager) -> ImportMap:
import_map = self._projects[project.project_id]
import_map.paths = list({*import_map.paths, *paths})
for path in paths:
if path in import_map:
# Already handled.
Expand Down
74 changes: 39 additions & 35 deletions ape_vyper/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Optional, cast

from ape.managers import ProjectManager
from ape.types import SourceTraceback
from ape.utils import ManagerAccessMixin, get_full_extension
from eth_pydantic_types import HexBytes
Expand All @@ -20,58 +21,57 @@ class SourceTracer(ManagerAccessMixin):
Use EVM data to create a trace of Vyper source lines.
"""

def __init__(self, contract_source: ContractSource, frames: Iterator[dict], calldata: HexBytes):
self.contract_source = contract_source
self.frames = frames
self.calldata = calldata

@classmethod
def trace(
self,
contract: Optional[ContractSource] = None,
calldata: Optional[HexBytes] = None,
cls,
frames: Iterator[dict],
contract: ContractSource,
calldata: HexBytes,
previous_depth: Optional[int] = None,
project: Optional[ProjectManager] = None,
) -> SourceTraceback:
contract_source = self.contract_source if contract is None else contract
calldata = self.calldata if calldata is None else calldata
pm = project or cls.local_project
method_id = HexBytes(calldata[:4])
traceback = SourceTraceback.model_validate([])
completed = False
pcmap = PCMap.model_validate({})

for frame in self.frames:
for frame in frames:
if frame["op"] in [c.value for c in CALL_OPCODES]:
start_depth = frame["depth"]
called_contract, sub_calldata = self._create_contract_from_call(frame)
called_contract, sub_calldata = cls._create_contract_from_call(frame, project=pm)
if called_contract:
ext = get_full_extension(Path(called_contract.source_id))
if ext in [x for x in FileType]:
# Called another Vyper contract.
sub_trace = self.trace(
contract=called_contract,
calldata=sub_calldata,
sub_trace = cls.trace(
frames,
called_contract,
sub_calldata,
previous_depth=frame["depth"],
project=pm,
)
traceback.extend(sub_trace)

else:
# Not a Vyper contract!
compiler = self.compiler_manager.registered_compilers[ext]
compiler = cls.compiler_manager.registered_compilers[ext]
try:
sub_trace = compiler.trace_source(
called_contract.contract_type, self.frames, sub_calldata
called_contract.contract_type, frames, sub_calldata
)
traceback.extend(sub_trace)
except NotImplementedError:
# Compiler not supported. Fast forward out of this call.
for fr in self.frames:
for fr in frames:
if fr["depth"] <= start_depth:
break

continue

else:
# Contract not found. Fast forward out of this call.
for fr in self.frames:
for fr in frames:
if fr["depth"] <= start_depth:
break

Expand All @@ -83,14 +83,14 @@ def trace(
completed = previous_depth is not None

pcs_to_try_adding = set()
if "PUSH" in frame["op"] and frame["pc"] in contract_source.pcmap:
if "PUSH" in frame["op"] and frame["pc"] in contract.pcmap:
# Check if next op is SSTORE to properly use AST from push op.
next_frame: Optional[dict] = frame
loc = contract_source.pcmap[frame["pc"]]
loc = contract.pcmap[frame["pc"]]
pcs_to_try_adding.add(frame["pc"])

while next_frame and "PUSH" in next_frame["op"]:
next_frame = next(self.frames, None)
next_frame = next(frames, None)
if next_frame and "PUSH" in next_frame["op"]:
pcs_to_try_adding.add(next_frame["pc"])

Expand All @@ -103,15 +103,15 @@ def trace(
completed = True

else:
pcmap = contract_source.pcmap
pcmap = contract.pcmap
dev_val = str((loc.get("dev") or "")).replace("dev: ", "")
is_non_payable_hit = dev_val == RuntimeErrorType.NONPAYABLE_CHECK.value

if not is_non_payable_hit and next_frame:
frame = next_frame

else:
pcmap = contract_source.pcmap
pcmap = contract.pcmap

pcs_to_try_adding.add(frame["pc"])
pcs_to_try_adding = {pc for pc in pcs_to_try_adding if pc in pcmap}
Expand Down Expand Up @@ -147,7 +147,7 @@ def trace(
# New group.
pc_groups.append([location, {pc}, dev])

dev_messages = contract_source.contract_type.dev_messages or {}
dev_messages = contract.contract_type.dev_messages or {}
for location, pcs, dev in pc_groups:
if dev in [m.value for m in RuntimeErrorType if m != RuntimeErrorType.USER_ASSERT]:
error_type = RuntimeErrorType(dev)
Expand All @@ -160,9 +160,9 @@ def trace(
name = traceback.last.closure.name
full_name = traceback.last.closure.full_name

elif method_id in contract_source.contract_type.methods:
elif method_id in contract.contract_type.methods:
# For non-payable checks, they should hit here.
method_checked = contract_source.contract_type.methods[method_id]
method_checked = contract.contract_type.methods[method_id]
name = method_checked.name
full_name = method_checked.selector

Expand All @@ -186,15 +186,15 @@ def trace(
f"dev: {dev}",
full_name=full_name,
pcs=pcs,
source_path=contract_source.source_path,
source_path=contract.source_path,
)
continue

elif not location:
# Unknown.
continue

if not (function := contract_source.lookup_function(location, method_id=method_id)):
if not (function := contract.lookup_function(location, method_id=method_id)):
continue

if (
Expand All @@ -213,7 +213,7 @@ def trace(
function,
depth,
pcs=pcs,
source_path=contract_source.source_path,
source_path=contract.source_path,
)
else:
traceback.extend_last(location, pcs=pcs)
Expand All @@ -235,20 +235,24 @@ def trace(

return traceback

def _create_contract_from_call(self, frame: dict) -> tuple[Optional[ContractSource], HexBytes]:
@classmethod
def _create_contract_from_call(
cls, frame: dict, project: Optional[ProjectManager] = None
) -> tuple[Optional[ContractSource], HexBytes]:
pm = project or cls.local_project
evm_frame = TraceFrame(**frame)
data = create_call_node_data(evm_frame)
calldata = data.get("calldata", HexBytes(""))
if not (address := (data.get("address", evm_frame.contract_address) or None)):
return None, calldata

try:
address = self.provider.network.ecosystem.decode_address(address)
address = cls.provider.network.ecosystem.decode_address(address)
except Exception:
return None, calldata

if address not in self.chain_manager.contracts:
if address not in cls.chain_manager.contracts:
return None, calldata

called_contract = self.chain_manager.contracts[address]
return self.local_project._create_contract_source(called_contract), calldata
called_contract = cls.chain_manager.contracts[address]
return pm._create_contract_source(called_contract), calldata
Loading

0 comments on commit e592cf9

Please sign in to comment.