Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add get_defaults method to allow sub-class to customize parameter loading #138

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paramtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"select_ne",
"read_json",
"get_example_paths",
"get_defaults",
"LeafGetter",
"get_leaves",
"ravel",
Expand Down
1 change: 1 addition & 0 deletions paramtools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class InconsistentLabelsException(ParamToolsError):
"to_dict",
"_parse_validation_messages",
"sel",
"get_defaults",
]


Expand Down
14 changes: 13 additions & 1 deletion paramtools/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
sort_values: bool = True,
**ops,
):
schemafactory = SchemaFactory(self.defaults)
schemafactory = SchemaFactory(self.get_defaults())
(
self._defaults_schema,
self._validator_schema,
Expand Down Expand Up @@ -1403,3 +1403,15 @@ def keyfunc(vo, label, label_values):
)[param]
setattr(self, param, sorted_values)
return data

def get_defaults(self):
"""
Hook for implementing custom behavior for getting the default parameters.


**Returns**

- `params`: String if URL or file path. Dict if this is the loaded params
dict.
"""
return utils.read_json(self.defaults)
2 changes: 0 additions & 2 deletions paramtools/schema_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
get_param_schema,
ParamToolsSchema,
)
from paramtools import utils


class SchemaFactory:
Expand All @@ -26,7 +25,6 @@ class SchemaFactory:
"""

def __init__(self, defaults):
defaults = utils.read_json(defaults)
self.defaults = {k: v for k, v in defaults.items() if k != "schema"}
self.schema = ParamToolsSchema().load(defaults.get("schema", {}))
(self.BaseParamSchema, self.label_validators) = get_param_schema(
Expand Down
30 changes: 30 additions & 0 deletions paramtools/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,36 @@ class Params(Parameters):
assert params.hello_world == "hello world"
assert params.label_grid == {}

def test_get_defaults_override(self):
class Params(Parameters):
array_first = True
defaults = {
"schema": {
"labels": {
"somelabel": {
"type": "int",
"validators": {"range": {"min": 0, "max": 2}},
}
}
},
"hello_world": {
"title": "Hello, World!",
"description": "Simplest config possible.",
"type": "str",
"value": "hello world",
},
}

def get_defaults(self):
label = self.defaults["schema"]["labels"]["somelabel"]
label["validators"]["range"]["max"] = 5
return self.defaults

params = Params()
assert params.hello_world == "hello world"
assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]}


def test_schema_not_dropped(self, defaults_spec_path):
with open(defaults_spec_path, "r") as f:
defaults_ = json.loads(f.read())
Expand Down