Skip to content

Commit

Permalink
Correctly handle token refreshes
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Oct 5, 2024
1 parent 0f12762 commit 35f5e8c
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope)
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope)
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope)
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope)
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope)
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope)
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
72 changes: 72 additions & 0 deletions sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,75 @@ def get_token(*scopes, **kwargs):
assert credential.get_token.call_count == 2

test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM)


@empty_challenge_cache
def test_cae_token_expiry():
"""The policy should avoid sending claims more than once when a token expires."""

expected_content = b"a duck"

def test_with_challenge(claims_challenge, expected_claim):
first_token = "first_token"
second_token = "second_token"
third_token = "third_token"

class Requests:
count = 0

def send(request):
Requests.count += 1
if Requests.count == 1:
# first request should be unauthorized and have no content; triggers a KV challenge response
assert not request.body
assert "Authorization" not in request.headers
assert request.headers["Content-Length"] == "0"
return KV_CHALLENGE_RESPONSE
elif Requests.count == 2:
# second request will trigger a CAE challenge response in this test scenario
assert request.headers["Content-Length"]
assert request.body == expected_content
assert first_token in request.headers["Authorization"]
return claims_challenge
elif Requests.count == 3:
# third request should include the required claims and correctly use content from the first challenge
assert request.headers["Content-Length"]
assert request.body == expected_content
assert second_token in request.headers["Authorization"]
return Mock(status_code=200)
elif Requests.count == 4:
# fourth request should not include claims, but otherwise use content from the first challenge
assert request.headers["Content-Length"]
assert request.body == expected_content
assert third_token in request.headers["Authorization"]
return Mock(status_code=200)
raise ValueError("unexpected request")

def get_token(*scopes, **kwargs):
assert kwargs.get("enable_cae") == True
assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT
assert scopes[0] == RESOURCE + "/.default"
# Response to KV challenge
if Requests.count == 1:
assert kwargs.get("claims") == None
return AccessToken(first_token, time.time() + 3600)
# Response to first CAE challenge
elif Requests.count == 2:
assert kwargs.get("claims") == expected_claim
return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh
# Token refresh before making the final request
elif Requests.count == 3:
assert kwargs.get("claims") == None
return AccessToken(third_token, time.time() + 3600)

credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token))
pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send))
request = HttpRequest("POST", get_random_url())
request.set_bytes_body(expected_content)
pipeline.run(request)
pipeline.run(request) # Send the request again to trigger a token refresh upon expiry

# get_token is called for the KV and CAE challenges, as well as for the token refresh
assert credential.get_token.call_count == 3

test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM)
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,76 @@ async def get_token(*scopes, **kwargs):
assert credential.get_token.call_count == 2

await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM)


@pytest.mark.asyncio
@empty_challenge_cache
async def test_cae_token_expiry():
"""The policy should avoid sending claims more than once when a token expires."""

expected_content = b"a duck"

async def test_with_challenge(claims_challenge, expected_claim):
first_token = "first_token"
second_token = "second_token"
third_token = "third_token"

class Requests:
count = 0

async def send(request):
Requests.count += 1
if Requests.count == 1:
# first request should be unauthorized and have no content; triggers a KV challenge response
assert not request.body
assert "Authorization" not in request.headers
assert request.headers["Content-Length"] == "0"
return KV_CHALLENGE_RESPONSE
elif Requests.count == 2:
# second request will trigger a CAE challenge response in this test scenario
assert request.headers["Content-Length"]
assert request.body == expected_content
assert first_token in request.headers["Authorization"]
return claims_challenge
elif Requests.count == 3:
# third request should include the required claims and correctly use content from the first challenge
assert request.headers["Content-Length"]
assert request.body == expected_content
assert second_token in request.headers["Authorization"]
return Mock(status_code=200)
elif Requests.count == 4:
# fourth request should not include claims, but otherwise use content from the first challenge
assert request.headers["Content-Length"]
assert request.body == expected_content
assert third_token in request.headers["Authorization"]
return Mock(status_code=200)
raise ValueError("unexpected request")

async def get_token(*scopes, **kwargs):
assert kwargs.get("enable_cae") == True
assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT
assert scopes[0] == RESOURCE + "/.default"
# Response to KV challenge
if Requests.count == 1:
assert kwargs.get("claims") == None
return AccessToken(first_token, time.time() + 3600)
# Response to first CAE challenge
elif Requests.count == 2:
assert kwargs.get("claims") == expected_claim
return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh
# Token refresh before making the final request
elif Requests.count == 3:
assert kwargs.get("claims") == None
return AccessToken(third_token, time.time() + 3600)

credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token))
pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send))
request = HttpRequest("POST", get_random_url())
request.set_bytes_body(expected_content)
await pipeline.run(request)
await pipeline.run(request) # Send the request again to trigger a token refresh upon expiry

# get_token is called for the KV and CAE challenges, as well as for the token refresh
assert credential.get_token.call_count == 3

await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM)
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy):
"""

def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super().__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: AsyncTokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = await self._credential.get_token(scope)
self._token = await self._credential.get_token(scope, enable_cae=True)
else:
self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = await self._credential.get_token(
scope, tenant_id=challenge.tenant_id, enable_cae=True
)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy):
"""

def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None:
# Pass `enable_cae` so `enable_cae=True` is always passed to get_token
# Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request
super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs)
self._credential: TokenCredential = credential
self._token: Optional[AccessToken] = None
Expand Down Expand Up @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None:
scope = challenge.get_scope() or challenge.get_resource() + "/.default"
# Exclude tenant for AD FS authentication
if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"):
self._token = self._credential.get_token(scope)
self._token = self._credential.get_token(scope, enable_cae=True)
else:
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id)
self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True)

# ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore
Expand Down

0 comments on commit 35f5e8c

Please sign in to comment.