From 8a9af835b88cd000d7a6f04882bda7620e20aa98 Mon Sep 17 00:00:00 2001 From: Alyssa Dai Date: Fri, 29 Nov 2024 13:44:01 -0500 Subject: [PATCH] refactor our JKWS client into a constant --- app/api/security.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/app/api/security.py b/app/api/security.py index 5f4f492..dee3226 100644 --- a/app/api/security.py +++ b/app/api/security.py @@ -8,6 +8,13 @@ AUTH_ENABLED = os.environ.get("NB_ENABLE_AUTH", "True").lower() == "true" CLIENT_ID = os.environ.get("NB_QUERY_CLIENT_ID", None) +KEYS_URL = "https://neurobagel.ca.auth0.com/.well-known/jwks.json" +ISSUER = "https://neurobagel.ca.auth0.com/" +# We only need to define the JWKS client once because get_signing_key_from_jwt will handle key rotations +# by automatically fetching updated keys when needed +# See https://github.com/jpadilla/pyjwt/blob/3ebbb22f30f2b1b41727b269a08b427e9a85d6bb/jwt/jwks_client.py#L96-L115 +JWKS_CLIENT = PyJWKClient(KEYS_URL) + def check_client_id(): """Check if the CLIENT_ID environment variable is set.""" @@ -36,15 +43,11 @@ def verify_token(token: str) -> str: Verify the ID token against the identity provider public keys, and return the token with the authorization scheme stripped. Raise an HTTPException if the token is invalid. """ - keys_url = "https://neurobagel.ca.auth0.com/.well-known/jwks.json" - issuer = "https://neurobagel.ca.auth0.com/" - try: extracted_token = extract_token(token) # Determine which key was used to sign the token # Adapted from https://pyjwt.readthedocs.io/en/stable/usage.html#retrieve-rsa-signing-keys-from-a-jwks-endpoint - jwks_client = PyJWKClient(keys_url) - signing_key = jwks_client.get_signing_key_from_jwt(extracted_token) + signing_key = JWKS_CLIENT.get_signing_key_from_jwt(extracted_token) # https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode id_info = jwt.decode( @@ -55,7 +58,7 @@ def verify_token(token: str) -> str: "require": ["aud", "iss", "exp", "iat"], }, audience=CLIENT_ID, - issuer=issuer, + issuer=ISSUER, ) # TODO: Remove print statement or turn into logging