Skip to content

Commit

Permalink
Use responses property for routes
Browse files Browse the repository at this point in the history
Only supply the pydantic model to `response_model` that represents the
successful response.
Add a dictionary of erroneous status codes with the ErrorResponse
pydantic model to `responses` for all routes.

Fix bugs caught by the tests introduced in the latest commit,
modularizing out processing a database response.

Use new specific exception classes from `optimade`.
  • Loading branch information
CasperWA committed Jun 17, 2021
1 parent 63cc168 commit 66e5bc5
Show file tree
Hide file tree
Showing 21 changed files with 179 additions and 132 deletions.
25 changes: 18 additions & 7 deletions optimade_gateway/models/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
from optimade.models import (
EntryResource,
EntryResourceAttributes,
EntryResponseMany,
ErrorResponse,
OptimadeError,
ReferenceResource,
ReferenceResponseMany,
ReferenceResponseOne,
Response,
ResponseMeta,
StructureResource,
StructureResponseMany,
StructureResponseOne,
)
from optimade.models.utils import StrictField
from optimade.server.query_params import EntryListingQueryParams
from pydantic import BaseModel, EmailStr, Field, validator

Expand Down Expand Up @@ -138,13 +140,23 @@ class QueryState(Enum):
FINISHED = "finished"


class GatewayQueryResponse(EntryResponseMany):
class GatewayQueryResponse(Response):
"""Response from a Gateway Query."""

