Skip to content

Commit

Permalink
refactor: use vyper bin for 4
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Jan 20, 2025
1 parent 62f14cc commit 1e9bdb1
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 151 deletions.
92 changes: 90 additions & 2 deletions ape_vyper/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import re
import subprocess
import time
from collections.abc import Iterable
from enum import Enum
Expand All @@ -13,8 +15,9 @@
from eth_utils import is_0x_prefixed
from ethpm_types import ASTNode, PCMap, SourceMapItem
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from vvm.exceptions import UnknownOption, UnknownValue

from ape_vyper.exceptions import RuntimeErrorType, VyperInstallError
from ape_vyper.exceptions import RuntimeErrorType, VyperError, VyperInstallError

if TYPE_CHECKING:
from ape.types.trace import SourceTraceback
Expand Down Expand Up @@ -277,7 +280,7 @@ def has_empty_revert(opcodes: list[str]) -> bool:


def get_pcmap(bytecode: dict) -> PCMap:
# Find the non payable value check.
# Find the non-payable value check.
src_info = bytecode["sourceMapFull"] if "sourceMapFull" in bytecode else bytecode["sourceMap"]
pc_data = {pc: {"location": ln} for pc, ln in src_info["pc_pos_map"].items()}
if not pc_data:
Expand Down Expand Up @@ -515,3 +518,88 @@ def _is_fallback_check(opcodes: list[str], op: str) -> bool:
and opcodes[6] == "SHR"
and opcodes[5] == "0xE0"
)


def compile_files(
binary: Path,
source_files: list[Path],
project_path: Path,
output_format: Optional[list[str]] = None,
additional_paths: Optional[list[Path]] = None,
**kwargs,
) -> dict[str, Any]:
"""
Borrowed (and modified) from vvm.
"""

command = [f"{binary}", *[str(f) for f in source_files]]
if output_format:
command.extend(("-f", ",".join(output_format)))

paths = [project_path, *(additional_paths or [])]
for path in paths:
command.extend(("-p", f"{path}"))

if kwargs:
command.extend(_kwargs_to_cli_options(**kwargs))

process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if process.returncode != 0:
raise _handle_process_failure(process)

outputs = process.stdout.decode("utf-8").splitlines()
iter_length = len(output_format)
return {
f"{input_source}": dict(zip(output_format, outputs[i : i + iter_length]))
for i, input_source in zip(range(0, len(outputs), iter_length), source_files)
}


def _kwargs_to_cli_options(**kwargs) -> list[str]:
options = []
for key, value in kwargs.items():
if not value:
continue

key = f"-{key}" if len(key) == 1 else f"--{key.replace('_', '-')}"
if value is True:
# Is a flag.
options.append(key)

else:
# Has a value.
value = (
",".join([f"{v}" for v in value])
if isinstance(value, (list, tuple))
else f"{value}"
)
options.extend((key, value))

return options


def _to_string(key: str, value: Any) -> str:
if isinstance(value, (int, str)):
return str(value)

if isinstance(value, (list, tuple)):
return ",".join(_to_string(key, i) for i in value)


def _handle_process_failure(process) -> Exception:
bin_name = process.args[-1].split(os.path.sep)[-1].split("-")
stderr = process.stderr.decode("utf-8")
if stderr.startswith("unrecognised option"):
# unrecognised option '<FLAG>'
flag = stderr.split("'")[1]
return UnknownOption(f"{bin_name} does not support the '{flag}' option'")

if stderr.startswith("Invalid option"):
# Invalid option to <FLAG>: <OPTION>
flag, option = stderr.split(": ")
flag = flag.split(" ")[-1]
return UnknownValue(
f"{bin_name} does not accept '{option}' as an option for the '{flag}' flag"
)

return VyperError.from_process(process)
49 changes: 22 additions & 27 deletions ape_vyper/compiler/_versions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ethpm_types.ast import ASTClassification
from ethpm_types.source import Content
from vvm import compile_standard as vvm_compile_standard # type: ignore
from vvm.exceptions import VyperError # type: ignore

from ape_vyper._utils import (
DEV_MSG_PATTERN,
Expand All @@ -22,13 +21,14 @@
get_pcmap,
)
from ape_vyper.compiler._versions.utils import output_details
from ape_vyper.exceptions import VyperCompileError
from ape_vyper.exceptions import VyperCompileError, VyperError

if TYPE_CHECKING:
from ape.managers.project import ProjectManager
from packaging.version import Version

from ape_vyper.compiler.api import VyperCompiler
from ape_vyper.config import VyperConfig
from ape_vyper.imports import ImportMap


Expand All @@ -40,6 +40,17 @@ class BaseVyperCompiler(ManagerAccessMixin):
def __init__(self, api: "VyperCompiler"):
self.api = api

@property
def config(self) -> "VyperConfig":
return self.config_manager.vyper # type: ignore

@property
def output_format(self) -> list[str]:
return self.config.output_format or ["*"]

def get_evm_version(self, version: "Version") -> str:
return self.config.evm_version or EVM_VERSION_DEFAULT.get(version.base_version)

