Skip to content

Commit

Permalink
Support facets for nested arrays and fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Aug 7, 2024
1 parent 933241e commit 1bb1803
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 41 deletions.
62 changes: 30 additions & 32 deletions src/mass/adapters/outbound/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,6 @@ def pipeline_match_text_search(*, query: str) -> JsonObject:
return {"$match": text_search}


def args_for_getfield(*, root_object_name: str, field_name: str) -> tuple[str, str]:
"""Fieldpath names can't have '.', so specify any nested fields with $getField"""
prefix = f"${root_object_name}"
specified_field = field_name
if "." in field_name:
pieces = field_name.split(".")
specified_field = pieces[-1]
prefix += "." + ".".join(pieces[:-1])

return prefix, specified_field


def pipeline_match_filters_stage(*, filters: list[models.Filter]) -> JsonObject:
"""Build segment of pipeline to apply search filters"""
filter_values = defaultdict(list)
Expand Down Expand Up @@ -95,35 +83,45 @@ def pipeline_facet_sort_and_paginate(
segment: dict[str, list[JsonObject]] = {}

for facet in facet_fields:
prefix, specified_field = args_for_getfield(
root_object_name="content", field_name=facet.key
)
name = facet.name
if not name:
name = name_from_key(facet.key)
segment[name] = [
{
"$unwind": {
"path": prefix,
"preserveNullAndEmptyArrays": True,
}
},
pipeline: list[JsonObject] = [
{
"$unwind": {
"path": f"{prefix}.{specified_field}",
"path": "$content",
"preserveNullAndEmptyArrays": True,
}
},
{
"$group": {
"_id": {"$getField": {"field": specified_field, "input": prefix}},
"count": {"$sum": 1},
}
},
{"$addFields": {"value": "$_id"}}, # rename "_id" to "value" on each option
{"$unset": "_id"},
{"$sort": {"value": 1}},
]
path = "$content"
for field in facet.key.split("."):
path += f".{field}"
pipeline.append(
{
"$unwind": {
"path": path,
"preserveNullAndEmptyArrays": True,
}
},
)
path, field = path.rsplit(".", 1)
pipeline.extend(
(
{
"$group": {
"_id": {"$getField": {"field": field, "input": path}},
"count": {"$sum": 1},
}
},
{
"$addFields": {"value": "$_id"}
}, # rename "_id" to "value" on each option
{"$unset": "_id"},
{"$sort": {"value": 1}},
)
)
segment[name] = pipeline

# this is the total number of hits, but pagination can mean only a few are returned
segment["count"] = [{"$count": "total"}]
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ searchable_classes:
name: Food
- key: friends.name
name: Friend
- key: special.features.fur.color
name: Fur color
selected_fields:
- key: name
resource_change_event_topic: searchable_resources
Expand Down
51 changes: 51 additions & 0 deletions tests/fixtures/test_data/FilteringTests.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
],
"id_": "1",
"name": "Jack",
"special": {
"features": [
{
"fur": {
"color": "brown"
}
}
]
},
"species": "monkey"
},
{
Expand All @@ -34,6 +43,15 @@
],
"id_": "2",
"name": "Bruiser",
"special": {
"features": [
{
"fur": {
"color": "light brown"
}
}
]
},
"species": "dog"
},
{
Expand All @@ -54,6 +72,18 @@
],
"id_": "3",
"name": "Lady",
"special": {
"features": [
{
"fur": {
"color": [
"cream",
"brown"
]
}
}
]
},
"species": "dog"
},
{
Expand All @@ -80,6 +110,18 @@
],
"id_": "4",
"name": "Garfield",
"special": {
"features": [
{
"fur": {
"color": [
"orange",
"black"
]
}
}
]
},
"species": "cat"
},
{
Expand All @@ -103,6 +145,15 @@
],
"id_": "5",
"name": "Flipper",
"special": {
"features": [
{
"fur": {
"color": "gray"
}
}
]
},
"species": "dolphin"
}
]
Expand Down
67 changes: 58 additions & 9 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_facets(joint_fixture: JointFixture):
results = await joint_fixture.call_search_endpoint(params)

facets = results.facets
assert len(facets) == 3
assert len(facets) == 4

facet = facets[0]
assert facet.key == "species"
Expand Down Expand Up @@ -77,11 +77,25 @@ async def test_facets(joint_fixture: JointFixture):
}
assert list(options) == sorted(options)

