Skip to content

Commit

Permalink
[REF] Manually verify ID token using PyJWT instead of google_auth (#139)
Browse files Browse the repository at this point in the history
* manually verify ID token using PyJWT instead of google_auth

* update issuer and keys URL for auth0, update comments

* update authorization endpoint for oauth2_scheme docs

* refactor out token extraction from verification

* update tests

* remove google-auth from dependencies

* update docstring

* refactor our JKWS client into a constant

* add positive unit test for verify_token

* remove print statement
  • Loading branch information
alyssadai authored Dec 2, 2024
1 parent c4d68a4 commit 9c32acb
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 63 deletions.
2 changes: 1 addition & 1 deletion app/api/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def get(
node_urls : list[str]
List of Neurobagel nodes to send the query to.
token : str, optional
Google ID token for authentication, by default None
ID token for authentication, by default None
Returns
-------
Expand Down
8 changes: 5 additions & 3 deletions app/api/routers/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,20 @@

from .. import crud, security
from ..models import CombinedQueryResponse, QueryModel
from ..security import verify_and_extract_token
from ..security import verify_token

# from fastapi.security import open_id_connect_url


router = APIRouter(prefix="/query", tags=["query"])

# Adapted from info in https://github.com/tiangolo/fastapi/discussions/9137#discussioncomment-5157382
# I believe for us this is purely for documentatation/a nice looking interactive API docs page,
# and doesn't actually have any bearing on the ID token validation process.
oauth2_scheme = OAuth2(
flows={
"implicit": {
"authorizationUrl": "https://accounts.google.com/o/oauth2/auth",
"authorizationUrl": "https://neurobagel.ca.auth0.com/authorize",
}
},
# Don't automatically error out when request is not authenticated, to support optional authentication
Expand Down Expand Up @@ -61,7 +63,7 @@ async def get_query(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
token = verify_and_extract_token(token)
token = verify_token(token)

response_dict = await crud.get(
query.min_age,
Expand Down
60 changes: 41 additions & 19 deletions app/api/security.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,67 @@
import os

import jwt
from fastapi import HTTPException, status
from fastapi.security.utils import get_authorization_scheme_param
from google.auth.exceptions import GoogleAuthError
from google.auth.transport import requests
from google.oauth2 import id_token
from jwt import PyJWKClient, PyJWTError

AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)

KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
ISSUER = "https://neurobagel.ca.auth0.com/"
# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations
# by automatically fetching updated keys when needed
# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115
JWKS_CLIENT = PyJWKClient(KEYS_URL)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
# By default, if CLIENT_ID is not provided to verify_oauth2_token,
# Google will simply skip verifying the audience claim of ID tokens.
# This however can be a security risk, so we mandate that CLIENT_ID is set.
# The CLIENT_ID is needed to verify the audience claim of ID tokens.
if AUTH_ENABLED and CLIENT_ID is None:
raise ValueError(
"Authentication has been enabled (NB_ENABLE_AUTH) but the environment variable NB_QUERY_CLIENT_ID is not set. "
"Please set NB_QUERY_CLIENT_ID to the Google client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
"Please set NB_QUERY_CLIENT_ID to the client ID for your Neurobagel query tool deployment, to verify the audience claim of ID tokens."
)


def verify_and_extract_token(token: str) -> str:
def extract_token(token: str) -> str:
"""
Extract the token from the authorization header.
This ensures that it is passed on to downstream APIs without the authorization scheme.
"""
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, extracted_token = get_authorization_scheme_param(token)
return extracted_token


def verify_token(token: str) -> str:
"""
Verify and return the Google ID token with the authorization scheme stripped.
Verify the ID token against the identity provider public keys, and return the token with the authorization scheme stripped.
Raise an HTTPException if the token is invalid.
"""
# Adapted from https://developers.google.com/identity/gsi/web/guides/verify-google-id-token#python
try:
# Extract the token from the "Bearer" scheme
# (See https://github.com/tiangolo/fastapi/blob/master/fastapi/security/oauth2.py#L473-L485)
# TODO: Check also if scheme of token is "Bearer"?
_, extracted_token = get_authorization_scheme_param(token)
id_info = id_token.verify_oauth2_token(
extracted_token, requests.Request(), CLIENT_ID
extracted_token = extract_token(token)
# Determine which key was used to sign the token
# Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint
signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token)

# https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode
jwt.decode(
jwt=extracted_token,
key=signing_key,
options={
"verify_signature": True,
"require": ["aud", "iss", "exp", "iat"],
},
audience=CLIENT_ID,
issuer=ISSUER,
)
# TODO: Remove print statement or turn into logging
print("Token verified: ", id_info)
return extracted_token
except (GoogleAuthError, ValueError) as exc:
except (PyJWTError, ValueError) as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid token: {exc}",
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ anyio==3.6.2
attrs==23.1.0
cachetools==5.3.3
certifi==2023.7.22
cffi==1.17.1
cfgv==3.3.1
charset-normalizer==3.3.2
click==8.1.3
colorama==0.4.6
coverage==7.0.0
cryptography==44.0.0
distlib==0.3.6
exceptiongroup==1.0.4
fastapi==0.110.1
filelock==3.8.0
google-auth==2.32.0
h11==0.14.0
httpcore==0.16.2
httpx==0.23.1
Expand All @@ -30,7 +31,9 @@ pluggy==1.0.0
pre-commit==2.20.0
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycparser==2.22
pydantic==1.10.2
PyJWT==2.10.1
pyparsing==3.0.9
pytest==7.2.0
pytest-asyncio==0.23.7
Expand Down
14 changes: 6 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,21 @@ def disable_auth(monkeypatch):


@pytest.fixture()
def mock_verify_and_extract_token():
def mock_verify_token():
"""Mock a successful token verification that does not raise any exceptions."""

def _verify_and_extract_token(token):
def _verify_token(token):
return None

return _verify_and_extract_token
return _verify_token


@pytest.fixture()
def set_mock_verify_and_extract_token(
monkeypatch, mock_verify_and_extract_token
):
def set_mock_verify_token(monkeypatch, mock_verify_token):
"""Set the verify_token function to a mock that does not raise any exceptions."""
monkeypatch.setattr(
"app.api.routers.query.verify_and_extract_token",
mock_verify_and_extract_token,
"app.api.routers.query.verify_token",
mock_verify_token,
)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_partial_node_failure_responses_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
caplog,
):
"""
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_partial_node_request_failures_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
error_to_raise,
expected_node_message,
caplog,
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_all_nodes_failure_handled_gracefully(
test_app,
mock_failed_connection_httpx_get,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
set_valid_test_federation_nodes,
caplog,
):
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_all_nodes_success_handled_gracefully(
set_valid_test_federation_nodes,
mocked_single_matching_dataset_result,
mock_token,
set_mock_verify_and_extract_token,
set_mock_verify_token,
):
"""
Test that when queries sent to all nodes succeed, the federation API response includes an overall success status and no errors.
Expand Down
60 changes: 33 additions & 27 deletions tests/test_security.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pytest
from fastapi import HTTPException
from google.oauth2 import id_token

from app.api.security import verify_and_extract_token
from app.api.security import extract_token, verify_token


def test_missing_client_id_raises_error_when_auth_enabled(
Expand Down Expand Up @@ -38,10 +37,12 @@ def test_missing_client_id_ignored_when_auth_disabled(monkeypatch, test_app):
"invalid_token",
["Bearer faketoken", "Bearer", "faketoken", "fakescheme faketoken"],
)
def test_invalid_token_raises_error(invalid_token):
def test_invalid_token_raises_error(monkeypatch, invalid_token):
"""Test that an invalid token raises an error from the verification process."""
monkeypatch.setattr("app.api.security.CLIENT_ID", "foo.id")

with pytest.raises(HTTPException) as exc_info:
verify_and_extract_token(invalid_token)
verify_token(invalid_token)

assert exc_info.value.status_code == 401
assert "Invalid token" in exc_info.value.detail
Expand All @@ -53,7 +54,7 @@ def test_invalid_token_raises_error(invalid_token):
)
def test_query_with_malformed_auth_header_fails(
test_app,
set_mock_verify_and_extract_token,
set_mock_verify_token,
enable_auth,
invalid_auth_header,
monkeypatch,
Expand All @@ -72,33 +73,38 @@ def test_query_with_malformed_auth_header_fails(
assert response.status_code == 403


def test_verified_token_returned_without_auth_scheme(monkeypatch, enable_auth):
def test_token_returned_without_auth_scheme(monkeypatch, enable_auth):
"""
Test that when a token is valid, verify_token correctly returns the token with the authorization scheme stripped.
"""
mock_valid_token = "Bearer foo"
mock_id_info = {
"iss": "https://accounts.google.com",
"azp": "123abc.apps.googleusercontent.com",
"aud": "123abc.apps.googleusercontent.com",
"sub": "1234567890",
"email": "[email protected]",
"email_verified": True,
"nbf": 1730476622,
"name": "Jane Doe",
"picture": "https://lh3.googleusercontent.com/a/example1234567890",
"given_name": "Jane",
"family_name": "Doe",
"iat": 1730476922,
"exp": 1730480522,
"jti": "123e4567-e89b",
}

def mock_oauth2_verify_token(param, request, client_id, **kwargs):
return mock_id_info
assert extract_token(mock_valid_token) == "foo"


def test_valid_token_does_not_error_out(monkeypatch, enable_auth):
"""
Test that when a valid token is passed to verify_token, the token is returned without errors.
"""

def mock_get_signing_key_from_jwt(*args, **kwargs):
# NOTE: The actual get_signing_key_from_jwt method should return a key object
return "signingkey"

def mock_jwt_decode(*args, **kwargs):
return {
"iss": "https://myissuer.com",
"aud": "123abc.myapp.com",
"sub": "1234567890",
"name": "John Doe",
"iat": 1730476922,
"exp": 1730480522,
}

monkeypatch.setattr("app.api.security.CLIENT_ID", "123abc.myapp.com")
monkeypatch.setattr(
id_token, "verify_oauth2_token", mock_oauth2_verify_token
"app.api.security.JWKS_CLIENT.get_signing_key_from_jwt",
mock_get_signing_key_from_jwt,
)
monkeypatch.setattr("app.api.security.jwt.decode", mock_jwt_decode)

assert verify_and_extract_token(mock_valid_token) == "foo"
assert verify_token("Bearer myvalidtoken") == "myvalidtoken"

0 comments on commit 9c32acb

Please sign in to comment.