diff --git a/src/flatten_dict/flatten_dict.py b/src/flatten_dict/flatten_dict.py index 96c0e3c..e3ecc0f 100644 --- a/src/flatten_dict/flatten_dict.py +++ b/src/flatten_dict/flatten_dict.py @@ -31,6 +31,7 @@ def flatten( max_flatten_depth=None, enumerate_types=(), keep_empty_types=(), + keep_element_types=(), ): """Flatten `Mapping` object. @@ -62,6 +63,21 @@ def flatten( >>> flatten({1: 2, 3: {}}, keep_empty_types=(dict,)) {(1,): 2, (3,): {}} + keep_element_types : Sequence[type] + As an option if `enumerate_types` is specified, skip enumerating if at least one of the + elements is listed in `keep_element_types`. + For example, if we set `enumerate_types` to ``(list,)`` and `keep_element_types` to + ``(str, )``, list of str will be kept without being enumerated while list of non-str will + be enumerated: + + >>> flatten({'a': ['b', 'c']}, enumerate_types=(list,), keep_element_types=(str,)) + {('a',): ['b', 'c']} + + >>> flatten({'a': [10, 11]}, enumerate_types=(list,), keep_element_types=(str,)) + {('a', 0): 10, ('a', 1): 11} + + >>> flatten({'a': [{'b': 'foo'}]}, enumerate_types=(list,), keep_element_types=(str,)) + {('a', 0, 'b'): 'foo'} Returns ------- @@ -91,8 +107,14 @@ def _flatten(_d, depth, parent=None): for key, value in key_value_iterable: has_item = True flat_key = reducer(parent, key) - if isinstance(value, flattenable_types) and ( - max_flatten_depth is None or depth < max_flatten_depth + if ( + isinstance(value, flattenable_types) + and (max_flatten_depth is None or depth < max_flatten_depth) + and not ( + keep_element_types + and isinstance(value, enumerate_types) + and any([isinstance(e, keep_element_types) for e in value]) + ) ): # recursively build the result has_child = _flatten(value, depth=depth + 1, parent=flat_key) diff --git a/src/flatten_dict/tests/flatten_dict_test.py b/src/flatten_dict/tests/flatten_dict_test.py index 7cc97bb..4a70f87 100644 --- a/src/flatten_dict/tests/flatten_dict_test.py +++ b/src/flatten_dict/tests/flatten_dict_test.py @@ -316,6 +316,53 @@ def test_flatten_list(): assert flatten([1, 2], enumerate_types=(list,)) == {(0,): 1, (1,): 2} +def test_flatten_dict_with_list_of_str_with_keep_element_types(): + assert ( + flatten( + {"a": ["b", "c"]}, + enumerate_types=(list,), + keep_element_types=(str,), + ) + == {("a",): ["b", "c"]} + ) + + +def test_flatten_dict_with_list_of_int_with_keep_element_types(): + assert ( + flatten( + {"a": [10, 11]}, + enumerate_types=(list,), + keep_element_types=(str,), + ) + == {("a", 0): 10, ("a", 1): 11} + ) + + +def test_flatten_dict_with_list_of_dict_with_keep_element_types(): + assert ( + flatten( + {"a": [{"b": "foo"}]}, + enumerate_types=(list,), + keep_element_types=(str,), + ) + == {("a", 0, "b"): "foo"} + ) + + +def test_flatten_dict_with_list_with_keep_empty_types_and_keep_element_types( + dict_with_list, flat_tuple_dict_with_list +): + assert ( + flatten( + dict_with_list, + enumerate_types=(list,), + keep_empty_types=(list,), + keep_element_types=(str,), + ) + == flat_tuple_dict_with_list + ) + + @pytest.fixture def dict_with_generator(): return {