Skip to content

Commit

Permalink
refactor our JKWS client into a constant
Browse files Browse the repository at this point in the history
  • Loading branch information
alyssadai committed Nov 29, 2024
1 parent 2566f1d commit 8a9af83
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions app/api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true"
CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None)

KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
ISSUER = "https://neurobagel.ca.auth0.com/"
# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations
# by automatically fetching updated keys when needed
# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115
JWKS_CLIENT = PyJWKClient(KEYS_URL)


def check_client_id():
"""Check if the CLIENT_ID environment variable is set."""
Expand Down Expand Up @@ -36,15 +43,11 @@ def verify_token(token: str) -> str:
Verify the ID token against the identity provider public keys, and return the token with the authorization scheme stripped.
Raise an HTTPException if the token is invalid.
"""
keys_url = "https://neurobagel.ca.auth0.com/.well-known/jwks.json"
issuer = "https://neurobagel.ca.auth0.com/"

try:
extracted_token = extract_token(token)
# Determine which key was used to sign the token
# Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint
jwks_client = PyJWKClient(keys_url)
signing_key = jwks_client.get_signing_key_from_jwt(extracted_token)
signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token)

# https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode
id_info = jwt.decode(

Check warning on line 53 in app/api/security.py

View check run for this annotation

Codecov / codecov/patch

app/api/security.py#L53

Added line #L53 was not covered by tests
Expand All @@ -55,7 +58,7 @@ def verify_token(token: str) -> str:
"require": ["aud", "iss", "exp", "iat"],
},
audience=CLIENT_ID,
issuer=issuer,
issuer=ISSUER,
)

# TODO: Remove print statement or turn into logging
Expand Down

0 comments on commit 8a9af83

Please sign in to comment.