From 3e76fd0e5700519c03d5aae8c6b7df6ff5d0c982 Mon Sep 17 00:00:00 2001 From: Alyssa Dai Date: Sun, 14 Jan 2024 17:59:25 -0500 Subject: [PATCH] [ENH] Handle partial nodes success (#55) * add comment explaining FEDERATION_NODES format * implement custom HTTP response for partial success federated query - use custom HTTP success status code - return both node-specific errors and combined successful query results in response body - log node errors and query federation successfulness to console * test API path response for partial success federated query * mock individual node request to assert over combined query response * add new response model returning errors and node query results * do not use exception object for partial success responses * update response returned when all nodes succeed or fail * handle network errors in federated query * test federated response given unreachable nodes, create fixture for single matching dataset result * test response when queries to all nodes either fail or succeed * turn status of node responses into enum * use model for node error in federated query response * switch to sending requests to nodes asynchronously * update mocked get function in tests to be async * make code a bit cleaner Co-authored-by: Sebastian Urchs * rename fixture for mocked data --------- Co-authored-by: Sebastian Urchs --- app/api/crud.py | 66 +++++++++++-- app/api/models.py | 24 +++++ app/api/routers/query.py | 6 +- app/api/utility.py | 43 ++++---- tests/test_query.py | 205 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 315 insertions(+), 29 deletions(-) create mode 100644 tests/test_query.py diff --git a/app/api/crud.py b/app/api/crud.py index c926018..3444bda 100644 --- a/app/api/crud.py +++ b/app/api/crud.py @@ -1,5 +1,11 @@ """CRUD functions called by path operations.""" +import asyncio +import warnings + +from fastapi import HTTPException, status +from fastapi.responses import JSONResponse + from . import utility as util @@ -13,7 +19,7 @@ async def get( assessment: str, image_modal: str, node_urls: list[str], -): +) -> dict: """ Makes GET requests to one or more Neurobagel node APIs using send_get_request utility function where the parameters are Neurobagel query parameters. @@ -45,8 +51,10 @@ async def get( """ cross_node_results = [] + node_errors = [] node_urls = util.validate_query_node_url_list(node_urls) + total_nodes = len(node_urls) # Node API query parameters params = {} @@ -67,19 +75,61 @@ async def get( if image_modal: params["image_modal"] = image_modal - for node_url in node_urls: - node_name = util.FEDERATION_NODES[node_url] - response = util.send_get_request(node_url + "query/", params) + tasks = [ + util.send_get_request(node_url + "query/", params) + for node_url in node_urls + ] + responses = await asyncio.gather(*tasks, return_exceptions=True) - for result in response: - result["node_name"] = node_name + for node_url, response in zip(node_urls, responses): + node_name = util.FEDERATION_NODES[node_url] + if isinstance(response, HTTPException): + node_errors.append( + {"node_name": node_name, "error": response.detail} + ) + warnings.warn( + f"Query to node {node_name} ({node_url}) did not succeed: {response.detail}" + ) + else: + for result in response: + result["node_name"] = node_name + cross_node_results.extend(response) + + if node_errors: + # TODO: Use logger instead of print, see https://github.com/tiangolo/fastapi/issues/5003 + print( + f"Queries to {len(node_errors)}/{total_nodes} nodes failed: {[node_error['node_name'] for node_error in node_errors]}." + ) - cross_node_results += response + if len(node_errors) == total_nodes: + # See https://fastapi.tiangolo.com/advanced/additional-responses/ for more info + return JSONResponse( + status_code=status.HTTP_207_MULTI_STATUS, + content={ + "errors": node_errors, + "responses": cross_node_results, + "nodes_response_status": "fail", + }, + ) + return JSONResponse( + status_code=status.HTTP_207_MULTI_STATUS, + content={ + "errors": node_errors, + "responses": cross_node_results, + "nodes_response_status": "partial success", + }, + ) - return cross_node_results + print(f"All nodes queried successfully ({total_nodes/total_nodes}).") + return { + "errors": node_errors, + "responses": cross_node_results, + "nodes_response_status": "success", + } async def get_terms(data_element_URI: str): + # TODO: Make this path able to handle partial successes as well """ Makes a GET request to one or more Neurobagel node APIs using send_get_request utility function where the only parameter is a data element URI. diff --git a/app/api/models.py b/app/api/models.py index b086776..1c0f7b3 100644 --- a/app/api/models.py +++ b/app/api/models.py @@ -1,4 +1,5 @@ """Data models.""" +from enum import Enum from typing import Optional, Union from fastapi import Query @@ -36,3 +37,26 @@ class CohortQueryResponse(BaseModel): num_matching_subjects: int subject_data: Union[list[dict], str] image_modals: list + + +class NodesResponseStatus(str, Enum): + """Possible values for the status of the responses from the queried nodes.""" + + SUCCESS = "success" + PARTIAL_SUCCESS = "partial success" + FAIL = "fail" + + +class NodeError(BaseModel): + """Data model for an error encountered when querying a node.""" + + node_name: str + error: str + + +class CombinedQueryResponse(BaseModel): + """Data model for the combined query results of all matching datasets across all queried nodes.""" + + errors: list[NodeError] + responses: list[CohortQueryResponse] + nodes_response_status: NodesResponseStatus diff --git a/app/api/routers/query.py b/app/api/routers/query.py index cd53245..eed690c 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -1,16 +1,14 @@ """Router for query path operations.""" -from typing import List - from fastapi import APIRouter, Depends from .. import crud -from ..models import CohortQueryResponse, QueryModel +from ..models import CombinedQueryResponse, QueryModel router = APIRouter(prefix="/query", tags=["query"]) -@router.get("/", response_model=List[CohortQueryResponse]) +@router.get("/", response_model=CombinedQueryResponse) async def get_query(query: QueryModel = Depends(QueryModel)): """When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset.""" response = await crud.get( diff --git a/app/api/utility.py b/app/api/utility.py index 484e304..6636325 100644 --- a/app/api/utility.py +++ b/app/api/utility.py @@ -6,10 +6,12 @@ import httpx import jsonschema -from fastapi import HTTPException +from fastapi import HTTPException, status from jsonschema import validate LOCAL_NODE_INDEX_PATH = Path(__file__).parents[2] / "local_nb_nodes.json" + +# Stores the names and URLs of all Neurobagel nodes known to the API instance, in the form of {node_url: node_name, ...} FEDERATION_NODES = {} # We use this schema to validate the local_nb_nodes.json file @@ -196,7 +198,7 @@ def validate_query_node_url_list(node_urls: list) -> list: return node_urls -def send_get_request(url: str, params: list): +async def send_get_request(url: str, params: list) -> dict: """ Makes a GET request to one or more Neurobagel nodes. @@ -218,19 +220,26 @@ def send_get_request(url: str, params: list): HTTPException _description_ """ - response = httpx.get( - url=url, - params=params, - # TODO: Revisit timeout value when query performance is improved - timeout=30.0, - # Enable redirect following (off by default) so - # APIs behind a proxy can be reached - follow_redirects=True, - ) + async with httpx.AsyncClient() as client: + try: + response = await client.get( + url=url, + params=params, + # TODO: Revisit timeout value when query performance is improved + timeout=30.0, + # Enable redirect following (off by default) so + # APIs behind a proxy can be reached + follow_redirects=True, + ) - if not response.is_success: - raise HTTPException( - status_code=response.status_code, - detail=f"{response.reason_phrase}: {response.text}", - ) - return response.json() + if not response.is_success: + raise HTTPException( + status_code=response.status_code, + detail=f"{response.reason_phrase}: {response.text}", + ) + return response.json() + except httpx.NetworkError as exc: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail=f"Request failed due to a network error or because the node API cannot be reached: {exc}", + ) from exc diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..4cbaf25 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,205 @@ +import httpx +import pytest +from fastapi import status + +from app.api import utility as util + + +@pytest.fixture() +def mocked_single_matching_dataset_result(): + """Valid aggregate query result for a single matching dataset.""" + return { + "dataset_uuid": "http://neurobagel.org/vocab/12345", + "dataset_name": "QPN", + "dataset_portal_uri": "https://rpq-qpn.ca/en/researchers-section/databases/", + "dataset_total_subjects": 200, + "num_matching_subjects": 5, + "records_protected": True, + "subject_data": "protected", + "image_modals": [ + "http://purl.org/nidash/nidm#T1Weighted", + "http://purl.org/nidash/nidm#T2Weighted", + ], + } + + +def test_partial_node_failure_responses_handled_gracefully( + monkeypatch, test_app, capsys, mocked_single_matching_dataset_result +): + """ + Test that when queries to some nodes return errors, the overall API get request still succeeds, + the successful responses are returned along with a list of the encountered errors, and the failed nodes are logged to the console. + """ + monkeypatch.setattr( + util, + "FEDERATION_NODES", + { + "https://firstpublicnode.org/": "First Public Node", + "https://secondpublicnode.org/": "Second Public Node", + }, + ) + + async def mock_httpx_get(self, **kwargs): + if kwargs["url"] == "https://firstpublicnode.org/query/": + return httpx.Response( + status_code=200, json=[mocked_single_matching_dataset_result] + ) + + return httpx.Response( + status_code=500, json={}, text="Some internal server error" + ) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + with pytest.warns( + UserWarning, + match=r"Second Public Node \(https://secondpublicnode.org/\) did not succeed", + ): + response = test_app.get("/query/") + captured = capsys.readouterr() + + assert response.status_code == status.HTTP_207_MULTI_STATUS + assert response.json() == { + "errors": [ + { + "node_name": "Second Public Node", + "error": "Internal Server Error: Some internal server error", + }, + ], + "responses": [ + { + **mocked_single_matching_dataset_result, + "node_name": "First Public Node", + }, + ], + "nodes_response_status": "partial success", + } + assert ( + "Queries to 1/2 nodes failed: ['Second Public Node']" in captured.out + ) + + +def test_partial_node_connection_failures_handled_gracefully( + monkeypatch, test_app, capsys, mocked_single_matching_dataset_result +): + """ + Test that when requests to some nodes fail (e.g., if API is unreachable), the overall API get request still succeeds, + the successful responses are returned along with a list of the encountered errors, and the failed nodes are logged to the console. + """ + monkeypatch.setattr( + util, + "FEDERATION_NODES", + { + "https://firstpublicnode.org/": "First Public Node", + "https://secondpublicnode.org/": "Second Public Node", + }, + ) + + async def mock_httpx_get(self, **kwargs): + if kwargs["url"] == "https://firstpublicnode.org/query/": + return httpx.Response( + status_code=200, json=[mocked_single_matching_dataset_result] + ) + + raise httpx.ConnectError("Some connection error") + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + with pytest.warns( + UserWarning, + match=r"Second Public Node \(https://secondpublicnode.org/\) did not succeed", + ): + response = test_app.get("/query/") + captured = capsys.readouterr() + + assert response.status_code == status.HTTP_207_MULTI_STATUS + assert response.json() == { + "errors": [ + { + "node_name": "Second Public Node", + "error": "Request failed due to a network error or because the node API cannot be reached: Some connection error", + }, + ], + "responses": [ + { + **mocked_single_matching_dataset_result, + "node_name": "First Public Node", + }, + ], + "nodes_response_status": "partial success", + } + assert ( + "Queries to 1/2 nodes failed: ['Second Public Node']" in captured.out + ) + + +def test_all_nodes_failure_handled_gracefully(monkeypatch, test_app, capsys): + """ + Test that when queries sent to all nodes fail, the federation API get request still succeeds, + but includes an overall failure status and all encountered errors in the response. + """ + monkeypatch.setattr( + util, + "FEDERATION_NODES", + { + "https://firstpublicnode.org/": "First Public Node", + "https://secondpublicnode.org/": "Second Public Node", + }, + ) + + async def mock_httpx_get(self, **kwargs): + raise httpx.ConnectError("Some connection error") + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + with pytest.warns( + UserWarning, + ) as w: + response = test_app.get("/query/") + captured = capsys.readouterr() + + assert len(w) == 2 + assert response.status_code == status.HTTP_207_MULTI_STATUS + + response = response.json() + assert response["nodes_response_status"] == "fail" + assert len(response["errors"]) == 2 + assert response["responses"] == [] + assert ( + "Queries to 2/2 nodes failed: ['First Public Node', 'Second Public Node']" + in captured.out + ) + + +def test_all_nodes_success_handled_gracefully( + monkeypatch, test_app, capsys, mocked_single_matching_dataset_result +): + """ + Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors. + """ + monkeypatch.setattr( + util, + "FEDERATION_NODES", + { + "https://firstpublicnode.org/": "First Public Node", + "https://secondpublicnode.org/": "Second Public Node", + }, + ) + + async def mock_httpx_get(self, **kwargs): + return httpx.Response( + status_code=200, json=[mocked_single_matching_dataset_result] + ) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_httpx_get) + + response = test_app.get("/query/") + captured = capsys.readouterr() + + assert response.status_code == status.HTTP_200_OK + + response = response.json() + assert response["nodes_response_status"] == "success" + assert response["errors"] == [] + assert len(response["responses"]) == 2 + assert "All nodes queried successfully" in captured.out