From f732ff72202d9cb6c867d79cc5ae4ad75861522e Mon Sep 17 00:00:00 2001 From: hdoupe Date: Fri, 26 Jul 2024 13:38:33 -0400 Subject: [PATCH] Add get_defaults method to allow sub-class to customize parameter loading --- paramtools/__init__.py | 1 + paramtools/exceptions.py | 1 + paramtools/parameters.py | 14 +++++++++++++- paramtools/schema_factory.py | 2 -- paramtools/tests/test_parameters.py | 30 +++++++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 3 deletions(-) diff --git a/paramtools/__init__.py b/paramtools/__init__.py index d8088f2..b99bcd0 100644 --- a/paramtools/__init__.py +++ b/paramtools/__init__.py @@ -89,6 +89,7 @@ "select_ne", "read_json", "get_example_paths", + "get_defaults", "LeafGetter", "get_leaves", "ravel", diff --git a/paramtools/exceptions.py b/paramtools/exceptions.py index fe8b9ef..c4f51ed 100644 --- a/paramtools/exceptions.py +++ b/paramtools/exceptions.py @@ -94,6 +94,7 @@ class InconsistentLabelsException(ParamToolsError): "to_dict", "_parse_validation_messages", "sel", + "get_defaults", ] diff --git a/paramtools/parameters.py b/paramtools/parameters.py index eff211e..fe3e908 100644 --- a/paramtools/parameters.py +++ b/paramtools/parameters.py @@ -80,7 +80,7 @@ def __init__( sort_values: bool = True, **ops, ): - schemafactory = SchemaFactory(self.defaults) + schemafactory = SchemaFactory(self.get_defaults()) ( self._defaults_schema, self._validator_schema, @@ -1403,3 +1403,15 @@ def keyfunc(vo, label, label_values): )[param] setattr(self, param, sorted_values) return data + + def get_defaults(self): + """ + Hook for implementing custom behavior for getting the default parameters. + + + **Returns** + + - `params`: String if URL or file path. Dict if this is the loaded params + dict. + """ + return utils.read_json(self.defaults) \ No newline at end of file diff --git a/paramtools/schema_factory.py b/paramtools/schema_factory.py index 32a5480..b841169 100644 --- a/paramtools/schema_factory.py +++ b/paramtools/schema_factory.py @@ -8,7 +8,6 @@ get_param_schema, ParamToolsSchema, ) -from paramtools import utils class SchemaFactory: @@ -26,7 +25,6 @@ class SchemaFactory: """ def __init__(self, defaults): - defaults = utils.read_json(defaults) self.defaults = {k: v for k, v in defaults.items() if k != "schema"} self.schema = ParamToolsSchema().load(defaults.get("schema", {})) (self.BaseParamSchema, self.label_validators) = get_param_schema( diff --git a/paramtools/tests/test_parameters.py b/paramtools/tests/test_parameters.py index 150c9ed..1c5a6ec 100644 --- a/paramtools/tests/test_parameters.py +++ b/paramtools/tests/test_parameters.py @@ -137,6 +137,36 @@ class Params(Parameters): assert params.hello_world == "hello world" assert params.label_grid == {} + def test_get_defaults_override(self): + class Params(Parameters): + array_first = True + defaults = { + "schema": { + "labels": { + "somelabel": { + "type": "int", + "validators": {"range": {"min": 0, "max": 2}}, + } + } + }, + "hello_world": { + "title": "Hello, World!", + "description": "Simplest config possible.", + "type": "str", + "value": "hello world", + }, + } + + def get_defaults(self): + label = self.defaults["schema"]["labels"]["somelabel"] + label["validators"]["range"]["max"] = 5 + return self.defaults + + params = Params() + assert params.hello_world == "hello world" + assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]} + + def test_schema_not_dropped(self, defaults_spec_path): with open(defaults_spec_path, "r") as f: defaults_ = json.loads(f.read())