def get_import_remapping(self, project: Optional["ProjectManager"] = None) -> dict[str, dict]:
# Overridden on 0.4 to not use.
# Import-remapping is for Vyper versions 0.2 - 0.3 to
Expand All @@ -52,7 +63,6 @@ def compile(
vyper_version: "Version",
settings: dict,
import_map: "ImportMap",
compiler_data: dict,
project: Optional["ProjectManager"] = None,
):
pm = project or self.local_project
Expand All @@ -78,7 +88,7 @@ def compile(
# Output compiler details.
output_details(*output_selection.keys(), version=vyper_version)

comp_kwargs = self._get_compile_kwargs(vyper_version, compiler_data, project=pm)
comp_kwargs = self._get_compile_kwargs(vyper_version, settings, project=pm)

here = Path.cwd()
if pm.path != here:
Expand Down Expand Up @@ -107,7 +117,7 @@ def compile(
evm = output["evm"]
bytecode = evm["deployedBytecode"]
opcodes = bytecode["opcodes"].split(" ")
compressed_src_map = self._parse_source_map(bytecode["sourceMap"])
compressed_src_map = SourceMap(root=bytecode["sourceMap"])
src_map = list(compressed_src_map.parse())[1:]
pcmap = self._get_pcmap(vyper_version, ast, src_map, opcodes, bytecode)

Expand Down Expand Up @@ -141,35 +151,32 @@ def compile(
yield contract_type, settings_key

def _parse_ast(self, ast: dict, content: Content) -> ASTNode:
ast = ASTNode.model_validate(ast)
self._classify_ast(ast)
ast_model = ASTNode.model_validate(ast)
self._classify_ast(ast_model)

# Track function offsets.
function_offsets = []
for node in ast.children:
for node in ast_model.children:
lineno = node.lineno

# NOTE: Constructor is handled elsewhere.
if node.ast_type == "FunctionDef" and "__init__" not in content.get(lineno, ""):
function_offsets.append((node.lineno, node.end_lineno))

return ASTNode
return ast_model

def get_settings(
self,
version: "Version",
source_paths: Iterable[Path],
compiler_data: dict,
project: Optional["ProjectManager"] = None,
) -> dict:
pm = project or self.local_project
default_optimization = self._get_default_optimization(version)
output_selection: dict[str, set[str]] = {}
optimizations_map = get_optimization_pragma_map(source_paths, pm.path, default_optimization)
evm_version_map = get_evm_version_pragma_map(source_paths, pm.path)
default_evm_version = compiler_data.get(
"evm_version", compiler_data.get("evmVersion")
) or EVM_VERSION_DEFAULT.get(version.base_version)
default_evm_version = self.get_evm_version(version)
for source_path in source_paths:
source_id = str(get_relative_path(source_path.absolute(), pm.path))

Expand Down Expand Up @@ -242,23 +249,18 @@ def _get_selection_dictionary(
def _get_compile_kwargs(
self,
vyper_version: "Version",
compiler_data: dict,
settings: dict,
project: Optional["ProjectManager"] = None,
) -> dict:
"""
Generate extra kwargs to pass to Vyper.
"""
pm = project or self.local_project
comp_kwargs = self._get_base_compile_kwargs(vyper_version, compiler_data)
comp_kwargs = self._get_base_compile_kwargs(vyper_version)
# `base_path` is required for pre-0.4 versions or else imports won't resolve.
comp_kwargs["base_path"] = pm.path
return comp_kwargs

def _get_base_compile_kwargs(self, vyper_version: "Version", compiler_data: dict):
vyper_binary = compiler_data[vyper_version]["vyper_binary"]
comp_kwargs = {"vyper_version": vyper_version, "vyper_binary": vyper_binary}
return comp_kwargs

def _get_pcmap(
self,
vyper_version: "Version",
Expand All @@ -272,13 +274,6 @@ def _get_pcmap(
"""
return get_pcmap(bytecode)

def _parse_source_map(self, raw_source_map: Any) -> SourceMap:
"""
Generate the SourceMap.
"""
# All versions < 0.4 use this one
return SourceMap(root=raw_source_map)

def _get_default_optimization(self, vyper_version: "Version") -> Optimization:
"""
The default value for "optimize" in the settings for input JSON.
Expand Down
12 changes: 12 additions & 0 deletions ape_vyper/compiler/_versions/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import re
from pathlib import Path
from typing import TYPE_CHECKING

from ape.logging import logger
from ape.utils.os import clean_path

from ape_vyper._utils import DEV_MSG_PATTERN

if TYPE_CHECKING:
from packaging.version import Version

Expand All @@ -12,3 +15,12 @@ def output_details(*source_ids: str, version: "Version"):
source_ids = "\n\t".join(sorted([clean_path(Path(x)) for x in source_ids]))
log_str = f"Compiling using Vyper compiler '{version}'.\nInput:\n\t{source_ids}"
logger.info(log_str)


def map_dev_messages(content: dict) -> dict:
dev_messages = {}
for line_no, line in content.items():
if match := re.search(DEV_MSG_PATTERN, line):
dev_messages[line_no] = match.group(1).strip()

return dev_messages
Loading

0 comments on commit 1e9bdb1

Please sign in to comment.