facet = facets[3]
assert facet.key == "special.features.fur.color"
assert facet.name == "Fur color"
options = {option.value: option.count for option in facet.options}
assert options == {
"black": 1,
"brown": 2,
"cream": 1,
"gray": 1,
"light brown": 1,
"orange": 1,
}
assert list(options) == sorted(options)


@pytest.mark.parametrize(
"species,names",
[("mouse", []), ("cat", ["Garfield"]), ("dog", ["Bruiser", "Lady"])],
ids=[0, 1, 2],
ids=range(1, 4),
)
async def test_single_valued_with_with_single_filter(
species: str, names: list[str], joint_fixture: JointFixture
Expand All @@ -101,7 +115,7 @@ async def test_single_valued_with_with_single_filter(

# Check that the facet only contains the filtered values
facets = results.facets
assert len(facets) == 3
assert len(facets) == 4
facet = facets[0]
assert facet.key == "species"
assert facet.name == "Species"
Expand All @@ -118,7 +132,7 @@ async def test_single_valued_with_with_single_filter(
@pytest.mark.parametrize(
"food,names",
[("broccoli", []), ("bananas", ["Jack"]), ("fish", ["Garfield", "Flipper"])],
ids=[0, 1, 2],
ids=range(1, 4),
)
async def test_multi_valued_with_with_single_filter(
food: str, names: list[str], joint_fixture: JointFixture
Expand All @@ -138,7 +152,7 @@ async def test_multi_valued_with_with_single_filter(

# Check that the facet only contains the filtered values
facets = results.facets
assert len(facets) == 3
assert len(facets) == 4
facet = facets[1]
assert facet.key == "eats"
assert facet.name == "Food"
Expand Down Expand Up @@ -188,7 +202,7 @@ async def test_multiple_filters(joint_fixture: JointFixture):
("Jon", ["Garfield", "Flipper"]),
("Jack", ["Jack", "Bruiser", "Garfield"]),
],
ids=[0, 1, 2, 3],
ids=range(1, 5),
)
async def test_filter_for_common_friend(
friend: str, names: list[str], joint_fixture: JointFixture
Expand All @@ -208,7 +222,7 @@ async def test_filter_for_common_friend(

# Check that the facet contains the friend in question
facets = results.facets
assert len(facets) == 3
assert len(facets) == 4
facet = facets[2]
assert facet.key == "friends.name"
assert facet.name == "Friend"
Expand All @@ -217,7 +231,7 @@ async def test_filter_for_common_friend(


async def test_filter_for_couple_of_friends(joint_fixture: JointFixture):
"""Test that we can search for animals who have one of the specified friends"""
"""Test that we can search for animals with one of the specified friends"""
params: QueryParams = {
"class_name": CLASS_NAME,
"filter_by": ["friends.name"] * 3,
Expand All @@ -232,11 +246,46 @@ async def test_filter_for_couple_of_friends(joint_fixture: JointFixture):

# Check that the facet contains the selected friends and all their co-friends
facets = results.facets
assert len(facets) == 3
assert len(facets) == 4
facet = facets[2]
assert facet.key == "friends.name"
assert facet.name == "Friend"
option_values = [option.value for option in facet.options]
co_friends = "Buddy Hector Jack Jog Jon Peter Sandy Tramp Trusty".split()
assert option_values == co_friends
assert all(option.count == 1 for option in facet.options)


@pytest.mark.parametrize(
"color,names",
[
("green", []),
("orange", ["Garfield"]),
("brown", ["Jack", "Lady"]),
],
ids=range(1, 4),
)
async def test_filter_for_fur_color(
color: str, names: list[str], joint_fixture: JointFixture
):
"""Test that we can search for animals with a given fur color"""
params: QueryParams = {
"class_name": CLASS_NAME,
"filter_by": "special.features.fur.color",
"value": color,
}

results = await joint_fixture.call_search_endpoint(params)

# Check that the expected names are returned
returned_names = [resource.content["name"] for resource in results.hits]
assert returned_names == names

# Check that the facet contains the color in question
facets = results.facets
assert len(facets) == 4
facet = facets[3]
assert facet.key == "special.features.fur.color"
assert facet.name == "Fur color"
if names:
assert any(option.value == color for option in facet.options)

0 comments on commit 1bb1803

Please sign in to comment.