From 1bb1803239b6f5e00bb01096007ba0b6bb75249e Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Wed, 7 Aug 2024 11:26:59 +0000 Subject: [PATCH] Support facets for nested arrays and fields --- src/mass/adapters/outbound/utils.py | 62 +++++++++--------- tests/fixtures/test_config.yaml | 2 + tests/fixtures/test_data/FilteringTests.json | 51 +++++++++++++++ tests/test_filtering.py | 67 +++++++++++++++++--- 4 files changed, 141 insertions(+), 41 deletions(-) diff --git a/src/mass/adapters/outbound/utils.py b/src/mass/adapters/outbound/utils.py index 8c24b29..e14a284 100644 --- a/src/mass/adapters/outbound/utils.py +++ b/src/mass/adapters/outbound/utils.py @@ -41,18 +41,6 @@ def pipeline_match_text_search(*, query: str) -> JsonObject: return {"$match": text_search} -def args_for_getfield(*, root_object_name: str, field_name: str) -> tuple[str, str]: - """Fieldpath names can't have '.', so specify any nested fields with $getField""" - prefix = f"${root_object_name}" - specified_field = field_name - if "." in field_name: - pieces = field_name.split(".") - specified_field = pieces[-1] - prefix += "." + ".".join(pieces[:-1]) - - return prefix, specified_field - - def pipeline_match_filters_stage(*, filters: list[models.Filter]) -> JsonObject: """Build segment of pipeline to apply search filters""" filter_values = defaultdict(list) @@ -95,35 +83,45 @@ def pipeline_facet_sort_and_paginate( segment: dict[str, list[JsonObject]] = {} for facet in facet_fields: - prefix, specified_field = args_for_getfield( - root_object_name="content", field_name=facet.key - ) name = facet.name if not name: name = name_from_key(facet.key) - segment[name] = [ - { - "$unwind": { - "path": prefix, - "preserveNullAndEmptyArrays": True, - } - }, + pipeline: list[JsonObject] = [ { "$unwind": { - "path": f"{prefix}.{specified_field}", + "path": "$content", "preserveNullAndEmptyArrays": True, } }, - { - "$group": { - "_id": {"$getField": {"field": specified_field, "input": prefix}}, - "count": {"$sum": 1}, - } - }, - {"$addFields": {"value": "$_id"}}, # rename "_id" to "value" on each option - {"$unset": "_id"}, - {"$sort": {"value": 1}}, ] + path = "$content" + for field in facet.key.split("."): + path += f".{field}" + pipeline.append( + { + "$unwind": { + "path": path, + "preserveNullAndEmptyArrays": True, + } + }, + ) + path, field = path.rsplit(".", 1) + pipeline.extend( + ( + { + "$group": { + "_id": {"$getField": {"field": field, "input": path}}, + "count": {"$sum": 1}, + } + }, + { + "$addFields": {"value": "$_id"} + }, # rename "_id" to "value" on each option + {"$unset": "_id"}, + {"$sort": {"value": 1}}, + ) + ) + segment[name] = pipeline # this is the total number of hits, but pagination can mean only a few are returned segment["count"] = [{"$count": "total"}] diff --git a/tests/fixtures/test_config.yaml b/tests/fixtures/test_config.yaml index 700fc91..987edd7 100644 --- a/tests/fixtures/test_config.yaml +++ b/tests/fixtures/test_config.yaml @@ -61,6 +61,8 @@ searchable_classes: name: Food - key: friends.name name: Friend + - key: special.features.fur.color + name: Fur color selected_fields: - key: name resource_change_event_topic: searchable_resources diff --git a/tests/fixtures/test_data/FilteringTests.json b/tests/fixtures/test_data/FilteringTests.json index 480d979..3ccd750 100644 --- a/tests/fixtures/test_data/FilteringTests.json +++ b/tests/fixtures/test_data/FilteringTests.json @@ -14,6 +14,15 @@ ], "id_": "1", "name": "Jack", + "special": { + "features": [ + { + "fur": { + "color": "brown" + } + } + ] + }, "species": "monkey" }, { @@ -34,6 +43,15 @@ ], "id_": "2", "name": "Bruiser", + "special": { + "features": [ + { + "fur": { + "color": "light brown" + } + } + ] + }, "species": "dog" }, { @@ -54,6 +72,18 @@ ], "id_": "3", "name": "Lady", + "special": { + "features": [ + { + "fur": { + "color": [ + "cream", + "brown" + ] + } + } + ] + }, "species": "dog" }, { @@ -80,6 +110,18 @@ ], "id_": "4", "name": "Garfield", + "special": { + "features": [ + { + "fur": { + "color": [ + "orange", + "black" + ] + } + } + ] + }, "species": "cat" }, { @@ -103,6 +145,15 @@ ], "id_": "5", "name": "Flipper", + "special": { + "features": [ + { + "fur": { + "color": "gray" + } + } + ] + }, "species": "dolphin" } ] diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 399fa29..e4616a0 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -31,7 +31,7 @@ async def test_facets(joint_fixture: JointFixture): results = await joint_fixture.call_search_endpoint(params) facets = results.facets - assert len(facets) == 3 + assert len(facets) == 4 facet = facets[0] assert facet.key == "species" @@ -77,11 +77,25 @@ async def test_facets(joint_fixture: JointFixture): } assert list(options) == sorted(options) + facet = facets[3] + assert facet.key == "special.features.fur.color" + assert facet.name == "Fur color" + options = {option.value: option.count for option in facet.options} + assert options == { + "black": 1, + "brown": 2, + "cream": 1, + "gray": 1, + "light brown": 1, + "orange": 1, + } + assert list(options) == sorted(options) + @pytest.mark.parametrize( "species,names", [("mouse", []), ("cat", ["Garfield"]), ("dog", ["Bruiser", "Lady"])], - ids=[0, 1, 2], + ids=range(1, 4), ) async def test_single_valued_with_with_single_filter( species: str, names: list[str], joint_fixture: JointFixture @@ -101,7 +115,7 @@ async def test_single_valued_with_with_single_filter( # Check that the facet only contains the filtered values facets = results.facets - assert len(facets) == 3 + assert len(facets) == 4 facet = facets[0] assert facet.key == "species" assert facet.name == "Species" @@ -118,7 +132,7 @@ async def test_single_valued_with_with_single_filter( @pytest.mark.parametrize( "food,names", [("broccoli", []), ("bananas", ["Jack"]), ("fish", ["Garfield", "Flipper"])], - ids=[0, 1, 2], + ids=range(1, 4), ) async def test_multi_valued_with_with_single_filter( food: str, names: list[str], joint_fixture: JointFixture @@ -138,7 +152,7 @@ async def test_multi_valued_with_with_single_filter( # Check that the facet only contains the filtered values facets = results.facets - assert len(facets) == 3 + assert len(facets) == 4 facet = facets[1] assert facet.key == "eats" assert facet.name == "Food" @@ -188,7 +202,7 @@ async def test_multiple_filters(joint_fixture: JointFixture): ("Jon", ["Garfield", "Flipper"]), ("Jack", ["Jack", "Bruiser", "Garfield"]), ], - ids=[0, 1, 2, 3], + ids=range(1, 5), ) async def test_filter_for_common_friend( friend: str, names: list[str], joint_fixture: JointFixture @@ -208,7 +222,7 @@ async def test_filter_for_common_friend( # Check that the facet contains the friend in question facets = results.facets - assert len(facets) == 3 + assert len(facets) == 4 facet = facets[2] assert facet.key == "friends.name" assert facet.name == "Friend" @@ -217,7 +231,7 @@ async def test_filter_for_common_friend( async def test_filter_for_couple_of_friends(joint_fixture: JointFixture): - """Test that we can search for animals who have one of the specified friends""" + """Test that we can search for animals with one of the specified friends""" params: QueryParams = { "class_name": CLASS_NAME, "filter_by": ["friends.name"] * 3, @@ -232,7 +246,7 @@ async def test_filter_for_couple_of_friends(joint_fixture: JointFixture): # Check that the facet contains the selected friends and all their co-friends facets = results.facets - assert len(facets) == 3 + assert len(facets) == 4 facet = facets[2] assert facet.key == "friends.name" assert facet.name == "Friend" @@ -240,3 +254,38 @@ async def test_filter_for_couple_of_friends(joint_fixture: JointFixture): co_friends = "Buddy Hector Jack Jog Jon Peter Sandy Tramp Trusty".split() assert option_values == co_friends assert all(option.count == 1 for option in facet.options) + + +@pytest.mark.parametrize( + "color,names", + [ + ("green", []), + ("orange", ["Garfield"]), + ("brown", ["Jack", "Lady"]), + ], + ids=range(1, 4), +) +async def test_filter_for_fur_color( + color: str, names: list[str], joint_fixture: JointFixture +): + """Test that we can search for animals with a given fur color""" + params: QueryParams = { + "class_name": CLASS_NAME, + "filter_by": "special.features.fur.color", + "value": color, + } + + results = await joint_fixture.call_search_endpoint(params) + + # Check that the expected names are returned + returned_names = [resource.content["name"] for resource in results.hits] + assert returned_names == names + + # Check that the facet contains the color in question + facets = results.facets + assert len(facets) == 4 + facet = facets[3] + assert facet.key == "special.features.fur.color" + assert facet.name == "Fur color" + if names: + assert any(option.value == color for option in facet.options)