Skip to content

Commit

Permalink
Merge branch 'main' into demand-overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
barneydobson authored Oct 1, 2024
2 parents d88f93d + be12a15 commit 738d9e8
Show file tree
Hide file tree
Showing 14 changed files with 1,321 additions and 921 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/draft-pdf.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ jobs:
name: Paper Draft
steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Build draft PDF
uses: openjournals/openjournals-draft-action@master
with:
journal: joss
# This should be the path to the paper within your repo.
paper-path: docs/paper/paper.md
- name: Upload
uses: actions/upload-artifact@v1
uses: actions/upload-artifact@v4
with:
name: paper
# This is the output path where Pandoc will write the compiled
Expand Down
75 changes: 75 additions & 0 deletions docs/demo/examples/test_customise_orchestration_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
orchestration:
- Groundwater: infiltrate
- Sewer: make_discharge

nodes:
Sewer:
type_: Sewer
name: my_sewer
capacity: 0.04

Groundwater:
type_: Groundwater
name: my_groundwater
capacity: 100
area: 100

River:
type_: Node
name: my_river

Waste:
type_: Waste
name: my_outlet

arcs:
storm_outflow:
type_: Arc
name: storm_outflow
in_port: my_sewer
out_port: my_river

baseflow:
type_: Arc
name: baseflow
in_port: my_groundwater
out_port: my_river

catchment_outflow:
type_: Arc
name: catchment_outflow
in_port: my_river
out_port: my_outlet

pollutants:
- org-phosphorus
- phosphate
- ammonia
- solids
- temperature
- nitrate
- nitrite
- org-nitrogen
additive_pollutants:
- org-phosphorus
- phosphate
- ammonia
- solids
- nitrate
- nitrite
- org-nitrogen
non_additive_pollutants:
- temperature
float_accuracy: 1.0e-06

dates:
- '2000-01-01'
- '2000-01-02'
- '2000-01-03'
- '2000-01-04'
- '2000-01-05'
- '2000-01-06'
- '2000-01-07'
- '2000-01-08'
- '2000-01-09'
- '2000-01-10'
174 changes: 174 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Optional

import pytest


@pytest.fixture
def temp_extension_registry():
from wsimod.extensions import extensions_registry

bkp = extensions_registry.copy()
extensions_registry.clear()
yield
extensions_registry.clear()
extensions_registry.update(bkp)


def test_register_node_patch(temp_extension_registry):
from wsimod.extensions import extensions_registry, register_node_patch

# Define a dummy function to patch a node method
@register_node_patch("node_name", "method_name")
def dummy_patch():
print("Patched method")

# Check if the patch is registered correctly
assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch

# Another function with other arguments
@register_node_patch("node_name", "method_name", item="default", is_attr=True)
def another_dummy_patch():
print("Another patched method")

# Check if this other patch is registered correctly
assert (
extensions_registry[("node_name", "method_name", "default", True)]
== another_dummy_patch
)


def test_apply_patches(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import (
apply_patches,
extensions_registry,
register_node_patch,
)
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

# Create a dummy model
node = Node("dummy_node")
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)
model = Model()
model.nodes[node.name] = node

# 1. Patch a method
@register_node_patch("dummy_node", "apply_overrides")
def dummy_patch():
pass

# 2. Patch an attribute
@register_node_patch("dummy_node", "t", is_attr=True)
def another_dummy_patch(node):
return f"A pathced attribute for {node.name}"

# 3. Patch a method with an item
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def yet_another_dummy_patch():
pass

# 4. Path a method of an attribute
@register_node_patch("dummy_node", "dummy_arc.arc_mass_balance")
def arc_dummy_patch():
pass

# Check if all patches are registered
assert len(extensions_registry) == 4

# Apply the patches
apply_patches(model)

# Verify that the patches are applied correctly
assert (
model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__
)
assert (
model.nodes[node.name]._patched_apply_overrides.__qualname__
== "Node.apply_overrides"
)
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name]._patched_t == None
assert (
model.nodes[node.name].pull_set_handler["default"].__qualname__
== yet_another_dummy_patch.__qualname__
)
assert (
model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__
== arc_dummy_patch.__qualname__
)
assert (
model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__
== "Arc.arc_mass_balance"
)


def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None):
"""Check if two dictionaries are almost equal.
Args:
d1 (dict): The first dictionary.
d2 (dict): The second dictionary.
tol (float | None, optional): Relative tolerance. Defaults to 1e-6,
`pytest.approx` default.
"""
for key in d1.keys():
assert d1[key] == pytest.approx(d2[key], rel=tol)


def test_path_method_with_reuse(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes.storage import Reservoir
from wsimod.orchestration.model import Model

# Create a dummy model
node = Reservoir(name="dummy_node", initial_storage=10, capacity=10)
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)

vq = node.pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5))

model = Model()
model.nodes[node.name] = node

@register_node_patch("dummy_node", "pull_distributed")
def new_pull_distributed(self, vqip, of_type=None, tag="default"):
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag)

# Apply the patches
apply_patches(model)

# Check appropriate result
assert node.tank.storage["volume"] == 5
vq = model.nodes[node.name].pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.empty_vqip())
assert node.tank.storage["volume"] == 5


def test_handler_extensions(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

# Create a dummy model
node = Node("dummy_node")
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)
model = Model()
model.nodes[node.name] = node

# 1. Patch a handler
@register_node_patch("dummy_node", "pull_check_handler", item="default")
def dummy_patch(self, *args, **kwargs):
return "dummy_patch"

# 2. Patch a handler with access to self
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def dummy_patch(self, vqip, *args, **kwargs):
return f"{self.name} - {vqip['volume']}"

apply_patches(model)

assert node.pull_check() == "dummy_patch"
assert node.pull_set({"volume": 1}) == "dummy_node - 1"
13 changes: 13 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from wsimod.nodes.sewer import Sewer
from wsimod.nodes.waste import Waste
from wsimod.orchestration.model import Model, to_datetime
import os


class MyTestClass(TestCase):
Expand Down Expand Up @@ -291,6 +292,18 @@ def test_run(self):
0.03, my_model.nodes["my_land"].get_surface("urban").storage["volume"]
)

def test_customise_orchestration(self):
my_model = Model()
my_model.load(
os.path.join(os.getcwd(), "docs", "demo", "examples"),
config_name="test_customise_orchestration_example.yaml",
)
revised_orchestration = [
{"Groundwater": "infiltrate"},
{"Sewer": "make_discharge"},
]
self.assertListEqual(my_model.orchestration, revised_orchestration)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 738d9e8

Please sign in to comment.