Skip to content

Commit

Permalink
Merge pull request #116 from ImperialCollegeLondon/model-extension-at…
Browse files Browse the repository at this point in the history
…tribute

Model didn't save `extensions`
  • Loading branch information
barneydobson authored Oct 15, 2024
2 parents fd9182d + ebd6581 commit 7a053a1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def test_custom_class_from_file():
model.run(dates=[to_datetime("2000-01-01")])
assert model.nodes["node_name"].custom_attr == 2

model.save(temp_dir, "new_config.yml")

# Remove the custom class from the registry to test loading it again
del model
NODES_REGISTRY.pop("CustomNode", None)

model = Model()
model.load(temp_dir, "new_config.yml")
assert model.nodes["node_name"].custom_attr == 1
model.run(dates=[to_datetime("2000-01-01")])
assert model.nodes["node_name"].custom_attr == 2


def test_custom_class_on_the_fly():
"""Test a custom class."""
Expand Down
3 changes: 3 additions & 0 deletions wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self):
# self.arcs_type = {} #not sure that this would be necessary
self.nodes = {}
self.nodes_type = {}
self.extensions = []

# Default orchestration
self.orchestration = [
Expand Down Expand Up @@ -200,6 +201,7 @@ def load(self, address, config_name="config.yml", overrides={}):
E.G. ADDITION FOR NEW ORCHESTRATION
"""
load_extension_files(data.get("extensions", []))
self.extensions = data.get("extensions", [])

if "orchestration" in data.keys():
# Update orchestration
Expand Down Expand Up @@ -330,6 +332,7 @@ def save(self, address, config_name="config.yml", compress=False):
"additive_pollutants": constants.ADDITIVE_POLLUTANTS,
"non_additive_pollutants": constants.NON_ADDITIVE_POLLUTANTS,
"float_accuracy": constants.FLOAT_ACCURACY,
"extensions": self.extensions,
}
if hasattr(self, "dates"):
data["dates"] = [str(x) for x in self.dates]
Expand Down

0 comments on commit 7a053a1

Please sign in to comment.