diff --git a/tests/test_demand.py b/tests/test_demand.py index 6032c95..c6d421e 100644 --- a/tests/test_demand.py +++ b/tests/test_demand.py @@ -119,6 +119,48 @@ def test_garden_demand(self): d1 = {"volume": 0.03 * 0.4, "temperature": 0, "phosphate": 0} self.assertDictAlmostEqual(d1, reply) + def test_demand_overrides(self): + demand = Demand( + name="", + constant_demand=10, + pollutant_load={"phosphate": 0.1, "temperature": 12}, + ) + demand.apply_overrides( + {"constant_demand": 20, "pollutant_load": {"phosphate": 0.5}} + ) + self.assertEqual(demand.constant_demand, 20) + self.assertDictEqual( + demand.pollutant_load, {"phosphate": 0.5, "temperature": 12} + ) + + def test_residentialdemand_overrides(self): + demand = ResidentialDemand( + name="", + gardening_efficiency=0.4, + pollutant_load={"phosphate": 0.1, "temperature": 12}, + ) + demand.apply_overrides( + { + "gardening_efficiency": 0.5, + "population": 153.2, + "per_capita": 32.4, + "constant_weighting": 47.5, + "constant_temp": 0.71, + "constant_demand": 20, + "pollutant_load": {"phosphate": 0.5}, + } + ) + self.assertEqual(demand.gardening_efficiency, 0.5) + self.assertEqual(demand.population, 153.2) + self.assertEqual(demand.per_capita, 32.4) + self.assertEqual(demand.constant_weighting, 47.5) + self.assertEqual(demand.constant_temp, 0.71) + + self.assertEqual(demand.constant_demand, 20) + self.assertDictEqual( + demand.pollutant_load, {"phosphate": 0.5, "temperature": 12} + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_model.py b/tests/test_model.py index e5e404c..3ffdbf7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,10 +4,10 @@ @author: Barney """ - -# import pytest +import os +import pytest import unittest -from unittest import TestCase +from unittest import TestCase, mock from wsimod.arcs.arcs import Arc from wsimod.nodes.land import Land @@ -305,5 +305,47 @@ def test_customise_orchestration(self): self.assertListEqual(my_model.orchestration, revised_orchestration) +class TestLoadExtensionFiles: + def test_load_extension_files_valid(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("def test_func(): pass") + + load_extension_files([valid_file]) + + def test_load_extension_files_invalid_extension(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + invalid_file = os.path.join(tempdir, "invalid_extension.txt") + with open(invalid_file, "w") as f: + f.write("This is a text file") + + with pytest.raises(ValueError, match="Only .py files are supported"): + load_extension_files([invalid_file]) + + def test_load_extension_files_nonexistent_file(self): + from wsimod.orchestration.model import load_extension_files + + with pytest.raises( + FileNotFoundError, match="File nonexistent_file.py does not exist" + ): + load_extension_files(["nonexistent_file.py"]) + + def test_load_extension_files_import_error(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("raise ImportError") + + with pytest.raises(ImportError): + load_extension_files([valid_file]) + + if __name__ == "__main__": unittest.main() diff --git a/wsimod/nodes/demand.py b/wsimod/nodes/demand.py index 131ac80..5778186 100644 --- a/wsimod/nodes/demand.py +++ b/wsimod/nodes/demand.py @@ -5,6 +5,8 @@ Converted to totals BD 2022-05-03 """ +from typing import Any, Dict + from wsimod.core import constants from wsimod.nodes.nodes import Node @@ -28,7 +30,7 @@ def __init__( is used. Defaults to 0. pollutant_load (dict, optional): Pollutant mass per timestep of constant_demand. - Defaults to {}. + Defaults to 0. data_input_dict (dict, optional): Dictionary of data inputs relevant for the node (temperature). Keys are tuples where first value is the name of the variable to read from the dict and the second value is the time. @@ -61,6 +63,19 @@ def __init__( self.mass_balance_out.append(lambda: self.total_backup) self.mass_balance_out.append(lambda: self.total_received) + def apply_overrides(self, overrides: Dict[str, Any] = {}): + """Apply overrides to the sewer. + + Enables a user to override any of the following parameters: + constant_demand, pollutant_load. + + Args: + overrides (dict, optional): Dictionary of overrides. Defaults to {}. + """ + self.constant_demand = overrides.pop("constant_demand", self.constant_demand) + self.pollutant_load.update(overrides.pop("pollutant_load", {})) + super().apply_overrides(overrides) + def create_demand(self): """Function to call get_demand, which should return a dict with keys that match the keys in directions. @@ -198,6 +213,26 @@ def __init__( # Label as Demand class so that other nodes treat it the same self.__class__.__name__ = "Demand" + def apply_overrides(self, overrides: Dict[str, Any] = {}): + """Apply overrides to the sewer. + + Enables a user to override any of the following parameters: + gardening_efficiency, population, per_capita, constant_weighting, constant_temp. + + Args: + overrides (dict, optional): Dictionary of overrides. Defaults to {}. + """ + self.gardening_efficiency = overrides.pop( + "gardening_efficiency", self.gardening_efficiency + ) + self.population = overrides.pop("population", self.population) + self.per_capita = overrides.pop("per_capita", self.per_capita) + self.constant_weighting = overrides.pop( + "constant_weighting", self.constant_weighting + ) + self.constant_temp = overrides.pop("constant_temp", self.constant_temp) + super().apply_overrides(overrides) + def get_demand(self): """Overwrite get_demand and replace with custom functions. diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 4026a32..5b3347f 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -178,7 +178,7 @@ def load(self, address, config_name="config.yml", overrides={}): from ..extensions import apply_patches with open(os.path.join(address, config_name), "r") as file: - data = yaml.safe_load(file) + data: dict = yaml.safe_load(file) for key, item in overrides.items(): data[key] = item @@ -220,6 +220,7 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] + load_extension_files(data.get("extensions", [])) apply_patches(self) def save(self, address, config_name="config.yml", compress=False): @@ -1269,3 +1270,27 @@ def yaml2csv(address, config_name="config.yml", csv_folder_name="csv"): writer.writerow( [str(value_[x]) if x in value_.keys() else None for x in fields] ) + + +def load_extension_files(files: list[str]) -> None: + """Load extension files from a list of files. + + Args: + files (list[str]): List of file paths to load + + Raises: + ValueError: If file is not a .py file + FileNotFoundError: If file does not exist + """ + import importlib + from pathlib import Path + + for file in files: + if not file.endswith(".py"): + raise ValueError(f"Only .py files are supported. Invalid file: {file}") + if not Path(file).exists(): + raise FileNotFoundError(f"File {file} does not exist") + + spec = importlib.util.spec_from_file_location("module.name", file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module)