Skip to content

Commit

Permalink
fix: handle implicit relative imports (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 23, 2024
1 parent 86a607f commit 6d05c26
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.1
rev: v1.11.0
hooks:
- id: mypy
additional_dependencies: [types-setuptools, pydantic==1.10.4]
Expand Down
74 changes: 61 additions & 13 deletions ape_vyper/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _get_imports(
if not path.is_file():
continue

content = path.read_text().splitlines()
content = path.read_text(encoding="utf8").splitlines()
source_id = (
str(path.absolute())
if use_absolute_paths
Expand Down Expand Up @@ -410,7 +410,10 @@ def _get_imports(
dots += prefix[0]
prefix = prefix[1:]

is_relative = dots != ""
is_relative: Optional[bool] = None
if dots != "":
is_relative = True
# else: we are unsure since dots are not required.

# Replace rest of dots with slashes.
prefix = prefix.replace(".", os.path.sep)
Expand All @@ -421,25 +424,70 @@ def _get_imports(

continue

local_path = (
(path.parent / dots / prefix.lstrip(os.path.sep)).resolve()
if is_relative
else (pm.path / prefix.lstrip(os.path.sep)).resolve()
relative_path = None
abs_path = None
if is_relative is True:
relative_path = (path.parent / dots / prefix.lstrip(os.path.sep)).resolve()
elif is_relative is False:
abs_path = (pm.path / prefix.lstrip(os.path.sep)).resolve()
elif is_relative is None:
relative_path = (path.parent / dots / prefix.lstrip(os.path.sep)).resolve()
abs_path = (pm.path / prefix.lstrip(os.path.sep)).resolve()

local_prefix_relative = (
None
if relative_path is None
else str(relative_path).replace(f"{pm.path}", "").lstrip(os.path.sep)
)
local_prefix_abs = (
None
if abs_path is None
else str(abs_path).replace(f"{pm.path}", "").lstrip(os.path.sep)
)
local_prefix = str(local_path).replace(f"{pm.path}", "").lstrip(os.path.sep)

import_source_id = None
is_local = True
local_path = None # TBD
local_prefix = None # TBD

# NOTE: Defaults to JSON (assuming from input JSON or a local JSON),
# unless a Vyper file exists.
if (pm.path / f"{local_prefix}{FileType.SOURCE}").is_file():
if (pm.path / f"{local_prefix_relative}{FileType.SOURCE}").is_file():
# Relative source.
ext = FileType.SOURCE.value
elif (pm.path / f"{local_prefix}{FileType.SOURCE}").is_file():
local_path = relative_path
local_prefix = local_prefix_relative

elif (pm.path / f"{local_prefix_relative}{FileType.INTERFACE}").is_file():
# Relative interface.
ext = FileType.INTERFACE.value
elif (pm.path / f"{local_prefix}{FileType.INTERFACE}").is_file():
local_path = relative_path
local_prefix = local_prefix_relative

elif (pm.path / f"{local_prefix_relative}.json").is_file():
# Relative JSON interface.
ext = ".json"
local_path = relative_path
local_prefix = local_prefix_relative

elif (pm.path / f"{local_prefix_abs}{FileType.SOURCE}").is_file():
# Absolute source.
ext = FileType.SOURCE.value
local_path = abs_path
local_prefix = local_prefix_abs

elif (pm.path / f"{local_prefix_abs}{FileType.INTERFACE}").is_file():
# Absolute interface.
ext = FileType.INTERFACE.value
local_path = abs_path
local_prefix = local_prefix_abs

elif (pm.path / f"{local_prefix_abs}.json").is_file():
# Absolute JSON interface.
ext = ".json"
local_path = abs_path
local_prefix = local_prefix_abs

else:
# Must be an interface JSON specified in the input JSON.
ext = ".json"
dep_key = prefix.split(os.path.sep)[0]
dependency_name = prefix.split(os.path.sep)[0]
Expand Down Expand Up @@ -488,7 +536,7 @@ def _get_imports(

is_local = False

if is_local:
if is_local and local_prefix is not None and local_path is not None:
import_source_id = f"{local_prefix}{ext}"
full_path = local_path.parent / f"{local_path.stem}{ext}"

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
],
"lint": [
"black>=24.4.2,<25", # Auto-formatter and linter
"mypy>=1.10.1,<2", # Static type analyzer
"mypy>=1.11.0,<2", # Static type analyzer
"types-setuptools", # Needed due to mypy typeshed
"flake8>=7.1.0,<8", # Style linter
"isort>=5.13.2", # Import sorting linter
Expand Down
11 changes: 11 additions & 0 deletions tests/contracts/passing_contracts/zero_four_snekmate_erc20.vy
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# pragma version ~=0.4.0
from snekmate.auth import ownable
from snekmate.tokens import erc20

initializes: ownable
initializes: erc20[ownable := ownable]

@deploy
def __init__(_name: String[25]):
ownable.__init__()
erc20.__init__(_name, "ERC20", 18, "name", "name2")
3 changes: 2 additions & 1 deletion tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def test_get_version_map(project, compiler, all_versions):
"zero_four.vy",
"zero_four_module.vy",
"zero_four_module_2.vy",
"zero_four_snekmate_erc20.vy",
}
assert actual4 == expected4

Expand Down Expand Up @@ -397,7 +398,7 @@ def test_get_imports(compiler, project):

actual_iface_use = actual[use_iface_key]
for expected in (local_import, local_from_import, dependency_import, local_nested_import):
assert any(k for k in actual_iface_use if expected in k)
assert any(k for k in actual_iface_use if expected in k), f"{expected} not found"

assert actual[use_iface2_key][0].endswith(local_import)

Expand Down

0 comments on commit 6d05c26

Please sign in to comment.