diff --git a/substratools/utils.py b/substratools/utils.py index 02cf5da..2f05b99 100644 --- a/substratools/utils.py +++ b/substratools/utils.py @@ -101,9 +101,17 @@ def load_interface_from_module(module_name, interface_class, interface_signature ) # find interface class + found_interfaces = [] for _, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, interface_class): - return obj() # return interface instance + if issubclass(obj, interface_class) and obj != interface_class: + found_interfaces.append(obj) + + if len(found_interfaces) == 1: + return found_interfaces[0]() # return interface instance + elif len(found_interfaces) > 1: + raise exceptions.InvalidInterfaceError( + f"Multiple interfaces found in module '{module_name}': {found_interfaces}" + ) # backward compatibility; accept methods at module level directly if interface_signature is None: diff --git a/tests/test_opener.py b/tests/test_opener.py index 58e8869..8c235d2 100644 --- a/tests/test_opener.py +++ b/tests/test_opener.py @@ -74,6 +74,36 @@ def test_load_opener_from_path(tmp_cwd, valid_opener_code): assert o.get_data()[0] == "X" +def test_load_opener_from_path_error_with_inheritance(tmp_cwd): + wrong_opener_code = """ +import json +from substratools import Opener + +class FakeOpener(Opener): + def get_data(self, folder): + return 'X', list(range(0, 3)) + + def fake_data(self, n_samples): + return ['Xfake'] * n_samples, [0] * n_samples + +class FinalOpener(FakeOpener): + def __init__(self): + super().__init__() +""" + dirpath = tmp_cwd / "myopener" + dirpath.mkdir() + path = dirpath / "my_opener.py" + path.write_text(wrong_opener_code) + + with pytest.raises(exceptions.InvalidInterfaceError): + load_interface_from_module( + "opener", + interface_class=Opener, + interface_signature=None, # XXX does not support interface for debugging + path=path, + ) + + def test_opener_check_folders(tmp_cwd): script = """ from substratools import Opener