Skip to content

Commit

Permalink
Merge pull request #49 from chisholm/fix-find-references
Browse files Browse the repository at this point in the history
Fix reference search code to search for references inside objects which are inside arrays.
  • Loading branch information
rpiazza authored Nov 28, 2023
2 parents afbe459 + b74b2b3 commit 7319f21
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
10 changes: 6 additions & 4 deletions stix2generator/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,21 @@ def object_generator21():
"obj, findings", [
({"type": "foo", "a_ref": 1, "b_ref": 2, "c": 3}, {("a_ref", 1), ("b_ref", 2)}),
({"type": "foo", "a_refs": [1, 2], "b_ref": 3, "c": 4}, {("a_refs", 1), ("a_refs", 2), ("b_ref", 3)}),
({"type": "foo", "a": {"b": {"c_ref": 1}}}, {("c_ref", 1)})
({"type": "foo", "a": {"b": {"c_ref": 1}}}, {("c_ref", 1)}),
({"type": "foo", "a": [{"b_ref": 1}, {"b_ref": 2}, {"c": [{"d_ref": 3}]}]}, {("b_ref", 1), ("b_ref", 2), ("d_ref", 3)})
]
)
def test_find_references(obj, findings):
for ref_prop, ref_id in stix2generator.utils.find_references(obj):
assert (ref_prop, ref_id) in findings
found = set(stix2generator.utils.find_references(obj))
assert found == findings


@pytest.mark.parametrize(
"obj", [
{"type": "foo", "a_ref": 1, "b_ref": 1, "c": 1},
{"type": "foo", "a_refs": [1, 1], "b_ref": 1, "c": 1},
{"type": "foo", "a": {"b": {"c_ref": 1}}}
{"type": "foo", "a": {"b": {"c_ref": 1}}},
{"type": "foo", "a": [{"b_ref": 1}, {"b_ref": 1}, {"c": [{"d_ref": 1}]}]}
]
)
def test_find_references_assignable(obj):
Expand Down
14 changes: 12 additions & 2 deletions stix2generator/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def recurse_references(obj):
observed-data/objects.
:param obj: An object. Can be any type, but values will only be produced
from mappings.
from mappings (or iterables containing mappings).
"""
if isinstance(obj, collections.abc.Mapping):
for prop, value in obj.items():
Expand All @@ -82,6 +82,11 @@ def recurse_references(obj):
else:
yield from recurse_references(value)

elif isinstance(obj, collections.abc.Iterable) \
and not isinstance(obj, str):
for elt in obj:
yield from recurse_references(elt)


def find_references(obj):
"""
Expand Down Expand Up @@ -116,7 +121,7 @@ def recurse_references_assignable(obj):
special casing for observed-data/objects.
:param obj: An object. Can be any type, but values will only be produced
from mappings.
from mappings (or iterables containing mappings).
"""
if isinstance(obj, collections.abc.Mapping):
for prop, value in obj.items():
Expand All @@ -130,6 +135,11 @@ def recurse_references_assignable(obj):
else:
yield from recurse_references_assignable(value)

elif isinstance(obj, collections.abc.Iterable) \
and not isinstance(obj, str):
for elt in obj:
yield from recurse_references_assignable(elt)


def find_references_assignable(obj):
"""
Expand Down

0 comments on commit 7319f21

Please sign in to comment.