Skip to content

Commit

Permalink
Merge pull request #138 from PSLmodels/get-defaults-override
Browse files Browse the repository at this point in the history
Add get_defaults method to allow sub-class to customize parameter loading
  • Loading branch information
jdebacker authored Nov 23, 2024
2 parents 8d5ff38 + f732ff7 commit d68b28a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions paramtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"select_ne",
"read_json",
"get_example_paths",
"get_defaults",
"LeafGetter",
"get_leaves",
"ravel",
Expand Down
1 change: 1 addition & 0 deletions paramtools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class InconsistentLabelsException(ParamToolsError):
"to_dict",
"_parse_validation_messages",
"sel",
"get_defaults",
]


Expand Down
14 changes: 13 additions & 1 deletion paramtools/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
sort_values: bool = True,
**ops,
):
schemafactory = SchemaFactory(self.defaults)
schemafactory = SchemaFactory(self.get_defaults())
(
self._defaults_schema,
self._validator_schema,
Expand Down Expand Up @@ -1403,3 +1403,15 @@ def keyfunc(vo, label, label_values):
)[param]
setattr(self, param, sorted_values)
return data

def get_defaults(self):
"""
Hook for implementing custom behavior for getting the default parameters.
**Returns**
- `params`: String if URL or file path. Dict if this is the loaded params
dict.
"""
return utils.read_json(self.defaults)
2 changes: 0 additions & 2 deletions paramtools/schema_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
get_param_schema,
ParamToolsSchema,
)
from paramtools import utils


class SchemaFactory:
Expand All @@ -26,7 +25,6 @@ class SchemaFactory:
"""

def __init__(self, defaults):
defaults = utils.read_json(defaults)
self.defaults = {k: v for k, v in defaults.items() if k != "schema"}
self.schema = ParamToolsSchema().load(defaults.get("schema", {}))
(self.BaseParamSchema, self.label_validators) = get_param_schema(
Expand Down
30 changes: 30 additions & 0 deletions paramtools/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,36 @@ class Params(Parameters):
assert params.hello_world == "hello world"
assert params.label_grid == {}

def test_get_defaults_override(self):
class Params(Parameters):
array_first = True
defaults = {
"schema": {
"labels": {
"somelabel": {
"type": "int",
"validators": {"range": {"min": 0, "max": 2}},
}
}
},
"hello_world": {
"title": "Hello, World!",
"description": "Simplest config possible.",
"type": "str",
"value": "hello world",
},
}

def get_defaults(self):
label = self.defaults["schema"]["labels"]["somelabel"]
label["validators"]["range"]["max"] = 5
return self.defaults

params = Params()
assert params.hello_world == "hello world"
assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]}


def test_schema_not_dropped(self, defaults_spec_path):
with open(defaults_spec_path, "r") as f:
defaults_ = json.loads(f.read())
Expand Down

0 comments on commit d68b28a

Please sign in to comment.