From 66e5bc5d5dc619580052ed31ccda6424c6ad85a2 Mon Sep 17 00:00:00 2001 From: Casper Welzel Andersen Date: Wed, 9 Jun 2021 15:42:03 +0200 Subject: [PATCH] Use responses property for routes Only supply the pydantic model to `response_model` that represents the successful response. Add a dictionary of erroneous status codes with the ErrorResponse pydantic model to `responses` for all routes. Fix bugs caught by the tests introduced in the latest commit, modularizing out processing a database response. Use new specific exception classes from `optimade`. --- optimade_gateway/models/queries.py | 25 ++++++++---- optimade_gateway/queries/perform.py | 4 +- optimade_gateway/queries/process.py | 14 ++++--- optimade_gateway/queries/utils.py | 11 +++-- optimade_gateway/routers/databases.py | 14 ++++--- optimade_gateway/routers/gateway/info.py | 22 +++++----- optimade_gateway/routers/gateway/links.py | 10 ++--- optimade_gateway/routers/gateway/queries.py | 19 +++++---- .../routers/gateway/structures.py | 22 +++++++--- optimade_gateway/routers/gateway/utils.py | 6 +-- optimade_gateway/routers/gateway/versions.py | 2 + optimade_gateway/routers/gateways.py | 14 ++++--- optimade_gateway/routers/info.py | 16 ++++---- optimade_gateway/routers/links.py | 9 ++--- optimade_gateway/routers/queries.py | 12 ++++-- optimade_gateway/routers/search.py | 19 +++++---- optimade_gateway/routers/utils.py | 20 ++++------ tests/conftest.py | 12 ++++++ .../gateway/test_gateway_structures.py | 10 +++-- tests/routers/test_queries.py | 10 ++--- tests/routers/test_search.py | 40 ++++++++++--------- 21 files changed, 179 insertions(+), 132 deletions(-) diff --git a/optimade_gateway/models/queries.py b/optimade_gateway/models/queries.py index 28895c48..49fd14d2 100644 --- a/optimade_gateway/models/queries.py +++ b/optimade_gateway/models/queries.py @@ -6,15 +6,17 @@ from optimade.models import ( EntryResource, EntryResourceAttributes, - EntryResponseMany, - ErrorResponse, + OptimadeError, ReferenceResource, ReferenceResponseMany, ReferenceResponseOne, + Response, + ResponseMeta, StructureResource, StructureResponseMany, StructureResponseOne, ) +from optimade.models.utils import StrictField from optimade.server.query_params import EntryListingQueryParams from pydantic import BaseModel, EmailStr, Field, validator @@ -138,13 +140,23 @@ class QueryState(Enum): FINISHED = "finished" -class GatewayQueryResponse(EntryResponseMany): +class GatewayQueryResponse(Response): """Response from a Gateway Query.""" - data: Dict[str, Union[List[EntryResource], List[Dict[str, Any]]]] = Field( - ..., + data: Dict[str, Union[List[EntryResource], List[Dict[str, Any]]]] = StrictField( + ..., uniqueItems=True, description="Outputted Data" + ) + meta: ResponseMeta = StrictField( + ..., description="A meta object containing non-standard information" + ) + errors: Optional[List[OptimadeError]] = StrictField( + [], + description="A list of OPTIMADE-specific JSON API error objects, where the field detail MUST be present.", uniqueItems=True, ) + included: Optional[Union[List[EntryResource], List[Dict[str, Any]]]] = Field( + None, uniqueItems=True + ) class QueryResourceAttributes(EntryResourceAttributes): @@ -165,10 +177,9 @@ class QueryResourceAttributes(EntryResourceAttributes): title="State", type="enum", ) - response: Optional[Union[GatewayQueryResponse, ErrorResponse]] = Field( + response: Optional[GatewayQueryResponse] = Field( None, description="Response from gateway query.", - type="object", ) endpoint: EndpointEntryType = Field( EndpointEntryType.STRUCTURES, diff --git a/optimade_gateway/queries/perform.py b/optimade_gateway/queries/perform.py index ee1a18d2..98f7ca23 100644 --- a/optimade_gateway/queries/perform.py +++ b/optimade_gateway/queries/perform.py @@ -90,7 +90,7 @@ async def perform_query( **{"$set": {"state": QueryState.IN_PROGRESS}}, ) else: - response = EntryResponseMany( + response = query.attributes.endpoint.get_response_model()( data=[], links=ToplevelLinks(next=None), meta=meta_values( @@ -163,7 +163,7 @@ async def perform_query( resource["meta"] = database_id_meta else: resource.meta = Meta(**database_id_meta) - response.data.append(results) + response.data.extend(results) response.meta.data_returned += response_meta["data_returned"] if not response.meta.more_data_available: # Keep it True, if set to True once. diff --git a/optimade_gateway/queries/process.py b/optimade_gateway/queries/process.py index e91764c8..753dc658 100644 --- a/optimade_gateway/queries/process.py +++ b/optimade_gateway/queries/process.py @@ -57,6 +57,10 @@ async def process_db_response( errors = [] meta = {} + from optimade_gateway.common.logger import LOGGER + + LOGGER.debug("database_id: %s", database_id) + if isinstance(response, ErrorResponse): for error in response.errors: if isinstance(error.id, str) and error.id.startswith("OPTIMADE_GATEWAY"): @@ -113,14 +117,12 @@ async def process_db_response( # This ensures an empty list under `response.data.{database_id}` is returned if the case is # simply that there is no results to return. - # It also ensures that only `response.errors.{database_id}` is created if there are any - # errors. + if errors: + extra_updates.update({"$addToSet": {"response.errors": {"$each": errors}}}) await update_query( query, - f"response.errors.{database_id}" - if errors - else f"response.data.{database_id}", - errors or results, + f"response.data.{database_id}", + results, **extra_updates, ) else: diff --git a/optimade_gateway/queries/utils.py b/optimade_gateway/queries/utils.py index 21656b31..f0f22a15 100644 --- a/optimade_gateway/queries/utils.py +++ b/optimade_gateway/queries/utils.py @@ -47,17 +47,20 @@ async def update_query( update_kwargs = {"$set": {"last_modified": update_time}} if mongo_kwargs: - update_kwargs.update(await clean_python_types(mongo_kwargs)) + update_kwargs.update(mongo_kwargs) if operator and operator == "$set": - update_kwargs["$set"].update({field: await clean_python_types(value)}) + update_kwargs["$set"].update({field: value}) elif operator: - update_kwargs.update({operator: {field: await clean_python_types(value)}}) + if operator in update_kwargs: + update_kwargs[operator].update({field: value}) + else: + update_kwargs.update({operator: {field: value}}) # MongoDB result: UpdateResult = await QUERIES_COLLECTION.collection.update_one( filter={"id": {"$eq": query.id}}, - update=update_kwargs, + update=await clean_python_types(update_kwargs), ) if result.matched_count != 1: LOGGER.error( diff --git a/optimade_gateway/routers/databases.py b/optimade_gateway/routers/databases.py index 4d279cb9..a109ecc8 100644 --- a/optimade_gateway/routers/databases.py +++ b/optimade_gateway/routers/databases.py @@ -11,12 +11,11 @@ One can register a new database (by using `POST /databases`) or look through the available databases (by using `GET /databases`) using standard OPTIMADE filtering. """ -from typing import Union - from fastapi import APIRouter, Depends, Request -from optimade.models import ErrorResponse, LinksResource, ToplevelLinks +from optimade.models import LinksResource, ToplevelLinks from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams from optimade.server.routers.utils import handle_response_fields, meta_values +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.common.config import CONFIG from optimade_gateway.mappers import DatabasesMapper @@ -39,11 +38,12 @@ @ROUTER.get( "/databases", - response_model=Union[DatabasesResponse, ErrorResponse], + response_model=DatabasesResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Databases"], + responses=ERROR_RESPONSES, ) async def get_databases( request: Request, @@ -65,11 +65,12 @@ async def get_databases( @ROUTER.post( "/databases", - response_model=Union[DatabasesResponseSingle, ErrorResponse], + response_model=DatabasesResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Databases"], + responses=ERROR_RESPONSES, ) async def post_databases( request: Request, database: DatabaseCreate @@ -99,11 +100,12 @@ async def post_databases( @ROUTER.get( "/databases/{database_id:path}", - response_model=Union[DatabasesResponseSingle, ErrorResponse], + response_model=DatabasesResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Databases"], + responses=ERROR_RESPONSES, ) async def get_database( request: Request, diff --git a/optimade_gateway/routers/gateway/info.py b/optimade_gateway/routers/gateway/info.py index 2b863c82..809d2ba4 100644 --- a/optimade_gateway/routers/gateway/info.py +++ b/optimade_gateway/routers/gateway/info.py @@ -6,19 +6,17 @@ where `version` and `entry` may be left out. """ -from typing import Union - from fastapi import APIRouter, Request from optimade import __api_version__ from optimade.models import ( BaseInfoAttributes, BaseInfoResource, EntryInfoResponse, - ErrorResponse, InfoResponse, StructureResource, ) from optimade.server.routers.utils import get_base_url, meta_values +from optimade.server.schemas import ERROR_RESPONSES ROUTER = APIRouter(redirect_slashes=True) @@ -27,11 +25,12 @@ @ROUTER.get( "/gateways/{gateway_id}/info", - response_model=Union[InfoResponse, ErrorResponse], + response_model=InfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_gateways_info( request: Request, @@ -88,11 +87,12 @@ async def get_gateways_info( @ROUTER.get( "/gateways/{gateway_id}/info/{entry}", - response_model=Union[EntryInfoResponse, ErrorResponse], + response_model=EntryInfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_gateways_entry_info( request: Request, gateway_id: str, entry: str @@ -102,7 +102,7 @@ async def get_gateways_entry_info( Get information about the gateway `{gateway_id}`'s entry-listing endpoints. """ from optimade.models import EntryInfoResource - from optimade.server.exceptions import BadRequest + from optimade.server.exceptions import NotFound from optimade_gateway.routers.gateways import GATEWAYS_COLLECTION from optimade_gateway.routers.utils import ( @@ -114,9 +114,7 @@ async def get_gateways_entry_info( valid_entry_info_endpoints = ENTRY_INFO_SCHEMAS.keys() if entry not in valid_entry_info_endpoints: - raise BadRequest( - title="Not Found", - status_code=404, + raise NotFound( detail=( f"Entry info not found for {entry}, valid entry info endpoints are: " f"{', '.join(valid_entry_info_endpoints)}" @@ -147,11 +145,12 @@ async def get_gateways_entry_info( @ROUTER.get( "/gateways/{gateway_id}/{version}/info", - response_model=Union[InfoResponse, ErrorResponse], + response_model=InfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_versioned_gateways_info( request: Request, @@ -170,11 +169,12 @@ async def get_versioned_gateways_info( @ROUTER.get( "/gateways/{gateway_id}/{version}/info/{entry}", - response_model=Union[EntryInfoResponse, ErrorResponse], + response_model=EntryInfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_versioned_gateways_entry_info( request: Request, diff --git a/optimade_gateway/routers/gateway/links.py b/optimade_gateway/routers/gateway/links.py index e75344fe..1abd15d6 100644 --- a/optimade_gateway/routers/gateway/links.py +++ b/optimade_gateway/routers/gateway/links.py @@ -6,25 +6,24 @@ where `version` may be left out. """ -from typing import Union - from fastapi import APIRouter, Depends, Request from optimade.models import ( - ErrorResponse, LinksResponse, ) from optimade.server.query_params import EntryListingQueryParams +from optimade.server.schemas import ERROR_RESPONSES ROUTER = APIRouter(redirect_slashes=True) @ROUTER.get( "/gateways/{gateway_id}/links", - response_model=Union[LinksResponse, ErrorResponse], + response_model=LinksResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Links"], + responses=ERROR_RESPONSES, ) async def get_gateways_links( request: Request, @@ -45,11 +44,12 @@ async def get_gateways_links( @ROUTER.get( "/gateways/{gateway_id}/{version}/links", - response_model=Union[LinksResponse, ErrorResponse], + response_model=LinksResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Links"], + responses=ERROR_RESPONSES, ) async def get_versioned_gateways_links( request: Request, diff --git a/optimade_gateway/routers/gateway/queries.py b/optimade_gateway/routers/gateway/queries.py index 48a94e74..a014e7c6 100644 --- a/optimade_gateway/routers/gateway/queries.py +++ b/optimade_gateway/routers/gateway/queries.py @@ -6,13 +6,11 @@ where `version` and the last `id` may be left out. """ -from typing import Union - from fastapi import APIRouter, Depends, Request, status -from optimade.models import ErrorResponse from optimade.models.responses import EntryResponseMany -from optimade.server.exceptions import BadRequest +from optimade.server.exceptions import Forbidden from optimade.server.query_params import EntryListingQueryParams +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.models import ( QueryCreate, @@ -27,11 +25,12 @@ @ROUTER.get( "/gateways/{gateway_id}/queries", - response_model=Union[QueriesResponse, ErrorResponse], + response_model=QueriesResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways", "Queries"], + responses=ERROR_RESPONSES, ) async def get_gateway_queries( request: Request, @@ -57,12 +56,13 @@ async def get_gateway_queries( @ROUTER.post( "/gateways/{gateway_id}/queries", - response_model=Union[QueriesResponseSingle, ErrorResponse], + response_model=QueriesResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways", "Queries"], status_code=status.HTTP_202_ACCEPTED, + responses=ERROR_RESPONSES, ) async def post_gateway_queries( request: Request, @@ -79,9 +79,7 @@ async def post_gateway_queries( await validate_resource(GATEWAYS_COLLECTION, gateway_id) if query.gateway_id and query.gateway_id != gateway_id: - raise BadRequest( - status_code=403, - title="Forbidden", + raise Forbidden( detail=( f"The gateway ID in the posted data () does not align " f"with the gateway ID specified in the URL (/{gateway_id}/)." @@ -93,11 +91,12 @@ async def post_gateway_queries( @ROUTER.get( "/gateways/{gateway_id}/queries/{query_id}", - response_model=Union[EntryResponseMany, ErrorResponse], + response_model=EntryResponseMany, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways", "Queries"], + responses=ERROR_RESPONSES, ) async def get_gateway_query( request: Request, gateway_id: str, query_id: str diff --git a/optimade_gateway/routers/gateway/structures.py b/optimade_gateway/routers/gateway/structures.py index c36f18e1..99972e14 100644 --- a/optimade_gateway/routers/gateway/structures.py +++ b/optimade_gateway/routers/gateway/structures.py @@ -23,6 +23,7 @@ ) from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams from optimade.server.routers.utils import meta_values +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.models import QueryResource from optimade_gateway.queries import perform_query @@ -36,11 +37,12 @@ @ROUTER.get( "/gateways/{gateway_id}/structures", - response_model=Union[StructureResponseMany, ErrorResponse], + response_model=StructureResponseMany, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Structures"], + responses=ERROR_RESPONSES, ) async def get_structures( request: Request, @@ -88,17 +90,23 @@ async def get_structures( break else: response.status_code = 500 - - return gateway_response + return gateway_response + elif isinstance(gateway_response, StructureResponseMany): + return gateway_response + else: + raise TypeError( + "The response should be either StructureResponseMany or ErrorResponse." + ) @ROUTER.get( "/gateways/{gateway_id}/structures/{structure_id:path}", - response_model=Union[StructureResponseOne, ErrorResponse], + response_model=StructureResponseOne, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Structures"], + responses=ERROR_RESPONSES, ) async def get_single_structure( request: Request, @@ -241,12 +249,13 @@ async def get_single_structure( @ROUTER.get( "/gateways/{gateway_id}/{version}/structures", - response_model=Union[StructureResponseMany, ErrorResponse], + response_model=StructureResponseMany, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Structures"], include_in_schema=False, + responses=ERROR_RESPONSES, ) async def get_versioned_structures( request: Request, @@ -265,12 +274,13 @@ async def get_versioned_structures( @ROUTER.get( "/gateways/{gateway_id}/{version}/structures/{structure_id:path}", - response_model=Union[StructureResponseOne, ErrorResponse], + response_model=StructureResponseOne, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Structures"], include_in_schema=False, + responses=ERROR_RESPONSES, ) async def get_versioned_single_structure( request: Request, diff --git a/optimade_gateway/routers/gateway/utils.py b/optimade_gateway/routers/gateway/utils.py index bfb6e2a1..7d2062fd 100644 --- a/optimade_gateway/routers/gateway/utils.py +++ b/optimade_gateway/routers/gateway/utils.py @@ -1,4 +1,4 @@ -from optimade.server.exceptions import BadRequest, VersionNotSupported +from optimade.server.exceptions import NotFound, VersionNotSupported from optimade.server.routers.utils import BASE_URL_PREFIXES @@ -17,8 +17,6 @@ async def validate_version(version: str) -> None: detail=f"version {version} is not supported. Supported versions: {valid_versions}" ) else: - raise BadRequest( - title="Not Found", - status_code=404, + raise NotFound( detail=f"version MUST be one of {valid_versions}", ) diff --git a/optimade_gateway/routers/gateway/versions.py b/optimade_gateway/routers/gateway/versions.py index 94ca58ea..cbf162e4 100644 --- a/optimade_gateway/routers/gateway/versions.py +++ b/optimade_gateway/routers/gateway/versions.py @@ -7,6 +7,7 @@ """ from fastapi import APIRouter, Request from optimade.server.routers.versions import CsvResponse +from optimade.server.schemas import ERROR_RESPONSES ROUTER = APIRouter(redirect_slashes=True) @@ -15,6 +16,7 @@ "/gateways/{gateway_id}/versions", response_class=CsvResponse, tags=["Versions"], + responses=ERROR_RESPONSES, ) async def get_gateway_versions(request: Request, gateway_id: str) -> CsvResponse: """`GET /gateways/{gateway_id}/versions` diff --git a/optimade_gateway/routers/gateways.py b/optimade_gateway/routers/gateways.py index 533e777d..122e55ee 100644 --- a/optimade_gateway/routers/gateways.py +++ b/optimade_gateway/routers/gateways.py @@ -6,13 +6,12 @@ where, `id` may be left out. """ -from typing import Union - from fastapi import APIRouter, Depends, Request from fastapi.responses import RedirectResponse -from optimade.models import ErrorResponse, ToplevelLinks +from optimade.models import ToplevelLinks from optimade.server.query_params import EntryListingQueryParams from optimade.server.routers.utils import meta_values +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.common.config import CONFIG from optimade_gateway.mappers import GatewaysMapper @@ -36,11 +35,12 @@ @ROUTER.get( "/gateways", - response_model=Union[GatewaysResponse, ErrorResponse], + response_model=GatewaysResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways"], + responses=ERROR_RESPONSES, ) async def get_gateways( request: Request, @@ -62,11 +62,12 @@ async def get_gateways( @ROUTER.post( "/gateways", - response_model=Union[GatewaysResponseSingle, ErrorResponse], + response_model=GatewaysResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways"], + responses=ERROR_RESPONSES, ) async def post_gateways( request: Request, gateway: GatewayCreate @@ -110,11 +111,12 @@ async def post_gateways( @ROUTER.get( "/gateways/{gateway_id}", - response_model=Union[GatewaysResponseSingle, ErrorResponse], + response_model=GatewaysResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Gateways"], + responses=ERROR_RESPONSES, ) async def get_gateway(request: Request, gateway_id: str) -> GatewaysResponseSingle: """`GET /gateways/{gateway ID}` diff --git a/optimade_gateway/routers/info.py b/optimade_gateway/routers/info.py index 4d3c6ad9..78d1df1a 100644 --- a/optimade_gateway/routers/info.py +++ b/optimade_gateway/routers/info.py @@ -6,19 +6,17 @@ where, `entry` may be left out. """ -from typing import Union - from fastapi import APIRouter, Request from optimade import __api_version__ from optimade.models import ( BaseInfoAttributes, BaseInfoResource, EntryInfoResponse, - ErrorResponse, InfoResponse, LinksResource, ) from optimade.server.routers.utils import get_base_url, meta_values +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.models import GatewayResource, QueryResource @@ -33,11 +31,12 @@ @ROUTER.get( "/info", - response_model=Union[InfoResponse, ErrorResponse], + response_model=InfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_info(request: Request) -> InfoResponse: """`GET /info` @@ -83,11 +82,12 @@ async def get_info(request: Request) -> InfoResponse: @ROUTER.get( "/info/{entry}", - response_model=Union[EntryInfoResponse, ErrorResponse], + response_model=EntryInfoResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Info"], + responses=ERROR_RESPONSES, ) async def get_entry_info(request: Request, entry: str) -> EntryInfoResponse: """`GET /info/{entry}` @@ -95,15 +95,13 @@ async def get_entry_info(request: Request, entry: str) -> EntryInfoResponse: Get information about the gateway service's entry-listing endpoints. """ from optimade.models import EntryInfoResource - from optimade.server.exceptions import BadRequest + from optimade.server.exceptions import NotFound from optimade_gateway.routers.utils import aretrieve_queryable_properties valid_entry_info_endpoints = ENTRY_INFO_SCHEMAS.keys() if entry not in valid_entry_info_endpoints: - raise BadRequest( - title="Not Found", - status_code=404, + raise NotFound( detail=( f"Entry info not found for {entry}, valid entry info endpoints are: " f"{', '.join(valid_entry_info_endpoints)}" diff --git a/optimade_gateway/routers/links.py b/optimade_gateway/routers/links.py index 3be28677..d66996ed 100644 --- a/optimade_gateway/routers/links.py +++ b/optimade_gateway/routers/links.py @@ -5,17 +5,15 @@ /links """ -from typing import Union - from fastapi import APIRouter, Depends, Request -from optimade.models import ErrorResponse, LinksResponse, LinksResource +from optimade.models import LinksResponse, LinksResource from optimade.server.mappers import LinksMapper from optimade.server.query_params import EntryListingQueryParams +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.common.config import CONFIG from optimade_gateway.mongo.collection import AsyncMongoCollection - from optimade_gateway.routers.utils import get_entries ROUTER = APIRouter(redirect_slashes=True) @@ -29,11 +27,12 @@ @ROUTER.get( "/links", - response_model=Union[LinksResponse, ErrorResponse], + response_model=LinksResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Links"], + responses=ERROR_RESPONSES, ) async def get_links( request: Request, params: EntryListingQueryParams = Depends() diff --git a/optimade_gateway/routers/queries.py b/optimade_gateway/routers/queries.py index da70e7f7..87d910a6 100644 --- a/optimade_gateway/routers/queries.py +++ b/optimade_gateway/routers/queries.py @@ -20,6 +20,7 @@ from optimade.models.responses import EntryResponseMany from optimade.server.query_params import EntryListingQueryParams from optimade.server.routers.utils import meta_values +from optimade.server.schemas import ERROR_RESPONSES from optimade_gateway.common.config import CONFIG from optimade_gateway.mappers import QueryMapper @@ -45,11 +46,12 @@ @ROUTER.get( "/queries", - response_model=Union[QueriesResponse, ErrorResponse], + response_model=QueriesResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Queries"], + responses=ERROR_RESPONSES, ) async def get_queries( request: Request, @@ -71,12 +73,13 @@ async def get_queries( @ROUTER.post( "/queries", - response_model=Union[QueriesResponseSingle, ErrorResponse], + response_model=QueriesResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Queries"], status_code=status.HTTP_202_ACCEPTED, + responses=ERROR_RESPONSES, ) async def post_queries( request: Request, @@ -114,17 +117,18 @@ async def post_queries( @ROUTER.get( "/queries/{query_id:path}", - response_model=Union[EntryResponseMany, ErrorResponse, GatewayQueryResponse], + response_model=GatewayQueryResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Queries"], + responses=ERROR_RESPONSES, ) async def get_query( request: Request, query_id: str, response: Response, -) -> Union[EntryResponseMany, ErrorResponse, GatewayQueryResponse]: +) -> Union[ErrorResponse, GatewayQueryResponse]: """`GET /queries/{query_id}` Return the response from a query diff --git a/optimade_gateway/routers/search.py b/optimade_gateway/routers/search.py index 362e1f1e..a5d196ed 100644 --- a/optimade_gateway/routers/search.py +++ b/optimade_gateway/routers/search.py @@ -19,8 +19,6 @@ from fastapi.responses import RedirectResponse from optimade.server.exceptions import BadRequest from optimade.models import ( - EntryResponseMany, - ErrorResponse, LinksResource, LinksResourceAttributes, ToplevelLinks, @@ -28,6 +26,7 @@ from optimade.models.links import LinkType from optimade.server.query_params import EntryListingQueryParams from optimade.server.routers.utils import meta_values +from optimade.server.schemas import ERROR_RESPONSES from pydantic import AnyUrl, ValidationError from optimade_gateway.common.config import CONFIG @@ -40,7 +39,11 @@ QueryResource, Search, ) -from optimade_gateway.models.queries import OptimadeQueryParameters, QueryState +from optimade_gateway.models.queries import ( + GatewayQueryResponse, + OptimadeQueryParameters, + QueryState, +) from optimade_gateway.queries import perform_query, SearchQueryParams @@ -49,12 +52,13 @@ @ROUTER.post( "/search", - response_model=Union[QueriesResponseSingle, ErrorResponse], + response_model=QueriesResponseSingle, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Search"], status_code=status.HTTP_202_ACCEPTED, + responses=ERROR_RESPONSES, ) async def post_search(request: Request, search: Search) -> QueriesResponseSingle: """`POST /search` @@ -118,7 +122,7 @@ async def post_search(request: Request, search: Search) -> QueriesResponseSingle f"{url.user + '@' if url.user else ''}{url.host}" f"{':' + url.port if url.port else ''}" f"{url.path.rstrip('/') if url.path else ''}" - ), + ).replace(".", "__"), type="links", attributes=LinksResourceAttributes( name=( @@ -169,18 +173,19 @@ async def post_search(request: Request, search: Search) -> QueriesResponseSingle @ROUTER.get( "/search", - response_model=Union[EntryResponseMany, ErrorResponse], + response_model=GatewayQueryResponse, response_model_exclude_defaults=False, response_model_exclude_none=False, response_model_exclude_unset=True, tags=["Search"], + responses=ERROR_RESPONSES, ) async def get_search( request: Request, response: Response, search_params: SearchQueryParams = Depends(), entry_params: EntryListingQueryParams = Depends(), -) -> Union[EntryResponseMany, ErrorResponse, RedirectResponse]: +) -> Union[GatewayQueryResponse, RedirectResponse]: """`GET /search` Coordinate a new OPTIMADE query in multiple databases through a gateway: diff --git a/optimade_gateway/routers/utils.py b/optimade_gateway/routers/utils.py index c9de1139..9d681e63 100644 --- a/optimade_gateway/routers/utils.py +++ b/optimade_gateway/routers/utils.py @@ -10,7 +10,7 @@ ToplevelLinks, ) from optimade.models.links import LinkType -from optimade.server.exceptions import BadRequest +from optimade.server.exceptions import NotFound from optimade.server.query_params import EntryListingQueryParams from optimade.server.routers.utils import ( get_base_url, @@ -96,9 +96,7 @@ async def aretrieve_queryable_properties( async def validate_resource(collection: AsyncMongoCollection, entry_id: str) -> None: """Validate whether a resource exists in a collection""" if not await collection.exists(entry_id): - raise BadRequest( - title="Not Found", - status_code=404, + raise NotFound( detail=f"Resource not found in {collection}.", ) @@ -182,8 +180,8 @@ async def resource_factory( mongo_query = { "$or": [ - {"base_url": {"$eq": await clean_python_types(base_url)}}, - {"base_url.href": {"$eq": await clean_python_types(base_url)}}, + {"base_url": {"$eq": base_url}}, + {"base_url.href": {"$eq": base_url}}, ] } elif isinstance(create_resource, GatewayCreate): @@ -194,9 +192,7 @@ async def resource_factory( mongo_query = { "databases": {"$size": len(create_resource.databases)}, "databases.attributes.base_url": { - "$all": await clean_python_types( - [_.attributes.base_url for _ in create_resource.databases] - ) + "$all": [_.attributes.base_url for _ in create_resource.databases] }, } elif isinstance(create_resource, QueryCreate): @@ -212,9 +208,7 @@ async def resource_factory( mongo_query = { "gateway_id": {"$eq": create_resource.gateway_id}, - "query_parameters": { - "$eq": await clean_python_types(create_resource.query_parameters), - }, + "query_parameters": {"$eq": create_resource.query_parameters}, "endpoint": {"$eq": create_resource.endpoint}, } else: @@ -224,7 +218,7 @@ async def resource_factory( ) result, more_data_available, _ = await RESOURCE_COLLECTION.find( - criteria={"filter": mongo_query} + criteria={"filter": await clean_python_types(mongo_query)} ) if more_data_available: diff --git a/tests/conftest.py b/tests/conftest.py index 1a935de7..4e94982d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,9 +173,21 @@ def _mock_response(gateway: dict) -> None: top_dir / f"tests/static/db_responses/{database['id']}.json" ) as handle: data = json.load(handle) + + if data.get("errors", []): + for error in data.get("errors", []): + if "status" in error: + status_code = int(error["status"]) + break + else: + status_code = 500 + else: + status_code = 200 + httpx_mock.add_response( url=re.compile(fr"{database['attributes']['base_url']}.*"), json=data, + status_code=status_code, ) def sleep_response(request: httpx.Request, extensions: dict) -> MockResponse: diff --git a/tests/routers/gateway/test_gateway_structures.py b/tests/routers/gateway/test_gateway_structures.py index 84755053..d74a69f3 100644 --- a/tests/routers/gateway/test_gateway_structures.py +++ b/tests/routers/gateway/test_gateway_structures.py @@ -63,15 +63,19 @@ async def test_get_structures( more_data_available = db_response["meta"]["more_data_available"] for datum in db_response["data"]: - datum["id"] = f"{database['id']}/{datum['id']}" + database_id_meta = { + "_optimade_gateway_": {"source_database_id": database["id"]} + } + if "meta" in datum: + datum["meta"].update(database_id_meta) + else: + datum["meta"] = database_id_meta data.append(datum) assert data_returned == response.meta.data_returned assert data_available == response.meta.data_available assert more_data_available == response.meta.more_data_available - data.sort(key=lambda datum: datum["id"]) - data.sort(key=lambda datum: "/".join(datum["id"].split("/")[1:])) assert data == json.loads(response.json(exclude_unset=True))["data"], ( f"IDs in test not in response: {set([_['id'] for _ in data]) - set([_['id'] for _ in json.loads(response.json(exclude_unset=True))['data']])}\n\n" f"IDs in response not in test: {set([_['id'] for _ in json.loads(response.json(exclude_unset=True))['data']]) - set([_['id'] for _ in data])}\n\n" diff --git a/tests/routers/test_queries.py b/tests/routers/test_queries.py index d3b15366..db7e1628 100644 --- a/tests/routers/test_queries.py +++ b/tests/routers/test_queries.py @@ -155,7 +155,6 @@ async def test_query_results( ): """Test POST /queries and GET /queries/{id}""" import asyncio - from optimade.models import EntryResponseMany from optimade_gateway.common.config import CONFIG from optimade_gateway.models.queries import ( @@ -182,8 +181,8 @@ async def test_query_results( response = await client(f"/queries/{data['id']}") assert response.status_code == 200, f"Request failed: {response.json()}" - response = EntryResponseMany(**response.json()) - assert response.data == [] + response = GatewayQueryResponse(**response.json()) + assert response.data == {} query: QueryResource = QueryResource( **getattr(response.meta, f"_{CONFIG.provider.prefix}_query") @@ -215,8 +214,8 @@ async def test_errored_query_results( ): """Test POST /queries and GET /queries/{id} with an erroneous response""" import asyncio - from optimade.models import ErrorResponse + from optimade_gateway.models.queries import GatewayQueryResponse from optimade_gateway.models.responses import QueriesResponseSingle data = { @@ -239,7 +238,8 @@ async def test_errored_query_results( response.status_code == 404 ), f"Request succeeded, where it should have failed:\n{json.dumps(response.json(), indent=2)}" - response = ErrorResponse(**response.json()) + response = GatewayQueryResponse(**response.json()) + assert response.errors @pytest.mark.usefixtures("reset_db_after") diff --git a/tests/routers/test_search.py b/tests/routers/test_search.py index f26ad30b..1da2733d 100644 --- a/tests/routers/test_search.py +++ b/tests/routers/test_search.py @@ -30,9 +30,8 @@ async def test_get_search( this should ensure a new gateway is created, specifically for use with these versioned base URLs, but we can reuse the mock_gateway_responses for the "twodbs" gateway. """ - from optimade.models import StructureResponseMany - from optimade_gateway.common.config import CONFIG + from optimade_gateway.models import GatewayQueryResponse gateway_id = "twodbs" gateway: dict = await get_gateway(gateway_id) @@ -52,7 +51,7 @@ async def test_get_search( assert response.status_code == 200, f"Request failed: {response.json()}" - response = StructureResponseMany(**response.json()) + response = GatewayQueryResponse(**response.json()) assert response.data assert ( getattr(response.meta, f"_{CONFIG.provider.prefix}_query", "NOT FOUND") @@ -72,9 +71,8 @@ async def test_get_search_existing_gateway( caplog: pytest.LogCaptureFixture, ): """Test GET /search for base URLs matching an existing gateway""" - from optimade.models import StructureResponseMany - from optimade_gateway.common.config import CONFIG + from optimade_gateway.models import GatewayQueryResponse gateway_id = "twodbs" gateway: dict = await get_gateway(gateway_id) @@ -115,7 +113,7 @@ async def test_get_search_existing_gateway( assert response.status_code == 200, f"Request failed: {response.json()}" - response = StructureResponseMany(**response.json()) + response = GatewayQueryResponse(**response.json()) assert response.data, f"No data: {response.json(indent=2)}" assert ( getattr(response.meta, f"_{CONFIG.provider.prefix}_query", "NOT FOUND") @@ -135,10 +133,12 @@ async def test_get_search_not_finishing( caplog: pytest.LogCaptureFixture, ): """Test GET /search for unfinished query (redirect to query URL)""" - from optimade.models import EntryResponseMany - from optimade_gateway.common.config import CONFIG - from optimade_gateway.models.queries import QueryResource, QueryState + from optimade_gateway.models.queries import ( + GatewayQueryResponse, + QueryResource, + QueryState, + ) gateway_id = "slow-query" gateway: dict = await get_gateway(gateway_id) @@ -160,12 +160,13 @@ async def test_get_search_not_finishing( assert "A gateway was found and reused for a query" in caplog.text, caplog.text - response = EntryResponseMany(**response.json()) - assert response.data == [], f"Data was found in response: {response.json(indent=2)}" + response = GatewayQueryResponse(**response.json()) + assert response.data == {}, f"Data was found in response: {response.json(indent=2)}" assert getattr( response.meta, f"_{CONFIG.provider.prefix}_query", False ), f"Special __query field not found in meta. Response: {response.json(indent=2)}" + query: QueryResource = QueryResource( **getattr(response.meta, f"_{CONFIG.provider.prefix}_query") ) @@ -175,9 +176,9 @@ async def test_get_search_not_finishing( assert ( query.attributes.query_parameters.page_limit == query_params["page_limit"] ), query - assert ( - query.attributes.response == query.attributes.__fields__["response"].default - ), query + assert isinstance(query.attributes.response, GatewayQueryResponse) + assert query.attributes.response.data == {} + assert query.attributes.response.errors == [] assert query.attributes.gateway_id == gateway_id, query @@ -239,7 +240,7 @@ async def test_post_search( == OptimadeQueryParameters(**data["query_parameters"]).dict() ), f"Response: {datum.attributes.query_parameters!r}\n\nTest data: {OptimadeQueryParameters(**data['query_parameters'])!r}" - assert datum.attributes.state == QueryState.CREATED + assert datum.attributes.state in [QueryState.CREATED, QueryState.STARTED] assert datum.attributes.response is None with open(top_dir / "tests/static/test_gateways.json") as handle: @@ -321,7 +322,7 @@ async def test_post_search_existing_gateway( == OptimadeQueryParameters(**gateway_create_data["query_parameters"]).dict() ), f"Response: {datum.attributes.query_parameters!r}\n\nTest data: {OptimadeQueryParameters(**gateway_create_data['query_parameters'])!r}" - assert datum.attributes.state == QueryState.CREATED + assert datum.attributes.state in [QueryState.CREATED, QueryState.STARTED] assert datum.attributes.response is None assert datum.attributes.gateway_id == gateway_id @@ -345,8 +346,9 @@ async def test_sort_no_effect( This means if the `sort` parameter is used, the response should not change - it should be ignored. """ - from optimade.models import StructureResponseMany, Warnings + from optimade.models import Warnings + from optimade_gateway.models import GatewayQueryResponse from optimade_gateway.models.responses import QueriesResponseSingle from optimade_gateway.warnings import SortNotSupported @@ -373,9 +375,9 @@ async def test_sort_no_effect( assert response_asc.status_code == 200, f"Request failed: {response_asc.json()}" assert response_desc.status_code == 200, f"Request failed: {response_desc.json()}" - response_asc = StructureResponseMany(**response_asc.json()) + response_asc = GatewayQueryResponse(**response_asc.json()) assert response_asc - response_desc = StructureResponseMany(**response_desc.json()) + response_desc = GatewayQueryResponse(**response_desc.json()) assert response_desc assert response_asc.data == response_desc.data