From b4776a9125943cd0fecd50d9f9968bae66ece0e8 Mon Sep 17 00:00:00 2001 From: Benjamin Pelletier Date: Thu, 30 Mar 2023 11:42:35 -0700 Subject: [PATCH] Fix parsing of values in lists (#7) --- src/implicitdict/__init__.py | 15 ++++++----- tests/test_containers.py | 49 ++++++++++++++++++++++++++++++++++++ tests/test_mutability.py | 2 +- tests/test_normal_usage.py | 6 ++--- 4 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 tests/test_containers.py diff --git a/src/implicitdict/__init__.py b/src/implicitdict/__init__.py index 6922eed..9e13680 100644 --- a/src/implicitdict/__init__.py +++ b/src/implicitdict/__init__.py @@ -159,14 +159,13 @@ def _parse_value(value, value_type: Type): # Type is generic arg_types = get_args(value_type) if generic_type is list: - if get_origin(arg_types[0]) is list: - return value - elif issubclass(arg_types[0], ImplicitDict): - # value is a list of some kind of ImplicitDict values - return [ImplicitDict.parse(item, arg_types[0]) for item in value] - else: - # value is a list of non-ImplicitDict values - return value + try: + value_list = [v for v in value] + except TypeError as e: + if "not iterable" in str(e): + raise ValueError(f"Cannot parse non-iterable value '{value}' of type '{type(value).__name__}' into list type '{value_type}'") + raise + return [_parse_value(v, arg_types[0]) for v in value_list] elif generic_type is dict: # value is a dict of some kind diff --git a/tests/test_containers.py b/tests/test_containers.py new file mode 100644 index 0000000..2df4065 --- /dev/null +++ b/tests/test_containers.py @@ -0,0 +1,49 @@ +from typing import List, Optional + +from implicitdict import ImplicitDict + + +class MySpecialClass(str): + @property + def is_special(self) -> bool: + return True + + +class MyContainers(ImplicitDict): + single_value: MySpecialClass + value_list: List[MySpecialClass] + optional_list: Optional[List[MySpecialClass]] + optional_value_list: List[Optional[MySpecialClass]] + list_of_lists: List[List[MySpecialClass]] + + +def test_container_item_value_casting(): + containers: MyContainers = ImplicitDict.parse( + { + "single_value": "foo", + "value_list": ["value1", "value2"], + "optional_list": ["bar"], + "optional_value_list": ["baz", None], + "list_of_lists": [["list1v1", "list1v2"], ["list2v1"]] + }, MyContainers) + + assert containers.single_value.is_special + + assert len(containers.value_list) == 2 + for v in containers.value_list: + assert v.is_special + + assert len(containers.optional_list) == 1 + assert containers.optional_list[0].is_special + + assert len(containers.optional_value_list) == 2 + for v in containers.optional_value_list: + assert (v is None) or v.is_special + + assert len(containers.list_of_lists) == 2 + assert len(containers.list_of_lists[0]) == 2 + for v in containers.list_of_lists[0]: + assert v.is_special + assert len(containers.list_of_lists[1]) == 1 + for v in containers.list_of_lists[1]: + assert v.is_special diff --git a/tests/test_mutability.py b/tests/test_mutability.py index 222fbf8..31801cf 100644 --- a/tests/test_mutability.py +++ b/tests/test_mutability.py @@ -46,7 +46,7 @@ def test_mutability_from_parse(): assert data.primitive != primitive primitive_list[1] = 'three' - assert data.list_of_primitives[1] == 'three' + assert data.list_of_primitives[1] == 'two' # <-- lists are reconstructed with `parse` generic_dict['level1'] = 'bar' assert data.generic_dict['level1'] == 'foo' # <-- dicts are reconstructed with `parse` diff --git a/tests/test_normal_usage.py b/tests/test_normal_usage.py index 5cb3f6d..4eb12b7 100644 --- a/tests/test_normal_usage.py +++ b/tests/test_normal_usage.py @@ -117,7 +117,7 @@ def test_nested_structures(): 'my_list': [{'foo': 'one'}, {'foo': 'two'}], 'my_list_2': [[1, 2], [3, 4, 5]], 'my_list_3': [[[1, 2, 3], [4, 5]], [[6], [7], [8]], [[9, 10]]], - 'my_dict': {'foo': 1.23, 'bar': 4.56} + 'my_dict': {'foo': [1.23], 'bar': [4.56]} } data: NestedStructures = ImplicitDict.parse(src_dict, NestedStructures) @@ -153,5 +153,5 @@ def test_nested_structures(): assert data.my_list_3[2][0][1] == 10 assert len(data.my_dict) == 2 - assert data.my_dict['foo'] == 1.23 - assert data.my_dict['bar'] == 4.56 + assert data.my_dict['foo'] == [1.23] + assert data.my_dict['bar'] == [4.56]