From 9c32acbf58cc69067b644eb312cdf42d3d7727d9 Mon Sep 17 00:00:00 2001 From: Alyssa Dai Date: Mon, 2 Dec 2024 18:31:33 -0500 Subject: [PATCH] [REF] Manually verify ID token using PyJWT instead of google_auth (#139) * manually verify ID token using PyJWT instead of google_auth * update issuer and keys URL for auth0, update comments * update authorization endpoint for oauth2_scheme docs * refactor out token extraction from verification * update tests * remove google-auth from dependencies * update docstring * refactor our JKWS client into a constant * add positive unit test for verify_token * remove print statement --- app/api/crud.py | 2 +- app/api/routers/query.py | 8 ++++-- app/api/security.py | 60 +++++++++++++++++++++++++++------------- requirements.txt | 5 +++- tests/conftest.py | 14 ++++------ tests/test_query.py | 8 +++--- tests/test_security.py | 60 ++++++++++++++++++++++------------------ 7 files changed, 94 insertions(+), 63 deletions(-) diff --git a/app/api/crud.py b/app/api/crud.py index 9715648..ddcc127 100644 --- a/app/api/crud.py +++ b/app/api/crud.py @@ -77,7 +77,7 @@ async def get( node_urls : list[str] List of Neurobagel nodes to send the query to. token : str, optional - Google ID token for authentication, by default None + ID token for authentication, by default None Returns ------- diff --git a/app/api/routers/query.py b/app/api/routers/query.py index b42ba66..77bd253 100644 --- a/app/api/routers/query.py +++ b/app/api/routers/query.py @@ -5,7 +5,7 @@ from .. import crud, security from ..models import CombinedQueryResponse, QueryModel -from ..security import verify_and_extract_token +from ..security import verify_token # from fastapi.security import open_id_connect_url @@ -13,10 +13,12 @@ router = APIRouter(prefix="/query", tags=["query"]) # Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382 +# I believe for us this is purely for documentatation/a nice looking interactive API docs page, +# and doesn't actually have any bearing on the ID token validation process. oauth2_scheme = OAuth2( flows={ "implicit": { - "authorizationUrl": "https://accounts.google.com/o/oauth2/auth", + "authorizationUrl": "https://neurobagel.ca.auth0.com/authorize", } }, # Don't automatically error out when request is not authenticated, to support optional authentication @@ -61,7 +63,7 @@ async def get_query( status_code=status.HTTP_403_FORBIDDEN, detail="Not authenticated", ) - token = verify_and_extract_token(token) + token = verify_token(token) response_dict = await crud.get( query.min_age, diff --git a/app/api/security.py b/app/api/security.py index 62625c6..aca9b1c 100644 --- a/app/api/security.py +++ b/app/api/security.py @@ -1,45 +1,67 @@ import os +import jwt from fastapi import HTTPException, status from fastapi.security.utils import get_authorization_scheme_param -from google.auth.exceptions import GoogleAuthError -from google.auth.transport import requests -from google.oauth2 import id_token +from jwt import PyJWKClient, PyJWTError AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true" CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None) +KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json" +ISSUER = "https://neurobagel.ca.auth0.com/" +# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations +# by automatically fetching updated keys when needed +# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115 +JWKS_CLIENT = PyJWKClient(KEYS_URL) + def check_client_id(): """Check if the CLIENT_ID environment variable is set.""" - # By default, if CLIENT_ID is not provided to verify_oauth2_token, - # Google will simply skip verifying the audience claim of ID tokens. - # This however can be a security risk, so we mandate that CLIENT_ID is set. + # The CLIENT_ID is needed to verify the audience claim of ID tokens. if AUTH_ENABLED and CLIENT_ID is None: raise ValueError( "Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. " - "Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens." + "Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens." ) -def verify_and_extract_token(token: str) -> str: +def extract_token(token: str) -> str: + """ + Extract the token from the authorization header. + This ensures that it is passed on to downstream APIs without the authorization scheme. + """ + # Extract the token from the "Bearer" scheme + # (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485) + # TODO: Check also if scheme of token is "Bearer"? + _, extracted_token = get_authorization_scheme_param(token) + return extracted_token + + +def verify_token(token: str) -> str: """ - Verify and return the Google ID token with the authorization scheme stripped. + Verify the ID token against the identity provider public keys, and return the token with the authorization scheme stripped. Raise an HTTPException if the token is invalid. """ - # Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python try: - # Extract the token from the "Bearer" scheme - # (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485) - # TODO: Check also if scheme of token is "Bearer"? - _, extracted_token = get_authorization_scheme_param(token) - id_info = id_token.verify_oauth2_token( - extracted_token, requests.Request(), CLIENT_ID + extracted_token = extract_token(token) + # Determine which key was used to sign the token + # Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint + signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token) + + # https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode + jwt.decode( + jwt=extracted_token, + key=signing_key, + options={ + "verify_signature": True, + "require": ["aud", "iss", "exp", "iat"], + }, + audience=CLIENT_ID, + issuer=ISSUER, ) - # TODO: Remove print statement or turn into logging - print("Token verified: ", id_info) return extracted_token - except (GoogleAuthError, ValueError) as exc: + except (PyJWTError, ValueError) as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Invalid token: {exc}", diff --git a/requirements.txt b/requirements.txt index 6c7adc9..d3e87c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,16 +2,17 @@ anyio==3.6.2 attrs==23.1.0 cachetools==5.3.3 certifi==2023.7.22 +cffi==1.17.1 cfgv==3.3.1 charset-normalizer==3.3.2 click==8.1.3 colorama==0.4.6 coverage==7.0.0 +cryptography==44.0.0 distlib==0.3.6 exceptiongroup==1.0.4 fastapi==0.110.1 filelock==3.8.0 -google-auth==2.32.0 h11==0.14.0 httpcore==0.16.2 httpx==0.23.1 @@ -30,7 +31,9 @@ pluggy==1.0.0 pre-commit==2.20.0 pyasn1==0.6.0 pyasn1_modules==0.4.0 +pycparser==2.22 pydantic==1.10.2 +PyJWT==2.10.1 pyparsing==3.0.9 pytest==7.2.0 pytest-asyncio==0.23.7 diff --git a/tests/conftest.py b/tests/conftest.py index cdce163..4e12217 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,23 +28,21 @@ def disable_auth(monkeypatch): @pytest.fixture() -def mock_verify_and_extract_token(): +def mock_verify_token(): """Mock a successful token verification that does not raise any exceptions.""" - def _verify_and_extract_token(token): + def _verify_token(token): return None - return _verify_and_extract_token + return _verify_token @pytest.fixture() -def set_mock_verify_and_extract_token( - monkeypatch, mock_verify_and_extract_token -): +def set_mock_verify_token(monkeypatch, mock_verify_token): """Set the verify_token function to a mock that does not raise any exceptions.""" monkeypatch.setattr( - "app.api.routers.query.verify_and_extract_token", - mock_verify_and_extract_token, + "app.api.routers.query.verify_token", + mock_verify_token, ) diff --git a/tests/test_query.py b/tests/test_query.py index c06d9c1..5891005 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -21,7 +21,7 @@ def test_partial_node_failure_responses_handled_gracefully( set_valid_test_federation_nodes, mocked_single_matching_dataset_result, mock_token, - set_mock_verify_and_extract_token, + set_mock_verify_token, caplog, ): """ @@ -101,7 +101,7 @@ def test_partial_node_request_failures_handled_gracefully( set_valid_test_federation_nodes, mocked_single_matching_dataset_result, mock_token, - set_mock_verify_and_extract_token, + set_mock_verify_token, error_to_raise, expected_node_message, caplog, @@ -155,7 +155,7 @@ def test_all_nodes_failure_handled_gracefully( test_app, mock_failed_connection_httpx_get, mock_token, - set_mock_verify_and_extract_token, + set_mock_verify_token, set_valid_test_federation_nodes, caplog, ): @@ -193,7 +193,7 @@ def test_all_nodes_success_handled_gracefully( set_valid_test_federation_nodes, mocked_single_matching_dataset_result, mock_token, - set_mock_verify_and_extract_token, + set_mock_verify_token, ): """ Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors. diff --git a/tests/test_security.py b/tests/test_security.py index 176d537..4957f2f 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,8 +1,7 @@ import pytest from fastapi import HTTPException -from google.oauth2 import id_token -from app.api.security import verify_and_extract_token +from app.api.security import extract_token, verify_token def test_missing_client_id_raises_error_when_auth_enabled( @@ -38,10 +37,12 @@ def test_missing_client_id_ignored_when_auth_disabled(monkeypatch, test_app): "invalid_token", ["Bearer faketoken", "Bearer", "faketoken", "fakescheme faketoken"], ) -def test_invalid_token_raises_error(invalid_token): +def test_invalid_token_raises_error(monkeypatch, invalid_token): """Test that an invalid token raises an error from the verification process.""" + monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id") + with pytest.raises(HTTPException) as exc_info: - verify_and_extract_token(invalid_token) + verify_token(invalid_token) assert exc_info.value.status_code == 401 assert "Invalid token" in exc_info.value.detail @@ -53,7 +54,7 @@ def test_invalid_token_raises_error(invalid_token): ) def test_query_with_malformed_auth_header_fails( test_app, - set_mock_verify_and_extract_token, + set_mock_verify_token, enable_auth, invalid_auth_header, monkeypatch, @@ -72,33 +73,38 @@ def test_query_with_malformed_auth_header_fails( assert response.status_code == 403 -def test_verified_token_returned_without_auth_scheme(monkeypatch, enable_auth): +def test_token_returned_without_auth_scheme(monkeypatch, enable_auth): """ Test that when a token is valid, verify_token correctly returns the token with the authorization scheme stripped. """ mock_valid_token = "Bearer foo" - mock_id_info = { - "iss": "https://accounts.google.com", - "azp": "123abc.apps.googleusercontent.com", - "aud": "123abc.apps.googleusercontent.com", - "sub": "1234567890", - "email": "jane.doe@gmail.com", - "email_verified": True, - "nbf": 1730476622, - "name": "Jane Doe", - "picture": "https://lh3.googleusercontent.com/a/example1234567890", - "given_name": "Jane", - "family_name": "Doe", - "iat": 1730476922, - "exp": 1730480522, - "jti": "123e4567-e89b", - } - - def mock_oauth2_verify_token(param, request, client_id, **kwargs): - return mock_id_info + assert extract_token(mock_valid_token) == "foo" + + +def test_valid_token_does_not_error_out(monkeypatch, enable_auth): + """ + Test that when a valid token is passed to verify_token, the token is returned without errors. + """ + def mock_get_signing_key_from_jwt(*args, **kwargs): + # NOTE: The actual get_signing_key_from_jwt method should return a key object + return "signingkey" + + def mock_jwt_decode(*args, **kwargs): + return { + "iss": "https://myissuer.com", + "aud": "123abc.myapp.com", + "sub": "1234567890", + "name": "John Doe", + "iat": 1730476922, + "exp": 1730480522, + } + + monkeypatch.setattr("app.api.security.CLIENT_ID", "123abc.myapp.com") monkeypatch.setattr( - id_token, "verify_oauth2_token", mock_oauth2_verify_token + "app.api.security.JWKS_CLIENT.get_signing_key_from_jwt", + mock_get_signing_key_from_jwt, ) + monkeypatch.setattr("app.api.security.jwt.decode", mock_jwt_decode) - assert verify_and_extract_token(mock_valid_token) == "foo" + assert verify_token("Bearer myvalidtoken") == "myvalidtoken"