diff --git a/src/implicitdict/__init__.py b/src/implicitdict/__init__.py index f1c6846..1e64695 100644 --- a/src/implicitdict/__init__.py +++ b/src/implicitdict/__init__.py @@ -1,13 +1,16 @@ +import inspect +from dataclasses import dataclass + import arrow import datetime -from typing import get_args, get_origin, get_type_hints, Dict, Literal, Optional, Type, Union +from typing import get_args, get_origin, get_type_hints, Dict, Literal, \ + Optional, Type, Union, Set, Tuple import pytimeparse _DICT_FIELDS = set(dir({})) -_KEY_ALL_FIELDS = '_all_fields' -_KEY_OPTIONAL_FIELDS = '_optional_fields' +_KEY_FIELDS_INFO = '_fields_info' class ImplicitDict(dict): @@ -87,48 +90,17 @@ def parse(cls, source: Dict, parse_type: Type): return parse_type(**kwargs) def __init__(self, previous_instance: Optional[dict]=None, **kwargs): - super(ImplicitDict, self).__init__() + ancestor_kwargs = {} subtype = type(self) - if not hasattr(subtype, _KEY_ALL_FIELDS): - # Enumerate all fields and default values defined for the subclass - all_fields = set() - annotations = type(self).__annotations__ if hasattr(type(self), '__annotations__') else {} - for key in annotations: - all_fields.add(key) - - attributes = set() - for key in dir(self): - if key not in _DICT_FIELDS and key[0:2] != '__' and not callable(getattr(self, key)): - all_fields.add(key) - attributes.add(key) - - # Identify which fields are Optional - optional_fields = set() - for key, field_type in annotations.items(): - generic_type = get_origin(field_type) - if generic_type is Optional: - optional_fields.add(key) - elif generic_type is Union: - generic_args = get_args(field_type) - if len(generic_args) == 2 and generic_args[1] is type(None): - optional_fields.add(key) - for key in attributes: - if key not in annotations: - optional_fields.add(key) - - setattr(subtype, _KEY_ALL_FIELDS, all_fields) - setattr(subtype, _KEY_OPTIONAL_FIELDS, optional_fields) - else: - all_fields = getattr(subtype, _KEY_ALL_FIELDS) - optional_fields = getattr(subtype, _KEY_OPTIONAL_FIELDS) + all_fields, optional_fields = _get_fields(subtype) # Copy explicit field values passed to the constructor provided_values = set() if previous_instance: for key, value in previous_instance.items(): if key in all_fields: - self[key] = value + ancestor_kwargs[key] = value provided_values.add(key) for key, value in kwargs.items(): if key in all_fields: @@ -137,33 +109,44 @@ def __init__(self, previous_instance: Optional[dict]=None, **kwargs): # actually providing a value; instead, consider it omitting a value. pass else: - self[key] = value + ancestor_kwargs[key] = value provided_values.add(key) # Copy default field values for key in all_fields: if key not in provided_values: - if hasattr(type(self), key): - self[key] = super(ImplicitDict, self).__getattribute__(key) + if hasattr(subtype, key): + ancestor_kwargs[key] = super(ImplicitDict, self).__getattribute__(key) # Make sure all fields without a default and not labeled Optional were provided for key in all_fields: - if key not in self and key not in optional_fields: - raise ValueError('Required field "{}" not specified in {}'.format(key, type(self).__name__)) + if key not in ancestor_kwargs and key not in optional_fields: + raise ValueError('Required field "{}" not specified in {}'.format(key, subtype.__name__)) + + super(ImplicitDict, self).__init__(**ancestor_kwargs) def __getattribute__(self, item): - if hasattr(type(self), _KEY_ALL_FIELDS) and item in getattr(type(self), _KEY_ALL_FIELDS): - return self[item] + self_type = type(self) + if hasattr(self_type, _KEY_FIELDS_INFO): + fields_info_by_type: Dict[str, FieldsInfo] = getattr(self_type, _KEY_FIELDS_INFO) + self_type_name = _fullname(self_type) + if self_type_name in fields_info_by_type: + if item in fields_info_by_type[self_type_name].all_fields: + return self[item] return super(ImplicitDict, self).__getattribute__(item) def __setattr__(self, key, value): - if hasattr(type(self), _KEY_ALL_FIELDS): - if key in getattr(type(self), _KEY_ALL_FIELDS): - self[key] = value - else: - raise KeyError('Attribute "{}" is not defined for "{}" object'.format(key, type(self).__name__)) - else: - super(ImplicitDict, self).__setattr__(key, value) + self_type = type(self) + if hasattr(self_type, _KEY_FIELDS_INFO): + fields_info_by_type: Dict[str, FieldsInfo] = getattr(self_type, _KEY_FIELDS_INFO) + self_type_name = _fullname(self_type) + if self_type_name in fields_info_by_type: + if key in fields_info_by_type[self_type_name].all_fields: + self[key] = value + return + else: + raise KeyError('Attribute "{}" is not defined for "{}" object'.format(key, type(self).__name__)) + super(ImplicitDict, self).__setattr__(key, value) def has_field_with_value(self, field_name: str) -> bool: return field_name in self and self[field_name] is not None @@ -216,6 +199,83 @@ def _parse_value(value, value_type: Type): return value_type(value) if value_type else value +@dataclass +class FieldsInfo(object): + all_fields: Set[str] + optional_fields: Set[str] + + + +def _get_fields(subtype: Type) -> Tuple[Set[str], Set[str]]: + """Determine all fields and optional fields for the specified type. + + When all & optional fields are determined for a type, the result is cached + as an entry in the _KEY_FIELDS_INFO attribute added to the type itself so + this evaluation only needs to be performed once per type. + + Returns: + * Names of all fields for subtype + * Names of all optional fields for subtype + """ + if not hasattr(subtype, _KEY_FIELDS_INFO): + setattr(subtype, _KEY_FIELDS_INFO, {}) + fields_info_by_type: Dict[str, FieldsInfo] = getattr(subtype, _KEY_FIELDS_INFO) + subtype_name = _fullname(subtype) + if subtype_name not in fields_info_by_type: + # Enumerate fields defined for superclasses + all_fields = set() + optional_fields = set() + ancestors = inspect.getmro(subtype) + for ancestor in ancestors: + if issubclass(ancestor, ImplicitDict) and ancestor is not subtype and ancestor is not ImplicitDict: + ancestor_all_fields, ancestor_optional_fields = _get_fields(ancestor) + all_fields = all_fields.union(ancestor_all_fields) + optional_fields = optional_fields.union(ancestor_optional_fields) + + # Enumerate all fields defined for the subclass + annotations = subtype.__annotations__ if hasattr(subtype, '__annotations__') else {} + for key in annotations: + all_fields.add(key) + + attributes = set() + for key in dir(subtype): + if ( + key != _KEY_FIELDS_INFO + and key not in _DICT_FIELDS + and key[0:2] != '__' + and not callable(getattr(subtype, key)) + ): + all_fields.add(key) + attributes.add(key) + + # Identify which fields are Optional + for key, field_type in annotations.items(): + generic_type = get_origin(field_type) + if generic_type is Optional: + optional_fields.add(key) + elif generic_type is Union: + generic_args = get_args(field_type) + if len(generic_args) == 2 and generic_args[1] is type(None): + optional_fields.add(key) + for key in attributes: + if key not in annotations: + optional_fields.add(key) + + fields_info_by_type[subtype_name] = FieldsInfo( + all_fields=all_fields, + optional_fields=optional_fields + ) + result = fields_info_by_type[subtype_name] + return result.all_fields, result.optional_fields + + +def _fullname(class_type: Type) -> str: + module = class_type.__module__ + if module == "builtins": + return class_type.__qualname__ # avoid outputs like 'builtins.str' + return module + "." + class_type.__qualname__ + + class StringBasedTimeDelta(str): """String that only allows values which describe a timedelta.""" def __new__(cls, value): diff --git a/tests/test_inheritance.py b/tests/test_inheritance.py new file mode 100644 index 0000000..f588cec --- /dev/null +++ b/tests/test_inheritance.py @@ -0,0 +1,71 @@ +import json +from typing import Optional + +from implicitdict import ImplicitDict + + +class MyData(ImplicitDict): + foo: str + bar: int = 0 + baz: Optional[float] + has_default_baseclass: str = "In MyData" + + def hello(self) -> str: + return "MyData" + + def base_method(self) -> int: + return 123 + + +class MySubclass(MyData): + buzz: Optional[str] + has_default_subclass: str = "In MySubclass" + + def hello(self) -> str: + return "MySubclass" + + +def test_inheritance(): + data: MyData = ImplicitDict.parse({'foo': 'asdf', 'bar': 1}, MyData) + assert json.loads(json.dumps(data)) == {"foo": "asdf", "bar": 1, "has_default_baseclass": "In MyData"} + assert data.hello() == "MyData" + assert data.has_default_baseclass == "In MyData" + + subclass: MySubclass = ImplicitDict.parse(json.loads(json.dumps(data)), MySubclass) + assert subclass.foo == "asdf" + assert subclass.bar == 1 + assert "baz" not in subclass + assert subclass.hello() == "MySubclass" + assert subclass.base_method() == 123 + subclass.buzz = "burrs" + assert subclass.has_default_baseclass == "In MyData" + assert subclass.has_default_subclass == "In MySubclass" + subclass.has_default_baseclass = "In MyData 2" + subclass.has_default_subclass = "In MySubclass 2" + + subclass = MySubclass(data) + assert subclass.foo == "asdf" + assert subclass.bar == 1 + assert "baz" not in subclass + assert subclass.hello() == "MySubclass" + assert subclass.base_method() == 123 + assert "buzz" not in subclass + subclass.buzz = "burrs" + assert subclass.has_default_baseclass == "In MyData" + assert subclass.has_default_subclass == "In MySubclass" + subclass.has_default_baseclass = "In MyData 3" + subclass.has_default_subclass = "In MySubclass 3" + + data2 = MyData(subclass) + assert data2.foo == "asdf" + assert data2.bar == 1 + assert "baz" not in data2 + assert data2.hello() == "MyData" + assert "buzz" not in data2 + assert data2.has_default_baseclass == "In MyData 3" + data2.has_default_baseclass = "In MyData 4" + + subclass2 = MySubclass(subclass) + assert subclass2.buzz == "burrs" + assert subclass.has_default_baseclass == "In MyData 3" + assert subclass.has_default_subclass == "In MySubclass 3" diff --git a/tests/test_mutability.py b/tests/test_mutability.py new file mode 100644 index 0000000..135745b --- /dev/null +++ b/tests/test_mutability.py @@ -0,0 +1,56 @@ +import json +from typing import Optional, List + +from implicitdict import ImplicitDict + + +class MyData(ImplicitDict): + primitive: str + list_of_primitives: List[str] + generic_dict: dict + subtype: Optional["MyData"] + + +def test_mutability_from_constructor(): + primitive = 'foobar' + primitive_list = ['one', 'two'] + generic_dict = {'level1': 'foo', 'level2': {'bar': 'baz'}} + data = MyData(primitive=primitive, list_of_primitives=primitive_list, generic_dict=generic_dict) + assert data.primitive == primitive + assert data.list_of_primitives == primitive_list + assert data.generic_dict == generic_dict + + primitive = 'foobar2' + assert data.primitive != primitive + + primitive_list[1] = 'three' + assert data.list_of_primitives[1] == 'three' + + generic_dict['level1'] = 'bar' + assert data.generic_dict['level1'] == 'bar' + + generic_dict['level2']['bar'] = 'buzz' + assert data.generic_dict['level2']['bar'] == 'buzz' + + +def test_mutability_from_parse(): + primitive = 'foobar' + primitive_list = ['one', 'two'] + generic_dict = {'level1': 'foo', 'level2': {'bar': 'baz'}} + data_source = MyData(primitive=primitive, list_of_primitives=primitive_list, generic_dict=generic_dict) + data: MyData = ImplicitDict.parse(data_source, MyData) + assert data.primitive == primitive + assert data.list_of_primitives == primitive_list + assert data.generic_dict == generic_dict + + primitive = 'foobar2' + assert data.primitive != primitive + + primitive_list[1] = 'three' + assert data.list_of_primitives[1] == 'three' + + generic_dict['level1'] = 'bar' + assert data.generic_dict['level1'] == 'foo' # <-- dicts are reconstructed with `parse` + + generic_dict['level2']['bar'] = 'buzz' + assert data.generic_dict['level2']['bar'] == 'buzz' diff --git a/tests/test_optional.py b/tests/test_optional.py new file mode 100644 index 0000000..44a3d2a --- /dev/null +++ b/tests/test_optional.py @@ -0,0 +1,130 @@ +import json +from typing import Optional + +import pytest + +from implicitdict import ImplicitDict + + +class MyData(ImplicitDict): + required_field: str + optional_field1: Optional[str] + field_with_default: str = "default value" + optional_field2_with_none_default: Optional[str] = None + optional_field3_with_default: Optional[str] = "concrete default" + + +def test_fully_defined(): + data = MyData( + required_field="foo1", + optional_field1="foo2", + field_with_default="foo3", + optional_field2_with_none_default="foo4", + optional_field3_with_default="foo5", + ) + assert "required_field" in data + assert "optional_field1" in data + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data + assert "optional_field3_with_default" in data + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" in s + assert "field_with_default" in s + assert "optional_field2" in s + assert "optional_field3" in s + assert "foo1" in s + assert "foo2" in s + assert "foo3" in s + assert "foo4" in s + assert "foo5" in s + + +def test_minimally_defined(): + # An unspecified optional field will not be present in the object at all + data = MyData(required_field="foo1") + assert "required_field" in data + assert "optional_field1" not in data + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data + assert "optional_field3_with_default" in data + with pytest.raises(KeyError): + # Trying to reference the Optional field will result in a KeyError + # To determine whether an Optional field is present, the user must check + # whether `"" in ` (see above). + assert data.optional_field1 == None + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" not in s + assert "field_with_default" in s + assert "optional_field2" in s + assert "optional_field3" in s + assert "foo1" in s + + +def test_provide_optional_field(): + data = MyData(required_field="foo1", optional_field1="foo2") + assert "required_field" in data + assert "optional_field1" in data + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data + assert "optional_field3_with_default" in data + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" in s + assert "field_with_default" in s + assert "optional_field2" in s + assert "optional_field3" in s + assert "foo1" in s + assert "foo2" in s + + +def test_provide_optional_field_as_none(): + # If an optional field with no default is explicitly provided as None, then that field will not be included in the object + data = MyData(required_field="foo1", optional_field1=None) + assert "required_field" in data + assert "optional_field1" not in data # <-- + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data + assert "optional_field3_with_default" in data + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" not in s # <-- + assert "field_with_default" in s + assert "optional_field2" in s + assert "optional_field3" in s + assert "foo1" in s + + +def test_provide_optional_field_with_none_default_as_none(): + # If a field has a default value, the field will always be present in the object, even if that default value is None and the field is Optional + data = MyData(required_field="foo1", optional_field2_with_none_default=None) + assert "required_field" in data + assert "optional_field1" not in data + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data # <-- + assert "optional_field3_with_default" in data + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" not in s + assert "field_with_default" in s + assert "optional_field2" in s # <-- + assert "optional_field3" in s + assert "foo1" in s + + +def test_provide_optional_field_with_default_as_none(): + # If a field has a default value, the field will always be present in the object + data = MyData(required_field="foo1", optional_field3_with_default=None) + assert "required_field" in data + assert "optional_field1" not in data + assert "field_with_default" in data + assert "optional_field2_with_none_default" in data + assert "optional_field3_with_default" in data # <-- + s = json.dumps(data) + assert "required_field" in s + assert "optional_field1" not in s + assert "field_with_default" in s + assert "optional_field2" in s + assert "optional_field3" in s # <-- + assert "foo1" in s