From 07a6afa007b1cc0ac90a01cad3d87d79c6b58b3a Mon Sep 17 00:00:00 2001 From: Alyssa Dai Date: Wed, 17 Jul 2024 13:59:07 -0400 Subject: [PATCH] [ENH] Add authentication to `/query` route (#323) * add implicit OAuth flow + Google token verification for /query route * add dependencies for Google auth library * check auth env vars on startup * mock token/token verification and disable auth as needed in tests * add tests of auth utilities and filter irrelevant warnings * test empty query succeeds when auth is disabled --- app/api/routers/query.py | 30 +++++- app/api/security.py | 45 +++++++++ app/main.py | 10 +- requirements.txt | 14 ++- tests/conftest.py | 33 +++++++ tests/test_app_events.py | 34 +++++-- tests/test_attributes.py | 9 +- tests/test_query.py | 199 +++++++++++++++++++++++++++++++-------- tests/test_security.py | 64 +++++++++++++ 9 files changed, 383 insertions(+), 55 deletions(-) create mode 100644 app/api/security.py create mode 100644 tests/test_security.py diff --git a/app/api/routers/query.py b/app/api/routers/query.py index c1e1b1a..539b27a 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -2,17 +2,41 @@ from typing import List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2 -from .. import crud +from .. import crud, security from ..models import CohortQueryResponse, QueryModel +from ..security import verify_token router = APIRouter(prefix="/query", tags=["query"]) +# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382 +oauth2_scheme = OAuth2( + flows={ + "implicit": { + "authorizationUrl": "https://accounts.google.com/o/oauth2/auth", + } + }, + # Don't automatically error out when request is not authenticated, to support optional authentication + auto_error=False, +) + @router.get("/", response_model=List[CohortQueryResponse]) -async def get_query(query: QueryModel = Depends(QueryModel)): +async def get_query( + query: QueryModel = Depends(QueryModel), + token: str | None = Depends(oauth2_scheme), +): """When a GET request is sent, return list of dicts corresponding to subject-level metadata aggregated by dataset.""" + if security.AUTH_ENABLED: + if token is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authenticated", + ) + verify_token(token) + response = await crud.get( query.min_age, query.max_age, diff --git a/app/api/security.py b/app/api/security.py new file mode 100644 index 0000000..3abf4ed --- /dev/null +++ b/app/api/security.py @@ -0,0 +1,45 @@ +"""Functions for handling authentication. Same ones as used in Neurobagel's federation API.""" + +import os + +from fastapi import HTTPException, status +from fastapi.security.utils import get_authorization_scheme_param +from google.auth.exceptions import GoogleAuthError +from google.auth.transport import requests +from google.oauth2 import id_token + +AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true" +CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None) + + +def check_client_id(): + """Check if the CLIENT_ID environment variable is set.""" + # By default, if CLIENT_ID is not provided to verify_oauth2_token, + # Google will simply skip verifying the audience claim of ID tokens. + # This however can be a security risk, so we mandate that CLIENT_ID is set. + if AUTH_ENABLED and CLIENT_ID is None: + raise ValueError( + "Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. " + "Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens." + ) + + +def verify_token(token: str): + """Verify the Google ID token. Raise an HTTPException if the token is invalid.""" + # Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python + try: + # Extract the token from the "Bearer" scheme + # (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485) + # TODO: Check also if scheme of token is "Bearer"? + _, param = get_authorization_scheme_param(token) + id_info = id_token.verify_oauth2_token( + param, requests.Request(), CLIENT_ID + ) + # TODO: Remove print statement or turn into logging + print("Token verified: ", id_info) + except (GoogleAuthError, ValueError) as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid token: {exc}", + headers={"WWW-Authenticate": "Bearer"}, + ) from exc diff --git a/app/main.py b/app/main.py index b869e44..c4a88cd 100644 --- a/app/main.py +++ b/app/main.py @@ -13,6 +13,7 @@ from .api import utility as util from .api.routers import attributes, query +from .api.security import check_client_id app = FastAPI( default_response_class=ORJSONResponse, docs_url=None, redoc_url=None @@ -77,7 +78,14 @@ def overridden_redoc(): @app.on_event("startup") async def auth_check(): - """Checks whether username and password environment variables are set.""" + """ + Checks whether authentication has been enabled for API queries and whether the + username and password environment variables for the graph backend have been set. + + TODO: Refactor once startup events have been replaced by lifespan event + """ + check_client_id() + if ( # TODO: Check if this error is still raised when variables are empty strings os.environ.get(util.GRAPH_USERNAME.name) is None diff --git a/requirements.txt b/requirements.txt index 22b4618..edc8b5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,17 @@ anyio==3.6.2 attrs==22.1.0 +cachetools==5.3.3 certifi==2024.7.4 cfgv==3.3.1 -coverage==7.0.0 +charset-normalizer==3.3.2 click==8.1.3 +colorama==0.4.6 +coverage==7.0.0 distlib==0.3.6 exceptiongroup==1.0.4 fastapi==0.110.1 filelock==3.8.0 +google-auth==2.32.0 h11==0.14.0 httpcore==0.16.2 httpx==0.23.1 @@ -22,15 +26,23 @@ pandas==1.5.2 platformdirs==2.5.4 pluggy==1.0.0 pre-commit==3.6.0 +pyasn1==0.6.0 +pyasn1_modules==0.4.0 pydantic==1.10.13 pyparsing==3.0.9 pytest==7.2.0 +python-dateutil==2.8.2 +pytz==2022.7 PyYAML==6.0 +requests==2.32.3 rfc3986==1.5.0 +rsa==4.9 +six==1.16.0 sniffio==1.3.0 starlette==0.37.2 toml==0.10.2 tomli==2.0.1 typing_extensions==4.11.0 +urllib3==2.2.2 uvicorn==0.20.0 virtualenv==20.16.7 diff --git a/tests/conftest.py b/tests/conftest.py index bb9166d..5d03c2d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,15 @@ def test_app(): yield client +@pytest.fixture +def disable_auth(monkeypatch): + """ + Disable the authentication requirement for the API to skip startup checks + (for when the tested route does not require authentication). + """ + monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) + + @pytest.fixture(scope="function") def set_test_credentials(monkeypatch): """Set random username and password to avoid error from startup check for set credentials.""" @@ -18,6 +27,30 @@ def set_test_credentials(monkeypatch): monkeypatch.setenv(util.GRAPH_PASSWORD.name, "SomePassword") +@pytest.fixture() +def mock_verify_token(): + """Mock a successful token verification that does not raise any exceptions.""" + + def _verify_token(token): + return None + + return _verify_token + + +@pytest.fixture() +def set_mock_verify_token(monkeypatch, mock_verify_token): + """Set the verify_token function to a mock that does not raise any exceptions.""" + monkeypatch.setattr( + "app.api.routers.query.verify_token", mock_verify_token + ) + + +@pytest.fixture() +def mock_auth_header() -> dict: + """Create an authorization header with a mock token that is well-formed for testing purposes.""" + return {"Authorization": "Bearer foo"} + + @pytest.fixture() def test_data(): """Create valid aggregate response data for two toy datasets for testing.""" diff --git a/tests/test_app_events.py b/tests/test_app_events.py index 5bc839e..eaaf001 100644 --- a/tests/test_app_events.py +++ b/tests/test_app_events.py @@ -10,7 +10,10 @@ from app.api import utility as util -def test_start_app_without_environment_vars_fails(test_app, monkeypatch): +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") +def test_start_app_without_environment_vars_fails( + test_app, monkeypatch, disable_auth +): """Given non-existing username and password environment variables, raises an informative RuntimeError.""" monkeypatch.delenv(util.GRAPH_USERNAME.name, raising=False) monkeypatch.delenv(util.GRAPH_PASSWORD.name, raising=False) @@ -24,8 +27,11 @@ def test_start_app_without_environment_vars_fails(test_app, monkeypatch): ) -def test_app_with_invalid_environment_vars(test_app, monkeypatch): - """Given invalid environment variables, returns a 401 status code.""" +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") +def test_app_with_invalid_environment_vars( + test_app, monkeypatch, mock_auth_header, set_mock_verify_token +): + """Given invalid environment variables for the graph, returns a 401 status code.""" monkeypatch.setenv(util.GRAPH_USERNAME.name, "something") monkeypatch.setenv(util.GRAPH_PASSWORD.name, "cool") @@ -33,12 +39,15 @@ def mock_httpx_post(**kwargs): return httpx.Response(status_code=401) monkeypatch.setattr(httpx, "post", mock_httpx_post) - response = test_app.get("/query/") + response = test_app.get("/query/", headers=mock_auth_header) assert response.status_code == 401 def test_app_with_unset_allowed_origins( - test_app, monkeypatch, set_test_credentials + test_app, + monkeypatch, + set_test_credentials, + disable_auth, ): """Tests that when the environment variable for allowed origins has not been set, a warning is raised and the app uses a default value.""" monkeypatch.delenv(util.ALLOWED_ORIGINS.name, raising=False) @@ -90,6 +99,7 @@ def test_app_with_set_allowed_origins( allowed_origins, parsed_origins, expectation, + disable_auth, ): """ Test that when the environment variable for allowed origins has been explicitly set, the app correctly parses it into a list @@ -108,8 +118,11 @@ def test_app_with_set_allowed_origins( ) +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_stored_vocab_lookup_file_created_on_startup( - test_app, set_test_credentials + test_app, + set_test_credentials, + disable_auth, ): """Test that on startup, a non-empty temporary lookup file is created for term ID-label mappings for the locally stored SNOMED CT vocabulary.""" with test_app: @@ -118,8 +131,9 @@ def test_stored_vocab_lookup_file_created_on_startup( assert term_labels_path.stat().st_size > 0 +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_external_vocab_is_fetched_on_startup( - test_app, monkeypatch, set_test_credentials + test_app, monkeypatch, set_test_credentials, disable_auth ): """ Tests that on startup, a GET request is made to the Cognitive Atlas API and that when the request succeeds, @@ -160,8 +174,9 @@ def mock_httpx_get(**kwargs): } +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_failed_vocab_fetching_on_startup_raises_warning( - test_app, monkeypatch, set_test_credentials + test_app, monkeypatch, set_test_credentials, disable_auth ): """ Tests that when a GET request to the Cognitive Atlas API has a non-success response code (e.g., due to service being unavailable), @@ -186,8 +201,9 @@ def mock_httpx_get(**kwargs): ) +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") def test_network_error_on_startup_raises_warning( - test_app, monkeypatch, set_test_credentials + test_app, monkeypatch, set_test_credentials, disable_auth ): """ Tests that when a GET request to the Cognitive Atlas API fails due to a network error (i.e., while issuing the request), diff --git a/tests/test_attributes.py b/tests/test_attributes.py index c8e15a9..3be37a2 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -7,17 +7,17 @@ from app.api import utility as util -def test_root(test_app, set_test_credentials): +def test_root(test_app): """Given a GET request to the root endpoint, Check for 200 status and expected content.""" - with test_app: - response = test_app.get("/") + response = test_app.get("/") assert response.status_code == 200 assert "Welcome to the Neurobagel REST API!" in response.text assert 'documentation' in response.text +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") @pytest.mark.parametrize( "valid_data_element_URI", ["nb:Diagnosis", "nb:Assessment"], @@ -28,6 +28,7 @@ def test_get_terms_valid_data_element_URI( mock_successful_get_terms, valid_data_element_URI, monkeypatch, + disable_auth, ): """Given a valid data element URI, returns a 200 status code and a non-empty list of terms for that data element.""" monkeypatch.setattr(crud, "get_terms", mock_successful_get_terms) @@ -54,7 +55,7 @@ def test_get_terms_invalid_data_element_URI( def test_get_terms_for_attribute_with_vocab_lookup( - test_app, monkeypatch, set_test_credentials + test_app, monkeypatch, set_test_credentials, disable_auth ): """ Given a valid data element URI with a vocabulary lookup file available, returns prefixed term URIs and their human-readable labels (where found) diff --git a/tests/test_query.py b/tests/test_query.py index f5b3feb..752b5a3 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -58,6 +58,8 @@ def test_null_modalities( mock_post_query_to_graph, mock_query_matching_dataset_sizes, monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a response containing a dataset with no recorded modalities, returns an empty list for the imaging modalities.""" @@ -66,17 +68,23 @@ def test_null_modalities( crud, "query_matching_dataset_sizes", mock_query_matching_dataset_sizes ) - response = test_app.get("/query/") + response = test_app.get("/query/", headers=mock_auth_header) assert response.json()[0]["image_modals"] == [ "http://purl.org/nidash/nidm#T1Weighted" ] -def test_get_all(test_app, mock_successful_get, monkeypatch): +def test_get_all( + test_app, + mock_successful_get, + monkeypatch, + mock_auth_header, + set_mock_verify_token, +): """Given no input for any query parameters, returns a 200 status code and a non-empty list of results (should correspond to all subjects in graph).""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get("/query/") + response = test_app.get("/query/", headers=mock_auth_header) assert response.status_code == 200 assert response.json() != [] @@ -86,13 +94,20 @@ def test_get_all(test_app, mock_successful_get, monkeypatch): [(30.5, 60), (23, 23)], ) def test_get_valid_age_range( - test_app, mock_successful_get, valid_min_age, valid_max_age, monkeypatch + test_app, + mock_successful_get, + valid_min_age, + valid_max_age, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid age range, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) response = test_app.get( - f"/query/?min_age={valid_min_age}&max_age={valid_max_age}" + f"/query/?min_age={valid_min_age}&max_age={valid_max_age}", + headers=mock_auth_header, ) assert response.status_code == 200 assert response.json() != [] @@ -103,12 +118,17 @@ def test_get_valid_age_range( ["min_age=20.75", "max_age=50"], ) def test_get_valid_age_single_bound( - test_app, mock_successful_get, age_keyval, monkeypatch + test_app, + mock_successful_get, + age_keyval, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given only a valid lower/upper age bound, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get(f"/query/?{age_keyval}") + response = test_app.get(f"/query/?{age_keyval}", headers=mock_auth_header) assert response.status_code == 200 assert response.json() != [] @@ -123,13 +143,20 @@ def test_get_valid_age_single_bound( ], ) def test_get_invalid_age( - test_app, mock_get, invalid_min_age, invalid_max_age, monkeypatch + test_app, + mock_get, + invalid_min_age, + invalid_max_age, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given an invalid age range, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) response = test_app.get( - f"/query/?min_age={invalid_min_age}&max_age={invalid_max_age}" + f"/query/?min_age={invalid_min_age}&max_age={invalid_max_age}", + headers=mock_auth_header, ) assert response.status_code == 422 @@ -138,21 +165,32 @@ def test_get_invalid_age( "valid_sex", ["snomed:248153007", "snomed:248152002", "snomed:32570681000036106"], ) -def test_get_valid_sex(test_app, mock_successful_get, valid_sex, monkeypatch): +def test_get_valid_sex( + test_app, + mock_successful_get, + valid_sex, + monkeypatch, + mock_auth_header, + set_mock_verify_token, +): """Given a valid sex string, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get(f"/query/?sex={valid_sex}") + response = test_app.get( + f"/query/?sex={valid_sex}", headers=mock_auth_header + ) assert response.status_code == 200 assert response.json() != [] @pytest.mark.parametrize("mock_get", [None], indirect=True) -def test_get_invalid_sex(test_app, mock_get, monkeypatch): +def test_get_invalid_sex( + test_app, mock_get, monkeypatch, mock_auth_header, set_mock_verify_token +): """Given an invalid sex string, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) - response = test_app.get("/query/?sex=apple") + response = test_app.get("/query/?sex=apple", headers=mock_auth_header) assert response.status_code == 422 @@ -160,12 +198,19 @@ def test_get_invalid_sex(test_app, mock_get, monkeypatch): "valid_diagnosis", ["snomed:35489007", "snomed:49049000"] ) def test_get_valid_diagnosis( - test_app, mock_successful_get, valid_diagnosis, monkeypatch + test_app, + mock_successful_get, + valid_diagnosis, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid diagnosis, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get(f"/query/?diagnosis={valid_diagnosis}") + response = test_app.get( + f"/query/?diagnosis={valid_diagnosis}", headers=mock_auth_header + ) assert response.status_code == 200 assert response.json() != [] @@ -175,43 +220,64 @@ def test_get_valid_diagnosis( "invalid_diagnosis", ["sn0med:35489007", "apple", ":123456"] ) def test_get_invalid_diagnosis( - test_app, mock_get, invalid_diagnosis, monkeypatch + test_app, + mock_get, + invalid_diagnosis, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given an invalid diagnosis, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) - response = test_app.get(f"/query/?diagnosis={invalid_diagnosis}") + response = test_app.get( + f"/query/?diagnosis={invalid_diagnosis}", headers=mock_auth_header + ) assert response.status_code == 422 @pytest.mark.parametrize("valid_iscontrol", [True, False]) def test_get_valid_iscontrol( - test_app, mock_successful_get, valid_iscontrol, monkeypatch + test_app, + mock_successful_get, + valid_iscontrol, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid is_control value, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get(f"/query/?is_control={valid_iscontrol}") + response = test_app.get( + f"/query/?is_control={valid_iscontrol}", headers=mock_auth_header + ) assert response.status_code == 200 assert response.json() != [] @pytest.mark.parametrize("mock_get", [None], indirect=True) -def test_get_invalid_iscontrol(test_app, mock_get, monkeypatch): +def test_get_invalid_iscontrol( + test_app, mock_get, monkeypatch, mock_auth_header, set_mock_verify_token +): """Given a non-boolean is_control value, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) - response = test_app.get("/query/?is_control=apple") + response = test_app.get( + "/query/?is_control=apple", headers=mock_auth_header + ) assert response.status_code == 422 @pytest.mark.parametrize("mock_get", [None], indirect=True) -def test_get_invalid_control_diagnosis_pair(test_app, mock_get, monkeypatch): +def test_get_invalid_control_diagnosis_pair( + test_app, mock_get, monkeypatch, mock_auth_header, set_mock_verify_token +): """Given a non-default diagnosis value and is_control value of True, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) response = test_app.get( - "/query/?diagnosis=snomed:35489007&is_control=True" + "/query/?diagnosis=snomed:35489007&is_control=True", + headers=mock_auth_header, ) assert response.status_code == 422 assert ( @@ -232,12 +298,15 @@ def test_get_valid_min_num_sessions( session_param, valid_min_num_sessions, monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid minimum number of imaging sessions, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) response = test_app.get( - f"/query/?{session_param}={valid_min_num_sessions}" + f"/query/?{session_param}={valid_min_num_sessions}", + headers=mock_auth_header, ) assert response.status_code == 200 assert response.json() != [] @@ -255,21 +324,32 @@ def test_get_invalid_min_num_sessions( session_param, invalid_min_num_sessions, monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given an invalid minimum number of imaging sessions, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) response = test_app.get( - f"/query/?{session_param}={invalid_min_num_sessions}" + f"/query/?{session_param}={invalid_min_num_sessions}", + headers=mock_auth_header, ) response.status_code = 422 -def test_get_valid_assessment(test_app, mock_successful_get, monkeypatch): +def test_get_valid_assessment( + test_app, + mock_successful_get, + monkeypatch, + mock_auth_header, + set_mock_verify_token, +): """Given a valid assessment, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) - response = test_app.get("/query/?assessment=nb:cogAtlas-1234") + response = test_app.get( + "/query/?assessment=nb:cogAtlas-1234", headers=mock_auth_header + ) assert response.status_code == 200 assert response.json() != [] @@ -279,12 +359,19 @@ def test_get_valid_assessment(test_app, mock_successful_get, monkeypatch): "invalid_assessment", ["bg01:cogAtlas-1234", "cogAtlas-1234"] ) def test_get_invalid_assessment( - test_app, mock_get, invalid_assessment, monkeypatch + test_app, + mock_get, + invalid_assessment, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given an invalid assessment, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) - response = test_app.get(f"/query/?assessment={invalid_assessment}") + response = test_app.get( + f"/query/?assessment={invalid_assessment}", headers=mock_auth_header + ) assert response.status_code == 422 @@ -299,13 +386,19 @@ def test_get_invalid_assessment( ], ) def test_get_valid_available_image_modal( - test_app, mock_successful_get, valid_available_image_modal, monkeypatch + test_app, + mock_successful_get, + valid_available_image_modal, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid and available image modality, returns a 200 status code and a non-empty list of results.""" monkeypatch.setattr(crud, "get", mock_successful_get) response = test_app.get( - f"/query/?image_modal={valid_available_image_modal}" + f"/query/?image_modal={valid_available_image_modal}", + headers=mock_auth_header, ) assert response.status_code == 200 assert response.json() != [] @@ -317,13 +410,19 @@ def test_get_valid_available_image_modal( ["nidm:Flair", "owl:sameAs", "nb:FlowWeighted", "snomed:something"], ) def test_get_valid_unavailable_image_modal( - test_app, valid_unavailable_image_modal, mock_get, monkeypatch + test_app, + valid_unavailable_image_modal, + mock_get, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid, pre-defined, and unavailable image modality, returns a 200 status code and an empty list of results.""" monkeypatch.setattr(crud, "get", mock_get) response = test_app.get( - f"/query/?image_modal={valid_unavailable_image_modal}" + f"/query/?image_modal={valid_unavailable_image_modal}", + headers=mock_auth_header, ) assert response.status_code == 200 @@ -335,12 +434,19 @@ def test_get_valid_unavailable_image_modal( "invalid_image_modal", ["2nim:EEG", "apple", "some_thing:cool"] ) def test_get_invalid_image_modal( - test_app, mock_get, invalid_image_modal, monkeypatch + test_app, + mock_get, + invalid_image_modal, + monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given an invalid image modality, returns a 422 status code.""" monkeypatch.setattr(crud, "get", mock_get) - response = test_app.get(f"/query/?image_modal={invalid_image_modal}") + response = test_app.get( + f"/query/?image_modal={invalid_image_modal}", headers=mock_auth_header + ) assert response.status_code == 422 @@ -356,12 +462,15 @@ def test_get_undefined_prefix_image_modal( undefined_prefix_image_modal, mock_get_with_exception, monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Given a valid and undefined prefix image modality, returns a 500 status code.""" monkeypatch.setattr(crud, "get", mock_get_with_exception) response = test_app.get( - f"/query/?image_modal={undefined_prefix_image_modal}" + f"/query/?image_modal={undefined_prefix_image_modal}", + headers=mock_auth_header, ) assert response.status_code == 500 @@ -372,6 +481,8 @@ def test_aggregate_query_response_structure( mock_post_query_to_graph, mock_query_matching_dataset_sizes, monkeypatch, + mock_auth_header, + set_mock_verify_token, ): """Test that when aggregate results are enabled, a cohort query response has the expected structure.""" monkeypatch.setenv(util.RETURN_AGG.name, "true") @@ -380,7 +491,21 @@ def test_aggregate_query_response_structure( crud, "query_matching_dataset_sizes", mock_query_matching_dataset_sizes ) - response = test_app.get("/query/") + response = test_app.get("/query/", headers=mock_auth_header) assert all( dataset["subject_data"] == "protected" for dataset in response.json() ) + + +def test_query_without_token_succeeds_when_auth_disabled( + test_app, + mock_successful_get, + monkeypatch, + disable_auth, +): + """ + Test that when authentication is disabled, a request to the /query route without a token succeeds. + """ + monkeypatch.setattr(crud, "get", mock_successful_get) + response = test_app.get("/query/") + assert response.status_code == 200 diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..6eb5b32 --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,64 @@ +import pytest +from fastapi import HTTPException + +from app.api.security import verify_token + + +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") +def test_missing_client_id_raises_error_when_auth_enabled( + monkeypatch, test_app, set_test_credentials +): + """Test that a missing client ID raises an error on startup when authentication is enabled.""" + # We're using what should be default values of CLIENT_ID and AUTH_ENABLED here + # (if the corresponding environment variables are unset), + # but we set the values explicitly here for clarity + monkeypatch.setattr("app.api.security.CLIENT_ID", None) + monkeypatch.setattr("app.api.security.AUTH_ENABLED", True) + + with pytest.raises(ValueError) as exc_info: + with test_app: + pass + + assert "NB_QUERY_CLIENT_ID is not set" in str(exc_info.value) + + +@pytest.mark.filterwarnings("ignore:.*NB_API_ALLOWED_ORIGINS") +def test_missing_client_id_ignored_when_auth_disabled( + monkeypatch, test_app, set_test_credentials +): + """Test that a missing client ID does not raise an error when authentication is disabled.""" + monkeypatch.setattr("app.api.security.CLIENT_ID", None) + monkeypatch.setattr("app.api.security.AUTH_ENABLED", False) + + with test_app: + pass + + +@pytest.mark.parametrize( + "invalid_token", + ["Bearer faketoken", "Bearer", "faketoken", "fakescheme faketoken"], +) +def test_invalid_token_raises_error(invalid_token): + """Test that an invalid token raises an error from the verification process.""" + with pytest.raises(HTTPException) as exc_info: + verify_token(invalid_token) + + assert exc_info.value.status_code == 401 + assert "Invalid token" in exc_info.value.detail + + +@pytest.mark.parametrize( + "invalid_auth_header", + [{}, {"Authorization": ""}, {"badheader": "badvalue"}], +) +def test_query_with_malformed_auth_header_fails( + test_app, set_mock_verify_token, invalid_auth_header +): + """Test that a request to the /query route with a missing or malformed authorization header, fails .""" + + response = test_app.get( + "/query/", + headers=invalid_auth_header, + ) + + assert response.status_code == 403