Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support configured dependency cache directory #133

Merged
merged 9 commits into from
Feb 27, 2024
7 changes: 3 additions & 4 deletions ape_solidity/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,21 @@ def package_id(self) -> Path:
class ImportRemappingBuilder:
def __init__(self, contracts_cache: Path):
# import_map maps import keys like `@openzeppelin/contracts`
# to str paths in the contracts' .cache folder.
# to str paths in the compiler cache folder.
self.import_map: Dict[str, str] = {}
self.dependencies_added: Set[Path] = set()
self.contracts_cache = contracts_cache

def add_entry(self, remapping: ImportRemapping):
path = remapping.package_id
if not str(path).startswith(f".cache{os.path.sep}"):
path = Path(".cache") / path
if self.contracts_cache not in path.parents:
path = self.contracts_cache / path

self.import_map[remapping.key] = str(path)


def get_import_lines(source_paths: Set[Path]) -> Dict[Path, List[str]]:
imports_dict: Dict[Path, List[str]] = {}

for filepath in source_paths:
import_set = set()
if not filepath.is_file():
Expand Down
23 changes: 15 additions & 8 deletions ape_solidity/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,10 @@ def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, st
raise IncorrectMappingFormatError()

# We use these helpers to transform the values configured
# to values matching files in the `contracts/.cache` folder.
contracts_cache = base_path / ".cache"
builder = ImportRemappingBuilder(contracts_cache)
# to values matching files in the compiler cache folder.
builder = ImportRemappingBuilder(
get_relative_path(self.project_manager.compiler_cache_folder, base_path)
)
packages_cache = self.config_manager.packages_folder

# Here we hash and validate if there were changes to remappings.
Expand All @@ -233,7 +234,7 @@ def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, st
if (
self._import_remapping_hash
and self._import_remapping_hash == hash(remappings_tuple)
and contracts_cache.is_dir()
and self.project_manager.compiler_cache_folder.is_dir()
):
return self._cached_import_map

Expand Down Expand Up @@ -265,7 +266,7 @@ def get_import_remapping(self, base_path: Optional[Path] = None) -> Dict[str, st
data_folder_cache = packages_cache / package_id

# Re-build a downloaded dependency manifest into the .cache directory for imports.
sub_contracts_cache = contracts_cache / package_id
sub_contracts_cache = self.project_manager.compiler_cache_folder / package_id
if not sub_contracts_cache.is_dir() or not list(sub_contracts_cache.iterdir()):
cached_manifest_file = data_folder_cache / f"{remapping_obj.name}.json"
if not cached_manifest_file.is_file():
Expand Down Expand Up @@ -428,14 +429,16 @@ def _get_used_remappings(
# No remappings used at all.
return {}

relative_cache = get_relative_path(self.project_manager.compiler_cache_folder, base_path)

# Filter out unused import remapping.
return {
k: v
for source in (
x
for sources in self.get_imports(list(sources), base_path=base_path).values()
for x in sources
if x.startswith(".cache")
for sourceset in self.get_imports(list(sources), base_path=base_path).values()
for x in sourceset
if str(relative_cache) in x
)
for parent_key in (
os.path.sep.join(source.split(os.path.sep)[:3]) for source in [source]
Expand Down Expand Up @@ -471,6 +474,7 @@ def get_standard_input_json(
x: {"content": (base_path / x).read_text()}
for x in vers_settings["outputSelection"]
}

input_jsons[solc_version] = {
"sources": sources,
"settings": vers_settings,
Expand All @@ -496,6 +500,9 @@ def compile(
if solc_version >= Version("0.6.9"):
arguments["base_path"] = base_path

if self.project_manager.compiler_cache_folder.is_dir():
arguments["allow_paths"] = self.project_manager.compiler_cache_folder

# Allow empty contracts, like Vyper does.
arguments["allow_empty"] = True

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
include_package_data=True,
install_requires=[
"py-solc-x>=2.0.2,<3",
"eth-ape>=0.7.0,<0.8",
"eth-ape>=0.7.10,<0.8",
"ethpm-types", # Use the version ape requires
"eth-pydantic-types", # Use the version ape requires
"packaging", # Use the version ape requires
Expand Down
43 changes: 42 additions & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def test_compile_just_a_struct(compiler, project):


def test_get_imports(project, compiler):
import_dict = compiler.get_imports(TEST_CONTRACT_PATHS, BASE_PATH)
test_contract_paths = [
p
for p in project.contracts_folder.iterdir()
if ".cache" not in str(p) and not p.is_dir() and p.suffix == Extension.SOL.value
]
import_dict = compiler.get_imports(test_contract_paths, project.contracts_folder)
contract_imports = import_dict["Imports.sol"]
# NOTE: make sure there aren't duplicates
assert len([x for x in contract_imports if contract_imports.count(x) > 1]) == 0
Expand All @@ -161,6 +166,42 @@ def test_get_imports(project, compiler):
assert set(contract_imports) == expected


def test_get_imports_cache_folder(project, compiler):
"""Test imports when cache folder is configured"""
compile_config = project.config_manager.get_config("compile")
og_cache_colder = compile_config.cache_folder
compile_config.cache_folder = project.path / ".cash"
# assert False
test_contract_paths = [
p
for p in project.contracts_folder.iterdir()
if ".cache" not in str(p) and not p.is_dir() and p.suffix == Extension.SOL.value
]
# Using a different base path here because the cache folder is in the project root
import_dict = compiler.get_imports(test_contract_paths, project.path)
contract_imports = import_dict["contracts/Imports.sol"]
# NOTE: make sure there aren't duplicates
assert len([x for x in contract_imports if contract_imports.count(x) > 1]) == 0
# NOTE: returning a list
assert isinstance(contract_imports, list)
# NOTE: in case order changes
expected = {
".cash/BrownieDependency/local/BrownieContract.sol",
".cash/BrownieStyleDependency/local/BrownieStyleDependency.sol",
".cash/TestDependency/local/Dependency.sol",
".cash/gnosis/v1.3.0/common/Enum.sol",
"contracts/CompilesOnce.sol",
"contracts/MissingPragma.sol",
"contracts/NumerousDefinitions.sol",
"contracts/subfolder/Relativecontract.sol",
}
assert set(contract_imports) == expected

# Reset because this config is stateful across tests
compile_config.cache_folder = og_cache_colder
shutil.rmtree(og_cache_colder)


def test_get_imports_raises_when_non_solidity_files(compiler, vyper_source_path):
with raises_because_not_sol:
compiler.get_imports([vyper_source_path])
Expand Down
Loading