From 1ab05cfd9fedc3f476ca32aeaa2871b69d952655 Mon Sep 17 00:00:00 2001 From: liuly12 Date: Sat, 16 Mar 2024 15:16:53 +0000 Subject: [PATCH 1/5] Apply overrides in Demand and ResidentialDemand --- tests/test_demand.py | 33 ++++++++++++++++++++++++++++++++- wsimod/nodes/demand.py | 40 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/tests/test_demand.py b/tests/test_demand.py index 6032c95..3c870fa 100644 --- a/tests/test_demand.py +++ b/tests/test_demand.py @@ -118,7 +118,38 @@ def test_garden_demand(self): d1 = {"volume": 0.03 * 0.4, "temperature": 0, "phosphate": 0} self.assertDictAlmostEqual(d1, reply) - + + def test_demand_overrides(self): + demand = Demand( + name="", + constant_demand=10, + pollutant_load={"phosphate": 0.1, "temperature": 12}, + ) + demand.apply_overrides({"constant_demand": 20, + "pollutant_load": {"phosphate": 0.5} + }) + self.assertEqual(demand.constant_demand, 20) + self.assertDictEqual(demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}) + + def test_residentialdemand_overrides(self): + demand = ResidentialDemand(name="", gardening_efficiency=0.4, pollutant_load={"phosphate": 0.1, "temperature": 12}) + demand.apply_overrides({"gardening_efficiency": 0.5, + "population": 153.2, + "per_capita": 32.4, + "constant_weighting": 47.5, + "constant_temp": 0.71, + + "constant_demand": 20, + "pollutant_load": {"phosphate": 0.5} + }) + self.assertEqual(demand.gardening_efficiency, 0.5) + self.assertEqual(demand.population, 153.2) + self.assertEqual(demand.per_capita, 32.4) + self.assertEqual(demand.constant_weighting, 47.5) + self.assertEqual(demand.constant_temp, 0.71) + + self.assertEqual(demand.constant_demand, 20) + self.assertDictEqual(demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}) if __name__ == "__main__": unittest.main() diff --git a/wsimod/nodes/demand.py b/wsimod/nodes/demand.py index 131ac80..0755e22 100644 --- a/wsimod/nodes/demand.py +++ b/wsimod/nodes/demand.py @@ -7,7 +7,7 @@ """ from wsimod.core import constants from wsimod.nodes.nodes import Node - +from typing import Any, Dict class Demand(Node): """""" @@ -28,7 +28,7 @@ def __init__( is used. Defaults to 0. pollutant_load (dict, optional): Pollutant mass per timestep of constant_demand. - Defaults to {}. + Defaults to 0. data_input_dict (dict, optional): Dictionary of data inputs relevant for the node (temperature). Keys are tuples where first value is the name of the variable to read from the dict and the second value is the time. @@ -61,6 +61,20 @@ def __init__( self.mass_balance_out.append(lambda: self.total_backup) self.mass_balance_out.append(lambda: self.total_received) + def apply_overrides(self, overrides: Dict[str, Any] = {}): + """Apply overrides to the sewer. + + Enables a user to override any of the following parameters: + constant_demand, pollutant_load. + + Args: + overrides (dict, optional): Dictionary of overrides. Defaults to {}. + """ + self.constant_demand = overrides.pop("constant_demand", + self.constant_demand) + self.pollutant_load.update(overrides.pop("pollutant_load", {})) + super().apply_overrides(overrides) + def create_demand(self): """Function to call get_demand, which should return a dict with keys that match the keys in directions. @@ -198,6 +212,28 @@ def __init__( # Label as Demand class so that other nodes treat it the same self.__class__.__name__ = "Demand" + def apply_overrides(self, overrides: Dict[str, Any] = {}): + """Apply overrides to the sewer. + + Enables a user to override any of the following parameters: + gardening_efficiency, population, per_capita, constant_weighting, constant_temp. + + Args: + overrides (dict, optional): Dictionary of overrides. Defaults to {}. + """ + self.gardening_efficiency = overrides.pop("gardening_efficiency", + self.gardening_efficiency) + self.population = overrides.pop("population", + self.population) + self.per_capita = overrides.pop("per_capita", + self.per_capita) + self.constant_weighting = overrides.pop("constant_weighting", + self.constant_weighting) + self.constant_temp = overrides.pop("constant_temp", + self.constant_temp) + self.pollutant_load.update(overrides.pop("pollutant_load", {})) + super().apply_overrides(overrides) + def get_demand(self): """Overwrite get_demand and replace with custom functions. From c64b57ae681697bc9cb3b6e4797d2ac2a065843c Mon Sep 17 00:00:00 2001 From: liuly12 Date: Sat, 16 Mar 2024 15:20:05 +0000 Subject: [PATCH 2/5] Remove duplicated pollutant_load overrides --- wsimod/nodes/demand.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wsimod/nodes/demand.py b/wsimod/nodes/demand.py index 0755e22..bd85dfa 100644 --- a/wsimod/nodes/demand.py +++ b/wsimod/nodes/demand.py @@ -231,7 +231,6 @@ def apply_overrides(self, overrides: Dict[str, Any] = {}): self.constant_weighting) self.constant_temp = overrides.pop("constant_temp", self.constant_temp) - self.pollutant_load.update(overrides.pop("pollutant_load", {})) super().apply_overrides(overrides) def get_demand(self): From d88f93d976c30632b061053a02040894b386f868 Mon Sep 17 00:00:00 2001 From: liuly12 Date: Fri, 26 Jul 2024 14:22:42 +0100 Subject: [PATCH 3/5] pre-commit reformatted --- tests/test_demand.py | 47 ++++++++++++++++++++++++++---------------- wsimod/nodes/demand.py | 38 +++++++++++++++++----------------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/tests/test_demand.py b/tests/test_demand.py index 3c870fa..c6d421e 100644 --- a/tests/test_demand.py +++ b/tests/test_demand.py @@ -118,38 +118,49 @@ def test_garden_demand(self): d1 = {"volume": 0.03 * 0.4, "temperature": 0, "phosphate": 0} self.assertDictAlmostEqual(d1, reply) - + def test_demand_overrides(self): demand = Demand( name="", constant_demand=10, pollutant_load={"phosphate": 0.1, "temperature": 12}, ) - demand.apply_overrides({"constant_demand": 20, - "pollutant_load": {"phosphate": 0.5} - }) + demand.apply_overrides( + {"constant_demand": 20, "pollutant_load": {"phosphate": 0.5}} + ) self.assertEqual(demand.constant_demand, 20) - self.assertDictEqual(demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}) - + self.assertDictEqual( + demand.pollutant_load, {"phosphate": 0.5, "temperature": 12} + ) + def test_residentialdemand_overrides(self): - demand = ResidentialDemand(name="", gardening_efficiency=0.4, pollutant_load={"phosphate": 0.1, "temperature": 12}) - demand.apply_overrides({"gardening_efficiency": 0.5, - "population": 153.2, - "per_capita": 32.4, - "constant_weighting": 47.5, - "constant_temp": 0.71, - - "constant_demand": 20, - "pollutant_load": {"phosphate": 0.5} - }) + demand = ResidentialDemand( + name="", + gardening_efficiency=0.4, + pollutant_load={"phosphate": 0.1, "temperature": 12}, + ) + demand.apply_overrides( + { + "gardening_efficiency": 0.5, + "population": 153.2, + "per_capita": 32.4, + "constant_weighting": 47.5, + "constant_temp": 0.71, + "constant_demand": 20, + "pollutant_load": {"phosphate": 0.5}, + } + ) self.assertEqual(demand.gardening_efficiency, 0.5) self.assertEqual(demand.population, 153.2) self.assertEqual(demand.per_capita, 32.4) self.assertEqual(demand.constant_weighting, 47.5) self.assertEqual(demand.constant_temp, 0.71) - + self.assertEqual(demand.constant_demand, 20) - self.assertDictEqual(demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}) + self.assertDictEqual( + demand.pollutant_load, {"phosphate": 0.5, "temperature": 12} + ) + if __name__ == "__main__": unittest.main() diff --git a/wsimod/nodes/demand.py b/wsimod/nodes/demand.py index bd85dfa..5778186 100644 --- a/wsimod/nodes/demand.py +++ b/wsimod/nodes/demand.py @@ -5,9 +5,11 @@ Converted to totals BD 2022-05-03 """ +from typing import Any, Dict + from wsimod.core import constants from wsimod.nodes.nodes import Node -from typing import Any, Dict + class Demand(Node): """""" @@ -63,15 +65,14 @@ def __init__( def apply_overrides(self, overrides: Dict[str, Any] = {}): """Apply overrides to the sewer. - - Enables a user to override any of the following parameters: + + Enables a user to override any of the following parameters: constant_demand, pollutant_load. - + Args: overrides (dict, optional): Dictionary of overrides. Defaults to {}. """ - self.constant_demand = overrides.pop("constant_demand", - self.constant_demand) + self.constant_demand = overrides.pop("constant_demand", self.constant_demand) self.pollutant_load.update(overrides.pop("pollutant_load", {})) super().apply_overrides(overrides) @@ -214,23 +215,22 @@ def __init__( def apply_overrides(self, overrides: Dict[str, Any] = {}): """Apply overrides to the sewer. - - Enables a user to override any of the following parameters: + + Enables a user to override any of the following parameters: gardening_efficiency, population, per_capita, constant_weighting, constant_temp. - + Args: overrides (dict, optional): Dictionary of overrides. Defaults to {}. """ - self.gardening_efficiency = overrides.pop("gardening_efficiency", - self.gardening_efficiency) - self.population = overrides.pop("population", - self.population) - self.per_capita = overrides.pop("per_capita", - self.per_capita) - self.constant_weighting = overrides.pop("constant_weighting", - self.constant_weighting) - self.constant_temp = overrides.pop("constant_temp", - self.constant_temp) + self.gardening_efficiency = overrides.pop( + "gardening_efficiency", self.gardening_efficiency + ) + self.population = overrides.pop("population", self.population) + self.per_capita = overrides.pop("per_capita", self.per_capita) + self.constant_weighting = overrides.pop( + "constant_weighting", self.constant_weighting + ) + self.constant_temp = overrides.pop("constant_temp", self.constant_temp) super().apply_overrides(overrides) def get_demand(self): From 49a7f71a78cd4efb288213acaa32f8b857f34310 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Wed, 18 Sep 2024 11:56:00 +0100 Subject: [PATCH 4/5] :sparkles: Add load_extensions_files step --- tests/test_model.py | 48 ++++++++++++++++++++++++++++++++--- wsimod/orchestration/model.py | 34 ++++++++++++++++++++++++- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index a004224..2ff53a6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,10 +4,10 @@ @author: Barney """ - -# import pytest +import os +import pytest import unittest -from unittest import TestCase +from unittest import TestCase, mock from wsimod.arcs.arcs import Arc from wsimod.nodes.land import Land @@ -292,5 +292,47 @@ def test_run(self): ) +class TestLoadExtensionFiles: + def test_load_extension_files_valid(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("def test_func(): pass") + + load_extension_files([valid_file]) + + def test_load_extension_files_invalid_extension(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + invalid_file = os.path.join(tempdir, "invalid_extension.txt") + with open(invalid_file, "w") as f: + f.write("This is a text file") + + with pytest.raises(ValueError, match="Only .py files are supported"): + load_extension_files([invalid_file]) + + def test_load_extension_files_nonexistent_file(self): + from wsimod.orchestration.model import load_extension_files + + with pytest.raises( + FileNotFoundError, match="File nonexistent_file.py does not exist" + ): + load_extension_files(["nonexistent_file.py"]) + + def test_load_extension_files_import_error(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("raise ImportError") + + with pytest.raises(ImportError): + load_extension_files([valid_file]) + + if __name__ == "__main__": unittest.main() diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 5050724..0075a29 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -160,7 +160,7 @@ def load(self, address, config_name="config.yml", overrides={}): from ..extensions import apply_patches with open(os.path.join(address, config_name), "r") as file: - data = yaml.safe_load(file) + data: dict = yaml.safe_load(file) for key, item in overrides.items(): data[key] = item @@ -193,6 +193,7 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] + load_extension_files(data.get("extensions", [])) apply_patches(self) def save(self, address, config_name="config.yml", compress=False): @@ -1285,3 +1286,34 @@ def yaml2csv(address, config_name="config.yml", csv_folder_name="csv"): writer.writerow( [str(value_[x]) if x in value_.keys() else None for x in fields] ) + + +def load_extension_files(files: list[str]) -> None: + """Load extension files from a list of files. + + Args: + files (list[str]): List of file paths to load + + Raises: + ValueError: If file is not a .py file + FileNotFoundError: If file does not exist + """ + import importlib + from pathlib import Path + + invalid_files: list[str] = [] + for file in files: + if not Path(file).exists(): + raise FileNotFoundError(f"File {file} does not exist") + + if file.endswith(".py"): + spec = importlib.util.spec_from_file_location("module.name", file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + else: + invalid_files.append(file) + + if invalid_files: + raise ValueError( + "Only .py files are supported. Invalid files: " + ", ".join(invalid_files) + ) From 36a70318eb1118d63802864060a787a4d27b77d2 Mon Sep 17 00:00:00 2001 From: Diego Alonso Alvarez Date: Wed, 18 Sep 2024 12:55:19 +0100 Subject: [PATCH 5/5] :recycle: Simplify error handling. --- wsimod/orchestration/model.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 0075a29..7746282 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -1301,19 +1301,12 @@ def load_extension_files(files: list[str]) -> None: import importlib from pathlib import Path - invalid_files: list[str] = [] for file in files: + if not file.endswith(".py"): + raise ValueError(f"Only .py files are supported. Invalid file: {file}") if not Path(file).exists(): raise FileNotFoundError(f"File {file} does not exist") - if file.endswith(".py"): - spec = importlib.util.spec_from_file_location("module.name", file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - else: - invalid_files.append(file) - - if invalid_files: - raise ValueError( - "Only .py files are supported. Invalid files: " + ", ".join(invalid_files) - ) + spec = importlib.util.spec_from_file_location("module.name", file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module)