Skip to content

Commit

Permalink
feat[lang]: support flags from imported interfaces (vyperlang#4253)
Browse files Browse the repository at this point in the history
this commit allows flag types to be imported from `.vyi` interface
files.

---------

Co-authored-by: Charles Cooper <[email protected]>
  • Loading branch information
tserg and charles-cooper authored Oct 15, 2024
1 parent 6bcdb00 commit 990a6fa
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
28 changes: 28 additions & 0 deletions tests/functional/codegen/modules/test_interface_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,31 @@ def foo() -> bool:
c = get_contract(main, input_bundle=input_bundle)

assert c.foo() is True


def test_import_interface_flags(make_input_bundle, get_contract):
ifaces = """
flag Foo:
BOO
MOO
POO
interface IFoo:
def foo() -> Foo: nonpayable
"""

contract = """
import ifaces
implements: ifaces
@external
def foo() -> ifaces.Foo:
return ifaces.Foo.POO
"""

input_bundle = make_input_bundle({"ifaces.vyi": ifaces})

c = get_contract(contract, input_bundle=input_bundle)

assert c.foo() == 4
57 changes: 29 additions & 28 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vyper.semantics.types.base import TYPE_T, VyperType, is_type_t
from vyper.semantics.types.function import ContractFunctionT
from vyper.semantics.types.primitives import AddressT
from vyper.semantics.types.user import EventT, StructT, _UserType
from vyper.semantics.types.user import EventT, FlagT, StructT, _UserType
from vyper.utils import OrderedSet

if TYPE_CHECKING:
Expand All @@ -45,27 +45,29 @@ def __init__(
functions: dict,
events: dict,
structs: dict,
flags: dict,
) -> None:
validate_unique_method_ids(list(functions.values()))

members = functions | events | structs
members = functions | events | structs | flags

# sanity check: by construction, there should be no duplicates.
assert len(members) == len(functions) + len(events) + len(structs)
assert len(members) == len(functions) + len(events) + len(structs) + len(flags)

super().__init__(functions)

self._helper = VyperType(events | structs)
self._helper = VyperType(events | structs | flags)
self._id = _id
self._helper._id = _id
self.functions = functions
self.events = events
self.structs = structs
self.flags = flags

self.decl_node = decl_node

def get_type_member(self, attr, node):
# get an event or struct from this interface
# get an event, struct or flag from this interface
return TYPE_T(self._helper.get_member(attr, node))

@property
Expand Down Expand Up @@ -159,12 +161,14 @@ def _from_lists(
interface_name: str,
decl_node: Optional[vy_ast.VyperNode],
function_list: list[tuple[str, ContractFunctionT]],
event_list: list[tuple[str, EventT]],
struct_list: list[tuple[str, StructT]],
event_list: Optional[list[tuple[str, EventT]]] = None,
struct_list: Optional[list[tuple[str, StructT]]] = None,
flag_list: Optional[list[tuple[str, FlagT]]] = None,
) -> "InterfaceT":
functions = {}
events = {}
structs = {}
functions: dict[str, ContractFunctionT] = {}
events: dict[str, EventT] = {}
structs: dict[str, StructT] = {}
flags: dict[str, FlagT] = {}

seen_items: dict = {}

Expand All @@ -175,19 +179,20 @@ def _mark_seen(name, item):
raise NamespaceCollision(msg, item.decl_node, prev_decl=prev_decl)
seen_items[name] = item

for name, function in function_list:
_mark_seen(name, function)
functions[name] = function
def _process(dst_dict, items):
if items is None:
return

for name, event in event_list:
_mark_seen(name, event)
events[name] = event
for name, item in items:
_mark_seen(name, item)
dst_dict[name] = item

for name, struct in struct_list:
_mark_seen(name, struct)
structs[name] = struct
_process(functions, function_list)
_process(events, event_list)
_process(structs, struct_list)
_process(flags, flag_list)

return cls(interface_name, decl_node, functions, events, structs)
return cls(interface_name, decl_node, functions, events, structs, flags)

@classmethod
def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT":
Expand All @@ -214,8 +219,7 @@ def from_json_abi(cls, name: str, abi: dict) -> "InterfaceT":
for item in [i for i in abi if i.get("type") == "event"]:
events.append((item["name"], EventT.from_abi(item)))

structs: list = [] # no structs in json ABI (as of yet)
return cls._from_lists(name, None, functions, events, structs)
return cls._from_lists(name, None, functions, events)

@classmethod
def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT":
Expand Down Expand Up @@ -247,8 +251,9 @@ def from_ModuleT(cls, module_t: "ModuleT") -> "InterfaceT":
# these are accessible via import, but they do not show up
# in the ABI json
structs = [(node.name, node._metadata["struct_type"]) for node in module_t.struct_defs]
flags = [(node.name, node._metadata["flag_type"]) for node in module_t.flag_defs]

return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs)
return cls._from_lists(module_t._id, module_t.decl_node, funcs, events, structs, flags)

@classmethod
def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT":
Expand All @@ -265,11 +270,7 @@ def from_InterfaceDef(cls, node: vy_ast.InterfaceDef) -> "InterfaceT":
)
functions.append((func_ast.name, ContractFunctionT.from_InterfaceDef(func_ast)))

# no structs or events in InterfaceDefs
events: list = []
structs: list = []

return cls._from_lists(node.name, node, functions, events, structs)
return cls._from_lists(node.name, node, functions)


# Datatype to store all module information.
Expand Down

0 comments on commit 990a6fa

Please sign in to comment.