Skip to content

Commit

Permalink
feat: support new pragma formats
Browse files Browse the repository at this point in the history
  • Loading branch information
z80dev committed Oct 24, 2023
1 parent 798874a commit f0b9533
Showing 1 changed file with 116 additions and 79 deletions.
195 changes: 116 additions & 79 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def _install_vyper(version: Version):
) from err


def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
def get_version_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
"""
Extracts pragma information from Vyper source code.
Extracts version pragma information from Vyper source code.
Args:
source (str): Vyper source code
Expand All @@ -81,7 +81,10 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#\s*@version\s*([^\n]*)", source_str), None)
if pragma_match is None:
return None # Try compiling with latest
# support new pragma syntax
pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+version\s*([^\n]*)", source_str), None)
if pragma_match is None:
return None # Try compiling with latest

raw_pragma = pragma_match.groups()[0]
pragma_str = " ".join(raw_pragma.split()).replace("^", "~=")
Expand All @@ -94,6 +97,22 @@ def get_pragma_spec(source: Union[str, Path]) -> Optional[SpecifierSet]:
logger.warning(f"Invalid pragma spec: '{raw_pragma}'. Trying latest.")
return None

def get_optimization_pragma(source: Union[str, Path]) -> Optional[str | bool]:
"""
Extracts optimization pragma information from Vyper source code.
Args:
source (str): Vyper source code
Returns:
``str``, or True if no valid pragma is found (for backwards compatibility).
"""
source_str = source if isinstance(source, str) else source.read_text()
pragma_match = next(re.finditer(r"(?:\n|^)\s*#pragma\s+optimize\s+([^\n]*)", source_str), None)
if pragma_match is None:
return True
return pragma_match.groups()[0]



class VyperCompiler(CompilerAPI):
@property
Expand Down Expand Up @@ -145,7 +164,7 @@ def get_imports(
def get_versions(self, all_paths: List[Path]) -> Set[str]:
versions = set()
for path in all_paths:
if version_spec := get_pragma_spec(path):
if version_spec := get_version_pragma_spec(path):
try:
# Make sure we have the best compiler available to compile this
version_iter = version_spec.filter(self.available_versions)
Expand Down Expand Up @@ -283,88 +302,106 @@ def compile(
for vyper_version, source_paths in version_map.items():
settings = all_settings.get(vyper_version, {})
path_args = {str(get_relative_path(p.absolute(), base_path)): p for p in source_paths}
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err
optimizations_map = self.get_optimization_pragma_map(list(source_paths))

def classify_ast(_node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION
for optimization, paths in optimizations_map.items():
input_json = {
"language": "Vyper",
"settings": settings,
"sources": {s: {"content": p.read_text()} for s, p in path_args.items()},
}

for child in _node.children:
classify_ast(child)
input_json["settings"]["optimize"] = optimization
if interfaces := self.import_remapping:
input_json["interfaces"] = interfaces

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
try:
result = vvm.compile_standard(
input_json,
base_path=base_path,
vyper_version=vyper_version,
vyper_binary=vyper_binary,
)
except VyperError as err:
raise VyperCompileError(err) from err

def classify_ast(_node: ASTNode):
if _node.ast_type in _FUNCTION_AST_TYPES:
_node.classification = ASTClassification.FUNCTION

for child in _node.children:
classify_ast(child)

for source_id, output_items in result["contracts"].items():
content = {
i + 1: ln
for i, ln in enumerate((base_path / source_id).read_text().splitlines())
}
for name, output in output_items.items():
# De-compress source map to get PC POS map.
ast = ASTNode.parse_obj(result["sources"][source_id]["ast"])
classify_ast(ast)

# Track function offsets.
function_offsets = []
for node in ast.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(
lineno, ""
):
function_offsets.append((node.lineno, node.end_lineno))

evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = SourceMap(__root__=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]

pcmap = (
_get_legacy_pcmap(ast, src_map, opcodes)
if vyper_version <= Version("0.3.7")
else _get_pcmap(bytecode)
)

# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)
# Find content-specified dev messages.
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

contract_type = ContractType(
ast=ast,
contractName=name,
sourceId=source_id,
deploymentBytecode={"bytecode": evm["bytecode"]["object"]},
runtimeBytecode={"bytecode": bytecode["object"]},
abi=output["abi"],
sourcemap=compressed_src_map,
pcmap=pcmap,
userdoc=output["userdoc"],
devdoc=output["devdoc"],
dev_messages=dev_messages,
)
contract_types.append(contract_type)

return contract_types

def get_optimization_pragma_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[str, Set[Path]]:
base_path = base_path or self.config_manager.contracts_folder
optimization_pragma_map: Dict[str, Set[Path]] = {}
for path in contract_filepaths:
if pragma := get_optimization_pragma(path):
if pragma not in optimization_pragma_map:
optimization_pragma_map[pragma] = set()
optimization_pragma_map[pragma].add(path)

return optimization_pragma_map

def get_version_map(
self, contract_filepaths: List[Path], base_path: Optional[Path] = None
) -> Dict[Version, Set[Path]]:
Expand All @@ -374,7 +411,7 @@ def get_version_map(

# Sort contract_filepaths to promote consistent, reproduce-able behavior
for path in sorted(contract_filepaths):
if pragma := get_pragma_spec(path):
if pragma := get_version_pragma_spec(path):
_safe_append(source_path_by_pragma_spec, pragma, path)
else:
source_paths_without_pragma.add(path)
Expand Down

0 comments on commit f0b9533

Please sign in to comment.