Skip to content

Commit

Permalink
Refactor sorting tests and add more tests cases
Browse files Browse the repository at this point in the history
Particularly, test the case that some of the sorting parameters are unselected fields.
  • Loading branch information
Cito committed Aug 2, 2024
1 parent a1eb80d commit 6c01a5f
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 89 deletions.
35 changes: 17 additions & 18 deletions src/mass/adapters/outbound/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,20 @@ def pipeline_facet_sort_and_paginate(
# this is the total number of hits, but pagination can mean only a few are returned
segment["count"] = [{"$count": "total"}]

segment["hits"] = []
# rename the ID field to id_ to match our model
segment["hits"] = [{"$addFields": {"id_": "$_id"}}, {"$unset": "_id"}]

if project:
# pick only the selected fields
segment["hits"].append({"$project": project})
else:
# rename the ID field to id_ to match our model
segment["hits"] = [{"$addFields": {"id_": "$_id"}}, {"$unset": "_id"}]

# apply sorting parameters
# apply sorting parameters (maybe some of them are unselected fields)
if sort:
segment["hits"].append({"$sort": sort})

# pick only the selected fields
if project:
segment["hits"].append({"$project": project})

# apply skip and limit for pagination
if skip > 0:
segment["hits"].append({"$skip": skip})

if limit:
segment["hits"].append({"$limit": limit})

Expand Down Expand Up @@ -148,17 +145,19 @@ def build_pipeline( # noqa: PLR0913
pipeline.append(pipeline_match_filters_stage(filters=filters))

# turn the selected fields into a formatted pipeline $project
keys = [field.key for field in selected_fields]
project: dict[str, Any] = {
key if key == "id_" else f"content.{key}": "$_id" if key == "id_" else 1
for key in keys
}
if "id_" in project:
project["_id"] = 0
project: dict[str, int] = dict.fromkeys(
[
field.key if field.key == "id_" else f"content.{field.key}"
for field in selected_fields
],
1,
)

# turn the sorting parameters into a formatted pipeline $sort
sort: dict[str, Any] = {
param.field: SORT_ORDER_CONVERSION[param.order.value]
param.field
if param.field == "id_"
else f"content.{param.field}": SORT_ORDER_CONVERSION[param.order.value]
for param in sorting_parameters
}

Expand Down
4 changes: 2 additions & 2 deletions src/mass/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@


class FieldLabel(BaseModel):
"""Contains the key and corresponding user-friendly name for a field"""
"""Contains the field name and corresponding user-friendly name"""

key: str = Field(..., description="The raw field key, such as study.type")
key: str = Field(..., description="The raw field name, such as study.type")
name: str = Field(
default="",
description="A user-friendly name for the field (leave empty to use the key)",
Expand Down
8 changes: 7 additions & 1 deletion tests/fixtures/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class JointFixture:
kafka: KafkaFixture
mongodb: MongoDbFixture
rest_client: AsyncTestClient
resources: dict[str, list[models.Resource]]

def remove_db_data(self) -> None:
"""Delete everything in the database to start from a clean slate"""
Expand All @@ -62,6 +63,7 @@ async def load_test_data(self) -> None:
if match_obj:
collection_name = match_obj.group(1)
resources = get_resources_from_file(filename)
self.resources[collection_name] = resources
for resource in resources:
await self.query_handler.load_resource(
resource=resource, class_name=collection_name
Expand All @@ -70,8 +72,11 @@ async def load_test_data(self) -> None:
async def call_search_endpoint(self, params: QueryParams) -> models.QueryResults:
"""Convenience function to call the /search endpoint"""
response = await self.rest_client.get(url="/search", params=params)
result = response.json()
assert result is not None, result
assert "detail" in result or "hits" in result, result
response.raise_for_status()
return models.QueryResults(**response.json())
return models.QueryResults(**result)


@pytest_asyncio.fixture
Expand All @@ -97,6 +102,7 @@ async def joint_fixture(
kafka=kafka,
mongodb=mongodb,
rest_client=rest_client,
resources={},
)
await joint_fixture.load_test_data()
yield joint_fixture
2 changes: 1 addition & 1 deletion tests/fixtures/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ searchable_classes:
- key: id_
name: ID
- key: type
name: Hotel Type
name: Location Type
- key: "has_object.type"
name: Object Type
EmptyCollection:
Expand Down
12 changes: 6 additions & 6 deletions tests/fixtures/test_data/SortingTests.json
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
{
"items": [
{
"field": "some data",
"field": "alpha",
"id_": "i2"
},
{
"field": "some data",
"field": "bravo",
"id_": "i1"
},
{
"field": "some data",
"field": "charlie",
"id_": "i3"
},
{
"field": "some data",
"field": "delta",
"id_": "i5"
},
{
"field": "some data",
"field": "echo",
"id_": "i6"
},
{
"field": "some data",
"field": "foxtrot",
"id_": "i4"
}
]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_index_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
QUERY_STRING = "Backrub"


@pytest.mark.parametrize("create_index_manually", (False, True))
@pytest.mark.parametrize("create_index_manually", [False, True], ids=["auto", "manual"])
@pytest.mark.asyncio
async def test_index_creation(joint_fixture: JointFixture, create_index_manually: bool):
"""Test the index creation function."""
Expand Down
Loading

0 comments on commit 6c01a5f

Please sign in to comment.