From 832d55430aa14d562e0f2f7339ead31fb3edb5b3 Mon Sep 17 00:00:00 2001 From: Serge Smertin <259697+nfx@users.noreply.github.com> Date: Tue, 28 Nov 2023 13:44:18 +0100 Subject: [PATCH] Introduce `DatabricksEnvironment` and fix Azure MSI auth from ACR, where IMDS doesn't give host environment information (#700) ## Changes This PR allows determining Azure Environment from a Databricks account or workspace hostname, removing the need for a separate configuration/environment variable and complexities related to Azure MSI from within ACR. Similar functionality in Python SDK: - https://github.com/databricks/databricks-sdk-py/pull/390 Stacked on top of: - https://github.com/databricks/databricks-sdk-go/pull/699 ## Tests - [x] `make test` passing - [x] `make fmt` applied - [ ] relevant integration tests applied --- config/auth_azure_cli.go | 10 +-- config/auth_azure_client_secret.go | 17 +++-- config/auth_azure_msi.go | 52 ++-------------- config/auth_permutations_test.go | 8 +++ config/azure.go | 55 +++-------------- config/azure_test.go | 19 ------ config/config.go | 23 ++++--- config/environments.go | 99 ++++++++++++++++++++++++++++++ 8 files changed, 146 insertions(+), 137 deletions(-) delete mode 100644 config/azure_test.go create mode 100644 config/environments.go diff --git a/config/auth_azure_cli.go b/config/auth_azure_cli.go index ff8762748..6fcf8419f 100644 --- a/config/auth_azure_cli.go +++ b/config/auth_azure_cli.go @@ -30,7 +30,7 @@ func (c AzureCliCredentials) Name() string { // implementing azureHostResolver for ensureWorkspaceUrl to work func (c AzureCliCredentials) tokenSourceFor( - ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource { + ctx context.Context, cfg *Config, _, resource string) oauth2.TokenSource { return &azureCliTokenSource{resource: resource} } @@ -44,11 +44,7 @@ func (c AzureCliCredentials) tokenSourceFor( // If the user can't access the service management endpoint, we assume they are in case 2 and do not pass the service // management token. Otherwise, we always pass the service management token. func (c AzureCliCredentials) getVisitor(ctx context.Context, cfg *Config, inner oauth2.TokenSource) (func(*http.Request) error, error) { - env, err := cfg.GetAzureEnvironment() - if err != nil { - return nil, err - } - ts := &azureCliTokenSource{env.ServiceManagementEndpoint, ""} + ts := &azureCliTokenSource{cfg.Environment().AzureServiceManagementEndpoint(), ""} t, err := ts.Token() if err != nil { logger.Debugf(ctx, "Not including service management token in headers: %v", err) @@ -63,7 +59,7 @@ func (c AzureCliCredentials) Configure(ctx context.Context, cfg *Config) (func(* return nil, nil } // Eagerly get a token to fail fast in case the user is not logged in with the Azure CLI. - ts := &azureCliTokenSource{cfg.getAzureLoginAppID(), cfg.AzureResourceID} + ts := &azureCliTokenSource{cfg.Environment().azureApplicationID, cfg.AzureResourceID} t, err := ts.Token() if err != nil { if strings.Contains(err.Error(), "No subscription found") { diff --git a/config/auth_azure_client_secret.go b/config/auth_azure_client_secret.go index bc2c55592..b6a0a3732 100644 --- a/config/auth_azure_client_secret.go +++ b/config/auth_azure_client_secret.go @@ -21,11 +21,11 @@ func (c AzureClientSecretCredentials) Name() string { } func (c AzureClientSecretCredentials) tokenSourceFor( - ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource { + ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource { return (&clientcredentials.Config{ ClientID: cfg.AzureClientID, ClientSecret: cfg.AzureClientSecret, - TokenURL: fmt.Sprintf("%s%s/oauth2/token", env.ActiveDirectoryEndpoint, cfg.AzureTenantID), + TokenURL: fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, cfg.AzureTenantID), EndpointParams: url.Values{ "resource": []string{resource}, }, @@ -43,18 +43,17 @@ func (c AzureClientSecretCredentials) Configure(ctx context.Context, cfg *Config if !cfg.IsAzure() { return nil, nil } - env, err := cfg.GetAzureEnvironment() - if err != nil { - return nil, err - } ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) - err = cfg.azureEnsureWorkspaceUrl(ctx, c) + err := cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) } logger.Infof(ctx, "Generating AAD token for Service Principal (%s)", cfg.AzureClientID) refreshCtx := context.Background() - inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, cfg.getAzureLoginAppID())) - management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, env, env.ServiceManagementEndpoint)) + env := cfg.Environment() + aadEndpoint := env.AzureActiveDirectoryEndpoint() + managementEndpoint := env.AzureServiceManagementEndpoint() + inner := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, env.azureApplicationID)) + management := azureReuseTokenSource(nil, c.tokenSourceFor(refreshCtx, cfg, aadEndpoint, managementEndpoint)) return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil } diff --git a/config/auth_azure_msi.go b/config/auth_azure_msi.go index 9f2218ec2..5ffd1e3d2 100644 --- a/config/auth_azure_msi.go +++ b/config/auth_azure_msi.go @@ -31,70 +31,28 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(* if !cfg.IsAzure() || !cfg.AzureUseMSI || (cfg.AzureResourceID == "" && !cfg.IsAccountClient()) { return nil, nil } - env, err := c.getInstanceEnvironment(ctx) - if err != nil { - return nil, err - } + env := cfg.Environment() ctx = httpclient.DefaultClient.InContextForOAuth2(ctx) if !cfg.IsAccountClient() { - err = cfg.azureEnsureWorkspaceUrl(ctx, c) + err := cfg.azureEnsureWorkspaceUrl(ctx, c) if err != nil { return nil, fmt.Errorf("resolve host: %w", err) } } logger.Debugf(ctx, "Generating AAD token via Azure MSI") inner := azureReuseTokenSource(nil, azureMsiTokenSource{ - resource: cfg.getAzureLoginAppID(), + resource: env.azureApplicationID, clientId: cfg.AzureClientID, }) management := azureReuseTokenSource(nil, azureMsiTokenSource{ - resource: env.ServiceManagementEndpoint, + resource: env.AzureServiceManagementEndpoint(), clientId: cfg.AzureClientID, }) return azureVisitor(cfg, serviceToServiceVisitor(inner, management, xDatabricksAzureSpManagementToken)), nil } -func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azureEnvironment, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, - fmt.Sprintf("%s/instance", instanceMetadataPrefix), nil) - if err != nil { - return nil, fmt.Errorf("metadata request: %w", err) - } - query := req.URL.Query() - query.Add("api-version", "2021-12-13") - req.URL.RawQuery = query.Encode() - req.Header.Add("Metadata", "true") - res, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("metadata response: %w", err) - } - raw, err := io.ReadAll(res.Body) - if err != nil { - return nil, fmt.Errorf("metadata read: %w", err) - } - if res.StatusCode != 200 { - return nil, fmt.Errorf("metadata error: %s", raw) - } - var metadata struct { - Compute struct { - Environment string `json:"azEnvironment"` - } `json:"compute"` - } - err = json.Unmarshal(raw, &metadata) - if err != nil { - return nil, fmt.Errorf("metadata parse: %w", err) - } - for _, v := range azureEnvironments { - if v.Name == metadata.Compute.Environment { - return &v, nil - } - } - return nil, fmt.Errorf("cannot determine environment: %s", - metadata.Compute.Environment) -} - // implementing azureHostResolver for ensureWorkspaceUrl to work -func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _ azureEnvironment, resource string) oauth2.TokenSource { +func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _, resource string) oauth2.TokenSource { return azureMsiTokenSource{ resource: resource, clientId: cfg.AzureClientID, diff --git a/config/auth_permutations_test.go b/config/auth_permutations_test.go index d37dd5145..daab9471f 100644 --- a/config/auth_permutations_test.go +++ b/config/auth_permutations_test.go @@ -113,6 +113,14 @@ func (cf configFixture) configureProviderAndReturnConfig(t *testing.T) (*Config, AzureResourceID: cf.AzureResourceID, AuthType: cf.AuthType, } + if client.IsAzure() { + client.DatabricksEnvironments = append(client.DatabricksEnvironments, DatabricksEnvironment{ + Cloud: CloudAzure, + dnsZone: cf.Host, + azureApplicationID: "abc", + azureEnvironment: &publicCloud, + }) + } err := client.Authenticate(&http.Request{Header: http.Header{}}) if err != nil { return nil, err diff --git a/config/azure.go b/config/azure.go index aa93f8140..28b90304c 100644 --- a/config/azure.go +++ b/config/azure.go @@ -3,7 +3,6 @@ package config import ( "context" "fmt" - "strings" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" @@ -20,73 +19,45 @@ type azureEnvironment struct { // based on github.com/Azure/go-autorest/autorest/azure/azureEnvironments.go var ( publicCloud = azureEnvironment{ - Name: "AzurePublicCloud", + Name: "PUBLIC", ServiceManagementEndpoint: "https://management.core.windows.net/", ResourceManagerEndpoint: "https://management.azure.com/", ActiveDirectoryEndpoint: "https://login.microsoftonline.com/", } usGovernmentCloud = azureEnvironment{ - Name: "AzureUSGovernmentCloud", + Name: "USGOVERNMENT", ServiceManagementEndpoint: "https://management.core.usgovcloudapi.net/", ResourceManagerEndpoint: "https://management.usgovcloudapi.net/", ActiveDirectoryEndpoint: "https://login.microsoftonline.us/", } chinaCloud = azureEnvironment{ - Name: "AzureChinaCloud", + Name: "CHINA", ServiceManagementEndpoint: "https://management.core.chinacloudapi.cn/", ResourceManagerEndpoint: "https://management.chinacloudapi.cn/", ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/", } - - germanCloud = azureEnvironment{ - Name: "AzureGermanCloud", - ServiceManagementEndpoint: "https://management.core.cloudapi.de/", - ResourceManagerEndpoint: "https://management.microsoftazure.de/", - ActiveDirectoryEndpoint: "https://login.microsoftonline.de/", - } - - azureEnvironments = map[string]azureEnvironment{ - "CHINA": chinaCloud, - "GERMAN": germanCloud, - "PUBLIC": publicCloud, - "USGOVERNMENT": usGovernmentCloud, - } ) -func (c *Config) GetAzureEnvironment() (azureEnvironment, error) { - if c.AzureEnvironment == "" { - c.AzureEnvironment = "public" - } - env, ok := azureEnvironments[strings.ToUpper(c.AzureEnvironment)] - if !ok { - return env, fmt.Errorf("azure environment not found: %s", c.AzureEnvironment) - } - return env, nil -} - type azureHostResolver interface { - tokenSourceFor(ctx context.Context, cfg *Config, env azureEnvironment, resource string) oauth2.TokenSource + tokenSourceFor(ctx context.Context, cfg *Config, aadEndpoint, resource string) oauth2.TokenSource } func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResolver) error { if c.AzureResourceID == "" || c.Host != "" { return nil } - env, err := c.GetAzureEnvironment() - if err != nil { - return err - } + azureEnv := c.Environment().azureEnvironment // azure resource ID can also be used in lieu of host by some of the clients, like Terraform - management := ahr.tokenSourceFor(ctx, c, env, env.ResourceManagerEndpoint) + management := ahr.tokenSourceFor(ctx, c, azureEnv.ActiveDirectoryEndpoint, azureEnv.ResourceManagerEndpoint) var workspaceMetadata struct { Properties struct { WorkspaceURL string `json:"workspaceUrl"` } `json:"properties"` } - requestURL := env.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01" - err = httpclient.DefaultClient.Do(ctx, "GET", requestURL, + requestURL := azureEnv.ResourceManagerEndpoint + c.AzureResourceID + "?api-version=2018-04-01" + err := httpclient.DefaultClient.Do(ctx, "GET", requestURL, httpclient.WithResponseUnmarshal(&workspaceMetadata), httpclient.WithTokenSource(management), ) @@ -97,13 +68,3 @@ func (c *Config) azureEnsureWorkspaceUrl(ctx context.Context, ahr azureHostResol logger.Debugf(ctx, "Discovered workspace url: %s", c.Host) return nil } - -// Resource ID of the Azure application we need to log in. -const azureDatabricksLoginAppID string = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" - -func (c *Config) getAzureLoginAppID() string { - if c.AzureLoginAppID != "" { - return c.AzureLoginAppID - } - return azureDatabricksLoginAppID -} diff --git a/config/azure_test.go b/config/azure_test.go deleted file mode 100644 index 5482c088e..000000000 --- a/config/azure_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestAzureLoginAppID(t *testing.T) { - var cfg Config - - // It's not set - cfg = Config{} - assert.Equal(t, azureDatabricksLoginAppID, cfg.getAzureLoginAppID()) - - // It's set - cfg = Config{AzureLoginAppID: "foobar"} - assert.Equal(t, "foobar", cfg.getAzureLoginAppID()) -} diff --git a/config/config.go b/config/config.go index abf42b5eb..3f307e9be 100644 --- a/config/config.go +++ b/config/config.go @@ -70,10 +70,15 @@ type Config struct { AzureClientID string `name:"azure_client_id" env:"ARM_CLIENT_ID" auth:"azure"` AzureTenantID string `name:"azure_tenant_id" env:"ARM_TENANT_ID" auth:"azure"` - // AzureEnvironment (Public, UsGov, China, Germany) has specific set of API endpoints. + // AzureEnvironment (PUBLIC, USGOVERNMENT, CHINA) has specific set of API endpoints. Starting from v0.26.0, + // the environment is determined based on the workspace hostname, if it's specified. AzureEnvironment string `name:"azure_environment" env:"ARM_ENVIRONMENT"` - // Azure Login Application ID. Must be set if authenticating for non-production workspaces. + // Azure Login Application ID. Must be set if authenticating for non-production workspaces. Starting from v0.26.0, + // the correct Azure Login App ID is determined based on the Azure Databricks Workspace hostname. + // + // Deprecated: this configuration property no longer has any effect and will be removed in the future + // versions of Go SDK. AzureLoginAppID string `name:"azure_login_app_id" env:"DATABRICKS_AZURE_LOGIN_APP_ID" auth:"azure"` ClientID string `name:"client_id" env:"DATABRICKS_CLIENT_ID" auth:"oauth"` @@ -108,6 +113,9 @@ type Config struct { // HTTPTransport can be overriden for unit testing and together with tooling like https://github.com/google/go-replayers HTTPTransport http.RoundTripper + // Metadata about the environment, where Databricks is deployed. Reserved for internal use. + DatabricksEnvironments []DatabricksEnvironment + Loaders []Loader // marker for configuration resolving @@ -142,16 +150,15 @@ func (c *Config) Authenticate(r *http.Request) error { // IsAzure returns if the client is configured for Azure Databricks. func (c *Config) IsAzure() bool { - isAzureHost := strings.Contains(c.Host, ".azuredatabricks.net") || - strings.Contains(c.Host, "databricks.azure.cn") || - strings.Contains(c.Host, ".databricks.azure.us") - - return isAzureHost || c.AzureResourceID != "" + if c.AzureResourceID != "" { + return true + } + return c.Environment().Cloud == CloudAzure } // IsGcp returns if the client is configured for Databricks on Google Cloud. func (c *Config) IsGcp() bool { - return strings.Contains(c.Host, ".gcp.databricks.com") + return c.Environment().Cloud == CloudGCP } // IsAws returns if the client is configured for Databricks on AWS. diff --git a/config/environments.go b/config/environments.go new file mode 100644 index 000000000..4c49a90e5 --- /dev/null +++ b/config/environments.go @@ -0,0 +1,99 @@ +package config + +import ( + "fmt" + "strings" +) + +type Cloud string + +const ( + CloudUnspecified Cloud = "Unspecified" + CloudAWS Cloud = "AWS" + CloudAzure Cloud = "Azure" + CloudGCP Cloud = "GCP" +) + +type DatabricksEnvironment struct { + Cloud Cloud + dnsZone string + azureApplicationID string + azureEnvironment *azureEnvironment +} + +func (de DatabricksEnvironment) DeploymentURL(name string) string { + return fmt.Sprintf("https://%s%s", name, de.dnsZone) +} + +func (de DatabricksEnvironment) AzureServiceManagementEndpoint() string { + if de.azureEnvironment == nil { + return "" + } + return de.azureEnvironment.ServiceManagementEndpoint +} + +func (de DatabricksEnvironment) AzureResourceManagerEndpoint() string { + if de.azureEnvironment == nil { + return "" + } + return de.azureEnvironment.ResourceManagerEndpoint +} + +func (de DatabricksEnvironment) AzureActiveDirectoryEndpoint() string { + if de.azureEnvironment == nil { + return "" + } + return de.azureEnvironment.ActiveDirectoryEndpoint +} + +// we default to AWS Prod environment since this case will be a hit for PVC +var defaultEnvironment = DatabricksEnvironment{ + Cloud: CloudAWS, + dnsZone: ".cloud.databricks.com", +} + +var envs = []DatabricksEnvironment{ + {Cloud: CloudUnspecified, dnsZone: "localhost"}, + + {Cloud: CloudAWS, dnsZone: ".dev.databricks.com"}, + {Cloud: CloudAWS, dnsZone: ".staging.cloud.databricks.com"}, + {Cloud: CloudAWS, dnsZone: ".cloud.databricks.us"}, + defaultEnvironment, + + {Cloud: CloudAzure, dnsZone: ".dev.azuredatabricks.net", azureApplicationID: "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", azureEnvironment: &publicCloud}, + {Cloud: CloudAzure, dnsZone: ".staging.azuredatabricks.net", azureApplicationID: "4a67d088-db5c-48f1-9ff2-0aace800ae68", azureEnvironment: &publicCloud}, + {Cloud: CloudAzure, dnsZone: ".azuredatabricks.net", azureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", azureEnvironment: &publicCloud}, + {Cloud: CloudAzure, dnsZone: ".databricks.azure.us", azureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", azureEnvironment: &usGovernmentCloud}, + {Cloud: CloudAzure, dnsZone: ".databricks.azure.cn", azureApplicationID: "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d", azureEnvironment: &chinaCloud}, + + {Cloud: CloudGCP, dnsZone: ".dev.gcp.databricks.com"}, + {Cloud: CloudGCP, dnsZone: ".staging.gcp.databricks.com"}, + {Cloud: CloudGCP, dnsZone: ".gcp.databricks.com"}, +} + +func (c *Config) Environment() DatabricksEnvironment { + if c.Host == "" && c.AzureResourceID != "" { + // azure resource ID can also be used in lieu of host by some + // of the clients, like Terraform + azureEnv := strings.ToUpper(c.AzureEnvironment) + if azureEnv == "" { + azureEnv = "PUBLIC" + } + for _, v := range envs { + if v.Cloud != CloudAzure { + continue + } + if v.azureEnvironment.Name != azureEnv { + continue + } + return v + } + } + hostname := c.CanonicalHostName() + for _, e := range append(c.DatabricksEnvironments, envs...) { + if strings.HasSuffix(hostname, e.dnsZone) { + return e + } + } + return defaultEnvironment +}