Skip to content

Commit

Permalink
Merge pull request #104 from ImperialCollegeLondon/extensions_patch
Browse files Browse the repository at this point in the history
✨ Add load_extensions_files step
  • Loading branch information
barneydobson authored Oct 1, 2024
2 parents be12a15 + 83cd219 commit 052d00c
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
48 changes: 45 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
@author: Barney
"""

# import pytest
import os
import pytest
import unittest
from unittest import TestCase
from unittest import TestCase, mock

from wsimod.arcs.arcs import Arc
from wsimod.nodes.land import Land
Expand Down Expand Up @@ -305,5 +305,47 @@ def test_customise_orchestration(self):
self.assertListEqual(my_model.orchestration, revised_orchestration)


class TestLoadExtensionFiles:
def test_load_extension_files_valid(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
valid_file = os.path.join(tempdir, "valid_extension.py")
with open(valid_file, "w") as f:
f.write("def test_func(): pass")

load_extension_files([valid_file])

def test_load_extension_files_invalid_extension(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
invalid_file = os.path.join(tempdir, "invalid_extension.txt")
with open(invalid_file, "w") as f:
f.write("This is a text file")

with pytest.raises(ValueError, match="Only .py files are supported"):
load_extension_files([invalid_file])

def test_load_extension_files_nonexistent_file(self):
from wsimod.orchestration.model import load_extension_files

with pytest.raises(
FileNotFoundError, match="File nonexistent_file.py does not exist"
):
load_extension_files(["nonexistent_file.py"])

def test_load_extension_files_import_error(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
valid_file = os.path.join(tempdir, "valid_extension.py")
with open(valid_file, "w") as f:
f.write("raise ImportError")

with pytest.raises(ImportError):
load_extension_files([valid_file])


if __name__ == "__main__":
unittest.main()
27 changes: 26 additions & 1 deletion wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def load(self, address, config_name="config.yml", overrides={}):
from ..extensions import apply_patches

with open(os.path.join(address, config_name), "r") as file:
data = yaml.safe_load(file)
data: dict = yaml.safe_load(file)

for key, item in overrides.items():
data[key] = item
Expand Down Expand Up @@ -220,6 +220,7 @@ def load(self, address, config_name="config.yml", overrides={}):
if "dates" in data.keys():
self.dates = [to_datetime(x) for x in data["dates"]]

load_extension_files(data.get("extensions", []))
apply_patches(self)

def save(self, address, config_name="config.yml", compress=False):
Expand Down Expand Up @@ -1269,3 +1270,27 @@ def yaml2csv(address, config_name="config.yml", csv_folder_name="csv"):
writer.writerow(
[str(value_[x]) if x in value_.keys() else None for x in fields]
)


def load_extension_files(files: list[str]) -> None:
"""Load extension files from a list of files.
Args:
files (list[str]): List of file paths to load
Raises:
ValueError: If file is not a .py file
FileNotFoundError: If file does not exist
"""
import importlib
from pathlib import Path

for file in files:
if not file.endswith(".py"):
raise ValueError(f"Only .py files are supported. Invalid file: {file}")
if not Path(file).exists():
raise FileNotFoundError(f"File {file} does not exist")

spec = importlib.util.spec_from_file_location("module.name", file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

0 comments on commit 052d00c

Please sign in to comment.