diff --git a/src/mass/adapters/outbound/utils.py b/src/mass/adapters/outbound/utils.py index 5ddb555..3bede27 100644 --- a/src/mass/adapters/outbound/utils.py +++ b/src/mass/adapters/outbound/utils.py @@ -89,23 +89,20 @@ def pipeline_facet_sort_and_paginate( # this is the total number of hits, but pagination can mean only a few are returned segment["count"] = [{"$count": "total"}] - segment["hits"] = [] + # rename the ID field to id_ to match our model + segment["hits"] = [{"$addFields": {"id_": "$_id"}}, {"$unset": "_id"}] - if project: - # pick only the selected fields - segment["hits"].append({"$project": project}) - else: - # rename the ID field to id_ to match our model - segment["hits"] = [{"$addFields": {"id_": "$_id"}}, {"$unset": "_id"}] - - # apply sorting parameters + # apply sorting parameters (maybe some of them are unselected fields) if sort: segment["hits"].append({"$sort": sort}) + # pick only the selected fields + if project: + segment["hits"].append({"$project": project}) + # apply skip and limit for pagination if skip > 0: segment["hits"].append({"$skip": skip}) - if limit: segment["hits"].append({"$limit": limit}) @@ -148,17 +145,19 @@ def build_pipeline( # noqa: PLR0913 pipeline.append(pipeline_match_filters_stage(filters=filters)) # turn the selected fields into a formatted pipeline $project - keys = [field.key for field in selected_fields] - project: dict[str, Any] = { - key if key == "id_" else f"content.{key}": "$_id" if key == "id_" else 1 - for key in keys - } - if "id_" in project: - project["_id"] = 0 + project: dict[str, int] = dict.fromkeys( + [ + field.key if field.key == "id_" else f"content.{field.key}" + for field in selected_fields + ], + 1, + ) # turn the sorting parameters into a formatted pipeline $sort sort: dict[str, Any] = { - param.field: SORT_ORDER_CONVERSION[param.order.value] + param.field + if param.field == "id_" + else f"content.{param.field}": SORT_ORDER_CONVERSION[param.order.value] for param in sorting_parameters } diff --git a/src/mass/core/models.py b/src/mass/core/models.py index 8d4f58e..5a64d83 100644 --- a/src/mass/core/models.py +++ b/src/mass/core/models.py @@ -22,9 +22,9 @@ class FieldLabel(BaseModel): - """Contains the key and corresponding user-friendly name for a field""" + """Contains the field name and corresponding user-friendly name""" - key: str = Field(..., description="The raw field key, such as study.type") + key: str = Field(..., description="The raw field name, such as study.type") name: str = Field( default="", description="A user-friendly name for the field (leave empty to use the key)", diff --git a/tests/fixtures/joint.py b/tests/fixtures/joint.py index 084dc91..7c19884 100644 --- a/tests/fixtures/joint.py +++ b/tests/fixtures/joint.py @@ -48,6 +48,7 @@ class JointFixture: kafka: KafkaFixture mongodb: MongoDbFixture rest_client: AsyncTestClient + resources: dict[str, list[models.Resource]] def remove_db_data(self) -> None: """Delete everything in the database to start from a clean slate""" @@ -62,6 +63,7 @@ async def load_test_data(self) -> None: if match_obj: collection_name = match_obj.group(1) resources = get_resources_from_file(filename) + self.resources[collection_name] = resources for resource in resources: await self.query_handler.load_resource( resource=resource, class_name=collection_name @@ -70,8 +72,11 @@ async def load_test_data(self) -> None: async def call_search_endpoint(self, params: QueryParams) -> models.QueryResults: """Convenience function to call the /search endpoint""" response = await self.rest_client.get(url="/search", params=params) + result = response.json() + assert result is not None, result + assert "detail" in result or "hits" in result, result response.raise_for_status() - return models.QueryResults(**response.json()) + return models.QueryResults(**result) @pytest_asyncio.fixture @@ -97,6 +102,7 @@ async def joint_fixture( kafka=kafka, mongodb=mongodb, rest_client=rest_client, + resources={}, ) await joint_fixture.load_test_data() yield joint_fixture diff --git a/tests/fixtures/test_config.yaml b/tests/fixtures/test_config.yaml index 12f7553..ab8fc53 100644 --- a/tests/fixtures/test_config.yaml +++ b/tests/fixtures/test_config.yaml @@ -30,7 +30,7 @@ searchable_classes: - key: id_ name: ID - key: type - name: Hotel Type + name: Location Type - key: "has_object.type" name: Object Type EmptyCollection: diff --git a/tests/fixtures/test_data/SortingTests.json b/tests/fixtures/test_data/SortingTests.json index 4206851..538a4a1 100644 --- a/tests/fixtures/test_data/SortingTests.json +++ b/tests/fixtures/test_data/SortingTests.json @@ -1,27 +1,27 @@ { "items": [ { - "field": "some data", + "field": "alpha", "id_": "i2" }, { - "field": "some data", + "field": "bravo", "id_": "i1" }, { - "field": "some data", + "field": "charlie", "id_": "i3" }, { - "field": "some data", + "field": "delta", "id_": "i5" }, { - "field": "some data", + "field": "echo", "id_": "i6" }, { - "field": "some data", + "field": "foxtrot", "id_": "i4" } ] diff --git a/tests/test_index_creation.py b/tests/test_index_creation.py index bbfd4e6..c4bedf4 100644 --- a/tests/test_index_creation.py +++ b/tests/test_index_creation.py @@ -30,7 +30,7 @@ QUERY_STRING = "Backrub" -@pytest.mark.parametrize("create_index_manually", (False, True)) +@pytest.mark.parametrize("create_index_manually", [False, True], ids=["auto", "manual"]) @pytest.mark.asyncio async def test_index_creation(joint_fixture: JointFixture, create_index_manually: bool): """Test the index creation function.""" diff --git a/tests/test_sorting.py b/tests/test_sorting.py index d27e4e7..97c4970 100644 --- a/tests/test_sorting.py +++ b/tests/test_sorting.py @@ -21,53 +21,67 @@ from tests.fixtures.joint import JointFixture, QueryParams CLASS_NAME = "SortingTests" -BASIC_SORT_PARAMETERS = [ - models.SortingParameter(field="id_", order=models.SortOrder.ASCENDING) -] -def multi_column_sort( - resources: list[models.Resource], sorts: list[models.SortingParameter] +def sorted_resources( # noqa: C901 + resources: list[models.Resource], + order_by: list[str] | None = None, + sort: list[str] | None = None, + complete_resources: list[models.Resource] | None = None, ) -> list[models.Resource]: - """This is equivalent to nested sorted() calls. - - This uses the same approach as the sorting function in test_relevance, but the - difference is that this function uses Resource models and doesn't work with the - relevance sorting parameter. There's no spot for a top-level text score parameter in - the resource model, which is why the relevance tests use a slightly different version - of this function. - - The sorting parameters are supplied in order of most significant to least significant, - so we take them off the front and apply sorted(). If there are more parameters to - apply (more sorts), we recurse until we apply the final parameter. The sorted lists - are passed back up the call chain. + """Sort resources by all specified fields. + + This function simulates the sorting that is expected to be done by the database. + Since there's no spot for a top-level text score parameter in the resource model, + the relevance tests need to use a slightly different version of this function. + + In the case that some of the sorted fields are not part of the resources, the + complete resources which contain these missing fields must be passed as well. """ - sorted_list = resources.copy() - sorts = sorts.copy() - - parameter = sorts[0] - del sorts[0] - - # sort descending for DESCENDING and RELEVANCE - reverse = parameter.order != models.SortOrder.ASCENDING - - if len(sorts) > 0: - # if there are more sorting parameters, recurse to nest the sorts - sorted_list = multi_column_sort(sorted_list, sorts) - - if parameter.field == "id_": - return sorted( - sorted_list, - key=lambda result: result.model_dump()[parameter.field], - reverse=reverse, - ) - else: - # the only top-level fields is "_id" -- all else is in "content" - return sorted( - sorted_list, - key=lambda result: result.model_dump()["content"][parameter.field], - reverse=reverse, - ) + if order_by is None: + order_by = [] + if sort is None: + sort = [] + assert len(order_by) == len(sort) + if "id_" not in order_by: + # implicitly add id_ at the end since we also do it in the query handler + order_by.append("id_") + sort.append("ascending") + + def sort_key(resource: models.Resource) -> tuple: + """Create a tuple that can be used as key for sorting the resource.""" + if complete_resources: + for complete_resource in complete_resources: + if complete_resource.id_ == resource.id_: + resource = complete_resource + break + else: + assert False, f"{resource.id_} not found in complete resources" + key = [] + for field, field_sort in zip(order_by, sort, strict=True): + resource_dict = resource.model_dump() + if field != "id_": + # the only top-level fields is "_id" -- all else is in "content" + resource_dict = resource_dict["content"] + # support dotted access + sub_fields = field.split(".") + sub_fields, field = sub_fields[:-1], sub_fields[-1] + for sub_field in sub_fields: + resource_dict = resource_dict.get(sub_field, {}) + value = resource_dict.get(field) + # MongoDB returns nulls first, help Python to sort it properly + key_for_null = value is not None + if field_sort == "descending": + key_for_null = not key_for_null + if isinstance(value, str): + value = tuple(-ord(c) for c in value) + elif isinstance(value, int | float): + value = -value + key.append((key_for_null, value)) + return tuple(key) + + # sort the reversed resources to not rely on the already given order + return sorted(reversed(resources), key=sort_key) @pytest.mark.asyncio @@ -77,35 +91,37 @@ async def test_api_without_sort_parameters(joint_fixture: JointFixture): results = await joint_fixture.call_search_endpoint(params) assert results.count > 0 - expected = multi_column_sort(results.hits, BASIC_SORT_PARAMETERS) + expected = sorted_resources(results.hits) assert results.hits == expected +@pytest.mark.parametrize("reverse", [False, True], ids=["normal", "reversed"]) @pytest.mark.asyncio -async def test_sort_with_id_not_last(joint_fixture: JointFixture): - """Test sorting parameters that contain id_, but id_ is not final sorting field. +async def test_sort_with_id_not_last(joint_fixture: JointFixture, reverse: bool): + """Test sorting parameters that contain id_, but not as the final sorting field. Since we modify sorting parameters based on presence of id_, make sure there aren't any bugs that will break the sort or query process. """ + order_by = ["id_", "field"] + sort = ["ascending", "descending"] + if reverse: + sort.reverse() params: QueryParams = { "class_name": CLASS_NAME, "query": "", "filters": [], - "order_by": ["id_", "field"], - "sort": ["ascending", "descending"], + "order_by": order_by, + "sort": sort, } - sorts_in_model_form = [ - models.SortingParameter(field="id_", order=models.SortOrder.ASCENDING), - models.SortingParameter(field="field", order=models.SortOrder.DESCENDING), - ] results = await joint_fixture.call_search_endpoint(params) - assert results.hits == multi_column_sort(results.hits, sorts_in_model_form) + assert results.hits == sorted_resources(results.hits, order_by, sort) +@pytest.mark.parametrize("reverse", [False, True], ids=["normal", "reversed"]) @pytest.mark.asyncio -async def test_sort_with_params_but_not_id(joint_fixture: JointFixture): +async def test_sort_with_params_but_not_id(joint_fixture: JointFixture, reverse: bool): """Test supplying sorting parameters but omitting id_. In order to provide consistent sorting, id_ should always be included. If it's not @@ -113,14 +129,16 @@ async def test_sort_with_params_but_not_id(joint_fixture: JointFixture): any tie between otherwise equivalent keys. If it is included but is not the final field, then we should not modify the parameters. """ + order_by = ["field"] + sort = ["descending" if reverse else "ascending"] params: QueryParams = { "class_name": CLASS_NAME, - "order_by": ["field"], - "sort": ["ascending"], + "order_by": order_by, + "sort": sort, } results = await joint_fixture.call_search_endpoint(params) - assert results.hits == multi_column_sort(results.hits, BASIC_SORT_PARAMETERS) + assert results.hits == sorted_resources(results.hits, order_by, sort) @pytest.mark.asyncio @@ -138,17 +156,19 @@ async def test_sort_with_invalid_field(joint_fixture: JointFixture): } results = await joint_fixture.call_search_endpoint(params) - assert results.hits == multi_column_sort(results.hits, BASIC_SORT_PARAMETERS) + assert results.hits == sorted_resources(results.hits) @pytest.mark.parametrize("order", [-7, 17, "some_string"]) @pytest.mark.asyncio -async def test_sort_with_invalid_sort_order(joint_fixture: JointFixture, order): +async def test_sort_with_invalid_sort_order( + joint_fixture: JointFixture, order: str | int +): """Test supplying an invalid value for the sort order""" params: QueryParams = { "class_name": CLASS_NAME, "order_by": ["field"], - "sort": [order], + "sort": [order], # type: ignore } response = await joint_fixture.rest_client.get(url="/search", params=params) @@ -220,3 +240,61 @@ async def test_sort_with_superfluous_sort(joint_fixture: JointFixture): assert response.status_code == 422 details = response.json()["detail"] assert details == "Number of fields to order by must match number of sort options" + + +@pytest.mark.parametrize("reverse", [False, True], ids=["normal", "reversed"]) +@pytest.mark.parametrize("field", ["type", "has_object.type"]) +@pytest.mark.asyncio +async def test_sort_with_one_of_the_selected_fields( + joint_fixture: JointFixture, reverse: bool, field: str +): + """Test sorting when fields are selected and one of them is used for sorting.""" + class_name = "DatasetEmbedded" + selected = joint_fixture.config.searchable_classes[class_name].selected_fields + assert selected # this resource has selected fields + assert any(f.key == field for f in selected) # field is selected + + order_by = [field] + sort = ["descending" if reverse else "ascending"] + params: QueryParams = { + "class_name": class_name, + "order_by": order_by, + "sort": sort, + } + + results = await joint_fixture.call_search_endpoint(params) + assert results.hits == sorted_resources(results.hits, order_by, sort) + + +@pytest.mark.parametrize("reverse", [False, True], ids=["normal", "reversed"]) +@pytest.mark.parametrize("field", ["category", "field1"]) +@pytest.mark.asyncio +async def test_sort_with_one_of_the_unselected_fields( + joint_fixture: JointFixture, reverse: bool, field: str +): + """Test sorting when fields are selected but sorted by an unselected field.""" + class_name = "DatasetEmbedded" + selected = joint_fixture.config.searchable_classes[class_name].selected_fields + assert selected # this resource has selected fields + assert not any(f.key == field for f in selected) # field is unselected + + order_by = [field] + sort = ["descending" if reverse else "ascending"] + params: QueryParams = { + "class_name": class_name, + "order_by": order_by, + "sort": sort, + } + + results = await joint_fixture.call_search_endpoint(params) + + # make sure the field is not returned in the results + for resource in results.hits: + assert field not in resource.content + + # therefore, we cannot just sort the results, + # but we need to fetch the field from the complete original resources + complete_resources = joint_fixture.resources[class_name] + assert results.hits == sorted_resources( + results.hits, order_by, sort, complete_resources + )