Skip to content

Commit

Permalink
Merge branch 'main' into distribution-overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
barneydobson authored Oct 1, 2024
2 parents 0eb27a7 + ec39820 commit f1db662
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 5 deletions.
42 changes: 42 additions & 0 deletions tests/test_demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,48 @@ def test_garden_demand(self):
d1 = {"volume": 0.03 * 0.4, "temperature": 0, "phosphate": 0}
self.assertDictAlmostEqual(d1, reply)

def test_demand_overrides(self):
demand = Demand(
name="",
constant_demand=10,
pollutant_load={"phosphate": 0.1, "temperature": 12},
)
demand.apply_overrides(
{"constant_demand": 20, "pollutant_load": {"phosphate": 0.5}}
)
self.assertEqual(demand.constant_demand, 20)
self.assertDictEqual(
demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}
)

def test_residentialdemand_overrides(self):
demand = ResidentialDemand(
name="",
gardening_efficiency=0.4,
pollutant_load={"phosphate": 0.1, "temperature": 12},
)
demand.apply_overrides(
{
"gardening_efficiency": 0.5,
"population": 153.2,
"per_capita": 32.4,
"constant_weighting": 47.5,
"constant_temp": 0.71,
"constant_demand": 20,
"pollutant_load": {"phosphate": 0.5},
}
)
self.assertEqual(demand.gardening_efficiency, 0.5)
self.assertEqual(demand.population, 153.2)
self.assertEqual(demand.per_capita, 32.4)
self.assertEqual(demand.constant_weighting, 47.5)
self.assertEqual(demand.constant_temp, 0.71)

self.assertEqual(demand.constant_demand, 20)
self.assertDictEqual(
demand.pollutant_load, {"phosphate": 0.5, "temperature": 12}
)


if __name__ == "__main__":
unittest.main()
48 changes: 45 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
@author: Barney
"""

# import pytest
import os
import pytest
import unittest
from unittest import TestCase
from unittest import TestCase, mock

from wsimod.arcs.arcs import Arc
from wsimod.nodes.land import Land
Expand Down Expand Up @@ -305,5 +305,47 @@ def test_customise_orchestration(self):
self.assertListEqual(my_model.orchestration, revised_orchestration)


class TestLoadExtensionFiles:
def test_load_extension_files_valid(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
valid_file = os.path.join(tempdir, "valid_extension.py")
with open(valid_file, "w") as f:
f.write("def test_func(): pass")

load_extension_files([valid_file])

def test_load_extension_files_invalid_extension(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
invalid_file = os.path.join(tempdir, "invalid_extension.txt")
with open(invalid_file, "w") as f:
f.write("This is a text file")

with pytest.raises(ValueError, match="Only .py files are supported"):
load_extension_files([invalid_file])

def test_load_extension_files_nonexistent_file(self):
from wsimod.orchestration.model import load_extension_files

with pytest.raises(
FileNotFoundError, match="File nonexistent_file.py does not exist"
):
load_extension_files(["nonexistent_file.py"])

def test_load_extension_files_import_error(self, tmp_path_factory):
from wsimod.orchestration.model import load_extension_files

with tmp_path_factory.mktemp("extensions") as tempdir:
valid_file = os.path.join(tempdir, "valid_extension.py")
with open(valid_file, "w") as f:
f.write("raise ImportError")

with pytest.raises(ImportError):
load_extension_files([valid_file])


if __name__ == "__main__":
unittest.main()
37 changes: 36 additions & 1 deletion wsimod/nodes/demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Converted to totals BD 2022-05-03
"""
from typing import Any, Dict

from wsimod.core import constants
from wsimod.nodes.nodes import Node

Expand All @@ -28,7 +30,7 @@ def __init__(
is used. Defaults to 0.
pollutant_load (dict, optional): Pollutant mass per timestep of
constant_demand.
Defaults to {}.
Defaults to 0.
data_input_dict (dict, optional): Dictionary of data inputs relevant for
the node (temperature). Keys are tuples where first value is the name of
the variable to read from the dict and the second value is the time.
Expand Down Expand Up @@ -61,6 +63,19 @@ def __init__(
self.mass_balance_out.append(lambda: self.total_backup)
self.mass_balance_out.append(lambda: self.total_received)

def apply_overrides(self, overrides: Dict[str, Any] = {}):
"""Apply overrides to the sewer.
Enables a user to override any of the following parameters:
constant_demand, pollutant_load.
Args:
overrides (dict, optional): Dictionary of overrides. Defaults to {}.
"""
self.constant_demand = overrides.pop("constant_demand", self.constant_demand)
self.pollutant_load.update(overrides.pop("pollutant_load", {}))
super().apply_overrides(overrides)

def create_demand(self):
"""Function to call get_demand, which should return a dict with keys that match
the keys in directions.
Expand Down Expand Up @@ -198,6 +213,26 @@ def __init__(
# Label as Demand class so that other nodes treat it the same
self.__class__.__name__ = "Demand"

def apply_overrides(self, overrides: Dict[str, Any] = {}):
"""Apply overrides to the sewer.
Enables a user to override any of the following parameters:
gardening_efficiency, population, per_capita, constant_weighting, constant_temp.
Args:
overrides (dict, optional): Dictionary of overrides. Defaults to {}.
"""
self.gardening_efficiency = overrides.pop(
"gardening_efficiency", self.gardening_efficiency
)
self.population = overrides.pop("population", self.population)
self.per_capita = overrides.pop("per_capita", self.per_capita)
self.constant_weighting = overrides.pop(
"constant_weighting", self.constant_weighting
)
self.constant_temp = overrides.pop("constant_temp", self.constant_temp)
super().apply_overrides(overrides)

def get_demand(self):
"""Overwrite get_demand and replace with custom functions.
Expand Down
27 changes: 26 additions & 1 deletion wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def load(self, address, config_name="config.yml", overrides={}):
from ..extensions import apply_patches

with open(os.path.join(address, config_name), "r") as file:
data = yaml.safe_load(file)
data: dict = yaml.safe_load(file)

for key, item in overrides.items():
data[key] = item
Expand Down Expand Up @@ -220,6 +220,7 @@ def load(self, address, config_name="config.yml", overrides={}):
if "dates" in data.keys():
self.dates = [to_datetime(x) for x in data["dates"]]

load_extension_files(data.get("extensions", []))
apply_patches(self)

def save(self, address, config_name="config.yml", compress=False):
Expand Down Expand Up @@ -1269,3 +1270,27 @@ def yaml2csv(address, config_name="config.yml", csv_folder_name="csv"):
writer.writerow(
[str(value_[x]) if x in value_.keys() else None for x in fields]
)


def load_extension_files(files: list[str]) -> None:
"""Load extension files from a list of files.
Args:
files (list[str]): List of file paths to load
Raises:
ValueError: If file is not a .py file
FileNotFoundError: If file does not exist
"""
import importlib
from pathlib import Path

for file in files:
if not file.endswith(".py"):
raise ValueError(f"Only .py files are supported. Invalid file: {file}")
if not Path(file).exists():
raise FileNotFoundError(f"File {file} does not exist")

spec = importlib.util.spec_from_file_location("module.name", file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

0 comments on commit f1db662

Please sign in to comment.