From 21a088d7a7b56af1760f8b40976a286dfd46ab82 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 May 2024 15:21:41 +0200 Subject: [PATCH 1/8] Infer azure tenant id --- databricks/sdk/azure.py | 19 +++++++++++++++++++ databricks/sdk/credentials_provider.py | 20 +++++++++----------- tests/test_azure.py | 20 ++++++++++++++++++++ 3 files changed, 48 insertions(+), 11 deletions(-) create mode 100644 tests/test_azure.py diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index ec084cf22..1cb10ff90 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -1,10 +1,16 @@ +import logging from dataclasses import dataclass from typing import Dict +from urllib import parse + +import requests from .oauth import TokenSource from .service.provisioning import Workspace +logger = logging.getLogger(__name__) + @dataclass class AzureEnvironment: name: str @@ -52,3 +58,16 @@ def get_azure_resource_id(workspace: Workspace): return (f'/subscriptions/{workspace.azure_workspace_info.subscription_id}' f'/resourceGroups/{workspace.azure_workspace_info.resource_group}' f'/providers/Microsoft.Databricks/workspaces/{workspace.workspace_name}') + + +def _load_azure_tenant_id(cfg: 'Config'): + if not cfg.is_azure or cfg.azure_tenant_id is not None or cfg.host is None: + return + logging.debug(f'Loading tenant ID from {cfg.host}/aad/auth') + resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False) + entra_id_endpoint = resp.headers.get('Location') + if entra_id_endpoint is None: + return + url = parse.urlparse(entra_id_endpoint) + cfg.azure_tenant_id = url.path.split('/')[1] + logging.debug(f'Loaded tenant ID: {cfg.azure_tenant_id}') diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 027275998..24b029224 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -18,7 +18,7 @@ from google.auth.transport.requests import Request from google.oauth2 import service_account -from .azure import add_sp_management_token, add_workspace_id_header +from .azure import add_sp_management_token, add_workspace_id_header, _load_azure_tenant_id from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) @@ -179,11 +179,10 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS @credentials_provider('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) + ['is_azure', 'azure_client_id', 'azure_client_secret']) def azure_service_principal(cfg: 'Config') -> HeaderFactory: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ - def token_source_for(resource: str) -> TokenSource: aad_endpoint = cfg.arm_environment.active_directory_endpoint return ClientCredentials(client_id=cfg.azure_client_id, @@ -192,6 +191,7 @@ def token_source_for(resource: str) -> TokenSource: endpoint_params={"resource": resource}, use_params=True) + _load_azure_tenant_id(cfg) _ensure_host_present(cfg, token_source_for) logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) @@ -363,11 +363,13 @@ def refresh(self) -> Token: class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ - def __init__(self, resource: str, subscription: str = ""): + def __init__(self, resource: str, subscription: str = "", tenant: str = None): cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] if subscription != "": cmd.append("--subscription") cmd.append(subscription) + if tenant: + cmd.extend(["--tenant", tenant]) super().__init__(cmd=cmd, token_type_field='tokenType', access_token_field='accessToken', @@ -395,7 +397,7 @@ def is_human_user(self) -> bool: @staticmethod def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) - if subscription != "": + if cfg.azure_tenant_id == "" and subscription != "": token_source = AzureCliTokenSource(resource, subscription) try: # This will fail if the user has access to the workspace, but not to the subscription @@ -406,7 +408,7 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': except OSError: logger.warning("Failed to get token for subscription. Using resource only token.") - token_source = AzureCliTokenSource(resource) + token_source = AzureCliTokenSource(resource, cfg.azure_tenant_id) token_source.token() return token_source @@ -425,6 +427,7 @@ def get_subscription(cfg: 'Config') -> str: @credentials_provider('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ + _load_azure_tenant_id(cfg) token_source = None mgmt_token_source = None try: @@ -448,11 +451,6 @@ def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource)) logger.info("Using Azure CLI authentication with AAD tokens") - if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "": - logger.warning( - "azure_workspace_resource_id field not provided. " - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors." - ) def inner() -> Dict[str, str]: token = token_source.token() diff --git a/tests/test_azure.py b/tests/test_azure.py new file mode 100644 index 000000000..9d1b1d2fb --- /dev/null +++ b/tests/test_azure.py @@ -0,0 +1,20 @@ +from databricks.sdk.config import Config +import os + +__tests__ = os.path.dirname(__file__) + + +def test_load_azure_tenant_id(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id == 'abc123xyz' + assert mock.called_once + + +def test_load_azure_tenant_id_tenant_id_set(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net", azure_tenant_id="123456789") + assert cfg.azure_tenant_id == '123456789' + assert mock.call_count == 0 From 0069bd6fba468f762dce4f151335fa492f788223 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 6 May 2024 15:23:40 +0200 Subject: [PATCH 2/8] logging --- databricks/sdk/azure.py | 1 + 1 file changed, 1 insertion(+) diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index 1cb10ff90..151a98cb6 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -67,6 +67,7 @@ def _load_azure_tenant_id(cfg: 'Config'): resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False) entra_id_endpoint = resp.headers.get('Location') if entra_id_endpoint is None: + logging.debug(f'No Location header in response from {cfg.host}/aad/auth') return url = parse.urlparse(entra_id_endpoint) cfg.azure_tenant_id = url.path.split('/')[1] From 8a59c084055fdf3bc623802c70033f9f80322438 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 09:47:54 +0200 Subject: [PATCH 3/8] take two --- databricks/sdk/azure.py | 17 ------------- databricks/sdk/config.py | 22 ++++++++++++++++ databricks/sdk/credentials_provider.py | 20 +++++++-------- tests/test_azure.py | 20 --------------- tests/test_config.py | 35 ++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 48 deletions(-) delete mode 100644 tests/test_azure.py diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index f06477f88..372669d61 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -1,7 +1,4 @@ from typing import Dict -from urllib import parse - -import requests from .oauth import TokenSource from .service.provisioning import Workspace @@ -28,17 +25,3 @@ def get_azure_resource_id(workspace: Workspace): return (f'/subscriptions/{workspace.azure_workspace_info.subscription_id}' f'/resourceGroups/{workspace.azure_workspace_info.resource_group}' f'/providers/Microsoft.Databricks/workspaces/{workspace.workspace_name}') - - -def _load_azure_tenant_id(cfg: 'Config'): - if not cfg.is_azure or cfg.azure_tenant_id is not None or cfg.host is None: - return - logging.debug(f'Loading tenant ID from {cfg.host}/aad/auth') - resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False) - entra_id_endpoint = resp.headers.get('Location') - if entra_id_endpoint is None: - logging.debug(f'No Location header in response from {cfg.host}/aad/auth') - return - url = parse.urlparse(entra_id_endpoint) - cfg.azure_tenant_id = url.path.split('/')[1] - logging.debug(f'Loaded tenant ID: {cfg.azure_tenant_id}') diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 47d0ecc44..0d9823231 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -119,6 +119,7 @@ def __init__(self, self._load_from_env() self._known_file_config_loader() self._fix_host_if_needed() + self._load_azure_tenant_id() self._validate() self.init_auth() self._init_product(product, product_version) @@ -363,6 +364,27 @@ def _fix_host_if_needed(self): self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + def _load_azure_tenant_id(self): + if not self.is_azure or self.azure_tenant_id is not None or self.host is None: + return + login_url = f'{self.host}/aad/auth' + logger.debug(f'Loading tenant ID from {login_url}') + resp = requests.get(login_url, allow_redirects=False) + if resp.status_code // 100 != 3: + logger.debug(f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') + return + entra_id_endpoint = resp.headers.get('Location') + if entra_id_endpoint is None: + logger.debug(f'No Location header in response from {login_url}') + return + url = urllib.parse.urlparse(entra_id_endpoint) + path_segments = url.path.split('/') + if len(path_segments) < 2: + logger.debug(f'Invalid path in Location header: {url.path}') + return + self.azure_tenant_id = path_segments[1] + logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}') + def _set_inner_config(self, keyword_args: Dict[str, any]): for attr in self.attributes(): if attr.name not in keyword_args: diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8738d5116..27237376a 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -18,7 +18,7 @@ from google.auth.transport.requests import Request from google.oauth2 import service_account -from .azure import add_sp_management_token, add_workspace_id_header, _load_azure_tenant_id +from .azure import add_sp_management_token, add_workspace_id_header from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) @@ -246,7 +246,6 @@ def token_source_for(resource: str) -> TokenSource: endpoint_params={"resource": resource}, use_params=True) - _load_azure_tenant_id(cfg) _ensure_host_present(cfg, token_source_for) logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) @@ -432,9 +431,9 @@ def refresh(self) -> Token: class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ - def __init__(self, resource: str, subscription: str = "", tenant: str = None): + def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None): cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - if subscription != "": + if subscription is not None: cmd.append("--subscription") cmd.append(subscription) if tenant: @@ -466,8 +465,8 @@ def is_human_user(self) -> bool: @staticmethod def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) - if cfg.azure_tenant_id == "" and subscription != "": - token_source = AzureCliTokenSource(resource, subscription) + if subscription is not None: + token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id) try: # This will fail if the user has access to the workspace, but not to the subscription # itself. @@ -477,26 +476,25 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': except OSError: logger.warning("Failed to get token for subscription. Using resource only token.") - token_source = AzureCliTokenSource(resource, cfg.azure_tenant_id) + token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id) token_source.token() return token_source @staticmethod - def get_subscription(cfg: 'Config') -> str: + def get_subscription(cfg: 'Config') -> Optional[str]: resource = cfg.azure_workspace_resource_id if resource is None or resource == "": - return "" + return None components = resource.split('/') if len(components) < 3: logger.warning("Invalid azure workspace resource ID") - return "" + return None return components[2] @credentials_strategy('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ - _load_azure_tenant_id(cfg) token_source = None mgmt_token_source = None try: diff --git a/tests/test_azure.py b/tests/test_azure.py deleted file mode 100644 index 9d1b1d2fb..000000000 --- a/tests/test_azure.py +++ /dev/null @@ -1,20 +0,0 @@ -from databricks.sdk.config import Config -import os - -__tests__ = os.path.dirname(__file__) - - -def test_load_azure_tenant_id(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) - cfg = Config(host="https://abc123.azuredatabricks.net") - assert cfg.azure_tenant_id == 'abc123xyz' - assert mock.called_once - - -def test_load_azure_tenant_id_tenant_id_set(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) - cfg = Config(host="https://abc123.azuredatabricks.net", azure_tenant_id="123456789") - assert cfg.azure_tenant_id == '123456789' - assert mock.call_count == 0 diff --git a/tests/test_config.py b/tests/test_config.py index 4d3a0ebef..701333e38 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,6 +8,9 @@ from .conftest import noop_credentials +import os + +__tests__ = os.path.dirname(__file__) def test_config_supports_legacy_credentials_provider(): c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3') @@ -74,3 +77,35 @@ def test_config_copy_deep_copies_user_agent_other_info(config): assert "blueprint/0.4.6" in config.user_agent assert "blueprint/0.4.6" in config_copy.user_agent useragent._reset_extra(original_extra) + + +def test_load_azure_tenant_id_404(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://unexpected-location'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id == 'tenant-id' + assert mock.called_once From 5f4e8f6ecc51f46061603184ba9f9aabab9b4844 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 09:51:13 +0200 Subject: [PATCH 4/8] only when needed --- databricks/sdk/config.py | 6 ++++-- databricks/sdk/credentials_provider.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 0d9823231..302ca5c92 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -119,7 +119,6 @@ def __init__(self, self._load_from_env() self._known_file_config_loader() self._fix_host_if_needed() - self._load_azure_tenant_id() self._validate() self.init_auth() self._init_product(product, product_version) @@ -364,7 +363,10 @@ def _fix_host_if_needed(self): self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) - def _load_azure_tenant_id(self): + def load_azure_tenant_id(self): + """[Internal] Load the Azure tenant ID from the Azure Databricks login page. + + If the tenant ID is already set, this method does nothing.""" if not self.is_azure or self.azure_tenant_id is not None or self.host is None: return login_url = f'{self.host}/aad/auth' diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 27237376a..3c1f47fb4 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -246,6 +246,7 @@ def token_source_for(resource: str) -> TokenSource: endpoint_params={"resource": resource}, use_params=True) + cfg.load_azure_tenant_id() _ensure_host_present(cfg, token_source_for) logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) @@ -495,6 +496,7 @@ def get_subscription(cfg: 'Config') -> Optional[str]: @credentials_strategy('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ + cfg.load_azure_tenant_id() token_source = None mgmt_token_source = None try: From 82557229f747e7719d0bc14fbf3cc28f276d3e37 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 09:52:26 +0200 Subject: [PATCH 5/8] flip order --- databricks/sdk/credentials_provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 3c1f47fb4..4f3896d2b 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -246,8 +246,8 @@ def token_source_for(resource: str) -> TokenSource: endpoint_params={"resource": resource}, use_params=True) - cfg.load_azure_tenant_id() _ensure_host_present(cfg, token_source_for) + cfg.load_azure_tenant_id() logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) cloud = token_source_for(cfg.arm_environment.service_management_endpoint) From bd842adc0f71822a61e0dec91b7c4980e9a3de99 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 10:00:57 +0200 Subject: [PATCH 6/8] work --- databricks/sdk/config.py | 3 ++- databricks/sdk/credentials_provider.py | 8 +++++--- tests/conftest.py | 13 +++++++++++++ tests/test_auth.py | 9 ++++++--- tests/test_auth_manual_tests.py | 15 ++++++++++----- tests/test_config.py | 13 +++++++++---- 6 files changed, 45 insertions(+), 16 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 302ca5c92..ad06fc247 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -373,7 +373,8 @@ def load_azure_tenant_id(self): logger.debug(f'Loading tenant ID from {login_url}') resp = requests.get(login_url, allow_redirects=False) if resp.status_code // 100 != 3: - logger.debug(f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') + logger.debug( + f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') return entra_id_endpoint = resp.headers.get('Location') if entra_id_endpoint is None: diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 4f3896d2b..cfdf80e0d 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -233,11 +233,11 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" -@oauth_credentials_strategy('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret']) +@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret']) def azure_service_principal(cfg: 'Config') -> CredentialsProvider: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ + def token_source_for(resource: str) -> TokenSource: aad_endpoint = cfg.arm_environment.active_directory_endpoint return ClientCredentials(client_id=cfg.azure_client_id, @@ -467,7 +467,9 @@ def is_human_user(self) -> bool: def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) if subscription is not None: - token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id) + token_source = AzureCliTokenSource(resource, + subscription=subscription, + tenant=cfg.azure_tenant_id) try: # This will fail if the user has access to the workspace, but not to the subscription # itself. diff --git a/tests/conftest.py b/tests/conftest.py index a7e520dc9..0f415ecf1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,3 +77,16 @@ def set_az_path(monkeypatch): monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe') else: monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin") + + +@pytest.fixture +def mock_tenant(requests_mock): + + def stub_tenant_request(host, tenant_id="test-tenant-id"): + mock = requests_mock.get( + f'https://{host}/aad/auth', + status_code=302, + headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'}) + return mock + + return stub_tenant_request diff --git a/tests/test_auth.py b/tests/test_auth.py index fd73378b2..cd8f3cfc1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -193,9 +193,10 @@ def test_config_azure_pat(): assert cfg.is_azure -def test_config_azure_cli_host(monkeypatch): +def test_config_azure_cli_host(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -229,9 +230,10 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws') -def test_config_azure_cli_host_and_resource_id(monkeypatch): +def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -239,10 +241,11 @@ def test_config_azure_cli_host_and_resource_id(monkeypatch): assert cfg.is_azure -def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch): +def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant): monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost') set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py index e2874c427..34aa3a9c2 100644 --- a/tests/test_auth_manual_tests.py +++ b/tests/test_auth_manual_tests.py @@ -3,9 +3,10 @@ from .conftest import set_az_path, set_home -def test_azure_cli_workspace_header_present(monkeypatch): +def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -14,9 +15,10 @@ def test_azure_cli_workspace_header_present(monkeypatch): assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id -def test_azure_cli_user_with_management_access(monkeypatch): +def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -24,9 +26,10 @@ def test_azure_cli_user_with_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_user_no_management_access(monkeypatch): +def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate() -def test_azure_cli_fallback(monkeypatch): +def test_azure_cli_fallback(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'subscription') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_with_warning_on_stderr(monkeypatch): +def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('WARN', 'this is a warning') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', diff --git a/tests/test_config.py b/tests/test_config.py index 701333e38..894dd78a3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import os import platform import pytest @@ -8,10 +9,9 @@ from .conftest import noop_credentials -import os - __tests__ = os.path.dirname(__file__) + def test_config_supports_legacy_credentials_provider(): c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3') c2 = c.copy() @@ -97,7 +97,9 @@ def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://unexpected-location'}) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://unexpected-location'}) cfg = Config(host="https://abc123.azuredatabricks.net") assert cfg.azure_tenant_id is None assert mock.called_once @@ -105,7 +107,10 @@ def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypa def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) + mock = requests_mock.get( + 'https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) cfg = Config(host="https://abc123.azuredatabricks.net") assert cfg.azure_tenant_id == 'tenant-id' assert mock.called_once From 825e144a2daa42953f1877947820b83b2a4813a3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 11:31:57 +0200 Subject: [PATCH 7/8] comment --- databricks/sdk/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index ad06fc247..28d57ad42 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -380,6 +380,8 @@ def load_azure_tenant_id(self): if entra_id_endpoint is None: logger.debug(f'No Location header in response from {login_url}') return + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). url = urllib.parse.urlparse(entra_id_endpoint) path_segments = url.path.split('/') if len(path_segments) < 2: From be5edeaef80337460c11b5268e2c700b3f1b4ee4 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 11:55:03 +0200 Subject: [PATCH 8/8] fix windows tests --- tests/test_config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 894dd78a3..4bab85cf1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ from databricks.sdk.config import Config, with_product, with_user_agent_extra from databricks.sdk.version import __version__ -from .conftest import noop_credentials +from .conftest import noop_credentials, set_az_path __tests__ = os.path.dirname(__file__) @@ -80,7 +80,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config): def test_load_azure_tenant_id_404(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + set_az_path(monkeypatch) mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404) cfg = Config(host="https://abc123.azuredatabricks.net") assert cfg.azure_tenant_id is None @@ -88,7 +88,7 @@ def test_load_azure_tenant_id_404(requests_mock, monkeypatch): def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + set_az_path(monkeypatch) mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302) cfg = Config(host="https://abc123.azuredatabricks.net") assert cfg.azure_tenant_id is None @@ -96,7 +96,7 @@ def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + set_az_path(monkeypatch) mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://unexpected-location'}) @@ -106,7 +106,7 @@ def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypa def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + set_az_path(monkeypatch) mock = requests_mock.get( 'https://abc123.azuredatabricks.net/aad/auth', status_code=302,