data: Dict[str, Union[List[EntryResource], List[Dict[str, Any]]]] = Field(
...,
data: Dict[str, Union[List[EntryResource], List[Dict[str, Any]]]] = StrictField(
..., uniqueItems=True, description="Outputted Data"
)
meta: ResponseMeta = StrictField(
..., description="A meta object containing non-standard information"
)
errors: Optional[List[OptimadeError]] = StrictField(
[],
description="A list of OPTIMADE-specific JSON API error objects, where the field detail MUST be present.",
uniqueItems=True,
)
included: Optional[Union[List[EntryResource], List[Dict[str, Any]]]] = Field(
None, uniqueItems=True
)


class QueryResourceAttributes(EntryResourceAttributes):
Expand All @@ -165,10 +177,9 @@ class QueryResourceAttributes(EntryResourceAttributes):
title="State",
type="enum",
)
response: Optional[Union[GatewayQueryResponse, ErrorResponse]] = Field(
response: Optional[GatewayQueryResponse] = Field(
None,
description="Response from gateway query.",
type="object",
)
endpoint: EndpointEntryType = Field(
EndpointEntryType.STRUCTURES,
Expand Down
4 changes: 2 additions & 2 deletions optimade_gateway/queries/perform.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def perform_query(
**{"$set": {"state": QueryState.IN_PROGRESS}},
)
else:
response = EntryResponseMany(
response = query.attributes.endpoint.get_response_model()(
data=[],
links=ToplevelLinks(next=None),
meta=meta_values(
Expand Down Expand Up @@ -163,7 +163,7 @@ async def perform_query(
resource["meta"] = database_id_meta
else:
resource.meta = Meta(**database_id_meta)
response.data.append(results)
response.data.extend(results)
response.meta.data_returned += response_meta["data_returned"]
if not response.meta.more_data_available:
# Keep it True, if set to True once.
Expand Down
14 changes: 8 additions & 6 deletions optimade_gateway/queries/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ async def process_db_response(
errors = []
meta = {}

from optimade_gateway.common.logger import LOGGER

LOGGER.debug("database_id: %s", database_id)

if isinstance(response, ErrorResponse):
for error in response.errors:
if isinstance(error.id, str) and error.id.startswith("OPTIMADE_GATEWAY"):
Expand Down Expand Up @@ -113,14 +117,12 @@ async def process_db_response(

# This ensures an empty list under `response.data.{database_id}` is returned if the case is
# simply that there is no results to return.
# It also ensures that only `response.errors.{database_id}` is created if there are any
# errors.
if errors:
extra_updates.update({"$addToSet": {"response.errors": {"$each": errors}}})
await update_query(
query,
f"response.errors.{database_id}"
if errors
else f"response.data.{database_id}",
errors or results,
f"response.data.{database_id}",
results,
**extra_updates,
)
else:
Expand Down
11 changes: 7 additions & 4 deletions optimade_gateway/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,20 @@ async def update_query(
update_kwargs = {"$set": {"last_modified": update_time}}

if mongo_kwargs:
update_kwargs.update(await clean_python_types(mongo_kwargs))
update_kwargs.update(mongo_kwargs)

if operator and operator == "$set":
update_kwargs["$set"].update({field: await clean_python_types(value)})
update_kwargs["$set"].update({field: value})
elif operator:
update_kwargs.update({operator: {field: await clean_python_types(value)}})
if operator in update_kwargs:
update_kwargs[operator].update({field: value})
else:
update_kwargs.update({operator: {field: value}})

# MongoDB
result: UpdateResult = await QUERIES_COLLECTION.collection.update_one(
filter={"id": {"$eq": query.id}},
update=update_kwargs,
update=await clean_python_types(update_kwargs),
)
if result.matched_count != 1:
LOGGER.error(
Expand Down
14 changes: 8 additions & 6 deletions optimade_gateway/routers/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
One can register a new database (by using `POST /databases`) or look through the available
databases (by using `GET /databases`) using standard OPTIMADE filtering.
"""
from typing import Union

from fastapi import APIRouter, Depends, Request
from optimade.models import ErrorResponse, LinksResource, ToplevelLinks
from optimade.models import LinksResource, ToplevelLinks
from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams
from optimade.server.routers.utils import handle_response_fields, meta_values
from optimade.server.schemas import ERROR_RESPONSES

from optimade_gateway.common.config import CONFIG
from optimade_gateway.mappers import DatabasesMapper
Expand All @@ -39,11 +38,12 @@

@ROUTER.get(
"/databases",
response_model=Union[DatabasesResponse, ErrorResponse],
response_model=DatabasesResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Databases"],
responses=ERROR_RESPONSES,
)
async def get_databases(
request: Request,
Expand All @@ -65,11 +65,12 @@ async def get_databases(

@ROUTER.post(
"/databases",
response_model=Union[DatabasesResponseSingle, ErrorResponse],
response_model=DatabasesResponseSingle,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Databases"],
responses=ERROR_RESPONSES,
)
async def post_databases(
request: Request, database: DatabaseCreate
Expand Down Expand Up @@ -99,11 +100,12 @@ async def post_databases(

@ROUTER.get(
"/databases/{database_id:path}",
response_model=Union[DatabasesResponseSingle, ErrorResponse],
response_model=DatabasesResponseSingle,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Databases"],
responses=ERROR_RESPONSES,
)
async def get_database(
request: Request,
Expand Down
22 changes: 11 additions & 11 deletions optimade_gateway/routers/gateway/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
where `version` and `entry` may be left out.
"""
from typing import Union

from fastapi import APIRouter, Request
from optimade import __api_version__
from optimade.models import (
BaseInfoAttributes,
BaseInfoResource,
EntryInfoResponse,
ErrorResponse,
InfoResponse,
StructureResource,
)
from optimade.server.routers.utils import get_base_url, meta_values
from optimade.server.schemas import ERROR_RESPONSES

ROUTER = APIRouter(redirect_slashes=True)

Expand All @@ -27,11 +25,12 @@

@ROUTER.get(
"/gateways/{gateway_id}/info",
response_model=Union[InfoResponse, ErrorResponse],
response_model=InfoResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Info"],
responses=ERROR_RESPONSES,
)
async def get_gateways_info(
request: Request,
Expand Down Expand Up @@ -88,11 +87,12 @@ async def get_gateways_info(

@ROUTER.get(
"/gateways/{gateway_id}/info/{entry}",
response_model=Union[EntryInfoResponse, ErrorResponse],
response_model=EntryInfoResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Info"],
responses=ERROR_RESPONSES,
)
async def get_gateways_entry_info(
request: Request, gateway_id: str, entry: str
Expand All @@ -102,7 +102,7 @@ async def get_gateways_entry_info(
Get information about the gateway `{gateway_id}`'s entry-listing endpoints.
"""
from optimade.models import EntryInfoResource
from optimade.server.exceptions import BadRequest
from optimade.server.exceptions import NotFound

from optimade_gateway.routers.gateways import GATEWAYS_COLLECTION
from optimade_gateway.routers.utils import (
Expand All @@ -114,9 +114,7 @@ async def get_gateways_entry_info(

valid_entry_info_endpoints = ENTRY_INFO_SCHEMAS.keys()
if entry not in valid_entry_info_endpoints:
raise BadRequest(
title="Not Found",
status_code=404,
raise NotFound(
detail=(
f"Entry info not found for {entry}, valid entry info endpoints are: "
f"{', '.join(valid_entry_info_endpoints)}"
Expand Down Expand Up @@ -147,11 +145,12 @@ async def get_gateways_entry_info(

@ROUTER.get(
"/gateways/{gateway_id}/{version}/info",
response_model=Union[InfoResponse, ErrorResponse],
response_model=InfoResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Info"],
responses=ERROR_RESPONSES,
)
async def get_versioned_gateways_info(
request: Request,
Expand All @@ -170,11 +169,12 @@ async def get_versioned_gateways_info(

@ROUTER.get(
"/gateways/{gateway_id}/{version}/info/{entry}",
response_model=Union[EntryInfoResponse, ErrorResponse],
response_model=EntryInfoResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Info"],
responses=ERROR_RESPONSES,
)
async def get_versioned_gateways_entry_info(
request: Request,
Expand Down
10 changes: 5 additions & 5 deletions optimade_gateway/routers/gateway/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,24 @@
where `version` may be left out.
"""
from typing import Union

from fastapi import APIRouter, Depends, Request
from optimade.models import (
ErrorResponse,
LinksResponse,
)
from optimade.server.query_params import EntryListingQueryParams
from optimade.server.schemas import ERROR_RESPONSES

ROUTER = APIRouter(redirect_slashes=True)


@ROUTER.get(
"/gateways/{gateway_id}/links",
response_model=Union[LinksResponse, ErrorResponse],
response_model=LinksResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Links"],
responses=ERROR_RESPONSES,
)
async def get_gateways_links(
request: Request,
Expand All @@ -45,11 +44,12 @@ async def get_gateways_links(

@ROUTER.get(
"/gateways/{gateway_id}/{version}/links",
response_model=Union[LinksResponse, ErrorResponse],
response_model=LinksResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Links"],
responses=ERROR_RESPONSES,
)
async def get_versioned_gateways_links(
request: Request,
Expand Down
19 changes: 9 additions & 10 deletions optimade_gateway/routers/gateway/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
where `version` and the last `id` may be left out.
"""
from typing import Union

from fastapi import APIRouter, Depends, Request, status
from optimade.models import ErrorResponse
from optimade.models.responses import EntryResponseMany
from optimade.server.exceptions import BadRequest
from optimade.server.exceptions import Forbidden
from optimade.server.query_params import EntryListingQueryParams
from optimade.server.schemas import ERROR_RESPONSES

from optimade_gateway.models import (
QueryCreate,
Expand All @@ -27,11 +25,12 @@

@ROUTER.get(
"/gateways/{gateway_id}/queries",
response_model=Union[QueriesResponse, ErrorResponse],
response_model=QueriesResponse,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Gateways", "Queries"],
responses=ERROR_RESPONSES,
)
async def get_gateway_queries(
request: Request,
Expand All @@ -57,12 +56,13 @@ async def get_gateway_queries(

@ROUTER.post(
"/gateways/{gateway_id}/queries",
response_model=Union[QueriesResponseSingle, ErrorResponse],
response_model=QueriesResponseSingle,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Gateways", "Queries"],
status_code=status.HTTP_202_ACCEPTED,
responses=ERROR_RESPONSES,
)
async def post_gateway_queries(
request: Request,
Expand All @@ -79,9 +79,7 @@ async def post_gateway_queries(
await validate_resource(GATEWAYS_COLLECTION, gateway_id)

if query.gateway_id and query.gateway_id != gateway_id:
raise BadRequest(
status_code=403,
title="Forbidden",
raise Forbidden(
detail=(
f"The gateway ID in the posted data (<gateway={query.gateway_id}>) does not align "
f"with the gateway ID specified in the URL (/{gateway_id}/)."
Expand All @@ -93,11 +91,12 @@ async def post_gateway_queries(

@ROUTER.get(
"/gateways/{gateway_id}/queries/{query_id}",
response_model=Union[EntryResponseMany, ErrorResponse],
response_model=EntryResponseMany,
response_model_exclude_defaults=False,
response_model_exclude_none=False,
response_model_exclude_unset=True,
tags=["Gateways", "Queries"],
responses=ERROR_RESPONSES,
)
async def get_gateway_query(
request: Request, gateway_id: str, query_id: str
Expand Down
Loading

0 comments on commit 66e5bc5

Please sign in to comment.