Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REF] Manually verify ID token using PyJWT instead of google_auth #139

Merged
merged 10 commits into from
Dec 2, 2024
2 changes: 1 addition & 1 deletion app/api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def get(
node_urls : list[str]
List of Neurobagel nodes to send the query to.
token : str, optional
Google ID token for authentication, by default None
ID token for authentication, by default None

Returns
-------
Expand Down
8 changes: 5 additions & 3 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@

from .. import crud, security
from ..models import CombinedQueryResponse, QueryModel
from ..security import verify_and_extract_token
from ..security import verify_token

# from fastapi.security import open_id_connect_url


router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
# I believe for us this is purely for documentatation/a nice looking interactive API docs page,
# and doesn't actually have any bearing on the ID token validation process.
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
"authorizationUrl": "https://neurobagel.ca.auth0.com/authorize",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
Expand Down Expand Up @@ -61,7 +63,7 @@ async def get_query(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
token = verify_and_extract_token(token)
token = verify_token(token)

response_dict = await crud.get(
query.min_age,
Expand Down
60 changes: 41 additions & 19 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,67 @@
import os

import jwt
from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token
from jwt import PyJWKClient, PyJWTError

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)

KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
ISSUER = "https://neurobagel.ca.auth0.com/"
# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations
# by automatically fetching updated keys when needed
# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115
JWKS_CLIENT = PyJWKClient(KEYS_URL)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)


def verify_and_extract_token(token: str) -> str:
def extract_token(token: str) -> str:
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
"""
Extract the token from the authorization header.
This ensures that it is passed on to downstream APIs without the authorization scheme.
"""
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, extracted_token = get_authorization_scheme_param(token)
return extracted_token


def verify_token(token: str) -> str:
"""
Verify and return the Google ID token with the authorization scheme stripped.
Verify the ID token against the identity provider public keys, and return the token with the authorization scheme stripped.
Raise an HTTPException if the token is invalid.
"""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, extracted_token = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
extracted_token, requests.Request(), CLIENT_ID
extracted_token = extract_token(token)
surchs marked this conversation as resolved.
Show resolved Hide resolved
# Determine which key was used to sign the token
# Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint
signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token)

# https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode
jwt.decode(
jwt=extracted_token,
key=signing_key,
options={
"verify_signature": True,
"require": ["aud", "iss", "exp", "iat"],
alyssadai marked this conversation as resolved.
Show resolved Hide resolved
},
audience=CLIENT_ID,
issuer=ISSUER,
)
# TODO: Remove print statement or turn into logging
print("Token verified: ", id_info)
return extracted_token
except (GoogleAuthError, ValueError) as exc:
except (PyJWTError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ anyio==3.6.2
attrs==23.1.0
cachetools==5.3.3
certifi==2023.7.22
cffi==1.17.1
cfgv==3.3.1
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
cryptography==44.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -30,7 +31,9 @@ pluggy==1.0.0
pre-commit==2.20.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==1.10.2
PyJWT==2.10.1
pyparsing==3.0.9
pytest==7.2.0
pytest-asyncio==0.23.7
Expand Down
14 changes: 6 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ def disable_auth(monkeypatch):


@pytest.fixture()
def mock_verify_and_extract_token():
def mock_verify_token():
"""Mock a successful token verification that does not raise any exceptions."""

def _verify_and_extract_token(token):
def _verify_token(token):
return None

return _verify_and_extract_token
return _verify_token


@pytest.fixture()
def set_mock_verify_and_extract_token(
monkeypatch, mock_verify_and_extract_token
):
def set_mock_verify_token(monkeypatch, mock_verify_token):
"""Set the verify_token function to a mock that does not raise any exceptions."""
monkeypatch.setattr(
"app.api.routers.query.verify_and_extract_token",
mock_verify_and_extract_token,
"app.api.routers.query.verify_token",
mock_verify_token,
)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_partial_node_failure_responses_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
caplog,
):
"""
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_partial_node_request_failures_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
error_to_raise,
expected_node_message,
caplog,
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_all_nodes_failure_handled_gracefully(
test_app,
mock_failed_connection_httpx_get,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
set_valid_test_federation_nodes,
caplog,
):
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_all_nodes_success_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
):
"""
Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors.
Expand Down
60 changes: 33 additions & 27 deletions tests/test_security.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from fastapi import HTTPException
from google.oauth2 import id_token

from app.api.security import verify_and_extract_token
from app.api.security import extract_token, verify_token


def test_missing_client_id_raises_error_when_auth_enabled(
Expand Down Expand Up @@ -38,10 +37,12 @@ def test_missing_client_id_ignored_when_auth_disabled(monkeypatch, test_app):
"invalid_token",
["Bearer faketoken", "Bearer", "faketoken", "fakescheme faketoken"],
)
def test_invalid_token_raises_error(invalid_token):
def test_invalid_token_raises_error(monkeypatch, invalid_token):
"""Test that an invalid token raises an error from the verification process."""
monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id")

with pytest.raises(HTTPException) as exc_info:
verify_and_extract_token(invalid_token)
verify_token(invalid_token)

assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
Expand All @@ -53,7 +54,7 @@ def test_invalid_token_raises_error(invalid_token):
)
def test_query_with_malformed_auth_header_fails(
test_app,
set_mock_verify_and_extract_token,
set_mock_verify_token,
enable_auth,
invalid_auth_header,
monkeypatch,
Expand All @@ -72,33 +73,38 @@ def test_query_with_malformed_auth_header_fails(
assert response.status_code == 403
surchs marked this conversation as resolved.
Show resolved Hide resolved


def test_verified_token_returned_without_auth_scheme(monkeypatch, enable_auth):
def test_token_returned_without_auth_scheme(monkeypatch, enable_auth):
"""
Test that when a token is valid, verify_token correctly returns the token with the authorization scheme stripped.
"""
mock_valid_token = "Bearer foo"
mock_id_info = {
"iss": "https://accounts.google.com",
"azp": "123abc.apps.googleusercontent.com",
"aud": "123abc.apps.googleusercontent.com",
"sub": "1234567890",
"email": "[email protected]",
"email_verified": True,
"nbf": 1730476622,
"name": "Jane Doe",
"picture": "https://lh3.googleusercontent.com/a/example1234567890",
"given_name": "Jane",
"family_name": "Doe",
"iat": 1730476922,
"exp": 1730480522,
"jti": "123e4567-e89b",
}

def mock_oauth2_verify_token(param, request, client_id, **kwargs):
return mock_id_info
assert extract_token(mock_valid_token) == "foo"


def test_valid_token_does_not_error_out(monkeypatch, enable_auth):
"""
Test that when a valid token is passed to verify_token, the token is returned without errors.
"""

def mock_get_signing_key_from_jwt(*args, **kwargs):
# NOTE: The actual get_signing_key_from_jwt method should return a key object
return "signingkey"

def mock_jwt_decode(*args, **kwargs):
return {
"iss": "https://myissuer.com",
"aud": "123abc.myapp.com",
"sub": "1234567890",
"name": "John Doe",
"iat": 1730476922,
"exp": 1730480522,
}

monkeypatch.setattr("app.api.security.CLIENT_ID", "123abc.myapp.com")
monkeypatch.setattr(
id_token, "verify_oauth2_token", mock_oauth2_verify_token
"app.api.security.JWKS_CLIENT.get_signing_key_from_jwt",
mock_get_signing_key_from_jwt,
)
monkeypatch.setattr("app.api.security.jwt.decode", mock_jwt_decode)

assert verify_and_extract_token(mock_valid_token) == "foo"
assert verify_token("Bearer myvalidtoken") == "myvalidtoken"
Loading