From ccd64b9df4a6e459a10e988a95949e3ea6aeddc2 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 11 Sep 2024 17:45:57 -0700 Subject: [PATCH 01/25] Add tests --- .../tests/test_challenge_auth.py | 55 +++++++++++++++++- .../tests/test_challenge_auth_async.py | 57 ++++++++++++++++++- 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index 8d9eff46b208..152d52e393ee 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -6,6 +6,7 @@ Tests for the HTTP challenge authentication implementation. These tests aren't parallelizable, because the challenge cache is global to the process. """ +import base64 import functools import os import time @@ -21,7 +22,6 @@ from azure.core.pipeline import Pipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest -from azure.identity import AzureCliCredential, AzurePowerShellCredential, ClientSecretCredential from azure.keyvault.keys import KeyClient from azure.keyvault.keys._shared import ChallengeAuthPolicy, HttpChallenge, HttpChallengeCache from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION @@ -536,3 +536,56 @@ def get_token(*_, **__): else: key = client.get_key("key-name") assert key.name == "key-name" + + +@empty_challenge_cache +def test_cae(): + """The policy should correctly handle claims in a challenge response""" + + expected_content = b"a duck" + + def test_with_challenge(challenge, expected_claim): + expected_token = "expected_token" + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content + assert not request.body + assert request.headers["Content-Length"] == "0" + return challenge + elif Requests.count == 2: + # second request should be authorized according to challenge and have the expected content + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + def get_token(*_, **kwargs): + assert kwargs.get("claims") == expected_claim + return AccessToken(expected_token, 0) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) + + assert credential.get_token.call_count == 1 + + url = f'authorization_uri="{get_random_url()}"' + resource = 'resource="https://vault.azure.net"' + cid = 'client_id="00000003-0000-0000-c000-000000000000"' + err = 'error="insufficient_claims"' + claim = '{"access_token": {"foo": "bar"}}' + # Claim token is a string of the base64 encoding of the claim + claim_token = base64.b64encode(claim.encode()).decode() + challenge = f'Bearer realm="", {url}, {resource}, {cid}, {err}, claims="{claim_token}"' + + challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + + test_with_challenge(challenge_response, claim) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 980753fbdc0a..f5c88a2362b3 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -7,6 +7,7 @@ the challenge cache is global to the process. """ import asyncio +import base64 import os import time from unittest.mock import Mock, patch @@ -18,7 +19,6 @@ from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest -from azure.identity.aio import AzureCliCredential, AzurePowerShellCredential, ClientSecretCredential from azure.keyvault.keys._shared import AsyncChallengeAuthPolicy,HttpChallenge, HttpChallengeCache from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION from azure.keyvault.keys.aio import KeyClient @@ -469,7 +469,6 @@ async def get_token(*_, **__): @pytest.mark.asyncio -@empty_challenge_cache @pytest.mark.parametrize("verify_challenge_resource", [True, False]) async def test_verify_challenge_resource_valid(verify_challenge_resource): """The auth policy should raise if the challenge resource isn't a valid URL unless check is disabled""" @@ -502,3 +501,57 @@ async def get_token(*_, **__): else: key = await client.get_key("key-name") assert key.name == "key-name" + + +@pytest.mark.asyncio +@empty_challenge_cache +async def test_cae(): + """The policy should correctly handle claims in a challenge response""" + + expected_content = b"a duck" + + async def test_with_challenge(challenge, expected_claim): + expected_token = "expected_token" + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content + assert not request.body + assert request.headers["Content-Length"] == "0" + return challenge + elif Requests.count == 2: + # second request should be authorized according to challenge and have the expected content + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + async def get_token(*_, **kwargs): + assert kwargs.get("claims") == expected_claim + return AccessToken(expected_token, 0) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) + + assert credential.get_token.call_count == 1 + + url = f'authorization_uri="{get_random_url()}"' + resource = 'resource="https://vault.azure.net"' + cid = 'client_id="00000003-0000-0000-c000-000000000000"' + err = 'error="insufficient_claims"' + claim = '{"access_token": {"foo": "bar"}}' + # Claim token is a string of the base64 encoding of the claim + claim_token = base64.b64encode(claim.encode()).decode() + challenge = f'Bearer realm="", {url}, {resource}, {cid}, {err}, claims="{claim_token}"' + + challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + + await test_with_challenge(challenge_response, claim) From 34fcfed9af147404047c307e88051ee2914300df Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 12 Sep 2024 14:45:46 -0700 Subject: [PATCH 02/25] Implement CAE support --- .../_shared/async_challenge_auth_policy.py | 10 +++-- .../keys/_shared/challenge_auth_policy.py | 10 +++-- .../keyvault/keys/_shared/http_challenge.py | 37 +++++++++++++++---- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..17cd0674e894 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, claims=challenge.claims) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index f16297aa5026..b9858736b13d 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, claims=challenge.claims) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py index df9055c7bda6..39a75ed3cd03 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,32 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None + encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # special case for claims, which can contain = symbols as padding + if "claims=" in item: + if encoded_claims: + # multiple claims challenges, e.g. for cross-tenant auth, would require special handling + # we can't support this scenario for now, so we ignore claims altogether if there are multiple + self.claims = None + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: From d07e0a5f2f155d34c9ca65d79b793e369175288f Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 12 Sep 2024 18:02:02 -0700 Subject: [PATCH 03/25] Share implementation across libraries --- .../_internal/async_challenge_auth_policy.py | 10 +++-- .../_internal/challenge_auth_policy.py | 10 +++-- .../_internal/http_challenge.py | 37 +++++++++++++++---- .../_shared/async_challenge_auth_policy.py | 10 +++-- .../_shared/challenge_auth_policy.py | 10 +++-- .../certificates/_shared/http_challenge.py | 37 +++++++++++++++---- .../_shared/async_challenge_auth_policy.py | 10 +++-- .../secrets/_shared/challenge_auth_policy.py | 10 +++-- .../secrets/_shared/http_challenge.py | 37 +++++++++++++++---- 9 files changed, 126 insertions(+), 45 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 1a872f36b6a8..17cd0674e894 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, claims=challenge.claims) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index f16297aa5026..b9858736b13d 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, claims=challenge.claims) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py index df9055c7bda6..39a75ed3cd03 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,32 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None + encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # special case for claims, which can contain = symbols as padding + if "claims=" in item: + if encoded_claims: + # multiple claims challenges, e.g. for cross-tenant auth, would require special handling + # we can't support this scenario for now, so we ignore claims altogether if there are multiple + self.claims = None + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..17cd0674e894 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, claims=challenge.claims) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index f16297aa5026..b9858736b13d 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, claims=challenge.claims) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py index df9055c7bda6..39a75ed3cd03 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,32 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None + encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # special case for claims, which can contain = symbols as padding + if "claims=" in item: + if encoded_claims: + # multiple claims challenges, e.g. for cross-tenant auth, would require special handling + # we can't support this scenario for now, so we ignore claims altogether if there are multiple + self.claims = None + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 1a872f36b6a8..17cd0674e894 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -53,9 +53,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, claims=challenge.claims) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -104,9 +106,9 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope) + await self.authorize_request(request, scope, claims=challenge.claims) else: - await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index f16297aa5026..b9858736b13d 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -82,9 +82,11 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, claims=challenge.claims) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token( + scope, claims=challenge.claims, tenant_id=challenge.tenant_id + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -132,9 +134,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request(request, scope, tenant_id=challenge.tenant_id) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py index df9055c7bda6..39a75ed3cd03 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import base64 from typing import Dict, MutableMapping, Optional from urllib import parse @@ -18,7 +19,13 @@ class HttpChallenge(object): def __init__( self, request_uri: str, challenge: str, response_headers: "Optional[MutableMapping[str, str]]" = None ) -> None: - """Parses an HTTP WWW-Authentication Bearer challenge from a server.""" + """Parses an HTTP WWW-Authentication Bearer challenge from a server. + + Example challenge with claims: + Bearer authorization="https://login.windows-ppe.net/", error="invalid_token", + error_description="User session has been revoked", + claims="eyJhY2Nlc3NfdG9rZW4iOnsibmJmIjp7ImVzc2VudGlhbCI6dHJ1ZSwgInZhbHVlIjoiMTYwMzc0MjgwMCJ9fX0=" + """ self.source_authority = self._validate_request_uri(request_uri) self.source_uri = request_uri self._parameters: "Dict[str, str]" = {} @@ -29,16 +36,32 @@ def __init__( self.scheme = split_challenge[0] trimmed_challenge = split_challenge[1] + self.claims = None + encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): + # special case for claims, which can contain = symbols as padding + if "claims=" in item: + if encoded_claims: + # multiple claims challenges, e.g. for cross-tenant auth, would require special handling + # we can't support this scenario for now, so we ignore claims altogether if there are multiple + self.claims = None + encoded_claims = item[item.index("=") + 1 :].strip(" \"'") + padding_needed = -len(encoded_claims) % 4 + try: + decoded_claims = base64.urlsafe_b64decode(encoded_claims + "=" * padding_needed).decode() + self.claims = decoded_claims + except Exception: # pylint:disable=broad-except + continue # process name=value pairs - comps = item.split("=") - if len(comps) == 2: - key = comps[0].strip(' "') - value = comps[1].strip(' "') - if key: - self._parameters[key] = value + else: + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + self._parameters[key] = value # minimum set of parameters if not self._parameters: From 84351af917f7431e6297cd0de87002b71ad93d60 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Tue, 17 Sep 2024 19:48:21 -0700 Subject: [PATCH 04/25] Enable CAE; provide claims only in challenges --- .../_internal/async_challenge_auth_policy.py | 12 ++++++------ .../_internal/challenge_auth_policy.py | 12 ++++++------ .../_shared/async_challenge_auth_policy.py | 12 ++++++------ .../certificates/_shared/challenge_auth_policy.py | 12 ++++++------ .../keys/_shared/async_challenge_auth_policy.py | 12 ++++++------ .../keyvault/keys/_shared/challenge_auth_policy.py | 12 ++++++------ .../secrets/_shared/async_challenge_auth_policy.py | 12 ++++++------ .../secrets/_shared/challenge_auth_policy.py | 12 ++++++------ 8 files changed, 48 insertions(+), 48 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 17cd0674e894..ddbdaa6122da 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -53,11 +53,9 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, claims=challenge.claims) + self._token = await self._credential.get_token(scope) else: - self._token = await self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -106,9 +104,11 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims) + await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index b9858736b13d..8cfaccb37695 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -82,11 +82,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, claims=challenge.claims) + self._token = self._credential.get_token(scope) else: - self._token = self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -134,9 +132,11 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims) + self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 17cd0674e894..ddbdaa6122da 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -53,11 +53,9 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, claims=challenge.claims) + self._token = await self._credential.get_token(scope) else: - self._token = await self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -106,9 +104,11 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims) + await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index b9858736b13d..8cfaccb37695 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -82,11 +82,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, claims=challenge.claims) + self._token = self._credential.get_token(scope) else: - self._token = self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -134,9 +132,11 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims) + self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 17cd0674e894..ddbdaa6122da 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -53,11 +53,9 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, claims=challenge.claims) + self._token = await self._credential.get_token(scope) else: - self._token = await self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -106,9 +104,11 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims) + await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index b9858736b13d..8cfaccb37695 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -82,11 +82,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, claims=challenge.claims) + self._token = self._credential.get_token(scope) else: - self._token = self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -134,9 +132,11 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims) + self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 17cd0674e894..ddbdaa6122da 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -53,11 +53,9 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, claims=challenge.claims) + self._token = await self._credential.get_token(scope) else: - self._token = await self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -106,9 +104,11 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims) + await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - await self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + await self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index b9858736b13d..8cfaccb37695 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -82,11 +82,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, claims=challenge.claims) + self._token = self._credential.get_token(scope) else: - self._token = self._credential.get_token( - scope, claims=challenge.claims, tenant_id=challenge.tenant_id - ) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore @@ -134,9 +132,11 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims) + self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) else: - self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) + self.authorize_request( + request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + ) return True From c22cb08d8e24645117cf3ec8d25c01d072a9850b Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 18 Sep 2024 17:26:50 -0700 Subject: [PATCH 05/25] Update tests for success scenarios --- .../tests/test_challenge_auth.py | 120 ++++++++++++++++- .../tests/test_challenge_auth_async.py | 121 +++++++++++++++++- 2 files changed, 227 insertions(+), 14 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index 152d52e393ee..cee6f815970c 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -540,12 +540,21 @@ def get_token(*_, **__): @empty_challenge_cache def test_cae(): - """The policy should correctly handle claims in a challenge response""" + """The policy should handle claims in a challenge response after having successfully authenticated prior.""" expected_content = b"a duck" def test_with_challenge(challenge, expected_claim): + first_token = "first_token" expected_token = "expected_token" + tenant = "tenant-id" + endpoint = f"https://authority.net/{tenant}" + resource = "https://vault.azure.net" + + first_challenge = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, + ) class Requests: count = 0 @@ -556,18 +565,115 @@ def send(request): # first request should be unauthorized and have no content assert not request.body assert request.headers["Content-Length"] == "0" - return challenge + return first_challenge elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 3: + # third request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return challenge + elif Requests.count == 4: + # fourth request should include the required claims and correctly use context from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content assert expected_token in request.headers["Authorization"] return Mock(status_code=200) raise ValueError("unexpected request") def get_token(*_, **kwargs): - assert kwargs.get("claims") == expected_claim - return AccessToken(expected_token, 0) + if Requests.count == 1: + assert kwargs.get("claims") == None + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 3: + assert kwargs.get("claims") == expected_claim + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(expected_token, time.time() + 3600) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) # Send the request once to trigger a regular auth challenge + pipeline.run(request) # Send the request again to trigger a CAE challenge + + assert credential.get_token.call_count == 2 + + url = f'authorization_uri="{get_random_url()}"' + cid = 'client_id="00000003-0000-0000-c000-000000000000"' + err = 'error="insufficient_claims"' + claim = '{"access_token": {"foo": "bar"}}' + # Claim token is a string of the base64 encoding of the claim + claim_token = base64.b64encode(claim.encode()).decode() + # Note that no resource or scope is necessarily proovided in a CAE challenge + challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' + + challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + + test_with_challenge(challenge_response, claim) + + +@empty_challenge_cache +def test_cae_consecutive_challenges(): + """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" + + expected_content = b"a duck" + + def test_with_challenge(challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + tenant = "tenant-id" + endpoint = f"https://authority.net/{tenant}" + resource = "https://vault.azure.net" + + first_challenge = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, + ) + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content + assert not request.body + assert request.headers["Content-Length"] == "0" + return first_challenge + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use context from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + def get_token(*_, **kwargs): + if Requests.count == 1: + assert kwargs.get("claims") == None + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 2: + assert kwargs.get("claims") == expected_claim + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) @@ -575,16 +681,16 @@ def get_token(*_, **kwargs): request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + assert credential.get_token.call_count == 2 url = f'authorization_uri="{get_random_url()}"' - resource = 'resource="https://vault.azure.net"' cid = 'client_id="00000003-0000-0000-c000-000000000000"' err = 'error="insufficient_claims"' claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - challenge = f'Bearer realm="", {url}, {resource}, {cid}, {err}, claims="{claim_token}"' + # Note that no resource or scope is necessarily proovided in a CAE challenge + challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index f5c88a2362b3..098674790c23 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -506,12 +506,21 @@ async def get_token(*_, **__): @pytest.mark.asyncio @empty_challenge_cache async def test_cae(): - """The policy should correctly handle claims in a challenge response""" + """The policy should handle claims in a challenge response after having successfully authenticated prior.""" expected_content = b"a duck" async def test_with_challenge(challenge, expected_claim): + first_token = "first_token" expected_token = "expected_token" + tenant = "tenant-id" + endpoint = f"https://authority.net/{tenant}" + resource = "https://vault.azure.net" + + first_challenge = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, + ) class Requests: count = 0 @@ -522,18 +531,116 @@ async def send(request): # first request should be unauthorized and have no content assert not request.body assert request.headers["Content-Length"] == "0" - return challenge + return first_challenge elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 3: + # third request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return challenge + elif Requests.count == 4: + # fourth request should include the required claims and correctly use context from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content assert expected_token in request.headers["Authorization"] return Mock(status_code=200) raise ValueError("unexpected request") async def get_token(*_, **kwargs): - assert kwargs.get("claims") == expected_claim - return AccessToken(expected_token, 0) + if Requests.count == 1: + assert kwargs.get("claims") == None + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 3: + assert kwargs.get("claims") == expected_claim + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(expected_token, time.time() + 3600) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) # Send the request once to trigger a regular auth challenge + await pipeline.run(request) # Send the request again to trigger a CAE challenge + + assert credential.get_token.call_count == 2 + + url = f'authorization_uri="{get_random_url()}"' + cid = 'client_id="00000003-0000-0000-c000-000000000000"' + err = 'error="insufficient_claims"' + claim = '{"access_token": {"foo": "bar"}}' + # Claim token is a string of the base64 encoding of the claim + claim_token = base64.b64encode(claim.encode()).decode() + # Note that no resource or scope is necessarily proovided in a CAE challenge + challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' + + challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + + await test_with_challenge(challenge_response, claim) + + +@pytest.mark.asyncio +@empty_challenge_cache +async def test_cae_consecutive_challenges(): + """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" + + expected_content = b"a duck" + + async def test_with_challenge(challenge, expected_claim): + first_token = "first_token" + expected_token = "expected_token" + tenant = "tenant-id" + endpoint = f"https://authority.net/{tenant}" + resource = "https://vault.azure.net" + + first_challenge = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, + ) + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content + assert not request.body + assert request.headers["Content-Length"] == "0" + return first_challenge + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use context from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + async def get_token(*_, **kwargs): + if Requests.count == 1: + assert kwargs.get("claims") == None + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 2: + assert kwargs.get("claims") == expected_claim + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) @@ -541,16 +648,16 @@ async def get_token(*_, **kwargs): request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + assert credential.get_token.call_count == 2 url = f'authorization_uri="{get_random_url()}"' - resource = 'resource="https://vault.azure.net"' cid = 'client_id="00000003-0000-0000-c000-000000000000"' err = 'error="insufficient_claims"' claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - challenge = f'Bearer realm="", {url}, {resource}, {cid}, {err}, claims="{claim_token}"' + # Note that no resource or scope is necessarily proovided in a CAE challenge + challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) From d854071ca5abf2ee4e96c1085fa9f011ba922de0 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 18 Sep 2024 18:02:03 -0700 Subject: [PATCH 06/25] Handle non-consecutive challenges (in Keys) --- .../_shared/async_challenge_auth_policy.py | 21 ++++++++++++++--- .../keys/_shared/challenge_auth_policy.py | 23 +++++++++++++++---- sdk/keyvault/azure-keyvault-keys/setup.py | 2 +- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index ddbdaa6122da..e2874fec25e4 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -42,6 +42,7 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -78,6 +79,14 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -87,7 +96,13 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -104,10 +119,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + await self.authorize_request(request, scope, claims=challenge.claims) else: await self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id ) return True diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 8cfaccb37695..ddb0a9f9757e 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -71,6 +71,7 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -106,6 +107,14 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -115,7 +124,13 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -132,11 +147,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id - ) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/setup.py b/sdk/keyvault/azure-keyvault-keys/setup.py index cbb26cf86d49..7bbe28af42c5 100644 --- a/sdk/keyvault/azure-keyvault-keys/setup.py +++ b/sdk/keyvault/azure-keyvault-keys/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "cryptography>=2.1.4", "isodate>=0.6.1", "typing-extensions>=4.0.1", From 6c19bbc40533436e8af102817df4e143a4900fe1 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Fri, 20 Sep 2024 14:11:10 -0700 Subject: [PATCH 07/25] Cover invalid challenge flows --- .../tests/test_challenge_auth.py | 36 ++++++++++--------- .../tests/test_challenge_auth_async.py | 36 ++++++++++--------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index cee6f815970c..a55f0d977c3e 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -544,14 +544,14 @@ def test_cae(): expected_content = b"a duck" - def test_with_challenge(challenge, expected_claim): + def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" resource = "https://vault.azure.net" - first_challenge = Mock( + kv_challenge = Mock( status_code=401, headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, ) @@ -562,10 +562,10 @@ class Requests: def send(request): Requests.count += 1 if Requests.count == 1: - # first request should be unauthorized and have no content + # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body assert request.headers["Content-Length"] == "0" - return first_challenge + return kv_challenge elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -577,13 +577,14 @@ def send(request): assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return challenge + return claims_challenge elif Requests.count == 4: # fourth request should include the required claims and correctly use context from the first challenge + # we return another KV challenge to verify that the policy doesn't try to handle this invalid flow assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return Mock(status_code=200) + return kv_challenge raise ValueError("unexpected request") def get_token(*_, **kwargs): @@ -605,6 +606,7 @@ def get_token(*_, **kwargs): pipeline.run(request) # Send the request once to trigger a regular auth challenge pipeline.run(request) # Send the request again to trigger a CAE challenge + # get_token is called for the first KV challenge and CAE challenge, but not the second KV challenge assert credential.get_token.call_count == 2 url = f'authorization_uri="{get_random_url()}"' @@ -616,9 +618,9 @@ def get_token(*_, **kwargs): # Note that no resource or scope is necessarily proovided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - test_with_challenge(challenge_response, claim) + test_with_challenge(claims_challenge, claim) @empty_challenge_cache @@ -627,14 +629,14 @@ def test_cae_consecutive_challenges(): expected_content = b"a duck" - def test_with_challenge(challenge, expected_claim): + def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" resource = "https://vault.azure.net" - first_challenge = Mock( + kv_challenge = Mock( status_code=401, headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, ) @@ -645,22 +647,23 @@ class Requests: def send(request): Requests.count += 1 if Requests.count == 1: - # first request should be unauthorized and have no content + # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body assert request.headers["Content-Length"] == "0" - return first_challenge + return kv_challenge elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return challenge + return claims_challenge elif Requests.count == 3: # third request should include the required claims and correctly use context from the first challenge + # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return Mock(status_code=200) + return claims_challenge raise ValueError("unexpected request") def get_token(*_, **kwargs): @@ -681,6 +684,7 @@ def get_token(*_, **kwargs): request.set_bytes_body(expected_content) pipeline.run(request) + # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 url = f'authorization_uri="{get_random_url()}"' @@ -692,6 +696,6 @@ def get_token(*_, **kwargs): # Note that no resource or scope is necessarily proovided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - test_with_challenge(challenge_response, claim) + test_with_challenge(claims_challenge, claim) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 098674790c23..cf7936b9470e 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -510,14 +510,14 @@ async def test_cae(): expected_content = b"a duck" - async def test_with_challenge(challenge, expected_claim): + async def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" resource = "https://vault.azure.net" - first_challenge = Mock( + kv_challenge = Mock( status_code=401, headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, ) @@ -528,10 +528,10 @@ class Requests: async def send(request): Requests.count += 1 if Requests.count == 1: - # first request should be unauthorized and have no content + # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body assert request.headers["Content-Length"] == "0" - return first_challenge + return kv_challenge elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -543,13 +543,14 @@ async def send(request): assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return challenge + return claims_challenge elif Requests.count == 4: # fourth request should include the required claims and correctly use context from the first challenge + # we return another KV challenge to verify that the policy doesn't try to handle this invalid flow assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return Mock(status_code=200) + return kv_challenge raise ValueError("unexpected request") async def get_token(*_, **kwargs): @@ -582,9 +583,10 @@ async def get_token(*_, **kwargs): # Note that no resource or scope is necessarily proovided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + # get_token is called for the first KV challenge and CAE challenge, but not the second KV challenge + claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - await test_with_challenge(challenge_response, claim) + await test_with_challenge(claims_challenge, claim) @pytest.mark.asyncio @@ -594,14 +596,14 @@ async def test_cae_consecutive_challenges(): expected_content = b"a duck" - async def test_with_challenge(challenge, expected_claim): + async def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" resource = "https://vault.azure.net" - first_challenge = Mock( + kv_challenge = Mock( status_code=401, headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, ) @@ -612,22 +614,23 @@ class Requests: async def send(request): Requests.count += 1 if Requests.count == 1: - # first request should be unauthorized and have no content + # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body assert request.headers["Content-Length"] == "0" - return first_challenge + return kv_challenge elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return challenge + return claims_challenge elif Requests.count == 3: # third request should include the required claims and correctly use context from the first challenge + # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return Mock(status_code=200) + return claims_challenge raise ValueError("unexpected request") async def get_token(*_, **kwargs): @@ -648,6 +651,7 @@ async def get_token(*_, **kwargs): request.set_bytes_body(expected_content) await pipeline.run(request) + # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 url = f'authorization_uri="{get_random_url()}"' @@ -659,6 +663,6 @@ async def get_token(*_, **kwargs): # Note that no resource or scope is necessarily proovided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - challenge_response = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) + claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - await test_with_challenge(challenge_response, claim) + await test_with_challenge(claims_challenge, claim) From c919056dfd52172563187ace6710ac05a180c9ac Mon Sep 17 00:00:00 2001 From: mccoyp Date: Fri, 20 Sep 2024 14:11:30 -0700 Subject: [PATCH 08/25] Handle (in)valid challenge flows --- .../_shared/async_challenge_auth_policy.py | 104 +++++++++++++++++- .../keys/_shared/challenge_auth_policy.py | 69 +++++++++++- 2 files changed, 170 insertions(+), 3 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index e2874fec25e4..e7018ffd6ec3 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -16,18 +16,48 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _update_challenge + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. @@ -44,6 +74,76 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any self._request_copy: Optional[HttpRequest] = None self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index ddb0a9f9757e..b2e260f20166 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -23,7 +23,7 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -73,6 +73,73 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> self._request_copy: Optional[HttpRequest] = None self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response + def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) From ff731e8376d11d1822323e331315ad77568b92e8 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Fri, 20 Sep 2024 14:31:09 -0700 Subject: [PATCH 09/25] Share updates across libraries --- .../_internal/async_challenge_auth_policy.py | 125 +++++++++++++++++- .../_internal/challenge_auth_policy.py | 92 ++++++++++++- .../_shared/async_challenge_auth_policy.py | 125 +++++++++++++++++- .../_shared/challenge_auth_policy.py | 92 ++++++++++++- sdk/keyvault/azure-keyvault-keys/setup.py | 2 +- .../_shared/async_challenge_auth_policy.py | 125 +++++++++++++++++- .../secrets/_shared/challenge_auth_policy.py | 92 ++++++++++++- 7 files changed, 619 insertions(+), 34 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index ddbdaa6122da..e7018ffd6ec3 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -16,18 +16,48 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _update_challenge + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. @@ -42,6 +72,77 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -78,6 +179,14 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -87,7 +196,13 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -104,10 +219,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + await self.authorize_request(request, scope, claims=challenge.claims) else: await self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id ) return True diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 8cfaccb37695..b2e260f20166 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -23,7 +23,7 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -71,6 +71,74 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -106,6 +174,14 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -115,7 +191,13 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -132,11 +214,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id - ) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index ddbdaa6122da..e7018ffd6ec3 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -16,18 +16,48 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _update_challenge + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. @@ -42,6 +72,77 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -78,6 +179,14 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -87,7 +196,13 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -104,10 +219,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + await self.authorize_request(request, scope, claims=challenge.claims) else: await self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id ) return True diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 8cfaccb37695..b2e260f20166 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -23,7 +23,7 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -71,6 +71,74 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -106,6 +174,14 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -115,7 +191,13 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -132,11 +214,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id - ) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True diff --git a/sdk/keyvault/azure-keyvault-keys/setup.py b/sdk/keyvault/azure-keyvault-keys/setup.py index 7bbe28af42c5..cbb26cf86d49 100644 --- a/sdk/keyvault/azure-keyvault-keys/setup.py +++ b/sdk/keyvault/azure-keyvault-keys/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.31.0", + "azure-core>=1.29.5", "cryptography>=2.1.4", "isodate>=0.6.1", "typing-extensions>=4.0.1", diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index ddbdaa6122da..e7018ffd6ec3 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -16,18 +16,48 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec from urllib.parse import urlparse from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _update_challenge + +P = ParamSpec("P") +T = TypeVar("T") + + +@overload +async def await_result(func: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +@overload +async def await_result(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def await_result(func: Callable[P, Union[T, Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T: + """If func returns an awaitable, await it. + + :param func: The function to run. + :type func: callable + :param args: The positional arguments to pass to the function. + :type args: list + :rtype: any + :return: The result of the function + """ + result = func(*args, **kwargs) + if isinstance(result, Awaitable): + return await result + return result + + class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """Policy for handling HTTP authentication challenges. @@ -42,6 +72,77 @@ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + async def send( + self, request: PipelineRequest[HttpRequest] + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + await await_result(self.on_request, request) + response: PipelineResponse[HttpRequest, AsyncHttpResponse] + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + await await_result(self.on_response, request, response) + + if response.http_response.status_code == 401: + return await self.handle_challenge_flow(request, response) + return response + + async def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, AsyncHttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, AsyncHttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = await self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = await self.next.send(request) + except Exception: # pylint:disable=broad-except + await await_result(self.on_exception, request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return await self.handle_challenge_flow(request, response, consecutive_challenge=True) + await await_result(self.on_response, request, response) + return response + async def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -78,6 +179,14 @@ async def on_request(self, request: PipelineRequest) -> None: async def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -87,7 +196,13 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -104,10 +219,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - await self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + await self.authorize_request(request, scope, claims=challenge.claims) else: await self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id + request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id ) return True diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 8cfaccb37695..b2e260f20166 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -23,7 +23,7 @@ from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, HttpResponse from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache @@ -71,6 +71,74 @@ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None + self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token + + def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: + """Authorize request with a bearer token and send it to the next policy. + + We implement this method to account for the valid scenario where a Key Vault authentication challenge is + immediately followed by a CAE claims challenge. The base class's implementation would return the second 401 to + the caller, but we should handle that second challenge as well (and only return any third 401 response). + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self.on_request(request) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + self.on_response(request, response) + if response.http_response.status_code == 401: + return self.handle_challenge_flow(request, response) + return response + + def handle_challenge_flow( + self, + request: PipelineRequest[HttpRequest], + response: PipelineResponse[HttpRequest, HttpResponse], + consecutive_challenge: bool = False, + ) -> PipelineResponse[HttpRequest, HttpResponse]: + """Handle the challenge flow of Key Vault and CAE authentication. + + :param request: The pipeline request object + :type request: ~azure.core.pipeline.PipelineRequest + :param response: The pipeline response object + :type response: ~azure.core.pipeline.PipelineResponse + :param bool consecutive_challenge: Whether the challenge is arriving immediately after another challenge. + Consecutive challenges can only be valid if a Key Vault challenge is followed by a CAE claims challenge. + True if the preceding challenge was a Key Vault challenge; False otherwise. + + :return: The pipeline response object + :rtype: ~azure.core.pipeline.PipelineResponse + """ + self._token = None # any cached token is invalid + if "WWW-Authenticate" in response.http_response.headers: + request_authorized = self.on_challenge(request, response) + if request_authorized: + # if we receive a challenge response, we retrieve a new token + # which matches the new target. In this case, we don't want to remove + # token from the request so clear the 'insecure_domain_change' tag + request.context.options.pop("insecure_domain_change", False) + try: + response = self.next.send(request) + except Exception: # pylint:disable=broad-except + self.on_exception(request) + raise + + # If consecutive_challenge == True, this could be a third consecutive 401 + if response.http_response.status_code == 401 and not consecutive_challenge: + # If the previous challenge wasn't from CAE, we can try this function one more time + challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if challenge and not challenge.claims: + return self.handle_challenge_flow(request, response, consecutive_challenge=True) + self.on_response(request, response) + return response def on_request(self, request: PipelineRequest) -> None: _enforce_tls(request) @@ -106,6 +174,14 @@ def on_request(self, request: PipelineRequest) -> None: def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> bool: try: + # CAE challenges may not include a scope or tenant; cache from the previous challenge to use if necessary + old_scope: Optional[str] = None + old_tenant: Optional[str] = None + cached_challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) + if cached_challenge: + old_scope = cached_challenge.get_scope() or cached_challenge.get_resource() + "/.default" + old_tenant = cached_challenge.tenant_id + challenge = _update_challenge(request, response) # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" @@ -115,7 +191,13 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + # Use the old scope for CAE challenges. The parsing will succeed here since it did before + if challenge.claims and old_scope: + resource_domain = urlparse(old_scope).netloc + challenge._parameters["scope"] = old_scope + challenge.tenant_id = old_tenant + else: + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): @@ -132,11 +214,9 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # The tenant parsed from AD FS challenges is "adfs"; we don't actually need a tenant for AD FS authentication # For AD FS we skip cross-tenant authentication per https://github.com/Azure/azure-sdk-for-python/issues/28648 if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self.authorize_request(request, scope, claims=challenge.claims, enable_cae=True) + self.authorize_request(request, scope, claims=challenge.claims) else: - self.authorize_request( - request, scope, claims=challenge.claims, enable_cae=True, tenant_id=challenge.tenant_id - ) + self.authorize_request(request, scope, claims=challenge.claims, tenant_id=challenge.tenant_id) return True From 237c57b1d626049c3880705785e3828a156671fa Mon Sep 17 00:00:00 2001 From: mccoyp Date: Fri, 20 Sep 2024 16:07:31 -0700 Subject: [PATCH 10/25] Fix spelling, pylint --- .../administration/_internal/async_challenge_auth_policy.py | 5 +++-- .../administration/_internal/challenge_auth_policy.py | 2 +- .../certificates/_shared/async_challenge_auth_policy.py | 5 +++-- .../keyvault/certificates/_shared/challenge_auth_policy.py | 2 +- .../keyvault/keys/_shared/async_challenge_auth_policy.py | 5 +++-- .../azure/keyvault/keys/_shared/challenge_auth_policy.py | 2 +- .../azure-keyvault-keys/tests/test_challenge_auth.py | 4 ++-- .../azure-keyvault-keys/tests/test_challenge_auth_async.py | 4 ++-- .../keyvault/secrets/_shared/async_challenge_auth_policy.py | 5 +++-- .../azure/keyvault/secrets/_shared/challenge_auth_policy.py | 2 +- 10 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index e7018ffd6ec3..ec303cdc6756 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -17,9 +17,10 @@ from copy import deepcopy import time from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union -from typing_extensions import ParamSpec from urllib.parse import urlparse +from typing_extensions import ParamSpec + from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -199,7 +200,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index b2e260f20166..9edd1555865b 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -194,7 +194,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index e7018ffd6ec3..ec303cdc6756 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -17,9 +17,10 @@ from copy import deepcopy import time from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union -from typing_extensions import ParamSpec from urllib.parse import urlparse +from typing_extensions import ParamSpec + from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -199,7 +200,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index b2e260f20166..9edd1555865b 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -194,7 +194,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index e7018ffd6ec3..ec303cdc6756 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -17,9 +17,10 @@ from copy import deepcopy import time from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union -from typing_extensions import ParamSpec from urllib.parse import urlparse +from typing_extensions import ParamSpec + from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -199,7 +200,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index b2e260f20166..9edd1555865b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -194,7 +194,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index a55f0d977c3e..e56573568a74 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -615,7 +615,7 @@ def get_token(*_, **kwargs): claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily proovided in a CAE challenge + # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) @@ -693,7 +693,7 @@ def get_token(*_, **kwargs): claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily proovided in a CAE challenge + # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index cf7936b9470e..a13b2b675e18 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -580,7 +580,7 @@ async def get_token(*_, **kwargs): claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily proovided in a CAE challenge + # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' # get_token is called for the first KV challenge and CAE challenge, but not the second KV challenge @@ -660,7 +660,7 @@ async def get_token(*_, **kwargs): claim = '{"access_token": {"foo": "bar"}}' # Claim token is a string of the base64 encoding of the claim claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily proovided in a CAE challenge + # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index e7018ffd6ec3..ec303cdc6756 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -17,9 +17,10 @@ from copy import deepcopy import time from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union -from typing_extensions import ParamSpec from urllib.parse import urlparse +from typing_extensions import ParamSpec + from azure.core.credentials import AccessToken from azure.core.credentials_async import AsyncTokenCredential from azure.core.pipeline import PipelineRequest, PipelineResponse @@ -199,7 +200,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index b2e260f20166..9edd1555865b 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -194,7 +194,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> # Use the old scope for CAE challenges. The parsing will succeed here since it did before if challenge.claims and old_scope: resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access challenge.tenant_id = old_tenant else: raise ValueError(f"The challenge contains invalid scope '{scope}'.") From 013673b7a4d394d879ffab79f80af33e746d5972 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 25 Sep 2024 17:06:34 -0700 Subject: [PATCH 11/25] Update changelogs --- sdk/keyvault/azure-keyvault-administration/CHANGELOG.md | 1 + sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md | 1 + sdk/keyvault/azure-keyvault-keys/CHANGELOG.md | 1 + sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md | 1 + 4 files changed, 4 insertions(+) diff --git a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md index 340928d8759d..5fbdb54852ae 100644 --- a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE) ### Breaking Changes diff --git a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md index 086f14d6e349..130535c01007 100644 --- a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE) ### Breaking Changes diff --git a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md index 4c7444151bf1..6cb678d30d13 100644 --- a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE) ### Breaking Changes diff --git a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md index a366b4af33a0..2a886c7aedce 100644 --- a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md @@ -4,6 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` +- Added support for Continuous Access Evaluation (CAE) ### Breaking Changes From 5da13ff90437ed213234321b34d0e1f54098f2c9 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 25 Sep 2024 17:07:38 -0700 Subject: [PATCH 12/25] Update tests for feedback --- .../tests/test_challenge_auth.py | 21 +++++++++++-------- .../tests/test_challenge_auth_async.py | 21 +++++++++++-------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index e56573568a74..f92d440465e2 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -564,6 +564,7 @@ def send(request): if Requests.count == 1: # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body + assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" return kv_challenge elif Requests.count == 2: @@ -588,15 +589,15 @@ def send(request): raise ValueError("unexpected request") def get_token(*_, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(first_token, time.time() + 3600) + # Response to CAE challenge elif Requests.count == 3: assert kwargs.get("claims") == expected_claim - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) @@ -613,8 +614,9 @@ def get_token(*_, **kwargs): cid = 'client_id="00000003-0000-0000-c000-000000000000"' err = 'error="insufficient_claims"' claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim + # Claim token is a string of the base64 encoding of the claim. Trim the padding to ensure the policy can handle it claim_token = base64.b64encode(claim.encode()).decode() + claim_token = claim_token.strip("=") # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' @@ -649,6 +651,7 @@ def send(request): if Requests.count == 1: # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body + assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" return kv_challenge elif Requests.count == 2: @@ -667,15 +670,15 @@ def send(request): raise ValueError("unexpected request") def get_token(*_, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge elif Requests.count == 2: assert kwargs.get("claims") == expected_claim - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index a13b2b675e18..f76496633d0d 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -530,6 +530,7 @@ async def send(request): if Requests.count == 1: # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body + assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" return kv_challenge elif Requests.count == 2: @@ -554,15 +555,15 @@ async def send(request): raise ValueError("unexpected request") async def get_token(*_, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(first_token, time.time() + 3600) + # Response to CAE challenge elif Requests.count == 3: assert kwargs.get("claims") == expected_claim - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) @@ -578,8 +579,9 @@ async def get_token(*_, **kwargs): cid = 'client_id="00000003-0000-0000-c000-000000000000"' err = 'error="insufficient_claims"' claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim + # Claim token is a string of the base64 encoding of the claim. Trim the padding to ensure the policy can handle it claim_token = base64.b64encode(claim.encode()).decode() + claim_token = claim_token.strip("=") # Note that no resource or scope is necessarily provided in a CAE challenge challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' @@ -616,6 +618,7 @@ async def send(request): if Requests.count == 1: # first request should be unauthorized and have no content; triggers a KV challenge response assert not request.body + assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" return kv_challenge elif Requests.count == 2: @@ -634,15 +637,15 @@ async def send(request): raise ValueError("unexpected request") async def get_token(*_, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == tenant + # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge elif Requests.count == 2: assert kwargs.get("claims") == expected_claim - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant return AccessToken(expected_token, time.time() + 3600) credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) From 36cb9fdd6548a65f05c2865560652853d43b4d27 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 25 Sep 2024 17:07:59 -0700 Subject: [PATCH 13/25] Use super() instead of private attribute --- .../administration/_internal/async_challenge_auth_policy.py | 4 ++-- .../administration/_internal/challenge_auth_policy.py | 4 ++-- .../certificates/_shared/async_challenge_auth_policy.py | 4 ++-- .../keyvault/certificates/_shared/challenge_auth_policy.py | 4 ++-- .../keyvault/keys/_shared/async_challenge_auth_policy.py | 4 ++-- .../azure/keyvault/keys/_shared/challenge_auth_policy.py | 4 ++-- .../keyvault/secrets/_shared/async_challenge_auth_policy.py | 4 ++-- .../azure/keyvault/secrets/_shared/challenge_auth_policy.py | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index ec303cdc6756..5484c3b773b9 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -68,12 +68,12 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token async def send( self, request: PipelineRequest[HttpRequest] diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 9edd1555865b..d62307d58060 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -66,12 +66,12 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: """Authorize request with a bearer token and send it to the next policy. diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index ec303cdc6756..5484c3b773b9 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -68,12 +68,12 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token async def send( self, request: PipelineRequest[HttpRequest] diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 9edd1555865b..d62307d58060 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -66,12 +66,12 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: """Authorize request with a bearer token and send it to the next policy. diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index ec303cdc6756..5484c3b773b9 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -68,12 +68,12 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token async def send( self, request: PipelineRequest[HttpRequest] diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 9edd1555865b..d62307d58060 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -66,12 +66,12 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: """Authorize request with a bearer token and send it to the next policy. diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index ec303cdc6756..5484c3b773b9 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -68,12 +68,12 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - super().__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token async def send( self, request: PipelineRequest[HttpRequest] diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 9edd1555865b..d62307d58060 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -66,12 +66,12 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - super(ChallengeAuthPolicy, self).__init__(credential, *scopes, **kwargs) + # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None - self._enable_cae = kwargs.pop("enable_cae", True) # When True, `enable_cae=True` is always passed to get_token def send(self, request: PipelineRequest[HttpRequest]) -> PipelineResponse[HttpRequest, HttpResponse]: """Authorize request with a bearer token and send it to the next policy. From f9ff1765c092bf29ad593e338d6db94ba9b4e20f Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 26 Sep 2024 16:32:34 -0700 Subject: [PATCH 14/25] Add live test; assert scope --- .../tests/test_challenge_auth.py | 36 ++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index f92d440465e2..b2588d0b01a9 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -17,8 +17,8 @@ from devtools_testutils import recorded_by_proxy import pytest -from azure.core.credentials import AccessToken -from azure.core.exceptions import ServiceRequestError +from azure.core.credentials import AccessToken, TokenCredential +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.core.pipeline import Pipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest @@ -69,6 +69,32 @@ def test_multitenant_authentication(self, client, is_hsm, **kwargs): else: os.environ.pop("AZURE_TENANT_ID") + @pytest.mark.skip("Manual test for specific, CAE-enabled environments.") + @pytest.mark.live_test_only + def test_cae_live(self, **kwargs): + class CredentialWrapper(TokenCredential): + def __init__(self, credential): + self._credential = credential + self._claims = None + + def get_token(self, *scopes, **kwargs): + assert kwargs["enable_cae"] == True + if kwargs.get("claims"): + # We should only receive claims once; subsequent challenges should be returned to the caller + assert self._claims is None + self._claims = kwargs["claims"] + return self._credential.get_token(*scopes, **kwargs) + + credential = self.get_credential(KeyClient) + wrapped = CredentialWrapper(credential) + client = KeyClient(vault_url=os.environ["AZURE_KEYVAULT_URL"], credential=wrapped) + try: + client.create_rsa_key("key-name") # Basic request meant to just trigger CAE challenges + # Test environment may continuously return claims challenges; a second consecutive challenge will raise + except ClientAuthenticationError as e: + assert "continuous access evaluation" in str(e).lower() + assert wrapped._claims is not None # Ensure we passed a claim to a token request + def empty_challenge_cache(fn): @functools.wraps(fn) def wrapper(**kwargs): @@ -588,9 +614,10 @@ def send(request): return kv_challenge raise ValueError("unexpected request") - def get_token(*_, **kwargs): + def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True assert kwargs.get("tenant_id") == tenant + assert scopes[0] == resource + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -669,9 +696,10 @@ def send(request): return claims_challenge raise ValueError("unexpected request") - def get_token(*_, **kwargs): + def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True assert kwargs.get("tenant_id") == tenant + assert scopes[0] == resource + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None From bf8f054819c5bceef6cd1b91dc886189be4c498a Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 26 Sep 2024 16:33:01 -0700 Subject: [PATCH 15/25] Fix auth policy to send scope correctly --- .../keyvault/keys/_shared/challenge_auth_policy.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index d62307d58060..58d37de9ee73 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -183,6 +183,10 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -191,13 +195,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): From 850e6e83b05fdb8124ba6ff6e84317961a08fc2a Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 26 Sep 2024 16:55:07 -0700 Subject: [PATCH 16/25] Async tests; sync challenge policy code --- .../_internal/async_challenge_auth_policy.py | 12 +++---- .../_internal/challenge_auth_policy.py | 12 +++---- .../_internal/http_challenge.py | 7 +--- .../_shared/async_challenge_auth_policy.py | 12 +++---- .../_shared/challenge_auth_policy.py | 12 +++---- .../certificates/_shared/http_challenge.py | 7 +--- .../_shared/async_challenge_auth_policy.py | 12 +++---- .../keyvault/keys/_shared/http_challenge.py | 7 +--- .../tests/test_challenge_auth_async.py | 36 +++++++++++++++++-- .../_shared/async_challenge_auth_policy.py | 12 +++---- .../secrets/_shared/challenge_auth_policy.py | 12 +++---- .../secrets/_shared/http_challenge.py | 7 +--- 12 files changed, 72 insertions(+), 76 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 5484c3b773b9..b6d390aa4e0c 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -189,6 +189,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -197,13 +201,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index d62307d58060..58d37de9ee73 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -183,6 +183,10 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -191,13 +195,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py index 39a75ed3cd03..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/http_challenge.py @@ -37,16 +37,11 @@ def __init__( trimmed_challenge = split_challenge[1] self.claims = None - encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): - # special case for claims, which can contain = symbols as padding + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge if "claims=" in item: - if encoded_claims: - # multiple claims challenges, e.g. for cross-tenant auth, would require special handling - # we can't support this scenario for now, so we ignore claims altogether if there are multiple - self.claims = None encoded_claims = item[item.index("=") + 1 :].strip(" \"'") padding_needed = -len(encoded_claims) % 4 try: diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 5484c3b773b9..b6d390aa4e0c 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -189,6 +189,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -197,13 +201,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index d62307d58060..58d37de9ee73 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -183,6 +183,10 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -191,13 +195,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py index 39a75ed3cd03..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/http_challenge.py @@ -37,16 +37,11 @@ def __init__( trimmed_challenge = split_challenge[1] self.claims = None - encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): - # special case for claims, which can contain = symbols as padding + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge if "claims=" in item: - if encoded_claims: - # multiple claims challenges, e.g. for cross-tenant auth, would require special handling - # we can't support this scenario for now, so we ignore claims altogether if there are multiple - self.claims = None encoded_claims = item[item.index("=") + 1 :].strip(" \"'") padding_needed = -len(encoded_claims) % 4 try: diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 5484c3b773b9..b6d390aa4e0c 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -189,6 +189,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -197,13 +201,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py index 39a75ed3cd03..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/http_challenge.py @@ -37,16 +37,11 @@ def __init__( trimmed_challenge = split_challenge[1] self.claims = None - encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): - # special case for claims, which can contain = symbols as padding + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge if "claims=" in item: - if encoded_claims: - # multiple claims challenges, e.g. for cross-tenant auth, would require special handling - # we can't support this scenario for now, so we ignore claims altogether if there are multiple - self.claims = None encoded_claims = item[item.index("=") + 1 :].strip(" \"'") padding_needed = -len(encoded_claims) % 4 try: diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index f76496633d0d..9f748f0aa2a2 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -15,7 +15,8 @@ import pytest from azure.core.credentials import AccessToken -from azure.core.exceptions import ServiceRequestError +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest @@ -70,6 +71,33 @@ async def test_multitenant_authentication(self, client, is_hsm, **kwargs): else: os.environ.pop("AZURE_TENANT_ID") + @pytest.mark.skip("Manual test for specific, CAE-enabled environments.") + @pytest.mark.asyncio + @pytest.mark.live_test_only + async def test_cae_live(self, **kwargs): + class CredentialWrapper(AsyncTokenCredential): + def __init__(self, credential): + self._credential = credential + self._claims = None + + async def get_token(self, *scopes, **kwargs): + assert kwargs["enable_cae"] == True + if kwargs.get("claims"): + # We should only receive claims once; subsequent challenges should be returned to the caller + assert self._claims is None + self._claims = kwargs["claims"] + return await self._credential.get_token(*scopes, **kwargs) + + credential = self.get_credential(KeyClient, is_async=True) + wrapped = CredentialWrapper(credential) + client = KeyClient(vault_url=os.environ["AZURE_KEYVAULT_URL"], credential=wrapped) + try: + await client.create_rsa_key("key-name") # Basic request meant to just trigger CAE challenges + # Test environment may continuously return claims challenges; a second consecutive challenge will raise + except ClientAuthenticationError as e: + assert "continuous access evaluation" in str(e).lower() + assert wrapped._claims is not None # Ensure we passed a claim to a token request + @pytest.mark.asyncio @empty_challenge_cache @@ -554,9 +582,10 @@ async def send(request): return kv_challenge raise ValueError("unexpected request") - async def get_token(*_, **kwargs): + async def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True assert kwargs.get("tenant_id") == tenant + assert scopes[0] == resource + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -636,9 +665,10 @@ async def send(request): return claims_challenge raise ValueError("unexpected request") - async def get_token(*_, **kwargs): + async def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True assert kwargs.get("tenant_id") == tenant + assert scopes[0] == resource + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 5484c3b773b9..b6d390aa4e0c 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -189,6 +189,10 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -197,13 +201,7 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index d62307d58060..58d37de9ee73 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -183,6 +183,10 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> old_tenant = cached_challenge.tenant_id challenge = _update_challenge(request, response) + # CAE challenges may not include a scope or tenant; use the previous challenge's values if necessary + if challenge.claims and old_scope: + challenge._parameters["scope"] = old_scope # pylint:disable=protected-access + challenge.tenant_id = old_tenant # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" except ValueError: @@ -191,13 +195,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> if self._verify_challenge_resource: resource_domain = urlparse(scope).netloc if not resource_domain: - # Use the old scope for CAE challenges. The parsing will succeed here since it did before - if challenge.claims and old_scope: - resource_domain = urlparse(old_scope).netloc - challenge._parameters["scope"] = old_scope # pylint:disable=protected-access - challenge.tenant_id = old_tenant - else: - raise ValueError(f"The challenge contains invalid scope '{scope}'.") + raise ValueError(f"The challenge contains invalid scope '{scope}'.") request_domain = urlparse(request.http_request.url).netloc if not request_domain.lower().endswith(f".{resource_domain.lower()}"): diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py index 39a75ed3cd03..0320df5a868b 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/http_challenge.py @@ -37,16 +37,11 @@ def __init__( trimmed_challenge = split_challenge[1] self.claims = None - encoded_claims = None # split trimmed challenge into comma-separated name=value pairs. Values are expected # to be surrounded by quotes which are stripped here. for item in trimmed_challenge.split(","): - # special case for claims, which can contain = symbols as padding + # Special case for claims, which can contain = symbols as padding. Assume at most one claim per challenge if "claims=" in item: - if encoded_claims: - # multiple claims challenges, e.g. for cross-tenant auth, would require special handling - # we can't support this scenario for now, so we ignore claims altogether if there are multiple - self.claims = None encoded_claims = item[item.index("=") + 1 :].strip(" \"'") padding_needed = -len(encoded_claims) % 4 try: From e78d4e9a960d5a8618af532c6a0b074f3b6951c9 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 3 Oct 2024 16:08:04 -0700 Subject: [PATCH 17/25] Ensure no re-sending claims in tests --- .../tests/test_challenge_auth.py | 130 +++++++----------- .../tests/test_challenge_auth_async.py | 124 ++++++----------- 2 files changed, 95 insertions(+), 159 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index b2588d0b01a9..990855528d00 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -69,31 +69,6 @@ def test_multitenant_authentication(self, client, is_hsm, **kwargs): else: os.environ.pop("AZURE_TENANT_ID") - @pytest.mark.skip("Manual test for specific, CAE-enabled environments.") - @pytest.mark.live_test_only - def test_cae_live(self, **kwargs): - class CredentialWrapper(TokenCredential): - def __init__(self, credential): - self._credential = credential - self._claims = None - - def get_token(self, *scopes, **kwargs): - assert kwargs["enable_cae"] == True - if kwargs.get("claims"): - # We should only receive claims once; subsequent challenges should be returned to the caller - assert self._claims is None - self._claims = kwargs["claims"] - return self._credential.get_token(*scopes, **kwargs) - - credential = self.get_credential(KeyClient) - wrapped = CredentialWrapper(credential) - client = KeyClient(vault_url=os.environ["AZURE_KEYVAULT_URL"], credential=wrapped) - try: - client.create_rsa_key("key-name") # Basic request meant to just trigger CAE challenges - # Test environment may continuously return claims challenges; a second consecutive challenge will raise - except ClientAuthenticationError as e: - assert "continuous access evaluation" in str(e).lower() - assert wrapped._claims is not None # Ensure we passed a claim to a token request def empty_challenge_cache(fn): @functools.wraps(fn) @@ -111,6 +86,25 @@ def get_random_url(): return f"https://{uuid4()}.vault.azure.net/{uuid4()}".replace("-", "") +URL = f'authorization_uri="{get_random_url()}"' +CLIENT_ID = 'client_id="00000003-0000-0000-c000-000000000000"' +CAE_ERROR = 'error="insufficient_claims"' +CAE_DECODED_CLAIM = '{"access_token": {"foo": "bar"}}' +# Claim token is a string of the base64 encoding of the claim +CLAIM_TOKEN = base64.b64encode(CAE_DECODED_CLAIM.encode()).decode() +# Note that no resource or scope is necessarily provided in a CAE challenge +CLAIM_CHALLENGE = f'Bearer realm="", {URL}, {CLIENT_ID}, {CAE_ERROR}, claims="{CLAIM_TOKEN}"' +CAE_CHALLENGE_HEADER = Mock(status_code=401, headers={"WWW-Authenticate": CLAIM_CHALLENGE}) + +KV_CHALLENGE_TENANT = "tenant-id" +ENDPOINT = f"https://authority.net/{KV_CHALLENGE_TENANT}" +RESOURCE = "https://vault.azure.net" +KV_CHALLENGE_HEADER = Mock( + status_code=401, + headers={"WWW-Authenticate": f'Bearer authorization="{ENDPOINT}", resource={RESOURCE}'}, +) + + def add_url_port(url: str): """Like `get_random_url`, but includes a port number (comes after the domain, and before the path of the URL).""" @@ -573,14 +567,6 @@ def test_cae(): def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" - tenant = "tenant-id" - endpoint = f"https://authority.net/{tenant}" - resource = "https://vault.azure.net" - - kv_challenge = Mock( - status_code=401, - headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, - ) class Requests: count = 0 @@ -592,7 +578,7 @@ def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return kv_challenge + return KV_CHALLENGE_HEADER elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -606,18 +592,30 @@ def send(request): assert first_token in request.headers["Authorization"] return claims_challenge elif Requests.count == 4: - # fourth request should include the required claims and correctly use context from the first challenge - # we return another KV challenge to verify that the policy doesn't try to handle this invalid flow + # fourth request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 5: + # fifth request should be a regular request with the expected token assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return kv_challenge + return KV_CHALLENGE_HEADER + elif Requests.count == 6: + # sixth request should respond to the KV challenge WITHOUT including claims + # we return another challenge to confirm that the policy will return consecutive 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return KV_CHALLENGE_HEADER raise ValueError("unexpected request") def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant - assert scopes[0] == resource + "/.default" + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -626,6 +624,12 @@ def get_token(*scopes, **kwargs): elif Requests.count == 3: assert kwargs.get("claims") == expected_claim return AccessToken(expected_token, time.time() + 3600) + # Response to second KV challenge + elif Requests.count == 5: + assert kwargs.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 6: + raise ValueError("unexpected token request") credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) @@ -633,23 +637,12 @@ def get_token(*scopes, **kwargs): request.set_bytes_body(expected_content) pipeline.run(request) # Send the request once to trigger a regular auth challenge pipeline.run(request) # Send the request again to trigger a CAE challenge + pipeline.run(request) # Send the request once to trigger another regular auth challenge - # get_token is called for the first KV challenge and CAE challenge, but not the second KV challenge - assert credential.get_token.call_count == 2 - - url = f'authorization_uri="{get_random_url()}"' - cid = 'client_id="00000003-0000-0000-c000-000000000000"' - err = 'error="insufficient_claims"' - claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim. Trim the padding to ensure the policy can handle it - claim_token = base64.b64encode(claim.encode()).decode() - claim_token = claim_token.strip("=") - # Note that no resource or scope is necessarily provided in a CAE challenge - challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' + # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge + assert credential.get_token.call_count == 3 - claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - - test_with_challenge(claims_challenge, claim) + test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) @empty_challenge_cache @@ -661,14 +654,6 @@ def test_cae_consecutive_challenges(): def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" - tenant = "tenant-id" - endpoint = f"https://authority.net/{tenant}" - resource = "https://vault.azure.net" - - kv_challenge = Mock( - status_code=401, - headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, - ) class Requests: count = 0 @@ -680,7 +665,7 @@ def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return kv_challenge + return KV_CHALLENGE_HEADER elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] @@ -688,7 +673,7 @@ def send(request): assert first_token in request.headers["Authorization"] return claims_challenge elif Requests.count == 3: - # third request should include the required claims and correctly use context from the first challenge + # third request should include the required claims and correctly use content from the first challenge # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content @@ -698,8 +683,8 @@ def send(request): def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant - assert scopes[0] == resource + "/.default" + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -718,15 +703,4 @@ def get_token(*scopes, **kwargs): # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 - url = f'authorization_uri="{get_random_url()}"' - cid = 'client_id="00000003-0000-0000-c000-000000000000"' - err = 'error="insufficient_claims"' - claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim - claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily provided in a CAE challenge - challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - - claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - - test_with_challenge(claims_challenge, claim) + test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 9f748f0aa2a2..ddf3a0a924d0 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -29,7 +29,16 @@ from _shared.helpers import Request, mock_response from _shared.helpers_async import async_validating_transport from _shared.test_case_async import KeyVaultTestCase -from test_challenge_auth import empty_challenge_cache, get_random_url, add_url_port +from test_challenge_auth import ( + empty_challenge_cache, + get_random_url, + add_url_port, + CAE_CHALLENGE_HEADER, + CAE_DECODED_CLAIM, + KV_CHALLENGE_HEADER, + KV_CHALLENGE_TENANT, + RESOURCE, +) only_default_version = get_decorator(is_async=True, api_versions=[DEFAULT_VERSION]) @@ -71,33 +80,6 @@ async def test_multitenant_authentication(self, client, is_hsm, **kwargs): else: os.environ.pop("AZURE_TENANT_ID") - @pytest.mark.skip("Manual test for specific, CAE-enabled environments.") - @pytest.mark.asyncio - @pytest.mark.live_test_only - async def test_cae_live(self, **kwargs): - class CredentialWrapper(AsyncTokenCredential): - def __init__(self, credential): - self._credential = credential - self._claims = None - - async def get_token(self, *scopes, **kwargs): - assert kwargs["enable_cae"] == True - if kwargs.get("claims"): - # We should only receive claims once; subsequent challenges should be returned to the caller - assert self._claims is None - self._claims = kwargs["claims"] - return await self._credential.get_token(*scopes, **kwargs) - - credential = self.get_credential(KeyClient, is_async=True) - wrapped = CredentialWrapper(credential) - client = KeyClient(vault_url=os.environ["AZURE_KEYVAULT_URL"], credential=wrapped) - try: - await client.create_rsa_key("key-name") # Basic request meant to just trigger CAE challenges - # Test environment may continuously return claims challenges; a second consecutive challenge will raise - except ClientAuthenticationError as e: - assert "continuous access evaluation" in str(e).lower() - assert wrapped._claims is not None # Ensure we passed a claim to a token request - @pytest.mark.asyncio @empty_challenge_cache @@ -541,14 +523,6 @@ async def test_cae(): async def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" - tenant = "tenant-id" - endpoint = f"https://authority.net/{tenant}" - resource = "https://vault.azure.net" - - kv_challenge = Mock( - status_code=401, - headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, - ) class Requests: count = 0 @@ -560,7 +534,7 @@ async def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return kv_challenge + return KV_CHALLENGE_HEADER elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -574,18 +548,30 @@ async def send(request): assert first_token in request.headers["Authorization"] return claims_challenge elif Requests.count == 4: - # fourth request should include the required claims and correctly use context from the first challenge - # we return another KV challenge to verify that the policy doesn't try to handle this invalid flow + # fourth request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert expected_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 5: + # fifth request should be a regular request with the expected token assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return kv_challenge + return KV_CHALLENGE_HEADER + elif Requests.count == 6: + # sixth request should respond to the KV challenge WITHOUT including claims + # we return another challenge to confirm that the policy will return consecutive 401s to the user + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return KV_CHALLENGE_HEADER raise ValueError("unexpected request") async def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant - assert scopes[0] == resource + "/.default" + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -594,6 +580,12 @@ async def get_token(*scopes, **kwargs): elif Requests.count == 3: assert kwargs.get("claims") == expected_claim return AccessToken(expected_token, time.time() + 3600) + # Response to second KV challenge + elif Requests.count == 5: + assert kwargs.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + elif Requests.count == 6: + raise ValueError("unexpected token request") credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) @@ -601,23 +593,12 @@ async def get_token(*scopes, **kwargs): request.set_bytes_body(expected_content) await pipeline.run(request) # Send the request once to trigger a regular auth challenge await pipeline.run(request) # Send the request again to trigger a CAE challenge + await pipeline.run(request) # Send the request once to trigger another regular auth challenge - assert credential.get_token.call_count == 2 - - url = f'authorization_uri="{get_random_url()}"' - cid = 'client_id="00000003-0000-0000-c000-000000000000"' - err = 'error="insufficient_claims"' - claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim. Trim the padding to ensure the policy can handle it - claim_token = base64.b64encode(claim.encode()).decode() - claim_token = claim_token.strip("=") - # Note that no resource or scope is necessarily provided in a CAE challenge - challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' + # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge + assert credential.get_token.call_count == 3 - # get_token is called for the first KV challenge and CAE challenge, but not the second KV challenge - claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - - await test_with_challenge(claims_challenge, claim) + await test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) @pytest.mark.asyncio @@ -630,14 +611,6 @@ async def test_cae_consecutive_challenges(): async def test_with_challenge(claims_challenge, expected_claim): first_token = "first_token" expected_token = "expected_token" - tenant = "tenant-id" - endpoint = f"https://authority.net/{tenant}" - resource = "https://vault.azure.net" - - kv_challenge = Mock( - status_code=401, - headers={"WWW-Authenticate": f'Bearer authorization="{endpoint}", resource={resource}'}, - ) class Requests: count = 0 @@ -649,7 +622,7 @@ async def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return kv_challenge + return KV_CHALLENGE_HEADER elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] @@ -657,7 +630,7 @@ async def send(request): assert first_token in request.headers["Authorization"] return claims_challenge elif Requests.count == 3: - # third request should include the required claims and correctly use context from the first challenge + # third request should include the required claims and correctly use content from the first challenge # we return another CAE challenge to verify that the policy will return consecutive CAE 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content @@ -667,8 +640,8 @@ async def send(request): async def get_token(*scopes, **kwargs): assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == tenant - assert scopes[0] == resource + "/.default" + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: assert kwargs.get("claims") == None @@ -687,15 +660,4 @@ async def get_token(*scopes, **kwargs): # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 - url = f'authorization_uri="{get_random_url()}"' - cid = 'client_id="00000003-0000-0000-c000-000000000000"' - err = 'error="insufficient_claims"' - claim = '{"access_token": {"foo": "bar"}}' - # Claim token is a string of the base64 encoding of the claim - claim_token = base64.b64encode(claim.encode()).decode() - # Note that no resource or scope is necessarily provided in a CAE challenge - challenge = f'Bearer realm="", {url}, {cid}, {err}, claims="{claim_token}"' - - claims_challenge = Mock(status_code=401, headers={"WWW-Authenticate": challenge}) - - await test_with_challenge(claims_challenge, claim) + await test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) From 1a6c9f767e46ef15eaba7fa94a486c3505f3586f Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 3 Oct 2024 16:08:25 -0700 Subject: [PATCH 18/25] Fix policy to handle KV -> KV challenge --- .../_shared/async_challenge_auth_policy.py | 10 ++++++--- .../keys/_shared/challenge_auth_policy.py | 22 +++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index b6d390aa4e0c..46b4de7bdfa4 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -28,7 +28,7 @@ from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge P = ParamSpec("P") @@ -123,6 +123,11 @@ async def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = await self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -138,8 +143,7 @@ async def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return await self.handle_challenge_flow(request, response, consecutive_challenge=True) await await_result(self.on_response, request, response) return response diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 58d37de9ee73..4ccd04d2cfee 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -36,6 +36,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -119,6 +133,11 @@ def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -134,8 +153,7 @@ def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return self.handle_challenge_flow(request, response, consecutive_challenge=True) self.on_response(request, response) return response From 3fad4db493c82e433be18748263114768bc43777 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 3 Oct 2024 16:19:11 -0700 Subject: [PATCH 19/25] Share bug fix across libraries --- .../_internal/async_challenge_auth_policy.py | 10 ++++++--- .../_internal/challenge_auth_policy.py | 22 +++++++++++++++++-- .../_shared/async_challenge_auth_policy.py | 10 ++++++--- .../_shared/challenge_auth_policy.py | 22 +++++++++++++++++-- .../_shared/async_challenge_auth_policy.py | 10 ++++++--- .../secrets/_shared/challenge_auth_policy.py | 22 +++++++++++++++++-- 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index b6d390aa4e0c..46b4de7bdfa4 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -28,7 +28,7 @@ from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge P = ParamSpec("P") @@ -123,6 +123,11 @@ async def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = await self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -138,8 +143,7 @@ async def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return await self.handle_challenge_flow(request, response, consecutive_challenge=True) await await_result(self.on_response, request, response) return response diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 58d37de9ee73..4ccd04d2cfee 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -36,6 +36,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -119,6 +133,11 @@ def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -134,8 +153,7 @@ def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return self.handle_challenge_flow(request, response, consecutive_challenge=True) self.on_response(request, response) return response diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index b6d390aa4e0c..46b4de7bdfa4 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -28,7 +28,7 @@ from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge P = ParamSpec("P") @@ -123,6 +123,11 @@ async def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = await self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -138,8 +143,7 @@ async def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return await self.handle_challenge_flow(request, response, consecutive_challenge=True) await await_result(self.on_response, request, response) return response diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 58d37de9ee73..4ccd04d2cfee 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -36,6 +36,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -119,6 +133,11 @@ def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -134,8 +153,7 @@ def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return self.handle_challenge_flow(request, response, consecutive_challenge=True) self.on_response(request, response) return response diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index b6d390aa4e0c..46b4de7bdfa4 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -28,7 +28,7 @@ from azure.core.rest import AsyncHttpResponse, HttpRequest from . import http_challenge_cache as ChallengeCache -from .challenge_auth_policy import _enforce_tls, _update_challenge +from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge P = ParamSpec("P") @@ -123,6 +123,11 @@ async def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = await self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -138,8 +143,7 @@ async def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return await self.handle_challenge_flow(request, response, consecutive_challenge=True) await await_result(self.on_response, request, response) return response diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 58d37de9ee73..4ccd04d2cfee 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -36,6 +36,20 @@ def _enforce_tls(request: PipelineRequest) -> None: ) +def _has_claims(challenge: str) -> bool: + """Check if a challenge header contains claims. + + :param challenge: The challenge header to check. + :type challenge: str + + :returns: True if the challenge contains claims; False otherwise. + :rtype: bool + """ + # Split the challenge into its scheme and parameters, then check if any parameter contains claims + split_challenge = challenge.strip().split(" ", 1) + return any("claims=" in item for item in split_challenge[1].split(",")) + + def _update_challenge(request: PipelineRequest, challenger: PipelineResponse) -> HttpChallenge: """Parse challenge from a challenge response, cache it, and return it. @@ -119,6 +133,11 @@ def handle_challenge_flow( """ self._token = None # any cached token is invalid if "WWW-Authenticate" in response.http_response.headers: + # If the previous challenge was a KV challenge and this one is too, return the 401 + claims_challenge = _has_claims(response.http_response.headers["WWW-Authenticate"]) + if consecutive_challenge and not claims_challenge: + return response + request_authorized = self.on_challenge(request, response) if request_authorized: # if we receive a challenge response, we retrieve a new token @@ -134,8 +153,7 @@ def handle_challenge_flow( # If consecutive_challenge == True, this could be a third consecutive 401 if response.http_response.status_code == 401 and not consecutive_challenge: # If the previous challenge wasn't from CAE, we can try this function one more time - challenge = ChallengeCache.get_challenge_for_url(request.http_request.url) - if challenge and not challenge.claims: + if not claims_challenge: return self.handle_challenge_flow(request, response, consecutive_challenge=True) self.on_response(request, response) return response From ba2a954f76bae10f1830fad0a861606b9048dc04 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Thu, 3 Oct 2024 17:15:12 -0700 Subject: [PATCH 20/25] Clarify test variable names --- .../tests/test_challenge_auth.py | 16 ++++++++-------- .../tests/test_challenge_auth_async.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index 990855528d00..a1188292bbce 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -94,12 +94,12 @@ def get_random_url(): CLAIM_TOKEN = base64.b64encode(CAE_DECODED_CLAIM.encode()).decode() # Note that no resource or scope is necessarily provided in a CAE challenge CLAIM_CHALLENGE = f'Bearer realm="", {URL}, {CLIENT_ID}, {CAE_ERROR}, claims="{CLAIM_TOKEN}"' -CAE_CHALLENGE_HEADER = Mock(status_code=401, headers={"WWW-Authenticate": CLAIM_CHALLENGE}) +CAE_CHALLENGE_RESPONSE = Mock(status_code=401, headers={"WWW-Authenticate": CLAIM_CHALLENGE}) KV_CHALLENGE_TENANT = "tenant-id" ENDPOINT = f"https://authority.net/{KV_CHALLENGE_TENANT}" RESOURCE = "https://vault.azure.net" -KV_CHALLENGE_HEADER = Mock( +KV_CHALLENGE_RESPONSE = Mock( status_code=401, headers={"WWW-Authenticate": f'Bearer authorization="{ENDPOINT}", resource={RESOURCE}'}, ) @@ -578,7 +578,7 @@ def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -602,14 +602,14 @@ def send(request): assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 6: # sixth request should respond to the KV challenge WITHOUT including claims # we return another challenge to confirm that the policy will return consecutive 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE raise ValueError("unexpected request") def get_token(*scopes, **kwargs): @@ -642,7 +642,7 @@ def get_token(*scopes, **kwargs): # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge assert credential.get_token.call_count == 3 - test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @empty_challenge_cache @@ -665,7 +665,7 @@ def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] @@ -703,4 +703,4 @@ def get_token(*scopes, **kwargs): # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 - test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index ddf3a0a924d0..155002976cee 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -33,9 +33,9 @@ empty_challenge_cache, get_random_url, add_url_port, - CAE_CHALLENGE_HEADER, + CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM, - KV_CHALLENGE_HEADER, + KV_CHALLENGE_RESPONSE, KV_CHALLENGE_TENANT, RESOURCE, ) @@ -534,7 +534,7 @@ async def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 2: # second request should be authorized according to challenge and have the expected content assert request.headers["Content-Length"] @@ -558,14 +558,14 @@ async def send(request): assert request.headers["Content-Length"] assert request.body == expected_content assert expected_token in request.headers["Authorization"] - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 6: # sixth request should respond to the KV challenge WITHOUT including claims # we return another challenge to confirm that the policy will return consecutive 401s to the user assert request.headers["Content-Length"] assert request.body == expected_content assert first_token in request.headers["Authorization"] - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE raise ValueError("unexpected request") async def get_token(*scopes, **kwargs): @@ -598,7 +598,7 @@ async def get_token(*scopes, **kwargs): # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge assert credential.get_token.call_count == 3 - await test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @pytest.mark.asyncio @@ -622,7 +622,7 @@ async def send(request): assert not request.body assert "Authorization" not in request.headers assert request.headers["Content-Length"] == "0" - return KV_CHALLENGE_HEADER + return KV_CHALLENGE_RESPONSE elif Requests.count == 2: # second request will trigger a CAE challenge response in this test scenario assert request.headers["Content-Length"] @@ -660,4 +660,4 @@ async def get_token(*scopes, **kwargs): # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge assert credential.get_token.call_count == 2 - await test_with_challenge(CAE_CHALLENGE_HEADER, CAE_DECODED_CLAIM) + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) From 92c6704aa2d759d124692afad08c3582836e9371 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Sat, 5 Oct 2024 11:01:28 -0700 Subject: [PATCH 21/25] Correctly handle token refreshes --- .../_internal/async_challenge_auth_policy.py | 8 +- .../_internal/challenge_auth_policy.py | 6 +- .../_shared/async_challenge_auth_policy.py | 8 +- .../_shared/challenge_auth_policy.py | 6 +- .../_shared/async_challenge_auth_policy.py | 8 +- .../keys/_shared/challenge_auth_policy.py | 6 +- .../tests/test_challenge_auth.py | 72 ++++++++++++++++++ .../tests/test_challenge_auth_async.py | 73 +++++++++++++++++++ .../_shared/async_challenge_auth_policy.py | 8 +- .../secrets/_shared/challenge_auth_policy.py | 6 +- 10 files changed, 177 insertions(+), 24 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 46b4de7bdfa4..8ece1ecc018a 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, enable_cae=True) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 4ccd04d2cfee..8c5815a51d0e 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, enable_cae=True) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 46b4de7bdfa4..8ece1ecc018a 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, enable_cae=True) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 4ccd04d2cfee..8c5815a51d0e 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, enable_cae=True) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 46b4de7bdfa4..8ece1ecc018a 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, enable_cae=True) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 4ccd04d2cfee..8c5815a51d0e 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, enable_cae=True) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index a1188292bbce..40e40f2f730c 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -704,3 +704,75 @@ def get_token(*scopes, **kwargs): assert credential.get_token.call_count == 2 test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@empty_challenge_cache +def test_cae_token_expiry(): + """The policy should avoid sending claims more than once when a token expires.""" + + expected_content = b"a duck" + + def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + second_token = "second_token" + third_token = "third_token" + + class Requests: + count = 0 + + def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert second_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 4: + # fourth request should not include claims, but otherwise use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert third_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + def get_token(*scopes, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert kwargs.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert kwargs.get("claims") == expected_claim + return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh + # Token refresh before making the final request + elif Requests.count == 3: + assert kwargs.get("claims") == None + return AccessToken(third_token, time.time() + 3600) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + pipeline.run(request) + pipeline.run(request) # Send the request again to trigger a token refresh upon expiry + + # get_token is called for the KV and CAE challenges, as well as for the token refresh + assert credential.get_token.call_count == 3 + + test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 155002976cee..2fa1da39b35d 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -661,3 +661,76 @@ async def get_token(*scopes, **kwargs): assert credential.get_token.call_count == 2 await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) + + +@pytest.mark.asyncio +@empty_challenge_cache +async def test_cae_token_expiry(): + """The policy should avoid sending claims more than once when a token expires.""" + + expected_content = b"a duck" + + async def test_with_challenge(claims_challenge, expected_claim): + first_token = "first_token" + second_token = "second_token" + third_token = "third_token" + + class Requests: + count = 0 + + async def send(request): + Requests.count += 1 + if Requests.count == 1: + # first request should be unauthorized and have no content; triggers a KV challenge response + assert not request.body + assert "Authorization" not in request.headers + assert request.headers["Content-Length"] == "0" + return KV_CHALLENGE_RESPONSE + elif Requests.count == 2: + # second request will trigger a CAE challenge response in this test scenario + assert request.headers["Content-Length"] + assert request.body == expected_content + assert first_token in request.headers["Authorization"] + return claims_challenge + elif Requests.count == 3: + # third request should include the required claims and correctly use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert second_token in request.headers["Authorization"] + return Mock(status_code=200) + elif Requests.count == 4: + # fourth request should not include claims, but otherwise use content from the first challenge + assert request.headers["Content-Length"] + assert request.body == expected_content + assert third_token in request.headers["Authorization"] + return Mock(status_code=200) + raise ValueError("unexpected request") + + async def get_token(*scopes, **kwargs): + assert kwargs.get("enable_cae") == True + assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + assert scopes[0] == RESOURCE + "/.default" + # Response to KV challenge + if Requests.count == 1: + assert kwargs.get("claims") == None + return AccessToken(first_token, time.time() + 3600) + # Response to first CAE challenge + elif Requests.count == 2: + assert kwargs.get("claims") == expected_claim + return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh + # Token refresh before making the final request + elif Requests.count == 3: + assert kwargs.get("claims") == None + return AccessToken(third_token, time.time() + 3600) + + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) + request = HttpRequest("POST", get_random_url()) + request.set_bytes_body(expected_content) + await pipeline.run(request) + await pipeline.run(request) # Send the request again to trigger a token refresh upon expiry + + # get_token is called for the KV and CAE challenges, as well as for the token refresh + assert credential.get_token.call_count == 3 + + await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 46b4de7bdfa4..8ece1ecc018a 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -68,7 +68,7 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): """ def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: AsyncTokenCredential = credential self._token: Optional[AccessToken] = None @@ -159,9 +159,11 @@ async def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope) + self._token = await self._credential.get_token(scope, enable_cae=True) else: - self._token = await self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = await self._credential.get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 4ccd04d2cfee..8c5815a51d0e 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -80,7 +80,7 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): """ def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: - # Pass `enable_cae` so `enable_cae=True` is always passed to get_token + # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) self._credential: TokenCredential = credential self._token: Optional[AccessToken] = None @@ -168,9 +168,9 @@ def on_request(self, request: PipelineRequest) -> None: scope = challenge.get_scope() or challenge.get_resource() + "/.default" # Exclude tenant for AD FS authentication if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope) + self._token = self._credential.get_token(scope, enable_cae=True) else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id) + self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore From 8e65726754ef3096146035ce8fb5a2b1de40e8f2 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Tue, 8 Oct 2024 10:21:04 -0700 Subject: [PATCH 22/25] Bump Core dep for SupportsTokenInfo protocol support --- sdk/keyvault/azure-keyvault-administration/setup.py | 2 +- sdk/keyvault/azure-keyvault-certificates/setup.py | 2 +- sdk/keyvault/azure-keyvault-keys/setup.py | 2 +- sdk/keyvault/azure-keyvault-secrets/setup.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/setup.py b/sdk/keyvault/azure-keyvault-administration/setup.py index 60bbdd2a58a6..30c88eb275bd 100644 --- a/sdk/keyvault/azure-keyvault-administration/setup.py +++ b/sdk/keyvault/azure-keyvault-administration/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ], diff --git a/sdk/keyvault/azure-keyvault-certificates/setup.py b/sdk/keyvault/azure-keyvault-certificates/setup.py index 00390ee336a0..347d79fb93a0 100644 --- a/sdk/keyvault/azure-keyvault-certificates/setup.py +++ b/sdk/keyvault/azure-keyvault-certificates/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ], diff --git a/sdk/keyvault/azure-keyvault-keys/setup.py b/sdk/keyvault/azure-keyvault-keys/setup.py index cbb26cf86d49..7bbe28af42c5 100644 --- a/sdk/keyvault/azure-keyvault-keys/setup.py +++ b/sdk/keyvault/azure-keyvault-keys/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "cryptography>=2.1.4", "isodate>=0.6.1", "typing-extensions>=4.0.1", diff --git a/sdk/keyvault/azure-keyvault-secrets/setup.py b/sdk/keyvault/azure-keyvault-secrets/setup.py index 989d08c1e98a..54019bafe3b8 100644 --- a/sdk/keyvault/azure-keyvault-secrets/setup.py +++ b/sdk/keyvault/azure-keyvault-secrets/setup.py @@ -68,7 +68,7 @@ ), python_requires=">=3.8", install_requires=[ - "azure-core>=1.29.5", + "azure-core>=1.31.0", "isodate>=0.6.1", "typing-extensions>=4.0.1", ], From 93c7eaa72c0b0064ce389522b6c9ccedb644819c Mon Sep 17 00:00:00 2001 From: mccoyp Date: Tue, 8 Oct 2024 15:17:41 -0700 Subject: [PATCH 23/25] (Async)SupportsTokenInfo support/tests --- .../_internal/async_challenge_auth_policy.py | 55 +++-- .../_internal/challenge_auth_policy.py | 58 +++-- .../_shared/async_challenge_auth_policy.py | 55 +++-- .../_shared/challenge_auth_policy.py | 58 +++-- .../_shared/async_challenge_auth_policy.py | 55 +++-- .../keys/_shared/challenge_auth_policy.py | 58 +++-- .../tests/test_challenge_auth.py | 230 ++++++++++++------ .../tests/test_challenge_auth_async.py | 216 ++++++++++------ .../_shared/async_challenge_auth_policy.py | 55 +++-- .../secrets/_shared/challenge_auth_policy.py | 58 +++-- 10 files changed, 623 insertions(+), 275 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index 8ece1ecc018a..cf1bdcf7c9ae 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -16,17 +16,18 @@ from copy import deepcopy import time -from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse from typing_extensions import ParamSpec -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, enable_cae=True) - else: - self._token = await self._credential.get_token( - scope, tenant_id=challenge.tenant_id, enable_cae=True - ) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 8c5815a51d0e..71e133819300 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -16,10 +16,17 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, enable_cae=True) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index 8ece1ecc018a..cf1bdcf7c9ae 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -16,17 +16,18 @@ from copy import deepcopy import time -from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse from typing_extensions import ParamSpec -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, enable_cae=True) - else: - self._token = await self._credential.get_token( - scope, tenant_id=challenge.tenant_id, enable_cae=True - ) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 8c5815a51d0e..71e133819300 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -16,10 +16,17 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, enable_cae=True) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index 8ece1ecc018a..cf1bdcf7c9ae 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -16,17 +16,18 @@ from copy import deepcopy import time -from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse from typing_extensions import ParamSpec -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, enable_cae=True) - else: - self._token = await self._credential.get_token( - scope, tenant_id=challenge.tenant_id, enable_cae=True - ) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 8c5815a51d0e..71e133819300 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -16,10 +16,17 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, enable_cae=True) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py index 40e40f2f730c..de26cc59f07d 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth.py @@ -8,6 +8,7 @@ """ import base64 import functools +from itertools import product import os import time from unittest.mock import Mock, patch @@ -17,8 +18,8 @@ from devtools_testutils import recorded_by_proxy import pytest -from azure.core.credentials import AccessToken, TokenCredential -from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError +from azure.core.credentials import AccessToken, AccessTokenInfo +from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import Pipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest @@ -33,6 +34,8 @@ only_default_version = get_decorator(api_versions=[DEFAULT_VERSION]) +TOKEN_TYPES = [AccessToken, AccessTokenInfo] + class TestChallengeAuth(KeyVaultTestCase, KeysTestCase): @pytest.mark.parametrize("api_version,is_hsm", only_default_version) @KeysClientPreparer() @@ -155,7 +158,8 @@ def test_challenge_parsing(): @empty_challenge_cache -def test_scope(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_scope(token_type): """The policy's token requests should always be for an AADv2 scope""" expected_content = b"a duck" @@ -184,15 +188,21 @@ def send(request): def get_token(*scopes, **_): assert len(scopes) == 1 assert scopes[0] == expected_scope - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 endpoint = "https://authority.net/tenant" @@ -214,7 +224,8 @@ def get_token(*scopes, **_): @empty_challenge_cache -def test_tenant(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_tenant(token_type): """The policy's token requests should pass the parsed tenant ID from the challenge""" expected_content = b"a duck" @@ -240,17 +251,24 @@ def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - def get_token(*_, **kwargs): - assert kwargs.get("tenant_id") == expected_tenant - return AccessToken(expected_token, 0) + def get_token(*_, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("tenant_id") == expected_tenant + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" @@ -265,7 +283,8 @@ def get_token(*_, **kwargs): @empty_challenge_cache -def test_adfs(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_adfs(token_type): """The policy should handle AD FS challenges as a special case and omit the tenant ID from token requests""" expected_content = b"a duck" @@ -294,15 +313,21 @@ def send(request): def get_token(*_, **kwargs): # we shouldn't provide a tenant ID during AD FS authentication assert "tenant_id" not in kwargs - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) policy = ChallengeAuthPolicy(credential=credential) pipeline = Pipeline(policies=[policy], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621 policy._token = None @@ -321,7 +346,8 @@ def get_token(*_, **kwargs): test_with_challenge(challenge, tenant) -def test_policy_updates_cache(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_policy_updates_cache(token_type): """ It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant. When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists. @@ -360,23 +386,38 @@ def test_policy_updates_cache(): ), ) - credential = Mock(spec_set=["get_token"], get_token=Mock(return_value=AccessToken(first_token, time.time() + 3600))) + token = token_type(first_token, time.time() + 3600) + + def get_token(*_, **__): + return token + + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=transport) # policy should complete and cache the first challenge and access token for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # The next request will receive a new challenge. The policy should handle it and update caches. - credential.get_token.return_value = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 - + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 -def test_token_expiration(): +@empty_challenge_cache +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_token_expiration(token_type): """policy should not use a cached token which has expired""" url = get_random_url() @@ -386,12 +427,15 @@ def test_token_expiration(): second_token = "**" resource = "https://vault.azure.net" - token = AccessToken(first_token, expires_on) + token = token_type(first_token, expires_on) def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[ Request(), @@ -410,16 +454,23 @@ def get_token(*_, **__): for _ in range(2): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) with patch("time.time", lambda: expires_on): pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @empty_challenge_cache -def test_preserves_options_and_headers(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_preserves_options_and_headers(token_type): """After a challenge, the policy should send the original request with its options and headers preserved""" url = get_random_url() @@ -427,9 +478,12 @@ def test_preserves_options_and_headers(): resource = "https://vault.azure.net" def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})], @@ -471,8 +525,8 @@ def verify(request): @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -def test_verify_challenge_resource_matches(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +def test_verify_challenge_resource_matches(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource doesn't match the request URL unless check is disabled""" url = get_random_url() @@ -481,9 +535,12 @@ def test_verify_challenge_resource_matches(verify_challenge_resource): resource = "https://myvault.azure.net" # Doesn't match a "".vault.azure.net" resource because of the "my" prefix def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -524,8 +581,8 @@ def get_token(*_, **__): @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -def test_verify_challenge_resource_valid(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +def test_verify_challenge_resource_valid(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource isn't a valid URL unless check is disabled""" url = get_random_url() @@ -533,9 +590,12 @@ def test_verify_challenge_resource_valid(verify_challenge_resource): resource = "bad-resource" def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -559,7 +619,8 @@ def get_token(*_, **__): @empty_challenge_cache -def test_cae(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae(token_type): """The policy should handle claims in a challenge response after having successfully authenticated prior.""" expected_content = b"a duck" @@ -612,26 +673,30 @@ def send(request): return KV_CHALLENGE_RESPONSE raise ValueError("unexpected request") - def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) # Response to CAE challenge elif Requests.count == 3: - assert kwargs.get("claims") == expected_claim + assert options_bag.get("claims") == expected_claim return AccessToken(expected_token, time.time() + 3600) # Response to second KV challenge elif Requests.count == 5: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) elif Requests.count == 6: raise ValueError("unexpected token request") - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) @@ -639,14 +704,18 @@ def get_token(*scopes, **kwargs): pipeline.run(request) # Send the request again to trigger a CAE challenge pipeline.run(request) # Send the request once to trigger another regular auth challenge - # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge - assert credential.get_token.call_count == 3 + # token requests made for the CAE challenge and first two KV challenges, but not the final KV challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @empty_challenge_cache -def test_cae_consecutive_challenges(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae_consecutive_challenges(token_type): """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" expected_content = b"a duck" @@ -681,33 +750,41 @@ def send(request): return claims_challenge raise ValueError("unexpected request") - def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None - return AccessToken(first_token, time.time() + 3600) + assert options_bag.get("claims") == None + return token_type(first_token, time.time() + 3600) # Response to first CAE challenge elif Requests.count == 2: - assert kwargs.get("claims") == expected_claim - return AccessToken(expected_token, time.time() + 3600) + assert options_bag.get("claims") == expected_claim + return token_type(expected_token, time.time() + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) - # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge - assert credential.get_token.call_count == 2 + # token requests made for the KV challenge and first CAE challenge, but not the second CAE challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @empty_challenge_cache -def test_cae_token_expiry(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +def test_cae_token_expiry(token_type): """The policy should avoid sending claims more than once when a token expires.""" expected_content = b"a duck" @@ -748,31 +825,38 @@ def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None - return AccessToken(first_token, time.time() + 3600) + assert options_bag.get("claims") == None + return token_type(first_token, time.time() + 3600) # Response to first CAE challenge elif Requests.count == 2: - assert kwargs.get("claims") == expected_claim - return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh + assert options_bag.get("claims") == expected_claim + return token_type(second_token, 0) # Return a token that expires immediately to trigger a refresh # Token refresh before making the final request elif Requests.count == 3: - assert kwargs.get("claims") == None - return AccessToken(third_token, time.time() + 3600) + assert options_bag.get("claims") == None + return token_type(third_token, time.time() + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = Pipeline(policies=[ChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) pipeline.run(request) pipeline.run(request) # Send the request again to trigger a token refresh upon expiry - # get_token is called for the KV and CAE challenges, as well as for the token refresh - assert credential.get_token.call_count == 3 + # token requests made for the KV and CAE challenges, as well as for the token refresh + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py index 2fa1da39b35d..81ec711f6ad2 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_challenge_auth_async.py @@ -7,16 +7,15 @@ the challenge cache is global to the process. """ import asyncio -import base64 +from itertools import product import os import time from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential -from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError +from azure.core.credentials import AccessToken, AccessTokenInfo +from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import AsyncPipeline from azure.core.pipeline.policies import SansIOHTTPPolicy from azure.core.rest import HttpRequest @@ -38,6 +37,7 @@ KV_CHALLENGE_RESPONSE, KV_CHALLENGE_TENANT, RESOURCE, + TOKEN_TYPES, ) only_default_version = get_decorator(is_async=True, api_versions=[DEFAULT_VERSION]) @@ -95,7 +95,8 @@ async def test_enforces_tls(): @pytest.mark.asyncio @empty_challenge_cache -async def test_scope(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_scope(token_type): """The policy's token requests should always be for an AADv2 scope""" expected_content = b"a duck" @@ -124,9 +125,12 @@ async def send(request): async def get_token(*scopes, **_): assert len(scopes) == 1 assert scopes[0] == expected_scope - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline( policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send) ) @@ -134,7 +138,10 @@ async def get_token(*scopes, **_): request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 endpoint = "https://authority.net/tenant" @@ -157,7 +164,8 @@ async def get_token(*scopes, **_): @pytest.mark.asyncio @empty_challenge_cache -async def test_tenant(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_tenant(token_type): """The policy's token requests should pass the parsed tenant ID from the challenge""" expected_content = b"a duck" @@ -183,11 +191,15 @@ async def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - async def get_token(*_, **kwargs): - assert kwargs.get("tenant_id") == expected_tenant - return AccessToken(expected_token, 0) + async def get_token(*_, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("tenant_id") == expected_tenant + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline( policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send) ) @@ -195,7 +207,10 @@ async def get_token(*_, **kwargs): request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 tenant = "tenant-id" endpoint = f"https://authority.net/{tenant}" @@ -211,7 +226,8 @@ async def get_token(*_, **kwargs): @pytest.mark.asyncio @empty_challenge_cache -async def test_adfs(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_adfs(token_type): """The policy should handle AD FS challenges as a special case and omit the tenant ID from token requests""" expected_content = b"a duck" @@ -240,15 +256,21 @@ async def send(request): async def get_token(*_, **kwargs): # we shouldn't provide a tenant ID during AD FS authentication assert "tenant_id" not in kwargs - return AccessToken(expected_token, 0) + return token_type(expected_token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) policy = AsyncChallengeAuthPolicy(credential=credential) pipeline = AsyncPipeline(policies=[policy], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) await pipeline.run(request) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # Regression test: https://github.com/Azure/azure-sdk-for-python/issues/33621 policy._token = None @@ -269,7 +291,8 @@ async def get_token(*_, **kwargs): @pytest.mark.asyncio @empty_challenge_cache -async def test_policy_updates_cache(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_policy_updates_cache(token_type): """ It's possible for the challenge returned for a request to change, e.g. when a vault is moved to a new tenant. When the policy receives a 401, it should update the cached challenge for the requested URL, if one exists. @@ -308,29 +331,39 @@ async def test_policy_updates_cache(): ), ) - token = AccessToken(first_token, time.time() + 3600) + token = token_type(first_token, time.time() + 3600) async def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=transport) # policy should complete and cache the first challenge and access token for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 # The next request will receive a new challenge. The policy should handle it and update caches. - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @pytest.mark.asyncio @empty_challenge_cache -async def test_token_expiration(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_token_expiration(token_type): """policy should not use a cached token which has expired""" url = get_random_url() @@ -340,12 +373,15 @@ async def test_token_expiration(): second_token = "**" resource = "https://vault.azure.net" - token = AccessToken(first_token, expires_on) + token = token_type(first_token, expires_on) async def get_token(*_, **__): return token - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[ Request(), @@ -364,17 +400,24 @@ async def get_token(*_, **__): for _ in range(2): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 1 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 1 + else: + assert credential.get_token_info.call_count == 1 - token = AccessToken(second_token, time.time() + 3600) + token = token_type(second_token, time.time() + 3600) with patch("time.time", lambda: expires_on): await pipeline.run(HttpRequest("GET", url)) - assert credential.get_token.call_count == 2 + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 @pytest.mark.asyncio @empty_challenge_cache -async def test_preserves_options_and_headers(): +@pytest.mark.parametrize("token_type", TOKEN_TYPES) +async def test_preserves_options_and_headers(token_type): """After a challenge, the policy should send the original request with its options and headers preserved""" url = get_random_url() @@ -382,9 +425,12 @@ async def test_preserves_options_and_headers(): resource = "https://vault.azure.net" async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request()] * 2 + [Request(required_headers={"Authorization": "Bearer " + token})], @@ -426,8 +472,8 @@ def verify(request): @pytest.mark.asyncio @empty_challenge_cache -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -async def test_verify_challenge_resource_matches(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +async def test_verify_challenge_resource_matches(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource doesn't match the request URL unless check is disabled""" url = get_random_url() @@ -436,9 +482,12 @@ async def test_verify_challenge_resource_matches(verify_challenge_resource): resource = "https://myvault.azure.net" # Doesn't match a "".vault.azure.net" resource because of the "my" prefix async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -479,8 +528,8 @@ async def get_token(*_, **__): @pytest.mark.asyncio -@pytest.mark.parametrize("verify_challenge_resource", [True, False]) -async def test_verify_challenge_resource_valid(verify_challenge_resource): +@pytest.mark.parametrize("verify_challenge_resource,token_type", product([True, False], TOKEN_TYPES)) +async def test_verify_challenge_resource_valid(verify_challenge_resource, token_type): """The auth policy should raise if the challenge resource isn't a valid URL unless check is disabled""" url = get_random_url() @@ -488,9 +537,12 @@ async def test_verify_challenge_resource_valid(verify_challenge_resource): resource = "bad-resource" async def get_token(*_, **__): - return AccessToken(token, 0) + return token_type(token, 0) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) transport = async_validating_transport( requests=[Request(), Request(required_headers={"Authorization": f"Bearer {token}"})], @@ -515,7 +567,8 @@ async def get_token(*_, **__): @pytest.mark.asyncio @empty_challenge_cache -async def test_cae(): +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae(token_type): """The policy should handle claims in a challenge response after having successfully authenticated prior.""" expected_content = b"a duck" @@ -568,26 +621,30 @@ async def send(request): return KV_CHALLENGE_RESPONSE raise ValueError("unexpected request") - async def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) # Response to CAE challenge elif Requests.count == 3: - assert kwargs.get("claims") == expected_claim + assert options_bag.get("claims") == expected_claim return AccessToken(expected_token, time.time() + 3600) # Response to second KV challenge elif Requests.count == 5: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) elif Requests.count == 6: raise ValueError("unexpected token request") - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) @@ -595,15 +652,19 @@ async def get_token(*scopes, **kwargs): await pipeline.run(request) # Send the request again to trigger a CAE challenge await pipeline.run(request) # Send the request once to trigger another regular auth challenge - # get_token is called for the CAE challenge and first two KV challenges, but not the final KV challenge - assert credential.get_token.call_count == 3 + # token requests made for the CAE challenge and first two KV challenges, but not the final KV challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @pytest.mark.asyncio @empty_challenge_cache -async def test_cae_consecutive_challenges(): +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae_consecutive_challenges(token_type): """The policy should correctly handle consecutive challenges in cases where the flow is valid or invalid.""" expected_content = b"a duck" @@ -638,34 +699,42 @@ async def send(request): return claims_challenge raise ValueError("unexpected request") - async def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) # Response to first CAE challenge elif Requests.count == 2: - assert kwargs.get("claims") == expected_claim + assert options_bag.get("claims") == expected_claim return AccessToken(expected_token, time.time() + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) await pipeline.run(request) - # get_token is called for the KV challenge and first CAE challenge, but not the second CAE challenge - assert credential.get_token.call_count == 2 + # token requests made for the KV challenge and first CAE challenge, but not the second CAE challenge + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 2 + else: + assert credential.get_token_info.call_count == 2 await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) @pytest.mark.asyncio @empty_challenge_cache -async def test_cae_token_expiry(): +@pytest.mark.parametrize("token_type", [AccessToken, AccessTokenInfo]) +async def test_cae_token_expiry(token_type): """The policy should avoid sending claims more than once when a token expires.""" expected_content = b"a duck" @@ -706,31 +775,38 @@ async def send(request): return Mock(status_code=200) raise ValueError("unexpected request") - async def get_token(*scopes, **kwargs): - assert kwargs.get("enable_cae") == True - assert kwargs.get("tenant_id") == KV_CHALLENGE_TENANT + async def get_token(*scopes, options=None, **kwargs): + options_bag = options if options else kwargs + assert options_bag.get("enable_cae") == True + assert options_bag.get("tenant_id") == KV_CHALLENGE_TENANT assert scopes[0] == RESOURCE + "/.default" # Response to KV challenge if Requests.count == 1: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(first_token, time.time() + 3600) # Response to first CAE challenge elif Requests.count == 2: - assert kwargs.get("claims") == expected_claim + assert options_bag.get("claims") == expected_claim return AccessToken(second_token, 0) # Return a token that expires immediately to trigger a refresh # Token refresh before making the final request elif Requests.count == 3: - assert kwargs.get("claims") == None + assert options_bag.get("claims") == None return AccessToken(third_token, time.time() + 3600) - credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + if token_type == AccessToken: + credential = Mock(spec_set=["get_token"], get_token=Mock(wraps=get_token)) + else: + credential = Mock(spec_set=["get_token_info"], get_token_info=Mock(wraps=get_token)) pipeline = AsyncPipeline(policies=[AsyncChallengeAuthPolicy(credential=credential)], transport=Mock(send=send)) request = HttpRequest("POST", get_random_url()) request.set_bytes_body(expected_content) await pipeline.run(request) await pipeline.run(request) # Send the request again to trigger a token refresh upon expiry - # get_token is called for the KV and CAE challenges, as well as for the token refresh - assert credential.get_token.call_count == 3 + # token requests made for the KV and CAE challenges, as well as for the token refresh + if hasattr(credential, "get_token"): + assert credential.get_token.call_count == 3 + else: + assert credential.get_token_info.call_count == 3 await test_with_challenge(CAE_CHALLENGE_RESPONSE, CAE_DECODED_CLAIM) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index 8ece1ecc018a..cf1bdcf7c9ae 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -16,17 +16,18 @@ from copy import deepcopy import time -from typing import Any, Awaitable, Callable, Optional, overload, TypeVar, Union +from typing import Any, Awaitable, Callable, cast, Optional, overload, TypeVar, Union from urllib.parse import urlparse from typing_extensions import ParamSpec -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import AccessToken, AccessTokenInfo, TokenRequestOptions +from azure.core.credentials_async import AsyncSupportsTokenInfo, AsyncTokenCredential, AsyncTokenProvider from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy from azure.core.rest import AsyncHttpResponse, HttpRequest +from .http_challenge import HttpChallenge from . import http_challenge_cache as ChallengeCache from .challenge_auth_policy import _enforce_tls, _has_claims, _update_challenge @@ -64,14 +65,14 @@ class AsyncChallengeAuthPolicy(AsyncBearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity.aio` - :type credential: ~azure.core.credentials_async.AsyncTokenCredential + :type credential: ~azure.core.credentials_async.AsyncTokenProvider """ - def __init__(self, credential: AsyncTokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: AsyncTokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super().__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: AsyncTokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: AsyncTokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -157,16 +158,10 @@ async def on_request(self, request: PipelineRequest) -> None: if self._need_new_token(): # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = await self._credential.get_token(scope, enable_cae=True) - else: - self._token = await self._credential.get_token( - scope, tenant_id=challenge.tenant_id, enable_cae=True - ) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + await self._request_kv_token(scope, challenge) + + bearer_token = cast(Union[AccessToken, AccessTokenInfo], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -233,4 +228,28 @@ async def on_challenge(self, request: PipelineRequest, response: PipelineRespons return True def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The AsyncSupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = await cast(AsyncSupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = await self._credential.get_token(scope, enable_cae=True) + else: + self._token = await cast(AsyncTokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 8c5815a51d0e..71e133819300 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -16,10 +16,17 @@ from copy import deepcopy import time -from typing import Any, Optional +from typing import Any, cast, Optional, Union from urllib.parse import urlparse -from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials import ( + AccessToken, + AccessTokenInfo, + TokenCredential, + TokenProvider, + TokenRequestOptions, + SupportsTokenInfo, +) from azure.core.exceptions import ServiceRequestError from azure.core.pipeline import PipelineRequest, PipelineResponse from azure.core.pipeline.policies import BearerTokenCredentialPolicy @@ -76,14 +83,15 @@ class ChallengeAuthPolicy(BearerTokenCredentialPolicy): :param credential: An object which can provide an access token for the vault, such as a credential from :mod:`azure.identity` - :type credential: ~azure.core.credentials.TokenCredential + :type credential: ~azure.core.credentials.TokenProvider + :param str scopes: Lets you specify the type of access needed. """ - def __init__(self, credential: TokenCredential, *scopes: str, **kwargs: Any) -> None: + def __init__(self, credential: TokenProvider, *scopes: str, **kwargs: Any) -> None: # Pass `enable_cae` so `enable_cae=True` is always passed through self.authorize_request super(ChallengeAuthPolicy, self).__init__(credential, *scopes, enable_cae=True, **kwargs) - self._credential: TokenCredential = credential - self._token: Optional[AccessToken] = None + self._credential: TokenProvider = credential + self._token: Optional[Union["AccessToken", "AccessTokenInfo"]] = None self._verify_challenge_resource = kwargs.pop("verify_challenge_resource", True) self._request_copy: Optional[HttpRequest] = None @@ -166,14 +174,10 @@ def on_request(self, request: PipelineRequest) -> None: if self._need_new_token: # azure-identity credentials require an AADv2 scope but the challenge may specify an AADv1 resource scope = challenge.get_scope() or challenge.get_resource() + "/.default" - # Exclude tenant for AD FS authentication - if challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs"): - self._token = self._credential.get_token(scope, enable_cae=True) - else: - self._token = self._credential.get_token(scope, tenant_id=challenge.tenant_id, enable_cae=True) - - # ignore mypy's warning -- although self._token is Optional, get_token raises when it fails to get a token - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" # type: ignore + self._request_kv_token(scope, challenge) + + bearer_token = cast(Union["AccessToken", "AccessTokenInfo"], self._token).token + request.http_request.headers["Authorization"] = f"Bearer {bearer_token}" return # else: discover authentication information by eliciting a challenge from Key Vault. Remove any request data, @@ -238,4 +242,28 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) -> @property def _need_new_token(self) -> bool: - return not self._token or self._token.expires_on - time.time() < 300 + now = time.time() + refresh_on = getattr(self._token, "refresh_on", None) + return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 + + def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. + + :param str scope: The scope for which to request a token. + :param challenge: The challenge for the request being made. + """ + # Exclude tenant for AD FS authentication + exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") + # The SupportsTokenInfo protocol needs TokenRequestOptions for token requests instead of kwargs + if hasattr(self._credential, "get_token_info"): + options: TokenRequestOptions = {"enable_cae": True} + if challenge.tenant_id and not exclude_tenant: + options["tenant_id"] = challenge.tenant_id + self._token = cast(SupportsTokenInfo, self._credential).get_token_info(scope, options=options) + else: + if exclude_tenant: + self._token = self._credential.get_token(scope, enable_cae=True) + else: + self._token = cast(TokenCredential, self._credential).get_token( + scope, tenant_id=challenge.tenant_id, enable_cae=True + ) From bf25378ab744af202d0da057f2ffbc4a2393acda Mon Sep 17 00:00:00 2001 From: mccoyp Date: Tue, 8 Oct 2024 16:03:30 -0700 Subject: [PATCH 24/25] Pylint --- .../administration/_internal/async_challenge_auth_policy.py | 3 ++- .../keyvault/administration/_internal/challenge_auth_policy.py | 3 ++- .../certificates/_shared/async_challenge_auth_policy.py | 3 ++- .../keyvault/certificates/_shared/challenge_auth_policy.py | 3 ++- .../azure/keyvault/keys/_shared/async_challenge_auth_policy.py | 3 ++- .../azure/keyvault/keys/_shared/challenge_auth_policy.py | 3 ++- .../keyvault/secrets/_shared/async_challenge_auth_policy.py | 3 ++- .../azure/keyvault/secrets/_shared/challenge_auth_policy.py | 3 ++- 8 files changed, 16 insertions(+), 8 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py index cf1bdcf7c9ae..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/async_challenge_auth_policy.py @@ -232,11 +232,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py index 71e133819300..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-administration/azure/keyvault/administration/_internal/challenge_auth_policy.py @@ -246,11 +246,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py index cf1bdcf7c9ae..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_challenge_auth_policy.py @@ -232,11 +232,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py index 71e133819300..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/challenge_auth_policy.py @@ -246,11 +246,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py index cf1bdcf7c9ae..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_challenge_auth_policy.py @@ -232,11 +232,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py index 71e133819300..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/challenge_auth_policy.py @@ -246,11 +246,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py index cf1bdcf7c9ae..e9b44fc68e55 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_challenge_auth_policy.py @@ -232,11 +232,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - async def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + async def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py index 71e133819300..eb4073d0e699 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/challenge_auth_policy.py @@ -246,11 +246,12 @@ def _need_new_token(self) -> bool: refresh_on = getattr(self._token, "refresh_on", None) return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300 - def _request_kv_token(self, scope: str, challenge: HttpChallenge, **kwargs: Any) -> None: + def _request_kv_token(self, scope: str, challenge: HttpChallenge) -> None: """Implementation of BearerTokenCredentialPolicy's _request_token method, but specific to Key Vault. :param str scope: The scope for which to request a token. :param challenge: The challenge for the request being made. + :type challenge: HttpChallenge """ # Exclude tenant for AD FS authentication exclude_tenant = challenge.tenant_id and challenge.tenant_id.lower().endswith("adfs") From 3cb6481008ee88fd3d65ffdcb75d498d3667e787 Mon Sep 17 00:00:00 2001 From: mccoyp Date: Wed, 16 Oct 2024 11:10:38 -0700 Subject: [PATCH 25/25] Mention Core bump, enable_cae kwarg in changelogs --- sdk/keyvault/azure-keyvault-administration/CHANGELOG.md | 3 ++- sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md | 3 ++- sdk/keyvault/azure-keyvault-keys/CHANGELOG.md | 3 ++- sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md index 5fbdb54852ae..cc1929a10aa2 100644 --- a/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-administration/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` -- Added support for Continuous Access Evaluation (CAE) +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -13,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.4.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md index 130535c01007..40fbe4b1674a 100644 --- a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` -- Added support for Continuous Access Evaluation (CAE) +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -13,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.8.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md index 6cb678d30d13..faf4085e14e1 100644 --- a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` -- Added support for Continuous Access Evaluation (CAE) +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -13,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.9.0 (2024-02-22) diff --git a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md index 2a886c7aedce..e02843b0c0cd 100644 --- a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added - Added support for service API version `7.6-preview.1` -- Added support for Continuous Access Evaluation (CAE) +- Added support for Continuous Access Evaluation (CAE). `enable_cae=True` is passed to all `get_token` requests. ### Breaking Changes @@ -13,6 +13,7 @@ ([#34744](https://github.com/Azure/azure-sdk-for-python/issues/34744)) ### Other Changes +- Updated minimum `azure-core` version to 1.31.0 - Key Vault API version `7.6-preview.1` is now the default ## 4.8.0 (2024-02-22)