From 3445c35001707aecfcb427a4d8bcfeb0764a877f Mon Sep 17 00:00:00 2001 From: Serhii Buniak Date: Thu, 11 Nov 2021 19:18:28 +0200 Subject: [PATCH] Verify claims before signature, thus reduce http calls. --- okta_jwt_verifier/jwt_verifier.py | 12 ++++++------ tests/unit/test_jwt_verifier.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/okta_jwt_verifier/jwt_verifier.py b/okta_jwt_verifier/jwt_verifier.py index 23481ae..d5b36aa 100644 --- a/okta_jwt_verifier/jwt_verifier.py +++ b/okta_jwt_verifier/jwt_verifier.py @@ -89,12 +89,12 @@ async def verify_access_token(self, token, claims_to_verify=('iss', 'aud', 'exp' if headers.get('alg') != 'RS256': raise JWTValidationException('Header claim "alg" is invalid.') - okta_jwk = await self.get_jwk(headers['kid']) - self.verify_signature(token, okta_jwk) - self.verify_claims(claims, claims_to_verify=claims_to_verify, leeway=self.leeway) + + okta_jwk = await self.get_jwk(headers['kid']) + self.verify_signature(token, okta_jwk) except JWTValidationException: raise except Exception as err: @@ -125,13 +125,13 @@ async def verify_id_token(self, token, claims_to_verify=('iss', 'exp'), nonce=No if headers.get('alg') != 'RS256': raise JWTValidationException('Header claim "alg" is invalid.') - okta_jwk = await self.get_jwk(headers['kid']) - self.verify_signature(token, okta_jwk) - self.verify_claims(claims, claims_to_verify=claims_to_verify, leeway=self.leeway) + okta_jwk = await self.get_jwk(headers['kid']) + self.verify_signature(token, okta_jwk) + # verify client_id and nonce self.verify_client_id(claims['aud']) if 'nonce' in claims and claims['nonce'] != nonce: diff --git a/tests/unit/test_jwt_verifier.py b/tests/unit/test_jwt_verifier.py index e55b9d8..3693b63 100644 --- a/tests/unit/test_jwt_verifier.py +++ b/tests/unit/test_jwt_verifier.py @@ -188,6 +188,36 @@ def test_verify_claims_invalid(): jwt_verifier.verify_claims(claims, ('iss', 'aud', 'exp')) +@pytest.mark.asyncio +async def test_invalid_claims_fail_first(mocker): + """Check if claims are invalid, exception is raised and no network call is needed.""" + client_id = 'test_client_id' + audience = 'api://default' + headers = {'alg': 'RS256', 'kid': 'test_kid'} + iss_time = time.time() + claims = {'ver': 1, + 'jti': 'test_jti_str', + 'iss': 'https://test_issuer.com', + 'aud': audience, + 'iat': iss_time, + 'exp': iss_time+300, + 'cid': client_id, + 'uid': 'test_uid', + 'scp': ['openid'], + 'sub': 'test_jwt@okta.com'} + signing_input = 'test_signing_input' + signature = 'test_signature' + mock_parse_token = lambda token: (headers, claims, signing_input, signature) + mocker.patch('okta_jwt_verifier.jwt_utils.JWTUtils.parse_token', mock_parse_token) + + token = 'test_token' + issuer = 'https://invalid_issuer.com' + jwt_verifier = AccessTokenVerifier(issuer) + with pytest.raises(JWTValidationException) as err: + await jwt_verifier.verify(token) + assert str(err.value) == 'Invalid issuer' + + def test_verify_claims_missing_claim(): """Check if method verify_claims raises an exception if required claim is missing.""" client_id = 'test_client_id'