Skip to content

Commit

Permalink
Simplify items package scanning (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Feb 15, 2024
2 parents b4bedc3 + 6c7f229 commit 2542f77
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 38 deletions.
41 changes: 3 additions & 38 deletions spine_engine/load_project_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Functions to load project item modules.
"""
import pathlib
""" Functions to load project item modules. """
import importlib
import importlib.util


def load_item_specification_factories(items_package_name):
Expand All @@ -30,21 +24,7 @@ def load_item_specification_factories(items_package_name):
dict: a map from item type to specification factory
"""
items = importlib.import_module(items_package_name)
items_root = pathlib.Path(items.__file__).parent
factories = dict()
for child in items_root.iterdir():
if child.is_dir() and (
child.joinpath("specification_factory.py").exists()
or child.is_dir()
and child.joinpath("specification_factory.pyc").exists()
):
spec = importlib.util.find_spec(f"{items_package_name}.{child.stem}.specification_factory")
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
if hasattr(m, "SpecificationFactory"):
item_type = m.SpecificationFactory.item_type()
factories[item_type] = m.SpecificationFactory
return factories
return items.item_specification_factories()


def load_executable_item_classes(items_package_name):
Expand All @@ -58,19 +38,4 @@ def load_executable_item_classes(items_package_name):
dict: a map from item type to the executable item class
"""
items = importlib.import_module(items_package_name)
items_root = pathlib.Path(items.__file__).parent
classes = dict()
for child in items_root.iterdir():
if (
child.is_dir()
and child.joinpath("executable_item.py").exists()
or (child.is_dir() and child.joinpath("executable_item.pyc").exists())
):
spec = importlib.util.find_spec(f"{items_package_name}.{child.stem}.executable_item")
m = importlib.util.module_from_spec(spec)
spec.loader.exec_module(m)
if hasattr(m, "ExecutableItem"):
item_class = m.ExecutableItem
item_type = item_class.item_type()
classes[item_type] = item_class
return classes
return items.executable_items()
10 changes: 10 additions & 0 deletions tests/mock_project_items/items_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def item_specification_factories():
from .test_item.specification_factory import SpecificationFactory

return {"TestItem": SpecificationFactory}


def executable_items():
from .test_item.executable_item import ExecutableItem

return {"TestItem": ExecutableItem}

0 comments on commit 2542f77

Please sign in to comment.