Skip to content

Commit

Permalink
chore: raise if several interfaces are found (#113)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Oct 3, 2024
1 parent d0303e4 commit 0c3f42d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
12 changes: 10 additions & 2 deletions substratools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,17 @@ def load_interface_from_module(module_name, interface_class, interface_signature
)

# find interface class
found_interfaces = []
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, interface_class):
return obj() # return interface instance
if issubclass(obj, interface_class) and obj != interface_class:
found_interfaces.append(obj)

if len(found_interfaces) == 1:
return found_interfaces[0]() # return interface instance
elif len(found_interfaces) > 1:
raise exceptions.InvalidInterfaceError(
f"Multiple interfaces found in module '{module_name}': {found_interfaces}"
)

# backward compatibility; accept methods at module level directly
if interface_signature is None:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_opener.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ def test_load_opener_from_path(tmp_cwd, valid_opener_code):
assert o.get_data()[0] == "X"


def test_load_opener_from_path_error_with_inheritance(tmp_cwd):
wrong_opener_code = """
import json
from substratools import Opener
class FakeOpener(Opener):
def get_data(self, folder):
return 'X', list(range(0, 3))
def fake_data(self, n_samples):
return ['Xfake'] * n_samples, [0] * n_samples
class FinalOpener(FakeOpener):
def __init__(self):
super().__init__()
"""
dirpath = tmp_cwd / "myopener"
dirpath.mkdir()
path = dirpath / "my_opener.py"
path.write_text(wrong_opener_code)

with pytest.raises(exceptions.InvalidInterfaceError):
load_interface_from_module(
"opener",
interface_class=Opener,
interface_signature=None, # XXX does not support interface for debugging
path=path,
)


def test_opener_check_folders(tmp_cwd):
script = """
from substratools import Opener
Expand Down

0 comments on commit 0c3f42d

Please sign in to comment.